Module awave.trim.trim

Expand source code
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
import sys

sys.path.append('..')
from .util import *
from torch import nn


class TrimModel(nn.Module):
    '''Prepends transformation onto network (with optional normalizaiton after the transform)
    Params
    ------
    model: nn.Module
        model after all the transformations
    inv_transform: nn.Module
        the inverse transform
    norm: nn.Module (Norm_Layer)
        normalization to apply after the inverse transform
    reshape: nn.Module
        reshape to apply after the normalization
    use_residuals: bool, optional
        whether or not to apply the residuals after the transformation 
        (for transformations which are not perfectly invertible)
    use_logits: bool, optional
        whether to use the logits (if the model has them) or the forward function
    n_components: int
        right now this setup is kind of weird - if you want to pass a residual
        pass x as a 1d vector whose last entries contain the residual [x, residual]
    '''

    def __init__(self, model, inv_transform, norm=None, reshape=None,
                 use_residuals=False, use_logits=False):
        super(TrimModel, self).__init__()
        self.inv_transform = inv_transform
        self.norm = norm
        self.reshape = reshape
        self.model = model
        self.use_residuals = use_residuals
        self.use_logits = use_logits

    def forward(self, s, x_orig=None):
        '''
        Params
        ------
        s: torch.Tensor
            This should be the input in the transformed space which we want to interpret
            (batch_size, C, H, W) for images
            (batch_size, C, seq_length) for audio
        '''
        # untransform the input
        x = self.inv_transform(s)

        # take residuals into account
        if self.use_residuals:
            assert x_orig is not None, "if using residuals, must also pass untransformed original image!"
            res = x_orig - x.detach()
            x = x + res

        # normalize
        if self.norm is not None:
            x = self.norm(x)

        # reshape
        if self.reshape is not None:
            x = self.reshape(x)

        # pass through the main model
        if self.use_logits:
            x = self.model.logits(x)
        else:
            x = self.model.forward(x)
        return x


def lay_from_w(D: np.ndarray):
    '''Creates a linear layer given a weight matrix
    Params
    ------
    D
        weight matrix (in_features, out_features)
    '''
    lay = nn.Linear(in_features=D.shape[0], out_features=D.shape[1], bias=False)
    lay.weight.data = torch.tensor(D.astype(np.float32)).T
    return lay


class NormLayer(nn.Module):
    '''Normalizes images (assumes only 1 channel)
    image = (image - mean) / std
    '''

    def __init__(self, mu=0.1307, std=0.3081):
        super(NormLayer, self).__init__()
        self.mean = mu
        self.std = std

    def forward(self, x):
        return (x - self.mean) / self.std


def modularize(f):
    '''Turns any function into a torch module
    '''

    class Transform(nn.Module):
        def __init__(self, f):
            super(Transform, self).__init__()
            self.f = f

        def forward(self, x):
            return self.f(x)

    return Transform(f)


class ReshapeLayer(nn.Module):
    '''Returns a torch module which reshapes an input to a desired shape
    Params
    ------
    shape: tuple
        shape excluding batch size
    '''

    def __init__(self, shape):
        super(ReshapeLayer, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.reshape(x.shape[0], *self.shape)


class DecoderEncoder(nn.Module):
    '''Prepends decoder onto encoder
    '''

    def __init__(self, model, use_residuals=False):
        super(DecoderEncoder, self).__init__()
        self.encoder = model.encoder
        self.decoder = model.decoder
        self.use_residuals = use_residuals

    def forward(self, s, x_orig=None):
        '''
        Params
        ------
        s: torch.Tensor
            This should be the input in the transformed space which we want to interpret
            (batch_size, C, H, W) for images
            (batch_size, C, seq_length) for audio
        '''
        x = self.decoder(s)

        if self.use_residuals:
            assert x_orig is not None, "if using residuals, must also pass untransformed original image!"
            res = (x_orig - x).detach()
            x = x + res
        x = self.encoder(x)[0]
        return x

Functions

def lay_from_w(D)

Creates a linear layer given a weight matrix Params


D
weight matrix (in_features, out_features)
Expand source code
def lay_from_w(D: np.ndarray):
    '''Creates a linear layer given a weight matrix
    Params
    ------
    D
        weight matrix (in_features, out_features)
    '''
    lay = nn.Linear(in_features=D.shape[0], out_features=D.shape[1], bias=False)
    lay.weight.data = torch.tensor(D.astype(np.float32)).T
    return lay
def modularize(f)

Turns any function into a torch module

Expand source code
def modularize(f):
    '''Turns any function into a torch module
    '''

    class Transform(nn.Module):
        def __init__(self, f):
            super(Transform, self).__init__()
            self.f = f

        def forward(self, x):
            return self.f(x)

    return Transform(f)

Classes

class DecoderEncoder (model, use_residuals=False)

Prepends decoder onto encoder

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class DecoderEncoder(nn.Module):
    '''Prepends decoder onto encoder
    '''

    def __init__(self, model, use_residuals=False):
        super(DecoderEncoder, self).__init__()
        self.encoder = model.encoder
        self.decoder = model.decoder
        self.use_residuals = use_residuals

    def forward(self, s, x_orig=None):
        '''
        Params
        ------
        s: torch.Tensor
            This should be the input in the transformed space which we want to interpret
            (batch_size, C, H, W) for images
            (batch_size, C, seq_length) for audio
        '''
        x = self.decoder(s)

        if self.use_residuals:
            assert x_orig is not None, "if using residuals, must also pass untransformed original image!"
            res = (x_orig - x).detach()
            x = x + res
        x = self.encoder(x)[0]
        return x

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, s, x_orig=None)

Params

s : torch.Tensor
This should be the input in the transformed space which we want to interpret (batch_size, C, H, W) for images (batch_size, C, seq_length) for audio
Expand source code
def forward(self, s, x_orig=None):
    '''
    Params
    ------
    s: torch.Tensor
        This should be the input in the transformed space which we want to interpret
        (batch_size, C, H, W) for images
        (batch_size, C, seq_length) for audio
    '''
    x = self.decoder(s)

    if self.use_residuals:
        assert x_orig is not None, "if using residuals, must also pass untransformed original image!"
        res = (x_orig - x).detach()
        x = x + res
    x = self.encoder(x)[0]
    return x
class NormLayer (mu=0.1307, std=0.3081)

Normalizes images (assumes only 1 channel) image = (image - mean) / std

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class NormLayer(nn.Module):
    '''Normalizes images (assumes only 1 channel)
    image = (image - mean) / std
    '''

    def __init__(self, mu=0.1307, std=0.3081):
        super(NormLayer, self).__init__()
        self.mean = mu
        self.std = std

    def forward(self, x):
        return (x - self.mean) / self.std

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, x):
    return (x - self.mean) / self.std
class ReshapeLayer (shape)

Returns a torch module which reshapes an input to a desired shape Params


shape : tuple
shape excluding batch size

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class ReshapeLayer(nn.Module):
    '''Returns a torch module which reshapes an input to a desired shape
    Params
    ------
    shape: tuple
        shape excluding batch size
    '''

    def __init__(self, shape):
        super(ReshapeLayer, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.reshape(x.shape[0], *self.shape)

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, x):
    return x.reshape(x.shape[0], *self.shape)
class TrimModel (model, inv_transform, norm=None, reshape=None, use_residuals=False, use_logits=False)

Prepends transformation onto network (with optional normalizaiton after the transform) Params


model : nn.Module
model after all the transformations
inv_transform : nn.Module
the inverse transform
norm : nn.Module (Norm_Layer)
normalization to apply after the inverse transform
reshape : nn.Module
reshape to apply after the normalization
use_residuals : bool, optional
whether or not to apply the residuals after the transformation (for transformations which are not perfectly invertible)
use_logits : bool, optional
whether to use the logits (if the model has them) or the forward function
n_components : int
right now this setup is kind of weird - if you want to pass a residual pass x as a 1d vector whose last entries contain the residual [x, residual]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class TrimModel(nn.Module):
    '''Prepends transformation onto network (with optional normalizaiton after the transform)
    Params
    ------
    model: nn.Module
        model after all the transformations
    inv_transform: nn.Module
        the inverse transform
    norm: nn.Module (Norm_Layer)
        normalization to apply after the inverse transform
    reshape: nn.Module
        reshape to apply after the normalization
    use_residuals: bool, optional
        whether or not to apply the residuals after the transformation 
        (for transformations which are not perfectly invertible)
    use_logits: bool, optional
        whether to use the logits (if the model has them) or the forward function
    n_components: int
        right now this setup is kind of weird - if you want to pass a residual
        pass x as a 1d vector whose last entries contain the residual [x, residual]
    '''

    def __init__(self, model, inv_transform, norm=None, reshape=None,
                 use_residuals=False, use_logits=False):
        super(TrimModel, self).__init__()
        self.inv_transform = inv_transform
        self.norm = norm
        self.reshape = reshape
        self.model = model
        self.use_residuals = use_residuals
        self.use_logits = use_logits

    def forward(self, s, x_orig=None):
        '''
        Params
        ------
        s: torch.Tensor
            This should be the input in the transformed space which we want to interpret
            (batch_size, C, H, W) for images
            (batch_size, C, seq_length) for audio
        '''
        # untransform the input
        x = self.inv_transform(s)

        # take residuals into account
        if self.use_residuals:
            assert x_orig is not None, "if using residuals, must also pass untransformed original image!"
            res = x_orig - x.detach()
            x = x + res

        # normalize
        if self.norm is not None:
            x = self.norm(x)

        # reshape
        if self.reshape is not None:
            x = self.reshape(x)

        # pass through the main model
        if self.use_logits:
            x = self.model.logits(x)
        else:
            x = self.model.forward(x)
        return x

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, s, x_orig=None)

Params

s : torch.Tensor
This should be the input in the transformed space which we want to interpret (batch_size, C, H, W) for images (batch_size, C, seq_length) for audio
Expand source code
def forward(self, s, x_orig=None):
    '''
    Params
    ------
    s: torch.Tensor
        This should be the input in the transformed space which we want to interpret
        (batch_size, C, H, W) for images
        (batch_size, C, seq_length) for audio
    '''
    # untransform the input
    x = self.inv_transform(s)

    # take residuals into account
    if self.use_residuals:
        assert x_orig is not None, "if using residuals, must also pass untransformed original image!"
        res = x_orig - x.detach()
        x = x + res

    # normalize
    if self.norm is not None:
        x = self.norm(x)

    # reshape
    if self.reshape is not None:
        x = self.reshape(x)

    # pass through the main model
    if self.use_logits:
        x = self.model.logits(x)
    else:
        x = self.model.forward(x)
    return x