ROOT
v6-36
Reference Guide
Loading...
Searching...
No Matches
ClassificationKeras.py
Go to the documentation of this file.
1
## \file
2
## \ingroup tutorial_tmva_keras
3
## \notebook -nodraw
4
## This tutorial shows how to do classification in TMVA with neural networks
5
## trained with keras.
6
##
7
## \macro_code
8
##
9
## \date 2017
10
## \author TMVA Team
11
12
from
ROOT
import
TMVA, TFile, TCut, gROOT
13
from
subprocess
import
call
14
from
os.path
import
isfile
15
16
from
tensorflow.keras.models
import
Sequential
17
from
tensorflow.keras.layers
import
Dense
18
from
tensorflow.keras.optimizers
import
SGD
19
20
21
def
create_model
():
22
# Generate model
23
24
# Define model
25
model =
Sequential
()
26
model.add
(
Dense
(64, activation=
'relu'
, input_dim=4))
27
model.add
(
Dense
(2, activation=
'softmax'
))
28
29
# Set loss and optimizer
30
model.compile
(loss=
'categorical_crossentropy'
,
31
optimizer=
SGD
(learning_rate=0.01), weighted_metrics=[
'accuracy'
, ])
32
33
# Store model to file
34
model.save
(
'modelClassification.h5'
)
35
model.summary
()
36
37
38
def
run():
39
with
TFile.Open
(
'TMVA_Classification_Keras.root'
,
'RECREATE'
)
as
output,
TFile.Open
(str(
gROOT.GetTutorialDir
()) +
'/machine_learning/data/tmva_class_example.root'
)
as
data:
40
factory =
TMVA.Factory
(
'TMVAClassification'
, output,
41
'!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Classification'
)
42
43
signal =
data.Get
(
'TreeS'
)
44
background =
data.Get
(
'TreeB'
)
45
46
dataloader =
TMVA.DataLoader
(
'dataset'
)
47
for
branch
in
signal.GetListOfBranches
():
48
dataloader.AddVariable
(
branch.GetName
())
49
50
dataloader.AddSignalTree
(signal, 1.0)
51
dataloader.AddBackgroundTree
(background, 1.0)
52
dataloader.PrepareTrainingAndTestTree
(
TCut
(
''
),
53
'nTrain_Signal=4000:nTrain_Background=4000:SplitMode=Random:NormMode=NumEvents:!V'
)
54
55
# Book methods
56
factory.BookMethod
(dataloader,
TMVA.Types.kFisher
,
'Fisher'
,
57
'!H:!V:Fisher:VarTransform=D,G'
)
58
factory.BookMethod
(dataloader,
TMVA.Types.kPyKeras
,
'PyKeras'
,
59
'H:!V:VarTransform=D,G:FilenameModel=modelClassification.h5:FilenameTrainedModel=trainedModelClassification.h5:NumEpochs=20:BatchSize=32'
)
60
61
# Run training, test and evaluation
62
factory.TrainAllMethods
()
63
factory.TestAllMethods
()
64
factory.EvaluateAllMethods
()
65
66
67
if
__name__ ==
"__main__"
:
68
# Setup TMVA
69
TMVA.Tools.Instance
()
70
TMVA.PyMethodBase.PyInitialize
()
71
72
# Create and store the ML model
73
create_model
()
74
75
# Run TMVA
76
run()
TRangeDynCast
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Definition
TCollection.h:358
ROOT::Detail::TRangeCast
Definition
TCollection.h:311
TCut
A specialized string object used for TTree selections.
Definition
TCut.h:25
TMVA::DataLoader
Definition
DataLoader.h:50
TMVA::Factory
This is the main MVA steering class.
Definition
Factory.h:80
tutorials
machine_learning
keras
ClassificationKeras.py
ROOT v6-36 - Reference Guide Generated on Thu Aug 21 2025 04:30:47 (GVA Time) using Doxygen 1.10.0