Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_RNN_Classification.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_ml
3/// \notebook
4/// TMVA Classification Example Using a Recurrent Neural Network
5///
6/// This is an example of using a RNN in TMVA. We do classification using a toy time dependent data set
7/// that is generated when running this example macro
8///
9/// \macro_code
10///
11/// \author Lorenzo Moneta
12/***
13
14 # TMVA Classification Example Using a Recurrent Neural Network
15
16 This is an example of using a RNN in TMVA.
17 We do the classification using a toy data set containing a time series of data sample ntimes
18 and with dimension ndim that is generated when running the provided function `MakeTimeData (nevents, ntime, ndim)`
19
20
21**/
22
23#include<TROOT.h>
24
25#include "TMVA/Factory.h"
26#include "TMVA/DataLoader.h"
27#include "TMVA/DataSetInfo.h"
28#include "TMVA/Config.h"
29#include "TMVA/MethodDL.h"
30
31
32#include "TFile.h"
33#include "TTree.h"
34
35/// Helper function to generate the time data set
36/// make some time data but not of fixed length.
37/// use a poisson with mu = 5 and truncated at 10
38///
39void MakeTimeData(int n, int ntime, int ndim )
40{
41
42 // const int ntime = 10;
43 // const int ndim = 30; // number of dim/time
44 TString fname = TString::Format("time_data_t%d_d%d.root", ntime, ndim);
45 std::vector<TH1 *> v1(ntime);
46 std::vector<TH1 *> v2(ntime);
47 int i = 0;
48 for (int i = 0; i < ntime; ++i) {
49 v1[i] = new TH1D(TString::Format("h1_%d", i), "h1", ndim, 0, 10);
50 v2[i] = new TH1D(TString::Format("h2_%d", i), "h2", ndim, 0, 10);
51 }
52
53 auto f1 = new TF1("f1", "gaus");
54 auto f2 = new TF1("f2", "gaus");
55
56 TFile f(fname, "RECREATE");
57 TTree sgn("sgn", "sgn");
58 TTree bkg("bkg", "bkg");
59
60 std::vector<std::vector<float>> x1(ntime);
61 std::vector<std::vector<float>> x2(ntime);
62
63 for (int i = 0; i < ntime; ++i) {
64 x1[i] = std::vector<float>(ndim);
65 x2[i] = std::vector<float>(ndim);
66 }
67
68 for (auto i = 0; i < ntime; i++) {
69 bkg.Branch(Form("vars_time%d", i), "std::vector<float>", &x1[i]);
70 sgn.Branch(Form("vars_time%d", i), "std::vector<float>", &x2[i]);
71 }
72
73 sgn.SetDirectory(&f);
74 bkg.SetDirectory(&f);
75 gRandom->SetSeed(0);
76
77 std::vector<double> mean1(ntime);
78 std::vector<double> mean2(ntime);
79 std::vector<double> sigma1(ntime);
80 std::vector<double> sigma2(ntime);
81 for (int j = 0; j < ntime; ++j) {
82 mean1[j] = 5. + 0.2 * sin(TMath::Pi() * j / double(ntime));
83 mean2[j] = 5. + 0.2 * cos(TMath::Pi() * j / double(ntime));
84 sigma1[j] = 4 + 0.3 * sin(TMath::Pi() * j / double(ntime));
85 sigma2[j] = 4 + 0.3 * cos(TMath::Pi() * j / double(ntime));
86 }
87 for (int i = 0; i < n; ++i) {
88
89 if (i % 1000 == 0)
90 std::cout << "Generating event ... " << i << std::endl;
91
92 for (int j = 0; j < ntime; ++j) {
93 auto h1 = v1[j];
94 auto j];
95 h1->Reset();
96 h2->Reset();
97
99 f2->SetParameters(1, mean2[j], sigma2[j]);
100
101 h1->FillRandom("f1", 1000);
102 h2->FillRandom("f2", 1000);
103
104 for (int k = 0; k < ndim; ++k) {
105 // std::cout << j*10+k << " ";
106 x1[j][k] = h1->GetBinContent(k + 1) + gRandom->Gaus(0, 10);
107 Gaus(0, 10);
108 }
109 }
110 // std::cout << std::endl;
111 sgn.Fill();
112 bkg.Fill();
113
114 if (n == 1) {
115 auto c1 = new TCanvas();
116 c1->Divide(ntime, 2);
117 for (int j = 0; j < ntime; ++j) {
118 c1->cd(j + 1);
119 v1[j]->Draw();
120 }
121 for (int j = 0; j < ntime; ++j) {
122 c1->cd(ntime + j + 1);
123 v2[j]->Draw();
124 }
125 gPad->Update();
126 }
127 }
128 if (n > 1) {
129 sgn.Write();
130 bkg.Write();
131 sgn.Print();
132 bkg.Print();
133 f.Close();
134 }
135}
136/// macro for performing a classification using a Recurrent Neural Network
137/// @param nevts = 2000 Number of events used. (increase for better classification results)
138/// @param use_type
139/// use_type = 0 use Simple RNN network
140/// use_type = 1 use LSTM network
141/// use_type = 2 use GRU
142/// use_type = 3 build 3 different networks with RNN, LSTM and GRU
143
144void use_type = 1)
145{
146
147 const int ninput = 30;
148 const int ntime = 10;
149 const int batchSize = 100;
150 const int maxepochs = 20;
151
152 int nTotEvts = nevts; // total events to be generated for signal or background
153
154 bool useKeras = true;
155
156
157 bool useTMVA_RNN = true;
158 bool useTMVA_DNN = true;
159 bool useTMVA_BDT = false;
160
161 std::vector<std::string> rnn_types = {"RNN", "LSTM", "GRU"};
162 std::vector<bool> use_rnn_type = {1, 1, 1};
163 if (use_type >=0 && use_type < 3) {
164 use_rnn_type = {0,0,0};
166 }
167 bool useGPU = true; // use GPU for TMVA if available
168
169#ifndef R__HAS_TMVAGPU
170 useGPU = false;
171#ifndef R__HAS_TMVACPU
172 Warning("TMVA_RNN_Classification", "TMVA is not build with GPU or CPU multi-thread support. Cannot use TMVA Deep Learning for RNN");
173 useTMVA_RNN = false;
174#endif
175#endif
176
177
178 TString archString = (useGPU) ? "GPU" : "CPU";
179
180 bool writeOutputFile = true;
181
182
183
184 const char *rnn_type = "RNN";
185
186#ifdef R__HAS_PYMVA
188#else
189 useKeras = false;
190#endif
191
192#ifdef R__USE_IMT
193 int num_threads = 4; // use max 4 threads
194 // switch off MT in OpenBLAS to avoid conflict with tbb
195 gSystem->Setenv("OMP_NUM_THREADS", "1");
196
197 // do enable MT running
198 if (num_threads >= 0) {
200 }
201#endif
202
204
205 std::cout << "Running with nthreads = " << ROOT::GetThreadPoolSize() << std::endl;
206
207 TString inputFileName = "time_data_t10_d30.root";
208
210
211 // if file does not exists create it
212 if (!fileExist) {
214 }
215
216
218 if (!inputFile) {
219 Error("TMVA_RNN_Classification", "Error opening input file %s - exit", inputFileName.Data());
220 return;
221 }
222
223
224 std::cout << "--- RNNClassification : Using input file: " << inputFile->GetName() << std::endl;
225
226 // Create a ROOT output file where TMVA will store ntuples, histograms, etc.
227 TString outfileName(TString::Format("data_RNN_%s.root", archString.Data()));
228 TFile *outputFile = nullptr;
230
231 /**
232 ## Declare Factory
233
234 Create the Factory class. Later you can choose the methods
235 whose performance you'd like to investigate.
236
237 The factory is the major TMVA object you have to interact with. Here is the list of parameters you need to
238pass
239
240 - The first argument is the base of the name of all the output
241 weightfiles in the directory weight/ that will be created with the
242 method parameters
243
244 - The second argument is the output file for the training results
245
246 - The third argument is a string option defining some general configuration for the TMVA session.
247 For example all TMVA output can be suppressed by removing the "!" (not) in front of the "Silent" argument in
248the option string
249
250 **/
251
252 // Creating the factory object
253 TMVA::Factory *factory = new TMVA::Factory("TMVAClassification", outputFile,
254 "!V:!Silent:Color:DrawProgressBar:Transformations=None:!Correlations:"
255 "AnalysisType=Classification:ModelPersistence");
257
258 TTree *signalTree = (TTree *)inputFile->Get("sgn");
259 TTree *background = (TTree *)inputFile->Get("bkg");
260
261 const int nvar = ninput * ntime;
262
263 /// add variables - use new AddVariablesArray function
264 for (auto i = 0; i < ntime; i++) {
265 dataloader->AddVariablesArray(Form("vars_time%d", i), ninput);
266 }
267
268 dataloader->AddSignalTree(signalTree, 1.0);
269 dataloader->AddBackgroundTree(background, 1.0);
270
271 // check given input
272 auto &datainfo = dataloader->GetDataSetInfo();
273 auto vars = datainfo.GetListOfVariables();
274 std::cout << "number of variables is " << vars.size() << std::endl;
275 for (auto &v : vars)
276 std::cout << v << ",";
277 std::cout << std::endl;
278
279 int nTrainSig = 0.8 * nTotEvts;
280 int nTrainBkg = 0.8 * nTotEvts;
281
282 // build the string options for DataLoader::PrepareTrainingAndTestTree
283 TString prepareOptions = TString::Format("nTrain_Signal=%d:nTrain_Background=%d:SplitMode=Random:SplitSeed=100:NormMode=NumEvents:!V:!CalcCorrelations", nTrainSig, nTrainBkg);
284
285 // Apply additional cuts on the signal and background samples (can be different)
286 TCut mycuts = ""; // for example: TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
287 TCut mycutb = "";
288
289 dataloader->PrepareTrainingAndTestTree(mycuts, mycutb, prepareOptions);
290
291 std::cout << "prepared DATA LOADER " << std::endl;
292
293 /**
294 ## Book TMVA recurrent models
295
296 Book the different types of recurrent models in TMVA (SimpleRNN, LSTM or GRU)
297
298 **/
299
300 if (useTMVA_RNN) {
301
302 for (int i = 0; i < 3; ++i) {
303
304 if (!use_rnn_type[i])
305 continue;
306
307 const char *rnn_type = rnn_types[i].c_str();
308
309 /// define the inputlayout string for RNN
310 /// the input data should be organize as following:
311 //// input layout for RNN: time x ndim
312
313 TString inputLayoutString = TString::Format("InputLayout=%d|%d", ntime, ninput);
314
315 /// Define RNN layer layout
316 /// it should be LayerType (RNN or LSTM or GRU) | number of units | number of inputs | time steps | remember output (typically no=0 | return full sequence
317 TString rnnLayout = TString::Format("%s|10|%d|%d|0|1", rnn_type, ninput, ntime);
318
319 /// add after RNN a reshape layer (needed top flatten the output) and a dense layer with 64 units and a last one
320 /// Note the last layer is linear because when using Crossentropy a Sigmoid is applied already
321 TString layoutString = TString("Layout=") + rnnLayout + TString(",RESHAPE|FLAT,DENSE|64|TANH,LINEAR");
322
323 /// Defining Training strategies. Different training strings can be concatenate. Use however only one
324 TString trainingString1 = TString::Format("LearningRate=1e-3,Momentum=0.0,Repetitions=1,"
325 "ConvergenceSteps=5,BatchSize=%d,TestRepetitions=1,"
326 "WeightDecay=1e-2,Regularization=None,MaxEpochs=%d,"
327 "Optimizer=ADAM,DropConfig=0.0+0.+0.+0.",
328 batchSize,maxepochs);
329
330 TString trainingStrategyString("TrainingStrategy=");
331 trainingStrategyString += trainingString1; // + "|" + trainingString2
332
333 /// Define the full RNN Noption string adding the final options for all network
334 TString rnnOptions("!H:V:ErrorStrategy=CROSSENTROPY:VarTransform=None:"
335 "WeightInitialization=XAVIERUNIFORM:ValidationSize=0.2:RandomSeed=1234");
336
337 rnnOptions.Append(":");
339 rnnOptions.Append(":");
340 rnnOptions.Append(layoutString);
341 rnnOptions.Append(":");
343 rnnOptions.Append(":");
344 rnnOptions.Append(TString::Format("Architecture=%s", archString.Data()));
345
346 TString rnnName = "TMVA_" + TString(rnn_type);
348
349 }
350 }
351
352 /**
353 ## Book TMVA fully connected dense layer models
354
355 **/
356
357 if (useTMVA_DNN) {
358 // Method DL with Dense Layer
359 TString inputLayoutString = TString::Format("InputLayout=1|1|%d", ntime * ninput);
360
361 TString layoutString("Layout=DENSE|64|TANH,DENSE|TANH|64,DENSE|TANH|64,LINEAR");
362 // Training strategies.
363 TString trainingString1("LearningRate=1e-3,Momentum=0.0,Repetitions=1,"
364 "ConvergenceSteps=10,BatchSize=256,TestRepetitions=1,"
365 "WeightDecay=1e-4,Regularization=None,MaxEpochs=20"
366 "DropConfig=0.0+0.+0.+0.,Optimizer=ADAM");
367 TString trainingStrategyString("TrainingStrategy=");
368 trainingStrategyString += trainingString1; // + "|" + trainingString2
369
370 // General Options.
371 TString dnnOptions("!H:V:ErrorStrategy=CROSSENTROPY:VarTransform=None:"
372 "WeightInitialization=XAVIER:RandomSeed=0");
373
374 dnnOptions.Append(":");
376 dnnOptions.Append(":");
377 dnnOptions.Append(layoutString);
378 dnnOptions.Append(":");
380 dnnOptions.Append(":");
381 dnnOptions.Append(archString);
382
383 TString dnnName = "TMVA_DNN";
385 }
386
387 /**
388 ## Book Keras recurrent models
389
390 Book the different types of recurrent models in Keras (SimpleRNN, LSTM or GRU)
391
392 **/
393
394 if (useKeras) {
395
396 for (int i = 0; i < 3; i++) {
397
398 if (use_rnn_type[i]) {
399
400 TString modelName = TString::Format("model_%s.h5", rnn_types[i].c_str());
401 TString trainedModelName = TString::Format("trained_model_%s.h5", rnn_types[i].c_str());
402
403 Info("TMVA_RNN_Classification", "Building recurrent keras model using a %s layer", rnn_types[i].c_str());
404 // create python script which can be executed
405 // create 2 conv2d layer + maxpool + dense
406 TMacro m;
407 m.AddLine("import tensorflow");
408 m.AddLine("from tensorflow.keras.models import Sequential");
409 m.AddLine("from tensorflow.keras.optimizers import Adam");
410 m.AddLine("from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, SimpleRNN, GRU, LSTM, Reshape, "
411 "BatchNormalization");
412 m.AddLine("");
413 m.AddLine("model = Sequential() ");
414 m.AddLine("model.add(Reshape((10, 30), input_shape = (10*30, )))");
415 // add recurrent neural network depending on type / Use option to return the full output
416 if (rnn_types[i] == "LSTM")
417 m.AddLine("model.add(LSTM(units=10, return_sequences=True) )");
418 else if (rnn_types[i] == "GRU")
419 m.AddLine("model.add(GRU(units=10, return_sequences=True) )");
420 else
421 m.AddLine("model.add(SimpleRNN(units=10, return_sequences=True) )");
422
423 // m.AddLine("model.add(BatchNormalization())");
424 m.AddLine("model.add(Flatten())"); // needed if returning the full time output sequence
425 m.AddLine("model.add(Dense(64, activation = 'tanh')) ");
426 m.AddLine("model.add(Dense(2, activation = 'sigmoid')) ");
427 m.AddLine(
428 "model.compile(loss = 'binary_crossentropy', optimizer = Adam(learning_rate = 0.001), weighted_metrics = ['accuracy'])");
429 m.AddLine(TString::Format("modelName = '%s'", modelName.Data()));
430 m.AddLine("model.save(modelName)");
431 m.AddLine("model.summary()");
432
433 m.SaveSource("make_rnn_model.py");
434 // execute python script to make the model
435 auto ret = (TString *)gROOT->ProcessLine("TMVA::Python_Executable()");
436 TString python_exe = (ret) ? *(ret) : "python";
437 gSystem->Exec(python_exe + " make_rnn_model.py");
438
440 Warning("TMVA_RNN_Classification", "Error creating Keras recurrent model file - Skip using Keras");
441 useKeras = false;
442 } else {
443 // book PyKeras method only if Keras model could be created
444 Info("TMVA_RNN_Classification", "Booking Keras %s model", rnn_types[i].c_str());
446 TString::Format("PyKeras_%s", rnn_types[i].c_str()),
447 TString::Format("!H:!V:VarTransform=None:FilenameModel=%s:tf.keras:"
448 "FilenameTrainedModel=%s:GpuOptions=allow_growth=True:"
449 "NumEpochs=%d:BatchSize=%d",
450 modelName.Data(), trainedModelName.Data(), maxepochs, batchSize));
451 }
452 }
453 }
454 }
455
456 // use BDT in case not using Keras or TMVA DL
457 if (!useKeras || !useTMVA_BDT)
458 useTMVA_BDT = true;
459
460 /**
461 ## Book TMVA BDT
462 **/
463
464 if (useTMVA_BDT) {
465
466 factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDTG",
467 "!H:!V:NTrees=100:MinNodeSize=2.5%:BoostType=Grad:Shrinkage=0.10:UseBaggedBoost:"
468 "BaggedSampleFraction=0.5:nCuts=20:"
469 "MaxDepth=2");
470
471 }
472
473 /// Train all methods
474 factory->TrainAllMethods();
475
476 std::cout << "nthreads = " << ROOT::GetThreadPoolSize() << std::endl;
477
478 // ---- Evaluate all MVAs using the set of test events
479 factory->TestAllMethods();
480
481 // ----- Evaluate and compare performance of all configured MVAs
482 factory->EvaluateAllMethods();
483
484 // check method
485
486 // plot ROC curve
487 auto c1 = factory->GetROCCurve(dataloader);
488 c1->Draw();
489
490 if (outputFile) outputFile->Close();
491}
#define f(i)
Definition RSha256.hxx:104
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
void Error(const char *location, const char *msgfmt,...)
Use this function in case an error occurred.
Definition TError.cxx:185
void Warning(const char *location, const char *msgfmt,...)
Use this function in warning situations.
Definition TError.cxx:229
Option_t Option_t TPoint TPoint const char x2
Option_t Option_t TPoint TPoint const char x1
#define gROOT
Definition TROOT.h:414
R__EXTERN TRandom * gRandom
Definition TRandom.h:62
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2489
R__EXTERN TSystem * gSystem
Definition TSystem.h:572
#define gPad
The Canvas class.
Definition TCanvas.h:23
A specialized string object used for TTree selections.
Definition TCut.h:25
1-Dim function class
Definition TF1.h:234
virtual void SetParameters(const Double_t *params)
Definition TF1.h:685
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:131
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4131
1-D histogram with a double per channel (see TH1 documentation)
Definition TH1.h:925
void Reset(Option_t *option="") override
Reset.
Definition TH1.cxx:10284
virtual void FillRandom(TF1 *f1, Int_t ntimes=5000, TRandom *rng=nullptr)
Definition TH1.cxx:3500
virtual Double_t GetBinContent(Int_t bin) const
Return content of bin number bin.
Definition TH1.cxx:5064
static Config & Instance()
static function: returns TMVA instance
Definition Config.cxx:98
This is the main MVA steering class.
Definition Factory.h:80
void TrainAllMethods()
Iterates through all booked methods and calls training.
Definition Factory.cxx:1114
MethodBase * BookMethod(DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption="")
Book a classifier or regression method.
Definition Factory.cxx:352
void TestAllMethods()
Evaluates all booked methods on the testing data and adds the output to the Results in the corresponi...
Definition Factory.cxx:1271
void EvaluateAllMethods(void)
Iterates over all MVAs that have been booked, and calls their evaluation methods.
Definition Factory.cxx:1376
TGraph * GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles=kTRUE, UInt_t iClass=0, Types::ETreeType type=Types::kTesting)
Argument iClass specifies the class to generate the ROC curve in a multiclass setting.
Definition Factory.cxx:912
static void PyInitialize()
Initialize Python interpreter.
Class supporting a collection of lines with C++ code.
Definition TMacro.h:31
virtual Double_t Gaus(Double_t mean=0, Double_t sigma=1)
Samples a random number from the standard Normal (Gaussian) Distribution with the given mean and sigm...
Definition TRandom.cxx:275
virtual void SetSeed(ULong_t seed=0)
Set the random generator seed.
Definition TRandom.cxx:615
Basic string class.
Definition TString.h:139
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2378
virtual Int_t Exec(const char *shellcmd)
Execute a command.
Definition TSystem.cxx:653
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition TSystem.cxx:1308
virtual void Setenv(const char *name, const char *value)
Set environment variable.
Definition TSystem.cxx:1661
A TTree represents a columnar dataset.
Definition TTree.h:84
std::ostream & Info()
Definition hadd.cxx:171
return c1
Definition legend1.C:41
const Int_t n
Definition legend1.C:16
TH1F * h1
Definition legend1.C:5
TF1 * f1
Definition legend1.C:11
void EnableImplicitMT(UInt_t numthreads=0)
Enable ROOT's implicit multi-threading for all objects and methods that provide an internal paralleli...
Definition TROOT.cxx:539
UInt_t GetThreadPoolSize()
Returns the size of ROOT's thread pool.
Definition TROOT.cxx:602
constexpr Double_t Pi()
Definition TMath.h:37
TMarker m
Definition textangle.C:8