// // @(#)root/tmva $Id$
// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss

/**********************************************************************************
 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
 * Package: TMVA                                                                  *
 * Class  : DataSetInfo                                                           *
 * Web    : http://tmva.sourceforge.net                                           *
 *                                                                                *
 * Description:                                                                   *
 *      Contains all the data information                                         *
 *                                                                                *
 * Authors (alphabetical):                                                        *
 *      Peter Speckmayer <speckmay@mail.cern.ch> - CERN, Switzerland              *
 *      Joerg Stelzer   <Joerg.Stelzer@cern.ch>  - DESY, Germany                  *
 *                                                                                *
 * Copyright (c) 2008-2011:                                                       *
 *      CERN, Switzerland                                                         *
 *      MPI-K Heidelberg, Germany                                                 *
 *      DESY Hamburg, Germany                                                     *
 *                                                                                *
 * Redistribution and use in source and binary forms, with or without             *
 * modification, are permitted according to the terms listed in LICENSE           *
 * (http://tmva.sourceforge.net/LICENSE)                                          *
 **********************************************************************************/

#ifndef ROOT_TMVA_DataSetInfo
#define ROOT_TMVA_DataSetInfo

//////////////////////////////////////////////////////////////////////////
//                                                                      //
// DataSetInfo                                                          //
//                                                                      //
// Class that contains all the data information                         //
//                                                                      //
//////////////////////////////////////////////////////////////////////////

#include <iosfwd>

#ifndef ROOT_TObject
#include "TObject.h"
#endif
#ifndef ROOT_TString
#include "TString.h"
#endif
#ifndef ROOT_TTree
#include "TTree.h"
#endif
#ifndef ROOT_TCut
#include "TCut.h"
#endif
#ifndef ROOT_TMatrixDfwd
#include "TMatrixDfwd.h"
#endif

#ifndef ROOT_TMVA_Types
#include "TMVA/Types.h"
#endif
#ifndef ROOT_TMVA_VariableInfo
#include "TMVA/VariableInfo.h"
#endif
#ifndef ROOT_TMVA_ClassInfo
#include "TMVA/ClassInfo.h"
#endif
#ifndef ROOT_TMVA_Event
#include "TMVA/Event.h"
#endif

class TH2;

namespace TMVA {

   class DataSet;
   class VariableTransformBase;
   class MsgLogger;
   class DataSetManager;

   class DataSetInfo : public TObject {

   public:

      DataSetInfo(const TString& name = "Default");
      virtual ~DataSetInfo();

      virtual const char* GetName() const { return fName.Data(); }

      // the data set
      void        ClearDataSet() const;
      DataSet*    GetDataSet() const;

      // ---
      // the variable data
      // ---
      VariableInfo&     AddVariable( const TString& expression, const TString& title = "", const TString& unit = "", 
                                     Double_t min = 0, Double_t max = 0, char varType='F', 
                                     Bool_t normalized = kTRUE, void* external = 0 );
      VariableInfo&     AddVariable( const VariableInfo& varInfo );

      VariableInfo&     AddTarget  ( const TString& expression, const TString& title, const TString& unit, 
                                     Double_t min, Double_t max, Bool_t normalized = kTRUE, void* external = 0 );
      VariableInfo&     AddTarget  ( const VariableInfo& varInfo );

      VariableInfo&     AddSpectator ( const TString& expression, const TString& title, const TString& unit, 
                                       Double_t min, Double_t max, char type = 'F', Bool_t normalized = kTRUE, void* external = 0 );
      VariableInfo&     AddSpectator ( const VariableInfo& varInfo );

      ClassInfo*        AddClass   ( const TString& className );

      // accessors

      // general
      std::vector<VariableInfo>&       GetVariableInfos()         { return fVariables; }
      const std::vector<VariableInfo>& GetVariableInfos() const   { return fVariables; }
      VariableInfo&                    GetVariableInfo( Int_t i ) { return fVariables.at(i); }
      const VariableInfo&              GetVariableInfo( Int_t i ) const { return fVariables.at(i); }

      std::vector<VariableInfo>&       GetTargetInfos()         { return fTargets; }
      const std::vector<VariableInfo>& GetTargetInfos() const   { return fTargets; }
      VariableInfo&                    GetTargetInfo( Int_t i ) { return fTargets.at(i); }
      const VariableInfo&              GetTargetInfo( Int_t i ) const { return fTargets.at(i); }

      std::vector<VariableInfo>&       GetSpectatorInfos()         { return fSpectators; }
      const std::vector<VariableInfo>& GetSpectatorInfos() const   { return fSpectators; }
      VariableInfo&                    GetSpectatorInfo( Int_t i ) { return fSpectators.at(i); }
      const VariableInfo&              GetSpectatorInfo( Int_t i ) const { return fSpectators.at(i); }


      UInt_t                           GetNVariables()    const { return fVariables.size(); }
      UInt_t                           GetNTargets()      const { return fTargets.size(); }
      UInt_t                           GetNSpectators(bool all=kTRUE)   const;

      const TString&                   GetNormalization() const { return fNormalization; }
      void                             SetNormalization( const TString& norm )   { fNormalization = norm; }

      void SetTrainingSumSignalWeights(Double_t trainingSumSignalWeights){fTrainingSumSignalWeights = trainingSumSignalWeights;}
      void SetTrainingSumBackgrWeights(Double_t trainingSumBackgrWeights){fTrainingSumBackgrWeights = trainingSumBackgrWeights;}
      void SetTestingSumSignalWeights (Double_t testingSumSignalWeights ){fTestingSumSignalWeights  = testingSumSignalWeights ;}
      void SetTestingSumBackgrWeights (Double_t testingSumBackgrWeights ){fTestingSumBackgrWeights  = testingSumBackgrWeights ;}

      Double_t GetTrainingSumSignalWeights();
      Double_t GetTrainingSumBackgrWeights();
      Double_t GetTestingSumSignalWeights ();
      Double_t GetTestingSumBackgrWeights ();



      // classification information
      Int_t              GetClassNameMaxLength() const;
      Int_t              GetVariableNameMaxLength() const;
      Int_t              GetTargetNameMaxLength() const;
      ClassInfo*         GetClassInfo( Int_t clNum ) const;
      ClassInfo*         GetClassInfo( const TString& name ) const;
      void               PrintClasses() const;
      UInt_t             GetNClasses() const { return fClasses.size(); }
      Bool_t             IsSignal( const Event* ev ) const;
      std::vector<Float_t>* GetTargetsForMulticlass( const Event* ev );
      UInt_t             GetSignalClassIndex(){return fSignalClass;}

      // by variable
      Int_t              FindVarIndex( const TString& )      const;

      // weights
      const TString      GetWeightExpression(Int_t i)      const { return GetClassInfo(i)->GetWeight(); }
      void               SetWeightExpression( const TString& exp, const TString& className = "" );

      // cuts
      const TCut&        GetCut (Int_t i)                         const { return GetClassInfo(i)->GetCut(); }
      const TCut&        GetCut ( const TString& className )      const { return GetClassInfo(className)->GetCut(); }
      void               SetCut ( const TCut& cut, const TString& className );
      void               AddCut ( const TCut& cut, const TString& className );
      Bool_t             HasCuts() const;

      std::vector<TString> GetListOfVariables() const;

      // correlation matrix
      const TMatrixD*    CorrelationMatrix     ( const TString& className ) const;
      void               SetCorrelationMatrix  ( const TString& className, TMatrixD* matrix );
      void               PrintCorrelationMatrix( const TString& className );
      TH2*               CreateCorrelationMatrixHist( const TMatrixD* m,
                                                      const TString& hName,
                                                      const TString& hTitle ) const;

      // options
      void               SetSplitOptions(const TString& so) { fSplitOptions = so; fNeedsRebuilding = kTRUE; }
      const TString&     GetSplitOptions() const { return fSplitOptions; }

      // root dir
      void               SetRootDir(TDirectory* d) { fOwnRootDir = d; }
      TDirectory*        GetRootDir() const { return fOwnRootDir; }

      void               SetMsgType( EMsgType t ) const;

      DataSetManager*   GetDataSetManager(){return fDataSetManager;}
   private:

      TMVA::DataSetManager*            fDataSetManager; // DSMTEST
      void                       SetDataSetManager( DataSetManager* dsm ) { fDataSetManager = dsm; } // DSMTEST
      friend class DataSetManager;  // DSMTEST (datasetmanager test)

   DataSetInfo( const DataSetInfo& ) : TObject() {}

      void PrintCorrelationMatrix( TTree* theTree );

      TString                    fName;              // name of the dataset info object

      mutable DataSet*           fDataSet;           // dataset, owned by this datasetinfo object
      mutable Bool_t             fNeedsRebuilding;   // flag if rebuilding of dataset is needed (after change of cuts, vars, etc.)

      // expressions/formulas
      std::vector<VariableInfo>  fVariables;         // list of variable expressions/internal names
      std::vector<VariableInfo>  fTargets;           // list of targets expressions/internal names
      std::vector<VariableInfo>  fSpectators;        // list of spectators expressions/internal names

      // the classes
      mutable std::vector<ClassInfo*> fClasses;      // name and other infos of the classes

      TString                    fNormalization;     //
      TString                    fSplitOptions;      //

      Double_t                   fTrainingSumSignalWeights;
      Double_t                   fTrainingSumBackgrWeights;
      Double_t                   fTestingSumSignalWeights ;
      Double_t                   fTestingSumBackgrWeights ;


      
      TDirectory*                fOwnRootDir;        // ROOT output dir
      Bool_t                     fVerbose;           // Verbosity

      UInt_t                     fSignalClass;       // index of the class with the name signal

      std::vector<Float_t>*      fTargetsForMulticlass;//-> all targets 0 except the one with index==classNumber
      
      mutable MsgLogger*         fLogger;            //! message logger
      MsgLogger& Log() const { return *fLogger; }

   public:
       
       ClassDef(DataSetInfo,1);
   };
}

#endif
 DataSetInfo.h:1
 DataSetInfo.h:2
 DataSetInfo.h:3
 DataSetInfo.h:4
 DataSetInfo.h:5
 DataSetInfo.h:6
 DataSetInfo.h:7
 DataSetInfo.h:8
 DataSetInfo.h:9
 DataSetInfo.h:10
 DataSetInfo.h:11
 DataSetInfo.h:12
 DataSetInfo.h:13
 DataSetInfo.h:14
 DataSetInfo.h:15
 DataSetInfo.h:16
 DataSetInfo.h:17
 DataSetInfo.h:18
 DataSetInfo.h:19
 DataSetInfo.h:20
 DataSetInfo.h:21
 DataSetInfo.h:22
 DataSetInfo.h:23
 DataSetInfo.h:24
 DataSetInfo.h:25
 DataSetInfo.h:26
 DataSetInfo.h:27
 DataSetInfo.h:28
 DataSetInfo.h:29
 DataSetInfo.h:30
 DataSetInfo.h:31
 DataSetInfo.h:32
 DataSetInfo.h:33
 DataSetInfo.h:34
 DataSetInfo.h:35
 DataSetInfo.h:36
 DataSetInfo.h:37
 DataSetInfo.h:38
 DataSetInfo.h:39
 DataSetInfo.h:40
 DataSetInfo.h:41
 DataSetInfo.h:42
 DataSetInfo.h:43
 DataSetInfo.h:44
 DataSetInfo.h:45
 DataSetInfo.h:46
 DataSetInfo.h:47
 DataSetInfo.h:48
 DataSetInfo.h:49
 DataSetInfo.h:50
 DataSetInfo.h:51
 DataSetInfo.h:52
 DataSetInfo.h:53
 DataSetInfo.h:54
 DataSetInfo.h:55
 DataSetInfo.h:56
 DataSetInfo.h:57
 DataSetInfo.h:58
 DataSetInfo.h:59
 DataSetInfo.h:60
 DataSetInfo.h:61
 DataSetInfo.h:62
 DataSetInfo.h:63
 DataSetInfo.h:64
 DataSetInfo.h:65
 DataSetInfo.h:66
 DataSetInfo.h:67
 DataSetInfo.h:68
 DataSetInfo.h:69
 DataSetInfo.h:70
 DataSetInfo.h:71
 DataSetInfo.h:72
 DataSetInfo.h:73
 DataSetInfo.h:74
 DataSetInfo.h:75
 DataSetInfo.h:76
 DataSetInfo.h:77
 DataSetInfo.h:78
 DataSetInfo.h:79
 DataSetInfo.h:80
 DataSetInfo.h:81
 DataSetInfo.h:82
 DataSetInfo.h:83
 DataSetInfo.h:84
 DataSetInfo.h:85
 DataSetInfo.h:86
 DataSetInfo.h:87
 DataSetInfo.h:88
 DataSetInfo.h:89
 DataSetInfo.h:90
 DataSetInfo.h:91
 DataSetInfo.h:92
 DataSetInfo.h:93
 DataSetInfo.h:94
 DataSetInfo.h:95
 DataSetInfo.h:96
 DataSetInfo.h:97
 DataSetInfo.h:98
 DataSetInfo.h:99
 DataSetInfo.h:100
 DataSetInfo.h:101
 DataSetInfo.h:102
 DataSetInfo.h:103
 DataSetInfo.h:104
 DataSetInfo.h:105
 DataSetInfo.h:106
 DataSetInfo.h:107
 DataSetInfo.h:108
 DataSetInfo.h:109
 DataSetInfo.h:110
 DataSetInfo.h:111
 DataSetInfo.h:112
 DataSetInfo.h:113
 DataSetInfo.h:114
 DataSetInfo.h:115
 DataSetInfo.h:116
 DataSetInfo.h:117
 DataSetInfo.h:118
 DataSetInfo.h:119
 DataSetInfo.h:120
 DataSetInfo.h:121
 DataSetInfo.h:122
 DataSetInfo.h:123
 DataSetInfo.h:124
 DataSetInfo.h:125
 DataSetInfo.h:126
 DataSetInfo.h:127
 DataSetInfo.h:128
 DataSetInfo.h:129
 DataSetInfo.h:130
 DataSetInfo.h:131
 DataSetInfo.h:132
 DataSetInfo.h:133
 DataSetInfo.h:134
 DataSetInfo.h:135
 DataSetInfo.h:136
 DataSetInfo.h:137
 DataSetInfo.h:138
 DataSetInfo.h:139
 DataSetInfo.h:140
 DataSetInfo.h:141
 DataSetInfo.h:142
 DataSetInfo.h:143
 DataSetInfo.h:144
 DataSetInfo.h:145
 DataSetInfo.h:146
 DataSetInfo.h:147
 DataSetInfo.h:148
 DataSetInfo.h:149
 DataSetInfo.h:150
 DataSetInfo.h:151
 DataSetInfo.h:152
 DataSetInfo.h:153
 DataSetInfo.h:154
 DataSetInfo.h:155
 DataSetInfo.h:156
 DataSetInfo.h:157
 DataSetInfo.h:158
 DataSetInfo.h:159
 DataSetInfo.h:160
 DataSetInfo.h:161
 DataSetInfo.h:162
 DataSetInfo.h:163
 DataSetInfo.h:164
 DataSetInfo.h:165
 DataSetInfo.h:166
 DataSetInfo.h:167
 DataSetInfo.h:168
 DataSetInfo.h:169
 DataSetInfo.h:170
 DataSetInfo.h:171
 DataSetInfo.h:172
 DataSetInfo.h:173
 DataSetInfo.h:174
 DataSetInfo.h:175
 DataSetInfo.h:176
 DataSetInfo.h:177
 DataSetInfo.h:178
 DataSetInfo.h:179
 DataSetInfo.h:180
 DataSetInfo.h:181
 DataSetInfo.h:182
 DataSetInfo.h:183
 DataSetInfo.h:184
 DataSetInfo.h:185
 DataSetInfo.h:186
 DataSetInfo.h:187
 DataSetInfo.h:188
 DataSetInfo.h:189
 DataSetInfo.h:190
 DataSetInfo.h:191
 DataSetInfo.h:192
 DataSetInfo.h:193
 DataSetInfo.h:194
 DataSetInfo.h:195
 DataSetInfo.h:196
 DataSetInfo.h:197
 DataSetInfo.h:198
 DataSetInfo.h:199
 DataSetInfo.h:200
 DataSetInfo.h:201
 DataSetInfo.h:202
 DataSetInfo.h:203
 DataSetInfo.h:204
 DataSetInfo.h:205
 DataSetInfo.h:206
 DataSetInfo.h:207
 DataSetInfo.h:208
 DataSetInfo.h:209
 DataSetInfo.h:210
 DataSetInfo.h:211
 DataSetInfo.h:212
 DataSetInfo.h:213
 DataSetInfo.h:214
 DataSetInfo.h:215
 DataSetInfo.h:216
 DataSetInfo.h:217
 DataSetInfo.h:218
 DataSetInfo.h:219
 DataSetInfo.h:220
 DataSetInfo.h:221
 DataSetInfo.h:222
 DataSetInfo.h:223
 DataSetInfo.h:224
 DataSetInfo.h:225
 DataSetInfo.h:226
 DataSetInfo.h:227
 DataSetInfo.h:228
 DataSetInfo.h:229
 DataSetInfo.h:230
 DataSetInfo.h:231
 DataSetInfo.h:232
 DataSetInfo.h:233
 DataSetInfo.h:234
 DataSetInfo.h:235
 DataSetInfo.h:236
 DataSetInfo.h:237
 DataSetInfo.h:238
 DataSetInfo.h:239
 DataSetInfo.h:240
 DataSetInfo.h:241
 DataSetInfo.h:242
 DataSetInfo.h:243