Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ROperator_Einsum.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_ROperator_Einsum
2#define TMVA_SOFIE_ROperator_Einsum
3
5#include "TMVA/ROperator.hxx"
6#include "TMVA/RModel.hxx"
7
8#include <sstream>
9#include <cassert>
10
11namespace TMVA{
12namespace Experimental{
13namespace SOFIE{
14
15
16
17template<typename T>
19private:
20
21 bool fIsInputBoolTensor = false;
22
23
24 std::vector<std::string> fNInputs;
25 std::string fNY;
26
27 std::vector<std::string> fInputLabels;
28 std::string fOutputLabels;
29 std::string fSumLabels; // string containing the reducing labels
30 std::string fGemmType;
31
32 std::vector<int> fSumDims; // dimension of the labels we use to perform summing
33
34 std::vector<std::vector<size_t>> fShapeInputs;
35 std::vector<size_t> fShapeY;
36
37
38
39
40public:
42 ROperator_Einsum(const std::string & equation, const std::vector<std::string> & namesX, const std::string & nameY):
43 fNInputs(namesX.size()), fNY(UTILITY::Clean_name(nameY))
44 {
45 for (size_t i = 0; i < namesX.size(); i++)
47
48 // parse teh equations to find labels
50 throw std::runtime_error("TMVA SOFIE Einsum Op: Error parsing the equation " + equation);
51
52 fInputTensorNames.resize(fNInputs.size());
53 std::transform(fNInputs.begin(), fNInputs.end(), fInputTensorNames.begin(),
54 [](const std::string& s) -> std::string_view { return s; });
56 }
57
58 bool ParseEquation(const std::string & input_equation) {
59 std::string eq (input_equation);
60 // remove blank spaces
61 eq.erase(std::remove(eq.begin(), eq.end(), ' '), eq.end());
62 // look for '->' finding the first occurrence
63 std::string target("->");
64 size_t pos = eq.find(target);
65 if (pos == std::string::npos) {
66 std::cout << "'->' not found in the equation." << std::endl;
67 return false;
68 }
69 // Substring before the target
70 std::string inputStr = eq.substr(0, pos);
71 // Substring after the target
72 std::string outputStr = eq.substr(pos + target.length());
73
74 // look now for the group of labels separated by "," in the inputs
75 size_t start = 0;
76 size_t pos1 = 0;
77 // Extract labels separated by commas
78 while ((pos1 = inputStr.find(',', start)) != std::string::npos) {
79 std::string labels = inputStr.substr(start, pos1 - start);
80 fInputLabels.push_back(labels);
81 start = pos1 + 1; // Move past the comma
82 }
83 // Add the last label (after the final comma)
84 fInputLabels.push_back(inputStr.substr(start));
85
86 // check if labels are ok and do not contain alphanumeric characters
87 auto checkLabel = [](const std::string & label) {
88 for (char c : label) {
89 if (!std::isalnum(c)) {
90 std::cout << "Wrong tensor label " << label << std::endl;
91 return false;
92 }
93 }
94 // empty label is OK , is a scalar
95 return true;
96 };
97 for (auto & label : fInputLabels) {
98 if (!checkLabel(label)) return false;
99 }
100 if (!checkLabel(outputStr)) {
101 std::cout << "invalid output label" << std::endl;
102 return false;
103 }
105
106 if (fInputLabels.size() != fNInputs.size()) {
107 std::cout << "Invalid number of input labels found " << fInputLabels.size() << " for #inputs = " << fNInputs.size() << std::endl;
108 return false;
109 }
110 // ignore for the time being broadcasting, empty output label and other features
111 return true;
112 }
113
114 // type of output given input
115 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
116 return input;
117 }
118
119 // shape of output tensors given input tensors
120 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
121 // assume now inputs have same shape (no broadcasting)
122 auto ret = std::vector<std::vector<size_t>>(1, input[0]); // return vector size 1 with first input
123 return ret;
124 }
125
126 void Initialize(RModel& model) override {
127 // input must be a graph input, or already initialized intermediate tensor
128 size_t i = 0;
129 std::map<char, int> labelsMap;
130 for ( auto & name : fNInputs) {
132 throw std::runtime_error(std::string("TMVA SOFIE Einsum Op Input Tensor ") + name + "is not found in model");
133
134 // if (model.IsDynamicTensor(name) || model.IsDimInputTensor(name) ) {
135 // // not yet supported
136 // } else {
137 auto shape = model.GetTensorShape(name);
138 fShapeInputs.push_back(shape);
139 //}
140 // fill the label maps
141 std::string labels = fInputLabels[i];
142 for (size_t j = 0; j < shape.size(); j++) {
143 if (j >= labels.length()) {
144 throw std::runtime_error(std::string("TMVA SOFIE Einsum Op Input Tensor has invalid label or shape ") + labels + " " + ConvertShapeToString(shape));
145 }
146 labelsMap[labels[j]] = shape[j];
147 }
148 i++;
149 }
150 // get output shape from label maps
151 for (char l : fOutputLabels) {
152 if (labelsMap.count(l) == 0)
153 throw std::runtime_error(std::string("TMVA SOFIE Einsum Op : output label ") + std::string(&l) + " is not present in inputs");
154 fShapeY.push_back(labelsMap[l]);
155 }
156 // we need to get the labels we are going to sum
157 // these are the labels not present in the output
158 fSumLabels = "";
159 fSumDims.clear();
160 for (auto & l : labelsMap) {
161 if (fOutputLabels.find(l.first) == std::string::npos) {
162 fSumLabels += l.first;
163 fSumDims.push_back(l.second);
164 }
165 }
166
167 // check if we can use MatMul for EinSum
168 // need to have one sum labels in the last 2 and have the first in common
169 if (fNInputs.size() == 2 && fSumDims.size() == 1 && fShapeInputs[0].size() >=2 && fShapeInputs[1].size() >= 2 ) {
170 // find positions of dum labels
171 char l = fSumLabels[0];
172 size_t pos1 = fInputLabels[0].find(l);
173 size_t pos2 = fInputLabels[1].find(l);
174 // check if summing is done in the last 2 indices of tensor
175
176 if (pos1 == fInputLabels[0].length() - 1 && pos2 == fInputLabels[1].length() - 2)
177 fGemmType = "nn";
178 else if (pos1 == fInputLabels[0].length() - 2 && pos2 == fInputLabels[1].length() - 2)
179 fGemmType = "tn";
180 else if (pos1 == fInputLabels[0].length() - 1 && pos2 == fInputLabels[1].length() - 1)
181 fGemmType = "nt";
182 else if (pos1 == fInputLabels[0].length() - 2 && pos2 == fInputLabels[1].length() - 1)
183 fGemmType = "tt";
184 else
185 fGemmType = "";
186 }
187
189
190 if (model.Verbose()) {
191 std::cout << "Einsum op ";
192 for (i = 0; i < fNInputs.size(); i++) {
193 if (i > 0) std::cout << ", ";
194 std::cout << fNInputs[i] << " " << ConvertShapeToString(fShapeInputs[i]) << " " << fInputLabels[i];
195 }
196 std::cout << " --> " << fNY << " " << ConvertShapeToString(fShapeY) << " " << fOutputLabels << std::endl;
197 }
198
199 }
200
201 std::string GenerateInitCode() override {
202 std::stringstream out;
203 return out.str();
204 }
205
206 std::string Generate(std::string opName) override {
207
208 if (fIsOutputConstant) return "";
209
210 opName = "op_" + opName;
211
212 if (fShapeY.size() != fOutputLabels.length()) {
213 throw std::runtime_error("TMVA SOFIE Einsum Op called to Generate without being initialized first");
214 }
215
216 // function to write compute expression index from strides
217 auto tensorIndex = [](const std::vector<size_t> & stride, const std::string & labels) {
218 std::stringstream strst;
219 int dims = labels.length();
220 // scalar case
221 if (dims == 0) return std::string("0");
222 assert (dims == (int) stride.size());
223 for (int i = 0; i < dims-1; i++) {
224 strst << stride[i] << "*" << std::string{labels[i]} << " + ";
225 }
226 strst << std::string{labels[dims-1]};
227 return strst.str();
228 };
229
230 std::stringstream out;
231 out << SP << "\n//-------- Einsum \n";
232
234
235 // loops on the output indices i0,....iN
236 if (fGemmType.empty()) {
237 int outDims = fShapeY.size();
238 int inDims = fSumLabels.length();
239 assert(outDims == int(fOutputLabels.size()));
240 assert(inDims == int(fSumDims.size()));
241 for (int i = 0; i < outDims; i++) {
242 for (int j = 0; j < i; j++) out << SP;
243 std::string l {fOutputLabels[i]};
244 out << "for (int " << l << " = 0; " << l << " < " << fShapeY[i] << "; " << l << "++) {\n";
245 }
246 // reset to zero output tensor
248
249 for (int j = 0; j < outDims; j++) out << SP;
250 out << "tensor_" << fNY << "[" << outputIndex << "] = 0;\n";
251 // loop on remaining indices where we perform the sum
252 for (int i = 0; i < inDims; i++) {
253 for (int j = 0; j < outDims + i; j++) out << SP;
254 std::string l {fSumLabels[i]};
255 out << "for (int " << l << " = 0; " << l << " < " << fSumDims[i] << "; " << l << "++) {\n";
256 }
257 for (int j = 0; j < outDims+inDims; j++) out << SP;
258 // tensor_out[outId] += t_in_0[ind0] * t_in1[ind1] *....
259 out << "tensor_" << fNY << "[" << outputIndex << "] +=\n";
260 for (size_t k = 0; k < fNInputs.size(); k++) {
263 for (int j = 0; j < outDims+inDims; j++) out << SP;
264 out << SP << "tensor_" << fNInputs[k] << "[" << inputIndex << "]";
265 if (fNInputs.size() > 1 && k < fNInputs.size() -1) out << " *\n";
266 }
267 out << ";\n";
268
269 // end loops on all indices i0,....iN
270 for (int i = outDims+inDims-1; i >= 0; i--) {
271 for (int j = 0; j < i; j++) out << SP;
272 out << "}\n";
273 }
274
275
276 } else {
277 // case we use Gemm
278 out << SP << "// implementing Einsum using MatMul \n";
279 // note A is second input and B first one - due to transpose of Fortran rep.
280 out << SP << "char " << opName << "_transA = '" << fGemmType[0] << "';\n";
281 out << SP << "char " << opName << "_transB = '" << fGemmType[1] << "';\n";
282 // need to consider case A and B have dim > 2 (for MatMul)
283 int64_t dimA = fShapeInputs[0].size();
284 int64_t dimB = fShapeInputs[1].size();
285
286 auto m = (fGemmType[0] == 't') ? fShapeInputs[0][dimA-1] : fShapeInputs[0][dimA-2];
287 auto n = (fGemmType[1] == 't') ? fShapeInputs[1][dimB-2] : fShapeInputs[1][dimB-1];
288 auto k = (fGemmType[0] == 't') ? fShapeInputs[0][dimA-2] : fShapeInputs[0][dimA-1];
289
290 out << SP << "int " << opName << "_m = " << m << ";\n";
291 out << SP << "int " << opName << "_n = " << n << ";\n";
292 out << SP << "int " << opName << "_k = " << k << ";\n";
293 out << SP << "float " << opName << "_alpha = 1.0;\n";
294 out << SP << "float " << opName << "_beta = 0.0;\n";
295 out << SP << "int " << opName << "_lda = " << ((fGemmType[0] == 't') ? m : k) << ";\n";
296 out << SP << "int " << opName << "_ldb = " << ((fGemmType[1] == 't') ? k : n) << ";\n";
297
300
301 int stackDims = fShapeY.size()-2;
302 for (int i = 0; i < stackDims; i++) {
303 for (int j = 0; j < i; j++) out << SP;
304 std::string l {fOutputLabels[i]};
305 out << "for (int " << l << " = 0; " << l << " < " << fShapeY[i] << "; " << l << "++) {\n";
306 }
307 auto tensorOffset = [](const std::vector<size_t> & stride, const std::string & labels) {
308 std::stringstream strst;
309 int dims = labels.length()-2;
310 // scalar case
311 if (dims == 0) return std::string("0");
312 assert (dims +2 == (int) stride.size());
313 for (int i = 0; i < dims; i++) {
314 strst << stride[i] << "*" << std::string{labels[i]};
315 if (i < dims-1) strst << " + ";
316 }
317 return strst.str();
318 };
319 // only float type supported
320 out << SP << "BLAS::sgemm_(&" << opName << "_transB, &" << opName << "_transA, &" << opName
321 << "_n, &" << opName << "_m, &" << opName << "_k, &" << opName << "_alpha, "
322 << "&tensor_" << fNInputs[1] << "[" << tensorOffset(inputStrideB, fInputLabels[1])
323 << "], &" << opName << "_ldb, "
324 << "&tensor_" << fNInputs[0] << "[" << tensorOffset(inputStrideA, fInputLabels[0] ) << "], &" << opName << "_lda, &" << opName << "_beta, "
325 << "&tensor_" << fNY << "[" << tensorOffset(outputStride,fOutputLabels) << "], &" << opName << "_n);\n";
326
327
328 for (int i = stackDims-1; i >= 0; i--) {
329 for (int j = 0; j < i; j++) out << SP;
330 out << "}\n";
331 }
332
333 }
334
335
336 return out.str();
337 }
338
339 std::vector<std::string> GetBlasRoutines() override {
340 return { std::string("Gemm") };
341 }
342};
343
344}//SOFIE
345}//Experimental
346}//TMVA
347
348
349#endif //TMVA_SOFIE_ROperator_Einsum
#define c(i)
Definition RSha256.hxx:101
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t target
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h length
char name[80]
Definition TGX11.cxx:110
const_iterator begin() const
const_iterator end() const
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector< Dim > dim_shape)
Definition RModel.cxx:200
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:95
const ETensorType & GetTensorType(std::string name) const
Definition RModel.cxx:67
const std::vector< size_t > & GetTensorShape(std::string name) const
Definition RModel.cxx:29
std::vector< std::vector< size_t > > fShapeInputs
bool ParseEquation(const std::string &input_equation)
std::string Generate(std::string opName) override
std::vector< std::vector< size_t > > ShapeInference(std::vector< std::vector< size_t > > input) override
ROperator_Einsum(const std::string &equation, const std::vector< std::string > &namesX, const std::string &nameY)
std::vector< std::string > GetBlasRoutines() override
std::vector< ETensorType > TypeInference(std::vector< ETensorType > input) override
std::vector< std::string_view > fInputTensorNames
Definition ROperator.hxx:46
bool fIsOutputConstant
flag to identify if operator has a constant output (no need to generate code)
Definition ROperator.hxx:44
const std::string SP
space used to correctly indent the generated C++ code
Definition ROperator.hxx:42
std::vector< std::string_view > fOutputTensorNames
Definition ROperator.hxx:47
const Int_t n
Definition legend1.C:16
std::string Clean_name(std::string input_tensor_name)
std::vector< size_t > ComputeStrideFromShape(const std::vector< size_t > &shape)
compute stride of a tensor given its shape (assume layout is row-major)
std::string ConvertShapeToString(std::vector< size_t > shape)
create variable transformations
TMarker m
Definition textangle.C:8
TLine l
Definition textangle.C:4