Module awave.transform1d
Expand source code
import torch
import torch.nn as nn
from awave.utils import lowlevel
from awave.transform import AbstractWT
from awave.utils.misc import init_filter, low_to_high
class DWT1d(AbstractWT):
'''Class of 1d wavelet transform
Params
------
J: int
number of levels of decomposition
wave: str
which wavelet to use.
can be:
1) a string to pass to pywt.Wavelet constructor
2) a pywt.Wavelet class
3) a tuple of numpy arrays, either (h0, h1) or (h0_col, h1_col, h0_row, h1_row)
mode: str
'zero', 'symmetric', 'reflect' or 'periodization'. The padding scheme
'''
def __init__(self, wave='db3', mode='zero', J=5, init_factor=1, noise_factor=0, const_factor=0, device='cpu'):
super().__init__()
h0, _ = lowlevel.load_wavelet(wave)
# initialize
h0 = init_filter(h0, init_factor, noise_factor, const_factor)
# parameterize
self.h0 = nn.Parameter(h0, requires_grad=True)
self.h0 = self.h0.to(device)
self.J = J
self.mode = mode
self.wt_type = 'DWT1d'
self.device = device
def forward(self, x):
""" Forward pass of the DWT.
Args:
x (tensor): Input of shape :math:`(N, C_{in}, L_{in})`
Returns:
(yl, yh)
tuple of lowpass (yl) and bandpass (yh) coefficients.
yh is a list of length J with the first entry
being the finest scale coefficients.
"""
assert x.ndim == 3, "Can only handle 3d inputs (N, C, L)"
highs = ()
x0 = x
mode = lowlevel.mode_to_int(self.mode)
h1 = low_to_high(self.h0)
# Do a multilevel transform
for j in range(self.J):
x0, x1 = lowlevel.AFB1D.forward(x0, self.h0, h1, mode)
highs += (x1,)
return (x0,) + highs
def inverse(self, coeffs):
"""
Args:
coeffs (yl, yh): tuple of lowpass and bandpass coefficients, should
match the format returned by DWT1DForward.
Returns:
Reconstructed input of shape :math:`(N, C_{in}, L_{in})`
Note:
Can have None for any of the highpass scales and will treat the
values as zeros (not in an efficient way though).
"""
coeffs = list(coeffs)
x0 = coeffs.pop(0)
highs = coeffs
assert x0.ndim == 3, "Can only handle 3d inputs (N, C, L)"
mode = lowlevel.mode_to_int(self.mode)
h1 = low_to_high(self.h0)
# Do a multilevel inverse transform
for x1 in highs[::-1]:
if x1 is None:
x1 = torch.zeros_like(x0)
# 'Unpad' added signal
if x0.shape[-1] > x1.shape[-1]:
x0 = x0[..., :-1]
x0 = lowlevel.SFB1D.forward(x0, x1, self.h0, h1, mode)
return x0
Classes
class DWT1d (wave='db3', mode='zero', J=5, init_factor=1, noise_factor=0, const_factor=0, device='cpu')
-
Class of 1d wavelet transform Params
J
:int
- number of levels of decomposition
wave
:str
- which wavelet to use. can be: 1) a string to pass to pywt.Wavelet constructor 2) a pywt.Wavelet class 3) a tuple of numpy arrays, either (h0, h1) or (h0_col, h1_col, h0_row, h1_row)
mode
:str
- 'zero', 'symmetric', 'reflect' or 'periodization'. The padding scheme
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class DWT1d(AbstractWT): '''Class of 1d wavelet transform Params ------ J: int number of levels of decomposition wave: str which wavelet to use. can be: 1) a string to pass to pywt.Wavelet constructor 2) a pywt.Wavelet class 3) a tuple of numpy arrays, either (h0, h1) or (h0_col, h1_col, h0_row, h1_row) mode: str 'zero', 'symmetric', 'reflect' or 'periodization'. The padding scheme ''' def __init__(self, wave='db3', mode='zero', J=5, init_factor=1, noise_factor=0, const_factor=0, device='cpu'): super().__init__() h0, _ = lowlevel.load_wavelet(wave) # initialize h0 = init_filter(h0, init_factor, noise_factor, const_factor) # parameterize self.h0 = nn.Parameter(h0, requires_grad=True) self.h0 = self.h0.to(device) self.J = J self.mode = mode self.wt_type = 'DWT1d' self.device = device def forward(self, x): """ Forward pass of the DWT. Args: x (tensor): Input of shape :math:`(N, C_{in}, L_{in})` Returns: (yl, yh) tuple of lowpass (yl) and bandpass (yh) coefficients. yh is a list of length J with the first entry being the finest scale coefficients. """ assert x.ndim == 3, "Can only handle 3d inputs (N, C, L)" highs = () x0 = x mode = lowlevel.mode_to_int(self.mode) h1 = low_to_high(self.h0) # Do a multilevel transform for j in range(self.J): x0, x1 = lowlevel.AFB1D.forward(x0, self.h0, h1, mode) highs += (x1,) return (x0,) + highs def inverse(self, coeffs): """ Args: coeffs (yl, yh): tuple of lowpass and bandpass coefficients, should match the format returned by DWT1DForward. Returns: Reconstructed input of shape :math:`(N, C_{in}, L_{in})` Note: Can have None for any of the highpass scales and will treat the values as zeros (not in an efficient way though). """ coeffs = list(coeffs) x0 = coeffs.pop(0) highs = coeffs assert x0.ndim == 3, "Can only handle 3d inputs (N, C, L)" mode = lowlevel.mode_to_int(self.mode) h1 = low_to_high(self.h0) # Do a multilevel inverse transform for x1 in highs[::-1]: if x1 is None: x1 = torch.zeros_like(x0) # 'Unpad' added signal if x0.shape[-1] > x1.shape[-1]: x0 = x0[..., :-1] x0 = lowlevel.SFB1D.forward(x0, x1, self.h0, h1, mode) return x0
Ancestors
- AbstractWT
- torch.nn.modules.module.Module
Methods
def forward(self, x)
-
Forward pass of the DWT.
Args
x
:tensor
- Input of shape :math:
(N, C_{in}, L_{in})
Returns
(yl, yh) tuple of lowpass (yl) and bandpass (yh) coefficients. yh is a list of length J with the first entry being the finest scale coefficients.
Expand source code
def forward(self, x): """ Forward pass of the DWT. Args: x (tensor): Input of shape :math:`(N, C_{in}, L_{in})` Returns: (yl, yh) tuple of lowpass (yl) and bandpass (yh) coefficients. yh is a list of length J with the first entry being the finest scale coefficients. """ assert x.ndim == 3, "Can only handle 3d inputs (N, C, L)" highs = () x0 = x mode = lowlevel.mode_to_int(self.mode) h1 = low_to_high(self.h0) # Do a multilevel transform for j in range(self.J): x0, x1 = lowlevel.AFB1D.forward(x0, self.h0, h1, mode) highs += (x1,) return (x0,) + highs
def inverse(self, coeffs)
-
Args
coeffs
:yl
,yh
- tuple of lowpass and bandpass coefficients, should match the format returned by DWT1DForward.
Returns
Reconstructed input of shape :math:
(N, C_{in}, L_{in})
Note
Can have None for any of the highpass scales and will treat the values as zeros (not in an efficient way though).
Expand source code
def inverse(self, coeffs): """ Args: coeffs (yl, yh): tuple of lowpass and bandpass coefficients, should match the format returned by DWT1DForward. Returns: Reconstructed input of shape :math:`(N, C_{in}, L_{in})` Note: Can have None for any of the highpass scales and will treat the values as zeros (not in an efficient way though). """ coeffs = list(coeffs) x0 = coeffs.pop(0) highs = coeffs assert x0.ndim == 3, "Can only handle 3d inputs (N, C, L)" mode = lowlevel.mode_to_int(self.mode) h1 = low_to_high(self.h0) # Do a multilevel inverse transform for x1 in highs[::-1]: if x1 is None: x1 = torch.zeros_like(x0) # 'Unpad' added signal if x0.shape[-1] > x1.shape[-1]: x0 = x0[..., :-1] x0 = lowlevel.SFB1D.forward(x0, x1, self.h0, h1, mode) return x0
Inherited members