Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ParseEinsum.cxx
Go to the documentation of this file.
3#include "onnx_proto3.pb.h"
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
9ParserFuncSignature ParseEinsum = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
10
12 int input_size = nodeproto.input_size();
13 std::vector<std::string> input_names(input_size);
14 for (int i = 0; i < input_size; i++) {
15 if (!parser.IsRegisteredTensorType(nodeproto.input(i))){
16 throw std::runtime_error("TMVA::SOFIE ONNX Parser Einsum op has input tensor " + nodeproto.input(i)
17 + " but its type is not yet registered");
18 }
19 if (i == 0)
20 input_type = parser.GetTensorType(nodeproto.input(0));
21 if (parser.GetTensorType(nodeproto.input(i)) != input_type) {
22 throw std::runtime_error("TMVA::SOFIE ONNX parser Einsum op has input tensors of different types: " +
23 nodeproto.input(i) + " : " + ConvertTypeToString(parser.GetTensorType(nodeproto.input(2))) +
24 " and " + nodeproto.input(0) + " : " + ConvertTypeToString(input_type));
25 }
26 input_names[i] = nodeproto.input(i);
27 }
28
29 // equation attribute should be existing
30 if (nodeproto.attribute_size() == 0)
31 throw std::runtime_error("TMVA::SOFIE ONNX Parser Einsum op has no attribute defining the equation");
32 if (nodeproto.attribute(0).name() != "equation")
33 throw std::runtime_error("TMVA::SOFIE ONNX Parser Einsum op has wrong attribute name: " + nodeproto.attribute(0).name());
34 std::string equation = nodeproto.attribute(0).s();
35
36 std::unique_ptr<ROperator> op;
37 std::string output_name = nodeproto.output(0);
38
39
40
41 switch (input_type) {
44 break;
45 default:
46 throw std::runtime_error("TMVA::SOFIE - Unsupported - Einsum Operator does not yet support input type " +
47 std::to_string(static_cast<int>(input_type)));
48 }
49
50 // Infer the output type
53 }
54
55 return op;
56};
57
58
59} // namespace SOFIE
60} // namespace Experimental
61} // namespace TMVA
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
ParserFuncSignature ParseEinsum
std::string ConvertTypeToString(ETensorType type)
create variable transformations