Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_ConvTranspose.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROPERATOR_CONVTRANSPOSE_HXX
2#define TMVA_SOFIE_ROPERATOR_CONVTRANSPOSE_HXX
3
5#include <TMVA/ROperator.hxx>
6#include <TMVA/RModel.hxx>
7
8#include <memory>
9#include <sstream>
10#include <algorithm>
11#include <stdexcept>
12#include <vector>
13#include <cassert>
14
15namespace TMVA {
16namespace Experimental {
17namespace SOFIE {
18
19/*! \brief Transposed Convolution operator
20 *
21 * Inference code generation for a transposed convolution layer.
22 * See the <a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#convtranspose">ONNX documentation</a> for
23 * details about the transposed conv layer.
24 */
25template <typename T>
27private:
28 std::string fAttrAutopad;
29 std::vector<size_t> fAttrDilations;
30 size_t fAttrGroup;
31 std::vector<size_t> fAttrKernelShape;
32 std::vector<size_t> fAttrOutputPadding;
33 std::vector<size_t> fAttrOutputShape;
34 std::vector<size_t> fAttrPads;
35 std::vector<size_t> fAttrStrides;
36
37 std::string fNX;
38 std::string fNW;
39 std::string fNB;
40 std::string fNBroadcastedB;
41 std::string fNY;
42
43 std::string fConvK;
44 std::string fImcol;
45
46 std::vector<size_t> fShapeX;
47 std::vector<size_t> fShapeW;
48 std::vector<size_t> fShapeB;
49 std::vector<size_t> fShapeY;
50
51 std::string fType;
52
53 size_t fDim; // dimension of the convolution
54
55public:
56 /*! Default constructor of ROperator_ConvTranspose */
58
59 /*! \brief Constructor of ROperator_ConvTranspose from the attributes
60 *
61 * \param autopad padding
62 * \param dilations dilations of the kernel
63 * \param group number of groups
64 * \param kernelShape shape of the kernel
65 * \param outputPadding padding of the output
66 * \param outputShape shape of the output
67 * \param pads padding of the input
68 * \param strides strides
69 * \param nameX name of the input
70 * \param nameW name of the weight
71 * \param nameB name of the bias
72 * \param nameY name of the output
73 */
74 ROperator_ConvTranspose(std::string autopad, std::vector<size_t> dilations, size_t group,
75 std::vector<size_t> kernelShape, std::vector<size_t> outputPadding,
76 std::vector<size_t> outputShape, std::vector<size_t> pads, std::vector<size_t> strides,
77 std::string nameX, std::string nameW, std::string nameB, std::string nameY)
80 fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)), fNB(UTILITY::Clean_name(nameB)),
81 fNY(UTILITY::Clean_name(nameY))
82 {
85 if (!fNB.empty()) {
86 fInputTensorNames.emplace_back(fNB);
87 }
88
89 if (std::is_same<T, float>::value) {
90 fType = "float";
91 } else {
92 throw std::runtime_error("TMVA SOFIE Encountered unsupported type parsing a Conv operator");
93 }
94 }
95
96 /*! \brief Infers the type of the output tensor
97 * \param input type of the input tensors
98 */
99 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override
100 {
101 ETensorType out = input[0];
102 return {out};
103 }
104
105 /*! \brief Infers the shape of the input tensors
106 * \param input shape of the input tensors
107 */
108 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> /*input*/) override;
109
110 /*! \brief Initialize the model
111 * \param model Model
112 */
113 void Initialize(RModel &) override;
114
115 /*! \brief Generate code for initializing the op
116 */
117 std::string GenerateInitCode() override;
118
119 /*! \brief Generate the inference code
120 * \param opName name of the operator
121 */
122 std::string Generate(std::string opName) override;
123
124 /*! \brief Returns the blas routines needed to compile the generated code
125 */
126 std::vector<std::string> GetBlasRoutines() override { return { std::string("Gemm"), std::string("Axpy") }; }
127};
128
129} // namespace SOFIE
130} // namespace Experimental
131} // namespace TMVA
132
133// Implementation of the ROperator_ConvTranspose class
135
136#endif
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
std::vector< std::string > GetBlasRoutines() override
Returns the blas routines needed to compile the generated code.
ROperator_ConvTranspose(std::string autopad, std::vector< size_t > dilations, size_t group, std::vector< size_t > kernelShape, std::vector< size_t > outputPadding, std::vector< size_t > outputShape, std::vector< size_t > pads, std::vector< size_t > strides, std::string nameX, std::string nameW, std::string nameB, std::string nameY)
Constructor of ROperator_ConvTranspose from the attributes.
void Initialize(RModel &) override
Initialize the model.
ROperator_ConvTranspose()
Default constructor of ROperator_ConvTranspose.
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input) override
Infers the type of the output tensor.
std::string GenerateInitCode() override
Generate code for initializing the op.
std::string Generate(std::string opName) override
Generate the inference code.
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > >) override
Infers the shape of the input tensors.
std::vector< std::string_view > fInputTensorNames
Definition ROperator.hxx:46
std::vector< std::string_view > fOutputTensorNames
Definition ROperator.hxx:47
create variable transformations