Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_SOFIE_Keras.C File Reference

Detailed Description

View in nbviewer Open in SWAN
This macro provides a simple example for the parsing of Keras .h5 file into RModel object and further generating the .hxx header files for inference.

using namespace TMVA::Experimental;
import os\n\
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n\
\n\
import numpy as np\n\
from tensorflow.keras.models import Model\n\
from tensorflow.keras.layers import Input,Dense,Activation,ReLU\n\
from tensorflow.keras.optimizers import SGD\n\
\n\
input=Input(shape=(64,),batch_size=4)\n\
x=Dense(32)(input)\n\
x=Activation('relu')(x)\n\
x=Dense(16,activation='relu')(x)\n\
x=Dense(8,activation='relu')(x)\n\
x=Dense(4)(x)\n\
output=ReLU()(x)\n\
model=Model(inputs=input,outputs=output)\n\
\n\
randomGenerator=np.random.RandomState(0)\n\
x_train=randomGenerator.rand(4,64)\n\
y_train=randomGenerator.rand(4,4)\n\
\n\
model.compile(loss='mean_squared_error', optimizer=SGD(learning_rate=0.01))\n\
model.fit(x_train, y_train, epochs=5, batch_size=4)\n\
model.save('KerasModel.h5')\n";
void TMVA_SOFIE_Keras(const char * modelFile = nullptr, bool printModelInfo = true){
//Running the Python script to generate Keras .h5 file
if (modelFile == nullptr) {
m.AddLine(pythonSrc);
m.SaveSource("make_keras_model.py");
gSystem->Exec(TMVA::Python_Executable() + " make_keras_model.py");
modelFile = "KerasModel.h5";
}
//Parsing the saved Keras .h5 file into RModel object
SOFIE::RModel model = SOFIE::PyKeras::Parse(modelFile);
//Generating inference code
model.Generate();
// generate output header. By default it will be modelName.hxx
model.OutputGenerated();
if (!printModelInfo) return;
//Printing required input tensors
std::cout<<"\n\n";
//Printing initialized tensors (weights)
std::cout<<"\n\n";
//Printing intermediate tensors
std::cout<<"\n\n";
//Printing generated inference code
std::cout<<"\n\n";
model.PrintGenerated();
}
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
R__EXTERN TSystem * gSystem
Definition TSystem.h:572
void OutputGenerated(std::string filename="", bool append=false)
Definition RModel.cxx:1303
void Generate(std::underlying_type_t< Options > options, int batchSize=-1, long pos=0, bool verbose=false)
Definition RModel.cxx:917
static void PyInitialize()
Initialize Python interpreter.
Class supporting a collection of lines with C++ code.
Definition TMacro.h:31
Basic string class.
Definition TString.h:139
virtual Int_t Exec(const char *shellcmd)
Execute a command.
Definition TSystem.cxx:653
TString Python_Executable()
Function to find current Python executable used by ROOT If "Python3" is installed,...
TMarker m
Definition textangle.C:8
Author
Sanjiban Sengupta

Definition in file TMVA_SOFIE_Keras.C.