// @(#)root/tmva $Id$
// Author: Peter Speckmayer

/**********************************************************************************
 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
 * Package: TMVA                                                                  *
 * Class  : MethodDNN                                                              *
 * Web    : http://tmva.sourceforge.net                                           *
 *                                                                                *
 * Description:                                                                   *
 *      NeuralNetwork                                                             *
 *                                                                                *
 * Authors (alphabetical):                                                        *
 *      Peter Speckmayer      <peter.speckmayer@gmx.at>  - CERN, Switzerland      *
 *      Simon Pfreundschuh    <s.pfreundschuh@gmail.com> - CERN, Switzerland      *
 *                                                                                *
 * Copyright (c) 2005-2015:                                                       *
 *      CERN, Switzerland                                                         *
 *      U. of Victoria, Canada                                                    *
 *      MPI-K Heidelberg, Germany                                                 *
 *      U. of Bonn, Germany                                                       *
 *                                                                                *
 * Redistribution and use in source and binary forms, with or without             *
 * modification, are permitted according to the terms listed in LICENSE           *
 * (http://tmva.sourceforge.net/LICENSE)                                          *
 **********************************************************************************/

//#pragma once

#ifndef ROOT_TMVA_MethodDNN
#define ROOT_TMVA_MethodDNN

//////////////////////////////////////////////////////////////////////////
//                                                                      //
// MethodDNN                                                             //
//                                                                      //
// Neural Network implementation                                        //
//                                                                      //
//////////////////////////////////////////////////////////////////////////

#include <vector>
#ifndef ROOT_TString
#include "TString.h"
#endif
#ifndef ROOT_TTree
#include "TTree.h"
#endif
#ifndef ROOT_TRandom3
#include "TRandom3.h"
#endif
#ifndef ROOT_TH1F
#include "TH1F.h"
#endif
#ifndef ROOT_TMVA_MethodBase
#include "TMVA/MethodBase.h"
#endif
#ifndef TMVA_NEURAL_NET
#include "TMVA/NeuralNet.h"
#endif

#include "TMVA/Tools.h"

#include "TMVA/DNN/Net.h"
#include "TMVA/DNN/Minimizers.h"
#include "TMVA/DNN/Architectures/Reference.h"

#ifdef DNNCPU
#include "TMVA/DNN/Architectures/Cpu.h"
#endif

#ifdef DNNCUDA
#include "TMVA/DNN/Architectures/Cuda.h"
#endif

using namespace TMVA::DNN;

namespace TMVA {

class MethodDNN : public MethodBase
{
    using Architecture_t = TReference<Double_t>;
    using Net_t          = TNet<Architecture_t>;
    using Matrix_t       = typename Architecture_t::Matrix_t;

private:

   using LayoutVector_t   = std::vector<std::pair<int, EActivationFunction>>;
   using KeyValueVector_t = std::vector<std::map<TString, TString>>;

   struct TTrainingSettings
   {
       size_t                batchSize;
       size_t                testInterval;
       size_t                convergenceSteps;
       ERegularization       regularization;
       Double_t              learningRate;
       Double_t              momentum;
       Double_t              weightDecay;
       std::vector<Double_t> dropoutProbabilities;
       bool                  multithreading;
   };

   // the option handling methods
   void DeclareOptions();
   void ProcessOptions();

   // general helper functions
   void     Init();

   Net_t             fNet;
   EInitialization   fWeightInitialization;
   EOutputFunction   fOutputFunction;

   TString                        fLayoutString;
   TString                        fErrorStrategy;
   TString                        fTrainingStrategyString;
   TString                        fWeightInitializationString;
   TString                        fArchitectureString;
   LayoutVector_t                 fLayout;
   std::vector<TTrainingSettings> fTrainingSettings;
   bool                           fResume;

   KeyValueVector_t fSettings;

   ClassDef(MethodDNN,0); // neural network

   static inline void WriteMatrixXML(void *parent, const char *name,
                                     const TMatrixT<Double_t> &X);
   static inline void ReadMatrixXML(void *xml, const char *name,
                                    TMatrixT<Double_t> &X);
protected:

   void MakeClassSpecific( std::ostream&, const TString& ) const;
   void GetHelpMessage() const;

public:

   // Standard Constructors
   MethodDNN(const TString& jobName,
             const TString&  methodTitle,
             DataSetInfo& theData,
             const TString& theOption);
   MethodDNN(DataSetInfo& theData,
             const TString& theWeightFile);
   virtual ~MethodDNN();

   virtual Bool_t HasAnalysisType(Types::EAnalysisType type,
                                  UInt_t numberClasses,
                                  UInt_t numberTargets );
   LayoutVector_t   ParseLayoutString(TString layerSpec);
   KeyValueVector_t ParseKeyValueString(TString parseString,
                                      TString blockDelim,
                                      TString tokenDelim);
   void Train();
   void TrainGpu();
   void TrainCpu();

   virtual Double_t GetMvaValue( Double_t* err=0, Double_t* errUpper=0 );
   virtual const std::vector<Float_t>& GetRegressionValues();
   virtual const std::vector<Float_t>& GetMulticlassValues();

   using MethodBase::ReadWeightsFromStream;

   // write weights to stream
   void AddWeightsXMLTo     ( void* parent ) const;

   // read weights from stream
   void ReadWeightsFromStream( std::istream & i );
   void ReadWeightsFromXML   ( void* wghtnode );

   // ranking of input variables
   const Ranking* CreateRanking();

};

inline void MethodDNN::WriteMatrixXML(void *parent,
                                      const char *name,
                                      const TMatrixT<Double_t> &X)
{
   std::stringstream matrixStringStream("");
   matrixStringStream.precision( 16 );

   for (size_t i = 0; i < (size_t) X.GetNrows(); i++)
   {
      for (size_t j = 0; j < (size_t) X.GetNcols(); j++)
      {
         matrixStringStream << std::scientific << X(i,j) << " ";
      }
   }
   std::string s = matrixStringStream.str();
   void* matxml = gTools().xmlengine().NewChild(parent, 0, name);
   gTools().xmlengine().NewAttr(matxml, 0, "rows",
                                gTools().StringFromInt((int)X.GetNrows()));
   gTools().xmlengine().NewAttr(matxml, 0, "cols",
                                gTools().StringFromInt((int)X.GetNcols()));
   gTools().xmlengine().AddRawLine (matxml, s.c_str());
}

inline void MethodDNN::ReadMatrixXML(void *xml,
                                     const char *name,
                                     TMatrixT<Double_t> &X)
{
   void *matrixXML = gTools().GetChild(xml, name);
   size_t rows, cols;
   gTools().ReadAttr(matrixXML, "rows", rows);
   gTools().ReadAttr(matrixXML, "cols", cols);

   const char * matrixString = gTools().xmlengine().GetNodeContent(matrixXML);
   std::stringstream matrixStringStream(matrixString);

   for (size_t i = 0; i < rows; i++)
   {
      for (size_t j = 0; j < cols; j++)
      {
         matrixStringStream >> X(i,j);
      }
   }
}
} // namespace TMVA

#endif
 MethodDNN.h:1
 MethodDNN.h:2
 MethodDNN.h:3
 MethodDNN.h:4
 MethodDNN.h:5
 MethodDNN.h:6
 MethodDNN.h:7
 MethodDNN.h:8
 MethodDNN.h:9
 MethodDNN.h:10
 MethodDNN.h:11
 MethodDNN.h:12
 MethodDNN.h:13
 MethodDNN.h:14
 MethodDNN.h:15
 MethodDNN.h:16
 MethodDNN.h:17
 MethodDNN.h:18
 MethodDNN.h:19
 MethodDNN.h:20
 MethodDNN.h:21
 MethodDNN.h:22
 MethodDNN.h:23
 MethodDNN.h:24
 MethodDNN.h:25
 MethodDNN.h:26
 MethodDNN.h:27
 MethodDNN.h:28
 MethodDNN.h:29
 MethodDNN.h:30
 MethodDNN.h:31
 MethodDNN.h:32
 MethodDNN.h:33
 MethodDNN.h:34
 MethodDNN.h:35
 MethodDNN.h:36
 MethodDNN.h:37
 MethodDNN.h:38
 MethodDNN.h:39
 MethodDNN.h:40
 MethodDNN.h:41
 MethodDNN.h:42
 MethodDNN.h:43
 MethodDNN.h:44
 MethodDNN.h:45
 MethodDNN.h:46
 MethodDNN.h:47
 MethodDNN.h:48
 MethodDNN.h:49
 MethodDNN.h:50
 MethodDNN.h:51
 MethodDNN.h:52
 MethodDNN.h:53
 MethodDNN.h:54
 MethodDNN.h:55
 MethodDNN.h:56
 MethodDNN.h:57
 MethodDNN.h:58
 MethodDNN.h:59
 MethodDNN.h:60
 MethodDNN.h:61
 MethodDNN.h:62
 MethodDNN.h:63
 MethodDNN.h:64
 MethodDNN.h:65
 MethodDNN.h:66
 MethodDNN.h:67
 MethodDNN.h:68
 MethodDNN.h:69
 MethodDNN.h:70
 MethodDNN.h:71
 MethodDNN.h:72
 MethodDNN.h:73
 MethodDNN.h:74
 MethodDNN.h:75
 MethodDNN.h:76
 MethodDNN.h:77
 MethodDNN.h:78
 MethodDNN.h:79
 MethodDNN.h:80
 MethodDNN.h:81
 MethodDNN.h:82
 MethodDNN.h:83
 MethodDNN.h:84
 MethodDNN.h:85
 MethodDNN.h:86
 MethodDNN.h:87
 MethodDNN.h:88
 MethodDNN.h:89
 MethodDNN.h:90
 MethodDNN.h:91
 MethodDNN.h:92
 MethodDNN.h:93
 MethodDNN.h:94
 MethodDNN.h:95
 MethodDNN.h:96
 MethodDNN.h:97
 MethodDNN.h:98
 MethodDNN.h:99
 MethodDNN.h:100
 MethodDNN.h:101
 MethodDNN.h:102
 MethodDNN.h:103
 MethodDNN.h:104
 MethodDNN.h:105
 MethodDNN.h:106
 MethodDNN.h:107
 MethodDNN.h:108
 MethodDNN.h:109
 MethodDNN.h:110
 MethodDNN.h:111
 MethodDNN.h:112
 MethodDNN.h:113
 MethodDNN.h:114
 MethodDNN.h:115
 MethodDNN.h:116
 MethodDNN.h:117
 MethodDNN.h:118
 MethodDNN.h:119
 MethodDNN.h:120
 MethodDNN.h:121
 MethodDNN.h:122
 MethodDNN.h:123
 MethodDNN.h:124
 MethodDNN.h:125
 MethodDNN.h:126
 MethodDNN.h:127
 MethodDNN.h:128
 MethodDNN.h:129
 MethodDNN.h:130
 MethodDNN.h:131
 MethodDNN.h:132
 MethodDNN.h:133
 MethodDNN.h:134
 MethodDNN.h:135
 MethodDNN.h:136
 MethodDNN.h:137
 MethodDNN.h:138
 MethodDNN.h:139
 MethodDNN.h:140
 MethodDNN.h:141
 MethodDNN.h:142
 MethodDNN.h:143
 MethodDNN.h:144
 MethodDNN.h:145
 MethodDNN.h:146
 MethodDNN.h:147
 MethodDNN.h:148
 MethodDNN.h:149
 MethodDNN.h:150
 MethodDNN.h:151
 MethodDNN.h:152
 MethodDNN.h:153
 MethodDNN.h:154
 MethodDNN.h:155
 MethodDNN.h:156
 MethodDNN.h:157
 MethodDNN.h:158
 MethodDNN.h:159
 MethodDNN.h:160
 MethodDNN.h:161
 MethodDNN.h:162
 MethodDNN.h:163
 MethodDNN.h:164
 MethodDNN.h:165
 MethodDNN.h:166
 MethodDNN.h:167
 MethodDNN.h:168
 MethodDNN.h:169
 MethodDNN.h:170
 MethodDNN.h:171
 MethodDNN.h:172
 MethodDNN.h:173
 MethodDNN.h:174
 MethodDNN.h:175
 MethodDNN.h:176
 MethodDNN.h:177
 MethodDNN.h:178
 MethodDNN.h:179
 MethodDNN.h:180
 MethodDNN.h:181
 MethodDNN.h:182
 MethodDNN.h:183
 MethodDNN.h:184
 MethodDNN.h:185
 MethodDNN.h:186
 MethodDNN.h:187
 MethodDNN.h:188
 MethodDNN.h:189
 MethodDNN.h:190
 MethodDNN.h:191
 MethodDNN.h:192
 MethodDNN.h:193
 MethodDNN.h:194
 MethodDNN.h:195
 MethodDNN.h:196
 MethodDNN.h:197
 MethodDNN.h:198
 MethodDNN.h:199
 MethodDNN.h:200
 MethodDNN.h:201
 MethodDNN.h:202
 MethodDNN.h:203
 MethodDNN.h:204
 MethodDNN.h:205
 MethodDNN.h:206
 MethodDNN.h:207
 MethodDNN.h:208
 MethodDNN.h:209
 MethodDNN.h:210
 MethodDNN.h:211
 MethodDNN.h:212
 MethodDNN.h:213
 MethodDNN.h:214
 MethodDNN.h:215
 MethodDNN.h:216
 MethodDNN.h:217
 MethodDNN.h:218
 MethodDNN.h:219
 MethodDNN.h:220
 MethodDNN.h:221