import pickle as pkl
import sys
import config
import data
from os.path import join as oj
import matplotlib.gridspec as grd
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib_venn import venn2
from sklearn import decomposition
from sklearn import metrics
from sklearn.covariance import EllipticEnvelope
from sklearn.ensemble import IsolationForest
from sklearn.neighbors import LocalOutlierFactor
from sklearn.svm import OneClassSVM
from sklearn.utils.multiclass import unique_labels
import os
import matplotlib.ticker as mtick
from config import DIR_FIGS
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import cm
from matplotlib.colors import ListedColormap
import dvu
# DIR_FILE = os.path.dirname(os.path.realpath(__file__)) # directory of this file
# DIR_FIGS = oj(DIR_FILE, '../reports/figs')
from skimage.external.tifffile import imread
from import imread
cb2 = '#66ccff'
cb = '#1f77b4'
co = '#ff7f0e'
cr = '#cc0000'
cp = '#cc3399'
cy = '#d8b365'
cg = '#5ab4ac'
cmap = LinearSegmentedColormap.from_list(
colors=[(205/255, 85/255, 51/255),
(50/255, 129/255, 168/255)]
def savefig(s: str, png=False):
# plt.tight_layout()
plt.savefig(oj(DIR_FIGS, 'fig_' + s + '.pdf'), bbox_inches='tight')
if png:
plt.savefig(oj(DIR_FIGS, 'fig_' + s + '.png'), dpi=300, bbox_inches='tight')
def fix_feat_name(s):
return s.replace('_', ' ').replace('X', 'Clath').capitalize()
def plot_confusion_matrix(y_true, y_pred, classes,
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
classes: np.ndarray(Str)
classes=np.array(['aux-', 'aux+'])
if not title:
if normalize:
title = 'Normalized confusion matrix'
title = 'Confusion matrix, without normalization'
# Compute confusion matrix
cm = metrics.confusion_matrix(y_true, y_pred)
# Only use the labels that appear in the data
classes = classes[unique_labels(y_true.astype(, y_pred.astype(]
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
# fig, ax = plt.subplots()
im = plt.imshow(cm, interpolation='nearest', cmap=cmap)
ax = plt.gca()
# ax.figure.colorbar(im, ax=ax)
# We want to show all ticks...
# ... and label them with the respective list entries
xticklabels=classes, yticklabels=classes,
# title=title,
ylabel='True label',
xlabel='Predicted label')
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
# Loop over data dimensions and create text annotations.
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
return ax
def highlight_max(data, color='#0e5c99'):
highlight the maximum in a Series or DataFrame
attr = 'background-color: {}'.format(color)
if data.ndim == 1: # Series from .apply(axis=0) or axis=1
is_max = data == data.max()
return [attr if v else '' for v in is_max]
else: # from .apply(axis=None)
is_max = data == data.max().max()
return pd.DataFrame(np.where(is_max, attr, ''),
index=data.index, columns=data.columns)
# visualize biggest errs
def viz_biggest_errs(df, idxs_cv, idxs, Y_test, preds, preds_proba,
ylim: tuple=None,
'''Visualize X and Y where the top examples are the most wrong / least confident
idxs_cv: integer ndarray
which idxs are not part of the test set (usually just 0, 1, 2, ...)
idxs: boolean ndarray
subset of points to plot
DIFF = 0 # use this to ensure values are all positive
# deal with idxs
if idxs is not None:
Y_test = Y_test[idxs]
preds = preds[idxs]
preds_proba = preds_proba[idxs]
if idxs_cv is None:
idxs_cv = np.arange(df.shape[0])
df = df.iloc[idxs_cv][idxs]
# get args to sort by
if sort_by_residuals:
residuals = np.abs(Y_test - preds_proba)
args = np.argsort(residuals)[::-1]
dft = df.iloc[args]
dft = df
if lifetime_max is None:
lifetime_max = np.max(dft.lifetime.values)
if num_to_plot is None:
num_to_plot = dft.shape[0]
R = int(np.sqrt(num_to_plot))
C = num_to_plot // R # + 1
plt.figure(figsize=(C * width_mult, R * 2.5), dpi=200)
i = 0
for r in range(R):
for c in range(C):
if i < dft.shape[0]:
row = dft.iloc[i]
ax = plt.subplot(R, C, i + 1)
# show nums on tracks
if show_track_num:
ax.text(.5, .9, f'{i}', #
elif show_track_pid:
ax.text(.5, .9, f'{}', #
# plt.axis('off')
if '1.5s' in row['cell_num']:
interval = 1.5
interval = 1
if plot_x:
plt.plot(interval * np.arange(len(row["X"])), np.array(row["X"]) + DIFF, color=cr, label='clath', lw=2) # could do X_extended
if plot_y:
plt.plot(interval * np.arange(len(row["Y"])), np.array(row["Y"]) + DIFF, color=cg, label='aux', lw=2)
if plot_z:
plt.plot(interval * np.arange(len(row["Z"])), np.array(row["Z"]) + DIFF, color='gray', label='dyn')
if xlim_constant:
plt.xlim([-1, lifetime_max])
if plot_axhline:
plt.axhline(aux_thresh, color='gray', alpha=0.5, lw=2)
if ylim is not None:
plt.ylim((ylim[0] + DIFF, ylim[1] + DIFF))
if not r == R - 1:
if not c == 0:
elif yticks is not None:
plt.yticks(yticks, labels=yticklabels)
i += 1
if text_labels:
plt.text(len(row["X"]), row["X"][-1] + DIFF, 'Clathrin', color=cr,
fontsize=text_label_size, fontweight='bold')
plt.text(len(row["Y"]), row["Y"][-1] + DIFF, 'Auxilin', color=cg,
fontsize=text_label_size, fontweight='bold')
if plot_z:
plt.text(len(row["Z"]), row["Z"][-1] + DIFF, 'Dynamin',
fontsize=text_label_size, color='gray', fontweight='bold')
return dft
def viz_errs_2d(df, idxs_test, preds, Y_test, key1='x_pos', key2='y_pos', X=None, plot_correct=True):
'''visualize distribution of errs wrt to 2 dimensions
x_pos = df[key1].iloc[idxs_test]
y_pos = df[key2].iloc[idxs_test]
ms = 4
me = 1
if plot_correct:
plt.plot(x_pos[(preds == Y_test) & (preds == 1)], y_pos[(preds == Y_test) & (preds == 1)], 'o',
color=cb, alpha=0.4, label='true pos', ms=ms, markeredgewidth=0)
plt.plot(x_pos[(preds == Y_test) & (preds == 0)], y_pos[(preds == Y_test) & (preds == 0)], 'o',
color=cr, alpha=0.4, label='true neg', ms=ms, markeredgewidth=0)
plt.plot(x_pos[preds > Y_test], y_pos[preds > Y_test], 'x', color=cb,
alpha=0.4, label='false pos', ms=ms, markeredgewidth=1)
plt.plot(x_pos[preds < Y_test], y_pos[preds < Y_test], 'x', color=cr,
alpha=0.4, label='false neg', ms=ms, markeredgewidth=1)
# plt.scatter(x_pos, y_pos, c=preds==Y_test, alpha=0.5)
def viz_errs_1d(X_test, preds, preds_proba, Y_test, norms, key='lifetime'):
'''visualize errs based on lifetime
correct_idxs = preds == Y_test
lifetime = X_test[key] * norms[key]['std'] + norms[key]['mu']
plt.plot(lifetime[(preds == Y_test) & (preds == 1)], preds_proba[(preds == Y_test) & (preds == 1)], 'o',
color=cb, alpha=0.5, label='true pos')
plt.plot(lifetime[(preds == Y_test) & (preds == 0)], preds_proba[(preds == Y_test) & (preds == 0)], 'x',
color=cb, alpha=0.5, label='true neg')
plt.plot(lifetime[preds > Y_test], preds_proba[preds > Y_test], 'o', color=cr, alpha=0.5, label='false pos')
plt.plot(lifetime[preds < Y_test], preds_proba[preds < Y_test], 'x', color=cr, alpha=0.5, label='false neg')
plt.ylabel('predicted probability')
def plot_above_threshold(x1, y1, b1, x2, y2, b2, ax, color, lsty):
sl1 = (y2 - y1)/(x2 - x1)
sl2 = (b2 - b1)/(x2 - x1)
if y1 >= b1 and y2 >= b2:
ax.plot([x1, x2], [y1, y2], linestyle=lsty, color=color, alpha=1)
elif y1 < b1 and y2 < b2:
ax.plot([x1, x2], [y1, y2], linestyle=lsty, color=color, alpha=.1)
elif y1 >= b1 and y2 < b2:
crosspoint_x, crosspoint_y = x1 + (y1 - b1)/(sl2 - sl1), y1 + sl1 * (y1 - b1)/(sl2 - sl1)
ax.plot([x1, crosspoint_x], [y1, crosspoint_y], linestyle=lsty, color=color, alpha=1)
ax.plot([crosspoint_x, x2], [crosspoint_y, y2], linestyle=lsty, color=color, alpha=.1)
elif y1 < b1 and y2 >= b2:
crosspoint_x, crosspoint_y = x1 + (y1 - b1)/(sl2 - sl1), y1 + sl1 * (y1 - b1)/(sl2 - sl1)
ax.plot([x1, crosspoint_x], [y1, crosspoint_y], linestyle=lsty, color=color, alpha=.1)
ax.plot([crosspoint_x, x2], [crosspoint_y, y2], linestyle=lsty, color=color, alpha=1)
def plot_background(interval, bg, trace, color, ax):
ax.fill_between(interval * np.arange(len(bg)),
[0] * len(bg),
2 * np.array(bg),
x, y, lt = np.arange(len(trace)), np.array(trace), len(trace)
#ax.plot(interval * x, y, linestyle='--', color=cr, alpha=.2)
bg = 2 * np.array(bg)
for f in range(lt - 1):
lsty = '--' if f < 5 or f >= lt - 5 else '-'
def plot_curves(df, extra_key=None, extra_key_label=None,
hline=True, R=5, C=8,
fig=None, ylim_constant=False, background=False, ylim_cla=None,
ylim_aux=None, ylim_dyn=None,
xlim_constant=True, legend=True, plot_x=True,
yticks=None, yticklabels=None, num_axes=3, show_track_pid=False,
'''Plot time-series curves from df
DIFF = 0
if fig is None:
plt.figure(figsize=(16, 10), dpi=200, facecolor='white')
lifetime_max = np.max(df.lifetime.values[:R * C])
df = df.iloc[range(R * C)]
for i in range(R * C):
if i < df.shape[0]:
ax = plt.subplot(R, C, i + 1)
row = df.iloc[i]
if '1.5s' in row['cell_num']:
interval = 1.5
interval = 1
if num_axes == 1:
if plot_x:
plt.plot(interval * np.arange(len(row.X_extended)), np.array(row.X_extended) + DIFF, linestyle='--', color=cr)
plt.plot(interval * np.arange(len(row.Y_extended)), np.array(row.Y_extended) + DIFF, linestyle='--', color=cg)
plt.plot(interval * np.arange(5, len(row.X_extended)-5), np.array(row.X_extended)[5:(-5)] + DIFF, color=cr, label='Clathrin')
plt.plot(interval * np.arange(5, len(row.Y_extended)-5), np.array(row.Y_extended)[5:(-5)] + DIFF, color=cg, label='Auxilin')
#plt.plot(interval * np.arange(5), np.array(row.X_extended)[-5:] + DIFF, linestyle='--', color=cr, label='Clathrin')
#plt.plot(interval * np.arange(5), np.array(row.Y_extended)[-5:] + DIFF, linestyle='--', color=cg, label='Auxilin')
if background:
ax.plot(interval * np.arange(len(row.X_extended)), np.array(row.X_c_extended),
color=cr, linewidth=.8)
ax.fill_between(interval * np.arange(len(row.X_extended)),
np.array(row.X_extended) - np.array(row.X_std_extended),
np.array(row.X_extended) + np.array(row.X_std_extended),
if hline:
plt.axhline(642.3754691658837, color='gray', alpha=0.5)
if extra_key is not None:
if extra_key_label is None:
if extra_key == 'Z':
extra_key_label = 'Dynamin'
extra_key_label = extra_key
plt.plot(interval * np.arange(len(row[extra_key])), np.array(row[extra_key]) + DIFF, linestyle='--', color='gray')
plt.plot(interval * np.arange(5, len(row[extra_key])-5), np.array(row[extra_key])[5:(-5)] + DIFF, color='gray', label=extra_key_label)
if xlim_constant:
if xlim is None:
plt.xlim([-1, lifetime_max + 1])
if ylim_constant:
if ylim is None:
plt.ylim([-10, max(max(df.X_max), max(df.Y_max)) + 1])
plt.ylim(ylim[0] + DIFF, ylim[1] + DIFF)
if yticks is not None:
plt.yticks(yticks, labels=yticklabels)
twin1 = ax.twinx()
if num_axes == 3:
twin2 = ax.twinx()
twin2.spines['right'].set_position(("axes", 1.2))
twin2 = twin1
if show_track_pid:
ax.text(.5, .9, f'{}', #
if plot_x:
p1, = ax.plot(interval * np.arange(len(row.X_extended)), np.array(row.X_extended) + DIFF, linestyle='--', color=cr, alpha=.1)
if i == 0:
ax.text(x=interval * len(row.X_extended),
if background:
plot_background(interval, row.X_sigma_extended, row.X_extended, color=cr, ax=ax)
ax.plot(interval * np.arange(5, len(row.X_extended)-5), np.array(row.X_extended)[5:(-5)] + DIFF, color=cr)
if i == 0 and legend:
ax.fill_between(interval * np.arange(len(row.X_extended)),
np.array(row.X_extended) - np.array(row.X_std_extended),
np.array(row.X_extended) + np.array(row.X_std_extended),
p2, = twin1.plot(interval * np.arange(len(row.Y_extended)), np.array(row.Y_extended) + DIFF, linestyle='--', color=cg, alpha=.1)
if background:
plot_background(interval, row.Y_sigma_extended, row.Y_extended, color=cg, ax=twin1
twin1.plot(interval * np.arange(5, len(row.Y_extended)-5), np.array(row.Y_extended)[5:(-5)] + DIFF, color=cg, label='EGFP-Aux1-GAK-F6')
if i == 0 and legend:
twin1.fill_between(interval * np.arange(len(row.Y_extended)),
np.array(row.Y_extended) - np.array(row.Y_std_extended),
np.array(row.Y_extended) + np.array(row.Y_std_extended),
if i == 0:
twin1.text(x=interval * len(row.Y_extended),
#plt.plot(interval * np.arange(5), np.array(row.X_extended)[-5:] + DIFF, linestyle='--', color=cr, label='Clathrin')
#plt.plot(interval * np.arange(5), np.array(row.Y_extended)[-5:] + DIFF, linestyle='--', color=cg, label='Auxilin')
if hline:
ax.axhline(642.3754691658837, color='gray', alpha=0.5)
if extra_key is not None:
if extra_key_label is None:
if extra_key == 'Z':
extra_key_label = 'Dynamin'
extra_key_label = extra_key
p3, = twin2.plot(interval * np.arange(len(row.Z_extended)), np.array(row.Z_extended) + DIFF, linestyle='--', color='gray', alpha=.1)
if background:
plot_background(interval, row.Z_sigma_extended, row.Z_extended, color='gray', ax=twin2)
twin2.plot(interval * np.arange(5, len(row.Z_extended)-5), np.array(row.Z_extended)[5:(-5)] + DIFF, color='gray')
twin2.fill_between(interval * np.arange(len(row.Z_extended)),
np.array(row.Z_extended) - np.array(row.Z_std_extended),
np.array(row.Z_extended) + np.array(row.Z_std_extended),
if i == 0:
twin2.text(x=interval * len(row.Z_extended),
#if i == 0 and legend:
# dvu.line_legend()
tkw = dict(size=4, width=1.5)
ax.tick_params(axis='y', colors=cr, labelsize=6, **tkw)
if num_axes == 3:
twin2.tick_params(axis='y', colors=p3.get_color(), labelsize=6, **tkw)
if ylim_constant:
#p1_ylim = ax.get_ylim()
p2_ylim = twin1.get_ylim()
p3_ylim = twin2.get_ylim()
ylim_min = min(p2_ylim[0], p3_ylim[0])
twin1.set_ylim((ylim_min, 2*p2_ylim[1]))
twin2.set_ylim((ylim_min, 3*p3_ylim[1]))
p1_ylim = ax.get_ylim()
ax.set_ylim((- 2 * p1_ylim[1], p1_ylim[1]))
p2_ylim = twin1.get_ylim()
twin1.set_ylim((- 0.5 * p2_ylim[1], p2_ylim[1]))
#p3_ylim = twin2.get_ylim()
#twin1.set_ylim((- 0.5 * p3_ylim[1], p3_ylim[1]))
twin1.tick_params(axis='y', colors=cg, labelsize=6, **tkw)
ax.tick_params(axis='x', **tkw)
yticks = ax.get_yticks()
#if len(yticks) > 5:
ax.set_yticks([yt for yt in yticks if yt >= 0])
yticks = twin1.get_yticks()
#if len(yticks) > 5:
twin1.set_yticks(yticks[np.arange(1, len(yticks), 2)])
if num_axes == 3:
yticks = twin2.get_yticks()
if len(yticks) > 5:
twin2.set_yticks(yticks[np.arange(1, len(yticks), 2)])
if xlim_constant:
if xlim is None:
plt.xlim([-1, lifetime_max + 16])
if axes_invisible:
# plt.axi('off')
# plt.legend()
if fig is None:
def viz_errs_outliers_venn(X_test, preds, Y_test, num_feats_reduced=5):
'''Compare outliers to errors in venn-diagram
feat_names = data.get_feature_names(X_test)
X_feat = X_test[feat_names]
if num_feats_reduced is not None:
pca = decomposition.PCA(n_components=num_feats_reduced)
X_reduced = pca.fit_transform(X_feat)
X_reduced = X_feat
R, C = 2, 2
titles = ['isolation forest', 'local outlier factor', 'elliptic envelop', 'one-class svm']
plt.figure(figsize=(6, 5), dpi=200)
for i in range(4):
plt.subplot(R, C, i + 1)
if i == 0:
clf = IsolationForest(n_estimators=10, warm_start=True)
elif i == 1:
clf = LocalOutlierFactor(novelty=True)
elif i == 2:
clf = EllipticEnvelope()
elif i == 3:
clf = OneClassSVM() # fit 10 trees
is_outlier = clf.predict(X_reduced) == -1
is_err = preds != Y_test
idxs = np.arange(is_outlier.size)
venn2([set(idxs[is_outlier]), set(idxs[is_err])], set_labels=['outliers', 'errors'])
def plot_pcs(pca, X):
'''Pretty plot of pcs with explained var bars
pca: sklearn PCA class after being fitted
plt.figure(figsize=(6, 9), dpi=200)
# extract out relevant pars
comps = pca.components_.transpose()
var_norm = pca.explained_variance_ / np.sum(pca.explained_variance_) * 100
# create a 2 X 2 grid
gs = grd.GridSpec(2, 2, height_ratios=[2, 10],
width_ratios=[12, 1], wspace=0.1, hspace=0)
# plot explained variance
ax2 = plt.subplot(gs[0]), comps.shape[1]), var_norm,
color='gray', width=0.8)
plt.title('Explained variance (%)')
ax2.set_yticks([0, max(var_norm)])
plt.xlim((-0.5, comps.shape[1] - 0.5))
# plot pcs
ax = plt.subplot(gs[2])
vmaxabs = np.max(np.abs(comps))
p = ax.imshow(comps, interpolation='None', aspect='auto',
cmap=sns.diverging_palette(10, 240, as_cmap=True, center='light'),
vmin=-vmaxabs, vmax=vmaxabs) # center at 0
plt.xlabel('PCA component number')
# make colorbar
colorAx = plt.subplot(gs[3])
cb = plt.colorbar(p, cax=colorAx)
def print_metadata(acc=None, metadata_file=oj(config.DIR_PROCESSED, 'metadata_clath_aux+gak_a7d2.pkl')):
m = pkl.load(open(metadata_file, 'rb'))
f'valid:\t\t{m["num_aux_pos_valid"]:>4.0f} aux+ / {m["num_tracks_valid"]:>4.0f} ({m["num_aux_pos_valid"] / m["num_tracks_valid"]:.3f})')
print(f'hotspots:\t{m["num_hotspots_valid"]:>4.0f} aux+ / {m["num_hotspots_valid"]:>4.0f}')
f'short:\t\t{m["num_short"] - m["num_short"] * m["acc_short"]:>4.0f} aux+ / {m["num_short"]:>4.0f} ({m["acc_short"]:.3f})')
print(f'long:\t\t{m["num_long"] * m["acc_long"]:>4.0f} aux+ / {m["num_long"]:>4.0f} ({m["acc_long"]:.3f})')
f'hard:\t\t{m["num_aux_pos_hard"]:>4.0f} aux+ / {m["num_tracks_hard"]:>4.0f} ({m["num_aux_pos_hard"] / m["num_tracks_hard"]:.3f})')
if acc is not None:
print(f'hard acc:\t\t\t {acc:.3f}')
num_eval = m["num_tracks_valid"] - m["num_hotspots_valid"]
# print(
# f'total acc (no hotspots):\t {(m["num_short"] * m["acc_short"] + m["num_long"] * m["acc_long"] + acc * m["num_tracks_hard"]) / num_eval:.3f}')
print('\nlifetime threshes', m['thresh_short'], m['thresh_long'])
def jointplot_grouped(col_x: str, col_y: str, col_k: str, df,
k_is_color=False, scatter_alpha=.5, add_global_hists: bool = False, ms=None):
'''Jointplot of hists + densities
name of X var
name of Y var
name of variable to group/color by
whether to plot the global hist as well
def colored_scatter(x, y, c=None):
def scatter(*args, **kwargs):
args = (x, y)
if c is not None:
kwargs['c'] = c
kwargs['marker'] = '.'
kwargs['alpha'] = scatter_alpha
plt.scatter(*args, **kwargs)
return scatter
g = sns.JointGrid(
color = None
legends = []
for name, df_group in df.groupby(col_k):
if k_is_color:
color = name
colored_scatter(df_group[col_x], df_group[col_y], color),
if add_global_hists:
# 2d decision boundary
def plot_decision_boundary(X_col, Y_col, m, df, norms, num_pts=100):
'''still not finished...
x = df[X_col]
y = df[Y_col]
x = np.linspace(x.min(), x.max(), num_pts)
y = np.linspace(y.min(), y.max(), num_pts)
# normalize
xv, yv = np.meshgrid(x, y, indexing='ij')
x = xv.flatten()
y = yv.flatten()
x = (x - norms[X_col]['mu']) / (norms[X_col]['std'])
y = (y - norms[Y_col]['mu']) / (norms[Y_col]['std'])
X = np.hstack((x, y)).reshape(-1, 2)
X = df[results_individual['feat_names_selected']]
preds = m.predict(X)
def cumulative_acc_plot_hard(preds_proba, preds, y_full_cv):
args = np.argsort(np.abs(preds_proba - 0.5))[::-1]
accs = (preds == y_full_cv)[args]
n = accs.size
accs = np.cumsum(accs) / np.arange(1, n + 1)
plt.plot(preds_proba[args], '.', ms=0.5, label='predicted prob', color=cb)
plt.plot(accs, label='cumulative acc', color=cr)
plt.yticks(np.arange(-0.05, 1.05, 0.1))
plt.xlabel('num pts included')
def cumulative_acc_plot_all(df, pred_proba_key='preds_proba', pred_key='preds',
plot_vert_line_for_high_lifetimes=False, show=True):
ax = plt.subplot(111)
# full (no model)
argsf = np.argsort(df.lifetime.values)
accsf = (1 - df[outcome_def]).values[argsf]
n = df.shape[0]
plt.plot(np.cumsum(accsf) / np.arange(1, accsf.size + 1), label='Predicting all abortive', color='gray')
print('accsf', np.sum(accsf))
# short
ds = df[df.short]
argss = np.argsort(ds.lifetime.values)
accss = (1 - ds[outcome_def]).values[argss]
ns = ds.shape[0]
# hard
dh = df[~df.short]
argsh = np.argsort(np.abs(dh[pred_key])) #[::-1]
accsh = ((dh[pred_key].values > 0) == dh[outcome_def].values)[argsh]
# put things together
accs = np.hstack((accss, accsh))
print(accsf.shape, accss.shape, accsh.shape, accs.shape)
plt.plot(np.cumsum(accs) / np.arange(1, accs.size + 1), label='LSTM', color=cb)
plt.axvline(ns, lw=2.5, color='black')
# dvu.line_legend()
plt.xlabel('Percentage of tracks included (sorted by uncertainty)')
ax.xaxis.set_ticks([int(x) for x in np.arange(0, n + 1, n//5)])
ax.xaxis.set_ticklabels([str(int(x)) + '%' for x in np.arange(0, 101, 100/5)])
plt.legend(fontsize='x-large', frameon=False, labelcolor='linecolor')
def plot_example(ex):
'''ex - row of the dataframe
plt.plot(ex['X'], color='red', label='clathrin')
plt.plot(ex['Y'], color='green', label='auxilin')
def get_videos(cell_dir: str):
'''Loads in X and Y for one cell
Path to directory for one cell
fname = {'cla': 'TagRFP', 'aux': 'EGFP', 'dyn': 'JF646'}
videos = {}
for m in fname:
for name in os.listdir(oj(cell_dir, fname[m])):
#print(f"filename: {name}")
if 'tif' in name:
videos[m] = imread(oj(cell_dir, fname[m], name))
return videos
def get_all_dynamin_videos(cells):
all_videos = {}
upper_dir = oj('/scratch/users/vision/data/abc_data/dynamin_data_with_ims/',
'CLTA-TagRFP EGFP-Aux1-GAK-F6 Dyn2-Halo-E1-JF646')
for cell_num in cells:
cell_dir = cell_num[:-6] + 'Cell1_1.5s'
full_dir = oj(upper_dir, cell_dir)
all_videos[cell_num] = get_videos(full_dir)
return all_videos
def get_dynamin_data_videos(df, pids, add_px=2, apply_norm=True):
extract videos of dynamin traces
df: pd.DataFrame
pids: list
list of pids to plot
add_px: int
number of additional pixels in each direction
add_px=1 means 3*3 pixels around the center, add_px=2 means 5*5, etc.
indices = np.array([np.where( == pid)[0][0] for pid in pids])
df = df.iloc[indices]
cells = set(df.cell_num.values)
raw_videos = get_all_dynamin_videos(cells)
videos = {}
for i in range(len(df)):
cell_num = df.cell_num.iloc[i]
fr, h, w = raw_videos[cell_num]['cla'].shape
pid =[i]
videos[pid] = {}
x_pos, y_pos = df.x_pos_seq.iloc[i], df.y_pos_seq.iloc[i]
x_pos = [x_pos[0]]*5 + x_pos + [x_pos[-1]]*5
y_pos = [y_pos[0]]*5 + y_pos + [y_pos[-1]]*5
t, lt = df.t.iloc[i] - 5*1.5, min(len(x_pos), len(y_pos))
for m in ['cla', 'aux', 'dyn']:
videos[pid][m] = []
for j in range(lt):
for m in ['cla', 'aux', 'dyn']:
crop_y_pos = np.maximum(0, np.minimum(h - 1, np.arange(int(y_pos[j]) - add_px, int(y_pos[j]) + add_px + 1)))
crop_x_pos = np.maximum(0, np.minimum(h - 1, np.arange(int(x_pos[j]) - add_px, int(x_pos[j]) + add_px + 1)))
videos[pid][m].append(raw_videos[cell_num][m][int(t/1.5) + j,:,:] \
[crop_y_pos, :] \
[:, crop_x_pos])
# normalize by the min/max intensities
#vmin, vmax = raw_videos[cell_num][m][int(t/1.5) + j,:,:].mean(), raw_videos[cell_num][m][int(t/1.5) + j,:,:].max()
#videos[pid][m][-1] = (videos[pid][m][-1] - vmin)/(vmax - vmin)
# normalization
norm = {}
for m in ['cla', 'aux', 'dyn']:
norm[m] = [1e9, -1e9] # min, max
for pid in videos:
for j in range(len(videos[pid][m])):
norm[m][0] = min(norm[m][0], videos[pid][m][j].min())
norm[m][1] = max(norm[m][1], videos[pid][m][j].max())
if apply_norm:
for pid in videos:
for m in ['cla', 'aux', 'dyn']:
for j in range(len(videos[pid][m])):
videos[pid][m][j] = (videos[pid][m][j] - norm[m][0])/(norm[m][1] - norm[m][0])
return videos, norm
def plot_kymographs(df, pids, add_px=2):
plot kymographs of dynamin traces
df: pd.DataFrame
pids: list
list of pids to plot
add_px: int
number of additional pixels in each direction
add_px=1 means 3*3 pixels around the center, add_px=2 means 5*5, etc.
cla_traces: np.array
clathrin traces from raw images
aux_traces: np.array
auxilin traces from raw images
rgb_image: 3d np.array
3d array (RGB values) of kymographs
indices = np.array([np.where( == pid)[0][0] for pid in pids])
df = df.iloc[indices]
cells = set(df.cell_num.values)
raw_videos = get_all_dynamin_videos(cells)
viridis = cm.get_cmap('viridis', 12)
reds = cm.get_cmap('Reds', 12) # red palette for clathrin
greens = cm.get_cmap('Greens', 12) # green palette for auxilin
lmax = max([len(df.x_pos_seq.iloc[i]) for i in range(len(df))]) + 2
width = 2 * add_px + 1
cla_traces, aux_traces = {}, {}
for i in range(len(df)):
cla_traces[i], aux_traces[i] = np.zeros((lmax, width)), np.zeros((lmax, width))
#xmean = X[df.cell_num.iloc[i]].mean()
cell_num = df.cell_num.iloc[i]
x_pos, y_pos = df.x_pos_seq.iloc[i], df.y_pos_seq.iloc[i]
t, lt = df.t.iloc[i], min(len(x_pos), len(y_pos))
for k in range(-add_px, add_px + 1):
for j in range(lt):
video = raw_videos[cell_num]['cla']
cla_traces[i][j, k + add_px] = max(video[int(t/1.5) + j, \
range(int(y_pos[j]) - 0, int(y_pos[j]) + 0 + 1), \
int(x_pos[j] + k)])
vmin, vmax = video[int(t/1.5) + j,:,:].min(), video[int(t/1.5) + j,:,:].max()
cla_traces[i][j, k + add_px] = (cla_traces[i][j, k + add_px] - vmin)/(vmax - vmin)
video = raw_videos[cell_num]['aux']
aux_traces[i][j, k + add_px] = max(video[int(t/1.5) + j, \
range(int(y_pos[j]) - 0, int(y_pos[j]) + 0 + 1), \
int(x_pos[j] + k)])
vmin, vmax = video[int(t/1.5) + j,:,:].min(), video[int(t/1.5) + j,:,:].max()
aux_traces[i][j, k + add_px] = (aux_traces[i][j, k + add_px] - vmin)/(vmax - vmin)
ncol = 3 * width * len(df)
cla_sparse = np.zeros((lmax, ncol))
aux_sparse = np.zeros((lmax, ncol))
for i in range(len(df)):
start_index = 3 * width * i
cla_sparse[:, (start_index):(start_index + width)] = cla_traces[i]
aux_sparse[:, (start_index + width):(start_index + 2 * width)] = aux_traces[i]
rgb_image = np.array([[list(reds(cla_sparse[i][j])[:3])
#[1, 1 - cla_sparse[i][j], 1 - cla_sparse[i][j]] \
if 0 < cla_sparse[i][j] < 1 \
else \
list(greens(aux_sparse[i][j])[:3]) \
#[1 - aux_sparse[i][j], 1, 1 - aux_sparse[i][j]] \
if 0 < aux_sparse[i][j] < 1 \
#else list(viridis(np.random.choice(background, 1)[0])[:3])\
else (1, 1, 1)
for i in range(lmax)] \
for j in range(ncol)])
#cla_sparse = np.transpose(cla_sparse)
#aux_sparse = np.transpose(aux_sparse)
return cla_traces, aux_traces, rgb_image
def cumulative_acc_plot_all(df, pred_proba_key='preds_proba', pred_key='preds', outcome_def='y_consec_thresh', plot_vert_line_for_high_lifetimes=False, show=True)
def cumulative_acc_plot_all(df, pred_proba_key='preds_proba', pred_key='preds', outcome_def='y_consec_thresh', plot_vert_line_for_high_lifetimes=False, show=True): plt.figure(dpi=500) ax = plt.subplot(111) # full (no model) argsf = np.argsort(df.lifetime.values) accsf = (1 - df[outcome_def]).values[argsf] n = df.shape[0] plt.plot(np.cumsum(accsf) / np.arange(1, accsf.size + 1), label='Predicting all abortive', color='gray') print('accsf', np.sum(accsf)) # short ds = df[df.short] argss = np.argsort(ds.lifetime.values) accss = (1 - ds[outcome_def]).values[argss] ns = ds.shape[0] # hard dh = df[~df.short] argsh = np.argsort(np.abs(dh[pred_key])) #[::-1] accsh = ((dh[pred_key].values > 0) == dh[outcome_def].values)[argsh] # put things together accs = np.hstack((accss, accsh)) print(accsf.shape, accss.shape, accsh.shape, accs.shape) plt.plot(np.cumsum(accs) / np.arange(1, accs.size + 1), label='LSTM', color=cb) print(accs) plt.axvline(ns, lw=2.5, color='black') # dvu.line_legend() plt.xlabel('Percentage of tracks included (sorted by uncertainty)') plt.ylabel('Accuracy') ax.xaxis.set_ticks([int(x) for x in np.arange(0, n + 1, n//5)]) ax.xaxis.set_ticklabels([str(int(x)) + '%' for x in np.arange(0, 101, 100/5)]) plt.legend(fontsize='x-large', frameon=False, labelcolor='linecolor') plt.grid(alpha=0.2) plt.tight_layout()
def cumulative_acc_plot_hard(preds_proba, preds, y_full_cv)
def cumulative_acc_plot_hard(preds_proba, preds, y_full_cv): args = np.argsort(np.abs(preds_proba - 0.5))[::-1] accs = (preds == y_full_cv)[args] n = accs.size accs = np.cumsum(accs) / np.arange(1, n + 1) plt.figure(dpi=500) plt.plot(preds_proba[args], '.', ms=0.5, label='predicted prob', color=cb) plt.plot(accs, label='cumulative acc', color=cr) plt.yticks(np.arange(-0.05, 1.05, 0.1)) plt.xlabel('num pts included') plt.grid(alpha=0.2) plt.legend()
def fix_feat_name(s)
def fix_feat_name(s): return s.replace('_', ' ').replace('X', 'Clath').capitalize()
def get_all_dynamin_videos(cells)
def get_all_dynamin_videos(cells): all_videos = {} upper_dir = oj('/scratch/users/vision/data/abc_data/dynamin_data_with_ims/', 'CLTA-TagRFP EGFP-Aux1-GAK-F6 Dyn2-Halo-E1-JF646') for cell_num in cells: cell_dir = cell_num[:-6] + 'Cell1_1.5s' full_dir = oj(upper_dir, cell_dir) all_videos[cell_num] = get_videos(full_dir) return all_videos
def get_dynamin_data_videos(df, pids, add_px=2, apply_norm=True)
extract videos of dynamin traces
df: pd.DataFrame dataframe
pids: list list of pids to plot
add_px: int number of additional pixels in each direction add_px=1 means 33 pixels around the center, add_px=2 means 55, etc.
def get_dynamin_data_videos(df, pids, add_px=2, apply_norm=True): """ extract videos of dynamin traces Params: ------ df: pd.DataFrame dataframe pids: list list of pids to plot add_px: int number of additional pixels in each direction add_px=1 means 3*3 pixels around the center, add_px=2 means 5*5, etc. Returns: ------ videos """ indices = np.array([np.where( == pid)[0][0] for pid in pids]) df = df.iloc[indices] cells = set(df.cell_num.values) raw_videos = get_all_dynamin_videos(cells) videos = {} for i in range(len(df)): cell_num = df.cell_num.iloc[i] fr, h, w = raw_videos[cell_num]['cla'].shape pid =[i] videos[pid] = {} x_pos, y_pos = df.x_pos_seq.iloc[i], df.y_pos_seq.iloc[i] x_pos = [x_pos[0]]*5 + x_pos + [x_pos[-1]]*5 y_pos = [y_pos[0]]*5 + y_pos + [y_pos[-1]]*5 t, lt = df.t.iloc[i] - 5*1.5, min(len(x_pos), len(y_pos)) for m in ['cla', 'aux', 'dyn']: videos[pid][m] = [] for j in range(lt): for m in ['cla', 'aux', 'dyn']: crop_y_pos = np.maximum(0, np.minimum(h - 1, np.arange(int(y_pos[j]) - add_px, int(y_pos[j]) + add_px + 1))) crop_x_pos = np.maximum(0, np.minimum(h - 1, np.arange(int(x_pos[j]) - add_px, int(x_pos[j]) + add_px + 1))) videos[pid][m].append(raw_videos[cell_num][m][int(t/1.5) + j,:,:] \ [crop_y_pos, :] \ [:, crop_x_pos]) # normalize by the min/max intensities #vmin, vmax = raw_videos[cell_num][m][int(t/1.5) + j,:,:].mean(), raw_videos[cell_num][m][int(t/1.5) + j,:,:].max() #videos[pid][m][-1] = (videos[pid][m][-1] - vmin)/(vmax - vmin) # normalization norm = {} for m in ['cla', 'aux', 'dyn']: norm[m] = [1e9, -1e9] # min, max for pid in videos: for j in range(len(videos[pid][m])): norm[m][0] = min(norm[m][0], videos[pid][m][j].min()) norm[m][1] = max(norm[m][1], videos[pid][m][j].max()) if apply_norm: for pid in videos: for m in ['cla', 'aux', 'dyn']: for j in range(len(videos[pid][m])): videos[pid][m][j] = (videos[pid][m][j] - norm[m][0])/(norm[m][1] - norm[m][0]) return videos, norm
def get_videos(cell_dir)
Loads in X and Y for one cell
- Path to directory for one cell
def get_videos(cell_dir: str): '''Loads in X and Y for one cell Params ------ cell_dir Path to directory for one cell Returns ------- videos ''' fname = {'cla': 'TagRFP', 'aux': 'EGFP', 'dyn': 'JF646'} videos = {} for m in fname: for name in os.listdir(oj(cell_dir, fname[m])): #print(f"filename: {name}") if 'tif' in name: videos[m] = imread(oj(cell_dir, fname[m], name)) return videos
def highlight_max(data, color='#0e5c99')
highlight the maximum in a Series or DataFrame
def highlight_max(data, color='#0e5c99'): ''' highlight the maximum in a Series or DataFrame ''' attr = 'background-color: {}'.format(color) if data.ndim == 1: # Series from .apply(axis=0) or axis=1 is_max = data == data.max() return [attr if v else '' for v in is_max] else: # from .apply(axis=None) is_max = data == data.max().max() return pd.DataFrame(np.where(is_max, attr, ''), index=data.index, columns=data.columns)
def jointplot_grouped(col_x, col_y, col_k, df, k_is_color=False, scatter_alpha=0.5, add_global_hists=False, ms=None)
Jointplot of hists + densities Params
- name of X var
- name of Y var
- name of variable to group/color by
- whether to plot the global hist as well
def jointplot_grouped(col_x: str, col_y: str, col_k: str, df, k_is_color=False, scatter_alpha=.5, add_global_hists: bool = False, ms=None): '''Jointplot of hists + densities Params ------ col_x name of X var col_y name of Y var col_k name of variable to group/color by add_global_hists whether to plot the global hist as well ''' def colored_scatter(x, y, c=None): def scatter(*args, **kwargs): args = (x, y) if c is not None: kwargs['c'] = c kwargs['marker'] = '.' kwargs['alpha'] = scatter_alpha plt.scatter(*args, **kwargs) return scatter g = sns.JointGrid( x=col_x, y=col_y, data=df ) color = None legends = [] for name, df_group in df.groupby(col_k): legends.append(name) if k_is_color: color = name g.plot_joint( colored_scatter(df_group[col_x], df_group[col_y], color), ) sns.distplot( df_group[col_x].values, ax=g.ax_marg_x, color=color, ) sns.distplot( df_group[col_y].values, ax=g.ax_marg_y, color=color, vertical=True ) if add_global_hists: sns.distplot( df[col_x].values, ax=g.ax_marg_x, color='grey' ) sns.distplot( df[col_y].values.ravel(), ax=g.ax_marg_y, color='grey', vertical=True ) plt.legend(legends)
def plot_above_threshold(x1, y1, b1, x2, y2, b2, ax, color, lsty)
def plot_above_threshold(x1, y1, b1, x2, y2, b2, ax, color, lsty): sl1 = (y2 - y1)/(x2 - x1) sl2 = (b2 - b1)/(x2 - x1) if y1 >= b1 and y2 >= b2: ax.plot([x1, x2], [y1, y2], linestyle=lsty, color=color, alpha=1) elif y1 < b1 and y2 < b2: ax.plot([x1, x2], [y1, y2], linestyle=lsty, color=color, alpha=.1) elif y1 >= b1 and y2 < b2: crosspoint_x, crosspoint_y = x1 + (y1 - b1)/(sl2 - sl1), y1 + sl1 * (y1 - b1)/(sl2 - sl1) ax.plot([x1, crosspoint_x], [y1, crosspoint_y], linestyle=lsty, color=color, alpha=1) ax.plot([crosspoint_x, x2], [crosspoint_y, y2], linestyle=lsty, color=color, alpha=.1) elif y1 < b1 and y2 >= b2: crosspoint_x, crosspoint_y = x1 + (y1 - b1)/(sl2 - sl1), y1 + sl1 * (y1 - b1)/(sl2 - sl1) ax.plot([x1, crosspoint_x], [y1, crosspoint_y], linestyle=lsty, color=color, alpha=.1) ax.plot([crosspoint_x, x2], [crosspoint_y, y2], linestyle=lsty, color=color, alpha=1)
def plot_background(interval, bg, trace, color, ax)
def plot_background(interval, bg, trace, color, ax): ax.fill_between(interval * np.arange(len(bg)), [0] * len(bg), 2 * np.array(bg), alpha=.1, color=color) x, y, lt = np.arange(len(trace)), np.array(trace), len(trace) #ax.plot(interval * x, y, linestyle='--', color=cr, alpha=.2) bg = 2 * np.array(bg) for f in range(lt - 1): lsty = '--' if f < 5 or f >= lt - 5 else '-' plot_above_threshold(x1=interval*x[f], y1=y[f], b1=bg[f], x2=interval*x[f+1], y2=y[f+1], b2=bg[f+1], ax=ax, color=color, lsty=lsty)
def plot_confusion_matrix(y_true, y_pred, classes, normalize=False, title=None, cmap=<matplotlib.colors.LinearSegmentedColormap object>)
This function prints and plots the confusion matrix. Normalization can be applied by setting
. Params
)- classes=np.array(['aux-', 'aux+'])
def plot_confusion_matrix(y_true, y_pred, classes, normalize=False, title=None, """ This function prints and plots the confusion matrix. Normalization can be applied by setting `normalize=True`. Params ------ classes: np.ndarray(Str) classes=np.array(['aux-', 'aux+']) """ plt.figure(dpi=300) if not title: if normalize: title = 'Normalized confusion matrix' else: title = 'Confusion matrix, without normalization' # Compute confusion matrix cm = metrics.confusion_matrix(y_true, y_pred) # Only use the labels that appear in the data classes = classes[unique_labels(y_true.astype(, y_pred.astype(] if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] # fig, ax = plt.subplots() im = plt.imshow(cm, interpolation='nearest', cmap=cmap) ax = plt.gca() # ax.figure.colorbar(im, ax=ax) # We want to show all ticks... ax.set(xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]), # ... and label them with the respective list entries xticklabels=classes, yticklabels=classes, # title=title, ylabel='True label', xlabel='Predicted label') # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Loop over data dimensions and create text annotations. fmt = '.2f' if normalize else 'd' thresh = cm.max() / 2. for i in range(cm.shape[0]): for j in range(cm.shape[1]): ax.text(j, i, format(cm[i, j], fmt), ha="center", va="center", color="white" if cm[i, j] > thresh else "black") return ax
def plot_curves(df, extra_key=None, extra_key_label=None, hline=True, R=5, C=8, xlim=None, fig=None, ylim_constant=False, background=False, ylim_cla=None, ylim_aux=None, ylim_dyn=None, xlim_constant=True, legend=True, plot_x=True, yticks=None, yticklabels=None, num_axes=3, show_track_pid=False, axes_invisible=False)
Plot time-series curves from df
def plot_decision_boundary(X_col, Y_col, m, df, norms, num_pts=100)
still not finished…
def plot_decision_boundary(X_col, Y_col, m, df, norms, num_pts=100): '''still not finished... ''' x = df[X_col] y = df[Y_col] x = np.linspace(x.min(), x.max(), num_pts) y = np.linspace(y.min(), y.max(), num_pts) # normalize xv, yv = np.meshgrid(x, y, indexing='ij') x = xv.flatten() y = yv.flatten() x = (x - norms[X_col]['mu']) / (norms[X_col]['std']) y = (y - norms[Y_col]['mu']) / (norms[Y_col]['std']) X = np.hstack((x, y)).reshape(-1, 2) print(X.shape) X = df[results_individual['feat_names_selected']] preds = m.predict(X)
def plot_example(ex)
ex - row of the dataframe
def plot_example(ex): '''ex - row of the dataframe ''' plt.figure(dpi=200) plt.plot(ex['X'], color='red', label='clathrin') plt.plot(ex['Y'], color='green', label='auxilin') plt.xlabel('Time') plt.ylabel('Amplitude') plt.legend()
def plot_kymographs(df, pids, add_px=2)
plot kymographs of dynamin traces
df: pd.DataFrame dataframe
pids: list list of pids to plot
add_px: int number of additional pixels in each direction add_px=1 means 33 pixels around the center, add_px=2 means 55, etc.
cla_traces: np.array clathrin traces from raw images aux_traces: np.array auxilin traces from raw images rgb_image: 3d np.array 3d array (RGB values) of kymographs
def plot_kymographs(df, pids, add_px=2): """ plot kymographs of dynamin traces Params: ------ df: pd.DataFrame dataframe pids: list list of pids to plot add_px: int number of additional pixels in each direction add_px=1 means 3*3 pixels around the center, add_px=2 means 5*5, etc. Returns: ------ cla_traces: np.array clathrin traces from raw images aux_traces: np.array auxilin traces from raw images rgb_image: 3d np.array 3d array (RGB values) of kymographs """ indices = np.array([np.where( == pid)[0][0] for pid in pids]) df = df.iloc[indices] cells = set(df.cell_num.values) raw_videos = get_all_dynamin_videos(cells) viridis = cm.get_cmap('viridis', 12) reds = cm.get_cmap('Reds', 12) # red palette for clathrin greens = cm.get_cmap('Greens', 12) # green palette for auxilin lmax = max([len(df.x_pos_seq.iloc[i]) for i in range(len(df))]) + 2 width = 2 * add_px + 1 cla_traces, aux_traces = {}, {} for i in range(len(df)): cla_traces[i], aux_traces[i] = np.zeros((lmax, width)), np.zeros((lmax, width)) #xmean = X[df.cell_num.iloc[i]].mean() cell_num = df.cell_num.iloc[i] x_pos, y_pos = df.x_pos_seq.iloc[i], df.y_pos_seq.iloc[i] t, lt = df.t.iloc[i], min(len(x_pos), len(y_pos)) for k in range(-add_px, add_px + 1): for j in range(lt): video = raw_videos[cell_num]['cla'] cla_traces[i][j, k + add_px] = max(video[int(t/1.5) + j, \ range(int(y_pos[j]) - 0, int(y_pos[j]) + 0 + 1), \ int(x_pos[j] + k)]) vmin, vmax = video[int(t/1.5) + j,:,:].min(), video[int(t/1.5) + j,:,:].max() cla_traces[i][j, k + add_px] = (cla_traces[i][j, k + add_px] - vmin)/(vmax - vmin) video = raw_videos[cell_num]['aux'] aux_traces[i][j, k + add_px] = max(video[int(t/1.5) + j, \ range(int(y_pos[j]) - 0, int(y_pos[j]) + 0 + 1), \ int(x_pos[j] + k)]) vmin, vmax = video[int(t/1.5) + j,:,:].min(), video[int(t/1.5) + j,:,:].max() aux_traces[i][j, k + add_px] = (aux_traces[i][j, k + add_px] - vmin)/(vmax - vmin) ncol = 3 * width * len(df) cla_sparse = np.zeros((lmax, ncol)) aux_sparse = np.zeros((lmax, ncol)) for i in range(len(df)): start_index = 3 * width * i cla_sparse[:, (start_index):(start_index + width)] = cla_traces[i] aux_sparse[:, (start_index + width):(start_index + 2 * width)] = aux_traces[i] rgb_image = np.array([[list(reds(cla_sparse[i][j])[:3]) #[1, 1 - cla_sparse[i][j], 1 - cla_sparse[i][j]] \ if 0 < cla_sparse[i][j] < 1 \ else \ list(greens(aux_sparse[i][j])[:3]) \ #[1 - aux_sparse[i][j], 1, 1 - aux_sparse[i][j]] \ if 0 < aux_sparse[i][j] < 1 \ #else list(viridis(np.random.choice(background, 1)[0])[:3])\ else (1, 1, 1) for i in range(lmax)] \ for j in range(ncol)]) #cla_sparse = np.transpose(cla_sparse) #aux_sparse = np.transpose(aux_sparse) return cla_traces, aux_traces, rgb_image
def plot_pcs(pca, X)
Pretty plot of pcs with explained var bars Params
def plot_pcs(pca, X): '''Pretty plot of pcs with explained var bars Params ------ pca: sklearn PCA class after being fitted ''' plt.figure(figsize=(6, 9), dpi=200) # extract out relevant pars comps = pca.components_.transpose() var_norm = pca.explained_variance_ / np.sum(pca.explained_variance_) * 100 # create a 2 X 2 grid gs = grd.GridSpec(2, 2, height_ratios=[2, 10], width_ratios=[12, 1], wspace=0.1, hspace=0) # plot explained variance ax2 = plt.subplot(gs[0]), comps.shape[1]), var_norm, color='gray', width=0.8) plt.title('Explained variance (%)') ax2.spines['right'].set_visible(False) ax2.spines['top'].set_visible(False) ax2.yaxis.set_ticks_position('left') ax2.set_yticks([0, max(var_norm)]) plt.xlim((-0.5, comps.shape[1] - 0.5)) # plot pcs ax = plt.subplot(gs[2]) vmaxabs = np.max(np.abs(comps)) p = ax.imshow(comps, interpolation='None', aspect='auto', cmap=sns.diverging_palette(10, 240, as_cmap=True, center='light'), vmin=-vmaxabs, vmax=vmaxabs) # center at 0 plt.xlabel('PCA component number') ax.set_yticklabels(list(X)) ax.set_yticks(range(len(list(X)))) # make colorbar colorAx = plt.subplot(gs[3]) cb = plt.colorbar(p, cax=colorAx)
def print_metadata(acc=None, metadata_file='/accounts/projects/vision/chandan/auxilin-prediction/src/../data/processed/metadata_clath_aux+gak_a7d2.pkl')
def print_metadata(acc=None, metadata_file=oj(config.DIR_PROCESSED, 'metadata_clath_aux+gak_a7d2.pkl')): m = pkl.load(open(metadata_file, 'rb')) print( f'valid:\t\t{m["num_aux_pos_valid"]:>4.0f} aux+ / {m["num_tracks_valid"]:>4.0f} ({m["num_aux_pos_valid"] / m["num_tracks_valid"]:.3f})') print('----------------------------------------') print(f'hotspots:\t{m["num_hotspots_valid"]:>4.0f} aux+ / {m["num_hotspots_valid"]:>4.0f}') print( f'short:\t\t{m["num_short"] - m["num_short"] * m["acc_short"]:>4.0f} aux+ / {m["num_short"]:>4.0f} ({m["acc_short"]:.3f})') print(f'long:\t\t{m["num_long"] * m["acc_long"]:>4.0f} aux+ / {m["num_long"]:>4.0f} ({m["acc_long"]:.3f})') print( f'hard:\t\t{m["num_aux_pos_hard"]:>4.0f} aux+ / {m["num_tracks_hard"]:>4.0f} ({m["num_aux_pos_hard"] / m["num_tracks_hard"]:.3f})') if acc is not None: print('----------------------------------------') print(f'hard acc:\t\t\t {acc:.3f}') num_eval = m["num_tracks_valid"] - m["num_hotspots_valid"] # print( # f'total acc (no hotspots):\t {(m["num_short"] * m["acc_short"] + m["num_long"] * m["acc_long"] + acc * m["num_tracks_hard"]) / num_eval:.3f}') print('\nlifetime threshes', m['thresh_short'], m['thresh_long'])
def savefig(s, png=False)
def savefig(s: str, png=False): # plt.tight_layout() plt.savefig(oj(DIR_FIGS, 'fig_' + s + '.pdf'), bbox_inches='tight') if png: plt.savefig(oj(DIR_FIGS, 'fig_' + s + '.png'), dpi=300, bbox_inches='tight')
def viz_biggest_errs(df, idxs_cv, idxs, Y_test, preds, preds_proba, num_to_plot=20, aux_thresh=642, show_track_num=True, show_track_pid=False, sort_by_residuals=True, width_mult=3, plot_x=True, plot_y=True, plot_z=False, plot_axhline=True, xlim_constant=True, ylim=None, yticks=None, yticklabels=None, lifetime_max=None, text_labels=False, text_label_size=25)
Visualize X and Y where the top examples are the most wrong / least confident Params
- which idxs are not part of the test set (usually just 0, 1, 2, …)
- subset of points to plot
def viz_biggest_errs(df, idxs_cv, idxs, Y_test, preds, preds_proba, num_to_plot=20, aux_thresh=642, show_track_num=True, show_track_pid=False, sort_by_residuals=True, width_mult=3, plot_x=True, plot_y=True, plot_z=False, plot_axhline=True, xlim_constant=True, ylim: tuple=None, yticks=None, yticklabels=None, lifetime_max=None, text_labels=False, text_label_size=25): '''Visualize X and Y where the top examples are the most wrong / least confident Params ------ idxs_cv: integer ndarray which idxs are not part of the test set (usually just 0, 1, 2, ...) idxs: boolean ndarray subset of points to plot ''' DIFF = 0 # use this to ensure values are all positive # deal with idxs if idxs is not None: Y_test = Y_test[idxs] preds = preds[idxs] preds_proba = preds_proba[idxs] if idxs_cv is None: idxs_cv = np.arange(df.shape[0]) df = df.iloc[idxs_cv][idxs] # get args to sort by if sort_by_residuals: residuals = np.abs(Y_test - preds_proba) args = np.argsort(residuals)[::-1] dft = df.iloc[args] else: dft = df if lifetime_max is None: lifetime_max = np.max(dft.lifetime.values) if num_to_plot is None: num_to_plot = dft.shape[0] R = int(np.sqrt(num_to_plot)) C = num_to_plot // R # + 1 plt.figure(figsize=(C * width_mult, R * 2.5), dpi=200) i = 0 for r in range(R): for c in range(C): if i < dft.shape[0]: row = dft.iloc[i] ax = plt.subplot(R, C, i + 1) # show nums on tracks if show_track_num: ax.text(.5, .9, f'{i}', # horizontalalignment='right', transform=ax.transAxes) elif show_track_pid: ax.text(.5, .9, f'{}', # horizontalalignment='right', transform=ax.transAxes) # plt.axis('off') if '1.5s' in row['cell_num']: interval = 1.5 else: interval = 1 if plot_x: plt.plot(interval * np.arange(len(row["X"])), np.array(row["X"]) + DIFF, color=cr, label='clath', lw=2) # could do X_extended if plot_y: plt.plot(interval * np.arange(len(row["Y"])), np.array(row["Y"]) + DIFF, color=cg, label='aux', lw=2) if plot_z: plt.plot(interval * np.arange(len(row["Z"])), np.array(row["Z"]) + DIFF, color='gray', label='dyn') if xlim_constant: plt.xlim([-1, lifetime_max]) if plot_axhline: plt.axhline(aux_thresh, color='gray', alpha=0.5, lw=2) #plt.yscale('log') if ylim is not None: plt.ylim((ylim[0] + DIFF, ylim[1] + DIFF)) if not r == R - 1: plt.xticks([]) if not c == 0: plt.yticks([]) elif yticks is not None: plt.yticks(yticks, labels=yticklabels) i += 1 if text_labels: plt.text(len(row["X"]), row["X"][-1] + DIFF, 'Clathrin', color=cr, fontsize=text_label_size, fontweight='bold') plt.text(len(row["Y"]), row["Y"][-1] + DIFF, 'Auxilin', color=cg, fontsize=text_label_size, fontweight='bold') if plot_z: plt.text(len(row["Z"]), row["Z"][-1] + DIFF, 'Dynamin', fontsize=text_label_size, color='gray', fontweight='bold') plt.tight_layout() return dft
def viz_errs_1d(X_test, preds, preds_proba, Y_test, norms, key='lifetime')
visualize errs based on lifetime
def viz_errs_1d(X_test, preds, preds_proba, Y_test, norms, key='lifetime'): '''visualize errs based on lifetime ''' plt.figure(dpi=200) correct_idxs = preds == Y_test lifetime = X_test[key] * norms[key]['std'] + norms[key]['mu'] plt.plot(lifetime[(preds == Y_test) & (preds == 1)], preds_proba[(preds == Y_test) & (preds == 1)], 'o', color=cb, alpha=0.5, label='true pos') plt.plot(lifetime[(preds == Y_test) & (preds == 0)], preds_proba[(preds == Y_test) & (preds == 0)], 'x', color=cb, alpha=0.5, label='true neg') plt.plot(lifetime[preds > Y_test], preds_proba[preds > Y_test], 'o', color=cr, alpha=0.5, label='false pos') plt.plot(lifetime[preds < Y_test], preds_proba[preds < Y_test], 'x', color=cr, alpha=0.5, label='false neg') plt.xlabel(key) plt.ylabel('predicted probability') plt.legend()
def viz_errs_2d(df, idxs_test, preds, Y_test, key1='x_pos', key2='y_pos', X=None, plot_correct=True)
visualize distribution of errs wrt to 2 dimensions
def viz_errs_2d(df, idxs_test, preds, Y_test, key1='x_pos', key2='y_pos', X=None, plot_correct=True): '''visualize distribution of errs wrt to 2 dimensions ''' x_pos = df[key1].iloc[idxs_test] y_pos = df[key2].iloc[idxs_test] plt.figure(dpi=200) ms = 4 me = 1 if plot_correct: plt.plot(x_pos[(preds == Y_test) & (preds == 1)], y_pos[(preds == Y_test) & (preds == 1)], 'o', color=cb, alpha=0.4, label='true pos', ms=ms, markeredgewidth=0) plt.plot(x_pos[(preds == Y_test) & (preds == 0)], y_pos[(preds == Y_test) & (preds == 0)], 'o', color=cr, alpha=0.4, label='true neg', ms=ms, markeredgewidth=0) plt.plot(x_pos[preds > Y_test], y_pos[preds > Y_test], 'x', color=cb, alpha=0.4, label='false pos', ms=ms, markeredgewidth=1) plt.plot(x_pos[preds < Y_test], y_pos[preds < Y_test], 'x', color=cr, alpha=0.4, label='false neg', ms=ms, markeredgewidth=1) plt.legend() # plt.scatter(x_pos, y_pos, c=preds==Y_test, alpha=0.5) plt.xlabel(key1) plt.ylabel(key2) plt.tight_layout()
def viz_errs_outliers_venn(X_test, preds, Y_test, num_feats_reduced=5)
Compare outliers to errors in venn-diagram
def viz_errs_outliers_venn(X_test, preds, Y_test, num_feats_reduced=5): '''Compare outliers to errors in venn-diagram ''' feat_names = data.get_feature_names(X_test) X_feat = X_test[feat_names] if num_feats_reduced is not None: pca = decomposition.PCA(n_components=num_feats_reduced) X_reduced = pca.fit_transform(X_feat) else: X_reduced = X_feat R, C = 2, 2 titles = ['isolation forest', 'local outlier factor', 'elliptic envelop', 'one-class svm'] plt.figure(figsize=(6, 5), dpi=200) for i in range(4): plt.subplot(R, C, i + 1) plt.title(titles[i]) if i == 0: clf = IsolationForest(n_estimators=10, warm_start=True) elif i == 1: clf = LocalOutlierFactor(novelty=True) elif i == 2: clf = EllipticEnvelope() elif i == 3: clf = OneClassSVM() # fit 10 trees is_outlier = clf.predict(X_reduced) == -1 is_err = preds != Y_test idxs = np.arange(is_outlier.size) venn2([set(idxs[is_outlier]), set(idxs[is_err])], set_labels=['outliers', 'errors'])