diff --git a/pydeform/sitk_api.py b/pydeform/sitk_api.py index e9839b3..94cd6ed 100644 --- a/pydeform/sitk_api.py +++ b/pydeform/sitk_api.py @@ -1,43 +1,37 @@ +import numpy as np import SimpleITK as sitk import pydeform +import stk from . import interruptible -def _sitk2numpy(image): - R""" Utility to convert a SimpleITK image to numpy. - - .. note:: - A view of the ITK underlying data is returned. - The array will become invalid if the input - ITK object is garbage-collected. - - Parameters - ---------- - image: SimpleITK.Image - Input SimpleITK object. - - Returns - ------- - numpy.ndarray - Array view of the image data. +def _convert_image(sitk_image): + """ Converts a SimpleITK.Image to a pydeform.Volume + + Return: + pydeform.Volume if sitk_image is valid, otherwise None + """ - Tuple[float] - Origin of the image. + if sitk_image is None: + return None - Tuple[float] - Spacing of the image. + return pydeform.Volume( + sitk.GetArrayViewFromImage(sitk_image), + sitk_image.GetOrigin(), + sitk_image.GetSpacing(), + np.array(sitk_image.GetDirection()).reshape((3,3)) + ) - Tuple[float] - Cosine direction matrix of the image, a :math:`3 \times 3` - matrix flattened as a tuple (row-major). +def _convert_transform(transform): + """ Converts a SimpleITK.AffineTransform to a pydeform.AffineTransform """ + + translation = np.array(transform.GetTranslation()) + matrix = np.array(transform.GetMatrix()).reshape((3,3)) - """ - return ( - sitk.GetArrayViewFromImage(image), - image.GetOrigin(), - image.GetSpacing(), - image.GetDirection() - ) + # We need to include the fixed parameter (or center) in the offset + center = np.array(transform.GetCenter()) + offset = translation + center - matrix.dot(center) + return pydeform.AffineTransform(matrix, offset) def register( @@ -49,6 +43,7 @@ def register( fixed_landmarks=None, moving_landmarks=None, initial_displacement=None, + affine_transform=None, constraint_mask=None, constraint_values=None, settings=None, @@ -88,6 +83,9 @@ def register( initial_displacement: SimpleITK.Image Initial guess of the displacement field. + affine_transform: AffineTransform + Optional initial affine transformation + constraint_mask: SimpleITK.Image Boolean mask for the constraints on the displacement. Requires to provide `constraint_values`. @@ -139,61 +137,55 @@ def register( if not isinstance(moving_images, (list, tuple)): moving_images = [moving_images] - fixed_origin = fixed_images[0].GetOrigin() - fixed_spacing = fixed_images[0].GetSpacing() - fixed_direction = fixed_images[0].GetDirection() - moving_origin = moving_images[0].GetOrigin() - moving_spacing = moving_images[0].GetSpacing() - moving_direction = moving_images[0].GetDirection() - # Get numpy view of the input - fixed_images = [sitk.GetArrayViewFromImage(img) for img in fixed_images] - moving_images = [sitk.GetArrayViewFromImage(img) for img in moving_images] + fixed_images = [_convert_image(img) for img in fixed_images] + moving_images = [_convert_image(img) for img in moving_images] + + if None in fixed_images or None in moving_images: + raise RuntimeError('Cannot pass None as fixed or moving image') + + # A bit of magic since we can't pass None for arguments expecting stk.Volume + kwargs = {} if initial_displacement: - initial_displacement = sitk.GetArrayViewFromImage(initial_displacement) + kwargs['initial_displacement'] = _convert_image(initial_displacement) + if affine_transform: + if (not isinstance(affine_transform, sitk.AffineTransform) or + affine_transform.GetDimension() != 3): + raise ValueError( + 'Expected affine transform to be a 3D SimpleITK.AffineTransform' + ) + kwargs['affine_transform'] = _convert_transform(affine_transform) if constraint_mask: - constraint_mask = sitk.GetArrayViewFromImage(constraint_mask) + kwargs['constraint_mask'] = _convert_image(constraint_mask) if constraint_values: - constraint_values = sitk.GetArrayViewFromImage(constraint_values) + kwargs['constraint_values'] = _convert_image(constraint_values) if fixed_mask: - fixed_mask = sitk.GetArrayViewFromImage(fixed_mask) + kwargs['fixed_mask'] = _convert_image(fixed_mask) if moving_mask: - moving_mask = sitk.GetArrayViewFromImage(moving_mask) + kwargs['moving_mask'] = _convert_image(moving_mask) register = interruptible.register if subprocess else pydeform.register # Perform registration through the numpy API displacement = register(fixed_images=fixed_images, moving_images=moving_images, - fixed_origin=fixed_origin, - moving_origin=moving_origin, - fixed_spacing=fixed_spacing, - moving_spacing=moving_spacing, - fixed_direction=fixed_direction, - moving_direction=moving_direction, - fixed_mask=fixed_mask, - moving_mask=moving_mask, - fixed_landmarks=fixed_landmarks, - moving_landmarks=moving_landmarks, - initial_displacement=initial_displacement, - constraint_mask=constraint_mask, - constraint_values=constraint_values, settings=settings, log=log, log_level=log_level, silent=silent, num_threads=num_threads, use_gpu=use_gpu, + **kwargs ) # Convert the result to SimpleITK - displacement = sitk.GetImageFromArray(displacement) - displacement.SetOrigin(fixed_origin) - displacement.SetSpacing(fixed_spacing) - displacement.SetDirection(fixed_direction) + out = sitk.GetImageFromArray(np.array(displacement, copy=False), isVector=True) + out.SetOrigin(displacement.origin) + out.SetSpacing(displacement.spacing) + out.SetDirection(displacement.direction.astype(np.float64).flatten()) - return displacement + return out def transform(image, df, interp=sitk.sitkLinear): R""" Resample an image with a given displacement field. @@ -248,11 +240,63 @@ def jacobian(image): SimpleITK.Image Scalar image representing the Jacobian determinant. """ - result = pydeform.jacobian(*_sitk2numpy(image)) - result = sitk.GetImageFromArray(result) + result = pydeform.jacobian(_convert_image(image)) + result = sitk.GetImageFromArray(np.array(result, copy=False)) result.CopyInformation(image) return result +def regularize( + displacement, + precision = 0.5, + pyramid_levels = 6, + constraint_mask = None, + constraint_values = None): + """Regularize a given displacement field. + + Parameters + ---------- + displacement: SimpleITK.Image + Displacement field used to resample the image. + precision: float + Amount of precision. + pyramid_levels: int + Number of levels for the resolution pyramid + constraint_mask: SimpleITK.Image + Mask for constraining displacements in a specific area, i.e., restricting + any changes within the region. + constraint_values: SimpleITK.Image + Vector field specifying the displacements within the constrained regions. + + Returns + ------- + SimpleITK.Image + Scalar volume image containing the resulting displacement field. + """ + + if displacement is None: + raise ValueError('Expected a displacement field') + + displacement = _convert_image(displacement) + + kwargs = {} + if constraint_mask: + kwargs['constraint_mask'] = _convert_image(constraint_mask) + if constraint_values: + kwargs['constraint_values'] = _convert_image(constraint_values) + + displacement = pydeform.regularize( + displacement, + precision, + pyramid_levels, + **kwargs + ) + + # Convert the result to SimpleITK + out = sitk.GetImageFromArray(np.array(displacement, copy=False), isVector=True) + out.SetOrigin(displacement.origin) + out.SetSpacing(displacement.spacing) + out.SetDirection(displacement.direction.astype(np.float64).flatten()) + return out def divergence(image): R""" Compute the divergence of a 3D 3-vector field. @@ -283,8 +327,8 @@ def divergence(image): SimpleITK.Image Scalar image representing the divergence. """ - result = pydeform.divergence(*_sitk2numpy(image)) - result = sitk.GetImageFromArray(result) + result = stk.divergence(_convert_image(image)) + result = sitk.GetImageFromArray(np.array(result, copy=False)) result.CopyInformation(image) return result @@ -324,8 +368,8 @@ def rotor(image): SimpleITK.Image Vector image representing the rotor. """ - result = pydeform.rotor(*_sitk2numpy(image)) - result = sitk.GetImageFromArray(result) + result = stk.rotor(_convert_image(image)) + result = sitk.GetImageFromArray(np.array(result, copy=False), isVector=True) result.CopyInformation(image) return result @@ -346,8 +390,8 @@ def circulation_density(image): SimpleITK.Image Vector image representing the circulation density. """ - result = pydeform.circulation_density(*_sitk2numpy(image)) - result = sitk.GetImageFromArray(result) + result = stk.circulation_density(_convert_image(image)) + result = sitk.GetImageFromArray(np.array(result, copy=False), isVector=True) result.CopyInformation(image) return result diff --git a/test/test_api.py b/test/test_api.py index f5fc7bb..6bff435 100644 --- a/test/test_api.py +++ b/test/test_api.py @@ -179,6 +179,41 @@ def test_regularize(self): np.testing.assert_equal(np.array(out), np.array(constraints)) + def test_affine(self): + # Test affine initialization + + affine_transform = pydeform.AffineTransform( + np.array(( + (2, 0, 0), + (0, 3, 0), + (0, 0, 4) + )), + np.array((10, 10, 10)) + ) + + # Do a registration pass without actual iterations to see if affine transform is + # applied to the resulting displacement field + settings = { + 'max_iteration_count': 0 + } + + fixed = pydeform.Volume(np.zeros((10,10,10), dtype=np.float32)) + moving = pydeform.Volume(np.zeros((10,10,10), dtype=np.float32)) + + df = pydeform.register( + fixed, + moving, + settings=settings, + affine_transform=affine_transform + ) + + df = np.array(df, copy=False) + + # Ax + b -> A(1, 1, 1) + b -> (2, 3, 4) + (10, 10, 10) -> (12, 13, 14) + # u(x) = Ax + b - x + self.assertEqual(df[1,1,1,0], 11) + self.assertEqual(df[1,1,1,1], 12) + self.assertEqual(df[1,1,1,2], 13) if __name__ == '__main__': unittest.main() diff --git a/test/test_sitk_api.py b/test/test_sitk_api.py new file mode 100644 index 0000000..df9b755 --- /dev/null +++ b/test/test_sitk_api.py @@ -0,0 +1,162 @@ +import os +import random +import unittest + +import numpy as np +import SimpleITK as sitk + +from random import uniform +from numpy.random import rand, randint + +import pydeform.sitk_api as pydeform +import _stk as stk + +# Use a known, random seed for each assert when +# testing with random data. +def _set_seed(): + seed = int.from_bytes(os.urandom(4), byteorder="big") + np.random.seed(seed) + random.seed(seed) + return seed + + +def _gauss3(size=(200, 200, 200), mu=(100, 100, 100), sigma=20, gamma=1): + x = np.linspace(0, size[2], size[2]) + y = np.linspace(0, size[1], size[1]) + z = np.linspace(0, size[0], size[0]) + x, y, z = np.meshgrid(x, y, z, indexing='ij', sparse=True) + arr = gamma * np.exp(-((x-mu[2])/sigma)**2 - ((y-mu[1])/sigma)**2 - ((z-mu[0])/sigma)**2) + return sitk.GetImageFromArray(arr.astype(np.float32)) + + +def _show(image, origin=(0, 0, 0), spacing=(1, 1, 1), direction=(1, 0, 0, 0, 1, 0, 0, 0, 1)): + image = sitk.GetImageFromArray(image) + image.SetOrigin(origin) + image.SetSpacing(spacing) + image.SetDirection(direction) + sitk.Show(image) + + +def _jaccard(a, b): + a = sitk.GetArrayFromImage(a) + b = sitk.GetArrayFromImage(b) + return np.sum(np.logical_and(a, b)) / np.sum(np.logical_or(a, b)) + +class Test_SitkAPI(unittest.TestCase): + + def test_register(self): + + with self.assertRaises(RuntimeError): + pydeform.register(None, None) + + fixed = _gauss3((10, 10, 10)) + with self.assertRaises(RuntimeError): + pydeform.register(fixed, None) + + moving = _gauss3((10, 10, 10)) + with self.assertRaises(RuntimeError): + pydeform.register(None, moving) + + fixed = _gauss3((10, 10, 10)) + moving = [fixed, fixed] + with self.assertRaises(ValueError): + pydeform.register(fixed, moving) + + moving = _gauss3((10, 10, 10)) + fixed = [moving, moving] + with self.assertRaises(ValueError): + pydeform.register(fixed, moving) + + fixed = sitk.Cast(_gauss3(size=(40, 50, 60), mu=(20, 25, 30), sigma=8) > 0.3, + sitk.sitkFloat32) + moving = sitk.Cast(_gauss3(size=(40, 50, 60), mu=(30, 20, 25), sigma=8) > 0.3, + sitk.sitkFloat32) + + settings = { + 'regularization_weight': 0.05, + } + + d = pydeform.register(fixed, moving, settings=settings) + + res = pydeform.transform(moving, d, sitk.sitkNearestNeighbor) + + self.assertGreater(_jaccard(res > 0.1, fixed > 0.1), 0.97) + + def test_jacobian(self): + for _ in range(100): + seed = _set_seed() + + # Generate some random image data + pad = 1 + origin = [uniform(-5, 5) for i in range(3)] + spacing = [uniform(0.1, 5) for i in range(3)] + shape_no_pad = [randint(50, 80) for i in range(3)] + d = 5 * (2.0 * rand(*shape_no_pad, 3) - 1.0) + d = np.pad(d, 3 * [(pad, pad)] + [(0, 0)], 'constant') + + # SimpleITK oracle + d_sitk = sitk.GetImageFromArray(d, isVector=True) + d_sitk.SetOrigin(origin) + d_sitk.SetSpacing(spacing) + jacobian_sitk = sitk.DisplacementFieldJacobianDeterminant(d_sitk) + jacobian_sitk = sitk.GetArrayFromImage(jacobian_sitk) + + # Compute Jacobian + jacobian = pydeform.jacobian(d_sitk) + jacobian = sitk.GetArrayFromImage(jacobian) + + np.testing.assert_almost_equal(jacobian, jacobian_sitk, decimal=2, + err_msg='Mismatch between `jacobian` and sitk, seed %d' % seed) + + + def test_regularize(self): + df = sitk.GetImageFromArray(rand(10,10,10,3).astype(np.float32), isVector=True) + full_mask = sitk.GetImageFromArray(np.ones((10,10,10)).astype(np.uint8)) + + out = pydeform.regularize(df) + # Should not be identical + self.assertFalse(np.array_equal(np.array(out) , np.array(df))) + + # Should fully replicate constraint values + constraints = sitk.GetImageFromArray(rand(10,10,10,3).astype(np.float32)) + out = pydeform.regularize(df, constraint_mask=full_mask, constraint_values=constraints) + np.testing.assert_equal(sitk.GetArrayFromImage(out), sitk.GetArrayFromImage(constraints)) + + def test_affine(self): + # Test affine initialization + + affine_transform = sitk.AffineTransform(3) + affine_transform.SetTranslation((10,10,10)) + affine_transform.SetMatrix(( + 2, 0, 0, + 0, 3, 0, + 0, 0, 4 + )) + + # Do a registration pass without actual iterations to see if affine transform is + # applied to the resulting displacement field + settings = { + 'max_iteration_count': 0 + } + + fixed = sitk.GetImageFromArray(np.zeros((10,10,10), dtype=np.float32)) + moving = sitk.GetImageFromArray(np.zeros((10,10,10), dtype=np.float32)) + + df = pydeform.register( + fixed, + moving, + settings=settings, + affine_transform=affine_transform + ) + + df = sitk.GetArrayFromImage(df) + + # Ax + b -> A(1, 1, 1) + b -> (2, 3, 4) + (10, 10, 10) -> (12, 13, 14) + # u(x) = Ax + b - x + self.assertEqual(df[1,1,1,0], 11) + self.assertEqual(df[1,1,1,1], 12) + self.assertEqual(df[1,1,1,2], 13) + +if __name__ == '__main__': + unittest.main() +