Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
PyTorch_Generate_CNN_Model.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_ml
3## \notebook
4##
5## \macro_image
6## \macro_output
7## \macro_code
8##
9## \author Harshal Shende
10
11import torch
12from torch import nn
13
14# Define model
15
16print("running Torch code defining the model....")
17
18# Custom Reshape Layer
19class Reshape(torch.nn.Module):
20 def forward(self, x):
21 return x.view(-1,1,16,16)
22
23# CNN Model Definition
25 Reshape(),
26 nn.Conv2d(1, 10, kernel_size=3, padding=1),
27 nn.ReLU(),
29 nn.Conv2d(10, 10, kernel_size=3, padding=1),
30 nn.ReLU(),
31 nn.MaxPool2d(kernel_size=2),
32 nn.Flatten(),
33 nn.Linear(10*8*8, 256),
34 nn.ReLU(),
35 nn.Linear(256, 2),
37 )
38
39# Construct loss function and Optimizer.
40criterion = nn.BCELoss()
41optimizer = torch.optim.Adam
42
43
44def fit(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler):
45 trainer = optimizer(model.parameters(), lr=0.01)
46 schedule, schedulerSteps = scheduler
47 best_val = None
48
49 # Setup GPU
50 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
51 model = model.to(device)
52
53 for epoch in range(num_epochs):
54 # Training Loop
55 # Set to train mode
57 running_train_loss = 0.0
58 running_val_loss = 0.0
59 for i, (X, y) in enumerate(train_loader):
61 X, y = X.to(device), y.to(device)
62 output = model(X)
63 target = y
64 train_loss = criterion(output, target)
67
68 # print train statistics
69 running_train_loss += train_loss.item()
70 if i % 4 == 3: # print every 4 mini-batches
71 print(f"[{epoch+1}, {i+1}] train loss: {running_train_loss / 4 :.3f}")
72 running_train_loss = 0.0
73
74 if schedule:
75 schedule(optimizer, epoch, schedulerSteps)
76
77 # Validation Loop
78 # Set to eval mode
80 with torch.no_grad():
81 for i, (X, y) in enumerate(val_loader):
82 X, y = X.to(device), y.to(device)
83 output = model(X)
84 target = y
85 val_loss = criterion(output, target)
86 running_val_loss += val_loss.item()
87
88 curr_val = running_val_loss / len(val_loader)
89 if save_best:
90 if best_val==None:
91 best_val = curr_val
92 best_val = save_best(model, curr_val, best_val)
93
94 # print val statistics per epoch
95 print(f"[{epoch+1}] val loss: {curr_val :.3f}")
96 running_val_loss = 0.0
97
98 print(f"Finished Training on {epoch+1} Epochs!")
99
100 return model
101
102
103def predict(model, test_X, batch_size=100):
104 # Set to eval mode
105
106 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
107 model = model.to(device)
108
109 model.eval()
110
111
112 test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))
113 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
114
115 predictions = []
116 with torch.no_grad():
117 for i, data in enumerate(test_loader):
118 X = data[0].to(device)
119 outputs = model(X)
120 predictions.append(outputs)
121 preds = torch.cat(predictions)
122
123 return preds.cpu().numpy()
124
125
126load_model_custom_objects = {"optimizer": optimizer, "criterion": criterion, "train_func": fit, "predict_func": predict}
127
128# Store model to file
129m = torch.jit.script(net)
130torch.jit.save(m,"PyTorchModelCNN.pt")
131print("The PyTorch CNN model is created and saved as PyTorchModelCNN.pt")
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len