Image Reconstruction in Optical Tomography

Home

Introduction
  About ODT
  About TOAST
  Matlab toolbox

Download
  Toast sources
  Matlab toolbox
  Installation

Documentation
  Getting Started
  Matlab demos
  FAQ
  Change log
  References
  License

Support
  Links
  Message board
  Contacts
  Vacancies

Toast python module tutorial: A simple parameter reconstruction example

This example shows how to write a simple DOT inverse solver for reconstructing optical parameters from boundary data in python using the Toast module. To run this code yourself, you need python with the numpy, scipy and matplotlib modules, as well as the Toast python module.

To save you typing, the full python script for this example can be downloaded here.

Step 1: Preliminaries

We need to import a few standard modules:

import os

import numpy as np

from numpy import matrix

from scipy import sparse

from scipy.sparse import linalg

from numpy.random import rand

import matplotlib.pyplot as plt

Load the toast module:

execfile(os.getenv("TOASTDIR") + "/ptoast_install.py")

import toast

Note that the ptoast_install script simply sets up the search paths for the local installation of the toast python module. If you installed the toast module in python's default module location, you can skip this line.

Step 1: Defining a few helper functions

Project a photon density distribution to boundary measurements. This function takes an array of column vectors of complex nodal photon densities phi and maps them to boundary measurements, using measurement operators mvec. The measurements are returned as log amplitude and phase data:

def projection(phi, mvec):

gamma = mvec.transpose() * phi

gamma = np.reshape(gamma, (-1, 1), 'F')

lgamma = np.log(gamma)

lnamp = lgamma.real

phase = lgamma.imag

return np.concatenate((lnamp, phase))

Compute the objective function for an array of boundary data proj computed from a set of trial parameters, given measurement data data and scaling parameters sd:

def objective(proj, data, sd):

err_data = np.sum(np.power((data-proj)/sd, 2))

return err_data

Compute the objective function, given vector of logarithmic parameter solutions logx. This function is called by the line search step in the iterative loop:

def objective_ls(logx):

x = np.exp(logx) # undo log transform

slen = x.shape[0]/2 # x contains 2 parameter sets

scmua = x[0:slen]

sckap = x[slen:2*slen]

smua = scmua/cm

skap = sckap/cm

smus = 1/(3*skap) - smua

mua = basis_inv.Map('S->M', smua)

mus = basis_inv.Map('S->M', smus)

phi = mesh_inv.Fields(None, qvec, mua, mus, ref, freq)

p = projection(phi, mvec)

return objective(p, data, sd)

For assessing the reconstruction results, evaluate the image error as a norm between target and reconstructed parameter distributions:

def imerr(im1, im2):

im1 = np.reshape(im1, -1, 1)

im2 = np.reshape(im2, -1, 1)

err = np.sum(np.power(im1-im2, 2))/np.sum(np.power(im2, 2))

return err

Step 2: Generating target data

For this example, we generate the target data directly with the forward solver. We use a high-resolution mesh for this, while the following reconstruction will be performed on a lower-resolution mesh to avoid an inverse crime. In addition, the simulated target data can be contaminated with noise. For the target data we use an inhomogeneous parameter distribution. The details of the python forward solver can be found here.

# Set the file paths

meshdir = os.path.expandvars("$TOASTDIR/test/2D/meshes/")

meshfile1 = meshdir + "ellips_tri10.msh" # mesh for target data generation

meshfile2 = meshdir + "circle25_32.msh" # mesh for reconstruction

qmfile = meshdir + "circle25_32x32.qm"

muafile = meshdir + "tgt_mua_ellips_tri10.nim"

musfile = meshdir + "tgt_mus_ellips_tri10.nim"

 

# A few general parameters

c0 = 0.3 # speed of light in vacuum [mm/ps]

refind = 1.4 # refractive index in medium (homogeneous)

cm = c0/refind; # speed of light in medium

 

# ---------------------------------------------------

# Generate target data

mesh_fwd = toast.Mesh(meshfile1)

mesh_fwd.ReadQM(qmfile)

qvec = mesh_fwd.Qvec(type='Neumann',shape='Gaussian',width=2)

mvec = mesh_fwd.Mvec(shape='Gaussian',width=2)

nlen = mesh_fwd.NodeCount()

nqm = qvec.shape[1] * mvec.shape[1]

ndat = nqm*2

 

# Target parameters

mua = mesh_fwd.ReadNim(muafile)

mus = mesh_fwd.ReadNim(musfile)

ref = np.ones((1, nlen)) * refind

freq = 100 MHz

 

# Target ranges (for display)

mua_min = 0.015 # np.min(mua)

mua_max = 0.055 # np.max(mua)

mus_min = 1 # np.min(mus)

mus_max = 4.5 # np.max(mus)

 

# Solve forward problem

phi = mesh_fwd.Fields(None, qvec, mvec, mua, mus, ref, freq)

data = projection(phi, mvec)

lnamp_tgt = data[0:nqm]

phase_tgt = data[nqm:nqm*2]

 

# Map target parameters to images for display

grd = np.array([100, 100])

basis_fwd = toast.Raster(mesh_fwd, grd)

bmua_tgt = np.reshape(basis_fwd.Map('M->B', mua), grd)

bmus_tgt = np.reshape(basis_fwd.Map('M->B', mus), grd)

Step 3: Set up inverse problem

We first load the FEM mesh to be used during the reconstruction, and construct the source and measurement operators:

mesh_inv = toast.Mesh(meshfile2)

mesh_inv.ReadQM(qmfile)

qvec = mesh_inv.Qvec(type='Neumann', shape='Gaussian', width=2)

mvec = mesh_inv.Mvec(shape='Gaussian', width=2)

nlen = mesh_inv.NodeCount()

The initial parameter estimates are set to homogeneous distributions:

mua = np.ones(nlen) * 0.025

mus = np.ones(nlen) * 2

kap = 1/(3*(mua+mus))

ref = np.ones(nlen) * refind

freq = 100

We also create a basis mapper object that allows to map parameter and field distributions between the mesh basis of the forward solver, and the regular pixel basis of the inverse problem:

basis_inv = toast.Raster(mesh_inv, grd)

The forward operator is applied to obtain boundary data from these parameter estimates:

phi = mesh_inv.Fields(None, qvec, mvec, mua, mus, ref, freq)

proj = projection(phi, mvec)

lnamp = proj[0:nqm]

phase = proj[nqm:nqm*2]

We use the differences between target and projection data to define a data scaling vector:

sd_lnamp = np.ones(lnamp.shape) * np.linalg.norm(lnamp_tgt-lnamp)

sd_phase = np.ones(phase.shape) * np.linalg.norm(phase_tgt-phase)

sd = np.concatenate((sd_lnamp, sd_phase))

We now map the parameters into the basis of the inverse solver:

bmua = basis_inv.Map('M->B', mua)

bmus = basis_inv.Map('M->B', mus)

bkap = basis_inv.Map('M->B', kap)

bcmua = bmua * cm

bckap = bkap * cm

scmua = basis_inv.Map('B->S', bcmua)

sckap = basis_inv.Map('B->S', bckap)

and construct the solution vector:

x = np.asmatrix(np.concatenate((scmua, sckap))).transpose()

logx = np.log(x)

Calculate the initial objective value:

err0 = objective(proj, data, sd, logx)

err = err0

errp = 1e10

erri = np.array([err])

errmua = np.array([imerr(bmua, bmua_tgt)])

errmus = np.array([imerr(bmus, bmus_tgt)])

Step 4: The iterative solver loop

We now construct the iterative nonlinear conjugate gradient (NCG) solver to find parameter distributions that minimise the objective function. This has the form:

itrmax = 100

itr = 1

step = 1.0

hfig = plt.figure()

plt.show()

while itr <= itrmax:

errp = err

# ... (body of conjugate gradient solver)

itr += 1

where itrmax is an iteration limit, itr is the iteration counter, and step is the initial step length for the line search.

The body of the NCG contains the following components:

1. Construction of the Jacobian matrix (later we will discuss methods to avoid the explicit storage of the Jacobian):

dphi = mesh_inv.Fields(None, qvec, mua, mus, ref, freq)

aphi = mesh_inv.Fields(None, mvec, mua, mus, ref, freq)

proj = np.reshape(mvec.transpose() * dphi, (-1, 1), 'F')

J = mesh_inv.Jacobian(basis_inv.Handle(), dphi, aphi, proj)

The gradient of the cost function is given by JTΔy, where Δy is the data difference:

proj = np.concatenate ((np.log(proj).real, np.log(proj).imag))

r = matrix(J).transpose() * (2*(data-proj)/sd**2)

r = np.multiply(r, x) # scale with x to incorporate log transform

We implement a Polak-Ribiere conjugate gradient to obtain an update direction. The step length is obtained from a non-exact one-dimensional line search:

if itr > 1:

delta_old = delta_new

delta_mid = np.dot (r.transpose(),s)

 

s = r # replace this with preconditioner

 

if itr == 1:

d = s

delta_new = np.dot(r.transpose(),d)

delta0 = delta_new

else:

delta_new = np.dot(r.transpose(),s)

beta = (delta_new-delta_mid) / delta_old

if itr % resetCG == 0 or beta <= 0:

d = s

else:

d = s + d*beta

 

delta_d = np.dot(d.transpose(), d)

step,err = toast.Linesearch(logx, d, step, err, objective_ls)

Direction and step size are used to update the estimate of the parameter vector:

logx = logx + d*step

x = np.exp(logx)

We can now map x back to the nodal absorption and scattering parameter distributions:

logx = logx + d*step

x = np.exp(logx)

slen = x.shape[0]/2

scmua = x[0:slen]

sckap = x[slen:2*slen]

smua = scmua/cm

skap = sckap/cm

smus = 1/(3*skap) - smua

mua = basis_inv.Map('S->M', smua)

mus = basis_inv.Map('S->M', smus)

We can also map the solutions to images for display and error analysis during the reconstruction:

# Map solution to images

bmua = np.reshape(basis_inv.Map('S->B', smua), grd)

bmus = np.reshape(basis_inv.Map('S->B', smus), grd)

 

# Compute solution errors

erri = np.concatenate((erri, [err]))

errmua = np.concatenate((errmua, [imerr(bmua, bmua_tgt)]))

errmus = np.concatenate((errmus, [imerr(bmus, bmus_tgt)]))

print ("Iteration "+str(itr)+", objective "+str(err))

 

# Plot images and error graphs

plt.clf()

hfig.suptitle("Iteration "+str(itr))

 

ax1 = hfig.add_subplot(231)

im = ax1.imshow(bmua_tgt, vmin=mua_min, vmax=mua_max)

im.axes.get_xaxis().set_visible(False)

im.axes.get_yaxis().set_visible(False)

ax1.set_title("mua target")

plt.colorbar(im)

 

ax2 = hfig.add_subplot(232)

im = ax2.imshow(bmus_tgt, vmin=mus_min, vmax=mus_max)

im.axes.get_xaxis().set_visible(False)

im.axes.get_yaxis().set_visible(False)

ax2.set_title("mus target")

plt.colorbar(im)

 

ax3 = hfig.add_subplot(234)

im = ax3.imshow(bmua, vmin=mua_min, vmax=mua_max)

im.axes.get_xaxis().set_visible(False)

im.axes.get_yaxis().set_visible(False)

ax3.set_title("mua recon")

plt.colorbar(im)

 

ax4 = hfig.add_subplot(235)

im = ax4.imshow(bmus, vmin=mus_min, vmax=mus_max)

im.axes.get_xaxis().set_visible(False)

im.axes.get_yaxis().set_visible(False)

ax4.set_title("mus recon")

plt.colorbar(im)

 

ax5 = hfig.add_subplot(233)

im = ax5.semilogy(erri)

ax5.set_title("objective function")

plt.xlabel("iteration")

 

ax6 = hfig.add_subplot(236)

im = ax6.semilogy(errmua)

im = ax6.semilogy(errmus)

ax6.set_title("rel. image error")

plt.xlabel("iteration")

plt.pause(0.05)