Module awave.utils.evaluate
Expand source code
from copy import deepcopy
import torch
from awave.losses import _reconstruction_loss, _lsum_loss, _hsum_loss, _L2norm_loss, _CMF_loss, \
_conv_loss, _L1_wave_loss, _L1_attribution_loss
from awave.utils.wave_attributions import Attributer
from awave.trim import TrimModel
class Validator():
"""
Class to handle training of model.
Parameters
----------
model: torch.model
data_loader: torch.utils.data.DataLoader
device: torch.device, optional
Device on which to run the code.
use_residuals : boolean, optional
Use residuals to compute TRIM score.
"""
def __init__(self, model, data_loader,
device=torch.device("cuda"),
use_residuals=True):
self.device = device
self.model = model.to(self.device)
self.data_loader = data_loader
self.use_residuals = use_residuals
def __call__(self, w_transform, target=1):
"""
Tests the model for one epoch.
Parameters
----------
w_transform: torch.nn.module
Wavelet transformer
Return
------
mean_epoch_loss: float
"""
w_transform = w_transform.to(self.device)
w_transform = w_transform.eval()
is_parallel = 'data_parallel' in str(type(w_transform))
wt_inverse = w_transform.module.inverse if is_parallel else w_transform.inverse # use multiple GPUs or not
mt = TrimModel(self.model, wt_inverse, use_residuals=self.use_residuals)
Saliency = Attributer(mt, attr_methods='Saliency', is_train=False, device=self.device)
Inputxgrad = Attributer(mt, attr_methods='InputXGradient', is_train=False, device=self.device)
rec_loss = 0.
lsum_loss = 0.
hsum_loss = 0.
L2norm_loss = 0.
CMF_loss = 0.
conv_loss = 0.
L1wave_loss = 0.
L1saliency_loss = 0.
L1inputxgrad_loss = 0.
for batch_idx, (data, _) in enumerate(self.data_loader):
data = data.to(self.device)
data_t = w_transform(data)
recon_data = wt_inverse(data_t)
saliency = Saliency(data_t, target=target, additional_forward_args=deepcopy(data))
inputxgrad = Inputxgrad(data_t, target=target, additional_forward_args=deepcopy(data))
rec_loss += _reconstruction_loss(data, recon_data).item()
lsum_loss += _lsum_loss(w_transform.module).item() if is_parallel else _lsum_loss(w_transform).item()
hsum_loss += _hsum_loss(w_transform.module).item() if is_parallel else _hsum_loss(w_transform).item()
L2norm_loss += _L2norm_loss(w_transform.module).item() if is_parallel else _L2norm_loss(w_transform).item()
CMF_loss += _CMF_loss(w_transform.module).item() if is_parallel else _CMF_loss(w_transform).item()
conv_loss += _conv_loss(w_transform.module).item() if is_parallel else _conv_loss(w_transform).item()
L1wave_loss += _L1_wave_loss(data_t).item()
L1saliency_loss += _L1_attribution_loss(saliency).item()
L1inputxgrad_loss += _L1_attribution_loss(inputxgrad).item()
mean_rec_loss = rec_loss / (batch_idx + 1)
mean_lsum_loss = lsum_loss / (batch_idx + 1)
mean_hsum_loss = hsum_loss / (batch_idx + 1)
mean_L2norm_loss = L2norm_loss / (batch_idx + 1)
mean_CMF_loss = CMF_loss / (batch_idx + 1)
mean_conv_loss = conv_loss / (batch_idx + 1)
mean_L1wave_loss = L1wave_loss / (batch_idx + 1)
mean_L1saliency_loss = L1saliency_loss / (batch_idx + 1)
mean_L1inputxgrad_loss = L1inputxgrad_loss / (batch_idx + 1)
return (mean_rec_loss, mean_lsum_loss, mean_hsum_loss,
mean_L2norm_loss, mean_CMF_loss, mean_conv_loss,
mean_L1wave_loss, mean_L1saliency_loss, mean_L1inputxgrad_loss)
Classes
class Validator (model, data_loader, device=device(type='cuda'), use_residuals=True)
-
Class to handle training of model.
Parameters
model
:torch.model
data_loader
:torch.utils.data.DataLoader
device
:torch.device
, optional- Device on which to run the code.
use_residuals
:boolean
, optional- Use residuals to compute TRIM score.
Expand source code
class Validator(): """ Class to handle training of model. Parameters ---------- model: torch.model data_loader: torch.utils.data.DataLoader device: torch.device, optional Device on which to run the code. use_residuals : boolean, optional Use residuals to compute TRIM score. """ def __init__(self, model, data_loader, device=torch.device("cuda"), use_residuals=True): self.device = device self.model = model.to(self.device) self.data_loader = data_loader self.use_residuals = use_residuals def __call__(self, w_transform, target=1): """ Tests the model for one epoch. Parameters ---------- w_transform: torch.nn.module Wavelet transformer Return ------ mean_epoch_loss: float """ w_transform = w_transform.to(self.device) w_transform = w_transform.eval() is_parallel = 'data_parallel' in str(type(w_transform)) wt_inverse = w_transform.module.inverse if is_parallel else w_transform.inverse # use multiple GPUs or not mt = TrimModel(self.model, wt_inverse, use_residuals=self.use_residuals) Saliency = Attributer(mt, attr_methods='Saliency', is_train=False, device=self.device) Inputxgrad = Attributer(mt, attr_methods='InputXGradient', is_train=False, device=self.device) rec_loss = 0. lsum_loss = 0. hsum_loss = 0. L2norm_loss = 0. CMF_loss = 0. conv_loss = 0. L1wave_loss = 0. L1saliency_loss = 0. L1inputxgrad_loss = 0. for batch_idx, (data, _) in enumerate(self.data_loader): data = data.to(self.device) data_t = w_transform(data) recon_data = wt_inverse(data_t) saliency = Saliency(data_t, target=target, additional_forward_args=deepcopy(data)) inputxgrad = Inputxgrad(data_t, target=target, additional_forward_args=deepcopy(data)) rec_loss += _reconstruction_loss(data, recon_data).item() lsum_loss += _lsum_loss(w_transform.module).item() if is_parallel else _lsum_loss(w_transform).item() hsum_loss += _hsum_loss(w_transform.module).item() if is_parallel else _hsum_loss(w_transform).item() L2norm_loss += _L2norm_loss(w_transform.module).item() if is_parallel else _L2norm_loss(w_transform).item() CMF_loss += _CMF_loss(w_transform.module).item() if is_parallel else _CMF_loss(w_transform).item() conv_loss += _conv_loss(w_transform.module).item() if is_parallel else _conv_loss(w_transform).item() L1wave_loss += _L1_wave_loss(data_t).item() L1saliency_loss += _L1_attribution_loss(saliency).item() L1inputxgrad_loss += _L1_attribution_loss(inputxgrad).item() mean_rec_loss = rec_loss / (batch_idx + 1) mean_lsum_loss = lsum_loss / (batch_idx + 1) mean_hsum_loss = hsum_loss / (batch_idx + 1) mean_L2norm_loss = L2norm_loss / (batch_idx + 1) mean_CMF_loss = CMF_loss / (batch_idx + 1) mean_conv_loss = conv_loss / (batch_idx + 1) mean_L1wave_loss = L1wave_loss / (batch_idx + 1) mean_L1saliency_loss = L1saliency_loss / (batch_idx + 1) mean_L1inputxgrad_loss = L1inputxgrad_loss / (batch_idx + 1) return (mean_rec_loss, mean_lsum_loss, mean_hsum_loss, mean_L2norm_loss, mean_CMF_loss, mean_conv_loss, mean_L1wave_loss, mean_L1saliency_loss, mean_L1inputxgrad_loss)