Module awave.utils.train
Expand source code
from copy import deepcopy
import numpy as np
import torch
from awave.utils.wave_attributions import Attributer
from awave.trim import TrimModel
class Trainer():
"""
Class to handle training of model.
Parameters
----------
model: optional, torch.model
optimizer: torch.optim.Optimizer
w_transform: torch.nn.module
Wavelet transformer
device: torch.device, optional
Device on which to run the code.
use_residuals : boolean, optional
Use residuals to compute TRIM score.
"""
def __init__(self,
model=None,
w_transform=None,
optimizer=None,
loss_f=None,
target=1,
device=torch.device("cuda"),
use_residuals=True,
attr_methods='InputXGradient',
n_print=1):
self.device = device
self.is_parallel = 'data_parallel' in str(type(w_transform))
self.wt_inverse = w_transform.module.inverse if self.is_parallel else w_transform.inverse # use multiple GPUs or not
if model is not None:
self.model = model.to(self.device)
self.mt = TrimModel(model, self.wt_inverse, use_residuals=use_residuals)
self.attributer = Attributer(self.mt, attr_methods=attr_methods, device=self.device)
else:
self.model = None
self.mt = None
self.attributer = None
self.w_transform = w_transform.to(self.device)
self.optimizer = optimizer
self.loss_f = loss_f
self.target = target
self.n_print = n_print
def __call__(self, train_loader, test_loader=None, epochs=10):
"""
Trains the model.
Parameters
----------
data_loader: torch.utils.data.DataLoader
epochs: int, optional
Number of epochs to train the model for.
"""
print("Starting Training Loop...")
self.train_losses = np.empty(epochs)
self.test_losses = np.empty(epochs)
for epoch in range(epochs):
if test_loader is not None:
mean_epoch_loss = self._train_epoch(train_loader, epoch)
mean_epoch_test_loss = self._test_epoch(test_loader)
if epoch % self.n_print == 0:
print('\n====> Epoch: {} Average train loss: {:.4f} (Test set loss: {:.4f})'.format(epoch,
mean_epoch_loss,
mean_epoch_test_loss))
self.train_losses[epoch] = mean_epoch_loss
self.test_losses[epoch] = mean_epoch_test_loss
else:
mean_epoch_loss = self._train_epoch(train_loader, epoch)
if epoch % self.n_print == 0:
print('\n====> Epoch: {} Average train loss: {:.4f}'.format(epoch, mean_epoch_loss))
try:
self.train_losses[epoch] = mean_epoch_loss
except:
self.train_losses[epoch] = mean_epoch_loss.real
def _train_epoch(self, data_loader, epoch):
"""
Trains the model for one epoch.
Parameters
----------
data_loader: torch.utils.data.DataLoader
epoch: int
Epoch number
Return
------
mean_epoch_loss: float
"""
self.w_transform.train()
epoch_loss = 0.
for batch_idx, (data, _) in enumerate(data_loader):
iter_loss = self._train_iteration(data)
epoch_loss += iter_loss
if epoch % self.n_print == 0:
print('\rTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(data_loader.dataset),
100. * batch_idx / len(data_loader), iter_loss), end='')
mean_epoch_loss = epoch_loss / (batch_idx + 1)
self.w_transform.eval()
return mean_epoch_loss
def _train_iteration(self, data):
"""
Trains the model for one iteration on a batch of data.
Parameters
----------
data: torch.Tensor
A batch of data. Shape : (batch_size, channel, height, width).
"""
data = data.to(self.device)
# zero grad
self.optimizer.zero_grad()
# transform
data_t = self.w_transform(data)
# reconstruction
recon_data = self.wt_inverse(data_t)
# TRIM score
if self.attributer is not None:
with torch.backends.cudnn.flags(enabled=False):
attributions = self.attributer(
data_t, target=self.target,
additional_forward_args=deepcopy(
data)) if self.loss_f.lamL1attr > 0 else None
else:
attributions = None
# loss
if self.is_parallel:
loss = self.loss_f(self.w_transform.module, data, recon_data, data_t, attributions)
else:
loss = self.loss_f(self.w_transform, data, recon_data, data_t, attributions)
# backward
loss.backward()
# update step
self.optimizer.step()
return loss.item()
def _test_epoch(self, data_loader):
"""
Tests the model for one epoch.
Parameters
----------
data_loader: torch.utils.data.DataLoader
epoch: int
Epoch number
Return
------
mean_epoch_loss: float
"""
self.w_transform.eval()
epoch_loss = 0.
for batch_idx, (data, _) in enumerate(data_loader):
data = data.to(self.device)
data_t = self.w_transform(data)
recon_data = self.wt_inverse(data_t)
attributions = self.attributer(data_t, target=self.target, additional_forward_args=deepcopy(data))
loss = self.loss_f(self.w_transform, data, recon_data, data_t, attributions)
iter_loss = loss.item()
epoch_loss += iter_loss
print('\rTest: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx * len(data), len(data_loader.dataset),
100. * batch_idx / len(data_loader), iter_loss), end
='')
mean_epoch_loss = epoch_loss / (batch_idx + 1)
return mean_epoch_loss
Classes
class Trainer (model=None, w_transform=None, optimizer=None, loss_f=None, target=1, device=device(type='cuda'), use_residuals=True, attr_methods='InputXGradient', n_print=1)
-
Class to handle training of model.
Parameters
model
: optional,torch.model
optimizer
:torch.optim.Optimizer
w_transform
:torch.nn.module
- Wavelet transformer
device
:torch.device
, optional- Device on which to run the code.
use_residuals
:boolean
, optional- Use residuals to compute TRIM score.
Expand source code
class Trainer(): """ Class to handle training of model. Parameters ---------- model: optional, torch.model optimizer: torch.optim.Optimizer w_transform: torch.nn.module Wavelet transformer device: torch.device, optional Device on which to run the code. use_residuals : boolean, optional Use residuals to compute TRIM score. """ def __init__(self, model=None, w_transform=None, optimizer=None, loss_f=None, target=1, device=torch.device("cuda"), use_residuals=True, attr_methods='InputXGradient', n_print=1): self.device = device self.is_parallel = 'data_parallel' in str(type(w_transform)) self.wt_inverse = w_transform.module.inverse if self.is_parallel else w_transform.inverse # use multiple GPUs or not if model is not None: self.model = model.to(self.device) self.mt = TrimModel(model, self.wt_inverse, use_residuals=use_residuals) self.attributer = Attributer(self.mt, attr_methods=attr_methods, device=self.device) else: self.model = None self.mt = None self.attributer = None self.w_transform = w_transform.to(self.device) self.optimizer = optimizer self.loss_f = loss_f self.target = target self.n_print = n_print def __call__(self, train_loader, test_loader=None, epochs=10): """ Trains the model. Parameters ---------- data_loader: torch.utils.data.DataLoader epochs: int, optional Number of epochs to train the model for. """ print("Starting Training Loop...") self.train_losses = np.empty(epochs) self.test_losses = np.empty(epochs) for epoch in range(epochs): if test_loader is not None: mean_epoch_loss = self._train_epoch(train_loader, epoch) mean_epoch_test_loss = self._test_epoch(test_loader) if epoch % self.n_print == 0: print('\n====> Epoch: {} Average train loss: {:.4f} (Test set loss: {:.4f})'.format(epoch, mean_epoch_loss, mean_epoch_test_loss)) self.train_losses[epoch] = mean_epoch_loss self.test_losses[epoch] = mean_epoch_test_loss else: mean_epoch_loss = self._train_epoch(train_loader, epoch) if epoch % self.n_print == 0: print('\n====> Epoch: {} Average train loss: {:.4f}'.format(epoch, mean_epoch_loss)) try: self.train_losses[epoch] = mean_epoch_loss except: self.train_losses[epoch] = mean_epoch_loss.real def _train_epoch(self, data_loader, epoch): """ Trains the model for one epoch. Parameters ---------- data_loader: torch.utils.data.DataLoader epoch: int Epoch number Return ------ mean_epoch_loss: float """ self.w_transform.train() epoch_loss = 0. for batch_idx, (data, _) in enumerate(data_loader): iter_loss = self._train_iteration(data) epoch_loss += iter_loss if epoch % self.n_print == 0: print('\rTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(data_loader.dataset), 100. * batch_idx / len(data_loader), iter_loss), end='') mean_epoch_loss = epoch_loss / (batch_idx + 1) self.w_transform.eval() return mean_epoch_loss def _train_iteration(self, data): """ Trains the model for one iteration on a batch of data. Parameters ---------- data: torch.Tensor A batch of data. Shape : (batch_size, channel, height, width). """ data = data.to(self.device) # zero grad self.optimizer.zero_grad() # transform data_t = self.w_transform(data) # reconstruction recon_data = self.wt_inverse(data_t) # TRIM score if self.attributer is not None: with torch.backends.cudnn.flags(enabled=False): attributions = self.attributer( data_t, target=self.target, additional_forward_args=deepcopy( data)) if self.loss_f.lamL1attr > 0 else None else: attributions = None # loss if self.is_parallel: loss = self.loss_f(self.w_transform.module, data, recon_data, data_t, attributions) else: loss = self.loss_f(self.w_transform, data, recon_data, data_t, attributions) # backward loss.backward() # update step self.optimizer.step() return loss.item() def _test_epoch(self, data_loader): """ Tests the model for one epoch. Parameters ---------- data_loader: torch.utils.data.DataLoader epoch: int Epoch number Return ------ mean_epoch_loss: float """ self.w_transform.eval() epoch_loss = 0. for batch_idx, (data, _) in enumerate(data_loader): data = data.to(self.device) data_t = self.w_transform(data) recon_data = self.wt_inverse(data_t) attributions = self.attributer(data_t, target=self.target, additional_forward_args=deepcopy(data)) loss = self.loss_f(self.w_transform, data, recon_data, data_t, attributions) iter_loss = loss.item() epoch_loss += iter_loss print('\rTest: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx * len(data), len(data_loader.dataset), 100. * batch_idx / len(data_loader), iter_loss), end ='') mean_epoch_loss = epoch_loss / (batch_idx + 1) return mean_epoch_loss