// @(#)root/tmva $Id$
// Author: Omar Zapata, Thomas James Stevenson and Pourya Vakilipourtakalou. 2016

#ifndef ROOT_TMVA_CrossValidation
#define ROOT_TMVA_CrossValidation

#ifndef ROOT_TString
#include "TString.h"
#endif

#ifndef ROOT_TMultiGraph
#include "TMultiGraph.h"
#endif

#ifndef ROOT_TMVA_IMethod
#include "TMVA/IMethod.h"
#endif

#ifndef ROOT_TMVA_Configurable
#include "TMVA/Configurable.h"
#endif

#ifndef ROOT_TMVA_Types
#include "TMVA/Types.h"
#endif

#ifndef ROOT_TMVA_DataSet
#include "TMVA/DataSet.h"
#endif

#ifndef ROOT_TMVA_Event
#include "TMVA/Event.h"
#endif

#ifndef ROOT_TMVA_Results
#include <TMVA/Results.h>
#endif

#ifndef ROOT_TMVA_Factory
#include <TMVA/Factory.h>
#endif

#ifndef ROOT_TMVA_DataLoader
#include <TMVA/DataLoader.h>
#endif

#ifndef ROOT_TMVA_OptionMap
#include <TMVA/OptionMap.h>
#endif

#ifndef ROOT_TMVA_Envelope
#include <TMVA/Envelope.h>
#endif

namespace TMVA {

   class CrossValidationResult {
      friend class CrossValidation;

   private:
      std::map<UInt_t,Float_t> fROCs;
      std::shared_ptr<TMultiGraph> fROCCurves;

      std::vector<Double_t> fSigs;
      std::vector<Double_t> fSeps;
      std::vector<Double_t> fEff01s;
      std::vector<Double_t> fEff10s;
      std::vector<Double_t> fEff30s;
      std::vector<Double_t> fEffAreas;
      std::vector<Double_t> fTrainEff01s;
      std::vector<Double_t> fTrainEff10s;
      std::vector<Double_t> fTrainEff30s;

   public:
      CrossValidationResult();
      CrossValidationResult(const CrossValidationResult &);
      ~CrossValidationResult(){fROCCurves=nullptr;}

      std::map<UInt_t,Float_t> GetROCValues(){return fROCs;}
      Float_t GetROCAverage() const;
      Float_t GetROCStandardDeviation() const;
      TMultiGraph *GetROCCurves(Bool_t fLegend=kTRUE);
      void Print() const ;

      TCanvas* Draw(const TString name="CrossValidation") const;

      std::vector<Double_t> GetSigValues() {return fSigs;}
      std::vector<Double_t> GetSepValues() {return fSeps;}
      std::vector<Double_t> GetEff01Values() {return fEff01s;}
      std::vector<Double_t> GetEff10Values() {return fEff10s;}
      std::vector<Double_t> GetEff30Values() {return fEff30s;}
      std::vector<Double_t> GetEffAreaValues() {return fEffAreas;}
      std::vector<Double_t> GetTrainEff01Values() {return fTrainEff01s;}
      std::vector<Double_t> GetTrainEff10Values() {return fTrainEff10s;}
      std::vector<Double_t> GetTrainEff30Values() {return fTrainEff30s;}
   };


   class CrossValidation : public Envelope {
      UInt_t                 fNumFolds;     //!
      CrossValidationResult  fResults;      //!
      Bool_t                 fFoldStatus;   //!
   public:
      explicit CrossValidation(DataLoader *loader);
      ~CrossValidation();

      void SetNumFolds(UInt_t i);
      UInt_t GetNumFolds() {return fNumFolds;}

      virtual void Evaluate();
//    void EvaluateFold(UInt_t fold);//used in ParallelExecution

      const CrossValidationResult& GetResults() const;

   private:
      std::unique_ptr<Factory> fClassifier;
      ClassDef(CrossValidation, 0);
   };

} // namespace TMVA

#endif // ROOT_TMVA_CrossValidation
 CrossValidation.h:1
 CrossValidation.h:2
 CrossValidation.h:3
 CrossValidation.h:4
 CrossValidation.h:5
 CrossValidation.h:6
 CrossValidation.h:7
 CrossValidation.h:8
 CrossValidation.h:9
 CrossValidation.h:10
 CrossValidation.h:11
 CrossValidation.h:12
 CrossValidation.h:13
 CrossValidation.h:14
 CrossValidation.h:15
 CrossValidation.h:16
 CrossValidation.h:17
 CrossValidation.h:18
 CrossValidation.h:19
 CrossValidation.h:20
 CrossValidation.h:21
 CrossValidation.h:22
 CrossValidation.h:23
 CrossValidation.h:24
 CrossValidation.h:25
 CrossValidation.h:26
 CrossValidation.h:27
 CrossValidation.h:28
 CrossValidation.h:29
 CrossValidation.h:30
 CrossValidation.h:31
 CrossValidation.h:32
 CrossValidation.h:33
 CrossValidation.h:34
 CrossValidation.h:35
 CrossValidation.h:36
 CrossValidation.h:37
 CrossValidation.h:38
 CrossValidation.h:39
 CrossValidation.h:40
 CrossValidation.h:41
 CrossValidation.h:42
 CrossValidation.h:43
 CrossValidation.h:44
 CrossValidation.h:45
 CrossValidation.h:46
 CrossValidation.h:47
 CrossValidation.h:48
 CrossValidation.h:49
 CrossValidation.h:50
 CrossValidation.h:51
 CrossValidation.h:52
 CrossValidation.h:53
 CrossValidation.h:54
 CrossValidation.h:55
 CrossValidation.h:56
 CrossValidation.h:57
 CrossValidation.h:58
 CrossValidation.h:59
 CrossValidation.h:60
 CrossValidation.h:61
 CrossValidation.h:62
 CrossValidation.h:63
 CrossValidation.h:64
 CrossValidation.h:65
 CrossValidation.h:66
 CrossValidation.h:67
 CrossValidation.h:68
 CrossValidation.h:69
 CrossValidation.h:70
 CrossValidation.h:71
 CrossValidation.h:72
 CrossValidation.h:73
 CrossValidation.h:74
 CrossValidation.h:75
 CrossValidation.h:76
 CrossValidation.h:77
 CrossValidation.h:78
 CrossValidation.h:79
 CrossValidation.h:80
 CrossValidation.h:81
 CrossValidation.h:82
 CrossValidation.h:83
 CrossValidation.h:84
 CrossValidation.h:85
 CrossValidation.h:86
 CrossValidation.h:87
 CrossValidation.h:88
 CrossValidation.h:89
 CrossValidation.h:90
 CrossValidation.h:91
 CrossValidation.h:92
 CrossValidation.h:93
 CrossValidation.h:94
 CrossValidation.h:95
 CrossValidation.h:96
 CrossValidation.h:97
 CrossValidation.h:98
 CrossValidation.h:99
 CrossValidation.h:100
 CrossValidation.h:101
 CrossValidation.h:102
 CrossValidation.h:103
 CrossValidation.h:104
 CrossValidation.h:105
 CrossValidation.h:106
 CrossValidation.h:107
 CrossValidation.h:108
 CrossValidation.h:109
 CrossValidation.h:110
 CrossValidation.h:111
 CrossValidation.h:112
 CrossValidation.h:113
 CrossValidation.h:114
 CrossValidation.h:115
 CrossValidation.h:116
 CrossValidation.h:117
 CrossValidation.h:118
 CrossValidation.h:119
 CrossValidation.h:120
 CrossValidation.h:121
 CrossValidation.h:122