Module awave.trim.attributions
Expand source code
import numpy as np
import torch
import acd
from copy import deepcopy
import sys
from awave.trim.util import *
from numpy.fft import *
from torch import nn
from captum.attr import *
from awave.trim.trim import *
sys.path.append('../..')
def get_attributions(x_t: torch.Tensor,
mt,
class_num=1,
attr_methods = ['IG', 'DeepLift', 'SHAP', 'CD', 'InputXGradient'],
device='cuda'):
'''Returns all scores in a dict assuming mt works with both grads + CD
Params
------
mt: model
class_num: target class
'''
x_t = x_t.to(device)
x_t.requires_grad = True
mt = mt.to(device)
mt.eval()
results = {}
if 'CD' in attr_methods:
attr_funcs = [IntegratedGradients, DeepLift, GradientShap, None, InputXGradient]
else:
attr_funcs = [IntegratedGradients, DeepLift, GradientShap, InputXGradient]
for name, func in zip(attr_methods, attr_funcs):
if name == 'CD':
with torch.no_grad():
sweep_dim = 1
tiles = acd.tiling_2d.gen_tiles(x_t[0,0,...,0], fill=0, method='cd', sweep_dim=sweep_dim)
if x_t.shape[-1] == 2: # check for imaginary representations
tiles = np.repeat(np.expand_dims(tiles, axis=-1), repeats=2, axis=3).squeeze()
tiles = torch.Tensor(tiles).unsqueeze(1)
attributions = acd.get_scores_2d(mt, method='cd', ims=tiles, im_torch=x_t)[..., class_num].T.reshape(-1,28,28).squeeze()
# attributions = score_funcs.get_scores_2d(mt, method='cd', ims=tiles, im_torch=x_t)[..., class_num].T.reshape(-1,28,28)
else:
baseline = torch.zeros(x_t.shape).to(device)
attributer = func(mt)
if name in ['InputXGradient']:
attributions = attributer.attribute(deepcopy(x_t), target=class_num)
else:
attributions = attributer.attribute(deepcopy(x_t), deepcopy(baseline), target=class_num)
attributions = attributions.cpu().detach().numpy().squeeze()
if x_t.shape[-1] == 2: # check for imaginary representations
attributions = mag(attributions)
results[name] = attributions
return results
Functions
def get_attributions(x_t, mt, class_num=1, attr_methods=['IG', 'DeepLift', 'SHAP', 'CD', 'InputXGradient'], device='cuda')
-
Returns all scores in a dict assuming mt works with both grads + CD
Params
mt
:model
class_num
:target
class
Expand source code
def get_attributions(x_t: torch.Tensor, mt, class_num=1, attr_methods = ['IG', 'DeepLift', 'SHAP', 'CD', 'InputXGradient'], device='cuda'): '''Returns all scores in a dict assuming mt works with both grads + CD Params ------ mt: model class_num: target class ''' x_t = x_t.to(device) x_t.requires_grad = True mt = mt.to(device) mt.eval() results = {} if 'CD' in attr_methods: attr_funcs = [IntegratedGradients, DeepLift, GradientShap, None, InputXGradient] else: attr_funcs = [IntegratedGradients, DeepLift, GradientShap, InputXGradient] for name, func in zip(attr_methods, attr_funcs): if name == 'CD': with torch.no_grad(): sweep_dim = 1 tiles = acd.tiling_2d.gen_tiles(x_t[0,0,...,0], fill=0, method='cd', sweep_dim=sweep_dim) if x_t.shape[-1] == 2: # check for imaginary representations tiles = np.repeat(np.expand_dims(tiles, axis=-1), repeats=2, axis=3).squeeze() tiles = torch.Tensor(tiles).unsqueeze(1) attributions = acd.get_scores_2d(mt, method='cd', ims=tiles, im_torch=x_t)[..., class_num].T.reshape(-1,28,28).squeeze() # attributions = score_funcs.get_scores_2d(mt, method='cd', ims=tiles, im_torch=x_t)[..., class_num].T.reshape(-1,28,28) else: baseline = torch.zeros(x_t.shape).to(device) attributer = func(mt) if name in ['InputXGradient']: attributions = attributer.attribute(deepcopy(x_t), target=class_num) else: attributions = attributer.attribute(deepcopy(x_t), deepcopy(baseline), target=class_num) attributions = attributions.cpu().detach().numpy().squeeze() if x_t.shape[-1] == 2: # check for imaginary representations attributions = mag(attributions) results[name] = attributions return results