Module awave.utils.visualize

Expand source code
import matplotlib.pyplot as plt
import torch
from matplotlib import gridspec
from skimage.transform import rescale


def cshow(im):
    plt.imshow(im, cmap='magma', vmax=0.15, vmin=-0.05)
    plt.axis('off')


def plot_2dreconstruct(im, recon):
    if 'Tensor' in str(type(im)):
        im = im.detach().data.cpu()
        recon = recon.detach().data.cpu()
    res = im - recon
    pl = [im, recon, res]

    R = 3
    C = min(im.size(0), 10)
    plt.figure(figsize=(C + 1, R + 1), dpi=200)
    gs = gridspec.GridSpec(R, C,
                           wspace=0.0, hspace=0.0,
                           top=1. - 0.5 / (R + 1), bottom=0.5 / (R + 1),
                           left=0.5 / (C + 1), right=1 - 0.5 / (C + 1))

    for r in range(R):
        for c in range(C):
            ax = plt.subplot(gs[r, c])
            ax.imshow(pl[r][c][0], cmap='magma', vmax=0.15, vmin=-0.05)
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.tick_params(
                axis='both',
                which='both',
                bottom=False,
                top=False,
                left=False,
                right=False,
                labelbottom=False)
    plt.show()


def plot_2dfilts(filts: list, scale=2, share_min_max=True, figsize=(1, 1)):
    '''Plot filters in the list
    Params
    ------
    filts: list
        list of filters
    figsize: tuple
        figure size    
    '''
    ls = len(filts)
    v_min = 1e4
    v_max = -1e4
    for i in range(ls):
        v_min = min(filts[i].min(), v_min)
        v_max = max(filts[i].max(), v_max)

    fig = plt.figure(figsize=figsize, dpi=200)
    gs = gridspec.GridSpec(ncols=2, nrows=2, figure=fig)

    i = 0
    for r in range(2):
        for c in range(2):
            ax = plt.subplot(gs[r, c])
            if share_min_max:
                ax.imshow(rescale(filts[i], scale, mode='constant'), cmap='gray', vmin=v_min, vmax=v_max)
            else:
                ax.imshow(rescale(filts[i], scale, mode='constant'), cmap='gray')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.tick_params(
                axis='both',
                which='both',
                bottom=False,
                top=False,
                left=False,
                right=False,
                labelbottom=False)
            i += 1
    plt.tight_layout()
    plt.show()


def plot_1dreconstruct(data, recon):
    if 'Tensor' in str(type(data)):
        data = data.detach().data.cpu()
        recon = recon.detach().data.cpu()
    res = data - recon
    pl = [data, recon, res]
    vmax = torch.max(data).item()
    vmin = torch.min(data).item()

    R = 3
    C = min(data.size(0), 10)
    plt.figure(figsize=(C + 1, R + 1), dpi=200)
    gs = gridspec.GridSpec(R, C,
                           wspace=0.0, hspace=0.0,
                           top=1. - 0.5 / (R + 1), bottom=0.5 / (R + 1),
                           left=0.5 / (C + 1), right=1 - 0.5 / (C + 1))

    labs = ['Original', 'Reconstruction', 'Residual']
    for r in range(R):
        for c in range(C):
            ax = plt.subplot(gs[r, c])
            ax.plot(pl[r][c][0])
            ax.set_ylim((vmin - 1, vmax))
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.tick_params(
                axis='both',
                which='both',
                bottom=False,
                top=False,
                left=False,
                right=False,
                labelbottom=False)
            if c == 0:
                plt.ylabel(labs[r])
    plt.show()


def plot_1dfilts(filts: list, is_title=False, figsize=(10, 10)):
    '''Plot filters in the list
    Params
    ------
    filts: list
        list of filters
    figsize: tuple
        figure size    
    '''
    ls = len(filts)
    v_min = 1e4
    v_max = -1e4

    for i in range(ls):
        v_min = min(filts[i].min(), v_min)
        v_max = max(filts[i].max(), v_max)
    titles = ['lowpass', 'highpass']

    plt.figure(figsize=figsize, dpi=200)
    for i in range(ls):
        plt.subplot(1, ls, i + 1)
        plt.plot(filts[i])
        plt.ylim((v_min - 1, v_max + 1))
        plt.axis('off')
        if is_title is True:
            plt.title(titles[i])
    plt.show()


def plot_wavefun(waves: tuple, is_title=False, figsize=(10, 10), flip_wavelet=False):
    '''Plot filters in the list
    Params
    ------
    waves: tuple
        tuple of scaling and wavelet functions
    figsize: tuple
        figure size    
    '''

    titles = ['scaling', 'wavelet']
    plt.figure(figsize=figsize, dpi=300)
    for i in range(2):
        plt.subplot(1, 2, i + 1)
        plt.plot(waves[-1], waves[i])
        plt.axis('off')
        if is_title is True:
            plt.title(titles[i])
    plt.show()

Functions

def cshow(im)
Expand source code
def cshow(im):
    plt.imshow(im, cmap='magma', vmax=0.15, vmin=-0.05)
    plt.axis('off')
def plot_1dfilts(filts, is_title=False, figsize=(10, 10))

Plot filters in the list Params


filts : list
list of filters
figsize : tuple
figure size
Expand source code
def plot_1dfilts(filts: list, is_title=False, figsize=(10, 10)):
    '''Plot filters in the list
    Params
    ------
    filts: list
        list of filters
    figsize: tuple
        figure size    
    '''
    ls = len(filts)
    v_min = 1e4
    v_max = -1e4

    for i in range(ls):
        v_min = min(filts[i].min(), v_min)
        v_max = max(filts[i].max(), v_max)
    titles = ['lowpass', 'highpass']

    plt.figure(figsize=figsize, dpi=200)
    for i in range(ls):
        plt.subplot(1, ls, i + 1)
        plt.plot(filts[i])
        plt.ylim((v_min - 1, v_max + 1))
        plt.axis('off')
        if is_title is True:
            plt.title(titles[i])
    plt.show()
def plot_1dreconstruct(data, recon)
Expand source code
def plot_1dreconstruct(data, recon):
    if 'Tensor' in str(type(data)):
        data = data.detach().data.cpu()
        recon = recon.detach().data.cpu()
    res = data - recon
    pl = [data, recon, res]
    vmax = torch.max(data).item()
    vmin = torch.min(data).item()

    R = 3
    C = min(data.size(0), 10)
    plt.figure(figsize=(C + 1, R + 1), dpi=200)
    gs = gridspec.GridSpec(R, C,
                           wspace=0.0, hspace=0.0,
                           top=1. - 0.5 / (R + 1), bottom=0.5 / (R + 1),
                           left=0.5 / (C + 1), right=1 - 0.5 / (C + 1))

    labs = ['Original', 'Reconstruction', 'Residual']
    for r in range(R):
        for c in range(C):
            ax = plt.subplot(gs[r, c])
            ax.plot(pl[r][c][0])
            ax.set_ylim((vmin - 1, vmax))
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.tick_params(
                axis='both',
                which='both',
                bottom=False,
                top=False,
                left=False,
                right=False,
                labelbottom=False)
            if c == 0:
                plt.ylabel(labs[r])
    plt.show()
def plot_2dfilts(filts, scale=2, share_min_max=True, figsize=(1, 1))

Plot filters in the list Params


filts : list
list of filters
figsize : tuple
figure size
Expand source code
def plot_2dfilts(filts: list, scale=2, share_min_max=True, figsize=(1, 1)):
    '''Plot filters in the list
    Params
    ------
    filts: list
        list of filters
    figsize: tuple
        figure size    
    '''
    ls = len(filts)
    v_min = 1e4
    v_max = -1e4
    for i in range(ls):
        v_min = min(filts[i].min(), v_min)
        v_max = max(filts[i].max(), v_max)

    fig = plt.figure(figsize=figsize, dpi=200)
    gs = gridspec.GridSpec(ncols=2, nrows=2, figure=fig)

    i = 0
    for r in range(2):
        for c in range(2):
            ax = plt.subplot(gs[r, c])
            if share_min_max:
                ax.imshow(rescale(filts[i], scale, mode='constant'), cmap='gray', vmin=v_min, vmax=v_max)
            else:
                ax.imshow(rescale(filts[i], scale, mode='constant'), cmap='gray')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.tick_params(
                axis='both',
                which='both',
                bottom=False,
                top=False,
                left=False,
                right=False,
                labelbottom=False)
            i += 1
    plt.tight_layout()
    plt.show()
def plot_2dreconstruct(im, recon)
Expand source code
def plot_2dreconstruct(im, recon):
    if 'Tensor' in str(type(im)):
        im = im.detach().data.cpu()
        recon = recon.detach().data.cpu()
    res = im - recon
    pl = [im, recon, res]

    R = 3
    C = min(im.size(0), 10)
    plt.figure(figsize=(C + 1, R + 1), dpi=200)
    gs = gridspec.GridSpec(R, C,
                           wspace=0.0, hspace=0.0,
                           top=1. - 0.5 / (R + 1), bottom=0.5 / (R + 1),
                           left=0.5 / (C + 1), right=1 - 0.5 / (C + 1))

    for r in range(R):
        for c in range(C):
            ax = plt.subplot(gs[r, c])
            ax.imshow(pl[r][c][0], cmap='magma', vmax=0.15, vmin=-0.05)
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.tick_params(
                axis='both',
                which='both',
                bottom=False,
                top=False,
                left=False,
                right=False,
                labelbottom=False)
    plt.show()
def plot_wavefun(waves, is_title=False, figsize=(10, 10), flip_wavelet=False)

Plot filters in the list Params


waves : tuple
tuple of scaling and wavelet functions
figsize : tuple
figure size
Expand source code
def plot_wavefun(waves: tuple, is_title=False, figsize=(10, 10), flip_wavelet=False):
    '''Plot filters in the list
    Params
    ------
    waves: tuple
        tuple of scaling and wavelet functions
    figsize: tuple
        figure size    
    '''

    titles = ['scaling', 'wavelet']
    plt.figure(figsize=figsize, dpi=300)
    for i in range(2):
        plt.subplot(1, 2, i + 1)
        plt.plot(waves[-1], waves[i])
        plt.axis('off')
        if is_title is True:
            plt.title(titles[i])
    plt.show()