// @(#)root/tmva/tmva/dnn:$Id$
// Author: Simon Pfreundschuh 08/08/16

/*************************************************************************
 * Copyright (C) 2016, Simon Pfreundschuh                                *
 * All rights reserved.                                                  *
 *                                                                       *
 * For the licensing terms see $ROOTSYS/LICENSE.                         *
 * For the list of contributors see $ROOTSYS/README/CREDITS.             *
 *************************************************************************/

/////////////////////////////////////////////////////////////////////
// Generic data loader for neural network input data. Provides a   //
// high level abstraction for the transfer of training data to the //
// device.                                                         //
/////////////////////////////////////////////////////////////////////

#ifndef TMVA_DNN_DATALOADER
#define TMVA_DNN_DATALOADER

#include "TMatrix.h"
#include <vector>
#include <iostream>

#include "TMVA/Event.h"

namespace TMVA {
namespace DNN  {

//
// Input Data Types
//______________________________________________________________________________
using MatrixInput_t    = std::pair<const TMatrixT<Double_t> &,
                                   const TMatrixT<Double_t> &>;
using TMVAInput_t      = std::vector<Event*>;

using IndexIterator_t = typename std::vector<size_t>::iterator;

/** TBatch
 *
 * Class representing training batches consisting of a matrix of input data
 * and a matrix of output data. The input and output data can be accessed using
 * the GetInput() and GetOutput() member functions.
 *
 * \tparam AArchitecture The underlying architecture.
 */
//______________________________________________________________________________
template <typename AArchitecture>
class TBatch
{
private:

   using Matrix_t       = typename AArchitecture::Matrix_t;

   Matrix_t fInputMatrix;
   Matrix_t fOutputMatrix;

public:

   TBatch(Matrix_t &, Matrix_t &);
   TBatch(const TBatch  &) = default;
   TBatch(      TBatch &&) = default;
   TBatch & operator=(const TBatch  &) = default;
   TBatch & operator=(      TBatch &&) = default;

   /** Return the matrix representing the input data. */
   Matrix_t & GetInput()  {return fInputMatrix;}
   /** Return the matrix representing the output data. */
   Matrix_t & GetOutput() {return fOutputMatrix;}
};

template<typename Data_t, typename AArchitecture> class TDataLoader;

/** TBatchIterator
 *
 * Simple iterator class for the iterations over the training batches in
 * a given data set represented by a TDataLoader object.
 *
 * \tparam AData         The input data type.
 * \tparam AArchitecture The underlying architecture type.
 */
template<typename Data_t, typename AArchitecture>
class TBatchIterator
{
private:

   TDataLoader<Data_t, AArchitecture> & fDataLoader;
   size_t fBatchIndex;

public:

TBatchIterator(TDataLoader<Data_t, AArchitecture> & dataLoader, size_t index = 0)
: fDataLoader(dataLoader), fBatchIndex(index)
{
   // Nothing to do here.
}

   TBatch<AArchitecture> operator*() {return fDataLoader.GetBatch();}
   TBatchIterator operator++() {fBatchIndex++; return *this;}
   bool operator!=(const TBatchIterator & other) {
      return fBatchIndex != other.fBatchIndex;
   }
};

/** TDataLoader
 *
 * Service class managing the streaming of the training data from the input data
 * type to the accelerator device or the CPU. A TDataLoader object manages a number
 * of host and device buffer pairs that are used in a round-robin manner for the
 * transfer of batches to the device.
 *
 * Each TDataLoader object has an associated batch size and a number of total
 * samples in the dataset. One epoch is the number of buffers required to transfer
 * the complete training set. Using the begin() and end() member functions allows
 * the user to iterate over the batches in one epoch.
 *
 * \tparam AData The input data type.
 * \tparam AArchitecture The achitecture class of the underlying architecture.
 */
template<typename Data_t, typename AArchitecture>
class TDataLoader
{
private:

   using HostBuffer_t    = typename AArchitecture::HostBuffer_t;
   using DeviceBuffer_t  = typename AArchitecture::DeviceBuffer_t;
   using Matrix_t        = typename AArchitecture::Matrix_t;
   using BatchIterator_t = TBatchIterator<Data_t, AArchitecture>;

   const Data_t  & fData;

   size_t fNSamples;
   size_t fBatchSize;
   size_t fNInputFeatures;
   size_t fNOutputFeatures;
   size_t fBatchIndex;

   size_t fNStreams;                            ///< Number of buffer pairs.
   std::vector<DeviceBuffer_t> fDeviceBuffers;
   std::vector<HostBuffer_t>   fHostBuffers;

   std::vector<size_t> fSampleIndices; ///< Ordering of the samples in the epoch.

public:

   TDataLoader(const Data_t & data, size_t nSamples, size_t batchSize,
               size_t nInputFeatures, size_t nOutputFeatures, size_t nStreams = 1);
   TDataLoader(const TDataLoader  &) = default;
   TDataLoader(      TDataLoader &&) = default;
   TDataLoader & operator=(const TDataLoader  &) = default;
   TDataLoader & operator=(      TDataLoader &&) = default;

   /** Copy input matrix into the given host buffer. Function to be specialized by
    *  the architecture-specific backend. */
   void  CopyInput(HostBuffer_t &buffer, IndexIterator_t begin, size_t batchSize);
   /** Copy output matrix into the given host buffer. Function to be specialized
    * by the architecture-spcific backend. */
   void CopyOutput(HostBuffer_t &buffer, IndexIterator_t begin, size_t batchSize);

   BatchIterator_t begin() {return TBatchIterator<Data_t, AArchitecture>(*this);}
   BatchIterator_t end()
   {
      return TBatchIterator<Data_t, AArchitecture>(*this, fNSamples / fBatchSize);
   }

   /** Shuffle the order of the samples in the batch. The shuffling is indirect,
    *  i.e. only the indices are shuffled. No input data is moved by this
    * routine. */
   void Shuffle();

   /** Return the next batch from the training set. The TDataLoader object
    *  keeps an internal counter that cycles over the batches in the training
    *  set. */
   TBatch<AArchitecture> GetBatch();

};

//
// TBatch Class.
//______________________________________________________________________________
template<typename AArchitecture>
TBatch<AArchitecture>::TBatch(Matrix_t & inputMatrix, Matrix_t & outputMatrix)
    : fInputMatrix(inputMatrix), fOutputMatrix(outputMatrix)
{
    // Nothing to do here.
}

//
// TDataLoader Class.
//______________________________________________________________________________
template<typename Data_t, typename AArchitecture>
TDataLoader<Data_t, AArchitecture>::TDataLoader(
    const Data_t & data, size_t nSamples, size_t batchSize,
    size_t nInputFeatures, size_t nOutputFeatures, size_t nStreams)
    : fData(data), fNSamples(nSamples), fBatchSize(batchSize),
      fNInputFeatures(nInputFeatures), fNOutputFeatures(nOutputFeatures),
      fBatchIndex(0), fNStreams(nStreams), fDeviceBuffers(), fHostBuffers(),
      fSampleIndices()
{
   size_t inputMatrixSize  = fBatchSize * fNInputFeatures;
   size_t outputMatrixSize = fBatchSize * fNOutputFeatures;

   for (size_t i = 0; i < fNStreams; i++)
   {
      fHostBuffers.push_back(HostBuffer_t(inputMatrixSize + outputMatrixSize));
      fDeviceBuffers.push_back(DeviceBuffer_t(inputMatrixSize + outputMatrixSize));
   }

   fSampleIndices.reserve(fNSamples);
   for (size_t i = 0; i < fNSamples; i++) {
      fSampleIndices.push_back(i);
   }
}

//______________________________________________________________________________
template<typename Data_t, typename AArchitecture>
TBatch<AArchitecture> TDataLoader<Data_t, AArchitecture>::GetBatch()
{
   fBatchIndex %= (fNSamples / fBatchSize); // Cycle through samples.


   size_t inputMatrixSize  = fBatchSize * fNInputFeatures;
   size_t outputMatrixSize = fBatchSize * fNOutputFeatures;

   size_t streamIndex = fBatchIndex % fNStreams;
   HostBuffer_t   & hostBuffer   = fHostBuffers[streamIndex];
   DeviceBuffer_t & deviceBuffer = fDeviceBuffers[streamIndex];

   HostBuffer_t inputHostBuffer  = hostBuffer.GetSubBuffer(0, inputMatrixSize);
   HostBuffer_t outputHostBuffer = hostBuffer.GetSubBuffer(inputMatrixSize,
                                                           outputMatrixSize);

   DeviceBuffer_t inputDeviceBuffer  = deviceBuffer.GetSubBuffer(0, inputMatrixSize);
   DeviceBuffer_t outputDeviceBuffer = deviceBuffer.GetSubBuffer(inputMatrixSize,
                                                                 outputMatrixSize);
   size_t sampleIndex = fBatchIndex * fBatchSize;
   IndexIterator_t sampleIndexIterator = fSampleIndices.begin() + sampleIndex;

   CopyInput(inputHostBuffer,   sampleIndexIterator, fBatchSize);
   CopyOutput(outputHostBuffer, sampleIndexIterator, fBatchSize);

   deviceBuffer.CopyFrom(hostBuffer);
   Matrix_t  inputMatrix(inputDeviceBuffer,  fBatchSize, fNInputFeatures);
   Matrix_t outputMatrix(outputDeviceBuffer, fBatchSize, fNOutputFeatures);

   fBatchIndex++;
   return TBatch<AArchitecture>(inputMatrix, outputMatrix);
}

//______________________________________________________________________________
template<typename Data_t, typename AArchitecture>
void TDataLoader<Data_t, AArchitecture>::Shuffle()
{
   std::random_shuffle(fSampleIndices.begin(), fSampleIndices.end());
}

} // namespace DNN
} // namespace TMVA

#endif
 DataLoader.h:1
 DataLoader.h:2
 DataLoader.h:3
 DataLoader.h:4
 DataLoader.h:5
 DataLoader.h:6
 DataLoader.h:7
 DataLoader.h:8
 DataLoader.h:9
 DataLoader.h:10
 DataLoader.h:11
 DataLoader.h:12
 DataLoader.h:13
 DataLoader.h:14
 DataLoader.h:15
 DataLoader.h:16
 DataLoader.h:17
 DataLoader.h:18
 DataLoader.h:19
 DataLoader.h:20
 DataLoader.h:21
 DataLoader.h:22
 DataLoader.h:23
 DataLoader.h:24
 DataLoader.h:25
 DataLoader.h:26
 DataLoader.h:27
 DataLoader.h:28
 DataLoader.h:29
 DataLoader.h:30
 DataLoader.h:31
 DataLoader.h:32
 DataLoader.h:33
 DataLoader.h:34
 DataLoader.h:35
 DataLoader.h:36
 DataLoader.h:37
 DataLoader.h:38
 DataLoader.h:39
 DataLoader.h:40
 DataLoader.h:41
 DataLoader.h:42
 DataLoader.h:43
 DataLoader.h:44
 DataLoader.h:45
 DataLoader.h:46
 DataLoader.h:47
 DataLoader.h:48
 DataLoader.h:49
 DataLoader.h:50
 DataLoader.h:51
 DataLoader.h:52
 DataLoader.h:53
 DataLoader.h:54
 DataLoader.h:55
 DataLoader.h:56
 DataLoader.h:57
 DataLoader.h:58
 DataLoader.h:59
 DataLoader.h:60
 DataLoader.h:61
 DataLoader.h:62
 DataLoader.h:63
 DataLoader.h:64
 DataLoader.h:65
 DataLoader.h:66
 DataLoader.h:67
 DataLoader.h:68
 DataLoader.h:69
 DataLoader.h:70
 DataLoader.h:71
 DataLoader.h:72
 DataLoader.h:73
 DataLoader.h:74
 DataLoader.h:75
 DataLoader.h:76
 DataLoader.h:77
 DataLoader.h:78
 DataLoader.h:79
 DataLoader.h:80
 DataLoader.h:81
 DataLoader.h:82
 DataLoader.h:83
 DataLoader.h:84
 DataLoader.h:85
 DataLoader.h:86
 DataLoader.h:87
 DataLoader.h:88
 DataLoader.h:89
 DataLoader.h:90
 DataLoader.h:91
 DataLoader.h:92
 DataLoader.h:93
 DataLoader.h:94
 DataLoader.h:95
 DataLoader.h:96
 DataLoader.h:97
 DataLoader.h:98
 DataLoader.h:99
 DataLoader.h:100
 DataLoader.h:101
 DataLoader.h:102
 DataLoader.h:103
 DataLoader.h:104
 DataLoader.h:105
 DataLoader.h:106
 DataLoader.h:107
 DataLoader.h:108
 DataLoader.h:109
 DataLoader.h:110
 DataLoader.h:111
 DataLoader.h:112
 DataLoader.h:113
 DataLoader.h:114
 DataLoader.h:115
 DataLoader.h:116
 DataLoader.h:117
 DataLoader.h:118
 DataLoader.h:119
 DataLoader.h:120
 DataLoader.h:121
 DataLoader.h:122
 DataLoader.h:123
 DataLoader.h:124
 DataLoader.h:125
 DataLoader.h:126
 DataLoader.h:127
 DataLoader.h:128
 DataLoader.h:129
 DataLoader.h:130
 DataLoader.h:131
 DataLoader.h:132
 DataLoader.h:133
 DataLoader.h:134
 DataLoader.h:135
 DataLoader.h:136
 DataLoader.h:137
 DataLoader.h:138
 DataLoader.h:139
 DataLoader.h:140
 DataLoader.h:141
 DataLoader.h:142
 DataLoader.h:143
 DataLoader.h:144
 DataLoader.h:145
 DataLoader.h:146
 DataLoader.h:147
 DataLoader.h:148
 DataLoader.h:149
 DataLoader.h:150
 DataLoader.h:151
 DataLoader.h:152
 DataLoader.h:153
 DataLoader.h:154
 DataLoader.h:155
 DataLoader.h:156
 DataLoader.h:157
 DataLoader.h:158
 DataLoader.h:159
 DataLoader.h:160
 DataLoader.h:161
 DataLoader.h:162
 DataLoader.h:163
 DataLoader.h:164
 DataLoader.h:165
 DataLoader.h:166
 DataLoader.h:167
 DataLoader.h:168
 DataLoader.h:169
 DataLoader.h:170
 DataLoader.h:171
 DataLoader.h:172
 DataLoader.h:173
 DataLoader.h:174
 DataLoader.h:175
 DataLoader.h:176
 DataLoader.h:177
 DataLoader.h:178
 DataLoader.h:179
 DataLoader.h:180
 DataLoader.h:181
 DataLoader.h:182
 DataLoader.h:183
 DataLoader.h:184
 DataLoader.h:185
 DataLoader.h:186
 DataLoader.h:187
 DataLoader.h:188
 DataLoader.h:189
 DataLoader.h:190
 DataLoader.h:191
 DataLoader.h:192
 DataLoader.h:193
 DataLoader.h:194
 DataLoader.h:195
 DataLoader.h:196
 DataLoader.h:197
 DataLoader.h:198
 DataLoader.h:199
 DataLoader.h:200
 DataLoader.h:201
 DataLoader.h:202
 DataLoader.h:203
 DataLoader.h:204
 DataLoader.h:205
 DataLoader.h:206
 DataLoader.h:207
 DataLoader.h:208
 DataLoader.h:209
 DataLoader.h:210
 DataLoader.h:211
 DataLoader.h:212
 DataLoader.h:213
 DataLoader.h:214
 DataLoader.h:215
 DataLoader.h:216
 DataLoader.h:217
 DataLoader.h:218
 DataLoader.h:219
 DataLoader.h:220
 DataLoader.h:221
 DataLoader.h:222
 DataLoader.h:223
 DataLoader.h:224
 DataLoader.h:225
 DataLoader.h:226
 DataLoader.h:227
 DataLoader.h:228
 DataLoader.h:229
 DataLoader.h:230
 DataLoader.h:231
 DataLoader.h:232
 DataLoader.h:233
 DataLoader.h:234
 DataLoader.h:235
 DataLoader.h:236
 DataLoader.h:237
 DataLoader.h:238
 DataLoader.h:239
 DataLoader.h:240
 DataLoader.h:241
 DataLoader.h:242
 DataLoader.h:243
 DataLoader.h:244
 DataLoader.h:245
 DataLoader.h:246
 DataLoader.h:247
 DataLoader.h:248
 DataLoader.h:249
 DataLoader.h:250
 DataLoader.h:251
 DataLoader.h:252
 DataLoader.h:253
 DataLoader.h:254
 DataLoader.h:255
 DataLoader.h:256
 DataLoader.h:257
 DataLoader.h:258
 DataLoader.h:259
 DataLoader.h:260