Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RegressionKeras.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_tmva_keras
3## \notebook -nodraw
4## This tutorial shows how to do regression in TMVA with neural networks
5## trained with keras.
6##
7## \macro_code
8##
9## \date 2017
10## \author TMVA Team
11
12from ROOT import TMVA, TFile, TCut, gROOT
13from subprocess import call
14from os.path import isfile
15
16from tensorflow.keras.models import Sequential
17from tensorflow.keras.layers import Dense
18from tensorflow.keras.optimizers import SGD
19
20
21def create_model():
22 # Define model
23 model = Sequential()
24 model.add(Dense(64, activation='tanh', input_dim=2))
25 model.add(Dense(1, activation='linear'))
26
27 # Set loss and optimizer
28 model.compile(loss='mean_squared_error', optimizer=SGD(
29 learning_rate=0.01), weighted_metrics=[])
30
31 # Store model to file
32 model.save('modelRegression.h5')
34
35
36def run():
37
38 with TFile.Open('TMVA_Regression_Keras.root', 'RECREATE') as output, TFile.Open(str(gROOT.GetTutorialDir()) + '/machine_learning/data/tmva_reg_example.root') as data:
39 factory = TMVA.Factory('TMVARegression', output,
40 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Regression')
41
42 tree = data.Get('TreeR')
43
44 dataloader = TMVA.DataLoader('dataset')
45 for branch in tree.GetListOfBranches():
46 name = branch.GetName()
47 if name != 'fvalue':
49 dataloader.AddTarget('fvalue')
50
52 # use only 1000 events since evaluation is very slow (especially on MacOS). Increase it to get meaningful results
54 'nTrain_Regression=1000:SplitMode=Random:NormMode=NumEvents:!V')
55
56 # Book methods
57 factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 'PyKeras',
58 'H:!V:VarTransform=D,G:FilenameModel=modelRegression.h5:FilenameTrainedModel=trainedModelRegression.h5:NumEpochs=20:BatchSize=32')
59 factory.BookMethod(dataloader, TMVA.Types.kBDT, 'BDTG',
60 '!H:!V:VarTransform=D,G:NTrees=1000:BoostType=Grad:Shrinkage=0.1:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=4')
61
62 # Run TMVA
66
67
68if __name__ == "__main__":
69 # Setup TMVA
72
73 # Generate model
75
76 # Run TMVA
77 run()
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
A specialized string object used for TTree selections.
Definition TCut.h:25
This is the main MVA steering class.
Definition Factory.h:80