Module src.interpret

Expand source code
import sys
sys.path.append('../../hierarchical-dnn-interpretations') # if pip install doesn't work
import acd
from acd.scores import cd_propagate
import numpy as np
import seaborn as sns
import matplotlib.colors
import matplotlib.pyplot as plt
import torch
import viz


def calc_cd_score(xtrack_t, xfeats_t, start, stop, model):
    with torch.no_grad():
        rel, irrel = cd_propagate.propagate_lstm(xtrack_t.unsqueeze(-1), model.lstm, start=start, stop=stop, my_device='cpu')
    rel = rel.squeeze(1)
    irrel = irrel.squeeze(1)
    rel, irrel = cd_propagate.propagate_conv_linear(rel, irrel, model.fc)
    #return rel.item()
    return rel.data.numpy()

def plot_segs(track_segs, cd_scores, xtrack,
              pred=None, y=None, vabs=None, cbar=True, xticks=True, yticks=True):
    '''Plot a single segmentation plot
    '''
#     cm = sns.diverging_palette(22, 220, as_cmap=True, center='light')
#     cm = LinearSegmentedColormap.from_list(
#         name='orange-blue', 
#         colors=[(222/255, 85/255, 51/255),'lightgray', (50/255, 129/255, 168/255)]
#     )
    if vabs is None:
        vabs = np.max(np.abs(cd_scores))
    norm = matplotlib.colors.Normalize(vmin=-vabs, vmax=vabs)
    #vabs = 1.2
    # plt.plot(xtrack, zorder=0, lw=2, color='#111111')
    for i in range(len(track_segs)):
        (s, e) = track_segs[i]
        cd_score = cd_scores[i]
        seq_len = e - s
        xs = np.arange(s, e)
        if seq_len > 1:
            cd_score = [cd_score] * seq_len
            col = viz.cmap(norm(cd_score[0]))
            while len(col) == 1:
                col = col[0]
            plt.plot(xs, xtrack[s: e], zorder=0, lw=2, color=col, alpha=0.5)
        plt.scatter(xs, xtrack[s: e],
                    c=cd_score, cmap=viz.cmap, vmin=-vabs, vmax=vabs, s=6)
    if pred is not None:
        plt.title(f"Pred: {pred: .1f}, y: {y}", fontsize=24)
    cb = None
    if cbar:
        cb = plt.colorbar() #label='CD Score')
        cb.outline.set_visible(False)
    if not xticks:
        plt.xticks([])
    if not yticks:
        plt.yticks([])
    return cb
    
    
    
def max_abs_sum_seg(scores_list, min_length: int=1):
    """
    score_list[i][j] is the score for the segment from i to j (inclusive)
    Params
    ------
    min_length
        Minimum allowable length for a segment
    """
    
    n = len(scores_list[0])
    res = [0]*n
    paths = {}
    for s in range(n):
        for e in range(s, n):
            if e - s >= min_length - 1:
                scores_list[s][e] = abs(scores_list[s][e])
            else:
                scores_list[s][e] = -10000
    paths[-1] = []
    res[0] = scores_list[0][0]
    paths[0] = [0]
    for i in (range(1, n)):
        cand = [res[j-1] + scores_list[j][i] for j in range(i + 1)]
        seg_start = np.argmax(cand)
        res[i] = max(cand)
        paths[i] = paths[seg_start - 1] + [seg_start]
    return res, paths

Functions

def calc_cd_score(xtrack_t, xfeats_t, start, stop, model)
Expand source code
def calc_cd_score(xtrack_t, xfeats_t, start, stop, model):
    with torch.no_grad():
        rel, irrel = cd_propagate.propagate_lstm(xtrack_t.unsqueeze(-1), model.lstm, start=start, stop=stop, my_device='cpu')
    rel = rel.squeeze(1)
    irrel = irrel.squeeze(1)
    rel, irrel = cd_propagate.propagate_conv_linear(rel, irrel, model.fc)
    #return rel.item()
    return rel.data.numpy()
def max_abs_sum_seg(scores_list, min_length=1)

score_list[i][j] is the score for the segment from i to j (inclusive) Params


min_length
Minimum allowable length for a segment
Expand source code
def max_abs_sum_seg(scores_list, min_length: int=1):
    """
    score_list[i][j] is the score for the segment from i to j (inclusive)
    Params
    ------
    min_length
        Minimum allowable length for a segment
    """
    
    n = len(scores_list[0])
    res = [0]*n
    paths = {}
    for s in range(n):
        for e in range(s, n):
            if e - s >= min_length - 1:
                scores_list[s][e] = abs(scores_list[s][e])
            else:
                scores_list[s][e] = -10000
    paths[-1] = []
    res[0] = scores_list[0][0]
    paths[0] = [0]
    for i in (range(1, n)):
        cand = [res[j-1] + scores_list[j][i] for j in range(i + 1)]
        seg_start = np.argmax(cand)
        res[i] = max(cand)
        paths[i] = paths[seg_start - 1] + [seg_start]
    return res, paths
def plot_segs(track_segs, cd_scores, xtrack, pred=None, y=None, vabs=None, cbar=True, xticks=True, yticks=True)

Plot a single segmentation plot

Expand source code
def plot_segs(track_segs, cd_scores, xtrack,
              pred=None, y=None, vabs=None, cbar=True, xticks=True, yticks=True):
    '''Plot a single segmentation plot
    '''
#     cm = sns.diverging_palette(22, 220, as_cmap=True, center='light')
#     cm = LinearSegmentedColormap.from_list(
#         name='orange-blue', 
#         colors=[(222/255, 85/255, 51/255),'lightgray', (50/255, 129/255, 168/255)]
#     )
    if vabs is None:
        vabs = np.max(np.abs(cd_scores))
    norm = matplotlib.colors.Normalize(vmin=-vabs, vmax=vabs)
    #vabs = 1.2
    # plt.plot(xtrack, zorder=0, lw=2, color='#111111')
    for i in range(len(track_segs)):
        (s, e) = track_segs[i]
        cd_score = cd_scores[i]
        seq_len = e - s
        xs = np.arange(s, e)
        if seq_len > 1:
            cd_score = [cd_score] * seq_len
            col = viz.cmap(norm(cd_score[0]))
            while len(col) == 1:
                col = col[0]
            plt.plot(xs, xtrack[s: e], zorder=0, lw=2, color=col, alpha=0.5)
        plt.scatter(xs, xtrack[s: e],
                    c=cd_score, cmap=viz.cmap, vmin=-vabs, vmax=vabs, s=6)
    if pred is not None:
        plt.title(f"Pred: {pred: .1f}, y: {y}", fontsize=24)
    cb = None
    if cbar:
        cb = plt.colorbar() #label='CD Score')
        cb.outline.set_visible(False)
    if not xticks:
        plt.xticks([])
    if not yticks:
        plt.yticks([])
    return cb