Transforms can be used to center image data or to align two images¶
In [ ]:
import shapeworks as sw
import numpy as np
import pyvista as pv
In [ ]:
pv.set_jupyter_backend('static')
In [ ]:
DATA = "../ShapeWorks/Examples/Python/Data"
Center of mass transforms¶
Centers the contents of a given image using the specified resampling method
In [ ]:
filename = DATA + "/ellipsoid/Ellipsoids_UnPrepped/seg.ellipsoid_14.nrrd"
In [ ]:
img = sw.Image(filename)
print("center: ", img.center())
print("center of mass: ", img.centerOfMass())
In [ ]:
slices = sw.sw2vtkImage(img).slice_orthogonal(x=25,y=25,z=25)
In [ ]:
p = pv.Plotter(shape=(1,3), border=False)
p.subplot(0,0)
p.add_text("before centering", position='lower_left')
p.add_mesh(slices[2], style='points', show_scalar_bar=False)
p.show_grid()
p.camera_position = 'xy'
p.subplot(0,1)
p.add_mesh(slices[0], style='points', show_scalar_bar=False)
p.show_grid()
p.camera_position = 'yz'
p.subplot(0,2)
p.add_mesh(slices[1], style='points', show_scalar_bar=False)
p.show_grid()
p.camera_position = 'zx'
p.show()
create and apply transform, noting that center of mass is now roughly in the center of the image¶
In [ ]:
xform = img.createCenterOfMassTransform()
In [ ]:
# now center of mass is roughly at the center
img.applyTransform(xform)
print("center of mass: ", img.centerOfMass())
visualization of different resampling methods¶
In [ ]:
# linear interpolation
img = sw.Image(filename)
img.applyTransform(xform, sw.InterpolationType.Linear)
lin = sw.sw2vtkImage(img)
In [ ]:
# closest point resampling
img = sw.Image(filename)
img.applyTransform(xform, sw.InterpolationType.NearestNeighbor)
nn = sw.sw2vtkImage(img)
In [ ]:
lin_slices = lin.slice_orthogonal(x=25,y=25,z=25)
nn_slices = nn.slice_orthogonal(x=25,y=25,z=25)
In [ ]:
p = pv.Plotter(shape=(2,3), border=False)
p.subplot(0,0)
p.add_text("linear")
p.add_mesh(lin_slices[2], style='points', show_scalar_bar=False)
p.show_grid()
p.camera_position = 'xy'
p.subplot(0,1)
p.add_mesh(lin_slices[0], style='points', show_scalar_bar=False)
p.show_grid()
p.camera_position = 'yz'
p.subplot(0,2)
p.add_mesh(lin_slices[1], style='points', show_scalar_bar=False)
p.show_grid()
p.camera_position = 'zx'
p.subplot(1,0)
p.add_text("nearest neighbor")
p.add_mesh(nn_slices[2], style='points', show_scalar_bar=False)
p.show_grid()
p.camera_position = 'xy'
p.subplot(1,1)
p.add_mesh(nn_slices[0], style='points', show_scalar_bar=False)
p.show_grid()
p.camera_position = 'yz'
p.subplot(1,2)
p.add_mesh(nn_slices[1], style='points', show_scalar_bar=False)
p.show_grid()
p.camera_position = 'zx'
p.show()
In [ ]:
Rigid registration transforms¶
Aligns two images using their distance transforms by converting these to meshes, computing the alignment using iterative closest point (ICP), then resampling to move the center of mass of one to be aligned with the other
NOTE: images used to compute the transform must be distance transforms
In [ ]:
filename1 = DATA + "/ellipsoid/Ellipsoids_UnPrepped/seg.ellipsoid_17.nrrd"
filename2 = DATA + "/ellipsoid/Ellipsoids_UnPrepped/seg.ellipsoid_19.nrrd"
compute distance transforms to be used to compute alignment of the segmentations¶
In [ ]:
img1 = sw.Image(filename1)
img2 = sw.Image(filename2)
In [ ]:
img1_slices_orig = sw.sw2vtkImage(img1).slice_orthogonal(x=25,y=25,z=25)
img2_slices_orig = sw.sw2vtkImage(img2).slice_orthogonal(x=25,y=25,z=25)
In [ ]:
p = pv.Plotter(shape=(2,3), border=False)
p.subplot(0,0)
p.add_text("Image 1")
p.add_mesh(img1_slices_orig[2], style='points', show_scalar_bar=False)
p.show_grid()
p.camera_position = 'xy'
p.subplot(0,1)
p.add_mesh(img1_slices_orig[0], style='points', show_scalar_bar=False)
p.show_grid()
p.camera_position = 'yz'
p.subplot(0,2)
p.add_mesh(img1_slices_orig[1], style='points', show_scalar_bar=False)
p.show_grid()
p.camera_position = 'zx'
p.subplot(1,0)
p.add_text("Image 2")
p.add_mesh(img2_slices_orig[2], style='points', show_scalar_bar=False)
p.show_grid()
p.camera_position = 'xy'
p.subplot(1,1)
p.add_mesh(img2_slices_orig[0], style='points', show_scalar_bar=False)
p.show_grid()
p.camera_position = 'yz'
p.subplot(1,2)
p.add_mesh(img2_slices_orig[1], style='points', show_scalar_bar=False)
p.show_grid()
p.camera_position = 'zx'
p.show()
In [ ]:
create distance transforms to be used to compute the images' alignment¶
In [ ]:
dt1 = img1.computeDT(1.0)
dt2 = img2.computeDT(1.0)
create image alignment transforms¶
In [ ]:
xform_1_to_2 = dt1.createRigidRegistrationTransform(dt2)
xform_2_to_1 = dt2.createRigidRegistrationTransform(dt1)
apply the transforms to the original images¶
In [ ]:
img1 = sw.Image(filename1)
img2 = sw.Image(filename2)
In [ ]:
img1.applyTransform(xform_1_to_2)
img2.applyTransform(xform_2_to_1)
visualize the results¶
Notice how it even rotates the images into each others' centers of mass
In [ ]:
img1_slices = sw.sw2vtkImage(img1).slice_orthogonal(x=25,y=25,z=25)
img2_slices = sw.sw2vtkImage(img2).slice_orthogonal(x=25,y=25,z=25)
In [ ]:
p = pv.Plotter(shape=(3,3), border=False)
p.subplot(0,0)
p.add_text("1 -> 2")
p.add_mesh(img1_slices[2], style='points', show_scalar_bar=False, opacity=0.75)
p.add_mesh(img2_slices_orig[2], style='points', show_scalar_bar=False, opacity=0.25)
p.show_grid()
p.camera_position = 'xy'
p.subplot(0,1)
p.add_mesh(img1_slices[0], style='points', show_scalar_bar=False, opacity=0.75)
p.add_mesh(img2_slices_orig[0], style='points', show_scalar_bar=False, opacity=0.25)
p.show_grid()
p.camera_position = 'yz'
p.subplot(0,2)
p.add_mesh(img1_slices[1], style='points', show_scalar_bar=False, opacity=0.75)
p.add_mesh(img2_slices_orig[1], style='points', show_scalar_bar=False, opacity=0.25)
p.show_grid()
p.camera_position = 'zx'
p.subplot(1,0)
p.add_text("2 -> 1")
p.add_mesh(img1_slices_orig[2], style='points', show_scalar_bar=False, opacity=0.25)
p.add_mesh(img2_slices[2], style='points', show_scalar_bar=False, opacity=0.75)
p.show_grid()
p.camera_position = 'xy'
p.subplot(1,1)
p.add_mesh(img1_slices_orig[0], style='points', show_scalar_bar=False, opacity=0.25)
p.add_mesh(img2_slices[0], style='points', show_scalar_bar=False, opacity=0.75)
p.show_grid()
p.camera_position = 'yz'
p.subplot(1,2)
p.add_mesh(img1_slices_orig[1], style='points', show_scalar_bar=False, opacity=0.25)
p.add_mesh(img2_slices[1], style='points', show_scalar_bar=False, opacity=0.75)
p.show_grid()
p.camera_position = 'zx'
p.subplot(2,0)
p.add_text("both")
p.add_mesh(img1_slices[2], style='points', show_scalar_bar=False, opacity=0.5)
p.add_mesh(img2_slices[2], style='points', show_scalar_bar=False, opacity=0.5)
p.show_grid()
p.camera_position = 'xy'
p.subplot(2,1)
p.add_mesh(img1_slices[0], style='points', show_scalar_bar=False, opacity=0.5)
p.add_mesh(img2_slices[0], style='points', show_scalar_bar=False, opacity=0.5)
p.show_grid()
p.camera_position = 'yz'
p.subplot(2,2)
p.add_mesh(img1_slices[1], style='points', show_scalar_bar=False, opacity=0.5)
p.add_mesh(img2_slices[1], style='points', show_scalar_bar=False, opacity=0.5)
p.show_grid()
p.camera_position = 'zx'
p.show()
In [ ]: