Module awave.trim.transforms_torch
Expand source code
from copy import deepcopy
import numpy as np
import torch
from numpy.fft import *
def bandpass_filter(im: torch.Tensor, band_center=0.3, band_width_lower=0.1, band_width_upper=0.1):
'''Bandpass filter the image (assumes the image is square)
Returns
-------
im_bandpass: torch.Tensor
B, C, H, W
'''
freq_arr = fftshift(fftfreq(n=im.shape[-1]))
freq_arr /= np.max(np.abs(freq_arr))
im_c = torch.stack((im, torch.zeros_like(im)), dim=4)
im_f = batch_fftshift2d(torch.fft(im_c, 2))
mask_bandpass = torch.zeros(im_f.shape)
for r in range(im_f.shape[2]):
for c in range(im_f.shape[3]):
dist = np.sqrt(freq_arr[r] ** 2 + freq_arr[c] ** 2)
if dist >= band_center - band_width_lower and dist < band_center + band_width_upper:
mask_bandpass[:, :, r, c, :] = 1
if im.is_cuda:
mask_bandpass = mask_bandpass.to("cuda")
im_f_masked = torch.mul(im_f, mask_bandpass)
im_bandpass = torch.ifft(batch_ifftshift2d(im_f_masked), 2)[..., 0]
return im_bandpass
def transform_bandpass(im: torch.Tensor, band_center=0.3, band_width_lower=0.1, band_width_upper=0.1):
return im - bandpass_filter(im, band_center, band_width_lower, band_width_upper)
def tensor_t_augment(im: torch.Tensor, t):
'''
Returns
-------
im: torch.Tensor
2*B, C, H, W
'''
im_copy = deepcopy(im)
im_p = t(im)
return torch.cat((im_copy, im_p), dim=0)
def wavelet_filter(im: torch.Tensor, t, transform_i, idx=2, p=0.5):
'''Filter center of highpass wavelet coeffs
Params
------
im : torch.Tensor
idx : detail coefficients ('LH':0, 'HL':1, 'HH':2)
p : prop to perturb coeffs
'''
im_t = t(im)
# mask = torch.bernoulli((1-p) * torch.ones(im.shape[0], 5, 5))
# im_t[1][0][:,0,idx,6:11,6:11] = im_t[1][0][:,0,idx,6:11,6:11] * mask
im_t[1][0][:, 0, idx, 6:11, 6:11] = 0
return transform_i(im_t)
'''This code from https://github.com/tomrunia/PyTorchSteerablePyramid
'''
def roll_n(X, axis, n):
f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim()))
b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim()))
front = X[f_idx]
back = X[b_idx]
return torch.cat([back, front], axis)
def batch_fftshift2d(x):
real, imag = torch.unbind(x, -1)
for dim in range(1, len(real.size())):
n_shift = real.size(dim) // 2
if real.size(dim) % 2 != 0:
n_shift += 1 # for odd-sized images
real = roll_n(real, axis=dim, n=n_shift)
imag = roll_n(imag, axis=dim, n=n_shift)
return torch.stack((real, imag), -1) # last dim=2 (real&imag)
def batch_ifftshift2d(x):
real, imag = torch.unbind(x, -1)
for dim in range(len(real.size()) - 1, 0, -1):
real = roll_n(real, axis=dim, n=real.size(dim) // 2)
imag = roll_n(imag, axis=dim, n=imag.size(dim) // 2)
return torch.stack((real, imag), -1) # last dim=2 (real&imag)
Functions
def bandpass_filter(im, band_center=0.3, band_width_lower=0.1, band_width_upper=0.1)
-
Bandpass filter the image (assumes the image is square)
Returns
im_bandpass
:torch.Tensor
- B, C, H, W
Expand source code
def bandpass_filter(im: torch.Tensor, band_center=0.3, band_width_lower=0.1, band_width_upper=0.1): '''Bandpass filter the image (assumes the image is square) Returns ------- im_bandpass: torch.Tensor B, C, H, W ''' freq_arr = fftshift(fftfreq(n=im.shape[-1])) freq_arr /= np.max(np.abs(freq_arr)) im_c = torch.stack((im, torch.zeros_like(im)), dim=4) im_f = batch_fftshift2d(torch.fft(im_c, 2)) mask_bandpass = torch.zeros(im_f.shape) for r in range(im_f.shape[2]): for c in range(im_f.shape[3]): dist = np.sqrt(freq_arr[r] ** 2 + freq_arr[c] ** 2) if dist >= band_center - band_width_lower and dist < band_center + band_width_upper: mask_bandpass[:, :, r, c, :] = 1 if im.is_cuda: mask_bandpass = mask_bandpass.to("cuda") im_f_masked = torch.mul(im_f, mask_bandpass) im_bandpass = torch.ifft(batch_ifftshift2d(im_f_masked), 2)[..., 0] return im_bandpass
def batch_fftshift2d(x)
-
Expand source code
def batch_fftshift2d(x): real, imag = torch.unbind(x, -1) for dim in range(1, len(real.size())): n_shift = real.size(dim) // 2 if real.size(dim) % 2 != 0: n_shift += 1 # for odd-sized images real = roll_n(real, axis=dim, n=n_shift) imag = roll_n(imag, axis=dim, n=n_shift) return torch.stack((real, imag), -1) # last dim=2 (real&imag)
def batch_ifftshift2d(x)
-
Expand source code
def batch_ifftshift2d(x): real, imag = torch.unbind(x, -1) for dim in range(len(real.size()) - 1, 0, -1): real = roll_n(real, axis=dim, n=real.size(dim) // 2) imag = roll_n(imag, axis=dim, n=imag.size(dim) // 2) return torch.stack((real, imag), -1) # last dim=2 (real&imag)
def roll_n(X, axis, n)
-
Expand source code
def roll_n(X, axis, n): f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim())) b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim())) front = X[f_idx] back = X[b_idx] return torch.cat([back, front], axis)
def tensor_t_augment(im, t)
-
Returns
im
:torch.Tensor
- 2*B, C, H, W
Expand source code
def tensor_t_augment(im: torch.Tensor, t): ''' Returns ------- im: torch.Tensor 2*B, C, H, W ''' im_copy = deepcopy(im) im_p = t(im) return torch.cat((im_copy, im_p), dim=0)
def transform_bandpass(im, band_center=0.3, band_width_lower=0.1, band_width_upper=0.1)
-
Expand source code
def transform_bandpass(im: torch.Tensor, band_center=0.3, band_width_lower=0.1, band_width_upper=0.1): return im - bandpass_filter(im, band_center, band_width_lower, band_width_upper)
def wavelet_filter(im, t, transform_i, idx=2, p=0.5)
-
Filter center of highpass wavelet coeffs
Params
- im : torch.Tensor
idx
:detail
coefficients
('LH'
:0
,'HL'
:1
,'HH'
:2
)
p : prop to perturb coeffs
Expand source code
def wavelet_filter(im: torch.Tensor, t, transform_i, idx=2, p=0.5): '''Filter center of highpass wavelet coeffs Params ------ im : torch.Tensor idx : detail coefficients ('LH':0, 'HL':1, 'HH':2) p : prop to perturb coeffs ''' im_t = t(im) # mask = torch.bernoulli((1-p) * torch.ones(im.shape[0], 5, 5)) # im_t[1][0][:,0,idx,6:11,6:11] = im_t[1][0][:,0,idx,6:11,6:11] * mask im_t[1][0][:, 0, idx, 6:11, 6:11] = 0 return transform_i(im_t)