Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_RNN_Classification.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_ml
3## \notebook
4## TMVA Classification Example Using a Recurrent Neural Network
5##
6## This is an example of using a RNN in TMVA. We do classification using a toy time dependent data set
7## that is generated when running this example macro
8##
9## \macro_code
10##
11## \author Harshal Shende
12
13
14# TMVA Classification Example Using a Recurrent Neural Network
15
16# This is an example of using a RNN in TMVA.
17# We do the classification using a toy data set containing a time series of data sample ntimes
18# and with dimension ndim that is generated when running the provided function `MakeTimeData (nevents, ntime, ndim)`
19
20
21import ROOT
22
23num_threads = 4 # use max 4 threads
24# do enable MT running
26 ROOT.EnableImplicitMT(num_threads)
27 # switch off MT in OpenBLAS to avoid conflict with tbb
28 ROOT.gSystem.Setenv("OMP_NUM_THREADS", "1")
29 print("Running with nthreads = {}".format(ROOT.GetThreadPoolSize()))
30else:
31 print("Running in serial mode since ROOT does not support MT")
32
33
34TMVA = ROOT.TMVA
35TFile = ROOT.TFile
36
37import os
38import importlib
39
40
43
44
45## Helper function to generate the time data set
46## make some time data but not of fixed length.
47## use a poisson with mu = 5 and truncated at 10
48
49
50def MakeTimeData(n, ntime, ndim):
51 # ntime = 10;
52 # ndim = 30; // number of dim/time
53
54 fname = "time_data_t" + str(ntime) + "_d" + str(ndim) + ".root"
55 v1 = []
56 v2 = []
57
58 for i in range(ntime):
59 v1.append(ROOT.TH1D("h1_" + str(i), "h1", ndim, 0, 10))
60 v2.append(ROOT.TH1D("h2_" + str(i), "h2", ndim, 0, 10))
61
62 f1 = ROOT.TF1("f1", "gaus")
63 f2 = ROOT.TF1("f2", "gaus")
64
65 sgn = ROOT.TTree("sgn", "sgn")
66 bkg = ROOT.TTree("bkg", "bkg")
67 f = TFile(fname, "RECREATE")
68
69 x1 = []
70 x2 = []
71
72 for i in range(ntime):
73 x1.append(ROOT.std.vector["float"](ndim))
74 x2.append(ROOT.std.vector["float"](ndim))
75
76 for i in range(ntime):
77 bkg.Branch("vars_time" + str(i), "std::vector<float>", x1[i])
78 sgn.Branch("vars_time" + str(i), "std::vector<float>", x2[i])
79
83
84 mean1 = ROOT.std.vector["double"](ntime)
85 mean2 = ROOT.std.vector["double"](ntime)
86 sigma1 = ROOT.std.vector["double"](ntime)
87 sigma2 = ROOT.std.vector["double"](ntime)
88
89 for j in range(ntime):
90 mean1[j] = 5.0 + 0.2 * ROOT.TMath.Sin(ROOT.TMath.Pi() * j / float(ntime))
91 mean2[j] = 5.0 + 0.2 * ROOT.TMath.Cos(ROOT.TMath.Pi() * j / float(ntime))
92 sigma1[j] = 4 + 0.3 * ROOT.TMath.Sin(ROOT.TMath.Pi() * j / float(ntime))
93 sigma2[j] = 4 + 0.3 * ROOT.TMath.Cos(ROOT.TMath.Pi() * j / float(ntime))
94
95 for i in range(n):
96 if i % 1000 == 0:
97 print("Generating event ... %d", i)
98
99 for j in range(ntime):
100 h1 = v1[j]
101 h2 = v2[j]
102 h1.Reset()
103 h2.Reset()
104
105 f1.SetParameters(1, mean1[j], sigma1[j])
106 f2.SetParameters(1, mean2[j], sigma2[j])
107
108 h1.FillRandom(f1, 1000)
109 h2.FillRandom(f2, 1000)
110
111 for k in range(ntime):
112 # std::cout << j*10+k << " ";
113 x1[j][k] = h1.GetBinContent(k + 1) + ROOT.gRandom.Gaus(0, 10)
114 x2[j][k] = h2.GetBinContent(k + 1) + ROOT.gRandom.Gaus(0, 10)
115
116 sgn.Fill()
117 bkg.Fill()
118
119 if n == 1:
120 c1 = ROOT.TCanvas()
121 c1.Divide(ntime, 2)
122 for j in range(ntime):
123 c1.cd(j + 1)
124 v1[j].Draw()
125 for j in range(ntime):
126 c1.cd(ntime + j + 1)
127 v2[j].Draw()
128
130
131 if n > 1:
132 sgn.Write()
133 bkg.Write()
134 sgn.Print()
135 bkg.Print()
136 f.Close()
137
138
139## macro for performing a classification using a Recurrent Neural Network
140## @param use_type
141## use_type = 0 use Simple RNN network
142## use_type = 1 use LSTM network
143## use_type = 2 use GRU
144## use_type = 3 build 3 different networks with RNN, LSTM and GRU
145
146
147use_type = 1
148ninput = 30
149ntime = 10
150batchSize = 100
151maxepochs = 10
152
153nTotEvts = 2000 # total events to be generated for signal or background
154
155useKeras = False
156
157useTMVA_RNN = True
158useTMVA_DNN = True
159useTMVA_BDT = False
160
161if ROOT.gSystem.GetFromPipe("root-config --has-tmva-pymva") == "yes":
162 useKeras = True
163
164if useKeras:
165 try:
166 import tensorflow
167 except:
168 ROOT.Warning("TMVA_RNN_Classification", "Skip using Keras since tensorflow cannot be imported")
169 useKeras = False
170
171
172rnn_types = ["RNN", "LSTM", "GRU"]
173use_rnn_type = [1, 1, 1]
174
175if 0 <= use_type < 3:
176 use_rnn_type = [0, 0, 0]
177 use_rnn_type[use_type] = 1
178
179useGPU = True # use GPU for TMVA if available
180
181useGPU = "tmva-gpu" in ROOT.gROOT.GetConfigFeatures()
182useTMVA_RNN = ("tmva-cpu" in ROOT.gROOT.GetConfigFeatures()) or useGPU
183
184if useTMVA_RNN:
186 "TMVA_RNN_Classification",
187 "TMVA is not build with GPU or CPU multi-thread support. Cannot use TMVA Deep Learning for RNN",
188 )
189
190archString = "GPU" if useGPU else "CPU"
191
192writeOutputFile = True
193
194rnn_type = "RNN"
195
196if "tmva-pymva" in ROOT.gROOT.GetConfigFeatures():
198else:
199 useKeras = False
200
201
202
203inputFileName = "time_data_t10_d30.root"
204
205fileDoesNotExist = ROOT.gSystem.AccessPathName(inputFileName)
206
207# if file does not exists create it
208if fileDoesNotExist:
209 MakeTimeData(nTotEvts, ntime, ninput)
210
211
212inputFile = TFile.Open(inputFileName)
213if inputFile is None:
214 raise ROOT.Error("Error opening input file %s - exit", inputFileName.Data())
215
216
217print("--- RNNClassification : Using input file: {}".format(inputFile.GetName()))
218
219# Create a ROOT output file where TMVA will store ntuples, histograms, etc.
220outfileName = "data_RNN_" + archString + ".root"
221outputFile = None
222
223
224if writeOutputFile:
225 outputFile = TFile.Open(outfileName, "RECREATE")
226
227
228## Declare Factory
229
230# Create the Factory class. Later you can choose the methods
231# whose performance you'd like to investigate.
232
233# The factory is the major TMVA object you have to interact with. Here is the list of parameters you need to
234# pass
235
236# - The first argument is the base of the name of all the output
237# weightfiles in the directory weight/ that will be created with the
238# method parameters
239
240# - The second argument is the output file for the training results
241#
242# - The third argument is a string option defining some general configuration for the TMVA session.
243# For example all TMVA output can be suppressed by removing the "!" (not) in front of the "Silent" argument in
244# the option string
245
246
247# // Creating the factory object
248factory = TMVA.Factory(
249 "TMVAClassification",
250 outputFile,
251 V=False,
252 Silent=False,
253 Color=True,
254 DrawProgressBar=True,
255 Transformations=None,
256 Correlations=False,
257 AnalysisType="Classification",
258 ModelPersistence=True,
259)
260dataloader = TMVA.DataLoader("dataset")
261
262signalTree = inputFile.Get("sgn")
263background = inputFile.Get("bkg")
264
265nvar = ninput * ntime
266
267## add variables - use new AddVariablesArray function
268for i in range(ntime):
269 dataloader.AddVariablesArray("vars_time" + str(i), ninput)
270
271
272dataloader.AddSignalTree(signalTree, 1.0)
273dataloader.AddBackgroundTree(background, 1.0)
274
275# check given input
276datainfo = dataloader.GetDataSetInfo()
278print("number of variables is {}".format(vars.size()))
279
280
281for v in vars:
282 print(v)
283
284nTrainSig = 0.8 * nTotEvts
285nTrainBkg = 0.8 * nTotEvts
286
287# Apply additional cuts on the signal and background samples (can be different)
288mycuts = "" # for example: TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
289mycutb = ""
290
291# build the string options for DataLoader::PrepareTrainingAndTestTree
293 mycuts,
294 mycutb,
295 nTrain_Signal=nTrainSig,
296 nTrain_Background=nTrainBkg,
297 SplitMode="Random",
298 SplitSeed=100,
299 NormMode="NumEvents",
300 V=False,
301 CalcCorrelations=False,
302)
303
304print("prepared DATA LOADER ")
305
306
307## Book TMVA recurrent models
308
309# Book the different types of recurrent models in TMVA (SimpleRNN, LSTM or GRU)
310
311
312if useTMVA_RNN:
313 for i in range(3):
314 if not use_rnn_type[i]:
315 continue
316
317 rnn_type = rnn_types[i]
318
319 ## Define RNN layer layout
320 ## it should be LayerType (RNN or LSTM or GRU) | number of units | number of inputs | time steps | remember output (typically no=0 | return full sequence
321 rnnLayout = str(rnn_type) + "|10|" + str(ninput) + "|" + str(ntime) + "|0|1,RESHAPE|FLAT,DENSE|64|TANH,LINEAR"
322
323 ## Defining Training strategies. Different training strings can be concatenate. Use however only one
324 trainingString1 = "LearningRate=1e-3,Momentum=0.0,Repetitions=1,ConvergenceSteps=5,BatchSize=" + str(batchSize)
325 trainingString1 += ",TestRepetitions=1,WeightDecay=1e-2,Regularization=None,MaxEpochs=" + str(maxepochs)
326 trainingString1 += "Optimizer=ADAM,DropConfig=0.0+0.+0.+0."
327
328 ## define the inputlayout string for RNN
329 ## the input data should be organize as following:
330 ##/ input layout for RNN: time x ndim
331 ## add after RNN a reshape layer (needed top flatten the output) and a dense layer with 64 units and a last one
332 ## Note the last layer is linear because when using Crossentropy a Sigmoid is applied already
333 ## Define the full RNN Noption string adding the final options for all network
334 rnnName = "TMVA_" + str(rnn_type)
336 dataloader,
338 rnnName,
339 H=False,
340 V=True,
341 ErrorStrategy="CROSSENTROPY",
342 VarTransform=None,
343 WeightInitialization="XAVIERUNIFORM",
344 ValidationSize=0.2,
345 RandomSeed=1234,
346 InputLayout=str(ntime) + "|" + str(ninput),
347 Layout=rnnLayout,
348 TrainingStrategy=trainingString1,
349 Architecture=archString
350 )
351
352
353## Book TMVA fully connected dense layer models
354if useTMVA_DNN:
355 # Method DL with Dense Layer
356 # Training strategies.
357 trainingString1 = ROOT.TString(
358 "LearningRate=1e-3,Momentum=0.0,Repetitions=1,"
359 "ConvergenceSteps=10,BatchSize=256,TestRepetitions=1,"
360 "WeightDecay=1e-4,Regularization=None,MaxEpochs=20"
361 "DropConfig=0.0+0.+0.+0.,Optimizer=ADAM:"
362 ) # + "|" + trainingString2
363 # General Options.
364 trainingString1.Append(archString)
365 dnnName = "TMVA_DNN"
367 dataloader,
369 dnnName,
370 H=False,
371 V=True,
372 ErrorStrategy="CROSSENTROPY",
373 VarTransform=None,
374 WeightInitialization="XAVIER",
375 RandomSeed=0,
376 InputLayout="1|1|" + str(ntime * ninput),
377 Layout="DENSE|64|TANH,DENSE|TANH|64,DENSE|TANH|64,LINEAR",
378 TrainingStrategy=trainingString1
379 )
380
381
382## Book Keras recurrent models
383
384# Book the different types of recurrent models in Keras (SimpleRNN, LSTM or GRU)
385
386
387if useKeras:
388 for i in range(3):
389 if use_rnn_type[i]:
390 modelName = "model_" + rnn_types[i] + ".h5"
391 trainedModelName = "trained_" + modelName
392 print("Building recurrent keras model using a", rnn_types[i], "layer")
393 # create python script which can be executed
394 # create 2 conv2d layer + maxpool + dense
395 from tensorflow.keras.models import Sequential
396 from tensorflow.keras.optimizers import Adam
397
398 # from keras.initializers import TruncatedNormal
399 # from keras import initializations
400 from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, SimpleRNN, GRU, LSTM, Reshape, BatchNormalization
401
402 model = Sequential()
403 model.add(Reshape((10, 30), input_shape=(10 * 30,)))
404 # add recurrent neural network depending on type / Use option to return the full output
405 if rnn_types[i] == "LSTM":
406 model.add(LSTM(units=10, return_sequences=True))
407 elif rnn_types[i] == "GRU":
408 model.add(GRU(units=10, return_sequences=True))
409 else:
410 model.add(SimpleRNN(units=10, return_sequences=True))
411 # m.AddLine("model.add(BatchNormalization())");
412 model.add(Flatten()) # needed if returning the full time output sequence
413 model.add(Dense(64, activation="tanh"))
414 model.add(Dense(2, activation="sigmoid"))
415 model.compile(loss="binary_crossentropy", optimizer=Adam(learning_rate=0.001), weighted_metrics=["accuracy"])
416 model.save(modelName)
418 print("saved recurrent model", modelName)
419
420 if not os.path.exists(modelName):
421 useKeras = False
422 print("Error creating Keras recurrent model file - Skip using Keras")
423 else:
424 # book PyKeras method only if Keras model could be created
425 print("Booking Keras model ", rnn_types[i])
427 dataloader,
429 "PyKeras_" + rnn_types[i],
430 H=True,
431 V=False,
432 VarTransform=None,
433 FilenameModel=modelName,
434 FilenameTrainedModel="trained_" + modelName,
435 NumEpochs=maxepochs,
436 BatchSize=batchSize,
437 GpuOptions="allow_growth=True",
438 )
439
440
441# use BDT in case not using Keras or TMVA DL
442if not useKeras or not useTMVA_BDT:
443 useTMVA_BDT = True
444
445
446## Book TMVA BDT
447
448
449if useTMVA_BDT:
451 dataloader,
453 "BDTG",
454 H=True,
455 V=False,
456 NTrees=100,
457 MinNodeSize="2.5%",
458 BoostType="Grad",
459 Shrinkage=0.10,
460 UseBaggedBoost=True,
461 BaggedSampleFraction=0.5,
462 nCuts=20,
463 MaxDepth=2,
464 )
465
466
467## Train all methods
469
470print("nthreads = {}".format(ROOT.GetThreadPoolSize()))
471
472# ---- Evaluate all MVAs using the set of test events
474
475# ----- Evaluate and compare performance of all configured MVAs
477
478# check method
479
480# plot ROC curve
481c1 = factory.GetROCCurve(dataloader)
482c1.Draw()
483
484if outputFile:
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t format
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:131
This is the main MVA steering class.
Definition Factory.h:80
th1 Draw()