Code example


This code example is a Jupyter notebook with Script of Scripts (SoS) workflow. It reproduces interactive figures for the highlighted paper by first author Nan-kuei Chen and second author Pei-Hsin Wu.

The calculations are written using Julia (from this fetched from repo), and the interactive figures are written in Python 3.8.6 with the ploting library Plotly.

Julia code cell

The main program and the calculation are in the code bellow, you could use different value of patchsize.

If you wish to see the code cell, try clicking the button to the right of the empty spot bellow!

%use Julia 

###### The main program starts here; no need to modify the code below (although you could use different value of patchsize)
push!(LOAD_PATH,"juliafunction");
using PyPlot
using Read_NIfTI1
using myFun
using Distributed
@everywhere using FFTW
@everywhere using LinearAlgebra
@everywhere using SharedArrays

function julia_main(magnitude_nii_file::String,phase_nii_file::String,partiallyUnwrapped_nii_file::String,criticalROI_nii_file::String,ChooseThisSlice::Int64,criticalROILabel::Int64)
    headerinfo1 = load_nii_header(magnitude_nii_file);
    data001 = load_nii_data(magnitude_nii_file, headerinfo1);
    headerinfo1 = load_nii_header(phase_nii_file);
    data002 = load_nii_data(phase_nii_file, headerinfo1);
    imagedata_noise = data001 .* exp.(complex(0,1)*data002);

    imgsn = imagedata_noise[:,:,ChooseThisSlice];
    ksn = ifftshift(ifft(ifftshift(imgsn)));

    @everywhere function sidm(kdata::Array{Complex{Float32},2})
        kabs = abs.(kdata);
        pxpy = findall(kabs .== maximum(kabs));
        px = pxpy[1][1];
        py = pxpy[1][2];
        return (px,py);
    end

    datasize1,datasize2 = size(imgsn)

    pxmap = SharedArray{Float64,2}((datasize1,datasize2));
    pymap = SharedArray{Float64,2}((datasize1,datasize2));

    patchsize = 7;
    patchsize = patchsize + div(1 + (-1).^patchsize,2)
    ps1 = max(div((patchsize-1),2),1);

    @time @sync @distributed for cntx = 1:datasize1
        @inbounds @fastmath @simd for cnty = 1:datasize2
            tmp1 = zeros(ComplexF32,(datasize1,datasize2));
            startingx = max(1,cntx-ps1);
            startingy = max(1,cnty-ps1);
            endingx = min(cntx+ps1,datasize1);
            endingy = min(cnty+ps1,datasize2);
            tmp1[startingx:endingx,startingy:endingy] = imgsn[startingx:endingx,startingy:endingy];
            tmp2 = ifftshift(ifft(ifftshift(tmp1)));
            px,py = sidm(tmp2);
            pxmap[cntx,cnty]=px;
            pymap[cntx,cnty]=py;
        end
    end

    pxmap2a = pxmap.-(datasize1/2);
    pymap2a = pymap.-(datasize2/2);
    pxmap_radppixel = -pxmap2a*2*π/datasize1;
    pymap_radppixel = -pymap2a*2*π/datasize2;

    headerinfo4 = load_nii_header(partiallyUnwrapped_nii_file);
    preludephasemap_all = load_nii_data(partiallyUnwrapped_nii_file, headerinfo4);
    preludephasemap = preludephasemap_all[:,:,ChooseThisSlice];

    preludemask = ones(datasize1,datasize2);
    L = findall(preludephasemap.==0);
    preludemask[L].=0;
    preludephasemap_cpe = closest_point_estimation(preludephasemap,preludemask);

    headerinfo5 = load_nii_header(criticalROI_nii_file);
    criticalROI_all = load_nii_data(criticalROI_nii_file, headerinfo5);
    criticalROI = criticalROI_all[:,:,ChooseThisSlice];

    L2 = findall(criticalROI.==criticalROILabel);

    if length(L2)>0
        xcoordarray = zeros(Int64,length(L2));
        ycoordarray = zeros(Int64,length(L2));
        for cnt = 1:length(L2)
            xcoordarray[cnt] = L2[cnt][1];
            ycoordarray[cnt] = L2[cnt][2];
        end

        roiStartX = minimum(xcoordarray)-1;
        roiEndX = maximum(xcoordarray)+1;
        roiStartY = minimum(ycoordarray)-1;
        roiEndY = maximum(ycoordarray)+1;
        xdim = roiEndX-roiStartX+1;
        ydim = roiEndY-roiStartY+1;

        boundaryConditionMap = convert(Array{Float64},preludephasemap_cpe[roiStartX:roiEndX,roiStartY:roiEndY]);
        mask = ones(xdim,ydim)-criticalROI[roiStartX:roiEndX,roiStartY:roiEndY];
        snr_ref = abs.(imgsn)[roiStartX:roiEndX,roiStartY:roiEndY].*mask;
        λmap = snr_ref ./maximum(snr_ref[:]);
        λmap[:,1].=2.;
        λmap[:,end].=2.;
        λmap[1,:].=2.;
        λmap[end,:].=2.;
        Mgv = pxmap_radppixel[roiStartX:roiEndX,roiStartY:roiEndY];
        Mgh = pymap_radppixel[roiStartX:roiEndX,roiStartY:roiEndY];

        M_recovered = twoDimIntegration(mask,boundaryConditionMap,λmap,Mgv,Mgh);

        newPhaseMap = deepcopy(preludephasemap_cpe);
        newPhaseMap[roiStartX:roiEndX,roiStartY:roiEndY]= M_recovered;
        newPhaseMapMask = newPhaseMap.*preludemask;
    else
        newPhaseMap = preludephasemap;
        newPhaseMapMask = newPhaseMap.*preludemask;
    end

    headerinfo4["datatype"] = Int16(16);
    output_all = convert(Array{Float32,3}, preludephasemap_all);
    newPhaseMapMask = convert(Array{Float32,2}, newPhaseMapMask);
    output_all[:,:,ChooseThisSlice] = newPhaseMapMask;
    write_nii_header(output_nii_file, headerinfo4);
    fid = open(output_nii_file,"a");
    write(fid, output_all);
    close(fid);

#     figure(1,figsize=(10,5));imshow(reverse(permutedims(vcat(preludephasemap,newPhaseMapMask),[2,1]),dims=1),cmap="hot",interpolation="none"); axis("off");
#     figure(2,figsize=(10,5));imshow(reverse(permutedims(vcat(preludephasemap_cpe,newPhaseMap),[2,1]),dims=1),cmap="hot",interpolation="none"); axis("off");

end
# julia_main(magnitude_nii_file,phase_nii_file,partiallyUnwrapped_nii_file,criticalROI_nii_file,ChooseThisSlice,criticalROILabel);

Python code cell

Now lets try plotting the “output_nii_file” heatmap with a slider for navigating through slices.

You can zoom in and out if you wish to explore the phenomenon!

%use Python3

###### Plotting "output_nii_file" heatmap with slider for navigating through slices
import scipy.io as sio
import plotly.graph_objs as go
from ipywidgets import interactive, HBox, widgets, interact
import numpy as np
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
from IPython.core.display import display, HTML
import nibabel as nib
import os
import numpy as np 

init_notebook_mode(connected=True)
config={'showLink': False, 'displayModeBar': False}

# If you run on your local machine comment out this line
os.chdir("/home/jovyan/work/Phase_book/03/data")


# Load nifti files 
nii_left = nib.load('partiallyUnwrappedMap.nii').get_fdata()
nii_list = [nib.load(f"output{i}.nii").get_fdata() for i in range(1,10)]

# Concatenate files using list comprehension, along adequate axis 
newmap = [np.concatenate((nii[:,:,i], nii_left[:,:,i]), axis = 0) for i,nii in enumerate(nii_list)]

# Get # of slices
slices_z = np.shape(newmap)[0]

# Create list of Heatmaps, one for every slice 
data = []
for i in range(slices_z):
    z_current = np.rot90(newmap[i],3)
    data_c = go.Heatmap(z = z_current, 
                        visible = False,
                        xtype = "scaled", 
                        ytype = "scaled",
                        colorscale = "hot",
                        name = "comparison heatmap",
                        colorbar = dict(title = dict(text = "B<sub>0</sub> (Hz)")))
    data.append(data_c)

# Toggle frist slice to be visible
data[0]['visible'] = True

# Create steps and slider
steps = []
for i in range(slices_z):
    step = dict(
        method = 'restyle',  
        args = ['visible', [False]*slices_z],
        label = str(i+1)
    )
    step['args'][1][i] = True # Toggle i'th trace to "visible"
    steps.append(step)

sliders = [dict(
    active = 0,
    currentvalue = {'prefix':"Current slice is: <b>"},
    pad = {"t": 50, "b": 10},
    steps = steps
)]

# Setup the layout of the figure; NOTE this \t, are tabs for the layout in the html render. You can delete some of them in order to see the "Filtered"
layout = go.Layout(
    title = "\t \t \t \t \t \t \t \t \t Prelude \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t \t Filtered", 
    width=780,
    height=640,
    margin=go.layout.Margin(
        l=50,
        r=50,
        b=60,
        t=35,
    ),
    showlegend = False,
    autosize = False,
    sliders=sliders,
    xaxis = dict(showgrid = False,
                 showticklabels= False),
    yaxis = dict(showgrid = False,
                 showticklabels = False),
)

# Plot function saves as html or with ipplot
fig = dict(data=data, layout=layout)
plot(fig, filename = 'fig.html', config = config)

display(HTML('fig.html'))
# Local offline use
# iplot(fig, config=config)