Skip to content

Fix data-consistency module #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,41 @@ def apply_mask(data, mask_func, seed=None):
return data * mask, mask


def fft2c(x, dim=(-2, -1)):
""" Centered 2D Fast Fourier Transform

Args:
x (torch.Tensor): Complex valued input data containing at least 3
dimensions: dimensions -2 & -1 are spatial dimensions. All other
dimensions are assumed to be batch dimensions.

dim (tuple): Dimensions to apply the FFT along. Default is (-2, -1)

Returns:
torch.Tensor: The FFT of the input.
"""
x = torch.fft.ifftshift(x, dim=dim)
x = torch.fft.fft2(x, dim=dim)
return torch.fft.fftshift(x, dim=dim)


def ifft2c(x, dim=(-2, -1)):
""" Centered 2D Inverse Fast Fourier Transform

Args:
x (torch.Tensor): Complex valued input data containing at least 3
dimensions: dimensions -2 & -1 are spatial dimensions. All other
dimensions are assumed to be batch dimensions.
dim (tuple): Dimensions to apply the IFFT along. Default is (-2, -1)

Returns:
torch.Tensor: The IFFT of the input.
"""
x = torch.fft.ifftshift(x, dim=dim)
x = torch.fft.ifft2(x, dim=dim)
return torch.fft.fftshift(x, dim=dim)


def fft2(data, normalized=True):
"""
Apply centered 2 dimensional Fast Fourier Transform.
Expand Down
47 changes: 26 additions & 21 deletions models/Recurrent_Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,17 @@
from torch.nn import functional as F
import numpy as np

class DataConsistencyInKspace(nn.Module):
""" Create data consistency operator

Warning: note that FFT2 (by the default of torch.fft) is applied to the last 2 axes of the input.
This method detects if the input tensor is 4-dim (2D data) or 5-dim (3D data)
and applies FFT2 to the (nx, ny) axis.

"""
class DataConsistencyInKspace(nn.Module):
""" Data consistency layer in k-space. """

def __init__(self):
super(DataConsistencyInKspace, self).__init__()

def forward(self, *input, **kwargs):
return self.perform(*input)

def data_consistency(self,k, k0, mask):
def data_consistency(self, k, k0, mask):
"""
k - input in k-space
k0 - initially sampled elements in k-space
Expand All @@ -36,23 +31,33 @@ def data_consistency(self,k, k0, mask):
return out

def perform(self, x, k0, mask):
""" Forward pass to enforce data consistency in k-space.

Args:
x (torch.Tensor): Input image in spatial domain (batch_size, 2, height, width).
k0 (torch.Tensor): Measured k-space data (batch_size, 2, height, width).
mask (torch.Tensor): Binary mask indicating sampled k-space locations (batch_size, 1, height, width).

Returns:
torch.Tensor: Corrected image with the same shape as input.
"""
x - input in image domain, of shape (n, 2, nx, ny[, nt])
k0 - initially sampled elements in k-space
mask - corresponding nonzero location
"""
x = x.permute(0, 2, 3, 1)
k0 = k0.permute(0, 2, 3, 1)
mask = mask.permute(0, 2, 3, 1)
x_cx = torch.complex(x[:, 0], x[:, 1]).unsqueeze(1)
k0_cx = torch.complex(k0[:, 0], k0[:, 1]).unsqueeze(1)

k = transforms.fft2(x)
# Fourier transform
x_kspace = transforms.fft2c(x_cx)

out = self.data_consistency(k, k0, mask)
x_res = transforms.ifft2(out)
# Fill in k-space
x_kspace = self.data_consistency(x_kspace, k0_cx, mask)

x_res = x_res.permute(0, 3, 1, 2)
# Inverse Fourier transform
out = transforms.ifft2c(x_kspace)

# Stack real and imaginary parts
out = torch.cat((out.real, out.imag), dim=1)

return out

return x_res

class RFB(nn.Module):
"""
Expand Down