Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_CNN_Classification.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_ml
3## \notebook
4## TMVA Classification Example Using a Convolutional Neural Network
5##
6## This is an example of using a CNN in TMVA. We do classification using a toy image data set
7## that is generated when running the example macro
8##
9## \macro_image
10## \macro_output
11## \macro_code
12##
13## \author Harshal Shende
14
15
16# TMVA Classification Example Using a Convolutional Neural Network
17
18
19## Helper function to create input images data
20## we create a signal and background 2D histograms from 2d gaussians
21## with a location (means in X and Y) different for each event
22## The difference between signal and background is in the gaussian width.
23## The width for the background gaussian is slightly larger than the signal width by few % values
24
25import ROOT
26
27import os
28import importlib.util
29
30useKerasCNN = False
31
32if ROOT.gSystem.GetFromPipe("root-config --has-tmva-pymva") == "yes":
33 useKerasCNN = True
34
35opt = [1, 1, 1, 1, 1]
36useTMVACNN = opt[0] if len(opt) > 0 else False
37useKerasCNN = opt[1] if len(opt) > 1 else useKerasCNN
38useTMVADNN = opt[2] if len(opt) > 2 else False
39useTMVABDT = opt[3] if len(opt) > 3 else False
40usePyTorchCNN = opt[4] if len(opt) > 4 else False
41
42if useKerasCNN:
43 try:
44 import tensorflow
45 except:
46 ROOT.Warning("TMVA_CNN_Classification", "Skip using Keras since tensorflow cannot be imported")
47 useKerasCNN = False
48
49# PyTorch has to be imported before ROOT to avoid crashes because of clashing
50# std::regexp symbols that are exported by cppyy.
51# See also: https://github.com/wlav/cppyy/issues/227
52torch_spec = importlib.util.find_spec("torch")
53if torch_spec is None:
54 usePyTorchCNN = False
55 print("TMVA_CNN_Classificaton","Skip using PyTorch since torch is not installed")
56else:
57 try:
58 import torch
59 except:
60 ROOT.Warning("TMVA_CNN_Classification", "Skip using PyTorch since it cannot be imported")
61 usePyTorchCNN = False
62
63
64import ROOT
65
66
67TMVA = ROOT.TMVA
68TFile = ROOT.TFile
69
71
72def MakeImagesTree(n, nh, nw):
73 # image size (nh x nw)
74 ntot = nh * nw
75 fileOutName = "images_data_16x16.root"
76 nRndmEvts = 10000 # number of events we use to fill each image
77 delta_sigma = 0.1 # 5% difference in the sigma
78 pixelNoise = 5
79
80 sX1 = 3
81 sY1 = 3
82 sX2 = sX1 + delta_sigma
83 sY2 = sY1 - delta_sigma
84 h1 = ROOT.TH2D("h1", "h1", nh, 0, 10, nw, 0, 10)
85 h2 = ROOT.TH2D("h2", "h2", nh, 0, 10, nw, 0, 10)
86 f1 = ROOT.TF2("f1", "xygaus")
87 f2 = ROOT.TF2("f2", "xygaus")
88 sgn = ROOT.TTree("sig_tree", "signal_tree")
89 bkg = ROOT.TTree("bkg_tree", "background_tree")
90
91 f = TFile(fileOutName, "RECREATE")
92 x1 = ROOT.std.vector["float"](ntot)
93 x2 = ROOT.std.vector["float"](ntot)
94
95 # create signal and background trees with a single branch
96 # an std::vector<float> of size nh x nw containing the image data
97 bkg.Branch("vars", "std::vector<float>", x1)
98 sgn.Branch("vars", "std::vector<float>", x2)
99
102
103 f1.SetParameters(1, 5, sX1, 5, sY1)
104 f2.SetParameters(1, 5, sX2, 5, sY2)
106 ROOT.Info("TMVA_CNN_Classification", "Filling ROOT tree \n")
107 for i in range(n):
108 if i % 1000 == 0:
109 print("Generating image event ...", i)
110
111 h1.Reset()
112 h2.Reset()
113 # generate random means in range [3,7] to be not too much on the border
118
119 h1.FillRandom(f1, nRndmEvts)
120 h2.FillRandom(f2, nRndmEvts)
121
122 for k in range(nh):
123 for l in range(nw):
124 m = k * nw + l
125 # add some noise in each bin
126 x1[m] = h1.GetBinContent(k + 1, l + 1) + ROOT.gRandom.Gaus(0, pixelNoise)
127 x2[m] = h2.GetBinContent(k + 1, l + 1) + ROOT.gRandom.Gaus(0, pixelNoise)
128
129 sgn.Fill()
130 bkg.Fill()
131
132 sgn.Write()
133 bkg.Write()
134
135 print("Signal and background tree with images data written to the file %s", f.GetName())
136 sgn.Print()
137 bkg.Print()
138 f.Close()
139
140hasGPU = "tmva-gpu" in ROOT.gROOT.GetConfigFeatures()
141hasCPU = "tmva-cpu" in ROOT.gROOT.GetConfigFeatures()
142
143nevt = 1000 # use a larger value to get better results
144
145if (not hasCPU and not hasGPU) :
146 ROOT.Warning("TMVA_CNN_Classificaton","ROOT is not supporting tmva-cpu and tmva-gpu skip using TMVA-DNN and TMVA-CNN")
147 useTMVACNN = False
148 useTMVADNN = False
149
150if not "tmva-pymva" in ROOT.gROOT.GetConfigFeatures():
151 useKerasCNN = False
152 usePyTorchCNN = False
153else:
155
156if not useTMVACNN:
158 "TMVA_CNN_Classificaton",
159 "TMVA is not build with GPU or CPU multi-thread support. Cannot use TMVA Deep Learning for CNN",
160 )
161
162writeOutputFile = True
163
164num_threads = 4 # use max 4 threads
165max_epochs = 10 # maximum number of epochs used for training
166
167
168# do enable MT running
169if "imt" in ROOT.gROOT.GetConfigFeatures():
170 ROOT.EnableImplicitMT(num_threads)
171 ROOT.gSystem.Setenv("OMP_NUM_THREADS", "1") # switch OFF MT in OpenBLAS
172 print("Running with nthreads = {}".format(ROOT.GetThreadPoolSize()))
173else:
174 print("Running in serial mode since ROOT does not support MT")
175
176
177
178
179outputFile = None
180if writeOutputFile:
181 outputFile = TFile.Open("TMVA_CNN_ClassificationOutput.root", "RECREATE")
182
183
184## Create TMVA Factory
185
186# Create the Factory class. Later you can choose the methods
187# whose performance you'd like to investigate.
188
189# The factory is the major TMVA object you have to interact with. Here is the list of parameters you need to pass
190
191# - The first argument is the base of the name of all the output
192# weight files in the directory weight/ that will be created with the
193# method parameters
194
195# - The second argument is the output file for the training results
196
197# - The third argument is a string option defining some general configuration for the TMVA session.
198# For example all TMVA output can be suppressed by removing the "!" (not) in front of the "Silent" argument in the
199# option string
200
201# - note that we disable any pre-transformation of the input variables and we avoid computing correlations between
202# input variables
203
204
205factory = TMVA.Factory(
206 "TMVA_CNN_Classification",
207 outputFile,
208 V=False,
209 ROC=True,
210 Silent=False,
211 Color=True,
212 AnalysisType="Classification",
213 Transformations=None,
214 Correlations=False,
215)
216
217
218## Declare DataLoader(s)
219
220# The next step is to declare the DataLoader class that deals with input variables
221
222# Define the input variables that shall be used for the MVA training
223# note that you may also use variable expressions, which can be parsed by TTree::Draw( "expression" )]
224
225# In this case the input data consists of an image of 16x16 pixels. Each single pixel is a branch in a ROOT TTree
226
227loader = TMVA.DataLoader("dataset")
228
229
230## Setup Dataset(s)
231
232# Define input data file and signal and background trees
233
234
235imgSize = 16 * 16
236inputFileName = "images_data_16x16.root"
237
238# if the input file does not exist create it
239if ROOT.gSystem.AccessPathName(inputFileName):
240 MakeImagesTree(nevt, 16, 16)
241
242inputFile = TFile.Open(inputFileName)
243if inputFile is None:
244 ROOT.Warning("TMVA_CNN_Classification", "Error opening input file %s - exit", inputFileName.Data())
245
246
247# inputFileName = "tmva_class_example.root"
248
249
250# --- Register the training and test trees
251
252signalTree = inputFile.Get("sig_tree")
253backgroundTree = inputFile.Get("bkg_tree")
254
255nEventsSig = signalTree.GetEntries()
256nEventsBkg = backgroundTree.GetEntries()
257
258# global event weights per tree (see below for setting event-wise weights)
259signalWeight = 1.0
260backgroundWeight = 1.0
261
262# You can add an arbitrary number of signal or background trees
263loader.AddSignalTree(signalTree, signalWeight)
264loader.AddBackgroundTree(backgroundTree, backgroundWeight)
265
266## add event variables (image)
267## use new method (from ROOT 6.20 to add a variable array for all image data)
268loader.AddVariablesArray("vars", imgSize)
269
270# Set individual event weights (the variables must exist in the original TTree)
271# for signal : factory->SetSignalWeightExpression ("weight1*weight2");
272# for background: factory->SetBackgroundWeightExpression("weight1*weight2");
273# loader->SetBackgroundWeightExpression( "weight" );
274
275# Apply additional cuts on the signal and background samples (can be different)
276mycuts = "" # for example: TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
277mycutb = "" # for example: TCut mycutb = "abs(var1)<0.5";
278
279# Tell the factory how to use the training and testing events
280# If no numbers of events are given, half of the events in the tree are used
281# for training, and the other half for testing:
282# loader.PrepareTrainingAndTestTree( mycut, "SplitMode=random:!V" );
283# It is possible also to specify the number of training and testing events,
284# note we disable the computation of the correlation matrix of the input variables
285
286nTrainSig = 0.8 * nEventsSig
287nTrainBkg = 0.8 * nEventsBkg
288
289# build the string options for DataLoader::PrepareTrainingAndTestTree
290
292 mycuts,
293 mycutb,
294 nTrain_Signal=nTrainSig,
295 nTrain_Background=nTrainBkg,
296 SplitMode="Random",
297 SplitSeed=100,
298 NormMode="NumEvents",
299 V=False,
300 CalcCorrelations=False,
301)
302
303
304# DataSetInfo : [dataset] : Added class "Signal"
305# : Add Tree sig_tree of type Signal with 10000 events
306# DataSetInfo : [dataset] : Added class "Background"
307# : Add Tree bkg_tree of type Background with 10000 events
308
309# signalTree.Print();
310
311# Booking Methods
312
313# Here we book the TMVA methods. We book a Boosted Decision Tree method (BDT)
314
315
316# Boosted Decision Trees
317if useTMVABDT:
319 loader,
321 "BDT",
322 V=False,
323 NTrees=400,
324 MinNodeSize="2.5%",
325 MaxDepth=2,
326 BoostType="AdaBoost",
327 AdaBoostBeta=0.5,
328 UseBaggedBoost=True,
329 BaggedSampleFraction=0.5,
330 SeparationType="GiniIndex",
331 nCuts=20,
332 )
333
334
335#### Booking Deep Neural Network
336
337# Here we book the DNN of TMVA. See the example TMVA_Higgs_Classification.C for a detailed description of the
338# options
339
340if useTMVADNN:
341 layoutString = ROOT.TString(
342 "DENSE|100|RELU,BNORM,DENSE|100|RELU,BNORM,DENSE|100|RELU,BNORM,DENSE|100|RELU,DENSE|1|LINEAR"
343 )
344
345 # Training strategies
346 # one can catenate several training strings with different parameters (e.g. learning rates or regularizations
347 # parameters) The training string must be concatenated with the `|` delimiter
348 trainingString1 = ROOT.TString(
349 "LearningRate=1e-3,Momentum=0.9,Repetitions=1,"
350 "ConvergenceSteps=5,BatchSize=100,TestRepetitions=1,"
351 "WeightDecay=1e-4,Regularization=None,"
352 "Optimizer=ADAM,DropConfig=0.0+0.0+0.0+0."
353 ) # + "|" + trainingString2 + ...
354 trainingString1 += ",MaxEpochs=" + str(max_epochs)
355
356 # Build now the full DNN Option string
357 dnnMethodName = "TMVA_DNN_CPU"
358
359 # use GPU if available
360 dnnOptions = "CPU"
361 if hasGPU :
362 dnnOptions = "GPU"
363 dnnMethodName = "TMVA_DNN_GPU"
364
366 loader,
368 dnnMethodName,
369 H=False,
370 V=True,
371 ErrorStrategy="CROSSENTROPY",
372 VarTransform=None,
373 WeightInitialization="XAVIER",
374 Layout=layoutString,
375 TrainingStrategy=trainingString1,
376 Architecture=dnnOptions
377 )
378
379
380### Book Convolutional Neural Network in TMVA
381
382# For building a CNN one needs to define
383
384# - Input Layout : number of channels (in this case = 1) | image height | image width
385# - Batch Layout : batch size | number of channels | image size = (height*width)
386
387# Then one add Convolutional layers and MaxPool layers.
388
389# - For Convolutional layer the option string has to be:
390# - CONV | number of units | filter height | filter width | stride height | stride width | padding height | paddig
391# width | activation function
392
393# - note in this case we are using a filer 3x3 and padding=1 and stride=1 so we get the output dimension of the
394# conv layer equal to the input
395
396# - note we use after the first convolutional layer a batch normalization layer. This seems to help significantly the
397# convergence
398
399# - For the MaxPool layer:
400# - MAXPOOL | pool height | pool width | stride height | stride width
401
402# The RESHAPE layer is needed to flatten the output before the Dense layer
403
404# Note that to run the CNN is required to have CPU or GPU support
405
406
407if useTMVACNN:
408 # Training strategies.
409 trainingString1 = ROOT.TString(
410 "LearningRate=1e-3,Momentum=0.9,Repetitions=1,"
411 "ConvergenceSteps=5,BatchSize=100,TestRepetitions=1,"
412 "WeightDecay=1e-4,Regularization=None,"
413 "Optimizer=ADAM,DropConfig=0.0+0.0+0.0+0.0"
414 )
415 trainingString1 += ",MaxEpochs=" + str(max_epochs)
416
417 ## New DL (CNN)
418 cnnMethodName = "TMVA_CNN_CPU"
419 cnnOptions = "CPU"
420 # use GPU if available
421 if hasGPU:
422 cnnOptions = "GPU"
423 cnnMethodName = "TMVA_CNN_GPU"
424
426 loader,
428 cnnMethodName,
429 H=False,
430 V=True,
431 ErrorStrategy="CROSSENTROPY",
432 VarTransform=None,
433 WeightInitialization="XAVIER",
434 InputLayout="1|16|16",
435 Layout="CONV|10|3|3|1|1|1|1|RELU,BNORM,CONV|10|3|3|1|1|1|1|RELU,MAXPOOL|2|2|1|1,RESHAPE|FLAT,DENSE|100|RELU,DENSE|1|LINEAR",
436 TrainingStrategy=trainingString1,
437 Architecture=cnnOptions,
438 )
439
440
441### Book Convolutional Neural Network in Keras using a generated model
442
443
444if usePyTorchCNN:
445 ROOT.Info("TMVA_CNN_Classification", "Using Convolutional PyTorch Model")
446 pyTorchFileName = str(ROOT.gROOT.GetTutorialDir())
447 pyTorchFileName += "/machine_learning/PyTorch_Generate_CNN_Model.py"
448 # check that pytorch can be imported and file defining the model exists
449 torch_spec = importlib.util.find_spec("torch")
450 if torch_spec is not None and os.path.exists(pyTorchFileName):
451 #cmd = str(ROOT.TMVA.Python_Executable()) + " " + pyTorchFileName
452 #os.system(cmd)
453 #import PyTorch_Generate_CNN_Model
454 ROOT.Info("TMVA_CNN_Classification", "Booking PyTorch CNN model")
456 loader,
458 "PyTorch",
459 H=True,
460 V=False,
461 VarTransform=None,
462 FilenameModel="PyTorchModelCNN.pt",
463 FilenameTrainedModel="PyTorchTrainedModelCNN.pt",
464 NumEpochs=max_epochs,
465 BatchSize=100,
466 UserCode=str(pyTorchFileName)
467 )
468 else:
470 "TMVA_CNN_Classification",
471 "PyTorch is not installed or model building file is not existing - skip using PyTorch",
472 )
473
474if useKerasCNN:
475 ROOT.Info("TMVA_CNN_Classification", "Building convolutional keras model")
476 # create python script which can be executed
477 # create 2 conv2d layer + maxpool + dense
478 import tensorflow
479 from tensorflow.keras.models import Sequential
480 from tensorflow.keras.optimizers import Adam
481
482 # from keras.initializers import TruncatedNormal
483 # from keras import initializations
484 from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, Conv2D, MaxPooling2D, Reshape
485
486 # from keras.callbacks import ReduceLROnPlateau
487 model = Sequential()
488 model.add(Reshape((16, 16, 1), input_shape=(256,)))
489 model.add(Conv2D(10, kernel_size=(3, 3), kernel_initializer="TruncatedNormal", activation="relu", padding="same"))
490 model.add(Conv2D(10, kernel_size=(3, 3), kernel_initializer="TruncatedNormal", activation="relu", padding="same"))
491 # stride for maxpool is equal to pool size
492 model.add(MaxPooling2D(pool_size=(2, 2)))
493 model.add(Flatten())
494 model.add(Dense(64, activation="tanh"))
495 # model.add(Dropout(0.2))
496 model.add(Dense(2, activation="sigmoid"))
497 model.compile(loss="binary_crossentropy", optimizer=Adam(learning_rate=0.001), weighted_metrics=["accuracy"])
498 model.save("model_cnn.h5")
500
501 if not os.path.exists("model_cnn.h5"):
502 raise FileNotFoundError("Error creating Keras model file - skip using Keras")
503 else:
504 # book PyKeras method only if Keras model could be created
505 ROOT.Info("TMVA_CNN_Classification", "Booking convolutional keras model")
507 loader,
509 "PyKeras",
510 H=True,
511 V=False,
512 VarTransform=None,
513 FilenameModel="model_cnn.h5",
514 FilenameTrainedModel="trained_model_cnn.h5",
515 NumEpochs=max_epochs,
516 BatchSize=100,
517 GpuOptions="allow_growth=True",
518 ) # needed for RTX NVidia card and to avoid TF allocates all GPU memory
519
520
521
522## Train Methods
523
525
526## Test and Evaluate Methods
527
529
531
532## Plot ROC Curve
533
534c1 = factory.GetROCCurve(loader)
535c1.Draw()
536
537# close outputfile to save output file
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t format
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:131
This is the main MVA steering class.
Definition Factory.h:80