Module src.viz
Expand source code
import pickle as pkl
import sys
sys.path.append('..')
import config
import data
from os.path import join as oj
import matplotlib.gridspec as grd
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib_venn import venn2
from sklearn import decomposition
from sklearn import metrics
from sklearn.covariance import EllipticEnvelope
from sklearn.ensemble import IsolationForest
from sklearn.neighbors import LocalOutlierFactor
from sklearn.svm import OneClassSVM
from sklearn.utils.multiclass import unique_labels
import os
import matplotlib.ticker as mtick
from config import DIR_FIGS
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import cm
from matplotlib.colors import ListedColormap
import dvu
# DIR_FILE = os.path.dirname(os.path.realpath(__file__)) # directory of this file
# DIR_FIGS = oj(DIR_FILE, '../reports/figs')
try:
from skimage.external.tifffile import imread
except:
from skimage.io import imread
cb2 = '#66ccff'
cb = '#1f77b4'
co = '#ff7f0e'
cr = '#cc0000'
cp = '#cc3399'
cy = '#d8b365'
cg = '#5ab4ac'
cmap = LinearSegmentedColormap.from_list(
name='orange-blue',
colors=[(205/255, 85/255, 51/255),
'lightgray',
(50/255, 129/255, 168/255)]
)
def savefig(s: str, png=False):
# plt.tight_layout()
plt.savefig(oj(DIR_FIGS, 'fig_' + s + '.pdf'), bbox_inches='tight')
if png:
plt.savefig(oj(DIR_FIGS, 'fig_' + s + '.png'), dpi=300, bbox_inches='tight')
def fix_feat_name(s):
return s.replace('_', ' ').replace('X', 'Clath').capitalize()
def plot_confusion_matrix(y_true, y_pred, classes,
normalize=False,
title=None,
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
Params
------
classes: np.ndarray(Str)
classes=np.array(['aux-', 'aux+'])
"""
plt.figure(dpi=300)
if not title:
if normalize:
title = 'Normalized confusion matrix'
else:
title = 'Confusion matrix, without normalization'
# Compute confusion matrix
cm = metrics.confusion_matrix(y_true, y_pred)
# Only use the labels that appear in the data
classes = classes[unique_labels(y_true.astype(np.int), y_pred.astype(np.int))]
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
# fig, ax = plt.subplots()
im = plt.imshow(cm, interpolation='nearest', cmap=cmap)
ax = plt.gca()
# ax.figure.colorbar(im, ax=ax)
# We want to show all ticks...
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
# ... and label them with the respective list entries
xticklabels=classes, yticklabels=classes,
# title=title,
ylabel='True label',
xlabel='Predicted label')
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
return ax
def highlight_max(data, color='#0e5c99'):
'''
highlight the maximum in a Series or DataFrame
'''
attr = 'background-color: {}'.format(color)
if data.ndim == 1: # Series from .apply(axis=0) or axis=1
is_max = data == data.max()
return [attr if v else '' for v in is_max]
else: # from .apply(axis=None)
is_max = data == data.max().max()
return pd.DataFrame(np.where(is_max, attr, ''),
index=data.index, columns=data.columns)
# visualize biggest errs
def viz_biggest_errs(df, idxs_cv, idxs, Y_test, preds, preds_proba,
num_to_plot=20,
aux_thresh=642,
show_track_num=True,
show_track_pid=False,
sort_by_residuals=True,
width_mult=3,
plot_x=True,
plot_y=True,
plot_z=False,
plot_axhline=True,
xlim_constant=True,
ylim: tuple=None,
yticks=None,
yticklabels=None,
lifetime_max=None,
text_labels=False,
text_label_size=25):
'''Visualize X and Y where the top examples are the most wrong / least confident
Params
------
idxs_cv: integer ndarray
which idxs are not part of the test set (usually just 0, 1, 2, ...)
idxs: boolean ndarray
subset of points to plot
'''
DIFF = 0 # use this to ensure values are all positive
# deal with idxs
if idxs is not None:
Y_test = Y_test[idxs]
preds = preds[idxs]
preds_proba = preds_proba[idxs]
if idxs_cv is None:
idxs_cv = np.arange(df.shape[0])
df = df.iloc[idxs_cv][idxs]
# get args to sort by
if sort_by_residuals:
residuals = np.abs(Y_test - preds_proba)
args = np.argsort(residuals)[::-1]
dft = df.iloc[args]
else:
dft = df
if lifetime_max is None:
lifetime_max = np.max(dft.lifetime.values)
if num_to_plot is None:
num_to_plot = dft.shape[0]
R = int(np.sqrt(num_to_plot))
C = num_to_plot // R # + 1
plt.figure(figsize=(C * width_mult, R * 2.5), dpi=200)
i = 0
for r in range(R):
for c in range(C):
if i < dft.shape[0]:
row = dft.iloc[i]
ax = plt.subplot(R, C, i + 1)
# show nums on tracks
if show_track_num:
ax.text(.5, .9, f'{i}', # row.pid
horizontalalignment='right',
transform=ax.transAxes)
elif show_track_pid:
ax.text(.5, .9, f'{row.pid}', # row.pid
horizontalalignment='right',
transform=ax.transAxes)
# plt.axis('off')
if '1.5s' in row['cell_num']:
interval = 1.5
else:
interval = 1
if plot_x:
plt.plot(interval * np.arange(len(row["X"])), np.array(row["X"]) + DIFF, color=cr, label='clath', lw=2) # could do X_extended
if plot_y:
plt.plot(interval * np.arange(len(row["Y"])), np.array(row["Y"]) + DIFF, color=cg, label='aux', lw=2)
if plot_z:
plt.plot(interval * np.arange(len(row["Z"])), np.array(row["Z"]) + DIFF, color='gray', label='dyn')
if xlim_constant:
plt.xlim([-1, lifetime_max])
if plot_axhline:
plt.axhline(aux_thresh, color='gray', alpha=0.5, lw=2)
#plt.yscale('log')
if ylim is not None:
plt.ylim((ylim[0] + DIFF, ylim[1] + DIFF))
if not r == R - 1:
plt.xticks([])
if not c == 0:
plt.yticks([])
elif yticks is not None:
plt.yticks(yticks, labels=yticklabels)
i += 1
if text_labels:
plt.text(len(row["X"]), row["X"][-1] + DIFF, 'Clathrin', color=cr,
fontsize=text_label_size, fontweight='bold')
plt.text(len(row["Y"]), row["Y"][-1] + DIFF, 'Auxilin', color=cg,
fontsize=text_label_size, fontweight='bold')
if plot_z:
plt.text(len(row["Z"]), row["Z"][-1] + DIFF, 'Dynamin',
fontsize=text_label_size, color='gray', fontweight='bold')
plt.tight_layout()
return dft
def viz_errs_2d(df, idxs_test, preds, Y_test, key1='x_pos', key2='y_pos', X=None, plot_correct=True):
'''visualize distribution of errs wrt to 2 dimensions
'''
x_pos = df[key1].iloc[idxs_test]
y_pos = df[key2].iloc[idxs_test]
plt.figure(dpi=200)
ms = 4
me = 1
if plot_correct:
plt.plot(x_pos[(preds == Y_test) & (preds == 1)], y_pos[(preds == Y_test) & (preds == 1)], 'o',
color=cb, alpha=0.4, label='true pos', ms=ms, markeredgewidth=0)
plt.plot(x_pos[(preds == Y_test) & (preds == 0)], y_pos[(preds == Y_test) & (preds == 0)], 'o',
color=cr, alpha=0.4, label='true neg', ms=ms, markeredgewidth=0)
plt.plot(x_pos[preds > Y_test], y_pos[preds > Y_test], 'x', color=cb,
alpha=0.4, label='false pos', ms=ms, markeredgewidth=1)
plt.plot(x_pos[preds < Y_test], y_pos[preds < Y_test], 'x', color=cr,
alpha=0.4, label='false neg', ms=ms, markeredgewidth=1)
plt.legend()
# plt.scatter(x_pos, y_pos, c=preds==Y_test, alpha=0.5)
plt.xlabel(key1)
plt.ylabel(key2)
plt.tight_layout()
def viz_errs_1d(X_test, preds, preds_proba, Y_test, norms, key='lifetime'):
'''visualize errs based on lifetime
'''
plt.figure(dpi=200)
correct_idxs = preds == Y_test
lifetime = X_test[key] * norms[key]['std'] + norms[key]['mu']
plt.plot(lifetime[(preds == Y_test) & (preds == 1)], preds_proba[(preds == Y_test) & (preds == 1)], 'o',
color=cb, alpha=0.5, label='true pos')
plt.plot(lifetime[(preds == Y_test) & (preds == 0)], preds_proba[(preds == Y_test) & (preds == 0)], 'x',
color=cb, alpha=0.5, label='true neg')
plt.plot(lifetime[preds > Y_test], preds_proba[preds > Y_test], 'o', color=cr, alpha=0.5, label='false pos')
plt.plot(lifetime[preds < Y_test], preds_proba[preds < Y_test], 'x', color=cr, alpha=0.5, label='false neg')
plt.xlabel(key)
plt.ylabel('predicted probability')
plt.legend()
plt.show()
def plot_above_threshold(x1, y1, b1, x2, y2, b2, ax, color, lsty):
sl1 = (y2 - y1)/(x2 - x1)
sl2 = (b2 - b1)/(x2 - x1)
if y1 >= b1 and y2 >= b2:
ax.plot([x1, x2], [y1, y2], linestyle=lsty, color=color, alpha=1)
elif y1 < b1 and y2 < b2:
ax.plot([x1, x2], [y1, y2], linestyle=lsty, color=color, alpha=.1)
elif y1 >= b1 and y2 < b2:
crosspoint_x, crosspoint_y = x1 + (y1 - b1)/(sl2 - sl1), y1 + sl1 * (y1 - b1)/(sl2 - sl1)
ax.plot([x1, crosspoint_x], [y1, crosspoint_y], linestyle=lsty, color=color, alpha=1)
ax.plot([crosspoint_x, x2], [crosspoint_y, y2], linestyle=lsty, color=color, alpha=.1)
elif y1 < b1 and y2 >= b2:
crosspoint_x, crosspoint_y = x1 + (y1 - b1)/(sl2 - sl1), y1 + sl1 * (y1 - b1)/(sl2 - sl1)
ax.plot([x1, crosspoint_x], [y1, crosspoint_y], linestyle=lsty, color=color, alpha=.1)
ax.plot([crosspoint_x, x2], [crosspoint_y, y2], linestyle=lsty, color=color, alpha=1)
def plot_background(interval, bg, trace, color, ax):
ax.fill_between(interval * np.arange(len(bg)),
[0] * len(bg),
2 * np.array(bg),
alpha=.1,
color=color)
x, y, lt = np.arange(len(trace)), np.array(trace), len(trace)
#ax.plot(interval * x, y, linestyle='--', color=cr, alpha=.2)
bg = 2 * np.array(bg)
for f in range(lt - 1):
lsty = '--' if f < 5 or f >= lt - 5 else '-'
plot_above_threshold(x1=interval*x[f],
y1=y[f],
b1=bg[f],
x2=interval*x[f+1],
y2=y[f+1],
b2=bg[f+1],
ax=ax,
color=color,
lsty=lsty)
def plot_curves(df, extra_key=None, extra_key_label=None,
hline=True, R=5, C=8,
xlim=None,
fig=None, ylim_constant=False, background=False, ylim_cla=None,
ylim_aux=None, ylim_dyn=None,
xlim_constant=True, legend=True, plot_x=True,
yticks=None, yticklabels=None, num_axes=3, show_track_pid=False,
axes_invisible=False):
'''Plot time-series curves from df
'''
DIFF = 0
if fig is None:
plt.figure(figsize=(16, 10), dpi=200, facecolor='white')
lifetime_max = np.max(df.lifetime.values[:R * C])
df = df.iloc[range(R * C)]
for i in range(R * C):
if i < df.shape[0]:
ax = plt.subplot(R, C, i + 1)
row = df.iloc[i]
if '1.5s' in row['cell_num']:
interval = 1.5
else:
interval = 1
if num_axes == 1:
if plot_x:
plt.plot(interval * np.arange(len(row.X_extended)), np.array(row.X_extended) + DIFF, linestyle='--', color=cr)
plt.plot(interval * np.arange(len(row.Y_extended)), np.array(row.Y_extended) + DIFF, linestyle='--', color=cg)
plt.plot(interval * np.arange(5, len(row.X_extended)-5), np.array(row.X_extended)[5:(-5)] + DIFF, color=cr, label='Clathrin')
plt.plot(interval * np.arange(5, len(row.Y_extended)-5), np.array(row.Y_extended)[5:(-5)] + DIFF, color=cg, label='Auxilin')
#plt.plot(interval * np.arange(5), np.array(row.X_extended)[-5:] + DIFF, linestyle='--', color=cr, label='Clathrin')
#plt.plot(interval * np.arange(5), np.array(row.Y_extended)[-5:] + DIFF, linestyle='--', color=cg, label='Auxilin')
if background:
ax.plot(interval * np.arange(len(row.X_extended)), np.array(row.X_c_extended),
color=cr, linewidth=.8)
ax.fill_between(interval * np.arange(len(row.X_extended)),
np.array(row.X_extended) - np.array(row.X_std_extended),
np.array(row.X_extended) + np.array(row.X_std_extended),
alpha=.2,
color=cr
)
if hline:
plt.axhline(642.3754691658837, color='gray', alpha=0.5)
if extra_key is not None:
if extra_key_label is None:
if extra_key == 'Z':
extra_key_label = 'Dynamin'
else:
extra_key_label = extra_key
plt.plot(interval * np.arange(len(row[extra_key])), np.array(row[extra_key]) + DIFF, linestyle='--', color='gray')
plt.plot(interval * np.arange(5, len(row[extra_key])-5), np.array(row[extra_key])[5:(-5)] + DIFF, color='gray', label=extra_key_label)
if xlim_constant:
if xlim is None:
plt.xlim([-1, lifetime_max + 1])
else:
print(xlim)
plt.xlim(xlim)
if ylim_constant:
if ylim is None:
plt.ylim([-10, max(max(df.X_max), max(df.Y_max)) + 1])
else:
plt.ylim(ylim[0] + DIFF, ylim[1] + DIFF)
if yticks is not None:
plt.yticks(yticks, labels=yticklabels)
else:
ax.spines['right'].set_visible(True)
twin1 = ax.twinx()
if num_axes == 3:
twin2 = ax.twinx()
twin2.spines['right'].set_visible(True)
twin2.spines['right'].set_position(("axes", 1.2))
else:
twin2 = twin1
if show_track_pid:
ax.text(.5, .9, f'{row.pid}', # row.pid
horizontalalignment='right',
transform=ax.transAxes)
if plot_x:
p1, = ax.plot(interval * np.arange(len(row.X_extended)), np.array(row.X_extended) + DIFF, linestyle='--', color=cr, alpha=.1)
if i == 0:
ax.text(x=interval * len(row.X_extended),
y=np.array(row.X_extended)[-1],
s='CLTA-TagRFP',
color=cr,
size=8)
if background:
plot_background(interval, row.X_sigma_extended, row.X_extended, color=cr, ax=ax)
else:
ax.plot(interval * np.arange(5, len(row.X_extended)-5), np.array(row.X_extended)[5:(-5)] + DIFF, color=cr)
if i == 0 and legend:
dvu.line_legend()
ax.fill_between(interval * np.arange(len(row.X_extended)),
np.array(row.X_extended) - np.array(row.X_std_extended),
np.array(row.X_extended) + np.array(row.X_std_extended),
alpha=.2,
color=cr
)
p2, = twin1.plot(interval * np.arange(len(row.Y_extended)), np.array(row.Y_extended) + DIFF, linestyle='--', color=cg, alpha=.1)
if background:
plot_background(interval, row.Y_sigma_extended, row.Y_extended, color=cg, ax=twin1
)
else:
twin1.plot(interval * np.arange(5, len(row.Y_extended)-5), np.array(row.Y_extended)[5:(-5)] + DIFF, color=cg, label='EGFP-Aux1-GAK-F6')
if i == 0 and legend:
dvu.line_legend()
twin1.fill_between(interval * np.arange(len(row.Y_extended)),
np.array(row.Y_extended) - np.array(row.Y_std_extended),
np.array(row.Y_extended) + np.array(row.Y_std_extended),
alpha=.2,
color=cg
)
if i == 0:
twin1.text(x=interval * len(row.Y_extended),
y=np.array(row.Y_extended)[-1],
s='EGFP-Aux1-GAK-F6',
color=cg,
size=8)
#plt.plot(interval * np.arange(5), np.array(row.X_extended)[-5:] + DIFF, linestyle='--', color=cr, label='Clathrin')
#plt.plot(interval * np.arange(5), np.array(row.Y_extended)[-5:] + DIFF, linestyle='--', color=cg, label='Auxilin')
if hline:
ax.axhline(642.3754691658837, color='gray', alpha=0.5)
if extra_key is not None:
if extra_key_label is None:
if extra_key == 'Z':
extra_key_label = 'Dynamin'
else:
extra_key_label = extra_key
p3, = twin2.plot(interval * np.arange(len(row.Z_extended)), np.array(row.Z_extended) + DIFF, linestyle='--', color='gray', alpha=.1)
if background:
plot_background(interval, row.Z_sigma_extended, row.Z_extended, color='gray', ax=twin2)
else:
twin2.plot(interval * np.arange(5, len(row.Z_extended)-5), np.array(row.Z_extended)[5:(-5)] + DIFF, color='gray')
twin2.fill_between(interval * np.arange(len(row.Z_extended)),
np.array(row.Z_extended) - np.array(row.Z_std_extended),
np.array(row.Z_extended) + np.array(row.Z_std_extended),
alpha=.1,
color='gray'
)
if i == 0:
twin2.text(x=interval * len(row.Z_extended),
y=np.array(row.Z_extended)[-1]-500,
s='Dyn2-Halo-E1-JF646',
color='gray',
size=8)
#if i == 0 and legend:
# dvu.line_legend()
tkw = dict(size=4, width=1.5)
ax.spines['right'].set_color(cg)
ax.tick_params(axis='y', colors=cr, labelsize=6, **tkw)
twin1.spines['left'].set_color(cr)
ax.spines['left'].set_color(cr)
#twin1.spines['left'].set_color(cg)
if num_axes == 3:
twin2.spines['left'].set_color(cr)
twin2.spines['right'].set_color(p3.get_color())
twin2.tick_params(axis='y', colors=p3.get_color(), labelsize=6, **tkw)
if ylim_constant:
ax.set_ylim(ylim_cla)
twin1.set_ylim(ylim_aux)
twin2.set_ylim(ylim_dyn)
else:
#p1_ylim = ax.get_ylim()
p2_ylim = twin1.get_ylim()
p3_ylim = twin2.get_ylim()
ylim_min = min(p2_ylim[0], p3_ylim[0])
twin1.set_ylim((ylim_min, 2*p2_ylim[1]))
twin2.set_ylim((ylim_min, 3*p3_ylim[1]))
p1_ylim = ax.get_ylim()
ax.set_ylim((- 2 * p1_ylim[1], p1_ylim[1]))
p2_ylim = twin1.get_ylim()
twin1.set_ylim((- 0.5 * p2_ylim[1], p2_ylim[1]))
#p3_ylim = twin2.get_ylim()
#twin1.set_ylim((- 0.5 * p3_ylim[1], p3_ylim[1]))
twin1.tick_params(axis='y', colors=cg, labelsize=6, **tkw)
ax.tick_params(axis='x', **tkw)
yticks = ax.get_yticks()
#if len(yticks) > 5:
ax.set_yticks([yt for yt in yticks if yt >= 0])
yticks = twin1.get_yticks()
#if len(yticks) > 5:
twin1.set_yticks(yticks[np.arange(1, len(yticks), 2)])
if num_axes == 3:
yticks = twin2.get_yticks()
if len(yticks) > 5:
twin2.set_yticks(yticks[np.arange(1, len(yticks), 2)])
if xlim_constant:
if xlim is None:
plt.xlim([-1, lifetime_max + 16])
else:
print(xlim)
plt.xlim(xlim)
if axes_invisible:
ax.yaxis.set_visible(False)
ax.xaxis.set_visible(False)
twin1.yaxis.set_visible(False)
twin1.spines['left'].set_color('w')
twin2.spines['left'].set_color('w')
twin2.yaxis.set_visible(False)
twin1.spines['right'].set_color('w')
twin2.spines['right'].set_color('w')
#plt.yscale('log')
# plt.axi('off')
# plt.legend()
plt.tight_layout()
if fig is None:
plt.show()
def viz_errs_outliers_venn(X_test, preds, Y_test, num_feats_reduced=5):
'''Compare outliers to errors in venn-diagram
'''
feat_names = data.get_feature_names(X_test)
X_feat = X_test[feat_names]
if num_feats_reduced is not None:
pca = decomposition.PCA(n_components=num_feats_reduced)
X_reduced = pca.fit_transform(X_feat)
else:
X_reduced = X_feat
R, C = 2, 2
titles = ['isolation forest', 'local outlier factor', 'elliptic envelop', 'one-class svm']
plt.figure(figsize=(6, 5), dpi=200)
for i in range(4):
plt.subplot(R, C, i + 1)
plt.title(titles[i])
if i == 0:
clf = IsolationForest(n_estimators=10, warm_start=True)
elif i == 1:
clf = LocalOutlierFactor(novelty=True)
elif i == 2:
clf = EllipticEnvelope()
elif i == 3:
clf = OneClassSVM()
clf.fit(X_reduced) # fit 10 trees
is_outlier = clf.predict(X_reduced) == -1
is_err = preds != Y_test
idxs = np.arange(is_outlier.size)
venn2([set(idxs[is_outlier]), set(idxs[is_err])], set_labels=['outliers', 'errors'])
def plot_pcs(pca, X):
'''Pretty plot of pcs with explained var bars
Params
------
pca: sklearn PCA class after being fitted
'''
plt.figure(figsize=(6, 9), dpi=200)
# extract out relevant pars
comps = pca.components_.transpose()
var_norm = pca.explained_variance_ / np.sum(pca.explained_variance_) * 100
# create a 2 X 2 grid
gs = grd.GridSpec(2, 2, height_ratios=[2, 10],
width_ratios=[12, 1], wspace=0.1, hspace=0)
# plot explained variance
ax2 = plt.subplot(gs[0])
ax2.bar(np.arange(0, comps.shape[1]), var_norm,
color='gray', width=0.8)
plt.title('Explained variance (%)')
ax2.spines['right'].set_visible(False)
ax2.spines['top'].set_visible(False)
ax2.yaxis.set_ticks_position('left')
ax2.set_yticks([0, max(var_norm)])
plt.xlim((-0.5, comps.shape[1] - 0.5))
# plot pcs
ax = plt.subplot(gs[2])
vmaxabs = np.max(np.abs(comps))
p = ax.imshow(comps, interpolation='None', aspect='auto',
cmap=sns.diverging_palette(10, 240, as_cmap=True, center='light'),
vmin=-vmaxabs, vmax=vmaxabs) # center at 0
plt.xlabel('PCA component number')
ax.set_yticklabels(list(X))
ax.set_yticks(range(len(list(X))))
# make colorbar
colorAx = plt.subplot(gs[3])
cb = plt.colorbar(p, cax=colorAx)
plt.show()
def print_metadata(acc=None, metadata_file=oj(config.DIR_PROCESSED, 'metadata_clath_aux+gak_a7d2.pkl')):
m = pkl.load(open(metadata_file, 'rb'))
print(
f'valid:\t\t{m["num_aux_pos_valid"]:>4.0f} aux+ / {m["num_tracks_valid"]:>4.0f} ({m["num_aux_pos_valid"] / m["num_tracks_valid"]:.3f})')
print('----------------------------------------')
print(f'hotspots:\t{m["num_hotspots_valid"]:>4.0f} aux+ / {m["num_hotspots_valid"]:>4.0f}')
print(
f'short:\t\t{m["num_short"] - m["num_short"] * m["acc_short"]:>4.0f} aux+ / {m["num_short"]:>4.0f} ({m["acc_short"]:.3f})')
print(f'long:\t\t{m["num_long"] * m["acc_long"]:>4.0f} aux+ / {m["num_long"]:>4.0f} ({m["acc_long"]:.3f})')
print(
f'hard:\t\t{m["num_aux_pos_hard"]:>4.0f} aux+ / {m["num_tracks_hard"]:>4.0f} ({m["num_aux_pos_hard"] / m["num_tracks_hard"]:.3f})')
if acc is not None:
print('----------------------------------------')
print(f'hard acc:\t\t\t {acc:.3f}')
num_eval = m["num_tracks_valid"] - m["num_hotspots_valid"]
# print(
# f'total acc (no hotspots):\t {(m["num_short"] * m["acc_short"] + m["num_long"] * m["acc_long"] + acc * m["num_tracks_hard"]) / num_eval:.3f}')
print('\nlifetime threshes', m['thresh_short'], m['thresh_long'])
def jointplot_grouped(col_x: str, col_y: str, col_k: str, df,
k_is_color=False, scatter_alpha=.5, add_global_hists: bool = False, ms=None):
'''Jointplot of hists + densities
Params
------
col_x
name of X var
col_y
name of Y var
col_k
name of variable to group/color by
add_global_hists
whether to plot the global hist as well
'''
def colored_scatter(x, y, c=None):
def scatter(*args, **kwargs):
args = (x, y)
if c is not None:
kwargs['c'] = c
kwargs['marker'] = '.'
kwargs['alpha'] = scatter_alpha
plt.scatter(*args, **kwargs)
return scatter
g = sns.JointGrid(
x=col_x,
y=col_y,
data=df
)
color = None
legends = []
for name, df_group in df.groupby(col_k):
legends.append(name)
if k_is_color:
color = name
g.plot_joint(
colored_scatter(df_group[col_x], df_group[col_y], color),
)
sns.distplot(
df_group[col_x].values,
ax=g.ax_marg_x,
color=color,
)
sns.distplot(
df_group[col_y].values,
ax=g.ax_marg_y,
color=color,
vertical=True
)
if add_global_hists:
sns.distplot(
df[col_x].values,
ax=g.ax_marg_x,
color='grey'
)
sns.distplot(
df[col_y].values.ravel(),
ax=g.ax_marg_y,
color='grey',
vertical=True
)
plt.legend(legends)
# 2d decision boundary
def plot_decision_boundary(X_col, Y_col, m, df, norms, num_pts=100):
'''still not finished...
'''
x = df[X_col]
y = df[Y_col]
x = np.linspace(x.min(), x.max(), num_pts)
y = np.linspace(y.min(), y.max(), num_pts)
# normalize
xv, yv = np.meshgrid(x, y, indexing='ij')
x = xv.flatten()
y = yv.flatten()
x = (x - norms[X_col]['mu']) / (norms[X_col]['std'])
y = (y - norms[Y_col]['mu']) / (norms[Y_col]['std'])
X = np.hstack((x, y)).reshape(-1, 2)
print(X.shape)
X = df[results_individual['feat_names_selected']]
preds = m.predict(X)
def cumulative_acc_plot_hard(preds_proba, preds, y_full_cv):
args = np.argsort(np.abs(preds_proba - 0.5))[::-1]
accs = (preds == y_full_cv)[args]
n = accs.size
accs = np.cumsum(accs) / np.arange(1, n + 1)
plt.figure(dpi=500)
plt.plot(preds_proba[args], '.', ms=0.5, label='predicted prob', color=cb)
plt.plot(accs, label='cumulative acc', color=cr)
plt.yticks(np.arange(-0.05, 1.05, 0.1))
plt.xlabel('num pts included')
plt.grid(alpha=0.2)
plt.legend()
plt.show()
def cumulative_acc_plot_all(df, pred_proba_key='preds_proba', pred_key='preds',
outcome_def='y_consec_thresh',
plot_vert_line_for_high_lifetimes=False, show=True):
plt.figure(dpi=500)
ax = plt.subplot(111)
# full (no model)
argsf = np.argsort(df.lifetime.values)
accsf = (1 - df[outcome_def]).values[argsf]
n = df.shape[0]
plt.plot(np.cumsum(accsf) / np.arange(1, accsf.size + 1), label='Predicting all abortive', color='gray')
print('accsf', np.sum(accsf))
# short
ds = df[df.short]
argss = np.argsort(ds.lifetime.values)
accss = (1 - ds[outcome_def]).values[argss]
ns = ds.shape[0]
# hard
dh = df[~df.short]
argsh = np.argsort(np.abs(dh[pred_key])) #[::-1]
accsh = ((dh[pred_key].values > 0) == dh[outcome_def].values)[argsh]
# put things together
accs = np.hstack((accss, accsh))
print(accsf.shape, accss.shape, accsh.shape, accs.shape)
plt.plot(np.cumsum(accs) / np.arange(1, accs.size + 1), label='LSTM', color=cb)
print(accs)
plt.axvline(ns, lw=2.5, color='black')
# dvu.line_legend()
plt.xlabel('Percentage of tracks included (sorted by uncertainty)')
plt.ylabel('Accuracy')
ax.xaxis.set_ticks([int(x) for x in np.arange(0, n + 1, n//5)])
ax.xaxis.set_ticklabels([str(int(x)) + '%' for x in np.arange(0, 101, 100/5)])
plt.legend(fontsize='x-large', frameon=False, labelcolor='linecolor')
plt.grid(alpha=0.2)
plt.tight_layout()
def plot_example(ex):
'''ex - row of the dataframe
'''
plt.figure(dpi=200)
plt.plot(ex['X'], color='red', label='clathrin')
plt.plot(ex['Y'], color='green', label='auxilin')
plt.xlabel('Time')
plt.ylabel('Amplitude')
plt.legend()
def get_videos(cell_dir: str):
'''Loads in X and Y for one cell
Params
------
cell_dir
Path to directory for one cell
Returns
-------
videos
'''
fname = {'cla': 'TagRFP', 'aux': 'EGFP', 'dyn': 'JF646'}
videos = {}
for m in fname:
for name in os.listdir(oj(cell_dir, fname[m])):
#print(f"filename: {name}")
if 'tif' in name:
videos[m] = imread(oj(cell_dir, fname[m], name))
return videos
def get_all_dynamin_videos(cells):
all_videos = {}
upper_dir = oj('/scratch/users/vision/data/abc_data/dynamin_data_with_ims/',
'CLTA-TagRFP EGFP-Aux1-GAK-F6 Dyn2-Halo-E1-JF646')
for cell_num in cells:
cell_dir = cell_num[:-6] + 'Cell1_1.5s'
full_dir = oj(upper_dir, cell_dir)
all_videos[cell_num] = get_videos(full_dir)
return all_videos
def get_dynamin_data_videos(df, pids, add_px=2, apply_norm=True):
"""
extract videos of dynamin traces
Params:
------
df: pd.DataFrame
dataframe
pids: list
list of pids to plot
add_px: int
number of additional pixels in each direction
add_px=1 means 3*3 pixels around the center, add_px=2 means 5*5, etc.
Returns:
------
videos
"""
indices = np.array([np.where(df.pid.values == pid)[0][0] for pid in pids])
df = df.iloc[indices]
cells = set(df.cell_num.values)
raw_videos = get_all_dynamin_videos(cells)
videos = {}
for i in range(len(df)):
cell_num = df.cell_num.iloc[i]
fr, h, w = raw_videos[cell_num]['cla'].shape
pid = df.pid.iloc[i]
videos[pid] = {}
x_pos, y_pos = df.x_pos_seq.iloc[i], df.y_pos_seq.iloc[i]
x_pos = [x_pos[0]]*5 + x_pos + [x_pos[-1]]*5
y_pos = [y_pos[0]]*5 + y_pos + [y_pos[-1]]*5
t, lt = df.t.iloc[i] - 5*1.5, min(len(x_pos), len(y_pos))
for m in ['cla', 'aux', 'dyn']:
videos[pid][m] = []
for j in range(lt):
for m in ['cla', 'aux', 'dyn']:
crop_y_pos = np.maximum(0, np.minimum(h - 1, np.arange(int(y_pos[j]) - add_px, int(y_pos[j]) + add_px + 1)))
crop_x_pos = np.maximum(0, np.minimum(h - 1, np.arange(int(x_pos[j]) - add_px, int(x_pos[j]) + add_px + 1)))
videos[pid][m].append(raw_videos[cell_num][m][int(t/1.5) + j,:,:] \
[crop_y_pos, :] \
[:, crop_x_pos])
# normalize by the min/max intensities
#vmin, vmax = raw_videos[cell_num][m][int(t/1.5) + j,:,:].mean(), raw_videos[cell_num][m][int(t/1.5) + j,:,:].max()
#videos[pid][m][-1] = (videos[pid][m][-1] - vmin)/(vmax - vmin)
# normalization
norm = {}
for m in ['cla', 'aux', 'dyn']:
norm[m] = [1e9, -1e9] # min, max
for pid in videos:
for j in range(len(videos[pid][m])):
norm[m][0] = min(norm[m][0], videos[pid][m][j].min())
norm[m][1] = max(norm[m][1], videos[pid][m][j].max())
if apply_norm:
for pid in videos:
for m in ['cla', 'aux', 'dyn']:
for j in range(len(videos[pid][m])):
videos[pid][m][j] = (videos[pid][m][j] - norm[m][0])/(norm[m][1] - norm[m][0])
return videos, norm
def plot_kymographs(df, pids, add_px=2):
"""
plot kymographs of dynamin traces
Params:
------
df: pd.DataFrame
dataframe
pids: list
list of pids to plot
add_px: int
number of additional pixels in each direction
add_px=1 means 3*3 pixels around the center, add_px=2 means 5*5, etc.
Returns:
------
cla_traces: np.array
clathrin traces from raw images
aux_traces: np.array
auxilin traces from raw images
rgb_image: 3d np.array
3d array (RGB values) of kymographs
"""
indices = np.array([np.where(df.pid.values == pid)[0][0] for pid in pids])
df = df.iloc[indices]
cells = set(df.cell_num.values)
raw_videos = get_all_dynamin_videos(cells)
viridis = cm.get_cmap('viridis', 12)
reds = cm.get_cmap('Reds', 12) # red palette for clathrin
greens = cm.get_cmap('Greens', 12) # green palette for auxilin
lmax = max([len(df.x_pos_seq.iloc[i]) for i in range(len(df))]) + 2
width = 2 * add_px + 1
cla_traces, aux_traces = {}, {}
for i in range(len(df)):
cla_traces[i], aux_traces[i] = np.zeros((lmax, width)), np.zeros((lmax, width))
#xmean = X[df.cell_num.iloc[i]].mean()
cell_num = df.cell_num.iloc[i]
x_pos, y_pos = df.x_pos_seq.iloc[i], df.y_pos_seq.iloc[i]
t, lt = df.t.iloc[i], min(len(x_pos), len(y_pos))
for k in range(-add_px, add_px + 1):
for j in range(lt):
video = raw_videos[cell_num]['cla']
cla_traces[i][j, k + add_px] = max(video[int(t/1.5) + j, \
range(int(y_pos[j]) - 0, int(y_pos[j]) + 0 + 1), \
int(x_pos[j] + k)])
vmin, vmax = video[int(t/1.5) + j,:,:].min(), video[int(t/1.5) + j,:,:].max()
cla_traces[i][j, k + add_px] = (cla_traces[i][j, k + add_px] - vmin)/(vmax - vmin)
video = raw_videos[cell_num]['aux']
aux_traces[i][j, k + add_px] = max(video[int(t/1.5) + j, \
range(int(y_pos[j]) - 0, int(y_pos[j]) + 0 + 1), \
int(x_pos[j] + k)])
vmin, vmax = video[int(t/1.5) + j,:,:].min(), video[int(t/1.5) + j,:,:].max()
aux_traces[i][j, k + add_px] = (aux_traces[i][j, k + add_px] - vmin)/(vmax - vmin)
ncol = 3 * width * len(df)
cla_sparse = np.zeros((lmax, ncol))
aux_sparse = np.zeros((lmax, ncol))
for i in range(len(df)):
start_index = 3 * width * i
cla_sparse[:, (start_index):(start_index + width)] = cla_traces[i]
aux_sparse[:, (start_index + width):(start_index + 2 * width)] = aux_traces[i]
rgb_image = np.array([[list(reds(cla_sparse[i][j])[:3])
#[1, 1 - cla_sparse[i][j], 1 - cla_sparse[i][j]] \
if 0 < cla_sparse[i][j] < 1 \
else \
list(greens(aux_sparse[i][j])[:3]) \
#[1 - aux_sparse[i][j], 1, 1 - aux_sparse[i][j]] \
if 0 < aux_sparse[i][j] < 1 \
#else list(viridis(np.random.choice(background, 1)[0])[:3])\
else (1, 1, 1)
for i in range(lmax)] \
for j in range(ncol)])
#cla_sparse = np.transpose(cla_sparse)
#aux_sparse = np.transpose(aux_sparse)
return cla_traces, aux_traces, rgb_image
Functions
def cumulative_acc_plot_all(df, pred_proba_key='preds_proba', pred_key='preds', outcome_def='y_consec_thresh', plot_vert_line_for_high_lifetimes=False, show=True)
-
Expand source code
def cumulative_acc_plot_all(df, pred_proba_key='preds_proba', pred_key='preds', outcome_def='y_consec_thresh', plot_vert_line_for_high_lifetimes=False, show=True): plt.figure(dpi=500) ax = plt.subplot(111) # full (no model) argsf = np.argsort(df.lifetime.values) accsf = (1 - df[outcome_def]).values[argsf] n = df.shape[0] plt.plot(np.cumsum(accsf) / np.arange(1, accsf.size + 1), label='Predicting all abortive', color='gray') print('accsf', np.sum(accsf)) # short ds = df[df.short] argss = np.argsort(ds.lifetime.values) accss = (1 - ds[outcome_def]).values[argss] ns = ds.shape[0] # hard dh = df[~df.short] argsh = np.argsort(np.abs(dh[pred_key])) #[::-1] accsh = ((dh[pred_key].values > 0) == dh[outcome_def].values)[argsh] # put things together accs = np.hstack((accss, accsh)) print(accsf.shape, accss.shape, accsh.shape, accs.shape) plt.plot(np.cumsum(accs) / np.arange(1, accs.size + 1), label='LSTM', color=cb) print(accs) plt.axvline(ns, lw=2.5, color='black') # dvu.line_legend() plt.xlabel('Percentage of tracks included (sorted by uncertainty)') plt.ylabel('Accuracy') ax.xaxis.set_ticks([int(x) for x in np.arange(0, n + 1, n//5)]) ax.xaxis.set_ticklabels([str(int(x)) + '%' for x in np.arange(0, 101, 100/5)]) plt.legend(fontsize='x-large', frameon=False, labelcolor='linecolor') plt.grid(alpha=0.2) plt.tight_layout()
def cumulative_acc_plot_hard(preds_proba, preds, y_full_cv)
-
Expand source code
def cumulative_acc_plot_hard(preds_proba, preds, y_full_cv): args = np.argsort(np.abs(preds_proba - 0.5))[::-1] accs = (preds == y_full_cv)[args] n = accs.size accs = np.cumsum(accs) / np.arange(1, n + 1) plt.figure(dpi=500) plt.plot(preds_proba[args], '.', ms=0.5, label='predicted prob', color=cb) plt.plot(accs, label='cumulative acc', color=cr) plt.yticks(np.arange(-0.05, 1.05, 0.1)) plt.xlabel('num pts included') plt.grid(alpha=0.2) plt.legend() plt.show()
def fix_feat_name(s)
-
Expand source code
def fix_feat_name(s): return s.replace('_', ' ').replace('X', 'Clath').capitalize()
def get_all_dynamin_videos(cells)
-
Expand source code
def get_all_dynamin_videos(cells): all_videos = {} upper_dir = oj('/scratch/users/vision/data/abc_data/dynamin_data_with_ims/', 'CLTA-TagRFP EGFP-Aux1-GAK-F6 Dyn2-Halo-E1-JF646') for cell_num in cells: cell_dir = cell_num[:-6] + 'Cell1_1.5s' full_dir = oj(upper_dir, cell_dir) all_videos[cell_num] = get_videos(full_dir) return all_videos
def get_dynamin_data_videos(df, pids, add_px=2, apply_norm=True)
-
extract videos of dynamin traces
Params:
df: pd.DataFrame dataframe
pids: list list of pids to plot
add_px: int number of additional pixels in each direction add_px=1 means 33 pixels around the center, add_px=2 means 55, etc.
Returns:
videos
Expand source code
def get_dynamin_data_videos(df, pids, add_px=2, apply_norm=True): """ extract videos of dynamin traces Params: ------ df: pd.DataFrame dataframe pids: list list of pids to plot add_px: int number of additional pixels in each direction add_px=1 means 3*3 pixels around the center, add_px=2 means 5*5, etc. Returns: ------ videos """ indices = np.array([np.where(df.pid.values == pid)[0][0] for pid in pids]) df = df.iloc[indices] cells = set(df.cell_num.values) raw_videos = get_all_dynamin_videos(cells) videos = {} for i in range(len(df)): cell_num = df.cell_num.iloc[i] fr, h, w = raw_videos[cell_num]['cla'].shape pid = df.pid.iloc[i] videos[pid] = {} x_pos, y_pos = df.x_pos_seq.iloc[i], df.y_pos_seq.iloc[i] x_pos = [x_pos[0]]*5 + x_pos + [x_pos[-1]]*5 y_pos = [y_pos[0]]*5 + y_pos + [y_pos[-1]]*5 t, lt = df.t.iloc[i] - 5*1.5, min(len(x_pos), len(y_pos)) for m in ['cla', 'aux', 'dyn']: videos[pid][m] = [] for j in range(lt): for m in ['cla', 'aux', 'dyn']: crop_y_pos = np.maximum(0, np.minimum(h - 1, np.arange(int(y_pos[j]) - add_px, int(y_pos[j]) + add_px + 1))) crop_x_pos = np.maximum(0, np.minimum(h - 1, np.arange(int(x_pos[j]) - add_px, int(x_pos[j]) + add_px + 1))) videos[pid][m].append(raw_videos[cell_num][m][int(t/1.5) + j,:,:] \ [crop_y_pos, :] \ [:, crop_x_pos]) # normalize by the min/max intensities #vmin, vmax = raw_videos[cell_num][m][int(t/1.5) + j,:,:].mean(), raw_videos[cell_num][m][int(t/1.5) + j,:,:].max() #videos[pid][m][-1] = (videos[pid][m][-1] - vmin)/(vmax - vmin) # normalization norm = {} for m in ['cla', 'aux', 'dyn']: norm[m] = [1e9, -1e9] # min, max for pid in videos: for j in range(len(videos[pid][m])): norm[m][0] = min(norm[m][0], videos[pid][m][j].min()) norm[m][1] = max(norm[m][1], videos[pid][m][j].max()) if apply_norm: for pid in videos: for m in ['cla', 'aux', 'dyn']: for j in range(len(videos[pid][m])): videos[pid][m][j] = (videos[pid][m][j] - norm[m][0])/(norm[m][1] - norm[m][0]) return videos, norm
def get_videos(cell_dir)
-
Loads in X and Y for one cell
Params
cell_dir
- Path to directory for one cell
Returns
videos
Expand source code
def get_videos(cell_dir: str): '''Loads in X and Y for one cell Params ------ cell_dir Path to directory for one cell Returns ------- videos ''' fname = {'cla': 'TagRFP', 'aux': 'EGFP', 'dyn': 'JF646'} videos = {} for m in fname: for name in os.listdir(oj(cell_dir, fname[m])): #print(f"filename: {name}") if 'tif' in name: videos[m] = imread(oj(cell_dir, fname[m], name)) return videos
def highlight_max(data, color='#0e5c99')
-
highlight the maximum in a Series or DataFrame
Expand source code
def highlight_max(data, color='#0e5c99'): ''' highlight the maximum in a Series or DataFrame ''' attr = 'background-color: {}'.format(color) if data.ndim == 1: # Series from .apply(axis=0) or axis=1 is_max = data == data.max() return [attr if v else '' for v in is_max] else: # from .apply(axis=None) is_max = data == data.max().max() return pd.DataFrame(np.where(is_max, attr, ''), index=data.index, columns=data.columns)
def jointplot_grouped(col_x, col_y, col_k, df, k_is_color=False, scatter_alpha=0.5, add_global_hists=False, ms=None)
-
Jointplot of hists + densities Params
col_x
- name of X var
col_y
- name of Y var
col_k
- name of variable to group/color by
add_global_hists
- whether to plot the global hist as well
Expand source code
def jointplot_grouped(col_x: str, col_y: str, col_k: str, df, k_is_color=False, scatter_alpha=.5, add_global_hists: bool = False, ms=None): '''Jointplot of hists + densities Params ------ col_x name of X var col_y name of Y var col_k name of variable to group/color by add_global_hists whether to plot the global hist as well ''' def colored_scatter(x, y, c=None): def scatter(*args, **kwargs): args = (x, y) if c is not None: kwargs['c'] = c kwargs['marker'] = '.' kwargs['alpha'] = scatter_alpha plt.scatter(*args, **kwargs) return scatter g = sns.JointGrid( x=col_x, y=col_y, data=df ) color = None legends = [] for name, df_group in df.groupby(col_k): legends.append(name) if k_is_color: color = name g.plot_joint( colored_scatter(df_group[col_x], df_group[col_y], color), ) sns.distplot( df_group[col_x].values, ax=g.ax_marg_x, color=color, ) sns.distplot( df_group[col_y].values, ax=g.ax_marg_y, color=color, vertical=True ) if add_global_hists: sns.distplot( df[col_x].values, ax=g.ax_marg_x, color='grey' ) sns.distplot( df[col_y].values.ravel(), ax=g.ax_marg_y, color='grey', vertical=True ) plt.legend(legends)
def plot_above_threshold(x1, y1, b1, x2, y2, b2, ax, color, lsty)
-
Expand source code
def plot_above_threshold(x1, y1, b1, x2, y2, b2, ax, color, lsty): sl1 = (y2 - y1)/(x2 - x1) sl2 = (b2 - b1)/(x2 - x1) if y1 >= b1 and y2 >= b2: ax.plot([x1, x2], [y1, y2], linestyle=lsty, color=color, alpha=1) elif y1 < b1 and y2 < b2: ax.plot([x1, x2], [y1, y2], linestyle=lsty, color=color, alpha=.1) elif y1 >= b1 and y2 < b2: crosspoint_x, crosspoint_y = x1 + (y1 - b1)/(sl2 - sl1), y1 + sl1 * (y1 - b1)/(sl2 - sl1) ax.plot([x1, crosspoint_x], [y1, crosspoint_y], linestyle=lsty, color=color, alpha=1) ax.plot([crosspoint_x, x2], [crosspoint_y, y2], linestyle=lsty, color=color, alpha=.1) elif y1 < b1 and y2 >= b2: crosspoint_x, crosspoint_y = x1 + (y1 - b1)/(sl2 - sl1), y1 + sl1 * (y1 - b1)/(sl2 - sl1) ax.plot([x1, crosspoint_x], [y1, crosspoint_y], linestyle=lsty, color=color, alpha=.1) ax.plot([crosspoint_x, x2], [crosspoint_y, y2], linestyle=lsty, color=color, alpha=1)
def plot_background(interval, bg, trace, color, ax)
-
Expand source code
def plot_background(interval, bg, trace, color, ax): ax.fill_between(interval * np.arange(len(bg)), [0] * len(bg), 2 * np.array(bg), alpha=.1, color=color) x, y, lt = np.arange(len(trace)), np.array(trace), len(trace) #ax.plot(interval * x, y, linestyle='--', color=cr, alpha=.2) bg = 2 * np.array(bg) for f in range(lt - 1): lsty = '--' if f < 5 or f >= lt - 5 else '-' plot_above_threshold(x1=interval*x[f], y1=y[f], b1=bg[f], x2=interval*x[f+1], y2=y[f+1], b2=bg[f+1], ax=ax, color=color, lsty=lsty)
def plot_confusion_matrix(y_true, y_pred, classes, normalize=False, title=None, cmap=<matplotlib.colors.LinearSegmentedColormap object>)
-
This function prints and plots the confusion matrix. Normalization can be applied by setting
normalize=True
. Params
classes
:np.ndarray
(Str
)- classes=np.array(['aux-', 'aux+'])
Expand source code
def plot_confusion_matrix(y_true, y_pred, classes, normalize=False, title=None, cmap=plt.cm.Blues): """ This function prints and plots the confusion matrix. Normalization can be applied by setting `normalize=True`. Params ------ classes: np.ndarray(Str) classes=np.array(['aux-', 'aux+']) """ plt.figure(dpi=300) if not title: if normalize: title = 'Normalized confusion matrix' else: title = 'Confusion matrix, without normalization' # Compute confusion matrix cm = metrics.confusion_matrix(y_true, y_pred) # Only use the labels that appear in the data classes = classes[unique_labels(y_true.astype(np.int), y_pred.astype(np.int))] if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] # fig, ax = plt.subplots() im = plt.imshow(cm, interpolation='nearest', cmap=cmap) ax = plt.gca() # ax.figure.colorbar(im, ax=ax) # We want to show all ticks... ax.set(xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]), # ... and label them with the respective list entries xticklabels=classes, yticklabels=classes, # title=title, ylabel='True label', xlabel='Predicted label') # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Loop over data dimensions and create text annotations. fmt = '.2f' if normalize else 'd' thresh = cm.max() / 2. for i in range(cm.shape[0]): for j in range(cm.shape[1]): ax.text(j, i, format(cm[i, j], fmt), ha="center", va="center", color="white" if cm[i, j] > thresh else "black") return ax
def plot_curves(df, extra_key=None, extra_key_label=None, hline=True, R=5, C=8, xlim=None, fig=None, ylim_constant=False, background=False, ylim_cla=None, ylim_aux=None, ylim_dyn=None, xlim_constant=True, legend=True, plot_x=True, yticks=None, yticklabels=None, num_axes=3, show_track_pid=False, axes_invisible=False)
-
Plot time-series curves from df
Expand source code
def plot_curves(df, extra_key=None, extra_key_label=None, hline=True, R=5, C=8, xlim=None, fig=None, ylim_constant=False, background=False, ylim_cla=None, ylim_aux=None, ylim_dyn=None, xlim_constant=True, legend=True, plot_x=True, yticks=None, yticklabels=None, num_axes=3, show_track_pid=False, axes_invisible=False): '''Plot time-series curves from df ''' DIFF = 0 if fig is None: plt.figure(figsize=(16, 10), dpi=200, facecolor='white') lifetime_max = np.max(df.lifetime.values[:R * C]) df = df.iloc[range(R * C)] for i in range(R * C): if i < df.shape[0]: ax = plt.subplot(R, C, i + 1) row = df.iloc[i] if '1.5s' in row['cell_num']: interval = 1.5 else: interval = 1 if num_axes == 1: if plot_x: plt.plot(interval * np.arange(len(row.X_extended)), np.array(row.X_extended) + DIFF, linestyle='--', color=cr) plt.plot(interval * np.arange(len(row.Y_extended)), np.array(row.Y_extended) + DIFF, linestyle='--', color=cg) plt.plot(interval * np.arange(5, len(row.X_extended)-5), np.array(row.X_extended)[5:(-5)] + DIFF, color=cr, label='Clathrin') plt.plot(interval * np.arange(5, len(row.Y_extended)-5), np.array(row.Y_extended)[5:(-5)] + DIFF, color=cg, label='Auxilin') #plt.plot(interval * np.arange(5), np.array(row.X_extended)[-5:] + DIFF, linestyle='--', color=cr, label='Clathrin') #plt.plot(interval * np.arange(5), np.array(row.Y_extended)[-5:] + DIFF, linestyle='--', color=cg, label='Auxilin') if background: ax.plot(interval * np.arange(len(row.X_extended)), np.array(row.X_c_extended), color=cr, linewidth=.8) ax.fill_between(interval * np.arange(len(row.X_extended)), np.array(row.X_extended) - np.array(row.X_std_extended), np.array(row.X_extended) + np.array(row.X_std_extended), alpha=.2, color=cr ) if hline: plt.axhline(642.3754691658837, color='gray', alpha=0.5) if extra_key is not None: if extra_key_label is None: if extra_key == 'Z': extra_key_label = 'Dynamin' else: extra_key_label = extra_key plt.plot(interval * np.arange(len(row[extra_key])), np.array(row[extra_key]) + DIFF, linestyle='--', color='gray') plt.plot(interval * np.arange(5, len(row[extra_key])-5), np.array(row[extra_key])[5:(-5)] + DIFF, color='gray', label=extra_key_label) if xlim_constant: if xlim is None: plt.xlim([-1, lifetime_max + 1]) else: print(xlim) plt.xlim(xlim) if ylim_constant: if ylim is None: plt.ylim([-10, max(max(df.X_max), max(df.Y_max)) + 1]) else: plt.ylim(ylim[0] + DIFF, ylim[1] + DIFF) if yticks is not None: plt.yticks(yticks, labels=yticklabels) else: ax.spines['right'].set_visible(True) twin1 = ax.twinx() if num_axes == 3: twin2 = ax.twinx() twin2.spines['right'].set_visible(True) twin2.spines['right'].set_position(("axes", 1.2)) else: twin2 = twin1 if show_track_pid: ax.text(.5, .9, f'{row.pid}', # row.pid horizontalalignment='right', transform=ax.transAxes) if plot_x: p1, = ax.plot(interval * np.arange(len(row.X_extended)), np.array(row.X_extended) + DIFF, linestyle='--', color=cr, alpha=.1) if i == 0: ax.text(x=interval * len(row.X_extended), y=np.array(row.X_extended)[-1], s='CLTA-TagRFP', color=cr, size=8) if background: plot_background(interval, row.X_sigma_extended, row.X_extended, color=cr, ax=ax) else: ax.plot(interval * np.arange(5, len(row.X_extended)-5), np.array(row.X_extended)[5:(-5)] + DIFF, color=cr) if i == 0 and legend: dvu.line_legend() ax.fill_between(interval * np.arange(len(row.X_extended)), np.array(row.X_extended) - np.array(row.X_std_extended), np.array(row.X_extended) + np.array(row.X_std_extended), alpha=.2, color=cr ) p2, = twin1.plot(interval * np.arange(len(row.Y_extended)), np.array(row.Y_extended) + DIFF, linestyle='--', color=cg, alpha=.1) if background: plot_background(interval, row.Y_sigma_extended, row.Y_extended, color=cg, ax=twin1 ) else: twin1.plot(interval * np.arange(5, len(row.Y_extended)-5), np.array(row.Y_extended)[5:(-5)] + DIFF, color=cg, label='EGFP-Aux1-GAK-F6') if i == 0 and legend: dvu.line_legend() twin1.fill_between(interval * np.arange(len(row.Y_extended)), np.array(row.Y_extended) - np.array(row.Y_std_extended), np.array(row.Y_extended) + np.array(row.Y_std_extended), alpha=.2, color=cg ) if i == 0: twin1.text(x=interval * len(row.Y_extended), y=np.array(row.Y_extended)[-1], s='EGFP-Aux1-GAK-F6', color=cg, size=8) #plt.plot(interval * np.arange(5), np.array(row.X_extended)[-5:] + DIFF, linestyle='--', color=cr, label='Clathrin') #plt.plot(interval * np.arange(5), np.array(row.Y_extended)[-5:] + DIFF, linestyle='--', color=cg, label='Auxilin') if hline: ax.axhline(642.3754691658837, color='gray', alpha=0.5) if extra_key is not None: if extra_key_label is None: if extra_key == 'Z': extra_key_label = 'Dynamin' else: extra_key_label = extra_key p3, = twin2.plot(interval * np.arange(len(row.Z_extended)), np.array(row.Z_extended) + DIFF, linestyle='--', color='gray', alpha=.1) if background: plot_background(interval, row.Z_sigma_extended, row.Z_extended, color='gray', ax=twin2) else: twin2.plot(interval * np.arange(5, len(row.Z_extended)-5), np.array(row.Z_extended)[5:(-5)] + DIFF, color='gray') twin2.fill_between(interval * np.arange(len(row.Z_extended)), np.array(row.Z_extended) - np.array(row.Z_std_extended), np.array(row.Z_extended) + np.array(row.Z_std_extended), alpha=.1, color='gray' ) if i == 0: twin2.text(x=interval * len(row.Z_extended), y=np.array(row.Z_extended)[-1]-500, s='Dyn2-Halo-E1-JF646', color='gray', size=8) #if i == 0 and legend: # dvu.line_legend() tkw = dict(size=4, width=1.5) ax.spines['right'].set_color(cg) ax.tick_params(axis='y', colors=cr, labelsize=6, **tkw) twin1.spines['left'].set_color(cr) ax.spines['left'].set_color(cr) #twin1.spines['left'].set_color(cg) if num_axes == 3: twin2.spines['left'].set_color(cr) twin2.spines['right'].set_color(p3.get_color()) twin2.tick_params(axis='y', colors=p3.get_color(), labelsize=6, **tkw) if ylim_constant: ax.set_ylim(ylim_cla) twin1.set_ylim(ylim_aux) twin2.set_ylim(ylim_dyn) else: #p1_ylim = ax.get_ylim() p2_ylim = twin1.get_ylim() p3_ylim = twin2.get_ylim() ylim_min = min(p2_ylim[0], p3_ylim[0]) twin1.set_ylim((ylim_min, 2*p2_ylim[1])) twin2.set_ylim((ylim_min, 3*p3_ylim[1])) p1_ylim = ax.get_ylim() ax.set_ylim((- 2 * p1_ylim[1], p1_ylim[1])) p2_ylim = twin1.get_ylim() twin1.set_ylim((- 0.5 * p2_ylim[1], p2_ylim[1])) #p3_ylim = twin2.get_ylim() #twin1.set_ylim((- 0.5 * p3_ylim[1], p3_ylim[1])) twin1.tick_params(axis='y', colors=cg, labelsize=6, **tkw) ax.tick_params(axis='x', **tkw) yticks = ax.get_yticks() #if len(yticks) > 5: ax.set_yticks([yt for yt in yticks if yt >= 0]) yticks = twin1.get_yticks() #if len(yticks) > 5: twin1.set_yticks(yticks[np.arange(1, len(yticks), 2)]) if num_axes == 3: yticks = twin2.get_yticks() if len(yticks) > 5: twin2.set_yticks(yticks[np.arange(1, len(yticks), 2)]) if xlim_constant: if xlim is None: plt.xlim([-1, lifetime_max + 16]) else: print(xlim) plt.xlim(xlim) if axes_invisible: ax.yaxis.set_visible(False) ax.xaxis.set_visible(False) twin1.yaxis.set_visible(False) twin1.spines['left'].set_color('w') twin2.spines['left'].set_color('w') twin2.yaxis.set_visible(False) twin1.spines['right'].set_color('w') twin2.spines['right'].set_color('w') #plt.yscale('log') # plt.axi('off') # plt.legend() plt.tight_layout() if fig is None: plt.show()
def plot_decision_boundary(X_col, Y_col, m, df, norms, num_pts=100)
-
still not finished…
Expand source code
def plot_decision_boundary(X_col, Y_col, m, df, norms, num_pts=100): '''still not finished... ''' x = df[X_col] y = df[Y_col] x = np.linspace(x.min(), x.max(), num_pts) y = np.linspace(y.min(), y.max(), num_pts) # normalize xv, yv = np.meshgrid(x, y, indexing='ij') x = xv.flatten() y = yv.flatten() x = (x - norms[X_col]['mu']) / (norms[X_col]['std']) y = (y - norms[Y_col]['mu']) / (norms[Y_col]['std']) X = np.hstack((x, y)).reshape(-1, 2) print(X.shape) X = df[results_individual['feat_names_selected']] preds = m.predict(X)
def plot_example(ex)
-
ex - row of the dataframe
Expand source code
def plot_example(ex): '''ex - row of the dataframe ''' plt.figure(dpi=200) plt.plot(ex['X'], color='red', label='clathrin') plt.plot(ex['Y'], color='green', label='auxilin') plt.xlabel('Time') plt.ylabel('Amplitude') plt.legend()
def plot_kymographs(df, pids, add_px=2)
-
plot kymographs of dynamin traces
Params:
df: pd.DataFrame dataframe
pids: list list of pids to plot
add_px: int number of additional pixels in each direction add_px=1 means 33 pixels around the center, add_px=2 means 55, etc.
Returns:
cla_traces: np.array clathrin traces from raw images aux_traces: np.array auxilin traces from raw images rgb_image: 3d np.array 3d array (RGB values) of kymographs
Expand source code
def plot_kymographs(df, pids, add_px=2): """ plot kymographs of dynamin traces Params: ------ df: pd.DataFrame dataframe pids: list list of pids to plot add_px: int number of additional pixels in each direction add_px=1 means 3*3 pixels around the center, add_px=2 means 5*5, etc. Returns: ------ cla_traces: np.array clathrin traces from raw images aux_traces: np.array auxilin traces from raw images rgb_image: 3d np.array 3d array (RGB values) of kymographs """ indices = np.array([np.where(df.pid.values == pid)[0][0] for pid in pids]) df = df.iloc[indices] cells = set(df.cell_num.values) raw_videos = get_all_dynamin_videos(cells) viridis = cm.get_cmap('viridis', 12) reds = cm.get_cmap('Reds', 12) # red palette for clathrin greens = cm.get_cmap('Greens', 12) # green palette for auxilin lmax = max([len(df.x_pos_seq.iloc[i]) for i in range(len(df))]) + 2 width = 2 * add_px + 1 cla_traces, aux_traces = {}, {} for i in range(len(df)): cla_traces[i], aux_traces[i] = np.zeros((lmax, width)), np.zeros((lmax, width)) #xmean = X[df.cell_num.iloc[i]].mean() cell_num = df.cell_num.iloc[i] x_pos, y_pos = df.x_pos_seq.iloc[i], df.y_pos_seq.iloc[i] t, lt = df.t.iloc[i], min(len(x_pos), len(y_pos)) for k in range(-add_px, add_px + 1): for j in range(lt): video = raw_videos[cell_num]['cla'] cla_traces[i][j, k + add_px] = max(video[int(t/1.5) + j, \ range(int(y_pos[j]) - 0, int(y_pos[j]) + 0 + 1), \ int(x_pos[j] + k)]) vmin, vmax = video[int(t/1.5) + j,:,:].min(), video[int(t/1.5) + j,:,:].max() cla_traces[i][j, k + add_px] = (cla_traces[i][j, k + add_px] - vmin)/(vmax - vmin) video = raw_videos[cell_num]['aux'] aux_traces[i][j, k + add_px] = max(video[int(t/1.5) + j, \ range(int(y_pos[j]) - 0, int(y_pos[j]) + 0 + 1), \ int(x_pos[j] + k)]) vmin, vmax = video[int(t/1.5) + j,:,:].min(), video[int(t/1.5) + j,:,:].max() aux_traces[i][j, k + add_px] = (aux_traces[i][j, k + add_px] - vmin)/(vmax - vmin) ncol = 3 * width * len(df) cla_sparse = np.zeros((lmax, ncol)) aux_sparse = np.zeros((lmax, ncol)) for i in range(len(df)): start_index = 3 * width * i cla_sparse[:, (start_index):(start_index + width)] = cla_traces[i] aux_sparse[:, (start_index + width):(start_index + 2 * width)] = aux_traces[i] rgb_image = np.array([[list(reds(cla_sparse[i][j])[:3]) #[1, 1 - cla_sparse[i][j], 1 - cla_sparse[i][j]] \ if 0 < cla_sparse[i][j] < 1 \ else \ list(greens(aux_sparse[i][j])[:3]) \ #[1 - aux_sparse[i][j], 1, 1 - aux_sparse[i][j]] \ if 0 < aux_sparse[i][j] < 1 \ #else list(viridis(np.random.choice(background, 1)[0])[:3])\ else (1, 1, 1) for i in range(lmax)] \ for j in range(ncol)]) #cla_sparse = np.transpose(cla_sparse) #aux_sparse = np.transpose(aux_sparse) return cla_traces, aux_traces, rgb_image
def plot_pcs(pca, X)
-
Pretty plot of pcs with explained var bars Params
pca
:sklearn
PCA
class
after
being
fitted
Expand source code
def plot_pcs(pca, X): '''Pretty plot of pcs with explained var bars Params ------ pca: sklearn PCA class after being fitted ''' plt.figure(figsize=(6, 9), dpi=200) # extract out relevant pars comps = pca.components_.transpose() var_norm = pca.explained_variance_ / np.sum(pca.explained_variance_) * 100 # create a 2 X 2 grid gs = grd.GridSpec(2, 2, height_ratios=[2, 10], width_ratios=[12, 1], wspace=0.1, hspace=0) # plot explained variance ax2 = plt.subplot(gs[0]) ax2.bar(np.arange(0, comps.shape[1]), var_norm, color='gray', width=0.8) plt.title('Explained variance (%)') ax2.spines['right'].set_visible(False) ax2.spines['top'].set_visible(False) ax2.yaxis.set_ticks_position('left') ax2.set_yticks([0, max(var_norm)]) plt.xlim((-0.5, comps.shape[1] - 0.5)) # plot pcs ax = plt.subplot(gs[2]) vmaxabs = np.max(np.abs(comps)) p = ax.imshow(comps, interpolation='None', aspect='auto', cmap=sns.diverging_palette(10, 240, as_cmap=True, center='light'), vmin=-vmaxabs, vmax=vmaxabs) # center at 0 plt.xlabel('PCA component number') ax.set_yticklabels(list(X)) ax.set_yticks(range(len(list(X)))) # make colorbar colorAx = plt.subplot(gs[3]) cb = plt.colorbar(p, cax=colorAx) plt.show()
def print_metadata(acc=None, metadata_file='/accounts/projects/vision/chandan/auxilin-prediction/src/../data/processed/metadata_clath_aux+gak_a7d2.pkl')
-
Expand source code
def print_metadata(acc=None, metadata_file=oj(config.DIR_PROCESSED, 'metadata_clath_aux+gak_a7d2.pkl')): m = pkl.load(open(metadata_file, 'rb')) print( f'valid:\t\t{m["num_aux_pos_valid"]:>4.0f} aux+ / {m["num_tracks_valid"]:>4.0f} ({m["num_aux_pos_valid"] / m["num_tracks_valid"]:.3f})') print('----------------------------------------') print(f'hotspots:\t{m["num_hotspots_valid"]:>4.0f} aux+ / {m["num_hotspots_valid"]:>4.0f}') print( f'short:\t\t{m["num_short"] - m["num_short"] * m["acc_short"]:>4.0f} aux+ / {m["num_short"]:>4.0f} ({m["acc_short"]:.3f})') print(f'long:\t\t{m["num_long"] * m["acc_long"]:>4.0f} aux+ / {m["num_long"]:>4.0f} ({m["acc_long"]:.3f})') print( f'hard:\t\t{m["num_aux_pos_hard"]:>4.0f} aux+ / {m["num_tracks_hard"]:>4.0f} ({m["num_aux_pos_hard"] / m["num_tracks_hard"]:.3f})') if acc is not None: print('----------------------------------------') print(f'hard acc:\t\t\t {acc:.3f}') num_eval = m["num_tracks_valid"] - m["num_hotspots_valid"] # print( # f'total acc (no hotspots):\t {(m["num_short"] * m["acc_short"] + m["num_long"] * m["acc_long"] + acc * m["num_tracks_hard"]) / num_eval:.3f}') print('\nlifetime threshes', m['thresh_short'], m['thresh_long'])
def savefig(s, png=False)
-
Expand source code
def savefig(s: str, png=False): # plt.tight_layout() plt.savefig(oj(DIR_FIGS, 'fig_' + s + '.pdf'), bbox_inches='tight') if png: plt.savefig(oj(DIR_FIGS, 'fig_' + s + '.png'), dpi=300, bbox_inches='tight')
def viz_biggest_errs(df, idxs_cv, idxs, Y_test, preds, preds_proba, num_to_plot=20, aux_thresh=642, show_track_num=True, show_track_pid=False, sort_by_residuals=True, width_mult=3, plot_x=True, plot_y=True, plot_z=False, plot_axhline=True, xlim_constant=True, ylim=None, yticks=None, yticklabels=None, lifetime_max=None, text_labels=False, text_label_size=25)
-
Visualize X and Y where the top examples are the most wrong / least confident Params
idxs_cv
:integer
ndarray
- which idxs are not part of the test set (usually just 0, 1, 2, …)
idxs
:boolean
ndarray
- subset of points to plot
Expand source code
def viz_biggest_errs(df, idxs_cv, idxs, Y_test, preds, preds_proba, num_to_plot=20, aux_thresh=642, show_track_num=True, show_track_pid=False, sort_by_residuals=True, width_mult=3, plot_x=True, plot_y=True, plot_z=False, plot_axhline=True, xlim_constant=True, ylim: tuple=None, yticks=None, yticklabels=None, lifetime_max=None, text_labels=False, text_label_size=25): '''Visualize X and Y where the top examples are the most wrong / least confident Params ------ idxs_cv: integer ndarray which idxs are not part of the test set (usually just 0, 1, 2, ...) idxs: boolean ndarray subset of points to plot ''' DIFF = 0 # use this to ensure values are all positive # deal with idxs if idxs is not None: Y_test = Y_test[idxs] preds = preds[idxs] preds_proba = preds_proba[idxs] if idxs_cv is None: idxs_cv = np.arange(df.shape[0]) df = df.iloc[idxs_cv][idxs] # get args to sort by if sort_by_residuals: residuals = np.abs(Y_test - preds_proba) args = np.argsort(residuals)[::-1] dft = df.iloc[args] else: dft = df if lifetime_max is None: lifetime_max = np.max(dft.lifetime.values) if num_to_plot is None: num_to_plot = dft.shape[0] R = int(np.sqrt(num_to_plot)) C = num_to_plot // R # + 1 plt.figure(figsize=(C * width_mult, R * 2.5), dpi=200) i = 0 for r in range(R): for c in range(C): if i < dft.shape[0]: row = dft.iloc[i] ax = plt.subplot(R, C, i + 1) # show nums on tracks if show_track_num: ax.text(.5, .9, f'{i}', # row.pid horizontalalignment='right', transform=ax.transAxes) elif show_track_pid: ax.text(.5, .9, f'{row.pid}', # row.pid horizontalalignment='right', transform=ax.transAxes) # plt.axis('off') if '1.5s' in row['cell_num']: interval = 1.5 else: interval = 1 if plot_x: plt.plot(interval * np.arange(len(row["X"])), np.array(row["X"]) + DIFF, color=cr, label='clath', lw=2) # could do X_extended if plot_y: plt.plot(interval * np.arange(len(row["Y"])), np.array(row["Y"]) + DIFF, color=cg, label='aux', lw=2) if plot_z: plt.plot(interval * np.arange(len(row["Z"])), np.array(row["Z"]) + DIFF, color='gray', label='dyn') if xlim_constant: plt.xlim([-1, lifetime_max]) if plot_axhline: plt.axhline(aux_thresh, color='gray', alpha=0.5, lw=2) #plt.yscale('log') if ylim is not None: plt.ylim((ylim[0] + DIFF, ylim[1] + DIFF)) if not r == R - 1: plt.xticks([]) if not c == 0: plt.yticks([]) elif yticks is not None: plt.yticks(yticks, labels=yticklabels) i += 1 if text_labels: plt.text(len(row["X"]), row["X"][-1] + DIFF, 'Clathrin', color=cr, fontsize=text_label_size, fontweight='bold') plt.text(len(row["Y"]), row["Y"][-1] + DIFF, 'Auxilin', color=cg, fontsize=text_label_size, fontweight='bold') if plot_z: plt.text(len(row["Z"]), row["Z"][-1] + DIFF, 'Dynamin', fontsize=text_label_size, color='gray', fontweight='bold') plt.tight_layout() return dft
def viz_errs_1d(X_test, preds, preds_proba, Y_test, norms, key='lifetime')
-
visualize errs based on lifetime
Expand source code
def viz_errs_1d(X_test, preds, preds_proba, Y_test, norms, key='lifetime'): '''visualize errs based on lifetime ''' plt.figure(dpi=200) correct_idxs = preds == Y_test lifetime = X_test[key] * norms[key]['std'] + norms[key]['mu'] plt.plot(lifetime[(preds == Y_test) & (preds == 1)], preds_proba[(preds == Y_test) & (preds == 1)], 'o', color=cb, alpha=0.5, label='true pos') plt.plot(lifetime[(preds == Y_test) & (preds == 0)], preds_proba[(preds == Y_test) & (preds == 0)], 'x', color=cb, alpha=0.5, label='true neg') plt.plot(lifetime[preds > Y_test], preds_proba[preds > Y_test], 'o', color=cr, alpha=0.5, label='false pos') plt.plot(lifetime[preds < Y_test], preds_proba[preds < Y_test], 'x', color=cr, alpha=0.5, label='false neg') plt.xlabel(key) plt.ylabel('predicted probability') plt.legend() plt.show()
def viz_errs_2d(df, idxs_test, preds, Y_test, key1='x_pos', key2='y_pos', X=None, plot_correct=True)
-
visualize distribution of errs wrt to 2 dimensions
Expand source code
def viz_errs_2d(df, idxs_test, preds, Y_test, key1='x_pos', key2='y_pos', X=None, plot_correct=True): '''visualize distribution of errs wrt to 2 dimensions ''' x_pos = df[key1].iloc[idxs_test] y_pos = df[key2].iloc[idxs_test] plt.figure(dpi=200) ms = 4 me = 1 if plot_correct: plt.plot(x_pos[(preds == Y_test) & (preds == 1)], y_pos[(preds == Y_test) & (preds == 1)], 'o', color=cb, alpha=0.4, label='true pos', ms=ms, markeredgewidth=0) plt.plot(x_pos[(preds == Y_test) & (preds == 0)], y_pos[(preds == Y_test) & (preds == 0)], 'o', color=cr, alpha=0.4, label='true neg', ms=ms, markeredgewidth=0) plt.plot(x_pos[preds > Y_test], y_pos[preds > Y_test], 'x', color=cb, alpha=0.4, label='false pos', ms=ms, markeredgewidth=1) plt.plot(x_pos[preds < Y_test], y_pos[preds < Y_test], 'x', color=cr, alpha=0.4, label='false neg', ms=ms, markeredgewidth=1) plt.legend() # plt.scatter(x_pos, y_pos, c=preds==Y_test, alpha=0.5) plt.xlabel(key1) plt.ylabel(key2) plt.tight_layout()
def viz_errs_outliers_venn(X_test, preds, Y_test, num_feats_reduced=5)
-
Compare outliers to errors in venn-diagram
Expand source code
def viz_errs_outliers_venn(X_test, preds, Y_test, num_feats_reduced=5): '''Compare outliers to errors in venn-diagram ''' feat_names = data.get_feature_names(X_test) X_feat = X_test[feat_names] if num_feats_reduced is not None: pca = decomposition.PCA(n_components=num_feats_reduced) X_reduced = pca.fit_transform(X_feat) else: X_reduced = X_feat R, C = 2, 2 titles = ['isolation forest', 'local outlier factor', 'elliptic envelop', 'one-class svm'] plt.figure(figsize=(6, 5), dpi=200) for i in range(4): plt.subplot(R, C, i + 1) plt.title(titles[i]) if i == 0: clf = IsolationForest(n_estimators=10, warm_start=True) elif i == 1: clf = LocalOutlierFactor(novelty=True) elif i == 2: clf = EllipticEnvelope() elif i == 3: clf = OneClassSVM() clf.fit(X_reduced) # fit 10 trees is_outlier = clf.predict(X_reduced) == -1 is_err = preds != Y_test idxs = np.arange(is_outlier.size) venn2([set(idxs[is_outlier]), set(idxs[is_err])], set_labels=['outliers', 'errors'])