Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_SOFIE_Keras.py
Go to the documentation of this file.
1### \file
2### \ingroup tutorial_ml
3### \notebook -nodraw
4### This macro provides a simple example for the parsing of Keras .keras file
5### into RModel object and further generating the .hxx header files for inference.
6###
7### \macro_code
8### \macro_output
9### \author Sanjiban Sengupta and Lorenzo Moneta
10
11
12import contextlib
13import warnings
14
15import numpy as np
16import ROOT
17from tensorflow.keras.layers import Activation, Dense, Input, Softmax
18from tensorflow.keras.models import Model
19
20# Enable ROOT in batch mode (same effect as -nodraw)
22
23
24@contextlib.contextmanager
25def expect_warning(category, message):
26 """Silence a known third-party warning and raise if it stops firing.
27
28 Notifies us to drop the workaround once the upstream library is fixed.
29 """
30 with warnings.catch_warnings(record=True) as caught:
31 warnings.simplefilter("always")
32 yield
33 seen = False
34 for w in caught:
35 if issubclass(w.category, category) and message in str(w.message):
36 seen = True
37 else:
39 if not seen:
40 raise RuntimeError(
41 f"Expected {category.__name__} containing {message!r} was not "
42 "emitted. This tutorial's workaround can probably be removed."
43 )
44
45
46# -----------------------------------------------------------------------------
47# Step 1: Create and train a simple Keras model (via embedded Python)
48# -----------------------------------------------------------------------------
49
50input = Input(shape=(4,), batch_size=2)
51x = Dense(32)(input)
52x = Activation("relu")(x)
53x = Dense(16, activation="relu")(x)
54x = Dense(8, activation="relu")(x)
55x = Dense(2)(x)
56output = Softmax()(x)
57model = Model(inputs=input, outputs=output)
58
59randomGenerator = np.random.RandomState(0)
60x_train = randomGenerator.rand(4, 4)
61y_train = randomGenerator.rand(4, 2)
62
63model.compile(loss="mse", optimizer="adam")
64model.fit(x_train, y_train, epochs=3, batch_size=2)
65
66# Keras' internal ``np.array(x)`` (TensorFlow backend) does not yet implement
67# the NumPy 2.0 ``__array__(copy=...)`` signature, so saving the model emits a
68# DeprecationWarning that we cannot fix from user code.
69if tuple(int(p) for p in np.__version__.split(".")[:2]) >= (2, 0):
70 ctx = expect_warning(DeprecationWarning, "__array__ implementation doesn't accept a copy keyword")
71else:
73
74with ctx:
75 model.save("KerasModel.keras")
76
78
79# -----------------------------------------------------------------------------
80# Step 2: Use TMVA::SOFIE to parse the ONNX model
81# -----------------------------------------------------------------------------
82
83# Parse the ONNX model
84
85model = ROOT.TMVA.Experimental.SOFIE.PyKeras.Parse("KerasModel.keras")
86
87# Generate inference code
90# print generated code
91print("\n**************************************************")
92print(" Generated code")
93print("**************************************************\n")
95print("**************************************************\n\n\n")
96
97# Compile the generated code
98ROOT.gInterpreter.Declare('#include "KerasModel.hxx"')
99
100
101# -----------------------------------------------------------------------------
102# Step 3: Run inference
103# -----------------------------------------------------------------------------
104
105# instantiate SOFIE session class
107
108# Input tensor (same shape as training input)
109x = np.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]], dtype=np.float32)
110
111# Run inference
112y = session.infer(x)
113
114print("Inference output:", y)
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.