Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_SOFIE_ONNX.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_ml
3## \notebook -nodraw
4## This macro provides a simple example for:
5## - creating a model with Pytorch and export to ONNX
6## - parsing the ONNX file with SOFIE and generate C++ code
7## - compiling the model using ROOT Cling
8## - run the code and optionally compare with ONNXRuntime
9##
10##
11## \macro_code
12## \macro_output
13## \author Lorenzo Moneta
14
15
16import contextlib
17import inspect
18import warnings
19
20import numpy as np
21import ROOT
22import torch
23import torch.nn as nn
24
25
26@contextlib.contextmanager
27def expect_warning(category, message):
28 """Silence a known third-party warning and raise if it stops firing.
29
30 Notifies us to drop the workaround once the upstream library is fixed.
31 """
32 with warnings.catch_warnings(record=True) as caught:
33 warnings.simplefilter("always")
34 yield
35 seen = False
36 for w in caught:
37 if issubclass(w.category, category) and message in str(w.message):
38 seen = True
39 else:
41 if not seen:
42 raise RuntimeError(
43 f"Expected {category.__name__} containing {message!r} was not "
44 "emitted. This tutorial's workaround can probably be removed."
45 )
46
47
48def CreateAndTrainModel(modelName):
49
50 model = nn.Sequential(nn.Linear(32, 16), nn.ReLU(), nn.Linear(16, 8), nn.ReLU(), nn.Linear(8, 2), nn.Softmax(dim=1))
51
52 criterion = nn.MSELoss()
53 optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
54
55 # train model with the random data
56 for i in range(500):
57 x = torch.randn(2, 32)
58 y = torch.randn(2, 2)
59 y_pred = model(x)
60 loss = criterion(y_pred, y)
64
65 # *******************************************************
66 ## EXPORT to ONNX
67 #
68 # need to evaluate the model before exporting to ONNX
69 # and to provide a dummy input tensor to set the input model shape
71
72 modelFile = modelName + ".onnx"
73 dummy_x = torch.randn(1, 32)
74 model(dummy_x)
75
76 # check for torch.onnx.export parameters
77 def filtered_kwargs(func, **candidate_kwargs):
78 sig = inspect.signature(func)
79 return {k: v for k, v in candidate_kwargs.items() if k in sig.parameters}
80
81 kwargs = filtered_kwargs(
83 input_names=["input"],
84 output_names=["output"],
85 external_data=False, # may not exist
86 dynamo=True, # may not exist
87 )
88 print("calling torch.onnx.export with parameters", kwargs)
89
90 try:
91 # torch.onnx.export (dynamo path) pickles its export program through
92 # copyreg, which still references the deprecated LeafSpec. The warning
93 # is emitted from inside PyTorch and cannot be avoided from user code.
94 with expect_warning(FutureWarning, "isinstance(treespec, LeafSpec)"):
95 torch.onnx.export(model, dummy_x, modelFile, **kwargs)
96 print("model exported to ONNX as", modelFile)
97 return modelFile
98 except TypeError:
99 print("Cannot export model from pytorch to ONNX - with version ", torch.__version__)
100 print("Skip tutorial execution")
101 exit()
102
103
104def ParseModel(modelFile, verbose=False):
105
107 model = parser.Parse(modelFile, verbose)
108 #
109 # print model weights
110 if verbose:
112 data = model.GetTensorData["float"]("0weight")
113 print("0weight", data)
114 data = model.GetTensorData["float"]("2weight")
115 print("2weight", data)
116
117 # Generating inference code
119 # generate header file (and .dat file) with modelName+.hxx
121 if verbose:
123
124 modelCode = modelFile.replace(".onnx", ".hxx")
125 print("Generated model header file ", modelCode)
126 return modelCode
127
128
129###################################################################
130## Step 1 : Create and Train model
131###################################################################
132
133# use an arbitrary modelName
134modelName = "LinearModel"
135modelFile = CreateAndTrainModel(modelName)
136
137
138###################################################################
139## Step 2 : Parse model and generate inference code with SOFIE
140###################################################################
141
142modelCode = ParseModel(modelFile, False)
143
144###################################################################
145## Step 3 : Compile the generated C++ model code
146###################################################################
147
148ROOT.gInterpreter.Declare('#include "' + modelCode + '"')
149
150###################################################################
151## Step 4: Evaluate the model
152###################################################################
153
154# get first the SOFIE session namespace
155sofie = getattr(ROOT, "TMVA_SOFIE_" + modelName)
156session = sofie.Session()
157
158x = np.random.normal(0, 1, (1, 32)).astype(np.float32)
159print("\n************************************************************")
160print("Running inference with SOFIE ")
161print("\ninput to model is ", x)
162y = session.infer(x)
163# output shape is (1,2)
164y_sofie = np.asarray(y.data())
165print("-> output using SOFIE = ", y_sofie)
166
167# check inference with onnx
168try:
169 import onnxruntime as ort
170
171 # Load model
172 print("Running inference with ONNXRuntime ")
173 ort_session = ort.InferenceSession(modelFile)
174
175 # Run inference
176 outputs = ort_session.run(None, {"input": x})
177 y_ort = outputs[0]
178 print("-> output using ORT =", y_ort)
179
180 testFailed = abs(y_sofie - y_ort) > 0.01
181 if np.any(testFailed):
182 raise RuntimeError("Result is different between SOFIE and ONNXRT")
183 else:
184 print("OK")
185
186except ImportError:
187 print("Missing ONNXRuntime: skipping comparison test")
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.