Fit T2* values using Python

Fit T2* values using Python

For a project I needed T2*-values for different tissues. These can be estimated by using T2-weighted images with different echo times (TEs). Some free tools exist like the Mipav-based CBS-tools and the web-based MRI Toolbox, but the Mipav toolkit is rather unstable (Java, ...) and needs everything to be in a very specific format. And desptite that I think web-based neuroimaging analysis is very cool, the MRI toolbox is not very flexible yet (you have to manually enter the measurements). Therefore, I set out to find out what this T2(*)-fitting actually does. Basically it fits the following signal decay function: $$S(TE) = S_0 e^{-TE/T2(*)}$$   Sor for example:
In [3]:
import seaborn as sns
In [17]:
te = np.linspace(0, 100)
t2 = 40
S0 = 100

plt.plot(te, S0 * np.exp(-te/t2))
plt.scatter(t2, S0 * np.exp(-1), s=65, color=sns.color_palette()[0])

plt.plot([0, t2], [S0*np.exp(-1), S0*np.exp(-1)], ls='--', color=sns.color_palette()[0])
plt.plot([t2, t2], [0, S0*np.exp(-1)], ls='--', color=sns.color_palette()[0])

plt.xlim(0, 100)
plt.ylim(0, 100)

plt.xlabel('Time (ms)')
plt.ylabel('Signal S')
Out[17]:
<matplotlib.text.Text at 0x858c790>
If we want to fit such an exponential function, we can make it linear by using log: $$S(TE) = S_0 e^{-TE/T2(*)}$$ $$log(S(TE)) = log(S_0) \frac{-TE}{T2(*)}$$ $$log(S(TE)) = log(S_0) -TE \frac{1}{T2(*)}$$ Which can be rewritten as a standard linear equation   $$y = \beta_1 x_1 + \beta_2 x_2$$ Where \(\beta1 = log(S0)\), \(x_1 = 1\), \(\beta_2 = \frac{1}{T2(*)}\), \(x_2 = -TE\) and \(y = log(S(TE)\).
As we have \(x_1\) (1), \(x_2\) (-TE), and \(y\) (log of the measurement), this can be solved using a standard least-squares solvers. Let's try it for some simulated data:
In [88]:
tes = np.array([10., 15., 20., 25., 30.])
t2 = 40.
S0 = 100.

measurements = S0 * np.exp(-tes/t2)
measurements
Out[88]:
array([ 77.88007831,  68.72892788,  60.65306597,  53.52614285,  47.23665527])
In [89]:
y = np.log(measurements)
x1 = np.ones_like(tes)
x2 = -tes

x = np.concatenate((x1[:, np.newaxis], x2[:, np.newaxis]), 1)
display(y)
display(x)
array([ 4.35517019,  4.23017019,  4.10517019,  3.98017019,  3.85517019])
array([[  1., -10.],
       [  1., -15.],
       [  1., -20.],
       [  1., -25.],
       [  1., -30.]])
In log-space the exponential function is linear
In [90]:
plt.scatter(x2, y)
Out[90]:
<matplotlib.collections.PathCollection at 0x10558090>
And we can solve it...
In [91]:
import scipy as sp
beta, _, _, _ = sp.linalg.lstsq(x, y)

s0_ = np.exp(beta[0])
t2_ = 1/beta[1]

s0_, t2_
Out[91]:
(99.999999999999602, 40.000000000000128)

Real data

Now for real MRI data we can write a little function. It write the resulting parameter estimates to a nifti file and returns the filenames, so it can easily be used in something like NiPype.
In [92]:
def find_t2(nifti_images, tes):
    import nibabel as nb
    import scipy as sp
    import numpy as np
    import os

    data = np.log(np.array([nb.load(fn).get_data() for fn in nifti_images]))
    data[data < 0 ] = 0

    tes = np.array(tes)
    x = np.concatenate((np.ones_like(tes[..., np.newaxis]), -tes[..., np.newaxis]), 1)

    beta, _, _, _ = sp.linalg.lstsq(x, data)

    s0_ = np.exp(beta[0])
    t2_ = 1./beta[1]

    fn_s0 = os.path.abspath('s0.nii.gz')
    fn_t2 = os.path.abspath('t2.nii.gz')    

    nb.save(nb.Nifti1Image(s0_, nb.load(nifti_images[0]).get_affine()), fn_s0)
    nb.save(nb.Nifti1Image(t2_, nb.load(nifti_images[0]).get_affine()), fn_t2)

    return fn_s0, fn_t2
And this actually works:
In [93]:
fn_s0, fn_t2 = find_t2(['/home/gdholla1/data/t2star/flash/BE4T_11.22.nii.gz',
                        '/home/gdholla1/data/t2star/flash/BE4T_20.39.nii.gz',
                        '/home/gdholla1/data/t2star/flash/BE4T_29.57.nii.gz']
                        , [11.22, 20.39, 29.57])
In [102]:
import nibabel as nb

# Load data
t2 = nb.load(fn_t2).get_data()

# mask out the very noisy stuff
t2 = np.ma.masked_outside(t2, 0, 60)

# Nice figure size
plt.figure(figsize=(20, 7))

# Show T2* values for three different slices, 
for i, z in enumerate([40, 64, 104]):
    plt.subplot(1, 3, i+1)
    plt.title('Z = %s' % z, fontsize=25)
    plt.imshow(t2[:, :, z].T, origin='lower', cmap=plt.cm.hot, vmin=0, vmax=60)
    plt.grid(None)
    plt.xticks([])
    plt.yticks([])

cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=20) 

plt.tight_layout()