// @(#)root/tmva $Id$
// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss

/**********************************************************************************
 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
 * Package: TMVA                                                                  *
 * Class  : DataSet                                                               *
 * Web    : http://tmva.sourceforge.net                                           *
 *                                                                                *
 * Description:                                                                   *
 *      Contains all the data information                                         *
 *                                                                                *
 * Authors (alphabetical):                                                        *
 *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
 *      Joerg Stelzer   <Joerg.Stelzer@cern.ch>  - CERN, Switzerland              *
 *      Peter Speckmayer <Peter.Speckmayer@cern.ch>  - CERN, Switzerland          *
 *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
 *                                                                                *
 * Copyright (c) 2006:                                                            *
 *      CERN, Switzerland                                                         *
 *      U. of Victoria, Canada                                                    *
 *      MPI-K Heidelberg, Germany                                                 *
 *                                                                                *
 * Redistribution and use in source and binary forms, with or without             *
 * modification, are permitted according to the terms listed in LICENSE           *
 * (http://tmva.sourceforge.net/LICENSE)                                          *
 **********************************************************************************/

#ifndef ROOT_TMVA_DataSet
#define ROOT_TMVA_DataSet

//////////////////////////////////////////////////////////////////////////
//                                                                      //
// DataSet                                                              //
//                                                                      //
// Class that contains all the data information                         //
//                                                                      //
//////////////////////////////////////////////////////////////////////////

#include <vector>
#include <map>
#include <string>

#ifndef ROOT_TObject
#include "TObject.h"
#endif
#ifndef ROOT_TNamed
#include "TNamed.h"
#endif
#ifndef ROOT_TString
#include "TString.h"
#endif
#ifndef ROOT_TTree
#include "TTree.h"
#endif
//#ifndef ROOT_TCut
//#include "TCut.h"
//#endif
//#ifndef ROOT_TMatrixDfwd
//#include "TMatrixDfwd.h"
//#endif
//#ifndef ROOT_TPrincipal
//#include "TPrincipal.h"
//#endif
#ifndef ROOT_TRandom3
#include "TRandom3.h"
#endif

#ifndef ROOT_TMVA_Types
#include "TMVA/Types.h"
#endif
#ifndef ROOT_TMVA_VariableInfo
#include "TMVA/VariableInfo.h"
#endif

namespace TMVA {

   class Event;
   class DataSetInfo;
   class MsgLogger;
   class Results;

   class DataSet :public TNamed {

   public:
      DataSet();
      DataSet(const DataSetInfo&);
      virtual ~DataSet();

      void      AddEvent( Event *, Types::ETreeType );

      Long64_t  GetNEvents( Types::ETreeType type = Types::kMaxTreeType ) const;
      Long64_t  GetNTrainingEvents()              const { return GetNEvents(Types::kTraining); }
      Long64_t  GetNTestEvents()                  const { return GetNEvents(Types::kTesting); }

      // const getters
      const Event*    GetEvent()                        const; // returns event without transformations
      const Event*    GetEvent        ( Long64_t ievt ) const { fCurrentEventIdx = ievt; return GetEvent(); } // returns event without transformations
      const Event*    GetTrainingEvent( Long64_t ievt ) const { return GetEvent(ievt, Types::kTraining); }
      const Event*    GetTestEvent    ( Long64_t ievt ) const { return GetEvent(ievt, Types::kTesting); }
      const Event*    GetEvent        ( Long64_t ievt, Types::ETreeType type ) const 
      {
         fCurrentTreeIdx = TreeIndex(type); fCurrentEventIdx = ievt; return GetEvent();
      }




      UInt_t    GetNVariables()   const;
      UInt_t    GetNTargets()     const;
      UInt_t    GetNSpectators()  const;

      void      SetCurrentEvent( Long64_t ievt         ) const { fCurrentEventIdx = ievt; }
      void      SetCurrentType ( Types::ETreeType type ) const { fCurrentTreeIdx = TreeIndex(type); }
      Types::ETreeType GetCurrentType() const;

      void                       SetEventCollection( std::vector<Event*>*, Types::ETreeType, Bool_t deleteEvents = true );
      const std::vector<Event*>& GetEventCollection( Types::ETreeType type = Types::kMaxTreeType ) const;
      const TTree*               GetEventCollectionAsTree();

      Long64_t  GetNEvtSigTest();
      Long64_t  GetNEvtBkgdTest();
      Long64_t  GetNEvtSigTrain();
      Long64_t  GetNEvtBkgdTrain();

      Bool_t    HasNegativeEventWeights() const { return fHasNegativeEventWeights; }

      Results*  GetResults   ( const TString &,
                               Types::ETreeType type,
                               Types::EAnalysisType analysistype );
      void      DeleteResults   ( const TString &,
                                  Types::ETreeType type,
                                  Types::EAnalysisType analysistype );

      void      SetVerbose( Bool_t ) {}

      // sets the number of blocks to which the training set is divided,
      // some of which are given to the Validation sample. As default they belong all to Training set.
      void      DivideTrainingSet( UInt_t blockNum );

      // sets a certrain block from the origin training set to belong to either Training or Validation set
      void      MoveTrainingBlock( Int_t blockInd,Types::ETreeType dest, Bool_t applyChanges = kTRUE );

      void      IncrementNClassEvents( Int_t type, UInt_t classNumber );
      Long64_t  GetNClassEvents      ( Int_t type, UInt_t classNumber );
      void      ClearNClassEvents    ( Int_t type );

      TTree*    GetTree( Types::ETreeType type );

      // accessors for random and importance sampling
      void      InitSampling( Float_t fraction, Float_t weight, UInt_t seed = 0 );
      void      EventResult( Bool_t successful, Long64_t evtNumber = -1 );
      void      CreateSampling() const;

      UInt_t    TreeIndex(Types::ETreeType type) const;

   private:

      // data members
      void DestroyCollection( Types::ETreeType type, Bool_t deleteEvents );

      const DataSetInfo         *fdsi;                       //-> datasetinfo that created this dataset

      std::vector< std::vector<Event*>  > fEventCollection;  // list of events for training/testing/...

      std::vector< std::map< TString, Results* > > fResults; //!  [train/test/...][method-identifier]

      mutable UInt_t             fCurrentTreeIdx;
      mutable Long64_t           fCurrentEventIdx;

      // event sampling
      std::vector<Char_t>        fSampling;                   // random or importance sampling (not all events are taken) !! Bool_t are stored ( no std::vector<bool> taken for speed (performance) issues )
      std::vector<Int_t>         fSamplingNEvents;            // number of events which should be sampled
      std::vector<Float_t>       fSamplingWeight;             // weight change factor [weight is indicating if sampling is random (1.0) or importance (<1.0)] 
      mutable std::vector< std::vector< std::pair< Float_t, Long64_t > > > fSamplingEventList;  // weights and indices for sampling
      mutable std::vector< std::vector< std::pair< Float_t, Long64_t > > > fSamplingSelected;   // selected events
      TRandom3                   *fSamplingRandom;             //-> random generator for sampling


      // further things
      std::vector< std::vector<Long64_t> > fClassEvents;       // number of events of class 0,1,2,... in training[0] 
                                                               // and testing[1] (+validation, trainingoriginal)

      Bool_t                     fHasNegativeEventWeights;     // true if at least one signal or bkg event has negative weight

      mutable MsgLogger*         fLogger;                      //! message logger
      MsgLogger& Log() const { return *fLogger; }
      std::vector<Char_t>        fBlockBelongToTraining;       // when dividing the dataset to blocks, sets whether 
                                                               // the certain block is in the Training set or else 
                                                               // in the validation set 
                                                               // boolean are stored, taken std::vector<Char_t> for performance reasons (instead of std::vector<Bool_t>)
      Long64_t                   fTrainingBlockSize;           // block size into which the training dataset is divided

      void  ApplyTrainingBlockDivision();
      void  ApplyTrainingSetDivision();
   public:
       
       ClassDef(DataSet,1);
   };
}


//_______________________________________________________________________
inline UInt_t TMVA::DataSet::TreeIndex(Types::ETreeType type) const
{
   switch (type) {
   case Types::kMaxTreeType : return fCurrentTreeIdx;
   case Types::kTraining : return 0;
   case Types::kTesting : return 1;
   case Types::kValidation : return 2;
   case Types::kTrainingOriginal : return 3;
   default : return fCurrentTreeIdx;
   }
}

//_______________________________________________________________________
inline TMVA::Types::ETreeType TMVA::DataSet::GetCurrentType() const
{
   switch (fCurrentTreeIdx) {
   case 0: return Types::kTraining;
   case 1: return Types::kTesting;
   case 2: return Types::kValidation;
   case 3: return Types::kTrainingOriginal;
   }
   return Types::kMaxTreeType;
}

//_______________________________________________________________________
inline Long64_t TMVA::DataSet::GetNEvents(Types::ETreeType type) const 
{
   Int_t treeIdx = TreeIndex(type);
   if (fSampling.size() > UInt_t(treeIdx) && fSampling.at(treeIdx)) {
      return fSamplingSelected.at(treeIdx).size();
   }
   return GetEventCollection(type).size();
}

//_______________________________________________________________________
inline const std::vector<TMVA::Event*>& TMVA::DataSet::GetEventCollection( TMVA::Types::ETreeType type ) const
{
   return fEventCollection.at(TreeIndex(type));
}


#endif
 DataSet.h:1
 DataSet.h:2
 DataSet.h:3
 DataSet.h:4
 DataSet.h:5
 DataSet.h:6
 DataSet.h:7
 DataSet.h:8
 DataSet.h:9
 DataSet.h:10
 DataSet.h:11
 DataSet.h:12
 DataSet.h:13
 DataSet.h:14
 DataSet.h:15
 DataSet.h:16
 DataSet.h:17
 DataSet.h:18
 DataSet.h:19
 DataSet.h:20
 DataSet.h:21
 DataSet.h:22
 DataSet.h:23
 DataSet.h:24
 DataSet.h:25
 DataSet.h:26
 DataSet.h:27
 DataSet.h:28
 DataSet.h:29
 DataSet.h:30
 DataSet.h:31
 DataSet.h:32
 DataSet.h:33
 DataSet.h:34
 DataSet.h:35
 DataSet.h:36
 DataSet.h:37
 DataSet.h:38
 DataSet.h:39
 DataSet.h:40
 DataSet.h:41
 DataSet.h:42
 DataSet.h:43
 DataSet.h:44
 DataSet.h:45
 DataSet.h:46
 DataSet.h:47
 DataSet.h:48
 DataSet.h:49
 DataSet.h:50
 DataSet.h:51
 DataSet.h:52
 DataSet.h:53
 DataSet.h:54
 DataSet.h:55
 DataSet.h:56
 DataSet.h:57
 DataSet.h:58
 DataSet.h:59
 DataSet.h:60
 DataSet.h:61
 DataSet.h:62
 DataSet.h:63
 DataSet.h:64
 DataSet.h:65
 DataSet.h:66
 DataSet.h:67
 DataSet.h:68
 DataSet.h:69
 DataSet.h:70
 DataSet.h:71
 DataSet.h:72
 DataSet.h:73
 DataSet.h:74
 DataSet.h:75
 DataSet.h:76
 DataSet.h:77
 DataSet.h:78
 DataSet.h:79
 DataSet.h:80
 DataSet.h:81
 DataSet.h:82
 DataSet.h:83
 DataSet.h:84
 DataSet.h:85
 DataSet.h:86
 DataSet.h:87
 DataSet.h:88
 DataSet.h:89
 DataSet.h:90
 DataSet.h:91
 DataSet.h:92
 DataSet.h:93
 DataSet.h:94
 DataSet.h:95
 DataSet.h:96
 DataSet.h:97
 DataSet.h:98
 DataSet.h:99
 DataSet.h:100
 DataSet.h:101
 DataSet.h:102
 DataSet.h:103
 DataSet.h:104
 DataSet.h:105
 DataSet.h:106
 DataSet.h:107
 DataSet.h:108
 DataSet.h:109
 DataSet.h:110
 DataSet.h:111
 DataSet.h:112
 DataSet.h:113
 DataSet.h:114
 DataSet.h:115
 DataSet.h:116
 DataSet.h:117
 DataSet.h:118
 DataSet.h:119
 DataSet.h:120
 DataSet.h:121
 DataSet.h:122
 DataSet.h:123
 DataSet.h:124
 DataSet.h:125
 DataSet.h:126
 DataSet.h:127
 DataSet.h:128
 DataSet.h:129
 DataSet.h:130
 DataSet.h:131
 DataSet.h:132
 DataSet.h:133
 DataSet.h:134
 DataSet.h:135
 DataSet.h:136
 DataSet.h:137
 DataSet.h:138
 DataSet.h:139
 DataSet.h:140
 DataSet.h:141
 DataSet.h:142
 DataSet.h:143
 DataSet.h:144
 DataSet.h:145
 DataSet.h:146
 DataSet.h:147
 DataSet.h:148
 DataSet.h:149
 DataSet.h:150
 DataSet.h:151
 DataSet.h:152
 DataSet.h:153
 DataSet.h:154
 DataSet.h:155
 DataSet.h:156
 DataSet.h:157
 DataSet.h:158
 DataSet.h:159
 DataSet.h:160
 DataSet.h:161
 DataSet.h:162
 DataSet.h:163
 DataSet.h:164
 DataSet.h:165
 DataSet.h:166
 DataSet.h:167
 DataSet.h:168
 DataSet.h:169
 DataSet.h:170
 DataSet.h:171
 DataSet.h:172
 DataSet.h:173
 DataSet.h:174
 DataSet.h:175
 DataSet.h:176
 DataSet.h:177
 DataSet.h:178
 DataSet.h:179
 DataSet.h:180
 DataSet.h:181
 DataSet.h:182
 DataSet.h:183
 DataSet.h:184
 DataSet.h:185
 DataSet.h:186
 DataSet.h:187
 DataSet.h:188
 DataSet.h:189
 DataSet.h:190
 DataSet.h:191
 DataSet.h:192
 DataSet.h:193
 DataSet.h:194
 DataSet.h:195
 DataSet.h:196
 DataSet.h:197
 DataSet.h:198
 DataSet.h:199
 DataSet.h:200
 DataSet.h:201
 DataSet.h:202
 DataSet.h:203
 DataSet.h:204
 DataSet.h:205
 DataSet.h:206
 DataSet.h:207
 DataSet.h:208
 DataSet.h:209
 DataSet.h:210
 DataSet.h:211
 DataSet.h:212
 DataSet.h:213
 DataSet.h:214
 DataSet.h:215
 DataSet.h:216
 DataSet.h:217
 DataSet.h:218
 DataSet.h:219
 DataSet.h:220
 DataSet.h:221
 DataSet.h:222
 DataSet.h:223
 DataSet.h:224
 DataSet.h:225
 DataSet.h:226
 DataSet.h:227
 DataSet.h:228
 DataSet.h:229
 DataSet.h:230
 DataSet.h:231
 DataSet.h:232
 DataSet.h:233
 DataSet.h:234
 DataSet.h:235
 DataSet.h:236
 DataSet.h:237
 DataSet.h:238
 DataSet.h:239
 DataSet.h:240
 DataSet.h:241
 DataSet.h:242
 DataSet.h:243
 DataSet.h:244
 DataSet.h:245