Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator_TensorFlow.py File Reference

Detailed Description

View in nbviewer Open in SWAN
Example of getting batches of events from a ROOT dataset into a basic TensorFlow workflow.

import ROOT
# TensorFlow has to be imported after ROOT to avoid LLMV symbol clashes if ROOT
# was built with LLMV in Debug mode and TensorFlow>=2.20.0.
import tensorflow as tf
tree_name = "sig_tree"
file_name = str(ROOT.gROOT.GetTutorialDir()) + "/machine_learning/data/Higgs_data.root"
batch_size = 128
chunk_size = 5_000
rdataframe = ROOT.RDataFrame(tree_name, file_name)
target = "Type"
# Returns two TF.Dataset for training and validation batches.
rdataframe,
batch_size,
chunk_size,
validation_split=0.3,
target=target,
)
num_of_epochs = 2
# Datasets have to be repeated as many times as there are epochs
ds_train_repeated = ds_train.repeat(num_of_epochs)
ds_valid_repeated = ds_valid.repeat(num_of_epochs)
# Number of batches per epoch must be given for model.fit
train_batches_per_epoch = ds_train.number_of_batches
validation_batches_per_epoch = ds_valid.number_of_batches
# Get a list of the columns used for training
input_columns = ds_train.train_columns
num_features = len(input_columns)
##############################################################################
# AI example
##############################################################################
# Define TensorFlow model
[
tf.keras.layers.Input(shape=(num_features,)),
]
)
model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
# Train model
model.fit(ds_train_repeated, steps_per_epoch=train_batches_per_epoch, validation_data=ds_valid_repeated,\
validation_steps=validation_batches_per_epoch, epochs=num_of_epochs)
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...
Epoch 1/2
␛[1m 1/54␛[0m ␛[37m━━━━━━━━━━━━━━━━━━━━␛[0m ␛[1m1:56␛[0m 2s/step - accuracy: 0.6172 - loss: 0.6638␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
␛[1m16/54␛[0m ␛[32m━━━━━␛[0m␛[37m━━━━━━━━━━━━━━━␛[0m ␛[1m0s␛[0m 4ms/step - accuracy: 0.9191 - loss: 0.1538 ␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
␛[1m32/54␛[0m ␛[32m━━━━━━━━━━━␛[0m␛[37m━━━━━━━━━␛[0m ␛[1m0s␛[0m 3ms/step - accuracy: 0.9514 - loss: 0.0930␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
␛[1m49/54␛[0m ␛[32m━━━━━━━━━━━━━━━━━━␛[0m␛[37m━━␛[0m ␛[1m0s␛[0m 3ms/step - accuracy: 0.9650 - loss: 0.0672␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
␛[1m54/54␛[0m ␛[32m━━━━━━━━━━━━━━━━━━━━␛[0m␛[37m␛[0m ␛[1m3s␛[0m 8ms/step - accuracy: 0.9680 - loss: 0.0615 - val_accuracy: 1.0000 - val_loss: 3.4227e-07
Epoch 2/2
␛[1m 1/54␛[0m ␛[37m━━━━━━━━━━━━━━━━━━━━␛[0m ␛[1m0s␛[0m 3ms/step - accuracy: 1.0000 - loss: 3.0330e-07␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
␛[1m17/54␛[0m ␛[32m━━━━━━␛[0m␛[37m━━━━━━━━━━━━━━␛[0m ␛[1m0s␛[0m 3ms/step - accuracy: 1.0000 - loss: 3.7345e-07␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
␛[1m33/54␛[0m ␛[32m━━━━━━━━━━━━␛[0m␛[37m━━━━━━━━␛[0m ␛[1m0s␛[0m 3ms/step - accuracy: 1.0000 - loss: 3.6823e-07␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
␛[1m51/54␛[0m ␛[32m━━━━━━━━━━━━━━━━━━␛[0m␛[37m━━␛[0m ␛[1m0s␛[0m 3ms/step - accuracy: 1.0000 - loss: 3.6505e-07␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
␛[1m54/54␛[0m ␛[32m━━━━━━━━━━━━━━━━━━━━␛[0m␛[37m␛[0m ␛[1m0s␛[0m 4ms/step - accuracy: 1.0000 - loss: 3.6421e-07 - val_accuracy: 0.9545 - val_loss: 3.0931e-07
Author
Dante Niewenhuis

Definition in file RBatchGenerator_TensorFlow.py.