Search
Code Example

This code example is a Jupyter notebook with Script of Scripts (SoS) workflow. It reproduces interactive figures for the highlighted paper by first author Nan-kuei Chen and second author Pei-Hsin Wu.

The calculations are written using Julia (from this fetched from repo), and the interactive figures are written in Python 3.6 with the ploting library Plotly.

%use Julia

###### define the parameters in the next 7 lines
magnitude_nii_file = string("data/magnitude.nii");
phase_nii_file = string("data/phase.nii");
partiallyUnwrapped_nii_file = string("data/partiallyUnwrappedMap.nii");
criticalROI_nii_file = string("data/criticalROI.nii");
ChooseThisSlice = Int64(5); # process data in a chosen slice; e.g., slice #5 in this case
criticalROILabel = Int64(1); # unwrap phase values inside a critical ROI labeled with an integer number; e.g., 1 in this case
output_nii_file = string("data/output.nii");
%use Julia 

###### The main program starts here; no need to modify the code below (although you could use different value of patchsize)
push!(LOAD_PATH,"juliafunction");
using PyPlot
using Read_NIfTI1
using myFun
using Distributed
@everywhere using FFTW
@everywhere using LinearAlgebra
@everywhere using SharedArrays

function julia_main(magnitude_nii_file::String,phase_nii_file::String,partiallyUnwrapped_nii_file::String,criticalROI_nii_file::String,ChooseThisSlice::Int64,criticalROILabel::Int64)
    headerinfo1 = load_nii_header(magnitude_nii_file);
    data001 = load_nii_data(magnitude_nii_file, headerinfo1);
    headerinfo1 = load_nii_header(phase_nii_file);
    data002 = load_nii_data(phase_nii_file, headerinfo1);
    imagedata_noise = data001 .* exp.(complex(0,1)*data002);

    imgsn = imagedata_noise[:,:,ChooseThisSlice];
    ksn = ifftshift(ifft(ifftshift(imgsn)));

    @everywhere function sidm(kdata::Array{Complex{Float32},2})
        kabs = abs.(kdata);
        pxpy = findall(kabs .== maximum(kabs));
        px = pxpy[1][1];
        py = pxpy[1][2];
        return (px,py);
    end

    datasize1,datasize2 = size(imgsn)

    pxmap = SharedArray{Float64,2}((datasize1,datasize2));
    pymap = SharedArray{Float64,2}((datasize1,datasize2));

    patchsize = 7;
    patchsize = patchsize + div(1 + (-1).^patchsize,2)
    ps1 = max(div((patchsize-1),2),1);

    @time @sync @distributed for cntx = 1:datasize1
        @inbounds @fastmath @simd for cnty = 1:datasize2
            tmp1 = zeros(ComplexF32,(datasize1,datasize2));
            startingx = max(1,cntx-ps1);
            startingy = max(1,cnty-ps1);
            endingx = min(cntx+ps1,datasize1);
            endingy = min(cnty+ps1,datasize2);
            tmp1[startingx:endingx,startingy:endingy] = imgsn[startingx:endingx,startingy:endingy];
            tmp2 = ifftshift(ifft(ifftshift(tmp1)));
            px,py = sidm(tmp2);
            pxmap[cntx,cnty]=px;
            pymap[cntx,cnty]=py;
        end
    end

    pxmap2a = pxmap.-(datasize1/2);
    pymap2a = pymap.-(datasize2/2);
    pxmap_radppixel = -pxmap2a*2*π/datasize1;
    pymap_radppixel = -pymap2a*2*π/datasize2;

    headerinfo4 = load_nii_header(partiallyUnwrapped_nii_file);
    preludephasemap_all = load_nii_data(partiallyUnwrapped_nii_file, headerinfo4);
    preludephasemap = preludephasemap_all[:,:,ChooseThisSlice];

    preludemask = ones(datasize1,datasize2);
    L = findall(preludephasemap.==0);
    preludemask[L].=0;
    preludephasemap_cpe = closest_point_estimation(preludephasemap,preludemask);

    headerinfo5 = load_nii_header(criticalROI_nii_file);
    criticalROI_all = load_nii_data(criticalROI_nii_file, headerinfo5);
    criticalROI = criticalROI_all[:,:,ChooseThisSlice];

    L2 = findall(criticalROI.==criticalROILabel);

    if length(L2)>0
        xcoordarray = zeros(Int64,length(L2));
        ycoordarray = zeros(Int64,length(L2));
        for cnt = 1:length(L2)
            xcoordarray[cnt] = L2[cnt][1];
            ycoordarray[cnt] = L2[cnt][2];
        end

        roiStartX = minimum(xcoordarray)-1;
        roiEndX = maximum(xcoordarray)+1;
        roiStartY = minimum(ycoordarray)-1;
        roiEndY = maximum(ycoordarray)+1;
        xdim = roiEndX-roiStartX+1;
        ydim = roiEndY-roiStartY+1;

        boundaryConditionMap = convert(Array{Float64},preludephasemap_cpe[roiStartX:roiEndX,roiStartY:roiEndY]);
        mask = ones(xdim,ydim)-criticalROI[roiStartX:roiEndX,roiStartY:roiEndY];
        snr_ref = abs.(imgsn)[roiStartX:roiEndX,roiStartY:roiEndY].*mask;
        λmap = snr_ref ./maximum(snr_ref[:]);
        λmap[:,1].=2.;
        λmap[:,end].=2.;
        λmap[1,:].=2.;
        λmap[end,:].=2.;
        Mgv = pxmap_radppixel[roiStartX:roiEndX,roiStartY:roiEndY];
        Mgh = pymap_radppixel[roiStartX:roiEndX,roiStartY:roiEndY];

        M_recovered = twoDimIntegration(mask,boundaryConditionMap,λmap,Mgv,Mgh);

        newPhaseMap = deepcopy(preludephasemap_cpe);
        newPhaseMap[roiStartX:roiEndX,roiStartY:roiEndY]= M_recovered;
        newPhaseMapMask = newPhaseMap.*preludemask;
    else
        newPhaseMap = preludephasemap;
        newPhaseMapMask = newPhaseMap.*preludemask;
    end

    headerinfo4["datatype"] = Int16(16);
    output_all = convert(Array{Float32,3}, preludephasemap_all);
    newPhaseMapMask = convert(Array{Float32,2}, newPhaseMapMask);
    output_all[:,:,ChooseThisSlice] = newPhaseMapMask;
    write_nii_header(output_nii_file, headerinfo4);
    fid = open(output_nii_file,"a");
    write(fid, output_all);
    close(fid);

#     figure(1,figsize=(10,5));imshow(reverse(permutedims(vcat(preludephasemap,newPhaseMapMask),[2,1]),dims=1),cmap="hot",interpolation="none"); axis("off");
#     figure(2,figsize=(10,5));imshow(reverse(permutedims(vcat(preludephasemap_cpe,newPhaseMap),[2,1]),dims=1),cmap="hot",interpolation="none"); axis("off");

end
# julia_main(magnitude_nii_file,phase_nii_file,partiallyUnwrapped_nii_file,criticalROI_nii_file,ChooseThisSlice,criticalROILabel);
%use Julia 
###### Save multiple slices using julia_main function
for i=1:9
    println("#", i, " Slice:")
    output_nii_file = "data/output"*string(i)*".nii";
    julia_main(magnitude_nii_file,phase_nii_file,partiallyUnwrapped_nii_file,criticalROI_nii_file,Int64(i),criticalROILabel);
end
%use Python3

###### Plotting "output_nii_file" heatmap with slider for navigating through slices
import scipy.io as sio
import plotly.graph_objs as go
from ipywidgets import interactive, HBox, widgets, interact
import numpy as np
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
from IPython.core.display import display, HTML
import nibabel as nib
import os
import numpy as np 

init_notebook_mode(connected=True)
config={'showLink': False, 'displayModeBar': False}

# If you run on your local machine comment out this line
os.chdir("/home/jovyan/work/PhaseUnwrapping_book/content/03/data")


# Load nifti files 
nii_left = nib.load('partiallyUnwrappedMap.nii').get_fdata()
nii_list = [nib.load(f"output{i}.nii").get_fdata() for i in range(1,10)]

# Concatenate files using list comprehension, along adequate axis 
newmap = [np.concatenate((nii[:,:,i], nii_left[:,:,i]), axis = 0) for i,nii in enumerate(nii_list)]

# Get # of slices
slices_z = np.shape(newmap)[0]

# Create list of Heatmaps, one for every slice 
data = []
for i in range(slices_z):
    z_current = np.rot90(newmap[i],3)
    data_c = go.Heatmap(z = z_current, 
                        visible = False,
                        xtype = "scaled", 
                        ytype = "scaled",
                        colorscale = "hot",
                        name = "comparison heatmap",
                        colorbar = dict(title = dict(text = "B<sub>0</sub> (Hz)")))
    data.append(data_c)

# Toggle frist slice to be visible
data[0]['visible'] = True

# Create steps and slider
steps = []
for i in range(slices_z):
    step = dict(
        method = 'restyle',  
        args = ['visible', [False]*slices_z],
        label = str(i+1)
    )
    step['args'][1][i] = True # Toggle i'th trace to "visible"
    steps.append(step)

sliders = [dict(
    active = 0,
    currentvalue = {'prefix':"Current slice is: <b>"},
    pad = {"t": 50, "b": 10},
    steps = steps
)]

# Setup the layout of the figure
layout = go.Layout(
    title = "\t \t \t \t \t \t \t \t \t Prelude \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t Filtered", 
    width=780,
    height=640,
    margin=go.layout.Margin(
        l=50,
        r=50,
        b=60,
        t=35,
    ),
    showlegend = False,
    autosize = False,
    sliders=sliders,
    xaxis = dict(showgrid = False,
                 showticklabels= False),
    yaxis = dict(showgrid = False,
                 showticklabels = False),
)

# Plot function saves as html or with ipplot
fig = dict(data=data, layout=layout)
plot(fig, filename = 'fig.html', config = config)

display(HTML('fig.html'))
# Local offline use
# iplot(fig,config=config)