Module awave.trim.funcs

Expand source code
import sys

import numpy as np
import torch

sys.path.append('..')


def prox_positive(x):
    return torch.nn.functional.threshold(x, 0, 0)


def prox_identity(x):
    return x


def prox_soft_threshold(x, lamb):
    return torch.sign(x) * torch.nn.functional.threshold(torch.abs(x) - lamb, 0, 0)


def prox_hard_threshold(x, k):
    # hard-threshold each row of x
    x = x.clone().detach().cpu()
    m = x.data.shape[1]
    a, _ = torch.abs(x).data.sort(dim=1, descending=True)
    thresh = torch.mm(a[:, k].unsqueeze(1), torch.Tensor(np.ones((1, m))))
    mask = torch.tensor((np.abs(x.data.cpu().numpy()) > thresh.cpu().numpy()) + 0., dtype=torch.float)
    return (x * mask).to(device)


def prox_normalization(x):
    '''
    x : (B,C,H,W) tensor
    '''
    norm = torch.norm(x, dim=(2, 3)).unsqueeze(2).unsqueeze(3)
    return x / norm

Functions

def prox_hard_threshold(x, k)
Expand source code
def prox_hard_threshold(x, k):
    # hard-threshold each row of x
    x = x.clone().detach().cpu()
    m = x.data.shape[1]
    a, _ = torch.abs(x).data.sort(dim=1, descending=True)
    thresh = torch.mm(a[:, k].unsqueeze(1), torch.Tensor(np.ones((1, m))))
    mask = torch.tensor((np.abs(x.data.cpu().numpy()) > thresh.cpu().numpy()) + 0., dtype=torch.float)
    return (x * mask).to(device)
def prox_identity(x)
Expand source code
def prox_identity(x):
    return x
def prox_normalization(x)

x : (B,C,H,W) tensor

Expand source code
def prox_normalization(x):
    '''
    x : (B,C,H,W) tensor
    '''
    norm = torch.norm(x, dim=(2, 3)).unsqueeze(2).unsqueeze(3)
    return x / norm
def prox_positive(x)
Expand source code
def prox_positive(x):
    return torch.nn.functional.threshold(x, 0, 0)
def prox_soft_threshold(x, lamb)
Expand source code
def prox_soft_threshold(x, lamb):
    return torch.sign(x) * torch.nn.functional.threshold(torch.abs(x) - lamb, 0, 0)