13from ROOT 
import TMVA, TFile, TTree, TCut, gROOT
 
   14from os.path 
import isfile
 
   16from tensorflow.keras.models 
import Sequential
 
   17from tensorflow.keras.layers 
import Dense, Activation
 
   18from tensorflow.keras.optimizers 
import SGD
 
   24output = 
TFile.Open(
'TMVA_multiclass.root', 
'RECREATE')
 
   26    '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=multiclass')
 
   29if not isfile(
'tmva_example_multiple_background.root'):
 
   30    createDataMacro = str(gROOT.GetTutorialDir()) + 
'/tmva/createData.C' 
   31    print(createDataMacro)
 
   32    gROOT.ProcessLine(
'.L {}'.
format(createDataMacro))
 
   33    gROOT.ProcessLine(
'create_MultipleBackground(4000)')
 
   35data = 
TFile.Open(
'tmva_example_multiple_background.root')
 
   36signal = data.Get(
'TreeS')
 
   37background0 = data.Get(
'TreeB0')
 
   38background1 = data.Get(
'TreeB1')
 
   39background2 = data.Get(
'TreeB2')
 
   42for branch 
in signal.GetListOfBranches():
 
   43    dataloader.AddVariable(branch.GetName())
 
   45dataloader.AddTree(signal, 
'Signal')
 
   46dataloader.AddTree(background0, 
'Background_0')
 
   47dataloader.AddTree(background1, 
'Background_1')
 
   48dataloader.AddTree(background2, 
'Background_2')
 
   49dataloader.PrepareTrainingAndTestTree(
TCut(
''),
 
   50        'SplitMode=Random:NormMode=NumEvents:!V')
 
   56model.add(Dense(32, activation=
'relu', input_dim=4))
 
   57model.add(Dense(4, activation=
'softmax'))
 
   60model.compile(loss=
'categorical_crossentropy', optimizer=SGD(learning_rate=0.01), weighted_metrics=[
'accuracy',])
 
   63model.save(
'modelMultiClass.h5')
 
   67factory.BookMethod(dataloader, TMVA.Types.kFisher, 
'Fisher',
 
   68        '!H:!V:Fisher:VarTransform=D,G')
 
   69factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 
'PyKeras',
 
   70                   'H:!V:VarTransform=D,G:FilenameModel=modelMultiClass.h5:FilenameTrainedModel=trainedModelMultiClass.h5:NumEpochs=20:BatchSize=32')
 
   73factory.TrainAllMethods()
 
   74factory.TestAllMethods()
 
   75factory.EvaluateAllMethods()
 
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t format
A specialized string object used for TTree selections.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
This is the main MVA steering class.
static void PyInitialize()
Initialize Python interpreter.