Module src.data
Expand source code
import os
import sys
from copy import deepcopy
from os.path import join as oj
import mat4py
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
try:
from skimage.external.tifffile import imread
except:
from skimage.io import imread
pd.options.mode.chained_assignment = None # default='warn' - caution: this turns off setting with copy warning
import pickle as pkl
from viz import *
import math
import config
import features
import outcomes
import load_tracking
from tqdm import tqdm
import train_reg
def load_dfs_for_lstm(dsets=['clath_aux+gak_new'],
splits=['test'],
meta=None,
length=40,
normalize=True,
lifetime_threshold=15,
filter_short=True,
padding='end'):
'''Loads dataframes preprocessed ready for LSTM
'''
dfs = {}
for dset in tqdm(dsets):
for split in splits:
df = get_data(dset=dset)
df = df[~df.hotspots]
if filter_short and lifetime_threshold == 15:
# df = df[~(df.short | df.long)]
# df = df[df.valid]
df = df[df.lifetime > 15] # only keep hard tracks
elif not lifetime_threshold == 15:
df = df[df.lifetime > lifetime_threshold] # only keep hard tracks
else:
df = df[~df.hotspots]
df = df[df.cell_num.isin(config.DSETS[dset][split])] # select train/test etc.
feat_names = ['X_same_length_normalized'] + select_final_feats(get_feature_names(df))
# downsample tracks
df['X_same_length'] = [features.downsample(df.iloc[i]['X'],length, padding=padding)
for i in range(len(df))] # downsampling
df['X_same_length_extended'] = [features.downsample(df.iloc[i]['X_extended'], length, padding=padding)
for i in range(len(df))] # downsampling
# normalize tracks
df = features.normalize_track(df, track='X_same_length', by_time_point=False)
df = features.normalize_track(df, track='X_same_length_extended', by_time_point=False)
# regression response
df = outcomes.add_sig_mean(df, resp_tracks=['Y'])
df = outcomes.add_aux_dyn_outcome(df)
df['X_max_orig'] = deepcopy(df['X_max'].values)
# remove extraneous feats
# df = df[feat_names + meta]
# df = df.dropna()
# normalize features
if normalize:
for feat in feat_names:
if 'X_same_length' not in feat:
df = features.normalize_feature(df, feat)
dfs[(dset, split)] = deepcopy(df)
return dfs, feat_names
def get_data(dset='clath_aux+gak_a7d2', use_processed=True, save_processed=True,
processed_file=oj(config.DIR_PROCESSED, 'df.pkl'),
metadata_file=oj(config.DIR_PROCESSED, 'metadata.pkl'),
use_processed_dicts=True,
compute_dictionary_learning=False,
outcome_def='y_consec_thresh',
pixel_data: bool=False,
video_data: bool=False,
acc_thresh=0.95,
previous_meta_file: str=None):
'''
Params
------
use_processed: bool, optional
determines whether to load df from cached pkl
save_processed: bool, optional
if not using processed, determines whether to save the df
use_processed_dicts: bool, optional
if False, recalculate the dictionary features
previous_meta_file: str, optional
filename for metadata.pkl file saved by previous preprocessing
the thresholds for lifetime are taken from this file
'''
# get things based onn dset
DSET = config.DSETS[dset]
LABELS = config.LABELS[dset]
processed_file = processed_file[:-4] + '_' + dset + '.pkl'
metadata_file = metadata_file[:-4] + '_' + dset + '.pkl'
if use_processed and os.path.exists(processed_file):
return pd.read_pickle(processed_file)
else:
print('loading + preprocessing data...')
metadata = {}
# load tracks
print('\tloading tracks...')
df = load_tracking.get_tracks(data_dir=DSET['data_dir'],
split=DSET,
pixel_data=pixel_data,
video_data=video_data,
dset=dset) # note: different Xs can be different shapes
# df = df.fillna(df.median()) # this only does anything for the dynamin tracks, where x_pos is sometimes NaN
# print('num nans', df.isna().sum())
df['pid'] = np.arange(df.shape[0]) # assign each track a unique id
df['valid'] = True # all tracks start as valid
# set testing tracks to not valid
if DSET['test'] is not None:
df['valid'][df.cell_num.isin(DSET['test'])] = False
metadata['num_tracks'] = df.valid.sum()
# print('training', df.valid.sum())
# preprocess data
print('\tpreprocessing data...')
df = remove_invalid_tracks(df) # use catIdx
# print('valid', df.valid.sum())
df = features.add_basic_features(df)
df = outcomes.add_outcomes(df, LABELS=LABELS)
metadata['num_tracks_valid'] = df.valid.sum()
metadata['num_aux_pos_valid'] = df[df.valid][outcome_def].sum()
metadata['num_hotspots_valid'] = df[df.valid]['hotspots'].sum()
df['valid'][df.hotspots] = False
df, meta_lifetime = process_tracks_by_lifetime(df, outcome_def=outcome_def,
plot=False, acc_thresh=acc_thresh,
previous_meta_file=previous_meta_file)
df['valid'][df.short] = False
df['valid'][df.long] = False
metadata.update(meta_lifetime)
metadata['num_tracks_hard'] = df['valid'].sum()
metadata['num_aux_pos_hard'] = int(df[df.valid == 1][outcome_def].sum())
# add features
print('\tadding features...')
df = features.add_dasc_features(df)
if compute_dictionary_learning:
df = features.add_dict_features(df, use_processed=use_processed_dicts)
# df = features.add_smoothed_tracks(df)
# df = features.add_pcs(df)
# df = features.add_trend_filtering(df)
# df = features.add_binary_features(df, outcome_def=outcome_def)
if save_processed:
print('\tsaving...')
pkl.dump(metadata, open(metadata_file, 'wb'))
df.to_pickle(processed_file)
return df
def remove_invalid_tracks(df, keep=[1, 2]):
'''Remove certain types of tracks based on cat_idx.
Only keep cat_idx = 1 and 2
1-4 (non-complex trajectory - no merges and splits)
1 - valid
2 - signal occasionally drops out
3 - cut - starts / ends
4 - multiple - at the same place (continues throughout)
5-8 (there is merging or splitting)
'''
return df[df.catIdx.isin(keep)]
def process_tracks_by_lifetime(df: pd.DataFrame, outcome_def: str,
plot=False, acc_thresh=0.95, previous_meta_file=None):
'''Calculate accuracy you can get by just predicting max class
as a func of lifetime and return points within proper lifetime (only looks at training cells)
'''
vals = df[df.valid == 1][['lifetime', outcome_def]]
R, C = 1, 3
lifetimes = np.unique(vals['lifetime'])
# cumulative accuracy for different thresholds
accs_cum_lower = np.array([1 - np.mean(vals[outcome_def][vals['lifetime'] <= l]) for l in lifetimes])
accs_cum_higher = np.array([np.mean(vals[outcome_def][vals['lifetime'] >= l]) for l in lifetimes]).flatten()
if previous_meta_file is None:
try:
idx_thresh = np.nonzero(accs_cum_lower >= acc_thresh)[0][-1] # last nonzero index
thresh_lower = lifetimes[idx_thresh]
except:
idx_thresh = 0
thresh_lower = lifetimes[idx_thresh] - 1
try:
idx_thresh_2 = np.nonzero(accs_cum_higher >= acc_thresh)[0][0]
thresh_higher = lifetimes[idx_thresh_2]
except:
idx_thresh_2 = lifetimes.size - 1
thresh_higher = lifetimes[idx_thresh_2] + 1
else:
previous_meta = pkl.load(open(previous_meta_file, 'rb'))
thresh_lower = previous_meta['thresh_short']
thresh_higher = previous_meta['thresh_long']
# only df with lifetimes in proper range
df['short'] = df['lifetime'] <= thresh_lower
df['long'] = df['lifetime'] >= thresh_higher
n = vals.shape[0]
n_short = np.sum(df['short'])
n_long = np.sum(df['long'])
acc_short = 1 - np.mean(vals[outcome_def][vals['lifetime'] <= thresh_lower])
acc_long = np.mean(vals[outcome_def][vals['lifetime'] >= thresh_higher])
metadata = {'num_short': n_short, 'num_long': n_long, 'acc_short': acc_short,
'acc_long': acc_long, 'thresh_short': thresh_lower, 'thresh_long': thresh_higher}
if plot:
plt.figure(figsize=(12, 4), dpi=200)
plt.subplot(R, C, 1)
outcome = df[outcome_def]
plt.hist(df['lifetime'][outcome == 1], label='aux+', alpha=1, color=cb, bins=25)
plt.hist(df['lifetime'][outcome == 0], label='aux-', alpha=0.7, color=cr, bins=25)
plt.xlabel('lifetime')
plt.ylabel('count')
plt.legend()
plt.subplot(R, C, 2)
plt.plot(lifetimes, accs_cum_lower, color=cr)
# plt.axvline(thresh_lower)
plt.axvspan(0, thresh_lower, alpha=0.2, color=cr)
plt.ylabel('fraction of negative events')
plt.xlabel(f'lifetime <= value\nshaded includes {n_short / n * 100:0.0f}% of pts')
plt.subplot(R, C, 3)
plt.plot(lifetimes, accs_cum_higher, cb)
plt.axvspan(thresh_higher, max(lifetimes), alpha=0.2, color=cb)
plt.ylabel('fraction of positive events')
plt.xlabel(f'lifetime >= value\nshaded includes {n_long / n * 100:0.0f}% of pts')
plt.tight_layout()
return df, metadata
def get_feature_names(df):
'''Returns features (all of which are scalar)
Removes metadata + time-series columns + outcomes
'''
ks = list(df.keys())
feat_names = [
k for k in ks
if not k.startswith('y')
and not k.startswith('Y')
and not k.startswith('Z')
and not k.startswith('pixel')
# and not k.startswith('pc_')
and not k in ['catIdx', 'cell_num', 'pid', 'valid', # metadata
'X', 'X_pvals', 'x_pos', 'X_starts', 'X_ends', 'X_extended', # curves
'short', 'long', 'hotspots', 'sig_idxs', # should be weeded out
'X_max_around_Y_peak', 'X_max_after_Y_peak', # redudant with X_max / fall
'X_max_diff', 'X_peak_idx', # unlikely to be useful
't', 'x_pos_seq', 'y_pos_seq', # curves
'X_smooth_spl', 'X_smooth_spl_dx', 'X_smooth_spl_d2x', # curves
'X_quantiles',
]
]
return feat_names
def select_final_feats(feat_names, binarize=False):
feat_names = [x for x in feat_names
if not x.startswith('sc_') # sparse coding
and not x.startswith('nmf_') # nmf
and not x in ['center_max', 'left_max', 'right_max', 'up_max', 'down_max',
'X_max_around_Y_peak', 'X_max_after_Y_peak', 'X_max_diff_after_Y_peak']
and not x.startswith('pc_')
and not 'extended' in x
and not x == 'slope_end'
and not '_tf_smooth' in x
and not 'local' in x
and not 'last' in x
and not 'video' in x
and not x == 'X_quantiles'
# and not 'X_peak' in x
# and not 'slope' in x
# and not x in ['fall_final', 'fall_slope', 'fall_imp', 'fall']
]
if binarize:
feat_names = [x for x in feat_names if 'binary' in x]
else:
feat_names = [x for x in feat_names if not 'binary' in x]
return feat_names
if __name__ == '__main__':
# process original data (and save out lifetime thresholds)
dset_orig = 'clath_aux+gak_a7d2'
df = get_data(dset=dset_orig) # save out orig
# process new data (using lifetime thresholds from original data)
outcome_def = 'y_consec_sig'
# for dset in ['clath_aux_dynamin']:
for dset in config.DSETS.keys():
df = get_data(dset=dset, previous_meta_file=None)
# df = get_data(dset=dset, previous_meta_file=f'{config.DIR_PROCESSED}/metadata_{dset_orig}.pkl')
print(dset, 'num cells', len(df['cell_num'].unique()), 'num tracks', df.shape[0], 'num aux+',
df[outcome_def].sum(), 'aux+ fraction', (df[outcome_def].sum() / df.shape[0]).round(3),
'valid', df.valid.sum(), 'valid aux+', df[df.valid][outcome_def].sum(), 'valid aux+ fraction',
(df[df.valid][outcome_def].sum() / df.valid.sum()).round(3))
Functions
def get_data(dset='clath_aux+gak_a7d2', use_processed=True, save_processed=True, processed_file='/accounts/projects/vision/chandan/auxilin-prediction/src/../data/processed/df.pkl', metadata_file='/accounts/projects/vision/chandan/auxilin-prediction/src/../data/processed/metadata.pkl', use_processed_dicts=True, compute_dictionary_learning=False, outcome_def='y_consec_thresh', pixel_data=False, video_data=False, acc_thresh=0.95, previous_meta_file=None)
-
Params
use_processed
:bool
, optional- determines whether to load df from cached pkl
save_processed
:bool
, optional- if not using processed, determines whether to save the df
use_processed_dicts
:bool
, optional- if False, recalculate the dictionary features
previous_meta_file
:str
, optional- filename for metadata.pkl file saved by previous preprocessing the thresholds for lifetime are taken from this file
Expand source code
def get_data(dset='clath_aux+gak_a7d2', use_processed=True, save_processed=True, processed_file=oj(config.DIR_PROCESSED, 'df.pkl'), metadata_file=oj(config.DIR_PROCESSED, 'metadata.pkl'), use_processed_dicts=True, compute_dictionary_learning=False, outcome_def='y_consec_thresh', pixel_data: bool=False, video_data: bool=False, acc_thresh=0.95, previous_meta_file: str=None): ''' Params ------ use_processed: bool, optional determines whether to load df from cached pkl save_processed: bool, optional if not using processed, determines whether to save the df use_processed_dicts: bool, optional if False, recalculate the dictionary features previous_meta_file: str, optional filename for metadata.pkl file saved by previous preprocessing the thresholds for lifetime are taken from this file ''' # get things based onn dset DSET = config.DSETS[dset] LABELS = config.LABELS[dset] processed_file = processed_file[:-4] + '_' + dset + '.pkl' metadata_file = metadata_file[:-4] + '_' + dset + '.pkl' if use_processed and os.path.exists(processed_file): return pd.read_pickle(processed_file) else: print('loading + preprocessing data...') metadata = {} # load tracks print('\tloading tracks...') df = load_tracking.get_tracks(data_dir=DSET['data_dir'], split=DSET, pixel_data=pixel_data, video_data=video_data, dset=dset) # note: different Xs can be different shapes # df = df.fillna(df.median()) # this only does anything for the dynamin tracks, where x_pos is sometimes NaN # print('num nans', df.isna().sum()) df['pid'] = np.arange(df.shape[0]) # assign each track a unique id df['valid'] = True # all tracks start as valid # set testing tracks to not valid if DSET['test'] is not None: df['valid'][df.cell_num.isin(DSET['test'])] = False metadata['num_tracks'] = df.valid.sum() # print('training', df.valid.sum()) # preprocess data print('\tpreprocessing data...') df = remove_invalid_tracks(df) # use catIdx # print('valid', df.valid.sum()) df = features.add_basic_features(df) df = outcomes.add_outcomes(df, LABELS=LABELS) metadata['num_tracks_valid'] = df.valid.sum() metadata['num_aux_pos_valid'] = df[df.valid][outcome_def].sum() metadata['num_hotspots_valid'] = df[df.valid]['hotspots'].sum() df['valid'][df.hotspots] = False df, meta_lifetime = process_tracks_by_lifetime(df, outcome_def=outcome_def, plot=False, acc_thresh=acc_thresh, previous_meta_file=previous_meta_file) df['valid'][df.short] = False df['valid'][df.long] = False metadata.update(meta_lifetime) metadata['num_tracks_hard'] = df['valid'].sum() metadata['num_aux_pos_hard'] = int(df[df.valid == 1][outcome_def].sum()) # add features print('\tadding features...') df = features.add_dasc_features(df) if compute_dictionary_learning: df = features.add_dict_features(df, use_processed=use_processed_dicts) # df = features.add_smoothed_tracks(df) # df = features.add_pcs(df) # df = features.add_trend_filtering(df) # df = features.add_binary_features(df, outcome_def=outcome_def) if save_processed: print('\tsaving...') pkl.dump(metadata, open(metadata_file, 'wb')) df.to_pickle(processed_file) return df
def get_feature_names(df)
-
Returns features (all of which are scalar) Removes metadata + time-series columns + outcomes
Expand source code
def get_feature_names(df): '''Returns features (all of which are scalar) Removes metadata + time-series columns + outcomes ''' ks = list(df.keys()) feat_names = [ k for k in ks if not k.startswith('y') and not k.startswith('Y') and not k.startswith('Z') and not k.startswith('pixel') # and not k.startswith('pc_') and not k in ['catIdx', 'cell_num', 'pid', 'valid', # metadata 'X', 'X_pvals', 'x_pos', 'X_starts', 'X_ends', 'X_extended', # curves 'short', 'long', 'hotspots', 'sig_idxs', # should be weeded out 'X_max_around_Y_peak', 'X_max_after_Y_peak', # redudant with X_max / fall 'X_max_diff', 'X_peak_idx', # unlikely to be useful 't', 'x_pos_seq', 'y_pos_seq', # curves 'X_smooth_spl', 'X_smooth_spl_dx', 'X_smooth_spl_d2x', # curves 'X_quantiles', ] ] return feat_names
def load_dfs_for_lstm(dsets=['clath_aux+gak_new'], splits=['test'], meta=None, length=40, normalize=True, lifetime_threshold=15, filter_short=True, padding='end')
-
Loads dataframes preprocessed ready for LSTM
Expand source code
def load_dfs_for_lstm(dsets=['clath_aux+gak_new'], splits=['test'], meta=None, length=40, normalize=True, lifetime_threshold=15, filter_short=True, padding='end'): '''Loads dataframes preprocessed ready for LSTM ''' dfs = {} for dset in tqdm(dsets): for split in splits: df = get_data(dset=dset) df = df[~df.hotspots] if filter_short and lifetime_threshold == 15: # df = df[~(df.short | df.long)] # df = df[df.valid] df = df[df.lifetime > 15] # only keep hard tracks elif not lifetime_threshold == 15: df = df[df.lifetime > lifetime_threshold] # only keep hard tracks else: df = df[~df.hotspots] df = df[df.cell_num.isin(config.DSETS[dset][split])] # select train/test etc. feat_names = ['X_same_length_normalized'] + select_final_feats(get_feature_names(df)) # downsample tracks df['X_same_length'] = [features.downsample(df.iloc[i]['X'],length, padding=padding) for i in range(len(df))] # downsampling df['X_same_length_extended'] = [features.downsample(df.iloc[i]['X_extended'], length, padding=padding) for i in range(len(df))] # downsampling # normalize tracks df = features.normalize_track(df, track='X_same_length', by_time_point=False) df = features.normalize_track(df, track='X_same_length_extended', by_time_point=False) # regression response df = outcomes.add_sig_mean(df, resp_tracks=['Y']) df = outcomes.add_aux_dyn_outcome(df) df['X_max_orig'] = deepcopy(df['X_max'].values) # remove extraneous feats # df = df[feat_names + meta] # df = df.dropna() # normalize features if normalize: for feat in feat_names: if 'X_same_length' not in feat: df = features.normalize_feature(df, feat) dfs[(dset, split)] = deepcopy(df) return dfs, feat_names
def process_tracks_by_lifetime(df, outcome_def, plot=False, acc_thresh=0.95, previous_meta_file=None)
-
Calculate accuracy you can get by just predicting max class as a func of lifetime and return points within proper lifetime (only looks at training cells)
Expand source code
def process_tracks_by_lifetime(df: pd.DataFrame, outcome_def: str, plot=False, acc_thresh=0.95, previous_meta_file=None): '''Calculate accuracy you can get by just predicting max class as a func of lifetime and return points within proper lifetime (only looks at training cells) ''' vals = df[df.valid == 1][['lifetime', outcome_def]] R, C = 1, 3 lifetimes = np.unique(vals['lifetime']) # cumulative accuracy for different thresholds accs_cum_lower = np.array([1 - np.mean(vals[outcome_def][vals['lifetime'] <= l]) for l in lifetimes]) accs_cum_higher = np.array([np.mean(vals[outcome_def][vals['lifetime'] >= l]) for l in lifetimes]).flatten() if previous_meta_file is None: try: idx_thresh = np.nonzero(accs_cum_lower >= acc_thresh)[0][-1] # last nonzero index thresh_lower = lifetimes[idx_thresh] except: idx_thresh = 0 thresh_lower = lifetimes[idx_thresh] - 1 try: idx_thresh_2 = np.nonzero(accs_cum_higher >= acc_thresh)[0][0] thresh_higher = lifetimes[idx_thresh_2] except: idx_thresh_2 = lifetimes.size - 1 thresh_higher = lifetimes[idx_thresh_2] + 1 else: previous_meta = pkl.load(open(previous_meta_file, 'rb')) thresh_lower = previous_meta['thresh_short'] thresh_higher = previous_meta['thresh_long'] # only df with lifetimes in proper range df['short'] = df['lifetime'] <= thresh_lower df['long'] = df['lifetime'] >= thresh_higher n = vals.shape[0] n_short = np.sum(df['short']) n_long = np.sum(df['long']) acc_short = 1 - np.mean(vals[outcome_def][vals['lifetime'] <= thresh_lower]) acc_long = np.mean(vals[outcome_def][vals['lifetime'] >= thresh_higher]) metadata = {'num_short': n_short, 'num_long': n_long, 'acc_short': acc_short, 'acc_long': acc_long, 'thresh_short': thresh_lower, 'thresh_long': thresh_higher} if plot: plt.figure(figsize=(12, 4), dpi=200) plt.subplot(R, C, 1) outcome = df[outcome_def] plt.hist(df['lifetime'][outcome == 1], label='aux+', alpha=1, color=cb, bins=25) plt.hist(df['lifetime'][outcome == 0], label='aux-', alpha=0.7, color=cr, bins=25) plt.xlabel('lifetime') plt.ylabel('count') plt.legend() plt.subplot(R, C, 2) plt.plot(lifetimes, accs_cum_lower, color=cr) # plt.axvline(thresh_lower) plt.axvspan(0, thresh_lower, alpha=0.2, color=cr) plt.ylabel('fraction of negative events') plt.xlabel(f'lifetime <= value\nshaded includes {n_short / n * 100:0.0f}% of pts') plt.subplot(R, C, 3) plt.plot(lifetimes, accs_cum_higher, cb) plt.axvspan(thresh_higher, max(lifetimes), alpha=0.2, color=cb) plt.ylabel('fraction of positive events') plt.xlabel(f'lifetime >= value\nshaded includes {n_long / n * 100:0.0f}% of pts') plt.tight_layout() return df, metadata
def remove_invalid_tracks(df, keep=[1, 2])
-
Remove certain types of tracks based on cat_idx. Only keep cat_idx = 1 and 2 1-4 (non-complex trajectory - no merges and splits) 1 - valid 2 - signal occasionally drops out 3 - cut - starts / ends 4 - multiple - at the same place (continues throughout) 5-8 (there is merging or splitting)
Expand source code
def remove_invalid_tracks(df, keep=[1, 2]): '''Remove certain types of tracks based on cat_idx. Only keep cat_idx = 1 and 2 1-4 (non-complex trajectory - no merges and splits) 1 - valid 2 - signal occasionally drops out 3 - cut - starts / ends 4 - multiple - at the same place (continues throughout) 5-8 (there is merging or splitting) ''' return df[df.catIdx.isin(keep)]
def select_final_feats(feat_names, binarize=False)
-
Expand source code
def select_final_feats(feat_names, binarize=False): feat_names = [x for x in feat_names if not x.startswith('sc_') # sparse coding and not x.startswith('nmf_') # nmf and not x in ['center_max', 'left_max', 'right_max', 'up_max', 'down_max', 'X_max_around_Y_peak', 'X_max_after_Y_peak', 'X_max_diff_after_Y_peak'] and not x.startswith('pc_') and not 'extended' in x and not x == 'slope_end' and not '_tf_smooth' in x and not 'local' in x and not 'last' in x and not 'video' in x and not x == 'X_quantiles' # and not 'X_peak' in x # and not 'slope' in x # and not x in ['fall_final', 'fall_slope', 'fall_imp', 'fall'] ] if binarize: feat_names = [x for x in feat_names if 'binary' in x] else: feat_names = [x for x in feat_names if not 'binary' in x] return feat_names