Module awave.losses
Expand source code
import numpy as np
import torch
import torch.nn.functional as F
from awave.utils.misc import low_to_high
def get_loss_f(**kwargs_parse):
"""Return the loss function given the argparse arguments."""
return Loss(lamlSum=kwargs_parse["lamlSum"],
lamhSum=kwargs_parse["lamhSum"],
lamL2norm=kwargs_parse["lamL2norm"],
lamCMF=kwargs_parse["lamCMF"],
lamConv=kwargs_parse["lamConv"],
lamL1wave=kwargs_parse["lamL1wave"],
lamL1attr=kwargs_parse["lamL1attr"])
class Loss():
"""Class of calculating loss functions
"""
def __init__(self, lamlSum=1., lamhSum=1., lamL2norm=1., lamCMF=1., lamConv=1., lamL1wave=1., lamL1attr=1.,
lamHighfreq=0.0):
"""
Parameters
----------
lamlSum : float
Hyperparameter for penalizing sum of lowpass filter
lamhSum : float
Hyperparameter for penalizing sum of highpass filter
lamL2norm : float
Hyperparameter to enforce unit norm of lowpass filter
lamCMF : float
Hyperparameter to enforce conjugate mirror filter
lamConv : float
Hyperparameter to enforce convolution constraint
lamL1wave : float
Hyperparameter for penalizing L1 norm of wavelet coeffs
lamL1attr : float
Hyperparameter for penalizing L1 norm of attributions
"""
self.lamlSum = lamlSum
self.lamhSum = lamhSum
self.lamL2norm = lamL2norm
self.lamCMF = lamCMF
self.lamConv = lamConv
self.lamL1wave = lamL1wave
self.lamL1attr = lamL1attr
self.lamHighfreq = lamHighfreq
def __call__(self, w_transform, data, recon_data, data_t, attributions=None):
"""
Parameters
----------
w_transform : wavelet object
data : torch.Tensor
Input data (e.g. batch of images). Shape : (batch_size, n_chan,
height, width).
recon_data : torch.Tensor
Reconstructed data. Shape : (batch_size, n_chan, height, width).
data_t: list of torch.Tensor
Input data after wavelet transform.
attributions: torch.Tensor
Input attribution scores.
Return
------
loss : torch.Tensor
"""
self.rec_loss = _reconstruction_loss(data, recon_data)
# sum of lowpass filter
self.lsum_loss = 0
if self.lamlSum > 0:
self.lsum_loss += _lsum_loss(w_transform)
# sum of highpass filter
self.hsum_loss = 0
if self.lamhSum > 0:
self.hsum_loss += _hsum_loss(w_transform)
# l2norm of lowpass filter
self.L2norm_loss = 0
if self.lamL2norm > 0:
self.L2norm_loss += _L2norm_loss(w_transform)
# conjugate mirror filter condition
self.CMF_loss = 0
if self.lamCMF > 0:
self.CMF_loss += _CMF_loss(w_transform)
# convolution constraint
self.conv_loss = 0
if self.lamConv > 0:
self.conv_loss += _conv_loss(w_transform)
# L1 penalty on wavelet coeffs
self.L1wave_loss = 0
if self.lamL1wave > 0:
self.L1wave_loss += _L1_wave_loss(data_t)
# L1 penalty on attributions
self.L1attr_loss = 0
if self.lamL1attr > 0 and attributions is not None:
self.L1attr_loss += _L1_attribution_loss(attributions)
# Penalty on high frequency of h0
self.highfreq_loss = 0
if self.lamHighfreq > 0:
self.highfreq_loss += _penalty_high_freq(w_transform)
# total loss
loss = self.rec_loss + self.lamlSum * self.lsum_loss + self.lamhSum * self.hsum_loss + self.lamL2norm * self.L2norm_loss \
+ self.lamCMF * self.CMF_loss + self.lamConv * self.conv_loss + self.lamL1wave * self.L1wave_loss + self.lamL1attr * self.L1attr_loss \
+ self.lamHighfreq * self.highfreq_loss
return loss
def _reconstruction_loss(data, recon_data):
"""
Calculates the per image reconstruction loss for a batch of data. I.e. negative
log likelihood.
Parameters
----------
data : torch.Tensor
Input data (e.g. batch of images). Shape : (batch_size, n_chan,
height, width).
recon_data : torch.Tensor
Reconstructed data. Shape : (batch_size, n_chan, height, width).
Returns
-------
loss : torch.Tensor
Per image cross entropy (i.e. normalized per batch but not pixel and
channel)
"""
batch_size = recon_data.size(0)
loss = F.mse_loss(recon_data, data, reduction="sum")
loss = loss / batch_size
return loss
def _lsum_loss(w_transform):
"""
Calculate sum of lowpass filter
"""
h0 = w_transform.h0
loss = .5 * (h0.sum() - np.sqrt(2)) ** 2
return loss
def _hsum_loss(w_transform):
"""
Calculate sum of highpass filter
"""
h0 = w_transform.h0
h1 = low_to_high(h0)
loss = .5 * h1.sum() ** 2
return loss
def _L2norm_loss(w_transform):
"""
Calculate L2 norm of lowpass filter
"""
h0 = w_transform.h0
loss = .5 * ((h0 ** 2).sum() - 1) ** 2
return loss
def _CMF_loss(w_transform):
"""
Calculate conjugate mirror filter condition
"""
h0 = w_transform.h0
n = h0.size(2)
assert n % 2 == 0, "length of lowpass filter should be even"
try:
h_f = torch.fft.fft(torch.stack((h0, torch.zeros_like(h0)), dim=3), 1)
except:
h_f = torch.fft(torch.stack((h0, torch.zeros_like(h0)), dim=3), 1)
mod = (h_f ** 2).sum(axis=3)
cmf_identity = mod[0, 0, :n // 2] + mod[0, 0, n // 2:]
loss = .5 * torch.sum((cmf_identity - 2) ** 2)
return loss
def _conv_loss(w_transform):
"""
Calculate convolution of lowpass filter
"""
h0 = w_transform.h0
n = h0.size(2)
assert n % 2 == 0, "length of lowpass filter should be even"
v = F.conv1d(h0, h0, stride=2, padding=n)
e = torch.zeros_like(v)
e[0, 0, n // 2] = 1
loss = .5 * torch.sum((v - e) ** 2)
return loss
def _L1_wave_loss(coeffs):
"""
Calculate L1 norm of wavelet coefficients
"""
batch_size = coeffs[0].size(0)
loss = tuple_L1Loss(coeffs)
loss = loss / batch_size
return loss
def _L1_attribution_loss(attributions):
"""
Calculate L1 norm of the attributions
"""
batch_size = attributions[0].size(0)
loss = tuple_L1Loss(attributions)
loss = loss / batch_size
return loss
def _penalty_high_freq(w_transform):
# pen high frequency of h0
n = w_transform.h0.size(2)
h_f = torch.fft(torch.stack((w_transform.h0, torch.zeros_like(w_transform.h0)), dim=3), 1)
mod = (h_f ** 2).sum(axis=3)
left = int(np.floor(n / 4) + 1)
right = int(np.ceil(3 * n / 4) - 1)
h0_hf = mod[0, 0, left:right + 1]
loss = 0.5 * torch.norm(h0_hf) ** 2
return loss
def tuple_L1Loss(x):
output = 0
num = len(x)
for i in range(num):
output += torch.sum(abs(x[i]))
return output / num
def tuple_L2Loss(x):
output = 0
num = len(x)
for i in range(num):
output += torch.sum(x[i] ** 2)
return output / num
Functions
def get_loss_f(**kwargs_parse)
-
Return the loss function given the argparse arguments.
Expand source code
def get_loss_f(**kwargs_parse): """Return the loss function given the argparse arguments.""" return Loss(lamlSum=kwargs_parse["lamlSum"], lamhSum=kwargs_parse["lamhSum"], lamL2norm=kwargs_parse["lamL2norm"], lamCMF=kwargs_parse["lamCMF"], lamConv=kwargs_parse["lamConv"], lamL1wave=kwargs_parse["lamL1wave"], lamL1attr=kwargs_parse["lamL1attr"])
def tuple_L1Loss(x)
-
Expand source code
def tuple_L1Loss(x): output = 0 num = len(x) for i in range(num): output += torch.sum(abs(x[i])) return output / num
def tuple_L2Loss(x)
-
Expand source code
def tuple_L2Loss(x): output = 0 num = len(x) for i in range(num): output += torch.sum(x[i] ** 2) return output / num
Classes
class Loss (lamlSum=1.0, lamhSum=1.0, lamL2norm=1.0, lamCMF=1.0, lamConv=1.0, lamL1wave=1.0, lamL1attr=1.0, lamHighfreq=0.0)
-
Class of calculating loss functions
Parameters
lamlSum
:float
- Hyperparameter for penalizing sum of lowpass filter
lamhSum
:float
- Hyperparameter for penalizing sum of highpass filter
lamL2norm
:float
- Hyperparameter to enforce unit norm of lowpass filter
lamCMF
:float
- Hyperparameter to enforce conjugate mirror filter
lamConv
:float
- Hyperparameter to enforce convolution constraint
lamL1wave
:float
- Hyperparameter for penalizing L1 norm of wavelet coeffs
lamL1attr
:float
- Hyperparameter for penalizing L1 norm of attributions
Expand source code
class Loss(): """Class of calculating loss functions """ def __init__(self, lamlSum=1., lamhSum=1., lamL2norm=1., lamCMF=1., lamConv=1., lamL1wave=1., lamL1attr=1., lamHighfreq=0.0): """ Parameters ---------- lamlSum : float Hyperparameter for penalizing sum of lowpass filter lamhSum : float Hyperparameter for penalizing sum of highpass filter lamL2norm : float Hyperparameter to enforce unit norm of lowpass filter lamCMF : float Hyperparameter to enforce conjugate mirror filter lamConv : float Hyperparameter to enforce convolution constraint lamL1wave : float Hyperparameter for penalizing L1 norm of wavelet coeffs lamL1attr : float Hyperparameter for penalizing L1 norm of attributions """ self.lamlSum = lamlSum self.lamhSum = lamhSum self.lamL2norm = lamL2norm self.lamCMF = lamCMF self.lamConv = lamConv self.lamL1wave = lamL1wave self.lamL1attr = lamL1attr self.lamHighfreq = lamHighfreq def __call__(self, w_transform, data, recon_data, data_t, attributions=None): """ Parameters ---------- w_transform : wavelet object data : torch.Tensor Input data (e.g. batch of images). Shape : (batch_size, n_chan, height, width). recon_data : torch.Tensor Reconstructed data. Shape : (batch_size, n_chan, height, width). data_t: list of torch.Tensor Input data after wavelet transform. attributions: torch.Tensor Input attribution scores. Return ------ loss : torch.Tensor """ self.rec_loss = _reconstruction_loss(data, recon_data) # sum of lowpass filter self.lsum_loss = 0 if self.lamlSum > 0: self.lsum_loss += _lsum_loss(w_transform) # sum of highpass filter self.hsum_loss = 0 if self.lamhSum > 0: self.hsum_loss += _hsum_loss(w_transform) # l2norm of lowpass filter self.L2norm_loss = 0 if self.lamL2norm > 0: self.L2norm_loss += _L2norm_loss(w_transform) # conjugate mirror filter condition self.CMF_loss = 0 if self.lamCMF > 0: self.CMF_loss += _CMF_loss(w_transform) # convolution constraint self.conv_loss = 0 if self.lamConv > 0: self.conv_loss += _conv_loss(w_transform) # L1 penalty on wavelet coeffs self.L1wave_loss = 0 if self.lamL1wave > 0: self.L1wave_loss += _L1_wave_loss(data_t) # L1 penalty on attributions self.L1attr_loss = 0 if self.lamL1attr > 0 and attributions is not None: self.L1attr_loss += _L1_attribution_loss(attributions) # Penalty on high frequency of h0 self.highfreq_loss = 0 if self.lamHighfreq > 0: self.highfreq_loss += _penalty_high_freq(w_transform) # total loss loss = self.rec_loss + self.lamlSum * self.lsum_loss + self.lamhSum * self.hsum_loss + self.lamL2norm * self.L2norm_loss \ + self.lamCMF * self.CMF_loss + self.lamConv * self.conv_loss + self.lamL1wave * self.L1wave_loss + self.lamL1attr * self.L1attr_loss \ + self.lamHighfreq * self.highfreq_loss return loss