31 print(
"Running in serial mode since ROOT does not support MT")
54 fname =
"time_data_t" + str(ntime) +
"_d" + str(ndim) +
".root"
58 for i
in range(ntime):
67 f =
TFile(fname,
"RECREATE")
72 for i
in range(ntime):
76 for i
in range(ntime):
77 bkg.Branch(
"vars_time" + str(i),
"std::vector<float>", x1[i])
78 sgn.Branch(
"vars_time" + str(i),
"std::vector<float>", x2[i])
89 for j
in range(ntime):
97 print(
"Generating event ... %d", i)
99 for j
in range(ntime):
111 for k
in range(ntime):
122 for j
in range(ntime):
125 for j
in range(ntime):
168 ROOT.Warning(
"TMVA_RNN_Classification",
"Skip using Keras since tensorflow cannot be imported")
172rnn_types = [
"RNN",
"LSTM",
"GRU"]
173use_rnn_type = [1, 1, 1]
176 use_rnn_type = [0, 0, 0]
177 use_rnn_type[use_type] = 1
186 "TMVA_RNN_Classification",
187 "TMVA is not build with GPU or CPU multi-thread support. Cannot use TMVA Deep Learning for RNN",
190archString =
"GPU" if useGPU
else "CPU"
192writeOutputFile =
True
203inputFileName =
"time_data_t10_d30.root"
220outfileName =
"data_RNN_" + archString +
".root"
225 outputFile =
TFile.Open(outfileName,
"RECREATE")
249 "TMVAClassification",
254 DrawProgressBar=
True,
255 Transformations=
None,
257 AnalysisType=
"Classification",
258 ModelPersistence=
True,
268for i
in range(ntime):
284nTrainSig = 0.8 * nTotEvts
285nTrainBkg = 0.8 * nTotEvts
295 nTrain_Signal=nTrainSig,
296 nTrain_Background=nTrainBkg,
299 NormMode=
"NumEvents",
301 CalcCorrelations=
False,
304print(
"prepared DATA LOADER ")
314 if not use_rnn_type[i]:
317 rnn_type = rnn_types[i]
321 rnnLayout = str(rnn_type) +
"|10|" + str(ninput) +
"|" + str(ntime) +
"|0|1,RESHAPE|FLAT,DENSE|64|TANH,LINEAR"
324 trainingString1 =
"LearningRate=1e-3,Momentum=0.0,Repetitions=1,ConvergenceSteps=5,BatchSize=" + str(batchSize)
325 trainingString1 +=
",TestRepetitions=1,WeightDecay=1e-2,Regularization=None,MaxEpochs=" + str(maxepochs)
326 trainingString1 +=
"Optimizer=ADAM,DropConfig=0.0+0.+0.+0."
334 rnnName =
"TMVA_" + str(rnn_type)
341 ErrorStrategy=
"CROSSENTROPY",
343 WeightInitialization=
"XAVIERUNIFORM",
346 InputLayout=str(ntime) +
"|" + str(ninput),
348 TrainingStrategy=trainingString1,
349 Architecture=archString
358 "LearningRate=1e-3,Momentum=0.0,Repetitions=1,"
359 "ConvergenceSteps=10,BatchSize=256,TestRepetitions=1,"
360 "WeightDecay=1e-4,Regularization=None,MaxEpochs=20"
361 "DropConfig=0.0+0.+0.+0.,Optimizer=ADAM:"
372 ErrorStrategy=
"CROSSENTROPY",
374 WeightInitialization=
"XAVIER",
376 InputLayout=
"1|1|" + str(ntime * ninput),
377 Layout=
"DENSE|64|TANH,DENSE|TANH|64,DENSE|TANH|64,LINEAR",
378 TrainingStrategy=trainingString1
390 modelName =
"model_" + rnn_types[i] +
".h5"
391 trainedModelName =
"trained_" + modelName
392 print(
"Building recurrent keras model using a", rnn_types[i],
"layer")
400 from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, SimpleRNN, GRU, LSTM, Reshape, BatchNormalization
403 model.add(Reshape((10, 30), input_shape=(10 * 30,)))
405 if rnn_types[i] ==
"LSTM":
407 elif rnn_types[i] ==
"GRU":
415 model.compile(loss=
"binary_crossentropy", optimizer=
Adam(learning_rate=0.001), weighted_metrics=[
"accuracy"])
418 print(
"saved recurrent model", modelName)
422 print(
"Error creating Keras recurrent model file - Skip using Keras")
425 print(
"Booking Keras model ", rnn_types[i])
429 "PyKeras_" + rnn_types[i],
433 FilenameModel=modelName,
434 FilenameTrainedModel=
"trained_" + modelName,
437 GpuOptions=
"allow_growth=True",
442if not useKeras
or not useTMVA_BDT:
461 BaggedSampleFraction=0.5,
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t format
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
This is the main MVA steering class.