#ifndef ROOT_TMVA_DataSet
#define ROOT_TMVA_DataSet
#include <vector>
#include <map>
#include <string>
#ifndef ROOT_TObject
#include "TObject.h"
#endif
#ifndef ROOT_TNamed
#include "TNamed.h"
#endif
#ifndef ROOT_TString
#include "TString.h"
#endif
#ifndef ROOT_TTree
#include "TTree.h"
#endif
#ifndef ROOT_TRandom3
#include "TRandom3.h"
#endif
#ifndef ROOT_TMVA_Types
#include "TMVA/Types.h"
#endif
#ifndef ROOT_TMVA_VariableInfo
#include "TMVA/VariableInfo.h"
#endif
namespace TMVA {
class Event;
class DataSetInfo;
class MsgLogger;
class Results;
class DataSet :public TNamed {
public:
DataSet();
DataSet(const DataSetInfo&);
virtual ~DataSet();
void AddEvent( Event *, Types::ETreeType );
Long64_t GetNEvents( Types::ETreeType type = Types::kMaxTreeType ) const;
Long64_t GetNTrainingEvents() const { return GetNEvents(Types::kTraining); }
Long64_t GetNTestEvents() const { return GetNEvents(Types::kTesting); }
const Event* GetEvent() const;
const Event* GetEvent ( Long64_t ievt ) const { fCurrentEventIdx = ievt; return GetEvent(); }
const Event* GetTrainingEvent( Long64_t ievt ) const { return GetEvent(ievt, Types::kTraining); }
const Event* GetTestEvent ( Long64_t ievt ) const { return GetEvent(ievt, Types::kTesting); }
const Event* GetEvent ( Long64_t ievt, Types::ETreeType type ) const
{
fCurrentTreeIdx = TreeIndex(type); fCurrentEventIdx = ievt; return GetEvent();
}
UInt_t GetNVariables() const;
UInt_t GetNTargets() const;
UInt_t GetNSpectators() const;
void SetCurrentEvent( Long64_t ievt ) const { fCurrentEventIdx = ievt; }
void SetCurrentType ( Types::ETreeType type ) const { fCurrentTreeIdx = TreeIndex(type); }
Types::ETreeType GetCurrentType() const;
void SetEventCollection( std::vector<Event*>*, Types::ETreeType, Bool_t deleteEvents = true );
const std::vector<Event*>& GetEventCollection( Types::ETreeType type = Types::kMaxTreeType ) const;
const TTree* GetEventCollectionAsTree();
Long64_t GetNEvtSigTest();
Long64_t GetNEvtBkgdTest();
Long64_t GetNEvtSigTrain();
Long64_t GetNEvtBkgdTrain();
Bool_t HasNegativeEventWeights() const { return fHasNegativeEventWeights; }
Results* GetResults ( const TString &,
Types::ETreeType type,
Types::EAnalysisType analysistype );
void DeleteResults ( const TString &,
Types::ETreeType type,
Types::EAnalysisType analysistype );
void SetVerbose( Bool_t ) {}
void DivideTrainingSet( UInt_t blockNum );
void MoveTrainingBlock( Int_t blockInd,Types::ETreeType dest, Bool_t applyChanges = kTRUE );
void IncrementNClassEvents( Int_t type, UInt_t classNumber );
Long64_t GetNClassEvents ( Int_t type, UInt_t classNumber );
void ClearNClassEvents ( Int_t type );
TTree* GetTree( Types::ETreeType type );
void InitSampling( Float_t fraction, Float_t weight, UInt_t seed = 0 );
void EventResult( Bool_t successful, Long64_t evtNumber = -1 );
void CreateSampling() const;
UInt_t TreeIndex(Types::ETreeType type) const;
private:
void DestroyCollection( Types::ETreeType type, Bool_t deleteEvents );
const DataSetInfo *fdsi;
std::vector< std::vector<Event*> > fEventCollection;
std::vector< std::map< TString, Results* > > fResults;
mutable UInt_t fCurrentTreeIdx;
mutable Long64_t fCurrentEventIdx;
std::vector<Char_t> fSampling;
std::vector<Int_t> fSamplingNEvents;
std::vector<Float_t> fSamplingWeight;
mutable std::vector< std::vector< std::pair< Float_t, Long64_t > > > fSamplingEventList;
mutable std::vector< std::vector< std::pair< Float_t, Long64_t > > > fSamplingSelected;
TRandom3 *fSamplingRandom;
std::vector< std::vector<Long64_t> > fClassEvents;
Bool_t fHasNegativeEventWeights;
mutable MsgLogger* fLogger;
MsgLogger& Log() const { return *fLogger; }
std::vector<Char_t> fBlockBelongToTraining;
Long64_t fTrainingBlockSize;
void ApplyTrainingBlockDivision();
void ApplyTrainingSetDivision();
public:
ClassDef(DataSet,1);
};
}
inline UInt_t TMVA::DataSet::TreeIndex(Types::ETreeType type) const
{
switch (type) {
case Types::kMaxTreeType : return fCurrentTreeIdx;
case Types::kTraining : return 0;
case Types::kTesting : return 1;
case Types::kValidation : return 2;
case Types::kTrainingOriginal : return 3;
default : return fCurrentTreeIdx;
}
}
inline TMVA::Types::ETreeType TMVA::DataSet::GetCurrentType() const
{
switch (fCurrentTreeIdx) {
case 0: return Types::kTraining;
case 1: return Types::kTesting;
case 2: return Types::kValidation;
case 3: return Types::kTrainingOriginal;
}
return Types::kMaxTreeType;
}
inline Long64_t TMVA::DataSet::GetNEvents(Types::ETreeType type) const
{
Int_t treeIdx = TreeIndex(type);
if (fSampling.size() > UInt_t(treeIdx) && fSampling.at(treeIdx)) {
return fSamplingSelected.at(treeIdx).size();
}
return GetEventCollection(type).size();
}
inline const std::vector<TMVA::Event*>& TMVA::DataSet::GetEventCollection( TMVA::Types::ETreeType type ) const
{
return fEventCollection.at(TreeIndex(type));
}
#endif