30#ifndef TMVA_DNN_GRU_LAYER 
   31#define TMVA_DNN_GRU_LAYER 
   55template<
typename Architecture_t>
 
   61   using Matrix_t = 
typename Architecture_t::Matrix_t;
 
   62   using Scalar_t = 
typename Architecture_t::Scalar_t;
 
   63   using Tensor_t = 
typename Architecture_t::Tensor_t;
 
  140   TBasicGRULayer(
size_t batchSize, 
size_t stateSize, 
size_t inputSize,
 
  141                   size_t timeSteps, 
bool rememberState = 
false, 
bool returnSequence = 
false,
 
  142                   bool resetGateAfter = 
false,
 
  167                 const Tensor_t &activations_backward);
 
  175                           const Matrix_t & precStateActivations,
 
  307template <
typename Architecture_t>
 
  312   : 
VGeneralLayer<Architecture_t>(batchSize, 1, timeSteps, inputSize, 1, (returnSequence) ? timeSteps : 1, stateSize,
 
  313                                   6, {stateSize, stateSize, stateSize, stateSize, stateSize, stateSize},
 
  314                                   {inputSize, inputSize, inputSize, stateSize, stateSize, stateSize}, 3,
 
  315                                   {stateSize, stateSize, stateSize}, {1, 1, 1}, batchSize,
 
  316                                   (returnSequence) ? timeSteps : 1, stateSize, fA),
 
  317     fStateSize(stateSize), fTimeSteps(timeSteps), fRememberState(rememberState), fReturnSequence(returnSequence), fResetGateAfter(resetGateAfter),
 
  318     fF1(
f1), fF2(f2), fResetValue(batchSize, stateSize), fUpdateValue(batchSize, stateSize),
 
  319     fCandidateValue(batchSize, stateSize), fState(batchSize, stateSize), fWeightsResetGate(this->GetWeightsAt(0)),
 
  320     fWeightsResetGateState(this->GetWeightsAt(3)), fResetGateBias(this->GetBiasesAt(0)),
 
  321     fWeightsUpdateGate(this->GetWeightsAt(1)), fWeightsUpdateGateState(this->GetWeightsAt(4)),
 
  322     fUpdateGateBias(this->GetBiasesAt(1)), fWeightsCandidate(this->GetWeightsAt(2)),
 
  323     fWeightsCandidateState(this->GetWeightsAt(5)), fCandidateBias(this->GetBiasesAt(2)),
 
  324     fWeightsResetGradients(this->GetWeightGradientsAt(0)), fWeightsResetStateGradients(this->GetWeightGradientsAt(3)),
 
  325     fResetBiasGradients(this->GetBiasGradientsAt(0)), fWeightsUpdateGradients(this->GetWeightGradientsAt(1)),
 
  326     fWeightsUpdateStateGradients(this->GetWeightGradientsAt(4)), fUpdateBiasGradients(this->GetBiasGradientsAt(1)),
 
  327     fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
 
  328     fWeightsCandidateStateGradients(this->GetWeightGradientsAt(5)),
 
  329     fCandidateBiasGradients(this->GetBiasGradientsAt(2))
 
  331   for (
size_t i = 0; i < timeSteps; ++i) {
 
  339   Architecture_t::InitializeGRUTensors(
this);
 
  343template <
typename Architecture_t>
 
  346      fStateSize(layer.fStateSize),
 
  347      fTimeSteps(layer.fTimeSteps),
 
  348      fRememberState(layer.fRememberState),
 
  349      fReturnSequence(layer.fReturnSequence),
 
  350      fResetGateAfter(layer.fResetGateAfter),
 
  351      fF1(layer.GetActivationFunctionF1()),
 
  352      fF2(layer.GetActivationFunctionF2()),
 
  353      fResetValue(layer.GetBatchSize(), layer.GetStateSize()),
 
  354      fUpdateValue(layer.GetBatchSize(), layer.GetStateSize()),
 
  355      fCandidateValue(layer.GetBatchSize(), layer.GetStateSize()),
 
  356      fState(layer.GetBatchSize(), layer.GetStateSize()),
 
  357      fWeightsResetGate(this->GetWeightsAt(0)),
 
  358      fWeightsResetGateState(this->GetWeightsAt(3)),
 
  359      fResetGateBias(this->GetBiasesAt(0)),
 
  360      fWeightsUpdateGate(this->GetWeightsAt(1)),
 
  361      fWeightsUpdateGateState(this->GetWeightsAt(4)),
 
  362      fUpdateGateBias(this->GetBiasesAt(1)),
 
  363      fWeightsCandidate(this->GetWeightsAt(2)),
 
  364      fWeightsCandidateState(this->GetWeightsAt(5)),
 
  365      fCandidateBias(this->GetBiasesAt(2)),
 
  366      fWeightsResetGradients(this->GetWeightGradientsAt(0)),
 
  367      fWeightsResetStateGradients(this->GetWeightGradientsAt(3)),
 
  368      fResetBiasGradients(this->GetBiasGradientsAt(0)),
 
  369      fWeightsUpdateGradients(this->GetWeightGradientsAt(1)),
 
  370      fWeightsUpdateStateGradients(this->GetWeightGradientsAt(4)),
 
  371      fUpdateBiasGradients(this->GetBiasGradientsAt(1)),
 
  372      fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
 
  373      fWeightsCandidateStateGradients(this->GetWeightGradientsAt(5)),
 
  374      fCandidateBiasGradients(this->GetBiasGradientsAt(2))
 
  404   Architecture_t::InitializeGRUTensors(
this);
 
  408template <
typename Architecture_t>
 
  413   Architecture_t::InitializeGRUDescriptors(fDescriptors, 
this);
 
  414   Architecture_t::InitializeGRUWorkspace(fWorkspace, fDescriptors, 
this);
 
  417   if (Architecture_t::IsCudnn())
 
  418      fResetGateAfter = 
true;
 
  422template <
typename Architecture_t>
 
  430   Matrix_t tmpState(fResetValue.GetNrows(), fResetValue.GetNcols());
 
  431   Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsResetGateState);
 
  432   Architecture_t::MultiplyTranspose(fResetValue, 
input, fWeightsResetGate);
 
  433   Architecture_t::ScaleAdd(fResetValue, tmpState);
 
  434   Architecture_t::AddRowWise(fResetValue, fResetGateBias);
 
  435   DNN::evaluateDerivativeMatrix<Architecture_t>(dr, fRst, fResetValue);
 
  436   DNN::evaluateMatrix<Architecture_t>(fResetValue, fRst);
 
  440template <
typename Architecture_t>
 
  448   Matrix_t tmpState(fUpdateValue.GetNrows(), fUpdateValue.GetNcols());
 
  449   Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsUpdateGateState);
 
  450   Architecture_t::MultiplyTranspose(fUpdateValue, 
input, fWeightsUpdateGate);
 
  451   Architecture_t::ScaleAdd(fUpdateValue, tmpState);
 
  452   Architecture_t::AddRowWise(fUpdateValue, fUpdateGateBias);
 
  453   DNN::evaluateDerivativeMatrix<Architecture_t>(du, fUpd, fUpdateValue);
 
  454   DNN::evaluateMatrix<Architecture_t>(fUpdateValue, fUpd);
 
  458template <
typename Architecture_t>
 
  475   Matrix_t tmp(fCandidateValue.GetNrows(), fCandidateValue.GetNcols());
 
  476   if (!fResetGateAfter) {
 
  478      Architecture_t::Hadamard(tmpState, fState);
 
  479      Architecture_t::MultiplyTranspose(tmp, tmpState, fWeightsCandidateState);
 
  482      Architecture_t::MultiplyTranspose(tmp, fState, fWeightsCandidateState);
 
  483      Architecture_t::Hadamard(tmp, fResetValue);
 
  485   Architecture_t::MultiplyTranspose(fCandidateValue, 
input, fWeightsCandidate);
 
  486   Architecture_t::ScaleAdd(fCandidateValue, tmp);
 
  487   Architecture_t::AddRowWise(fCandidateValue, fCandidateBias);
 
  488   DNN::evaluateDerivativeMatrix<Architecture_t>(dc, fCan, fCandidateValue);
 
  489   DNN::evaluateMatrix<Architecture_t>(fCandidateValue, fCan);
 
  493template <
typename Architecture_t>
 
  498   if (Architecture_t::IsCudnn()) {
 
  501      assert(
input.GetStrides()[1] == this->GetInputSize());
 
  505      Architecture_t::Rearrange(
x, 
input);
 
  507      const auto &weights = this->GetWeightsAt(0);
 
  509      auto &hx = this->fState;
 
  510      auto &cx = this->fCell;
 
  512      auto &hy = this->fState;
 
  513      auto &cy = this->fCell;
 
  518      Architecture_t::RNNForward(
x, hx, cx, weights, 
y, hy, cy, rnnDesc, rnnWork, isTraining);
 
  520      if (fReturnSequence) {
 
  521         Architecture_t::Rearrange(this->GetOutput(), 
y); 
 
  524         Tensor_t tmp = (
y.At(
y.GetShape()[0] - 1)).Reshape({
y.GetShape()[1], 1, 
y.GetShape()[2]});
 
  525         Architecture_t::Copy(this->GetOutput(), tmp);
 
  536   Tensor_t arrInput ( fTimeSteps, this->GetBatchSize(), this->GetInputWidth());
 
  540   Architecture_t::Rearrange(arrInput, 
input); 
 
  542   Tensor_t arrOutput ( fTimeSteps, this->GetBatchSize(), fStateSize );
 
  547   if (!this->fRememberState) {
 
  553   for (
size_t t = 0; t < fTimeSteps; ++t) {
 
  555      ResetGate(arrInput[t], fDerivativesReset[t]);
 
  556      Architecture_t::Copy(this->GetResetGateTensorAt(t), fResetValue);
 
  557      UpdateGate(arrInput[t], fDerivativesUpdate[t]);
 
  558      Architecture_t::Copy(this->GetUpdateGateTensorAt(t), fUpdateValue);
 
  560      CandidateValue(arrInput[t], fDerivativesCandidate[t]);
 
  561      Architecture_t::Copy(this->GetCandidateGateTensorAt(t), fCandidateValue);
 
  564      CellForward(fUpdateValue, fCandidateValue);
 
  568      Matrix_t arrOutputMt = arrOutput[t];
 
  569      Architecture_t::Copy(arrOutputMt, fState);
 
  573      Architecture_t::Rearrange(this->GetOutput(), arrOutput); 
 
  576      Tensor_t tmp = arrOutput.At(fTimeSteps - 1); 
 
  579      tmp = tmp.Reshape({tmp.GetShape()[0], tmp.GetShape()[1], 1});
 
  580      assert(tmp.GetSize() == this->GetOutput().GetSize());
 
  581      assert(tmp.GetShape()[0] == this->GetOutput().GetShape()[2]); 
 
  582      Architecture_t::Rearrange(this->GetOutput(), tmp);
 
  589template <
typename Architecture_t>
 
  593   Architecture_t::Hadamard(fState, updateGateValues);
 
  597   for (
size_t j = 0; j < (size_t) tmp.GetNcols(); j++) {
 
  598      for (
size_t i = 0; i < (size_t) tmp.GetNrows(); i++) {
 
  599         tmp(i,j) = 1 - tmp(i,j);
 
  604   Architecture_t::Hadamard(candidateValues, tmp);
 
  605   Architecture_t::ScaleAdd(fState, candidateValues);
 
  609template <
typename Architecture_t>
 
  611                                                      const Tensor_t &activations_backward)   
 
  615   if (Architecture_t::IsCudnn()) {
 
  623      assert(activations_backward.GetStrides()[1] == this->GetInputSize());
 
  626      Architecture_t::Rearrange(
x, activations_backward);
 
  628      if (!fReturnSequence) {
 
  631         Architecture_t::InitializeZero(dy);
 
  634         Tensor_t tmp2 = dy.At(dy.GetShape()[0] - 1).Reshape({dy.GetShape()[1], 1, dy.GetShape()[2]});
 
  637         Architecture_t::Copy(tmp2, this->GetActivationGradients());
 
  639         Architecture_t::Rearrange(
y, this->GetOutput());
 
  640         Architecture_t::Rearrange(dy, this->GetActivationGradients());
 
  646      const auto &weights = this->GetWeightsTensor();
 
  647      auto &weightGradients = this->GetWeightGradientsTensor();
 
  651      Architecture_t::InitializeZero(weightGradients);
 
  654      auto &hx = this->GetState();
 
  655      auto &cx = this->GetCell();
 
  665      Architecture_t::RNNBackward(
x, hx, cx, 
y, dy, dhy, dcy, weights, dx, dhx, dcx, weightGradients, rnnDesc, rnnWork);
 
  669      if (gradients_backward.GetSize() != 0)
 
  670         Architecture_t::Rearrange(gradients_backward, dx);
 
  678   Matrix_t state_gradients_backward(this->GetBatchSize(), fStateSize); 
 
  683   if (gradients_backward.GetSize() == 0 || gradients_backward[0].GetNrows() == 0 || gradients_backward[0].GetNcols() == 0) {
 
  687   Tensor_t arr_gradients_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
 
  692   Tensor_t arr_activations_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
 
  694   Architecture_t::Rearrange(arr_activations_backward, activations_backward); 
 
  698   Tensor_t arr_output ( fTimeSteps, this->GetBatchSize(), fStateSize);
 
  700   Matrix_t initState(this->GetBatchSize(), fStateSize); 
 
  704   Tensor_t arr_actgradients ( fTimeSteps, this->GetBatchSize(), fStateSize);
 
  706   if (fReturnSequence) {
 
  707      Architecture_t::Rearrange(arr_output, this->GetOutput());
 
  708      Architecture_t::Rearrange(arr_actgradients, this->GetActivationGradients());
 
  712      Architecture_t::InitializeZero(arr_actgradients);
 
  714      Tensor_t tmp_grad = arr_actgradients.At(fTimeSteps - 1).Reshape({this->GetBatchSize(), fStateSize, 1});
 
  715      assert(tmp_grad.GetSize() == this->GetActivationGradients().GetSize());
 
  716      assert(tmp_grad.GetShape()[0] ==
 
  717             this->GetActivationGradients().GetShape()[2]); 
 
  719      Architecture_t::Rearrange(tmp_grad, this->GetActivationGradients());
 
  726   fWeightsResetGradients.Zero();
 
  727   fWeightsResetStateGradients.Zero();
 
  728   fResetBiasGradients.Zero();
 
  731   fWeightsUpdateGradients.Zero();
 
  732   fWeightsUpdateStateGradients.Zero();
 
  733   fUpdateBiasGradients.Zero();
 
  736   fWeightsCandidateGradients.Zero();
 
  737   fWeightsCandidateStateGradients.Zero();
 
  738   fCandidateBiasGradients.Zero();
 
  741   for (
size_t t = fTimeSteps; t > 0; t--) {
 
  743      Architecture_t::ScaleAdd(state_gradients_backward, arr_actgradients[t-1]);
 
  745         const Matrix_t &prevStateActivations = arr_output[t-2];
 
  746         Matrix_t dx = arr_gradients_backward[t-1];
 
  748         CellBackward(state_gradients_backward, prevStateActivations,
 
  749                      this->GetResetGateTensorAt(t-1), this->GetUpdateGateTensorAt(t-1),
 
  750                      this->GetCandidateGateTensorAt(t-1),
 
  751                      arr_activations_backward[t-1], dx ,
 
  752                      fDerivativesReset[t-1], fDerivativesUpdate[t-1],
 
  753                      fDerivativesCandidate[t-1]);
 
  755         const Matrix_t &prevStateActivations = initState;
 
  756         Matrix_t dx = arr_gradients_backward[t-1];
 
  757         CellBackward(state_gradients_backward, prevStateActivations,
 
  758                      this->GetResetGateTensorAt(t-1), this->GetUpdateGateTensorAt(t-1),
 
  759                      this->GetCandidateGateTensorAt(t-1),
 
  760                      arr_activations_backward[t-1], dx ,
 
  761                      fDerivativesReset[t-1], fDerivativesUpdate[t-1],
 
  762                      fDerivativesCandidate[t-1]);
 
  767      Architecture_t::Rearrange(gradients_backward, arr_gradients_backward );
 
  774template <
typename Architecture_t>
 
  776                                                          const Matrix_t & precStateActivations,
 
  785   return Architecture_t::GRULayerBackward(state_gradients_backward,
 
  786                                           fWeightsResetGradients, fWeightsUpdateGradients, fWeightsCandidateGradients,
 
  787                                           fWeightsResetStateGradients, fWeightsUpdateStateGradients,
 
  788                                           fWeightsCandidateStateGradients, fResetBiasGradients, fUpdateBiasGradients,
 
  789                                           fCandidateBiasGradients, dr, du, dc,
 
  790                                           precStateActivations,
 
  791                                           reset_gate, update_gate, candidate_gate,
 
  792                                           fWeightsResetGate, fWeightsUpdateGate, fWeightsCandidate,
 
  793                                           fWeightsResetGateState, fWeightsUpdateGateState, fWeightsCandidateState,
 
  794                                           input, input_gradient, fResetGateAfter);
 
  799template <
typename Architecture_t>
 
  807template<
typename Architecture_t>
 
  811   std::cout << 
" GRU Layer: \t ";
 
  812   std::cout << 
" (NInput = " << this->GetInputSize();  
 
  813   std::cout << 
", NState = " << this->GetStateSize();  
 
  814   std::cout << 
", NTime  = " << this->GetTimeSteps() << 
" )";  
 
  815   std::cout << 
"\tOutput = ( " << this->GetOutput().GetFirstSize() << 
" , " << this->GetOutput()[0].GetNrows() << 
" , " << this->GetOutput()[0].GetNcols() << 
" )\n";
 
  819template <
typename Architecture_t>
 
  834   this->WriteMatrixToXML(layerxml, 
"ResetWeights", this->GetWeightsAt(0));
 
  835   this->WriteMatrixToXML(layerxml, 
"ResetStateWeights", this->GetWeightsAt(1));
 
  836   this->WriteMatrixToXML(layerxml, 
"ResetBiases", this->GetBiasesAt(0));
 
  837   this->WriteMatrixToXML(layerxml, 
"UpdateWeights", this->GetWeightsAt(2));
 
  838   this->WriteMatrixToXML(layerxml, 
"UpdateStateWeights", this->GetWeightsAt(3));
 
  839   this->WriteMatrixToXML(layerxml, 
"UpdateBiases", this->GetBiasesAt(1));
 
  840   this->WriteMatrixToXML(layerxml, 
"CandidateWeights", this->GetWeightsAt(4));
 
  841   this->WriteMatrixToXML(layerxml, 
"CandidateStateWeights", this->GetWeightsAt(5));
 
  842   this->WriteMatrixToXML(layerxml, 
"CandidateBiases", this->GetBiasesAt(2));
 
  846template <
typename Architecture_t>
 
  851   this->ReadMatrixXML(parent, 
"ResetWeights", this->GetWeightsAt(0));
 
  852   this->ReadMatrixXML(parent, 
"ResetStateWeights", this->GetWeightsAt(1));
 
  853   this->ReadMatrixXML(parent, 
"ResetBiases", this->GetBiasesAt(0));
 
  854   this->ReadMatrixXML(parent, 
"UpdateWeights", this->GetWeightsAt(2));
 
  855   this->ReadMatrixXML(parent, 
"UpdateStateWeights", this->GetWeightsAt(3));
 
  856   this->ReadMatrixXML(parent, 
"UpdateBiases", this->GetBiasesAt(1));
 
  857   this->ReadMatrixXML(parent, 
"CandidateWeights", this->GetWeightsAt(4));
 
  858   this->ReadMatrixXML(parent, 
"CandidateStateWeights", this->GetWeightsAt(5));
 
  859   this->ReadMatrixXML(parent, 
"CandidateBiases", this->GetBiasesAt(2));
 
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
const Matrix_t & GetWeightsCandidate() const
Matrix_t & GetWeightsCandidateStateGradients()
typename Architecture_t::RecurrentDescriptor_t LayerDescriptor_t
Matrix_t & GetWeightsResetGate()
Matrix_t & fResetBiasGradients
Gradients w.r.t the reset gate - bias weights.
std::vector< Matrix_t > & GetUpdateGateTensor()
typename Architecture_t::Tensor_t Tensor_t
std::vector< Matrix_t > reset_gate_value
Reset gate value for every time step.
Matrix_t & CellBackward(Matrix_t &state_gradients_backward, const Matrix_t &precStateActivations, const Matrix_t &reset_gate, const Matrix_t &update_gate, const Matrix_t &candidate_gate, const Matrix_t &input, Matrix_t &input_gradient, Matrix_t &dr, Matrix_t &du, Matrix_t &dc)
Backward for a single time unit a the corresponding call to Forward(...).
size_t fStateSize
Hidden state size for GRU.
const Matrix_t & GetWeightsResetGradients() const
const Matrix_t & GetUpdateBiasGradients() const
bool fReturnSequence
Return in output full sequence or just last element.
void Forward(Tensor_t &input, bool isTraining=true)
Computes the next hidden state and next cell state with given input matrix.
const Matrix_t & GetWeightsResetStateGradients() const
std::vector< Matrix_t > fDerivativesReset
First fDerivatives of the activations reset gate.
const Tensor_t & GetWeightsTensor() const
std::vector< Matrix_t > & GetResetGateTensor()
Matrix_t & GetWeightsUpdateGateState()
const std::vector< Matrix_t > & GetCandidateGateTensor() const
const Matrix_t & GetUpdateDerivativesAt(size_t i) const
Matrix_t & GetWeightsUpdateStateGradients()
size_t GetInputSize() const
Getters.
Matrix_t fState
Hidden state of GRU.
Matrix_t & GetWeightsResetGradients()
Tensor_t & GetWeightGradientsTensor()
const Matrix_t & GetCandidateBias() const
std::vector< Matrix_t > update_gate_value
Update gate value for every time step.
Tensor_t & GetWeightsTensor()
Tensor_t fX
cached input tensor as T x B x I
Matrix_t & GetCandidateGateTensorAt(size_t i)
Matrix_t & GetResetBiasGradients()
Matrix_t & GetCandidateValue()
Matrix_t & GetWeightsResetGateState()
DNN::EActivationFunction fF1
Activation function: sigmoid.
const Matrix_t & GetWeightsUpdateGate() const
const std::vector< Matrix_t > & GetDerivativesReset() const
const Matrix_t & GetUpdateGateBias() const
Matrix_t & fWeightsResetGradients
Gradients w.r.t the reset gate - input weights.
std::vector< Matrix_t > & GetDerivativesUpdate()
Matrix_t & fCandidateBiasGradients
Gradients w.r.t the candidate gate - bias weights.
Matrix_t & fCandidateBias
Candidate Gate bias.
Matrix_t & GetUpdateGateTensorAt(size_t i)
DNN::EActivationFunction fF2
Activation function: tanh.
const Matrix_t & GetWeightsUpdateGradients() const
Matrix_t & GetWeightsCandidateGradients()
Matrix_t & fWeightsUpdateStateGradients
Gradients w.r.t the update gate - hidden state weights.
void AddWeightsXMLTo(void *parent)
Writes the information and the weights about the layer in an XML node.
Matrix_t & GetWeightsCandidate()
Matrix_t & fWeightsUpdateGradients
Gradients w.r.t the update gate - input weights.
Matrix_t & GetUpdateGateValue()
size_t fTimeSteps
Timesteps for GRU.
std::vector< Matrix_t > fDerivativesCandidate
First fDerivatives of the activations candidate gate.
const Tensor_t & GetWeightGradientsTensor() const
typename Architecture_t::FilterDescriptor_t WeightsDescriptor_t
Tensor_t fWeightGradientsTensor
Tensor for all weight gradients.
Matrix_t & fUpdateBiasGradients
Gradients w.r.t the update gate - bias weights.
Matrix_t & GetWeightsResetStateGradients()
std::vector< Matrix_t > & GetCandidateGateTensor()
Matrix_t & fWeightsResetGate
Reset Gate weights for input, fWeights[0].
const Matrix_t & GetResetDerivativesAt(size_t i) const
Matrix_t & GetWeightsUpdateGate()
typename Architecture_t::Matrix_t Matrix_t
void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward)
Backpropagates the error.
const Matrix_t & GetCandidateGateTensorAt(size_t i) const
Matrix_t & GetWeightsCandidateState()
const Matrix_t & GetCandidateBiasGradients() const
Matrix_t & GetResetGateTensorAt(size_t i)
Matrix_t & fResetGateBias
Input Gate bias.
const std::vector< Matrix_t > & GetResetGateTensor() const
Matrix_t fCell
Empty matrix for GRU.
std::vector< Matrix_t > candidate_gate_value
Candidate gate value for every time step.
typename Architecture_t::Scalar_t Scalar_t
const Matrix_t & GetWeigthsUpdateStateGradients() const
const Matrix_t & GetCandidateValue() const
Matrix_t & GetCandidateBiasGradients()
Matrix_t & fWeightsCandidateStateGradients
Gradients w.r.t the candidate gate - hidden state weights.
const std::vector< Matrix_t > & GetDerivativesUpdate() const
const Matrix_t & GetCell() const
void UpdateGate(const Matrix_t &input, Matrix_t &df)
Forgets the past values (NN with Sigmoid)
const Matrix_t & GetCandidateDerivativesAt(size_t i) const
Matrix_t fResetValue
Computed reset gate values.
DNN::EActivationFunction GetActivationFunctionF2() const
Matrix_t & GetResetGateBias()
typename Architecture_t::RNNWorkspace_t RNNWorkspace_t
Matrix_t fUpdateValue
Computed forget gate values.
const Matrix_t & GetResetBiasGradients() const
bool fResetGateAfter
GRU variant to Apply the reset gate multiplication afterwards (used by cuDNN)
const Matrix_t & GetWeightsCandidateGradients() const
DNN::EActivationFunction GetActivationFunctionF1() const
bool DoesReturnSequence() const
Matrix_t & GetUpdateBiasGradients()
const Matrix_t & GetUpdateGateTensorAt(size_t i) const
Matrix_t & fWeightsResetGateState
Input Gate weights for prev state, fWeights[1].
Matrix_t & fWeightsUpdateGateState
Update Gate weights for prev state, fWeights[3].
const std::vector< Matrix_t > & GetDerivativesCandidate() const
virtual void Initialize()
Initialize the weights according to the given initialization method.
Tensor_t fWeightsTensor
Tensor for all weights.
void ReadWeightsFromXML(void *parent)
Read the information and the weights about the layer from XML node.
typename Architecture_t::RNNDescriptors_t RNNDescriptors_t
const Matrix_t & GetResetGateBias() const
Matrix_t & GetResetDerivativesAt(size_t i)
const Matrix_t & GetUpdateGateValue() const
const Matrix_t & GetResetGateTensorAt(size_t i) const
TDescriptors * fDescriptors
Keeps all the RNN descriptors.
void CellForward(Matrix_t &updateGateValues, Matrix_t &candidateValues)
Forward for a single cell (time unit)
Matrix_t & GetWeightsUpdateGradients()
Matrix_t & fWeightsResetStateGradients
Gradients w.r.t the reset gate - hidden state weights.
Matrix_t & fWeightsCandidateState
Candidate Gate weights for prev state, fWeights[5].
void Print() const
Prints the info about the layer.
size_t GetStateSize() const
std::vector< Matrix_t > & GetDerivativesReset()
Matrix_t & fUpdateGateBias
Update Gate bias.
const Matrix_t & GetWeightsCandidateStateGradients() const
void ResetGate(const Matrix_t &input, Matrix_t &di)
Decides the values we'll update (NN with Sigmoid)
const Matrix_t & GetWeightsResetGate() const
Tensor_t fDx
cached gradient on the input (output of backward) as T x B x I
Matrix_t & GetCandidateBias()
typename Architecture_t::TensorDescriptor_t TensorDescriptor_t
bool fRememberState
Remember state in next pass.
Matrix_t & fWeightsCandidate
Candidate Gate weights for input, fWeights[4].
Matrix_t & fWeightsCandidateGradients
Gradients w.r.t the candidate gate - input weights.
const Matrix_t & GetWeightsCandidateState() const
const std::vector< Matrix_t > & GetUpdateGateTensor() const
const Matrix_t & GetResetGateValue() const
void Update(const Scalar_t learningRate)
Tensor_t fY
cached output tensor as T x B x S
Matrix_t fCandidateValue
Computed candidate values.
const Matrix_t & GetState() const
Matrix_t & GetUpdateGateBias()
void InitState(DNN::EInitialization m=DNN::EInitialization::kZero)
Initialize the hidden state and cell state method.
Tensor_t fDy
cached activation gradient (input of backward) as T x B x S
Matrix_t & GetCandidateDerivativesAt(size_t i)
std::vector< Matrix_t > fDerivativesUpdate
First fDerivatives of the activations update gate.
size_t GetTimeSteps() const
const Matrix_t & GetWeightsUpdateGateState() const
std::vector< Matrix_t > & GetDerivativesCandidate()
bool DoesRememberState() const
const Matrix_t & GetWeightsResetGateState() const
void CandidateValue(const Matrix_t &input, Matrix_t &dc)
Decides the new candidate values (NN with Tanh)
TBasicGRULayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps, bool rememberState=false, bool returnSequence=false, bool resetGateAfter=false, DNN::EActivationFunction f1=DNN::EActivationFunction::kSigmoid, DNN::EActivationFunction f2=DNN::EActivationFunction::kTanh, bool training=true, DNN::EInitialization fA=DNN::EInitialization::kZero)
Constructor.
Matrix_t & GetUpdateDerivativesAt(size_t i)
Matrix_t & fWeightsUpdateGate
Update Gate weights for input, fWeights[2].
typename Architecture_t::DropoutDescriptor_t HelperDescriptor_t
Matrix_t & GetResetGateValue()
Generic General Layer class.
virtual void Initialize()
Initialize the weights and biases according to the given initialization method.
size_t GetBatchSize() const
Getters.
size_t GetInputWidth() const
XMLNodePointer_t NewChild(XMLNodePointer_t parent, XMLNsPointer_t ns, const char *name, const char *content=nullptr)
create new child element for parent node
XMLAttrPointer_t NewAttr(XMLNodePointer_t xmlnode, XMLNsPointer_t, const char *name, const char *value)
creates new attribute for xmlnode, namespaces are not supported for attributes
EActivationFunction
Enum that represents layer activation functions.
create variable transformations