#ifndef ROOT_TMVA_DataSetInfo
#define ROOT_TMVA_DataSetInfo
#include <iosfwd>
#ifndef ROOT_TObject
#include "TObject.h"
#endif
#ifndef ROOT_TString
#include "TString.h"
#endif
#ifndef ROOT_TTree
#include "TTree.h"
#endif
#ifndef ROOT_TCut
#include "TCut.h"
#endif
#ifndef ROOT_TMatrixDfwd
#include "TMatrixDfwd.h"
#endif
#ifndef ROOT_TMVA_Types
#include "TMVA/Types.h"
#endif
#ifndef ROOT_TMVA_VariableInfo
#include "TMVA/VariableInfo.h"
#endif
#ifndef ROOT_TMVA_ClassInfo
#include "TMVA/ClassInfo.h"
#endif
#ifndef ROOT_TMVA_Event
#include "TMVA/Event.h"
#endif
class TH2;
namespace TMVA {
class DataSet;
class VariableTransformBase;
class MsgLogger;
class DataSetManager;
class DataSetInfo : public TObject {
public:
DataSetInfo(const TString& name = "Default");
virtual ~DataSetInfo();
virtual const char* GetName() const { return fName.Data(); }
void ClearDataSet() const;
DataSet* GetDataSet() const;
VariableInfo& AddVariable( const TString& expression, const TString& title = "", const TString& unit = "",
Double_t min = 0, Double_t max = 0, char varType='F',
Bool_t normalized = kTRUE, void* external = 0 );
VariableInfo& AddVariable( const VariableInfo& varInfo );
VariableInfo& AddTarget ( const TString& expression, const TString& title, const TString& unit,
Double_t min, Double_t max, Bool_t normalized = kTRUE, void* external = 0 );
VariableInfo& AddTarget ( const VariableInfo& varInfo );
VariableInfo& AddSpectator ( const TString& expression, const TString& title, const TString& unit,
Double_t min, Double_t max, char type = 'F', Bool_t normalized = kTRUE, void* external = 0 );
VariableInfo& AddSpectator ( const VariableInfo& varInfo );
ClassInfo* AddClass ( const TString& className );
std::vector<VariableInfo>& GetVariableInfos() { return fVariables; }
const std::vector<VariableInfo>& GetVariableInfos() const { return fVariables; }
VariableInfo& GetVariableInfo( Int_t i ) { return fVariables.at(i); }
const VariableInfo& GetVariableInfo( Int_t i ) const { return fVariables.at(i); }
std::vector<VariableInfo>& GetTargetInfos() { return fTargets; }
const std::vector<VariableInfo>& GetTargetInfos() const { return fTargets; }
VariableInfo& GetTargetInfo( Int_t i ) { return fTargets.at(i); }
const VariableInfo& GetTargetInfo( Int_t i ) const { return fTargets.at(i); }
std::vector<VariableInfo>& GetSpectatorInfos() { return fSpectators; }
const std::vector<VariableInfo>& GetSpectatorInfos() const { return fSpectators; }
VariableInfo& GetSpectatorInfo( Int_t i ) { return fSpectators.at(i); }
const VariableInfo& GetSpectatorInfo( Int_t i ) const { return fSpectators.at(i); }
UInt_t GetNVariables() const { return fVariables.size(); }
UInt_t GetNTargets() const { return fTargets.size(); }
UInt_t GetNSpectators(bool all=kTRUE) const;
const TString& GetNormalization() const { return fNormalization; }
void SetNormalization( const TString& norm ) { fNormalization = norm; }
void SetTrainingSumSignalWeights(Double_t trainingSumSignalWeights){fTrainingSumSignalWeights = trainingSumSignalWeights;}
void SetTrainingSumBackgrWeights(Double_t trainingSumBackgrWeights){fTrainingSumBackgrWeights = trainingSumBackgrWeights;}
void SetTestingSumSignalWeights (Double_t testingSumSignalWeights ){fTestingSumSignalWeights = testingSumSignalWeights ;}
void SetTestingSumBackgrWeights (Double_t testingSumBackgrWeights ){fTestingSumBackgrWeights = testingSumBackgrWeights ;}
Double_t GetTrainingSumSignalWeights();
Double_t GetTrainingSumBackgrWeights();
Double_t GetTestingSumSignalWeights ();
Double_t GetTestingSumBackgrWeights ();
Int_t GetClassNameMaxLength() const;
Int_t GetVariableNameMaxLength() const;
Int_t GetTargetNameMaxLength() const;
ClassInfo* GetClassInfo( Int_t clNum ) const;
ClassInfo* GetClassInfo( const TString& name ) const;
void PrintClasses() const;
UInt_t GetNClasses() const { return fClasses.size(); }
Bool_t IsSignal( const Event* ev ) const;
std::vector<Float_t>* GetTargetsForMulticlass( const Event* ev );
UInt_t GetSignalClassIndex(){return fSignalClass;}
Int_t FindVarIndex( const TString& ) const;
const TString GetWeightExpression(Int_t i) const { return GetClassInfo(i)->GetWeight(); }
void SetWeightExpression( const TString& exp, const TString& className = "" );
const TCut& GetCut (Int_t i) const { return GetClassInfo(i)->GetCut(); }
const TCut& GetCut ( const TString& className ) const { return GetClassInfo(className)->GetCut(); }
void SetCut ( const TCut& cut, const TString& className );
void AddCut ( const TCut& cut, const TString& className );
Bool_t HasCuts() const;
std::vector<TString> GetListOfVariables() const;
const TMatrixD* CorrelationMatrix ( const TString& className ) const;
void SetCorrelationMatrix ( const TString& className, TMatrixD* matrix );
void PrintCorrelationMatrix( const TString& className );
TH2* CreateCorrelationMatrixHist( const TMatrixD* m,
const TString& hName,
const TString& hTitle ) const;
void SetSplitOptions(const TString& so) { fSplitOptions = so; fNeedsRebuilding = kTRUE; }
const TString& GetSplitOptions() const { return fSplitOptions; }
void SetRootDir(TDirectory* d) { fOwnRootDir = d; }
TDirectory* GetRootDir() const { return fOwnRootDir; }
void SetMsgType( EMsgType t ) const;
DataSetManager* GetDataSetManager(){return fDataSetManager;}
private:
TMVA::DataSetManager* fDataSetManager;
void SetDataSetManager( DataSetManager* dsm ) { fDataSetManager = dsm; }
friend class DataSetManager;
DataSetInfo( const DataSetInfo& ) : TObject() {}
void PrintCorrelationMatrix( TTree* theTree );
TString fName;
mutable DataSet* fDataSet;
mutable Bool_t fNeedsRebuilding;
std::vector<VariableInfo> fVariables;
std::vector<VariableInfo> fTargets;
std::vector<VariableInfo> fSpectators;
mutable std::vector<ClassInfo*> fClasses;
TString fNormalization;
TString fSplitOptions;
Double_t fTrainingSumSignalWeights;
Double_t fTrainingSumBackgrWeights;
Double_t fTestingSumSignalWeights ;
Double_t fTestingSumBackgrWeights ;
TDirectory* fOwnRootDir;
Bool_t fVerbose;
UInt_t fSignalClass;
std::vector<Float_t>* fTargetsForMulticlass;
mutable MsgLogger* fLogger;
MsgLogger& Log() const { return *fLogger; }
public:
ClassDef(DataSetInfo,1);
};
}
#endif