Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
ParseScatterElements.cxx
Go to the documentation of this file.
3#include "onnx_proto3.pb.h"
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
10
11 if (nodeproto.input_size() != 3) {
12 throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterElements op has invalid input size");
13 }
14 // data is input 0
15 if (!parser.IsRegisteredTensorType(nodeproto.input(0))){
16 throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterElements op has input tensor " + nodeproto.input(0)
17 + " but its type is not yet registered");
18 }
19 if (!parser.IsRegisteredTensorType(nodeproto.input(1))){
20 throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterElements op has input tensor " + nodeproto.input(1)
21 + " but its type is not yet registered");
22 }
23 if (!parser.IsRegisteredTensorType(nodeproto.input(2))){
24 throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterElements op has input tensor " + nodeproto.input(2)
25 + " but its type is not yet registered");
26 }
28 if (parser.GetTensorType(nodeproto.input(2)) != input_type) {
29 throw std::runtime_error("TMVA::SOFIE ONNX parser ScatterElements op has input tensors of different types: " +
30 nodeproto.input(2) + " : " + ConvertTypeToString(parser.GetTensorType(nodeproto.input(2))) +
31 " and " + nodeproto.input(0) + " : " + ConvertTypeToString(input_type));
32 }
33
34 int axis = 0;
35 std::string reduction;
36 for (int i = 0; i < nodeproto.attribute_size(); i++) {
37 std::string attribute_name = nodeproto.attribute(i).name();
38 if (attribute_name == "axis")
39 axis = nodeproto.attribute(i).i();
40 else if (attribute_name == "reduction")
41 reduction = nodeproto.attribute(i).s();
42 }
43
44 std::unique_ptr<ROperator> op;
45 std::string output_name = nodeproto.output(0);
46
47 op.reset(new ROperator_ScatterElements(nodeproto.input(0), nodeproto.input(1), nodeproto.input(2),
48 output_name, axis, reduction));
49
50 // Infer the output type
53 }
54
55 return op;
56};
57
58
59} // namespace SOFIE
60} // namespace Experimental
61} // namespace TMVA
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
void RegisterTensorType(const std::string &, ETensorType)
ETensorType GetTensorType(const std::string &name)
std::function< std::unique_ptr< ROperator >(RModelParser_ONNX &, const onnx::NodeProto &)> ParserFuncSignature
std::string ConvertTypeToString(ETensorType type)
ParserFuncSignature ParseScatterElements
create variable transformations