Module awave.utils.warmstart

Expand source code
import os

import numpy as np
import torch

opj = os.path.join
import pickle as pkl

device = 'cuda' if torch.cuda.is_available() else 'cpu'

from awave.transform1d import DWT1d
from awave.transform2d import DWT2d


def warm_start(p, out_dir):
    '''load results and initialize model 
    '''
    print('\twarm starting...')
    fnames = sorted(os.listdir(out_dir))
    lamL1attr = []
    lamL1wave = []
    models = []
    if len(fnames) == 0:
        if p.wt_type == 'DWT1d':
            model = DWT1d(wave=p.wave, mode=p.mode, J=p.J, init_factor=p.init_factor, noise_factor=p.noise_factor).to(
                device)
        elif p.wt_type == 'DWT2d':
            model = DWT2d(wave=p.wave, mode=p.mode, J=p.J, init_factor=p.init_factor, noise_factor=p.noise_factor).to(
                device)
    else:
        for fname in fnames:
            if fname[-3:] == 'pkl':
                result = pkl.load(open(opj(out_dir, fname), 'rb'))
                lamL1attr.append(result['lamL1attr'])
                lamL1wave.append(result['lamL1wave'])
            if fname[-3:] == 'pth':
                if p.wt_type == 'DWT1d':
                    m = DWT1d(wave=p.wave, mode=p.mode, J=p.J, init_factor=p.init_factor,
                              noise_factor=p.noise_factor).to(device)
                elif p.wt_type == 'DWT2d':
                    m = DWT2d(wave=p.wave, mode=p.mode, J=p.J, init_factor=p.init_factor,
                              noise_factor=p.noise_factor).to(device)
                m.load_state_dict(torch.load(opj(out_dir, fname)))
                models.append(m)
        lamL1attr = np.array(lamL1attr)
        lamL1wave = np.array(lamL1wave)
        if p.lamL1attr == 0:
            lamL1wave_max = np.max(lamL1wave[lamL1attr == 0])
            idx = np.argwhere((lamL1attr == 0) & (lamL1wave == lamL1wave_max)).item()
        else:
            lamL1attr_max = np.max(lamL1attr[lamL1wave == p.lamL1wave])
            idx = np.argwhere((lamL1attr == lamL1attr_max) & (lamL1wave == p.lamL1wave)).item()
        model = models[idx]
        print('initialized at the model with lamL1wave={:.5f} and lamL1attr={:.5f}'.format(lamL1wave[idx],
                                                                                           lamL1attr[idx]))
    return model

Functions

def warm_start(p, out_dir)

load results and initialize model

Expand source code
def warm_start(p, out_dir):
    '''load results and initialize model 
    '''
    print('\twarm starting...')
    fnames = sorted(os.listdir(out_dir))
    lamL1attr = []
    lamL1wave = []
    models = []
    if len(fnames) == 0:
        if p.wt_type == 'DWT1d':
            model = DWT1d(wave=p.wave, mode=p.mode, J=p.J, init_factor=p.init_factor, noise_factor=p.noise_factor).to(
                device)
        elif p.wt_type == 'DWT2d':
            model = DWT2d(wave=p.wave, mode=p.mode, J=p.J, init_factor=p.init_factor, noise_factor=p.noise_factor).to(
                device)
    else:
        for fname in fnames:
            if fname[-3:] == 'pkl':
                result = pkl.load(open(opj(out_dir, fname), 'rb'))
                lamL1attr.append(result['lamL1attr'])
                lamL1wave.append(result['lamL1wave'])
            if fname[-3:] == 'pth':
                if p.wt_type == 'DWT1d':
                    m = DWT1d(wave=p.wave, mode=p.mode, J=p.J, init_factor=p.init_factor,
                              noise_factor=p.noise_factor).to(device)
                elif p.wt_type == 'DWT2d':
                    m = DWT2d(wave=p.wave, mode=p.mode, J=p.J, init_factor=p.init_factor,
                              noise_factor=p.noise_factor).to(device)
                m.load_state_dict(torch.load(opj(out_dir, fname)))
                models.append(m)
        lamL1attr = np.array(lamL1attr)
        lamL1wave = np.array(lamL1wave)
        if p.lamL1attr == 0:
            lamL1wave_max = np.max(lamL1wave[lamL1attr == 0])
            idx = np.argwhere((lamL1attr == 0) & (lamL1wave == lamL1wave_max)).item()
        else:
            lamL1attr_max = np.max(lamL1attr[lamL1wave == p.lamL1wave])
            idx = np.argwhere((lamL1attr == lamL1attr_max) & (lamL1wave == p.lamL1wave)).item()
        model = models[idx]
        print('initialized at the model with lamL1wave={:.5f} and lamL1attr={:.5f}'.format(lamL1wave[idx],
                                                                                           lamL1attr[idx]))
    return model