42 fUseSession = model.UseSession();
44 if (!model.CheckIfTensorAlreadyExist(fNX)) {
45 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " + fNX +
" is not found in model.");
47 fShapeX = model.GetTensorShape(fNX);
48 if (fShapeX.size() != 3) {
49 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " + fNX +
" is not of 3 dimensions.");
51 if (!model.CheckIfTensorAlreadyExist(fNW)) {
52 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " + fNW +
" is not found in model.");
54 fShapeW = model.GetTensorShape(fNW);
55 if (fShapeW.size() != 3) {
56 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " + fNW +
" is not of 3 dimensions.");
58 if (!model.CheckIfTensorAlreadyExist(fNR)) {
59 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " + fNR +
" is not found in model.");
61 fShapeR = model.GetTensorShape(fNR);
62 if (fShapeR.size() != 3) {
63 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " + fNR +
" is not of 3 dimensions.");
66 if (!model.CheckIfTensorAlreadyExist(fNB)) {
67 throw std::runtime_error(
"TMVA SOFIE LSTM op input tensor " + fNB +
" is not found in model.");
69 fShapeB = model.GetTensorShape(fNB);
70 if (fShapeB.size() != 2 && fShapeB.size() != 5) {
71 throw std::runtime_error(
"TMVA SOFIE LSTM op input tensor " + fNB +
" is not of 2 or 5 dimensions.");
73 if (fShapeB.size() == 2) {
77 size_t seq_length = (fAttrLayout == 0)? fShapeX[0] : fShapeX[1];
78 size_t batch_size = (fAttrLayout == 0)? fShapeX[1] : fShapeX[0];
79 if (fType ==
"float") {
83 std::vector<float>
sum(fAttrHiddenSize);
86 for (
size_t h = 0;
h < fAttrHiddenSize;
h++) {
102 fShapeB = model.GetTensorShape(fNB);
106 if (!fNSequence_lens.empty()) {
107 if (!model.CheckIfTensorAlreadyExist(fNSequence_lens)) {
108 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
110 "is not found in model.");
112 fShapeSequence_lens = model.GetTensorShape(fNSequence_lens);
113 if (fShapeSequence_lens.size() != 1) {
114 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
116 " is not of 1 dimension.");
119 if (!fNInitial_h.empty()) {
120 if (!model.CheckIfTensorAlreadyExist(fNInitial_h)) {
121 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
122 fNInitial_h +
" is not found in model.");
124 fShapeInitial_h = model.GetTensorShape(fNInitial_h);
125 if (fShapeInitial_h.size() != 3) {
126 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
127 fNInitial_h +
" is not of 3 dimensions.");
130 if (!fNInitial_c.empty()) {
131 if (!model.CheckIfTensorAlreadyExist(fNInitial_c)) {
132 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
133 fNInitial_c +
" is not found in model.");
135 fShapeInitial_c = model.GetTensorShape(fNInitial_c);
136 if (fShapeInitial_c.size() != 3) {
137 throw std::runtime_error(
"TMVA SOFIE LSTM Op input tensor " +
138 fNInitial_c +
" is not of 3 dimensions.");
142 if (!model.CheckIfTensorAlreadyExist(fNP)) {
143 throw std::runtime_error(
"TMVA SOFIE LSTM op input tensor " + fNP +
" is not found in model.");
145 fShapeP = model.GetTensorShape(fNP);
146 if (fShapeP.size() != 2 && fShapeP.size() != 4) {
147 throw std::runtime_error(
"TMVA SOFIE LSTM op input tensor " + fNP +
" is not of 2 or 4 dimensions.");
149 if (fShapeP.size() == 2) {
153 size_t batch_size = (fAttrLayout == 0)? fShapeX[1] : fShapeX[0];
154 if (fType ==
"float") {
169 std::shared_ptr<void>
new_p_ptr(
new_p, std::default_delete<
float[]>());
171 fShapeP = model.GetTensorShape(fNP);
176 fShapeY = ShapeInference({fShapeX, fShapeW})[0];
177 if (!model.CheckIfTensorAlreadyExist(fNY)) {
178 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
181 if (!fNY_h.empty()) {
182 fShapeY_h = ShapeInference({fShapeX, fShapeW})[1];
183 if (!model.CheckIfTensorAlreadyExist(fNY_h)) {
184 model.AddIntermediateTensor(fNY_h, model.GetTensorType(fNX), fShapeY_h);
187 if (!fNY_c.empty()) {
188 fShapeY_c = ShapeInference({fShapeX, fShapeW})[2];
189 if (!model.CheckIfTensorAlreadyExist(fNY_c)) {
190 model.AddIntermediateTensor(fNY_c, model.GetTensorType(fNX), fShapeY_c);
201 throw std::runtime_error(
"TMVA SOFIE - Activation function " +
205 if (fAttrDirection !=
"forward" && fAttrDirection !=
"backward" &&
206 fAttrDirection !=
"bidirectional") {
207 throw std::runtime_error(
208 "TMVA SOFIE - Invalid LSTM direction fAttrDirection = " +
211 if (4 * fAttrHiddenSize != fShapeW[1]) {
212 throw std::runtime_error(
213 "TMVA SOFIE - fAttrHiddenSize must be equal to " +
214 std::to_string(fShapeW[1] / 4));
216 if (fAttrInputForget > 1) {
217 throw std::runtime_error(
218 "TMVA SOFIE - fAttrInputForget = " + std::to_string(fAttrInputForget)
219 +
" must be 0 or 1.");
221 if (fAttrLayout > 1) {
222 throw std::runtime_error(
"TMVA SOFIE - Layout fAttrLayout = " +
223 std::to_string(fAttrLayout) +
224 " must be 0 (timewise) or 1 (batchwise)");
226 if (fAttrActivations.empty()) {
227 if (fAttrDirection ==
"bidirectional") {
228 fAttrActivations = {
"Sigmoid",
"Tanh",
"Tanh",
"Sigmoid",
"Tanh",
"Tanh"};
230 fAttrActivations = {
"Sigmoid",
"Tanh",
"Tanh"};
286 std::stringstream out;
288 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
289 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
294 if (fAttrLayout == 0) {
295 out << SP << fType <<
" *" <<
OpName <<
"_input = tensor_" << fNX <<
";\n";
298 out << SP << fType <<
" * " <<
OpName <<
"_input = fVec_" <<
OpName <<
"_input.data();\n";
302 out << SP <<
"for(size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
303 out << SP << SP <<
"for(size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
304 out << SP << SP << SP <<
"for(size_t i = 0; i < " <<
input_size <<
"; i++) {\n";
306 <<
" + batch * " <<
input_size <<
" + i] = " <<
"tensor_" << fNX <<
"[batch * "
308 out << SP << SP << SP <<
"}\n";
309 out << SP << SP <<
"}\n";
314 if (!fNInitial_h.empty()) {
315 if (fAttrLayout == 0) {
316 out << SP << fType <<
" *" <<
OpName <<
"_initial_hidden_state = " <<
" tensor_"
317 << fNInitial_h <<
";\n";
320 out << SP << fType <<
" * " <<
OpName <<
"_initial_hidden_state = fVec_" <<
OpName
321 <<
"_initial_hidden_state.data();\n";
324 fAttrHiddenSize <<
"] = {0};\n";
327 out << SP <<
"for(size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
328 out << SP << SP <<
"for(size_t h = 0; h < " << fAttrHiddenSize <<
"; h++) {\n";
329 out << SP << SP << SP <<
OpName <<
"_initial_hidden_state["
331 <<
" + h] = tensor_" << fNInitial_h <<
"[batch * " <<
num_directions * fAttrHiddenSize
332 <<
" + " <<
direction * fAttrHiddenSize <<
" + h];\n";
333 out << SP << SP <<
"}\n";
340 if (!fNInitial_c.empty()) {
341 if (fAttrLayout == 0) {
342 out << SP << fType <<
" *" <<
OpName <<
"_initial_cell_state = " <<
" tensor_"
343 << fNInitial_c <<
";\n";
346 out << SP << fType <<
" * " <<
OpName <<
"_initial_cell_state = fVec_" <<
OpName
347 <<
"_initial_cell_state.data();\n";
350 fAttrHiddenSize <<
"] = {0};\n";
353 out << SP <<
"for(size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
354 out << SP << SP <<
"for(size_t h = 0; h < " << fAttrHiddenSize <<
"; h++) {\n";
355 out << SP << SP << SP <<
OpName <<
"_initial_cell_state["
357 <<
" + h] = tensor_" << fNInitial_c <<
"[batch * " <<
num_directions * fAttrHiddenSize
358 <<
" + " <<
direction * fAttrHiddenSize <<
" + h];\n";
359 out << SP << SP <<
"}\n";
368 out << SP << fType <<
" * " <<
OpName <<
"_ff_input_gate = fVec_" <<
OpName <<
"_ff_input_gate.data();\n";
369 out << SP << fType <<
" * " <<
OpName <<
"_ff_output_gate = fVec_" <<
OpName <<
"_ff_output_gate.data();\n";
370 out << SP << fType <<
" * " <<
OpName <<
"_ff_cell_gate = fVec_" <<
OpName <<
"_ff_cell_gate.data();\n";
371 if (fAttrInputForget == 0) {
372 out << SP << fType <<
" * " <<
OpName <<
"_ff_forget_gate = fVec_" <<
OpName <<
"_ff_forget_gate.data();\n";
375 out << SP << fType <<
" " <<
OpName <<
"_ff_input_gate[" <<
ff_size <<
"] = {0};\n";
376 out << SP << fType <<
" " <<
OpName <<
"_ff_output_gate[" <<
ff_size <<
"] = {0};\n";
377 out << SP << fType <<
" " <<
OpName <<
"_ff_cell_gate[" <<
ff_size <<
"] = {0};\n";
378 if (fAttrInputForget == 0) {
379 out << SP << fType <<
" " <<
OpName <<
"_ff_forget_gate[" <<
ff_size <<
"] = {0};\n";
385 out << SP << fType <<
" * " <<
OpName <<
"_input_gate = fVec_" <<
OpName <<
"_input_gate.data();\n";
386 out << SP << fType <<
" * " <<
OpName <<
"_output_gate = fVec_" <<
OpName <<
"_output_gate.data();\n";
387 out << SP << fType <<
" * " <<
OpName <<
"_cell_gate = fVec_" <<
OpName <<
"_cell_gate.data();\n";
388 if (fAttrInputForget == 0) {
389 out << SP << fType <<
" * " <<
OpName <<
"_forget_gate = fVec_" <<
OpName <<
"_forget_gate.data();\n";
395 if (fAttrInputForget == 0) {
401 out << SP << fType <<
" * " <<
OpName <<
"_cell_state = fVec_" <<
OpName <<
"_cell_state.data();\n";
402 out << SP << fType <<
" * " <<
OpName <<
"_new_cell_state = fVec_" <<
OpName <<
"_new_cell_state.data();\n";
409 if (fAttrLayout == 0 && !fNY.empty()) {
410 out << SP << fType <<
" *" <<
OpName <<
"_hidden_state = tensor_" << fNY <<
";\n";
413 out << SP << fType <<
" * " <<
OpName <<
"_hidden_state = fVec_" <<
OpName <<
"_hidden_state.data();\n";
419 out << SP <<
"char " <<
OpName <<
"_transA = 'N';\n";
420 out << SP <<
"char " <<
OpName <<
"_transB = 'T';\n";
422 out << SP <<
"int " <<
OpName <<
"_n = " << fAttrHiddenSize <<
";\n";
424 if (fType ==
"float") {
425 out << SP << fType <<
" " <<
OpName <<
"_alpha = 1.;\n";
426 out << SP << fType <<
" " <<
OpName <<
"_beta = 0.;\n";
430 out << SP <<
"int " <<
OpName <<
"_incx = 1;\n";
431 out << SP <<
"int " <<
OpName <<
"_incy = 1;\n";
436 if (fType ==
"float") {
438 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
444 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
450 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
456 if (fType ==
"float") {
459 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
465 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
471 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
477 if (fAttrInputForget == 0) {
480 if (fType ==
"float") {
482 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
488 if (fType ==
"float") {
490 out << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
501 if (fType ==
"float") {
503 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
504 << fNB <<
", &" <<
OpName <<
"_incx, " <<
OpName <<
"_ff_input_gate, &" <<
OpName <<
"_incy);\n";
507 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
512 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
517 if (fType ==
"float") {
520 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
526 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
532 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
537 if (fAttrInputForget == 0) {
540 if (fType ==
"float") {
542 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
547 if (fType ==
"float") {
550 out << SP <<
"BLAS::saxpy_(&" <<
OpName <<
"_bias_size, &" <<
OpName <<
"_alpha, tensor_"
560 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
561 out << SP << SP <<
"size_t ff_offset = seq * " <<
batch_size * fAttrHiddenSize <<
";\n";
567 <<
" + " <<
batch_size * fAttrHiddenSize <<
";\n";
570 out << SP << SP <<
"std::copy(" <<
OpName <<
"_ff_input_gate + ff_offset, " <<
OpName
571 <<
"_ff_input_gate + ff_offset + " <<
ff_seq_size <<
", " <<
OpName <<
"_input_gate + gate_offset);\n";
572 out << SP << SP <<
"std::copy(" <<
OpName <<
"_ff_output_gate + ff_offset, " <<
OpName
573 <<
"_ff_output_gate + ff_offset + " <<
ff_seq_size <<
", " <<
OpName <<
"_output_gate + gate_offset);\n";
574 out << SP << SP <<
"std::copy(" <<
OpName <<
"_ff_cell_gate + ff_offset, " <<
OpName
575 <<
"_ff_cell_gate + ff_offset + " <<
ff_seq_size <<
", " <<
OpName <<
"_cell_gate + gate_offset);\n";
576 if (fAttrInputForget == 0) {
577 out << SP << SP <<
"std::copy(" <<
OpName <<
"_ff_forget_gate + ff_offset, " <<
OpName
578 <<
"_ff_forget_gate + ff_offset + " <<
ff_seq_size <<
", " <<
OpName <<
"_forget_gate + gate_offset);\n";
582 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
583 if (fAttrDirection ==
"backward" ||
direction == 1) {
584 out << SP << SP <<
"size_t index = " <<
seq_length - 1 <<
" - seq;\n";
586 out << SP << SP <<
"size_t index = seq;\n";
588 out << SP << SP <<
"int m2 = " <<
batch_size <<
";\n";
594 <<
" + " <<
batch_size * fAttrHiddenSize <<
";\n";
598 out << SP << SP <<
"if (seq == 0) {\n";
599 if (!fNInitial_h.empty()) {
601 if (fType ==
"float") {
602 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
603 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
", &"
605 <<
"_alpha, " <<
OpName <<
"_input_gate + offset, &" <<
OpName <<
"_n);\n";
606 size_t ro_offset = fAttrHiddenSize * fAttrHiddenSize;
607 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
608 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
610 <<
"_n, &" <<
OpName <<
"_alpha, " <<
OpName <<
"_output_gate + offset, &" <<
OpName <<
"_n);\n";
611 size_t rc_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
612 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
613 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
615 <<
"_n, &" <<
OpName <<
"_alpha, " <<
OpName <<
"_cell_gate + offset, &" <<
OpName <<
"_n);\n";
616 if (fAttrInputForget == 0) {
617 size_t rf_offset = 2 * fAttrHiddenSize * fAttrHiddenSize;
618 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
619 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
621 <<
"_n, &" <<
OpName <<
"_alpha, " <<
OpName <<
"_forget_gate + offset, &" <<
OpName <<
"_n);\n";
625 if (fType ==
"float") {
626 size_t ri_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
627 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
628 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
630 <<
"_n, &" <<
OpName <<
"_alpha, " <<
OpName <<
"_input_gate + offset, &" <<
OpName <<
"_n);\n";
631 size_t ro_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + 1 * fAttrHiddenSize * fAttrHiddenSize;
632 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
633 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
635 <<
"_n, &" <<
OpName <<
"_alpha, " <<
OpName <<
"_output_gate + offset, &" <<
OpName <<
"_n);\n";
636 size_t rc_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + 3 * fAttrHiddenSize * fAttrHiddenSize;
637 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
638 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
640 <<
"_n, &" <<
OpName <<
"_alpha, " <<
OpName <<
"_cell_gate + offset, &" <<
OpName <<
"_n);\n";
641 if (fAttrInputForget == 0) {
642 size_t rf_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
643 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
644 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
646 <<
"_n, &" <<
OpName <<
"_alpha, " <<
OpName <<
"_forget_gate + offset, &" <<
OpName <<
"_n);\n";
651 out << SP << SP <<
"} else {\n";
654 if (fAttrDirection ==
"backward") {
655 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
658 out << SP << SP << SP <<
"size_t previous_offset = (seq - 1) * "
661 if (fType ==
"float") {
662 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
663 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
", &"
664 <<
OpName <<
"_n, " <<
OpName <<
"_hidden_state + previous_offset, &" <<
OpName <<
"_n, &"
666 size_t ro_offset = 1 * fAttrHiddenSize * fAttrHiddenSize;
667 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
668 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
672 size_t rc_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
673 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
674 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
678 if (fAttrInputForget == 0) {
679 size_t rf_offset = 2 * fAttrHiddenSize * fAttrHiddenSize;
680 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
681 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
688 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
690 if (fType ==
"float") {
691 size_t ri_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
692 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
693 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
697 size_t ro_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + fAttrHiddenSize * fAttrHiddenSize;
698 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
699 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
703 size_t rc_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + 3 * fAttrHiddenSize * fAttrHiddenSize;
704 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
705 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
709 if (fAttrInputForget == 0) {
710 size_t rf_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
711 out << SP << SP << SP <<
"BLAS::sgemm_(&" <<
OpName <<
"_transB, &" <<
OpName <<
"_transA, &"
712 <<
OpName <<
"_n, &m2, &" <<
OpName <<
"_n, &" <<
OpName <<
"_alpha, tensor_" << fNR <<
" + "
719 out << SP << SP <<
"}\n";
722 if (fAttrClip > .0) {
723 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
724 if (fType ==
"float") {
725 out << SP << SP << SP <<
"float x = (" <<
OpName <<
"_cell_gate[i] > " << -fAttrClip <<
") ? "
726 <<
OpName <<
"_cell_gate[i] : " << -fAttrClip <<
";\n";
728 out << SP << SP << SP <<
OpName <<
"_cell_gate[i] = (x < " << fAttrClip <<
") ? x : "
729 << fAttrClip <<
";\n";
730 out << SP << SP <<
"}\n";
733 if (fAttrActivations[
direction * 3 + 1] ==
"Relu") {
734 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
735 out << SP << SP << SP <<
"if (" <<
OpName <<
"_cell_gate[i] < 0.)\n";
736 out << SP << SP << SP << SP <<
OpName <<
"_cell_gate[i] = 0.;\n";
737 out << SP << SP <<
"}\n";
738 }
else if (fAttrActivations[
direction * 3 + 1] ==
"Tanh") {
739 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
740 if (fType ==
"float") {
741 out << SP << SP << SP <<
"float ex = exp(-2 * " <<
OpName <<
"_cell_gate[i]);\n";
743 out << SP << SP << SP << SP <<
OpName <<
"_cell_gate[i] = (1. - ex) / (1. + ex);\n";
744 out << SP << SP <<
"}\n";
745 }
else if (fAttrActivations[
direction * 3 + 1] ==
"Sigmoid") {
746 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
747 out << SP << SP << SP << SP <<
OpName <<
"_cell_gate[i] = 1. / (1. + exp(-" <<
OpName
748 <<
"_cell_gate[i]));\n";
749 out << SP << SP <<
"}\n";
750 }
else if (fAttrActivations[
direction * 3 + 1] ==
"Affine") {
751 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
752 out << SP << SP << SP << SP <<
OpName <<
"_cell_gate[i] = "
753 << fAttrActivationAlpha[
direction * 3 + 1] <<
" * " <<
OpName <<
"_cell_gate[i] + "
754 << fAttrActivationBeta[
direction * 3 + 1] <<
";\n";
755 out << SP << SP <<
"}\n";
756 }
else if (fAttrActivations[
direction * 3 + 1] ==
"ScaledTanh") {
757 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
758 if (fType ==
"float") {
759 out << SP << SP << SP <<
"float ex = exp(-2 * " << fAttrActivationBeta[
direction * 3 + 1]
760 <<
" * "<<
OpName <<
"_cell_gate[i]);\n";
762 out << SP << SP << SP << SP <<
OpName <<
"_cell_gate[i] = "
763 << fAttrActivationAlpha[
direction * 3 + 1] <<
" * (1. - ex) / (1. + ex);\n";
764 out << SP << SP <<
"}\n";
765 }
else if (fAttrActivations[
direction * 3 + 1] ==
"HardSigmoid") {
766 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
767 if (fType ==
"float") {
768 out << SP << SP << SP <<
"float a = " << fAttrActivationAlpha[
direction * 3 + 1] <<
" * "
769 <<
OpName <<
"_cell_gate[i] + " << fAttrActivationBeta[
direction * 3 + 1] <<
";\n";
770 out << SP << SP << SP <<
"float b = (a > 0.) ? a : 0.;\n";
772 out << SP << SP << SP << SP <<
OpName <<
"_cell_gate[i] = (b < 1.) ? b : 1.;\n";
773 out << SP << SP <<
"}\n";
774 }
else if (fAttrActivations[
direction * 3 + 1] ==
"LeakyRelu") {
775 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
776 out << SP << SP << SP <<
"if (" <<
OpName <<
"_cell_gate[i] < 0.)\n";
777 out << SP << SP << SP << SP <<
OpName <<
"_cell_gate[i] = "
778 << fAttrActivationAlpha[
direction * 3 + 1] <<
" * " <<
OpName <<
"_cell_gate[i];\n";
779 out << SP << SP <<
"}\n";
780 }
else if (fAttrActivations[
direction * 3 + 1] ==
"ThresholdRelu") {
781 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
782 out << SP << SP << SP <<
"if (" <<
OpName <<
"_cell_gate[i] < "
783 << fAttrActivationAlpha[
direction * 3 + 1] <<
")\n";
784 out << SP << SP << SP << SP <<
OpName <<
"_cell_gate[i] = 0.;\n";
785 out << SP << SP <<
"}";
786 }
else if (fAttrActivations[
direction * 3 + 1] ==
"Elu") {
787 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
788 out << SP << SP << SP <<
"if (" <<
OpName <<
"_cell_gate[i] < 0.)\n";
789 out << SP << SP << SP << SP <<
OpName <<
"_cell_gate[i] = "
790 << fAttrActivationAlpha[
direction * 3 + 1] <<
" * exp(" <<
OpName <<
"_cell_gate[i] - 1.);\n";
791 out << SP << SP <<
"}\n";
792 }
else if (fAttrActivations[
direction * 3 + 1] ==
"Softsign") {
793 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
794 out << SP << SP << SP << SP <<
OpName <<
"_cell_gate[i] = " <<
OpName
795 <<
"_cell_gate[i] / (1. + abs(" <<
OpName <<
"_cell_gate[i]));\n";
796 out << SP << SP <<
"}\n";
798 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
799 out << SP << SP << SP << SP <<
OpName <<
"_cell_gate[i] = log(1. + exp("
800 <<
OpName <<
"_cell_gate[i]));\n";
801 out << SP << SP <<
"}\n";
807 out << SP << SP <<
"if (seq == 0) {\n";
808 if (!fNInitial_c.empty()) {
810 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
811 out << SP << SP << SP << SP <<
OpName <<
"_input_gate[i + offset] += tensor_" << fNP
812 <<
"[i] * " <<
OpName <<
"_initial_cell_state[i];\n";
813 out << SP << SP << SP <<
"}\n";
814 if (fAttrInputForget == 0) {
816 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
817 out << SP << SP << SP << SP <<
OpName <<
"_forget_gate[i + offset] += tensor_" << fNP
818 <<
"[i + " <<
pf_offset <<
"] * " <<
OpName <<
"_initial_cell_state[i];\n";
819 out << SP << SP << SP <<
"}\n";
824 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
825 out << SP << SP << SP << SP <<
OpName <<
"_input_gate[i + offset] += tensor_" << fNP
828 out << SP << SP << SP <<
"}\n";
829 if (fAttrInputForget == 0) {
831 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
832 out << SP << SP << SP << SP <<
OpName <<
"_forget_gate[i + offset] += tensor_" << fNP
835 out << SP << SP << SP <<
"}\n";
839 out << SP << SP <<
"} else {\n";
841 if (fAttrDirection ==
"backward") {
842 out << SP << SP << SP <<
"size_t c_offset = (index + 1) * "
845 out << SP << SP << SP <<
"size_t c_offset = (seq - 1) * "
848 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
849 out << SP << SP << SP << SP <<
OpName <<
"_input_gate[i + offset] += tensor_" << fNP
850 <<
"[i] * " <<
OpName <<
"_cell_state[i + c_offset];\n";
851 out << SP << SP << SP <<
"}\n";
852 if (fAttrInputForget == 0) {
854 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
855 out << SP << SP << SP << SP <<
OpName <<
"_forget_gate[i + offset] += tensor_" << fNP
856 <<
"[i + " <<
pf_offset <<
"] * " <<
OpName <<
"_cell_state[i + c_offset];\n";
857 out << SP << SP << SP <<
"}\n";
861 out << SP << SP << SP <<
"size_t c_offset = (index + 1) * "
863 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
864 out << SP << SP << SP << SP <<
OpName <<
"_input_gate[i + offset] += tensor_" << fNP
865 <<
"[i + " <<
pi_offset <<
"] * " <<
OpName <<
"_cell_state[i + c_offset];\n";
866 out << SP << SP << SP <<
"}\n";
867 if (fAttrInputForget == 0) {
869 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
870 out << SP << SP << SP << SP <<
OpName <<
"_forget_gate[i + offset] += tensor_" << fNP
871 <<
"[i + " <<
pf_offset <<
"] * " <<
OpName <<
"_cell_state[i + c_offset];\n";
872 out << SP << SP << SP <<
"}\n";
875 out << SP << SP <<
"}\n";
879 if (fAttrClip > .0) {
880 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
881 if (fType ==
"float") {
882 out << SP << SP << SP <<
"float x = (" <<
OpName <<
"_input_gate[i] > " << -fAttrClip <<
") ? "
883 <<
OpName <<
"_input_gate[i] : " << -fAttrClip <<
";\n";
885 out << SP << SP << SP <<
OpName <<
"_input_gate[i] = (x < " << fAttrClip <<
") ? x : "
886 << fAttrClip <<
";\n";
887 out << SP << SP <<
"}\n";
890 if (fAttrActivations[
direction * 3] ==
"Relu") {
891 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
892 out << SP << SP << SP <<
"if (" <<
OpName <<
"_input_gate[i] < 0.)\n";
893 out << SP << SP << SP << SP <<
OpName <<
"_input_gate[i] = 0.;\n";
894 out << SP << SP <<
"}\n";
895 }
else if (fAttrActivations[
direction * 3] ==
"Tanh") {
896 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
897 if (fType ==
"float") {
898 out << SP << SP << SP <<
"float ex = exp(-2 * " <<
OpName <<
"_input_gate[i]);\n";
900 out << SP << SP << SP << SP <<
OpName <<
"_input_gate[i] = (1. - ex) / (1. + ex);\n";
901 out << SP << SP <<
"}\n";
902 }
else if (fAttrActivations[
direction * 3] ==
"Sigmoid") {
903 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
904 out << SP << SP << SP << SP <<
OpName <<
"_input_gate[i] = 1. / (1. + exp(-" <<
OpName
905 <<
"_input_gate[i]));\n";
906 out << SP << SP <<
"}\n";
907 }
else if (fAttrActivations[
direction * 3] ==
"Affine") {
908 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
909 out << SP << SP << SP << SP <<
OpName <<
"_input_gate[i] = "
910 << fAttrActivationAlpha[
direction * 3] <<
" * " <<
OpName <<
"_input_gate[i] + "
911 << fAttrActivationBeta[
direction * 3] <<
";\n";
912 out << SP << SP <<
"}\n";
913 }
else if (fAttrActivations[
direction * 3] ==
"ScaledTanh") {
914 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
915 if (fType ==
"float") {
916 out << SP << SP << SP <<
"float ex = exp(-2 * " << fAttrActivationBeta[
direction * 3]
917 <<
" * "<<
OpName <<
"_input_gate[i]);\n";
919 out << SP << SP << SP << SP <<
OpName <<
"_input_gate[i] = "
920 << fAttrActivationAlpha[
direction * 3] <<
" * (1. - ex) / (1. + ex);\n";
921 out << SP << SP <<
"}\n";
922 }
else if (fAttrActivations[
direction * 3] ==
"HardSigmoid") {
923 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
924 if (fType ==
"float") {
925 out << SP << SP << SP <<
"float a = " << fAttrActivationAlpha[
direction * 3] <<
" * "
926 <<
OpName <<
"_input_gate[i] + " << fAttrActivationBeta[
direction * 3] <<
";\n";
927 out << SP << SP << SP <<
"float b = (a > 0.) ? a : 0.;\n";
929 out << SP << SP << SP << SP <<
OpName <<
"_input_gate[i] = (b < 1.) ? b : 1.;\n";
930 out << SP << SP <<
"}\n";
931 }
else if (fAttrActivations[
direction * 3] ==
"LeakyRelu") {
932 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
933 out << SP << SP << SP <<
"if (" <<
OpName <<
"_input_gate[i] < 0.)\n";
934 out << SP << SP << SP << SP <<
OpName <<
"_input_gate[i] = "
935 << fAttrActivationAlpha[
direction * 3] <<
" * " <<
OpName <<
"_input_gate[i];\n";
936 out << SP << SP <<
"}\n";
937 }
else if (fAttrActivations[
direction * 3] ==
"ThresholdRelu") {
938 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
939 out << SP << SP << SP <<
"if (" <<
OpName <<
"_input_gate[i] < "
940 << fAttrActivationAlpha[
direction * 3] <<
")\n";
941 out << SP << SP << SP << SP <<
OpName <<
"_input_gate[i] = 0.;\n";
942 out << SP << SP <<
"}";
943 }
else if (fAttrActivations[
direction * 3] ==
"Elu") {
944 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
945 out << SP << SP << SP <<
"if (" <<
OpName <<
"_input_gate[i] < 0.)\n";
946 out << SP << SP << SP << SP <<
OpName <<
"_input_gate[i] = "
947 << fAttrActivationAlpha[
direction * 3] <<
" * exp(" <<
OpName <<
"_input_gate[i] - 1.);\n";
948 out << SP << SP <<
"}\n";
949 }
else if (fAttrActivations[
direction * 3] ==
"Softsign") {
950 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
951 out << SP << SP << SP << SP <<
OpName <<
"_input_gate[i] = " <<
OpName
952 <<
"_input_gate[i] / (1. + abs(" <<
OpName <<
"_input_gate[i]));\n";
953 out << SP << SP <<
"}\n";
955 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
956 out << SP << SP << SP << SP <<
OpName <<
"_input_gate[i] = log(1. + exp("
957 <<
OpName <<
"_input_gate[i]));\n";
958 out << SP << SP <<
"}\n";
961 if (fAttrInputForget == 0) {
963 if (fAttrClip > .0) {
964 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
965 if (fType ==
"float") {
966 out << SP << SP << SP <<
"float x = (" <<
OpName <<
"_forget_gate[i] > "
967 << -fAttrClip <<
") ? " <<
OpName <<
"_forget_gate[i] : " << -fAttrClip <<
";\n";
969 out << SP << SP << SP <<
OpName <<
"_forget_gate[i] = (x < " << fAttrClip
970 <<
") ? x : " << fAttrClip <<
";\n";
971 out << SP << SP <<
"}\n";
974 if (fAttrActivations[
direction * 3] ==
"Relu") {
975 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
976 out << SP << SP << SP <<
"if (" <<
OpName <<
"_forget_gate[i] < 0.)\n";
977 out << SP << SP << SP << SP <<
OpName <<
"_forget_gate[i] = 0.;\n";
978 out << SP << SP <<
"}\n";
979 }
else if (fAttrActivations[
direction * 3] ==
"Tanh") {
980 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
981 if (fType ==
"float") {
982 out << SP << SP << SP <<
"float ex = exp(-2 * " <<
OpName <<
"_forget_gate[i]);\n";
984 out << SP << SP << SP << SP <<
OpName <<
"_forget_gate[i] = (1. - ex) / (1. + ex);\n";
985 out << SP << SP <<
"}\n";
986 }
else if (fAttrActivations[
direction * 3] ==
"Sigmoid") {
987 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
988 out << SP << SP << SP << SP <<
OpName <<
"_forget_gate[i] = 1. / (1. + exp(-"
989 <<
OpName <<
"_forget_gate[i]));\n";
990 out << SP << SP <<
"}\n";
991 }
else if (fAttrActivations[
direction * 3] ==
"Affine") {
992 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
993 out << SP << SP << SP << SP <<
OpName <<
"_forget_gate[i] = "
994 << fAttrActivationAlpha[
direction * 3] <<
" * " <<
OpName <<
"_forget_gate[i] + "
995 << fAttrActivationBeta[
direction * 3] <<
";\n";
996 out << SP << SP <<
"}\n";
997 }
else if (fAttrActivations[
direction * 3] ==
"ScaledTanh") {
998 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
999 if (fType ==
"float") {
1000 out << SP << SP << SP <<
"float ex = exp(-2 * " << fAttrActivationBeta[
direction * 3]
1001 <<
" * "<<
OpName <<
"_forget_gate[i]);\n";
1003 out << SP << SP << SP << SP <<
OpName <<
"_forget_gate[i] = "
1004 << fAttrActivationAlpha[
direction * 3] <<
" * (1. - ex) / (1. + ex);\n";
1005 out << SP << SP <<
"}\n";
1006 }
else if (fAttrActivations[
direction * 3] ==
"HardSigmoid") {
1007 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1008 if (fType ==
"float") {
1009 out << SP << SP << SP <<
"float a = " << fAttrActivationAlpha[
direction * 3] <<
" * "
1010 <<
OpName <<
"_forget_gate[i] + " << fAttrActivationBeta[
direction * 3] <<
";\n";
1011 out << SP << SP << SP <<
"float b = (a > 0.) ? a : 0.;\n";
1013 out << SP << SP << SP << SP <<
OpName <<
"_forget_gate[i] = (b < 1.) ? b : 1.;\n";
1014 out << SP << SP <<
"}\n";
1015 }
else if (fAttrActivations[
direction * 3] ==
"LeakyRelu") {
1016 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1017 out << SP << SP << SP <<
"if (" <<
OpName <<
"_forget_gate[i] < 0.)\n";
1018 out << SP << SP << SP << SP <<
OpName <<
"_forget_gate[i] = "
1019 << fAttrActivationAlpha[
direction * 3] <<
" * " <<
OpName <<
"_forget_gate[i];\n";
1020 out << SP << SP <<
"}\n";
1021 }
else if (fAttrActivations[
direction * 3] ==
"ThresholdRelu") {
1022 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1023 out << SP << SP << SP <<
"if (" <<
OpName <<
"_forget_gate[i] < "
1024 << fAttrActivationAlpha[
direction * 3] <<
")\n";
1025 out << SP << SP << SP << SP <<
OpName <<
"_forget_gate[i] = 0.;\n";
1026 out << SP << SP <<
"}";
1027 }
else if (fAttrActivations[
direction * 3] ==
"Elu") {
1028 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1029 out << SP << SP << SP <<
"if (" <<
OpName <<
"_forget_gate[i] < 0.)\n";
1030 out << SP << SP << SP << SP <<
OpName <<
"_forget_gate[i] = "
1031 << fAttrActivationAlpha[
direction * 3] <<
" * exp(" <<
OpName <<
"_forget_gate[i] - 1.);\n";
1032 out << SP << SP <<
"}\n";
1033 }
else if (fAttrActivations[
direction * 3] ==
"Softsign") {
1034 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1035 out << SP << SP << SP << SP <<
OpName <<
"_forget_gate[i] = " <<
OpName
1036 <<
"_forget_gate[i] / (1. + abs(" <<
OpName <<
"_forget_gate[i]));\n";
1037 out << SP << SP <<
"}\n";
1039 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1040 out << SP << SP << SP << SP <<
OpName <<
"_forget_gate[i] = log(1. + exp("
1041 <<
OpName <<
"_forget_gate[i]));\n";
1042 out << SP << SP <<
"}\n";
1047 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1048 out << SP << SP << SP <<
OpName <<
"_cell_state[i] = " <<
OpName <<
"_input_gate[i] * "
1049 <<
OpName <<
"_cell_gate[i];\n";
1050 out << SP << SP <<
"}\n";
1052 if (fAttrInputForget == 0) {
1053 out << SP << SP <<
"if (seq == 0) {\n";
1054 if (!fNInitial_c.empty()) {
1056 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
1057 out << SP << SP << SP << SP <<
OpName <<
"_cell_state[i + offset] += "
1058 <<
OpName <<
"_forget_gate[i + offset] * " <<
OpName <<
"_initial_cell_state[i];\n";
1059 out << SP << SP << SP <<
"}\n";
1061 out << SP << SP <<
"} else {\n";
1064 if (fAttrDirection ==
"backward") {
1065 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
1068 out << SP << SP << SP <<
"size_t previous_offset = (seq - 1) * "
1072 out << SP << SP << SP <<
"size_t previous_offset = (index + 1) * "
1075 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
1076 out << SP << SP << SP << SP <<
OpName <<
"_cell_state[i + offset] += "
1077 <<
OpName <<
"_forget_gate[i + offset] * " <<
OpName <<
"_cell_state[i + previous_offset];\n";
1078 out << SP << SP << SP <<
"}\n";
1079 out << SP << SP <<
"}\n";
1086 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
1087 out << SP << SP << SP << SP <<
OpName <<
"_output_gate[i + offset] += tensor_"
1088 << fNP <<
"[i + " <<
p_offset <<
"] * " <<
OpName <<
"_cell_state[i + offset];\n";
1089 out << SP << SP << SP <<
"}\n";
1092 out << SP << SP << SP <<
"for (size_t i = 0; i < " <<
size <<
"; i++) {\n";
1093 out << SP << SP << SP << SP <<
OpName <<
"_output_gate[i + offset] += tensor_"
1094 << fNP <<
"[i + " <<
p_offset <<
"] * " <<
OpName <<
"_cell_state[i + offset];\n";
1095 out << SP << SP << SP <<
"}\n";
1100 if (fAttrClip > .0) {
1101 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1102 if (fType ==
"float") {
1103 out << SP << SP << SP <<
"float x = (" <<
OpName <<
"_output_gate[i] > " << -fAttrClip
1104 <<
") ? " <<
OpName <<
"_output_gate[i] : " << -fAttrClip <<
";\n";
1106 out << SP << SP << SP <<
OpName <<
"_output_gate[i] = (x < " << fAttrClip <<
") ? x : "
1107 << fAttrClip <<
";\n";
1108 out << SP << SP <<
"}\n";
1111 if (fAttrActivations[
direction * 3] ==
"Relu") {
1112 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1113 out << SP << SP << SP <<
"if (" <<
OpName <<
"_output_gate[i] < 0.)\n";
1114 out << SP << SP << SP << SP <<
OpName <<
"_output_gate[i] = 0.;\n";
1115 out << SP << SP <<
"}\n";
1116 }
else if (fAttrActivations[
direction * 3] ==
"Tanh") {
1117 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1118 if (fType ==
"float") {
1119 out << SP << SP << SP <<
"float ex = exp(-2 * " <<
OpName <<
"_output_gate[i]);\n";
1121 out << SP << SP << SP << SP <<
OpName <<
"_output_gate[i] = (1. - ex) / (1. + ex);\n";
1122 out << SP << SP <<
"}\n";
1123 }
else if (fAttrActivations[
direction * 3] ==
"Sigmoid") {
1124 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1125 out << SP << SP << SP << SP <<
OpName <<
"_output_gate[i] = 1. / (1. + exp(-" <<
OpName
1126 <<
"_output_gate[i]));\n";
1127 out << SP << SP <<
"}\n";
1128 }
else if (fAttrActivations[
direction * 3] ==
"Affine") {
1129 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1130 out << SP << SP << SP << SP <<
OpName <<
"_output_gate[i] = "
1131 << fAttrActivationAlpha[
direction * 3] <<
" * " <<
OpName <<
"_output_gate[i] + "
1132 << fAttrActivationBeta[
direction * 3] <<
";\n";
1133 out << SP << SP <<
"}\n";
1134 }
else if (fAttrActivations[
direction * 3] ==
"ScaledTanh") {
1135 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1136 if (fType ==
"float") {
1137 out << SP << SP << SP <<
"float ex = exp(-2 * " << fAttrActivationBeta[
direction * 3]
1138 <<
" * "<<
OpName <<
"_output_gate[i]);\n";
1140 out << SP << SP << SP << SP <<
OpName <<
"_output_gate[i] = "
1141 << fAttrActivationAlpha[
direction * 3] <<
" * (1. - ex) / (1. + ex);\n";
1142 out << SP << SP <<
"}\n";
1143 }
else if (fAttrActivations[
direction * 3] ==
"HardSigmoid") {
1144 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1145 if (fType ==
"float") {
1146 out << SP << SP << SP <<
"float a = " << fAttrActivationAlpha[
direction * 3] <<
" * "
1147 <<
OpName <<
"_output_gate[i] + " << fAttrActivationBeta[
direction * 3] <<
";\n";
1148 out << SP << SP << SP <<
"float b = (a > 0.) ? a : 0.;\n";
1150 out << SP << SP << SP << SP <<
OpName <<
"_output_gate[i] = (b < 1.) ? b : 1.;\n";
1151 out << SP << SP <<
"}\n";
1152 }
else if (fAttrActivations[
direction * 3] ==
"LeakyRelu") {
1153 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1154 out << SP << SP << SP <<
"if (" <<
OpName <<
"_output_gate[i] < 0.)\n";
1155 out << SP << SP << SP << SP <<
OpName <<
"_output_gate[i] = "
1156 << fAttrActivationAlpha[
direction * 3] <<
" * " <<
OpName <<
"_output_gate[i];\n";
1157 out << SP << SP <<
"}\n";
1158 }
else if (fAttrActivations[
direction * 3] ==
"ThresholdRelu") {
1159 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1160 out << SP << SP << SP <<
"if (" <<
OpName <<
"_output_gate[i] < "
1161 << fAttrActivationAlpha[
direction * 3] <<
")\n";
1162 out << SP << SP << SP << SP <<
OpName <<
"_output_gate[i] = 0.;\n";
1163 out << SP << SP <<
"}";
1164 }
else if (fAttrActivations[
direction * 3] ==
"Elu") {
1165 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1166 out << SP << SP << SP <<
"if (" <<
OpName <<
"_output_gate[i] < 0.)\n";
1167 out << SP << SP << SP << SP <<
OpName <<
"_output_gate[i] = "
1168 << fAttrActivationAlpha[
direction * 3] <<
" * exp(" <<
OpName <<
"_output_gate[i] - 1.);\n";
1169 out << SP << SP <<
"}\n";
1170 }
else if (fAttrActivations[
direction * 3] ==
"Softsign") {
1171 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1172 out << SP << SP << SP << SP <<
OpName <<
"_output_gate[i] = " <<
OpName
1173 <<
"_output_gate[i] / (1. + abs(" <<
OpName <<
"_output_gate[i]));\n";
1174 out << SP << SP <<
"}\n";
1176 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1177 out << SP << SP << SP << SP <<
OpName <<
"_output_gate[i] = log(1. + exp("
1178 <<
OpName <<
"_output_gate[i]));\n";
1179 out << SP << SP <<
"}\n";
1183 out << SP << SP <<
"std::copy(" <<
OpName <<
"_cell_state + offset, " <<
OpName
1184 <<
"_cell_state + offset + " <<
size <<
", "<<
OpName <<
"_new_cell_state + offset);\n";
1186 if (fAttrClip > .0) {
1187 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1188 if (fType ==
"float") {
1189 out << SP << SP << SP <<
"float x = (" <<
OpName <<
"_new_cell_state[i] > " << -fAttrClip
1190 <<
") ? " <<
OpName <<
"_new_cell_state[i] : " << -fAttrClip <<
";\n";
1192 out << SP << SP << SP <<
OpName <<
"_new_cell_state[i] = (x < " << fAttrClip <<
") ? x : "
1193 << fAttrClip <<
";\n";
1194 out << SP << SP <<
"}\n";
1197 if (fAttrActivations[
direction * 3 + 2] ==
"Relu") {
1198 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1199 out << SP << SP << SP <<
"if (" <<
OpName <<
"_new_cell_state[i] < 0.)\n";
1200 out << SP << SP << SP << SP <<
OpName <<
"_new_cell_state[i] = 0.;\n";
1201 out << SP << SP <<
"}\n";
1202 }
else if (fAttrActivations[
direction * 3 + 2] ==
"Tanh") {
1203 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1204 if (fType ==
"float") {
1205 out << SP << SP << SP <<
"float ex = exp(-2 * " <<
OpName <<
"_new_cell_state[i]);\n";
1207 out << SP << SP << SP << SP <<
OpName <<
"_new_cell_state[i] = (1. - ex) / (1. + ex);\n";
1208 out << SP << SP <<
"}\n";
1209 }
else if (fAttrActivations[
direction * 3 + 2] ==
"Sigmoid") {
1210 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1211 out << SP << SP << SP << SP <<
OpName <<
"_new_cell_state[i] = 1. / (1. + exp(-" <<
OpName
1212 <<
"_new_cell_state[i]));\n";
1213 out << SP << SP <<
"}\n";
1214 }
else if (fAttrActivations[
direction * 3 + 2] ==
"Affine") {
1215 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1216 out << SP << SP << SP << SP <<
OpName <<
"_new_cell_state[i] = "
1217 << fAttrActivationAlpha[
direction * 3 + 2] <<
" * " <<
OpName <<
"_new_cell_state[i] + "
1218 << fAttrActivationBeta[
direction * 3 + 2] <<
";\n";
1219 out << SP << SP <<
"}\n";
1220 }
else if (fAttrActivations[
direction * 3 + 2] ==
"ScaledTanh") {
1221 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1222 if (fType ==
"float") {
1223 out << SP << SP << SP <<
"float ex = exp(-2 * " << fAttrActivationBeta[
direction * 3 + 2]
1224 <<
" * "<<
OpName <<
"_new_cell_state[i]);\n";
1226 out << SP << SP << SP << SP <<
OpName <<
"_new_cell_state[i] = "
1227 << fAttrActivationAlpha[
direction * 3 + 2] <<
" * (1. - ex) / (1. + ex);\n";
1228 out << SP << SP <<
"}\n";
1229 }
else if (fAttrActivations[
direction * 3 + 2] ==
"HardSigmoid") {
1230 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1231 if (fType ==
"float") {
1232 out << SP << SP << SP <<
"float a = " << fAttrActivationAlpha[
direction * 3 + 2] <<
" * "
1233 <<
OpName <<
"_new_cell_state[i] + " << fAttrActivationBeta[
direction * 3 + 2] <<
";\n";
1234 out << SP << SP << SP <<
"float b = (a > 0.) ? a : 0.;\n";
1236 out << SP << SP << SP << SP <<
OpName <<
"_new_cell_state[i] = (b < 1.) ? b : 1.;\n";
1237 out << SP << SP <<
"}\n";
1238 }
else if (fAttrActivations[
direction * 3 + 2] ==
"LeakyRelu") {
1239 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1240 out << SP << SP << SP <<
"if (" <<
OpName <<
"_new_cell_state[i] < 0.)\n";
1241 out << SP << SP << SP << SP <<
OpName <<
"_new_cell_state[i] = "
1242 << fAttrActivationAlpha[
direction * 3 + 2] <<
" * " <<
OpName <<
"_new_cell_state[i];\n";
1243 out << SP << SP <<
"}\n";
1244 }
else if (fAttrActivations[
direction * 3 + 2] ==
"ThresholdRelu") {
1245 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1246 out << SP << SP << SP <<
"if (" <<
OpName <<
"_new_cell_state[i] < "
1247 << fAttrActivationAlpha[
direction * 3 + 2] <<
")\n";
1248 out << SP << SP << SP << SP <<
OpName <<
"_new_cell_state[i] = 0.;\n";
1249 out << SP << SP <<
"}";
1250 }
else if (fAttrActivations[
direction * 3 + 2] ==
"Elu") {
1251 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1252 out << SP << SP << SP <<
"if (" <<
OpName <<
"_new_cell_state[i] < 0.)\n";
1253 out << SP << SP << SP << SP <<
OpName <<
"_new_cell_state[i] = "
1254 << fAttrActivationAlpha[
direction * 3 + 2] <<
" * exp(" <<
OpName <<
"_new_cell_state[i] - 1.);\n";
1255 out << SP << SP <<
"}\n";
1256 }
else if (fAttrActivations[
direction * 3 + 2] ==
"Softsign") {
1257 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1258 out << SP << SP << SP << SP <<
OpName <<
"_new_cell_state[i] = " <<
OpName
1259 <<
"_new_cell_state[i] / (1. + abs(" <<
OpName <<
"_new_cell_state[i]));\n";
1260 out << SP << SP <<
"}\n";
1262 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1263 out << SP << SP << SP << SP <<
OpName <<
"_new_cell_state[i] = log(1. + exp("
1264 <<
OpName <<
"_new_cell_state[i]));\n";
1265 out << SP << SP <<
"}\n";
1269 out << SP << SP <<
"for (size_t i = offset; i < offset + " <<
size <<
"; i++) {\n";
1270 out << SP << SP << SP <<
OpName <<
"_hidden_state[i] = " <<
OpName <<
"_output_gate[i] * "
1271 <<
OpName <<
"_new_cell_state[i];\n";
1272 out << SP << SP <<
"}\n";
1277 if (!fNSequence_lens.empty()) {
1278 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
1279 out << SP << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1280 out << SP << SP << SP <<
"if (seq >= tensor_" << fNSequence_lens <<
"[batch]) {\n";
1282 out << SP << SP << SP << SP << SP <<
"for (size_t h = 0; h < " << fAttrHiddenSize <<
"; h++) {\n";
1283 out << SP << SP << SP << SP << SP << SP <<
"size_t idx = seq * "
1285 <<
" + batch * " << fAttrHiddenSize <<
" + h;\n";
1286 out << SP << SP << SP << SP << SP << SP <<
OpName <<
"_cell_state[idx] = 0.;\n";
1287 out << SP << SP << SP << SP << SP << SP <<
OpName <<
"_hidden_state[idx] = 0.;\n";
1288 out << SP << SP << SP << SP << SP <<
"}\n";
1290 out << SP << SP << SP <<
"}\n";
1291 out << SP << SP <<
"}\n";
1296 if (fAttrLayout == 0) {
1297 if (!fNY_h.empty()) {
1299 if (fNSequence_lens.empty()) {
1301 if (fAttrDirection ==
"backward") {
1302 out << SP <<
"std::copy(" <<
OpName <<
"_hidden_state, " <<
OpName <<
"_hidden_state + "
1303 <<
y_h_size <<
", tensor_" << fNY_h <<
");\n";
1306 out << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + " <<
offset <<
", " <<
OpName
1307 <<
"_hidden_state + " <<
offset <<
" + " <<
y_h_size <<
", tensor_" << fNY_h <<
");\n";
1311 <<
"_hidden_state + " << 2 *
y_h_size <<
", tensor_" << fNY_h <<
" + " <<
y_h_size <<
");\n";
1314 if (fAttrDirection ==
"backward") {
1315 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1316 out << SP << SP <<
"size_t offset = batch * " << fAttrHiddenSize <<
";\n";
1317 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1318 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + offset);\n";
1321 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1322 out << SP << SP <<
"size_t seq = " <<
"tensor_" << fNSequence_lens <<
"[batch] - 1;\n";
1324 <<
" + batch * " << fAttrHiddenSize <<
";\n";
1325 out << SP << SP <<
"size_t y_h_offset = batch * " << fAttrHiddenSize <<
";\n";
1326 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1327 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + y_h_offset);\n";
1331 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1332 out << SP << SP <<
"size_t offset = " <<
batch_size * fAttrHiddenSize
1333 <<
" + batch * " << fAttrHiddenSize <<
";\n";
1334 out << SP << SP <<
"size_t y_h_offset = " <<
batch_size * fAttrHiddenSize
1335 <<
" + batch * " << fAttrHiddenSize <<
";\n";
1336 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1337 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + y_h_offset);\n";
1342 if (!fNY_c.empty()) {
1344 if (fNSequence_lens.empty()) {
1346 if (fAttrDirection ==
"backward") {
1347 out << SP <<
"std::copy(" <<
OpName <<
"_cell_state, " <<
OpName <<
"_hidden_state + "
1348 <<
y_h_size <<
", tensor_" << fNY_c <<
");\n";
1352 <<
"_cell_state + " <<
offset <<
" + " <<
y_h_size <<
", tensor_" << fNY_c <<
");\n";
1356 <<
"_cell_state + " << 2 *
y_h_size <<
", tensor_" << fNY_c <<
" + " <<
y_h_size <<
");\n";
1359 if (fAttrDirection ==
"backward") {
1360 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1361 out << SP << SP <<
"size_t offset = batch * " << fAttrHiddenSize <<
";\n";
1362 out << SP << SP <<
"std::copy(" <<
OpName <<
"_cell_state + offset, " <<
OpName
1363 <<
"_cell_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_c <<
" + offset);\n";
1366 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1367 out << SP << SP <<
"size_t seq = " <<
"tensor_" << fNSequence_lens <<
"[batch] - 1;\n";
1369 <<
" + batch * " << fAttrHiddenSize <<
";\n";
1370 out << SP << SP <<
"size_t y_h_offset = batch * " << fAttrHiddenSize <<
";\n";
1371 out << SP << SP <<
"std::copy(" <<
OpName <<
"_cell_state + offset, " <<
OpName
1372 <<
"_cell_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_c <<
" + y_h_offset);\n";
1376 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1377 out << SP << SP <<
"size_t offset = " <<
batch_size * fAttrHiddenSize
1378 <<
" + batch * " << fAttrHiddenSize <<
";\n";
1379 out << SP << SP <<
"size_t y_h_offset = " <<
batch_size * fAttrHiddenSize
1380 <<
" + batch * " << fAttrHiddenSize <<
";\n";
1381 out << SP << SP <<
"std::copy(" <<
OpName <<
"_cell_state + offset, " <<
OpName
1382 <<
"_cell_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_c <<
" + y_h_offset);\n";
1391 out << SP <<
"for (size_t seq = 0; seq < " <<
seq_length <<
"; seq++) {\n";
1392 out << SP << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1394 <<
" + " <<
direction *
batch_size * fAttrHiddenSize <<
" + batch * " << fAttrHiddenSize <<
";\n";
1397 out << SP << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1398 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY <<
" + y_offset);\n";
1399 out << SP << SP <<
"}\n";
1403 if (!fNY_h.empty()) {
1405 if (fAttrDirection ==
"backward") {
1406 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1407 out << SP << SP <<
"size_t offset = batch * " << fAttrHiddenSize <<
";\n";
1408 out << SP << SP <<
"size_t y_h_offset = batch * " <<
num_directions * fAttrHiddenSize <<
";\n";
1409 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1410 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + y_h_offset);\n";
1413 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1414 if (fNSequence_lens.empty()) {
1415 out << SP << SP <<
"size_t seq = " <<
seq_length - 1 <<
";\n";
1417 out << SP << SP <<
"size_t seq = " <<
"tensor_" << fNSequence_lens <<
"[batch] - 1;\n";
1420 <<
" + batch * " << fAttrHiddenSize <<
";\n";
1421 out << SP << SP <<
"size_t y_h_offset = batch * " <<
num_directions * fAttrHiddenSize <<
";\n";
1422 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1423 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + y_h_offset);\n";
1427 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1428 out << SP << SP <<
"size_t offset = " <<
batch_size * fAttrHiddenSize <<
" + batch * "
1429 << fAttrHiddenSize <<
";\n";
1430 out << SP << SP <<
"size_t y_h_offset = batch * " <<
num_directions * fAttrHiddenSize <<
" + "
1431 << fAttrHiddenSize <<
";\n";
1432 out << SP << SP <<
"std::copy(" <<
OpName <<
"_hidden_state + offset, " <<
OpName
1433 <<
"_hidden_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_h <<
" + y_h_offset);\n";
1438 if (!fNY_c.empty()) {
1440 if (fAttrDirection ==
"backward") {
1441 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1442 out << SP << SP <<
"size_t offset = batch * " << fAttrHiddenSize <<
";\n";
1443 out << SP << SP <<
"size_t y_h_offset = batch * " <<
num_directions * fAttrHiddenSize <<
";\n";
1444 out << SP << SP <<
"std::copy(" <<
OpName <<
"_cell_state + offset, " <<
OpName
1445 <<
"_cell_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_c <<
" + y_h_offset);\n";
1448 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1449 if (fNSequence_lens.empty()) {
1450 out << SP << SP <<
"size_t seq = " <<
seq_length - 1 <<
";\n";
1452 out << SP << SP <<
"size_t seq = " <<
"tensor_" << fNSequence_lens <<
"[batch] - 1;\n";
1455 <<
" + batch * " << fAttrHiddenSize <<
";\n";
1456 out << SP << SP <<
"size_t y_h_offset = batch * " <<
num_directions * fAttrHiddenSize <<
";\n";
1457 out << SP << SP <<
"std::copy(" <<
OpName <<
"_cell_state + offset, " <<
OpName
1458 <<
"_cell_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_c <<
" + y_h_offset);\n";
1462 out << SP <<
"for (size_t batch = 0; batch < " <<
batch_size <<
"; batch++) {\n";
1463 out << SP << SP <<
"size_t offset = " <<
batch_size * fAttrHiddenSize <<
" + batch * "
1464 << fAttrHiddenSize <<
";\n";
1465 out << SP << SP <<
"size_t y_h_offset = batch * " <<
num_directions * fAttrHiddenSize <<
" + "
1466 << fAttrHiddenSize <<
";\n";
1467 out << SP << SP <<
"std::copy(" <<
OpName <<
"_cell_state + offset, " <<
OpName
1468 <<
"_cell_state + offset + " << fAttrHiddenSize <<
", tensor_" << fNY_c <<
" + y_h_offset);\n";