Module awave.utils.misc

Expand source code
import numpy as np
import pywt
import torch
import torch.nn.functional as F


def reflect(x, minx, maxx):
    """Reflect the values in matrix *x* about the scalar values *minx* and
    *maxx*.  Hence a vector *x* containing a long linearly increasing series is
    converted into a waveform which ramps linearly up and down between *minx*
    and *maxx*.  If *x* contains integers and *minx* and *maxx* are (integers +
    0.5), the ramps will have repeated max and min samples.

    .. codeauthor:: Rich Wareham <rjw57@cantab.net>, Aug 2013
    .. codeauthor:: Nick Kingsbury, Cambridge University, January 1999.

    """
    x = np.asanyarray(x)
    rng = maxx - minx
    rng_by_2 = 2 * rng
    mod = np.fmod(x - minx, rng_by_2)
    normed_mod = np.where(mod < 0, mod + rng_by_2, mod)
    out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx
    return np.array(out, dtype=x.dtype)


def tuple_dim(x):
    tot_dim = 0
    for i in range(len(x)):
        shape = torch.tensor(x[i].shape)
        tot_dim += torch.prod(shape).item()
    return tot_dim


def tuple_to_tensor(x):
    batch_size = x[0].size(0)
    J = len(x)
    y = torch.tensor([]).to(x[0].device)
    list_of_size = [0]
    for j in range(J):
        a = x[j].reshape(batch_size, -1)
        y = torch.cat((y, a), dim=1)
        list_of_size.append(list_of_size[-1] + a.shape[1])
    return (y, list_of_size)


def tensor_to_tuple(y, d, list_of_size):
    x = []
    J = len(list_of_size) - 1
    for j in range(J):
        n0 = list_of_size[j]
        n1 = list_of_size[j + 1]
        x.append(y[:, n0:n1].reshape(d[j].shape))
    return tuple(x)


def init_filter(x, init_factor, noise_factor, const_factor):
    '''add random noise to tensor
    Params
    ------
    x: torch.tensor
        input
    init_factor: float

    noise_factor: float
        amount of noise added to original filter
        
    const_factor: float
        amount of constant added to original filter
    '''
    shape = x.shape
    x = init_factor * x + noise_factor * torch.randn(shape) + const_factor * torch.ones(shape)
    return x


def pad_within(x, stride=2, start_row=0, start_col=0):
    w = x.new_zeros(stride, stride)
    if start_row == 0 and start_col == 0:
        w[0, 0] = 1
    elif start_row == 0 and start_col == 1:
        w[0, 1] = 1
    elif start_row == 1 and start_col == 0:
        w[1, 0] = 1
    else:
        w[1, 1] = 1
    if len(x.shape) == 2:
        x = x[None, None]
    return F.conv_transpose2d(x, w.expand(x.size(1), 1, stride, stride), stride=stride, groups=x.size(1)).squeeze()


def low_to_high(x):
    """Converts lowpass filter to highpass filter. Input must be of shape (1,1,n) where n is length of filter
    """
    n = x.size(2)
    seq = (-1) ** torch.arange(n, device=x.device)
    y = torch.flip(x, (0, 2)) * seq
    return y


def get_wavefun(w_transform, level=5):
    '''Get wavelet function from wavelet object.
    Params
    ------
    w_transform: obj
        DWT1d or DWT2d object
    '''
    h0 = w_transform.h0
    h1 = low_to_high(h0)

    h0 = list(h0.squeeze().detach().cpu().numpy())[::-1]
    h1 = list(h1.squeeze().detach().cpu().numpy())[::-1]

    my_filter_bank = (h0, h1, h0[::-1], h1[::-1])
    my_wavelet = pywt.Wavelet('My Wavelet', filter_bank=my_filter_bank)
    wave = my_wavelet.wavefun(level=level)
    (phi, psi, x) = wave[0], wave[1], wave[4]

    return phi, psi, x


def dist(wt1, wt2):
    """function to compute distance between two wavelets 
    """
    _, psi1, _ = get_wavefun(wt1)
    _, psi2, _ = get_wavefun(wt2)

    if len(psi1) > len(psi2):
        psi2 = np.pad(psi2, (0, len(psi1) - len(psi2)), mode='constant', constant_values=(0,))
    if len(psi1) < len(psi2):
        psi1 = np.pad(psi1, (0, len(psi2) - len(psi1)), mode='constant', constant_values=(0,))

    distance = []
    # circular shift 
    for i in range(len(psi1)):
        psi1_r = np.roll(psi1, i)
        d = np.linalg.norm(psi1_r - psi2)
        distance.append(d.item())
    # flip filter
    psi1_f = psi1[::-1]
    for i in range(len(psi1)):
        psi1_r = np.roll(psi1_f, i)
        d = np.linalg.norm(psi1_r - psi2)
        distance.append(d.item())

    return min(distance)


def get_1dfilts(w_transform):
    '''Get 1d filters from DWT1d object.
    Params
    ------
    w_transform: obj
        DWT1d object
    '''
    if w_transform.wt_type == 'DWT1d':
        h0 = w_transform.h0.squeeze().detach().cpu()
        h1 = low_to_high(w_transform.h0)
        h1 = h1.squeeze().detach().cpu()
        h0 = F.pad(h0, pad=(0, 0), mode='constant', value=0)
        h1 = F.pad(h1, pad=(0, 0), mode='constant', value=0)
        return (h0, h1)
    else:
        raise ValueError('no such type of wavelet transform is supported')


def get_2dfilts(w_transform):
    '''Get 2d filters from DWT2d object.
    Params
    ------
    w_transform: obj
        DWT2d object
    '''
    if w_transform.wt_type == 'DTCWT2d':
        h0o = w_transform.xfm.h0o.data
        h1o = w_transform.xfm.h1o.data
        h0a = w_transform.xfm.h0a.data
        h1a = w_transform.xfm.h1a.data
        h0b = w_transform.xfm.h0b.data
        h1b = w_transform.xfm.h1b.data

        # compute first level wavelet filters
        h0_r = F.pad(h0o.squeeze().detach().cpu(), pad=(0, 1), mode='constant', value=0)
        h0_i = F.pad(h0o.squeeze().detach().cpu(), pad=(1, 0), mode='constant', value=0)
        h1_r = F.pad(h1o.squeeze().detach().cpu(), pad=(0, 1), mode='constant', value=0)
        h1_i = F.pad(h1o.squeeze().detach().cpu(), pad=(1, 0), mode='constant', value=0)

        lh_filt_r1 = h0_r.unsqueeze(0) * h1_r.unsqueeze(1) / np.sqrt(2)
        lh_filt_r2 = h0_i.unsqueeze(0) * h1_i.unsqueeze(1) / np.sqrt(2)
        lh_filt_i1 = h0_i.unsqueeze(0) * h1_r.unsqueeze(1) / np.sqrt(2)
        lh_filt_i2 = h0_r.unsqueeze(0) * h1_i.unsqueeze(1) / np.sqrt(2)
        filt_15r = lh_filt_r1 - lh_filt_r2
        filt_165r = lh_filt_r1 + lh_filt_r2
        filt_15i = lh_filt_i1 + lh_filt_i2
        filt_165i = lh_filt_i1 - lh_filt_i2

        hh_filt_r1 = h1_r.unsqueeze(0) * h1_r.unsqueeze(1) / np.sqrt(2)
        hh_filt_r2 = h1_i.unsqueeze(0) * h1_i.unsqueeze(1) / np.sqrt(2)
        hh_filt_i1 = h1_i.unsqueeze(0) * h1_r.unsqueeze(1) / np.sqrt(2)
        hh_filt_i2 = h1_r.unsqueeze(0) * h1_i.unsqueeze(1) / np.sqrt(2)
        filt_45r = hh_filt_r1 - hh_filt_r2
        filt_135r = hh_filt_r1 + hh_filt_r2
        filt_45i = hh_filt_i1 + hh_filt_i2
        filt_135i = hh_filt_i1 - hh_filt_i2

        hl_filt_r1 = h1_r.unsqueeze(0) * h0_r.unsqueeze(1) / np.sqrt(2)
        hl_filt_r2 = h1_i.unsqueeze(0) * h0_i.unsqueeze(1) / np.sqrt(2)
        hl_filt_i1 = h1_i.unsqueeze(0) * h0_r.unsqueeze(1) / np.sqrt(2)
        hl_filt_i2 = h1_r.unsqueeze(0) * h0_i.unsqueeze(1) / np.sqrt(2)
        filt_75r = hl_filt_r1 - hl_filt_r2
        filt_105r = hl_filt_r1 + hl_filt_r2
        filt_75i = hl_filt_i1 + hl_filt_i2
        filt_105i = hl_filt_i1 - hl_filt_i2

        fl_filt_reals = [filt_15r, filt_45r, filt_75r, filt_105r, filt_135r, filt_165r]
        fl_filt_imags = [filt_15i, filt_45i, filt_75i, filt_105i, filt_135i, filt_165i]

        # compute second level wavelet filters
        h0_a = h0a.squeeze().detach().cpu()
        h0_b = h0b.squeeze().detach().cpu()
        h1_a = h1a.squeeze().detach().cpu()
        h1_b = h1b.squeeze().detach().cpu()

        lh_filt_r1 = pad_within(h0_b.unsqueeze(0) * h1_a.unsqueeze(1), start_row=1, start_col=0) / np.sqrt(2)
        lh_filt_r2 = pad_within(h0_a.unsqueeze(0) * h1_b.unsqueeze(1), start_row=0, start_col=1) / np.sqrt(2)
        lh_filt_i1 = pad_within(h0_a.unsqueeze(0) * h1_a.unsqueeze(1), start_row=1, start_col=1) / np.sqrt(2)
        lh_filt_i2 = pad_within(h0_b.unsqueeze(0) * h1_b.unsqueeze(1), start_row=0, start_col=0) / np.sqrt(2)
        filt_15r = lh_filt_r1 - lh_filt_r2
        filt_165r = lh_filt_r1 + lh_filt_r2
        filt_15i = lh_filt_i1 + lh_filt_i2
        filt_165i = lh_filt_i1 - lh_filt_i2

        hh_filt_r1 = pad_within(h1_a.unsqueeze(0) * h1_a.unsqueeze(1), start_row=1, start_col=1) / np.sqrt(2)
        hh_filt_r2 = pad_within(h1_b.unsqueeze(0) * h1_b.unsqueeze(1), start_row=0, start_col=0) / np.sqrt(2)
        hh_filt_i1 = pad_within(h1_b.unsqueeze(0) * h1_a.unsqueeze(1), start_row=1, start_col=0) / np.sqrt(2)
        hh_filt_i2 = pad_within(h1_a.unsqueeze(0) * h1_b.unsqueeze(1), start_row=0, start_col=1) / np.sqrt(2)
        filt_45r = hh_filt_r1 - hh_filt_r2
        filt_135r = hh_filt_r1 + hh_filt_r2
        filt_45i = hh_filt_i1 + hh_filt_i2
        filt_135i = hh_filt_i1 - hh_filt_i2

        hl_filt_r1 = pad_within(h1_a.unsqueeze(0) * h0_b.unsqueeze(1), start_row=0, start_col=1) / np.sqrt(2)
        hl_filt_r2 = pad_within(h1_b.unsqueeze(0) * h0_a.unsqueeze(1), start_row=1, start_col=0) / np.sqrt(2)
        hl_filt_i1 = pad_within(h1_b.unsqueeze(0) * h0_b.unsqueeze(1), start_row=0, start_col=0) / np.sqrt(2)
        hl_filt_i2 = pad_within(h1_a.unsqueeze(0) * h0_a.unsqueeze(1), start_row=1, start_col=1) / np.sqrt(2)
        filt_75r = hl_filt_r1 - hl_filt_r2
        filt_105r = hl_filt_r1 + hl_filt_r2
        filt_75i = hl_filt_i1 + hl_filt_i2
        filt_105i = hl_filt_i1 - hl_filt_i2

        sl_filt_reals = [filt_15r, filt_45r, filt_75r, filt_105r, filt_135r, filt_165r]
        sl_filt_imags = [filt_15i, filt_45i, filt_75i, filt_105i, filt_135i, filt_165i]

        return (fl_filt_reals, fl_filt_imags), (sl_filt_reals, sl_filt_imags)

    elif w_transform.wt_type == 'DWT2d':
        h0 = w_transform.h0.squeeze().detach().cpu()
        h1 = low_to_high(w_transform.h0)
        h1 = h1.squeeze().detach().cpu()
        h0 = F.pad(h0, pad=(0, 0), mode='constant', value=0)
        h1 = F.pad(h1, pad=(0, 0), mode='constant', value=0)

        filt_ll = h0.unsqueeze(0) * h0.unsqueeze(1)
        filt_lh = h0.unsqueeze(0) * h1.unsqueeze(1)
        filt_hl = h1.unsqueeze(0) * h0.unsqueeze(1)
        filt_hh = h1.unsqueeze(0) * h1.unsqueeze(1)

        return (h0, h1), (filt_ll, filt_lh, filt_hl, filt_hh)

    else:
        raise ValueError('no such type of wavelet transform is supported')

Functions

def dist(wt1, wt2)

function to compute distance between two wavelets

Expand source code
def dist(wt1, wt2):
    """function to compute distance between two wavelets 
    """
    _, psi1, _ = get_wavefun(wt1)
    _, psi2, _ = get_wavefun(wt2)

    if len(psi1) > len(psi2):
        psi2 = np.pad(psi2, (0, len(psi1) - len(psi2)), mode='constant', constant_values=(0,))
    if len(psi1) < len(psi2):
        psi1 = np.pad(psi1, (0, len(psi2) - len(psi1)), mode='constant', constant_values=(0,))

    distance = []
    # circular shift 
    for i in range(len(psi1)):
        psi1_r = np.roll(psi1, i)
        d = np.linalg.norm(psi1_r - psi2)
        distance.append(d.item())
    # flip filter
    psi1_f = psi1[::-1]
    for i in range(len(psi1)):
        psi1_r = np.roll(psi1_f, i)
        d = np.linalg.norm(psi1_r - psi2)
        distance.append(d.item())

    return min(distance)
def get_1dfilts(w_transform)

Get 1d filters from DWT1d object. Params


w_transform : obj
DWT1d object
Expand source code
def get_1dfilts(w_transform):
    '''Get 1d filters from DWT1d object.
    Params
    ------
    w_transform: obj
        DWT1d object
    '''
    if w_transform.wt_type == 'DWT1d':
        h0 = w_transform.h0.squeeze().detach().cpu()
        h1 = low_to_high(w_transform.h0)
        h1 = h1.squeeze().detach().cpu()
        h0 = F.pad(h0, pad=(0, 0), mode='constant', value=0)
        h1 = F.pad(h1, pad=(0, 0), mode='constant', value=0)
        return (h0, h1)
    else:
        raise ValueError('no such type of wavelet transform is supported')
def get_2dfilts(w_transform)

Get 2d filters from DWT2d object. Params


w_transform : obj
DWT2d object
Expand source code
def get_2dfilts(w_transform):
    '''Get 2d filters from DWT2d object.
    Params
    ------
    w_transform: obj
        DWT2d object
    '''
    if w_transform.wt_type == 'DTCWT2d':
        h0o = w_transform.xfm.h0o.data
        h1o = w_transform.xfm.h1o.data
        h0a = w_transform.xfm.h0a.data
        h1a = w_transform.xfm.h1a.data
        h0b = w_transform.xfm.h0b.data
        h1b = w_transform.xfm.h1b.data

        # compute first level wavelet filters
        h0_r = F.pad(h0o.squeeze().detach().cpu(), pad=(0, 1), mode='constant', value=0)
        h0_i = F.pad(h0o.squeeze().detach().cpu(), pad=(1, 0), mode='constant', value=0)
        h1_r = F.pad(h1o.squeeze().detach().cpu(), pad=(0, 1), mode='constant', value=0)
        h1_i = F.pad(h1o.squeeze().detach().cpu(), pad=(1, 0), mode='constant', value=0)

        lh_filt_r1 = h0_r.unsqueeze(0) * h1_r.unsqueeze(1) / np.sqrt(2)
        lh_filt_r2 = h0_i.unsqueeze(0) * h1_i.unsqueeze(1) / np.sqrt(2)
        lh_filt_i1 = h0_i.unsqueeze(0) * h1_r.unsqueeze(1) / np.sqrt(2)
        lh_filt_i2 = h0_r.unsqueeze(0) * h1_i.unsqueeze(1) / np.sqrt(2)
        filt_15r = lh_filt_r1 - lh_filt_r2
        filt_165r = lh_filt_r1 + lh_filt_r2
        filt_15i = lh_filt_i1 + lh_filt_i2
        filt_165i = lh_filt_i1 - lh_filt_i2

        hh_filt_r1 = h1_r.unsqueeze(0) * h1_r.unsqueeze(1) / np.sqrt(2)
        hh_filt_r2 = h1_i.unsqueeze(0) * h1_i.unsqueeze(1) / np.sqrt(2)
        hh_filt_i1 = h1_i.unsqueeze(0) * h1_r.unsqueeze(1) / np.sqrt(2)
        hh_filt_i2 = h1_r.unsqueeze(0) * h1_i.unsqueeze(1) / np.sqrt(2)
        filt_45r = hh_filt_r1 - hh_filt_r2
        filt_135r = hh_filt_r1 + hh_filt_r2
        filt_45i = hh_filt_i1 + hh_filt_i2
        filt_135i = hh_filt_i1 - hh_filt_i2

        hl_filt_r1 = h1_r.unsqueeze(0) * h0_r.unsqueeze(1) / np.sqrt(2)
        hl_filt_r2 = h1_i.unsqueeze(0) * h0_i.unsqueeze(1) / np.sqrt(2)
        hl_filt_i1 = h1_i.unsqueeze(0) * h0_r.unsqueeze(1) / np.sqrt(2)
        hl_filt_i2 = h1_r.unsqueeze(0) * h0_i.unsqueeze(1) / np.sqrt(2)
        filt_75r = hl_filt_r1 - hl_filt_r2
        filt_105r = hl_filt_r1 + hl_filt_r2
        filt_75i = hl_filt_i1 + hl_filt_i2
        filt_105i = hl_filt_i1 - hl_filt_i2

        fl_filt_reals = [filt_15r, filt_45r, filt_75r, filt_105r, filt_135r, filt_165r]
        fl_filt_imags = [filt_15i, filt_45i, filt_75i, filt_105i, filt_135i, filt_165i]

        # compute second level wavelet filters
        h0_a = h0a.squeeze().detach().cpu()
        h0_b = h0b.squeeze().detach().cpu()
        h1_a = h1a.squeeze().detach().cpu()
        h1_b = h1b.squeeze().detach().cpu()

        lh_filt_r1 = pad_within(h0_b.unsqueeze(0) * h1_a.unsqueeze(1), start_row=1, start_col=0) / np.sqrt(2)
        lh_filt_r2 = pad_within(h0_a.unsqueeze(0) * h1_b.unsqueeze(1), start_row=0, start_col=1) / np.sqrt(2)
        lh_filt_i1 = pad_within(h0_a.unsqueeze(0) * h1_a.unsqueeze(1), start_row=1, start_col=1) / np.sqrt(2)
        lh_filt_i2 = pad_within(h0_b.unsqueeze(0) * h1_b.unsqueeze(1), start_row=0, start_col=0) / np.sqrt(2)
        filt_15r = lh_filt_r1 - lh_filt_r2
        filt_165r = lh_filt_r1 + lh_filt_r2
        filt_15i = lh_filt_i1 + lh_filt_i2
        filt_165i = lh_filt_i1 - lh_filt_i2

        hh_filt_r1 = pad_within(h1_a.unsqueeze(0) * h1_a.unsqueeze(1), start_row=1, start_col=1) / np.sqrt(2)
        hh_filt_r2 = pad_within(h1_b.unsqueeze(0) * h1_b.unsqueeze(1), start_row=0, start_col=0) / np.sqrt(2)
        hh_filt_i1 = pad_within(h1_b.unsqueeze(0) * h1_a.unsqueeze(1), start_row=1, start_col=0) / np.sqrt(2)
        hh_filt_i2 = pad_within(h1_a.unsqueeze(0) * h1_b.unsqueeze(1), start_row=0, start_col=1) / np.sqrt(2)
        filt_45r = hh_filt_r1 - hh_filt_r2
        filt_135r = hh_filt_r1 + hh_filt_r2
        filt_45i = hh_filt_i1 + hh_filt_i2
        filt_135i = hh_filt_i1 - hh_filt_i2

        hl_filt_r1 = pad_within(h1_a.unsqueeze(0) * h0_b.unsqueeze(1), start_row=0, start_col=1) / np.sqrt(2)
        hl_filt_r2 = pad_within(h1_b.unsqueeze(0) * h0_a.unsqueeze(1), start_row=1, start_col=0) / np.sqrt(2)
        hl_filt_i1 = pad_within(h1_b.unsqueeze(0) * h0_b.unsqueeze(1), start_row=0, start_col=0) / np.sqrt(2)
        hl_filt_i2 = pad_within(h1_a.unsqueeze(0) * h0_a.unsqueeze(1), start_row=1, start_col=1) / np.sqrt(2)
        filt_75r = hl_filt_r1 - hl_filt_r2
        filt_105r = hl_filt_r1 + hl_filt_r2
        filt_75i = hl_filt_i1 + hl_filt_i2
        filt_105i = hl_filt_i1 - hl_filt_i2

        sl_filt_reals = [filt_15r, filt_45r, filt_75r, filt_105r, filt_135r, filt_165r]
        sl_filt_imags = [filt_15i, filt_45i, filt_75i, filt_105i, filt_135i, filt_165i]

        return (fl_filt_reals, fl_filt_imags), (sl_filt_reals, sl_filt_imags)

    elif w_transform.wt_type == 'DWT2d':
        h0 = w_transform.h0.squeeze().detach().cpu()
        h1 = low_to_high(w_transform.h0)
        h1 = h1.squeeze().detach().cpu()
        h0 = F.pad(h0, pad=(0, 0), mode='constant', value=0)
        h1 = F.pad(h1, pad=(0, 0), mode='constant', value=0)

        filt_ll = h0.unsqueeze(0) * h0.unsqueeze(1)
        filt_lh = h0.unsqueeze(0) * h1.unsqueeze(1)
        filt_hl = h1.unsqueeze(0) * h0.unsqueeze(1)
        filt_hh = h1.unsqueeze(0) * h1.unsqueeze(1)

        return (h0, h1), (filt_ll, filt_lh, filt_hl, filt_hh)

    else:
        raise ValueError('no such type of wavelet transform is supported')
def get_wavefun(w_transform, level=5)

Get wavelet function from wavelet object. Params


w_transform : obj
DWT1d or DWT2d object
Expand source code
def get_wavefun(w_transform, level=5):
    '''Get wavelet function from wavelet object.
    Params
    ------
    w_transform: obj
        DWT1d or DWT2d object
    '''
    h0 = w_transform.h0
    h1 = low_to_high(h0)

    h0 = list(h0.squeeze().detach().cpu().numpy())[::-1]
    h1 = list(h1.squeeze().detach().cpu().numpy())[::-1]

    my_filter_bank = (h0, h1, h0[::-1], h1[::-1])
    my_wavelet = pywt.Wavelet('My Wavelet', filter_bank=my_filter_bank)
    wave = my_wavelet.wavefun(level=level)
    (phi, psi, x) = wave[0], wave[1], wave[4]

    return phi, psi, x
def init_filter(x, init_factor, noise_factor, const_factor)

add random noise to tensor Params


x : torch.tensor
input
init_factor : float
 
noise_factor : float
amount of noise added to original filter
const_factor : float
amount of constant added to original filter
Expand source code
def init_filter(x, init_factor, noise_factor, const_factor):
    '''add random noise to tensor
    Params
    ------
    x: torch.tensor
        input
    init_factor: float

    noise_factor: float
        amount of noise added to original filter
        
    const_factor: float
        amount of constant added to original filter
    '''
    shape = x.shape
    x = init_factor * x + noise_factor * torch.randn(shape) + const_factor * torch.ones(shape)
    return x
def low_to_high(x)

Converts lowpass filter to highpass filter. Input must be of shape (1,1,n) where n is length of filter

Expand source code
def low_to_high(x):
    """Converts lowpass filter to highpass filter. Input must be of shape (1,1,n) where n is length of filter
    """
    n = x.size(2)
    seq = (-1) ** torch.arange(n, device=x.device)
    y = torch.flip(x, (0, 2)) * seq
    return y
def pad_within(x, stride=2, start_row=0, start_col=0)
Expand source code
def pad_within(x, stride=2, start_row=0, start_col=0):
    w = x.new_zeros(stride, stride)
    if start_row == 0 and start_col == 0:
        w[0, 0] = 1
    elif start_row == 0 and start_col == 1:
        w[0, 1] = 1
    elif start_row == 1 and start_col == 0:
        w[1, 0] = 1
    else:
        w[1, 1] = 1
    if len(x.shape) == 2:
        x = x[None, None]
    return F.conv_transpose2d(x, w.expand(x.size(1), 1, stride, stride), stride=stride, groups=x.size(1)).squeeze()
def reflect(x, minx, maxx)

Reflect the values in matrix x about the scalar values minx and maxx. Hence a vector x containing a long linearly increasing series is converted into a waveform which ramps linearly up and down between minx and maxx. If x contains integers and minx and maxx are (integers + 0.5), the ramps will have repeated max and min samples.

Codeauthor: Rich Wareham rjw57@cantab.net, Aug 2013

Codeauthor: Nick Kingsbury, Cambridge University, January 1999.

Expand source code
def reflect(x, minx, maxx):
    """Reflect the values in matrix *x* about the scalar values *minx* and
    *maxx*.  Hence a vector *x* containing a long linearly increasing series is
    converted into a waveform which ramps linearly up and down between *minx*
    and *maxx*.  If *x* contains integers and *minx* and *maxx* are (integers +
    0.5), the ramps will have repeated max and min samples.

    .. codeauthor:: Rich Wareham <rjw57@cantab.net>, Aug 2013
    .. codeauthor:: Nick Kingsbury, Cambridge University, January 1999.

    """
    x = np.asanyarray(x)
    rng = maxx - minx
    rng_by_2 = 2 * rng
    mod = np.fmod(x - minx, rng_by_2)
    normed_mod = np.where(mod < 0, mod + rng_by_2, mod)
    out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx
    return np.array(out, dtype=x.dtype)
def tensor_to_tuple(y, d, list_of_size)
Expand source code
def tensor_to_tuple(y, d, list_of_size):
    x = []
    J = len(list_of_size) - 1
    for j in range(J):
        n0 = list_of_size[j]
        n1 = list_of_size[j + 1]
        x.append(y[:, n0:n1].reshape(d[j].shape))
    return tuple(x)
def tuple_dim(x)
Expand source code
def tuple_dim(x):
    tot_dim = 0
    for i in range(len(x)):
        shape = torch.tensor(x[i].shape)
        tot_dim += torch.prod(shape).item()
    return tot_dim
def tuple_to_tensor(x)
Expand source code
def tuple_to_tensor(x):
    batch_size = x[0].size(0)
    J = len(x)
    y = torch.tensor([]).to(x[0].device)
    list_of_size = [0]
    for j in range(J):
        a = x[j].reshape(batch_size, -1)
        y = torch.cat((y, a), dim=1)
        list_of_size.append(list_of_size[-1] + a.shape[1])
    return (y, list_of_size)