30#ifndef TMVA_DNN_LSTM_LAYER 
   31#define TMVA_DNN_LSTM_LAYER 
   55template<
typename Architecture_t>
 
   61   using Matrix_t = 
typename Architecture_t::Matrix_t;
 
   62   using Scalar_t = 
typename Architecture_t::Scalar_t;
 
   63   using Tensor_t = 
typename Architecture_t::Tensor_t;
 
  147   TBasicLSTMLayer(
size_t batchSize, 
size_t stateSize, 
size_t inputSize, 
size_t timeSteps, 
bool rememberState = 
false,
 
  148                   bool returnSequence = 
false,
 
  174                 const Tensor_t &activations_backward);
 
  183                           const Matrix_t & precStateActivations, 
const Matrix_t & precCellActivations,
 
  340template <
typename Architecture_t>
 
  346        batchSize, 1, timeSteps, inputSize, 1, (returnSequence) ? timeSteps : 1, stateSize, 8,
 
  347        {stateSize, stateSize, stateSize, stateSize, stateSize, stateSize, stateSize, stateSize},
 
  348        {inputSize, inputSize, inputSize, inputSize, stateSize, stateSize, stateSize, stateSize}, 4,
 
  349        {stateSize, stateSize, stateSize, stateSize}, {1, 1, 1, 1}, batchSize, (returnSequence) ? timeSteps : 1,
 
  351     fStateSize(stateSize), fCellSize(stateSize), fTimeSteps(timeSteps), fRememberState(rememberState),
 
  352     fReturnSequence(returnSequence), fF1(
f1), fF2(f2), fInputValue(batchSize, stateSize),
 
  353     fCandidateValue(batchSize, stateSize), fForgetValue(batchSize, stateSize), fOutputValue(batchSize, stateSize),
 
  354     fState(batchSize, stateSize), fCell(batchSize, stateSize), fWeightsInputGate(this->GetWeightsAt(0)),
 
  355     fWeightsInputGateState(this->GetWeightsAt(4)), fInputGateBias(this->GetBiasesAt(0)),
 
  356     fWeightsForgetGate(this->GetWeightsAt(1)), fWeightsForgetGateState(this->GetWeightsAt(5)),
 
  357     fForgetGateBias(this->GetBiasesAt(1)), fWeightsCandidate(this->GetWeightsAt(2)),
 
  358     fWeightsCandidateState(this->GetWeightsAt(6)), fCandidateBias(this->GetBiasesAt(2)),
 
  359     fWeightsOutputGate(this->GetWeightsAt(3)), fWeightsOutputGateState(this->GetWeightsAt(7)),
 
  360     fOutputGateBias(this->GetBiasesAt(3)), fWeightsInputGradients(this->GetWeightGradientsAt(0)),
 
  361     fWeightsInputStateGradients(this->GetWeightGradientsAt(4)), fInputBiasGradients(this->GetBiasGradientsAt(0)),
 
  362     fWeightsForgetGradients(this->GetWeightGradientsAt(1)),
 
  363     fWeightsForgetStateGradients(this->GetWeightGradientsAt(5)), fForgetBiasGradients(this->GetBiasGradientsAt(1)),
 
  364     fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
 
  365     fWeightsCandidateStateGradients(this->GetWeightGradientsAt(6)),
 
  366     fCandidateBiasGradients(this->GetBiasGradientsAt(2)), fWeightsOutputGradients(this->GetWeightGradientsAt(3)),
 
  367     fWeightsOutputStateGradients(this->GetWeightGradientsAt(7)), fOutputBiasGradients(this->GetBiasGradientsAt(3))
 
  369   for (
size_t i = 0; i < timeSteps; ++i) {
 
  378      cell_value.emplace_back(batchSize, stateSize);
 
  380   Architecture_t::InitializeLSTMTensors(
this);
 
  384template <
typename Architecture_t>
 
  387      fStateSize(layer.fStateSize),
 
  388      fCellSize(layer.fCellSize),
 
  389      fTimeSteps(layer.fTimeSteps),
 
  390      fRememberState(layer.fRememberState),
 
  391      fReturnSequence(layer.fReturnSequence),
 
  392      fF1(layer.GetActivationFunctionF1()),
 
  393      fF2(layer.GetActivationFunctionF2()),
 
  394      fInputValue(layer.GetBatchSize(), layer.GetStateSize()),
 
  395      fCandidateValue(layer.GetBatchSize(), layer.GetStateSize()),
 
  396      fForgetValue(layer.GetBatchSize(), layer.GetStateSize()),
 
  397      fOutputValue(layer.GetBatchSize(), layer.GetStateSize()),
 
  398      fState(layer.GetBatchSize(), layer.GetStateSize()),
 
  399      fCell(layer.GetBatchSize(), layer.GetCellSize()),
 
  400      fWeightsInputGate(this->GetWeightsAt(0)),
 
  401      fWeightsInputGateState(this->GetWeightsAt(4)),
 
  402      fInputGateBias(this->GetBiasesAt(0)),
 
  403      fWeightsForgetGate(this->GetWeightsAt(1)),
 
  404      fWeightsForgetGateState(this->GetWeightsAt(5)),
 
  405      fForgetGateBias(this->GetBiasesAt(1)),
 
  406      fWeightsCandidate(this->GetWeightsAt(2)),
 
  407      fWeightsCandidateState(this->GetWeightsAt(6)),
 
  408      fCandidateBias(this->GetBiasesAt(2)),
 
  409      fWeightsOutputGate(this->GetWeightsAt(3)),
 
  410      fWeightsOutputGateState(this->GetWeightsAt(7)),
 
  411      fOutputGateBias(this->GetBiasesAt(3)),
 
  412      fWeightsInputGradients(this->GetWeightGradientsAt(0)),
 
  413      fWeightsInputStateGradients(this->GetWeightGradientsAt(4)),
 
  414      fInputBiasGradients(this->GetBiasGradientsAt(0)),
 
  415      fWeightsForgetGradients(this->GetWeightGradientsAt(1)),
 
  416      fWeightsForgetStateGradients(this->GetWeightGradientsAt(5)),
 
  417      fForgetBiasGradients(this->GetBiasGradientsAt(1)),
 
  418      fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
 
  419      fWeightsCandidateStateGradients(this->GetWeightGradientsAt(6)),
 
  420      fCandidateBiasGradients(this->GetBiasGradientsAt(2)),
 
  421      fWeightsOutputGradients(this->GetWeightGradientsAt(3)),
 
  422      fWeightsOutputStateGradients(this->GetWeightGradientsAt(7)),
 
  423      fOutputBiasGradients(this->GetBiasGradientsAt(3))
 
  464   Architecture_t::InitializeLSTMTensors(
this);
 
  468template <
typename Architecture_t>
 
  473   Architecture_t::InitializeLSTMDescriptors(fDescriptors, 
this);
 
  474   Architecture_t::InitializeLSTMWorkspace(fWorkspace, fDescriptors, 
this);
 
  478template <
typename Architecture_t>
 
  486   Matrix_t tmpState(fInputValue.GetNrows(), fInputValue.GetNcols());
 
  487   Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsInputGateState);
 
  488   Architecture_t::MultiplyTranspose(fInputValue, 
input, fWeightsInputGate);
 
  489   Architecture_t::ScaleAdd(fInputValue, tmpState);
 
  490   Architecture_t::AddRowWise(fInputValue, fInputGateBias);
 
  491   DNN::evaluateDerivativeMatrix<Architecture_t>(di, fInp, fInputValue);
 
  492   DNN::evaluateMatrix<Architecture_t>(fInputValue, fInp);
 
  496template <
typename Architecture_t>
 
  504   Matrix_t tmpState(fForgetValue.GetNrows(), fForgetValue.GetNcols());
 
  505   Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsForgetGateState);
 
  506   Architecture_t::MultiplyTranspose(fForgetValue, 
input, fWeightsForgetGate);
 
  507   Architecture_t::ScaleAdd(fForgetValue, tmpState);
 
  508   Architecture_t::AddRowWise(fForgetValue, fForgetGateBias);
 
  509   DNN::evaluateDerivativeMatrix<Architecture_t>(df, fFor, fForgetValue);
 
  510   DNN::evaluateMatrix<Architecture_t>(fForgetValue, fFor);
 
  514template <
typename Architecture_t>
 
  522   Matrix_t tmpState(fCandidateValue.GetNrows(), fCandidateValue.GetNcols());
 
  523   Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsCandidateState);
 
  524   Architecture_t::MultiplyTranspose(fCandidateValue, 
input, fWeightsCandidate);
 
  525   Architecture_t::ScaleAdd(fCandidateValue, tmpState);
 
  526   Architecture_t::AddRowWise(fCandidateValue, fCandidateBias);
 
  527   DNN::evaluateDerivativeMatrix<Architecture_t>(dc, fCan, fCandidateValue);
 
  528   DNN::evaluateMatrix<Architecture_t>(fCandidateValue, fCan);
 
  532template <
typename Architecture_t>
 
  540   Matrix_t tmpState(fOutputValue.GetNrows(), fOutputValue.GetNcols());
 
  541   Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsOutputGateState);
 
  542   Architecture_t::MultiplyTranspose(fOutputValue, 
input, fWeightsOutputGate);
 
  543   Architecture_t::ScaleAdd(fOutputValue, tmpState);
 
  544   Architecture_t::AddRowWise(fOutputValue, fOutputGateBias);
 
  545   DNN::evaluateDerivativeMatrix<Architecture_t>(dout, fOut, fOutputValue);
 
  546   DNN::evaluateMatrix<Architecture_t>(fOutputValue, fOut);
 
  552template <
typename Architecture_t>
 
  558   if (Architecture_t::IsCudnn()) {
 
  561      assert(
input.GetStrides()[1] == this->GetInputSize());
 
  565      Architecture_t::Rearrange(
x, 
input);
 
  567      const auto &weights = this->GetWeightsAt(0);
 
  572      auto &hx = this->fState;
 
  574      auto &cx = this->fCell; 
 
  576      auto &hy = this->fState;
 
  577      auto &cy = this->fCell;
 
  582      Architecture_t::RNNForward(
x, hx, cx, weights, 
y, hy, cy, rnnDesc, rnnWork, isTraining);
 
  584      if (fReturnSequence) {
 
  585         Architecture_t::Rearrange(this->GetOutput(), 
y); 
 
  588         Tensor_t tmp = (
y.At(
y.GetShape()[0] - 1)).Reshape({
y.GetShape()[1], 1, 
y.GetShape()[2]});
 
  589         Architecture_t::Copy(this->GetOutput(), tmp);
 
  602   Tensor_t arrInput( fTimeSteps, this->GetBatchSize(), this->GetInputWidth());
 
  605   Architecture_t::Rearrange(arrInput, 
input); 
 
  607   Tensor_t arrOutput ( fTimeSteps, this->GetBatchSize(), fStateSize);
 
  610   if (!this->fRememberState) {
 
  616   for (
size_t t = 0; t < fTimeSteps; ++t) {
 
  619      InputGate(arrInputMt, fDerivativesInput[t]);
 
  620      ForgetGate(arrInputMt, fDerivativesForget[t]);
 
  621      CandidateValue(arrInputMt, fDerivativesCandidate[t]);
 
  622      OutputGate(arrInputMt, fDerivativesOutput[t]);
 
  624      Architecture_t::Copy(this->GetInputGateTensorAt(t), fInputValue);
 
  625      Architecture_t::Copy(this->GetForgetGateTensorAt(t), fForgetValue);
 
  626      Architecture_t::Copy(this->GetCandidateGateTensorAt(t), fCandidateValue);
 
  627      Architecture_t::Copy(this->GetOutputGateTensorAt(t), fOutputValue);
 
  629      CellForward(fInputValue, fForgetValue, fCandidateValue, fOutputValue);
 
  630      Matrix_t arrOutputMt = arrOutput[t];
 
  631      Architecture_t::Copy(arrOutputMt, fState);
 
  632      Architecture_t::Copy(this->GetCellTensorAt(t), fCell);
 
  637      Architecture_t::Rearrange(this->GetOutput(), arrOutput); 
 
  640      Tensor_t tmp = arrOutput.At(fTimeSteps - 1); 
 
  643      tmp = tmp.Reshape( {tmp.GetShape()[0], tmp.GetShape()[1], 1});
 
  644      assert(tmp.GetSize() == this->GetOutput().GetSize());
 
  645      assert( tmp.GetShape()[0] == this->GetOutput().GetShape()[2]);  
 
  646      Architecture_t::Rearrange(this->GetOutput(), tmp);
 
  653template <
typename Architecture_t>
 
  660   Architecture_t::Hadamard(fCell, forgetGateValues);
 
  661   Architecture_t::Hadamard(inputGateValues, candidateValues);
 
  662   Architecture_t::ScaleAdd(fCell, inputGateValues);
 
  664   Matrix_t cache(fCell.GetNrows(), fCell.GetNcols());
 
  665   Architecture_t::Copy(cache, fCell);
 
  669   DNN::evaluateMatrix<Architecture_t>(cache, fAT);
 
  674   Architecture_t::Copy(fState, cache);
 
  675   Architecture_t::Hadamard(fState, outputGateValues);
 
  679template <
typename Architecture_t>
 
  681                                                      const Tensor_t &activations_backward)   
 
  686   if (Architecture_t::IsCudnn()) {
 
  694      assert(activations_backward.GetStrides()[1] == this->GetInputSize());
 
  696      Architecture_t::Rearrange(
x, activations_backward);
 
  698      if (!fReturnSequence) {
 
  701         Architecture_t::InitializeZero(dy);
 
  706         Tensor_t tmp2 = dy.At(dy.GetShape()[0] - 1).Reshape({dy.GetShape()[1], 1, dy.GetShape()[2]});
 
  709         Architecture_t::Copy(tmp2, this->GetActivationGradients());
 
  711         Architecture_t::Rearrange(
y, this->GetOutput());
 
  712         Architecture_t::Rearrange(dy, this->GetActivationGradients());
 
  718      const auto &weights = this->GetWeightsTensor();
 
  719      auto &weightGradients = this->GetWeightGradientsTensor();
 
  722      Architecture_t::InitializeZero(weightGradients);
 
  725      auto &hx = this->GetState();
 
  726      auto &cx = this->GetCell();
 
  737      Architecture_t::RNNBackward(
x, hx, cx, 
y, dy, dhy, dcy, weights, dx, dhx, dcx, weightGradients, rnnDesc, rnnWork);
 
  741      if (gradients_backward.GetSize() != 0)
 
  742         Architecture_t::Rearrange(gradients_backward, dx);
 
  751   Matrix_t state_gradients_backward(this->GetBatchSize(), fStateSize); 
 
  755   Matrix_t cell_gradients_backward(this->GetBatchSize(), fStateSize); 
 
  760   if (gradients_backward.GetSize() == 0 || gradients_backward[0].GetNrows() == 0 || gradients_backward[0].GetNcols() == 0) {
 
  765   Tensor_t arr_gradients_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
 
  770   Tensor_t arr_activations_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
 
  772   Architecture_t::Rearrange(arr_activations_backward, activations_backward); 
 
  776   Tensor_t arr_output (  fTimeSteps, this->GetBatchSize(), fStateSize);
 
  778   Matrix_t initState(this->GetBatchSize(), fCellSize); 
 
  783   Tensor_t arr_actgradients(fTimeSteps, this->GetBatchSize(), fStateSize);
 
  785   if (fReturnSequence) {
 
  786      Architecture_t::Rearrange(arr_output, this->GetOutput());
 
  787      Architecture_t::Rearrange(arr_actgradients, this->GetActivationGradients());
 
  791      Architecture_t::InitializeZero(arr_actgradients);
 
  793      Tensor_t tmp_grad = arr_actgradients.At(fTimeSteps - 1).Reshape( {this->GetBatchSize(), fStateSize, 1});
 
  794      assert(tmp_grad.GetSize() == this->GetActivationGradients().GetSize());
 
  795      assert(tmp_grad.GetShape()[0] == this->GetActivationGradients().GetShape()[2]);  
 
  797      Architecture_t::Rearrange(tmp_grad, this->GetActivationGradients());
 
  804   fWeightsInputGradients.Zero();
 
  805   fWeightsInputStateGradients.Zero();
 
  806   fInputBiasGradients.Zero();
 
  809   fWeightsForgetGradients.Zero();
 
  810   fWeightsForgetStateGradients.Zero();
 
  811   fForgetBiasGradients.Zero();
 
  814   fWeightsCandidateGradients.Zero();
 
  815   fWeightsCandidateStateGradients.Zero();
 
  816   fCandidateBiasGradients.Zero();
 
  819   fWeightsOutputGradients.Zero();
 
  820   fWeightsOutputStateGradients.Zero();
 
  821   fOutputBiasGradients.Zero();
 
  824   for (
size_t t = fTimeSteps; t > 0; t--) {
 
  826      Architecture_t::ScaleAdd(state_gradients_backward, arr_actgradients[t-1]);
 
  828         const Matrix_t &prevStateActivations = arr_output[t-2];
 
  829         const Matrix_t &prevCellActivations = this->GetCellTensorAt(t-2);
 
  831         Matrix_t dx = arr_gradients_backward[t-1];
 
  832         CellBackward(state_gradients_backward, cell_gradients_backward,
 
  833                      prevStateActivations, prevCellActivations,
 
  834                      this->GetInputGateTensorAt(t-1), this->GetForgetGateTensorAt(t-1),
 
  835                      this->GetCandidateGateTensorAt(t-1), this->GetOutputGateTensorAt(t-1),
 
  836                      arr_activations_backward[t-1], dx,
 
  837                      fDerivativesInput[t-1], fDerivativesForget[t-1],
 
  838                      fDerivativesCandidate[t-1], fDerivativesOutput[t-1], t-1);
 
  840         const Matrix_t &prevStateActivations = initState;
 
  841         const Matrix_t &prevCellActivations = initState;
 
  842         Matrix_t dx = arr_gradients_backward[t-1];
 
  843         CellBackward(state_gradients_backward, cell_gradients_backward,
 
  844                      prevStateActivations, prevCellActivations,
 
  845                      this->GetInputGateTensorAt(t-1), this->GetForgetGateTensorAt(t-1),
 
  846                      this->GetCandidateGateTensorAt(t-1), this->GetOutputGateTensorAt(t-1),
 
  847                      arr_activations_backward[t-1], dx,
 
  848                      fDerivativesInput[t-1], fDerivativesForget[t-1],
 
  849                      fDerivativesCandidate[t-1], fDerivativesOutput[t-1], t-1);
 
  854      Architecture_t::Rearrange(gradients_backward, arr_gradients_backward );
 
  861template <
typename Architecture_t>
 
  864                                                          const Matrix_t & precStateActivations, 
const Matrix_t & precCellActivations,
 
  878   Matrix_t cell_gradient(this->GetCellTensorAt(t).GetNrows(), this->GetCellTensorAt(t).GetNcols());
 
  879   DNN::evaluateDerivativeMatrix<Architecture_t>(cell_gradient, fAT, this->GetCellTensorAt(t));
 
  882   Matrix_t cell_tanh(this->GetCellTensorAt(t).GetNrows(), this->GetCellTensorAt(t).GetNcols());
 
  883   Architecture_t::Copy(cell_tanh, this->GetCellTensorAt(t));
 
  884   DNN::evaluateMatrix<Architecture_t>(cell_tanh, fAT);
 
  886   return Architecture_t::LSTMLayerBackward(state_gradients_backward, cell_gradients_backward,
 
  887                                            fWeightsInputGradients, fWeightsForgetGradients, fWeightsCandidateGradients,
 
  888                                            fWeightsOutputGradients, fWeightsInputStateGradients, fWeightsForgetStateGradients,
 
  889                                            fWeightsCandidateStateGradients, fWeightsOutputStateGradients, fInputBiasGradients, fForgetBiasGradients,
 
  890                                            fCandidateBiasGradients, fOutputBiasGradients, di, df, dc, dout,
 
  891                                            precStateActivations, precCellActivations,
 
  892                                            input_gate, forget_gate, candidate_gate, output_gate,
 
  893                                            fWeightsInputGate, fWeightsForgetGate, fWeightsCandidate, fWeightsOutputGate,
 
  894                                            fWeightsInputGateState, fWeightsForgetGateState, fWeightsCandidateState,
 
  895                                            fWeightsOutputGateState, 
input, input_gradient,
 
  896                                            cell_gradient, cell_tanh);
 
  900template <
typename Architecture_t>
 
  909template<
typename Architecture_t>
 
  913   std::cout << 
" LSTM Layer: \t ";
 
  914   std::cout << 
" (NInput = " << this->GetInputSize();  
 
  915   std::cout << 
", NState = " << this->GetStateSize();  
 
  916   std::cout << 
", NTime  = " << this->GetTimeSteps() << 
" )";  
 
  917   std::cout << 
"\tOutput = ( " << this->GetOutput().GetFirstSize() << 
" , " << this->GetOutput()[0].GetNrows() << 
" , " << this->GetOutput()[0].GetNcols() << 
" )\n";
 
  921template <
typename Architecture_t>
 
  936   this->WriteMatrixToXML(layerxml, 
"InputWeights", this->GetWeightsAt(0));
 
  937   this->WriteMatrixToXML(layerxml, 
"InputStateWeights", this->GetWeightsAt(1));
 
  938   this->WriteMatrixToXML(layerxml, 
"InputBiases", this->GetBiasesAt(0));
 
  939   this->WriteMatrixToXML(layerxml, 
"ForgetWeights", this->GetWeightsAt(2));
 
  940   this->WriteMatrixToXML(layerxml, 
"ForgetStateWeights", this->GetWeightsAt(3));
 
  941   this->WriteMatrixToXML(layerxml, 
"ForgetBiases", this->GetBiasesAt(1));
 
  942   this->WriteMatrixToXML(layerxml, 
"CandidateWeights", this->GetWeightsAt(4));
 
  943   this->WriteMatrixToXML(layerxml, 
"CandidateStateWeights", this->GetWeightsAt(5));
 
  944   this->WriteMatrixToXML(layerxml, 
"CandidateBiases", this->GetBiasesAt(2));
 
  945   this->WriteMatrixToXML(layerxml, 
"OuputWeights", this->GetWeightsAt(6));
 
  946   this->WriteMatrixToXML(layerxml, 
"OutputStateWeights", this->GetWeightsAt(7));
 
  947   this->WriteMatrixToXML(layerxml, 
"OutputBiases", this->GetBiasesAt(3));
 
  951template <
typename Architecture_t>
 
  956   this->ReadMatrixXML(parent, 
"InputWeights", this->GetWeightsAt(0));
 
  957   this->ReadMatrixXML(parent, 
"InputStateWeights", this->GetWeightsAt(1));
 
  958   this->ReadMatrixXML(parent, 
"InputBiases", this->GetBiasesAt(0));
 
  959   this->ReadMatrixXML(parent, 
"ForgetWeights", this->GetWeightsAt(2));
 
  960   this->ReadMatrixXML(parent, 
"ForgetStateWeights", this->GetWeightsAt(3));
 
  961   this->ReadMatrixXML(parent, 
"ForgetBiases", this->GetBiasesAt(1));
 
  962   this->ReadMatrixXML(parent, 
"CandidateWeights", this->GetWeightsAt(4));
 
  963   this->ReadMatrixXML(parent, 
"CandidateStateWeights", this->GetWeightsAt(5));
 
  964   this->ReadMatrixXML(parent, 
"CandidateBiases", this->GetBiasesAt(2));
 
  965   this->ReadMatrixXML(parent, 
"OuputWeights", this->GetWeightsAt(6));
 
  966   this->ReadMatrixXML(parent, 
"OutputStateWeights", this->GetWeightsAt(7));
 
  967   this->ReadMatrixXML(parent, 
"OutputBiases", this->GetBiasesAt(3));
 
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
void InputGate(const Matrix_t &input, Matrix_t &di)
Decides the values we'll update (NN with Sigmoid)
const Matrix_t & GetForgetGateTensorAt(size_t i) const
Matrix_t & GetWeightsOutputGateState()
const std::vector< Matrix_t > & GetOutputGateTensor() const
Tensor_t fWeightsTensor
Tensor for all weights.
const std::vector< Matrix_t > & GetInputGateTensor() const
std::vector< Matrix_t > & GetDerivativesOutput()
const Matrix_t & GetWeigthsForgetStateGradients() const
Matrix_t & GetWeightsForgetGate()
typename Architecture_t::Matrix_t Matrix_t
Matrix_t & GetCandidateGateTensorAt(size_t i)
void InitState(DNN::EInitialization m=DNN::EInitialization::kZero)
Initialize the hidden state and cell state method.
Matrix_t & fWeightsCandidateGradients
Gradients w.r.t the candidate gate - input weights.
const Matrix_t & GetOutputGateBias() const
Matrix_t & GetWeightsCandidateStateGradients()
Matrix_t & GetWeightsInputGate()
Matrix_t & GetWeightsInputGateState()
const std::vector< Matrix_t > & GetCandidateGateTensor() const
const Matrix_t & GetInputGateTensorAt(size_t i) const
std::vector< Matrix_t > & GetForgetGateTensor()
std::vector< Matrix_t > cell_value
cell value for every time step
Matrix_t & fWeightsOutputGradients
Gradients w.r.t the output gate - input weights.
Matrix_t & GetOutputGateBias()
Matrix_t & fOutputBiasGradients
Gradients w.r.t the output gate - bias weights.
DNN::EActivationFunction fF1
Activation function: sigmoid.
virtual void Initialize()
Initialize the weights according to the given initialization method.
Tensor_t fDy
cached activation gradient (input of backward) as T x B x S
Matrix_t & fWeightsOutputGate
Output Gate weights for input, fWeights[6].
Matrix_t & GetForgetGateBias()
Matrix_t & fWeightsCandidateStateGradients
Gradients w.r.t the candidate gate - hidden state weights.
void Forward(Tensor_t &input, bool isTraining=true)
Computes the next hidden state and next cell state with given input matrix.
const Matrix_t & GetInputGateBias() const
typename Architecture_t::Scalar_t Scalar_t
size_t GetInputSize() const
Getters.
Matrix_t & GetForgetGateTensorAt(size_t i)
const Matrix_t & GetOutputGateTensorAt(size_t i) const
const Matrix_t & GetCellTensorAt(size_t i) const
Tensor_t fX
cached input tensor as T x B x I
DNN::EActivationFunction GetActivationFunctionF2() const
Matrix_t & GetCellTensorAt(size_t i)
Matrix_t & fWeightsInputStateGradients
Gradients w.r.t the input gate - hidden state weights.
void CellForward(Matrix_t &inputGateValues, const Matrix_t &forgetGateValues, const Matrix_t &candidateValues, const Matrix_t &outputGateValues)
Forward for a single cell (time unit)
Matrix_t & CellBackward(Matrix_t &state_gradients_backward, Matrix_t &cell_gradients_backward, const Matrix_t &precStateActivations, const Matrix_t &precCellActivations, const Matrix_t &input_gate, const Matrix_t &forget_gate, const Matrix_t &candidate_gate, const Matrix_t &output_gate, const Matrix_t &input, Matrix_t &input_gradient, Matrix_t &di, Matrix_t &df, Matrix_t &dc, Matrix_t &dout, size_t t)
Backward for a single time unit a the corresponding call to Forward(...).
const Matrix_t & GetWeightsInputStateGradients() const
std::vector< Matrix_t > fDerivativesOutput
First fDerivatives of the activations output gate.
size_t GetStateSize() const
Matrix_t & fWeightsForgetGateState
Forget Gate weights for prev state, fWeights[3].
Matrix_t & fOutputGateBias
Output Gate bias.
std::vector< Matrix_t > fDerivativesCandidate
First fDerivatives of the activations candidate gate.
const Matrix_t & GetInputDerivativesAt(size_t i) const
Matrix_t & fWeightsForgetGate
Forget Gate weights for input, fWeights[2].
Matrix_t & fWeightsInputGradients
Gradients w.r.t the input gate - input weights.
typename Architecture_t::Tensor_t Tensor_t
const std::vector< Matrix_t > & GetDerivativesInput() const
Matrix_t & GetWeightsCandidate()
Matrix_t & fForgetGateBias
Forget Gate bias.
Matrix_t & GetWeightsInputGradients()
Matrix_t & GetCandidateBiasGradients()
Matrix_t & GetWeightsOutputGradients()
Matrix_t & fCandidateBias
Candidate Gate bias.
Matrix_t fCandidateValue
Computed candidate values.
Tensor_t & GetWeightGradientsTensor()
bool DoesRememberState() const
const Matrix_t & GetWeightsOutputGradients() const
typename Architecture_t::RecurrentDescriptor_t LayerDescriptor_t
const Matrix_t & GetWeightsInputGradients() const
Matrix_t & GetWeightsCandidateState()
Matrix_t & GetInputBiasGradients()
const Matrix_t & GetInputBiasGradients() const
size_t GetTimeSteps() const
DNN::EActivationFunction fF2
Activation function: tanh.
Matrix_t & fInputBiasGradients
Gradients w.r.t the input gate - bias weights.
Matrix_t & GetWeightsOutputStateGradients()
Matrix_t & fWeightsCandidateState
Candidate Gate weights for prev state, fWeights[5].
void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
Matrix_t & GetForgetGateValue()
std::vector< Matrix_t > fDerivativesForget
First fDerivatives of the activations forget gate.
const Tensor_t & GetWeightGradientsTensor() const
Matrix_t & GetForgetDerivativesAt(size_t i)
const Matrix_t & GetWeightsInputGateState() const
Matrix_t & GetWeightsInputStateGradients()
typename Architecture_t::DropoutDescriptor_t HelperDescriptor_t
Matrix_t & fForgetBiasGradients
Gradients w.r.t the forget gate - bias weights.
const Matrix_t & GetCandidateBias() const
std::vector< Matrix_t > output_gate_value
output gate value for every time step
const std::vector< Matrix_t > & GetDerivativesCandidate() const
size_t fStateSize
Hidden state size for LSTM.
void CandidateValue(const Matrix_t &input, Matrix_t &dc)
Decides the new candidate values (NN with Tanh)
std::vector< Matrix_t > fDerivativesInput
First fDerivatives of the activations input gate.
const Matrix_t & GetWeightsForgetGateState() const
Matrix_t & GetWeightsForgetGateState()
const Matrix_t & GetWeightsInputGate() const
const Matrix_t & GetInputGateValue() const
void Update(const Scalar_t learningRate)
bool DoesReturnSequence() const
Tensor_t fDx
cached gradient on the input (output of backward) as T x B x I
typename Architecture_t::RNNWorkspace_t RNNWorkspace_t
Matrix_t & GetOutputGateValue()
TBasicLSTMLayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps, bool rememberState=false, bool returnSequence=false, DNN::EActivationFunction f1=DNN::EActivationFunction::kSigmoid, DNN::EActivationFunction f2=DNN::EActivationFunction::kTanh, bool training=true, DNN::EInitialization fA=DNN::EInitialization::kZero)
Constructor.
Matrix_t & GetWeightsForgetStateGradients()
const Matrix_t & GetOutputBiasGradients() const
typename Architecture_t::TensorDescriptor_t TensorDescriptor_t
const Matrix_t & GetWeightsOutputStateGradients() const
Matrix_t & fWeightsOutputStateGradients
Gradients w.r.t the output gate - hidden state weights.
bool fReturnSequence
Return in output full sequence or just last element.
Matrix_t & GetWeightsForgetGradients()
Matrix_t & GetWeightsCandidateGradients()
const Matrix_t & GetWeightsForgetGradients() const
Matrix_t fCell
Cell state of LSTM.
std::vector< Matrix_t > & GetDerivativesCandidate()
const Matrix_t & GetForgetBiasGradients() const
std::vector< Matrix_t > & GetOutputGateTensor()
Matrix_t & GetCandidateValue()
const Matrix_t & GetForgetDerivativesAt(size_t i) const
Matrix_t fState
Hidden state of LSTM.
void OutputGate(const Matrix_t &input, Matrix_t &dout)
Computes output values (NN with Sigmoid)
const Matrix_t & GetForgetGateValue() const
std::vector< Matrix_t > candidate_gate_value
candidate gate value for every time step
Matrix_t & GetInputGateValue()
const Matrix_t & GetState() const
const Matrix_t & GetWeightsCandidateState() const
Matrix_t & GetCandidateBias()
const std::vector< Matrix_t > & GetForgetGateTensor() const
const std::vector< Matrix_t > & GetDerivativesOutput() const
const std::vector< Matrix_t > & GetCellTensor() const
const Tensor_t & GetWeightsTensor() const
Matrix_t & fWeightsInputGate
Input Gate weights for input, fWeights[0].
std::vector< Matrix_t > & GetCandidateGateTensor()
const Matrix_t & GetOutputDerivativesAt(size_t i) const
void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
const Matrix_t & GetCell() const
Matrix_t & fWeightsForgetStateGradients
Gradients w.r.t the forget gate - hidden state weights.
const Matrix_t & GetCandidateGateTensorAt(size_t i) const
Matrix_t fOutputValue
Computed output gate values.
size_t fCellSize
Cell state size of LSTM.
Matrix_t & GetOutputDerivativesAt(size_t i)
Matrix_t & GetInputGateTensorAt(size_t i)
std::vector< Matrix_t > & GetDerivativesInput()
Matrix_t & fWeightsOutputGateState
Output Gate weights for prev state, fWeights[7].
const std::vector< Matrix_t > & GetDerivativesForget() const
Matrix_t & GetForgetBiasGradients()
const Matrix_t & GetForgetGateBias() const
const Matrix_t & GetCandidateDerivativesAt(size_t i) const
Matrix_t & GetInputGateBias()
Matrix_t & GetOutputGateTensorAt(size_t i)
size_t fTimeSteps
Timesteps for LSTM.
const Matrix_t & GetCandidateBiasGradients() const
const Matrix_t & GetCandidateValue() const
typename Architecture_t::FilterDescriptor_t WeightsDescriptor_t
Matrix_t & fInputGateBias
Input Gate bias.
const Matrix_t & GetWeightsForgetGate() const
std::vector< Matrix_t > input_gate_value
input gate value for every time step
const Matrix_t & GetWeightsCandidateStateGradients() const
Tensor_t & GetWeightsTensor()
Matrix_t & fWeightsForgetGradients
Gradients w.r.t the forget gate - input weights.
std::vector< Matrix_t > & GetDerivativesForget()
const Matrix_t & GetWeightsOutputGate() const
void ForgetGate(const Matrix_t &input, Matrix_t &df)
Forgets the past values (NN with Sigmoid)
std::vector< Matrix_t > & GetInputGateTensor()
Matrix_t & GetOutputBiasGradients()
const Matrix_t & GetOutputGateValue() const
const Matrix_t & GetWeightsOutputGateState() const
Matrix_t & GetCandidateDerivativesAt(size_t i)
Matrix_t fInputValue
Computed input gate values.
Matrix_t & GetWeightsOutputGate()
const Matrix_t & GetWeightsCandidate() const
void Print() const
Prints the info about the layer.
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Backpropagates the error.
const Matrix_t & GetWeightsCandidateGradients() const
Tensor_t fWeightGradientsTensor
Tensor for all weight gradients.
Matrix_t & GetInputDerivativesAt(size_t i)
typename Architecture_t::RNNDescriptors_t RNNDescriptors_t
DNN::EActivationFunction GetActivationFunctionF1() const
Tensor_t fY
cached output tensor as T x B x S
std::vector< Matrix_t > forget_gate_value
forget gate value for every time step
Matrix_t & fWeightsCandidate
Candidate Gate weights for input, fWeights[4].
bool fRememberState
Remember state in next pass.
Matrix_t & fWeightsInputGateState
Input Gate weights for prev state, fWeights[1].
TDescriptors * fDescriptors
Keeps all the RNN descriptors.
std::vector< Matrix_t > & GetCellTensor()
size_t GetCellSize() const
Matrix_t & fCandidateBiasGradients
Gradients w.r.t the candidate gate - bias weights.
Matrix_t fForgetValue
Computed forget gate values.
Generic General Layer class.
virtual void Initialize()
Initialize the weights and biases according to the given initialization method.
size_t GetBatchSize() const
Getters.
size_t GetInputWidth() const
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
EActivationFunction
Enum that represents layer activation functions.
create variable transformations