Source code for niworkflows.interfaces.reportlets.registration

# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2021 The NiPreps Developers <nipreps@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
#     https://www.nipreps.org/community/licensing/
#
"""ReportCapableInterfaces for registration tools."""

import os

from looseversion import LooseVersion
from nipype.interfaces import freesurfer as fs
from nipype.interfaces import fsl
from nipype.interfaces.ants import registration, resampling
from nipype.interfaces.base import (
    File,
    isdefined,
    traits,
)
from nipype.interfaces.mixins import reporting
from nipype.utils.filemanip import fname_presuffix

from ... import NIWORKFLOWS_LOG
from ..fixes import (
    FixHeaderApplyTransforms as ApplyTransforms,
)
from ..fixes import (
    FixHeaderRegistration as Registration,
)
from ..norm import (
    SpatialNormalization,
    _SpatialNormalizationInputSpec,
    _SpatialNormalizationOutputSpec,
)
from . import base as nrb


class _SpatialNormalizationInputSpecRPT(
    nrb._SVGReportCapableInputSpec, _SpatialNormalizationInputSpec
):
    pass


class _SpatialNormalizationOutputSpecRPT(
    reporting.ReportCapableOutputSpec, _SpatialNormalizationOutputSpec
):
    pass


[docs] class SpatialNormalizationRPT(nrb.RegistrationRC, SpatialNormalization): input_spec = _SpatialNormalizationInputSpecRPT output_spec = _SpatialNormalizationOutputSpecRPT def _post_run_hook(self, runtime): # We need to dig into the internal ants.Registration interface self._fixed_image = self._get_ants_args()['fixed_image'] if isinstance(self._fixed_image, (list, tuple)): self._fixed_image = self._fixed_image[0] # get first item if list if self._get_ants_args().get('fixed_image_mask') is not None: self._fixed_image_mask = self._get_ants_args().get('fixed_image_mask') self._moving_image = self.aggregate_outputs(runtime=runtime).warped_image NIWORKFLOWS_LOG.info( 'Report - setting fixed (%s) and moving (%s) images', self._fixed_image, self._moving_image, ) return super()._post_run_hook(runtime)
class _ANTSRegistrationInputSpecRPT( nrb._SVGReportCapableInputSpec, registration.RegistrationInputSpec ): pass class _ANTSRegistrationOutputSpecRPT( reporting.ReportCapableOutputSpec, registration.RegistrationOutputSpec ): pass
[docs] class ANTSRegistrationRPT(nrb.RegistrationRC, Registration): input_spec = _ANTSRegistrationInputSpecRPT output_spec = _ANTSRegistrationOutputSpecRPT def _post_run_hook(self, runtime): self._fixed_image = self.inputs.fixed_image[0] self._moving_image = self.aggregate_outputs(runtime=runtime).warped_image NIWORKFLOWS_LOG.info( 'Report - setting fixed (%s) and moving (%s) images', self._fixed_image, self._moving_image, ) return super()._post_run_hook(runtime)
class _ANTSApplyTransformsInputSpecRPT( nrb._SVGReportCapableInputSpec, resampling.ApplyTransformsInputSpec ): pass class _ANTSApplyTransformsOutputSpecRPT( reporting.ReportCapableOutputSpec, resampling.ApplyTransformsOutputSpec ): pass
[docs] class ANTSApplyTransformsRPT(nrb.RegistrationRC, ApplyTransforms): input_spec = _ANTSApplyTransformsInputSpecRPT output_spec = _ANTSApplyTransformsOutputSpecRPT def _post_run_hook(self, runtime): self._fixed_image = self.inputs.reference_image self._moving_image = self.aggregate_outputs(runtime=runtime).output_image NIWORKFLOWS_LOG.info( 'Report - setting fixed (%s) and moving (%s) images', self._fixed_image, self._moving_image, ) return super()._post_run_hook(runtime)
class _ApplyTOPUPInputSpecRPT(nrb._SVGReportCapableInputSpec, fsl.epi.ApplyTOPUPInputSpec): wm_seg = File(argstr='-wmseg %s', desc='reference white matter segmentation mask') class _ApplyTOPUPOutputSpecRPT(reporting.ReportCapableOutputSpec, fsl.epi.ApplyTOPUPOutputSpec): pass
[docs] class ApplyTOPUPRPT(nrb.RegistrationRC, fsl.ApplyTOPUP): input_spec = _ApplyTOPUPInputSpecRPT output_spec = _ApplyTOPUPOutputSpecRPT def _post_run_hook(self, runtime): from nilearn.image import index_img self._fixed_image_label = 'after' self._moving_image_label = 'before' self._fixed_image = index_img(self.aggregate_outputs(runtime=runtime).out_corrected, 0) self._moving_image = index_img(self.inputs.in_files[0], 0) self._contour = self.inputs.wm_seg if isdefined(self.inputs.wm_seg) else None NIWORKFLOWS_LOG.info( 'Report - setting corrected (%s) and warped (%s) images', self._fixed_image, self._moving_image, ) return super()._post_run_hook(runtime)
class _FUGUEInputSpecRPT(nrb._SVGReportCapableInputSpec, fsl.preprocess.FUGUEInputSpec): wm_seg = File(argstr='-wmseg %s', desc='reference white matter segmentation mask') class _FUGUEOutputSpecRPT(reporting.ReportCapableOutputSpec, fsl.preprocess.FUGUEOutputSpec): pass
[docs] class FUGUERPT(nrb.RegistrationRC, fsl.FUGUE): input_spec = _FUGUEInputSpecRPT output_spec = _FUGUEOutputSpecRPT def _post_run_hook(self, runtime): self._fixed_image_label = 'after' self._moving_image_label = 'before' self._fixed_image = self.aggregate_outputs(runtime=runtime).unwarped_file self._moving_image = self.inputs.in_file self._contour = self.inputs.wm_seg if isdefined(self.inputs.wm_seg) else None NIWORKFLOWS_LOG.info( 'Report - setting corrected (%s) and warped (%s) images', self._fixed_image, self._moving_image, ) return super()._post_run_hook(runtime)
class _FLIRTInputSpecRPT(nrb._SVGReportCapableInputSpec, fsl.preprocess.FLIRTInputSpec): pass class _FLIRTOutputSpecRPT(reporting.ReportCapableOutputSpec, fsl.preprocess.FLIRTOutputSpec): pass
[docs] class FLIRTRPT(nrb.RegistrationRC, fsl.FLIRT): input_spec = _FLIRTInputSpecRPT output_spec = _FLIRTOutputSpecRPT def _post_run_hook(self, runtime): self._fixed_image = self.inputs.reference self._moving_image = self.aggregate_outputs(runtime=runtime).out_file self._contour = self.inputs.wm_seg if isdefined(self.inputs.wm_seg) else None NIWORKFLOWS_LOG.info( 'Report - setting fixed (%s) and moving (%s) images', self._fixed_image, self._moving_image, ) return super()._post_run_hook(runtime)
class _ApplyXFMInputSpecRPT(nrb._SVGReportCapableInputSpec, fsl.preprocess.ApplyXFMInputSpec): pass
[docs] class ApplyXFMRPT(FLIRTRPT, fsl.ApplyXFM): input_spec = _ApplyXFMInputSpecRPT output_spec = _FLIRTOutputSpecRPT
if LooseVersion('0.0.0') < fs.Info.looseversion() < LooseVersion('6.0.0'): _BBRegisterInputSpec = fs.preprocess.BBRegisterInputSpec else: _BBRegisterInputSpec = fs.preprocess.BBRegisterInputSpec6 class _BBRegisterInputSpecRPT(nrb._SVGReportCapableInputSpec, _BBRegisterInputSpec): # Adds default=True, usedefault=True out_lta_file = traits.Either( traits.Bool, File, default=True, usedefault=True, argstr='--lta %s', min_ver='5.2.0', desc='write the transformation matrix in LTA format', ) class _BBRegisterOutputSpecRPT( reporting.ReportCapableOutputSpec, fs.preprocess.BBRegisterOutputSpec ): pass
[docs] class BBRegisterRPT(nrb.RegistrationRC, fs.BBRegister): input_spec = _BBRegisterInputSpecRPT output_spec = _BBRegisterOutputSpecRPT def _post_run_hook(self, runtime): outputs = self.aggregate_outputs(runtime=runtime) mri_dir = os.path.join(self.inputs.subjects_dir, self.inputs.subject_id, 'mri') target_file = os.path.join(mri_dir, 'brainmask.mgz') # Apply transform for simplicity mri_vol2vol = fs.ApplyVolTransform( source_file=self.inputs.source_file, target_file=target_file, lta_file=outputs.out_lta_file, interp='nearest', ) res = mri_vol2vol.run() self._fixed_image = target_file self._moving_image = res.outputs.transformed_file self._contour = os.path.join(mri_dir, 'ribbon.mgz') NIWORKFLOWS_LOG.info( 'Report - setting fixed (%s) and moving (%s) images', self._fixed_image, self._moving_image, ) return super()._post_run_hook(runtime)
class _MRICoregInputSpecRPT(nrb._SVGReportCapableInputSpec, fs.registration.MRICoregInputSpec): pass class _MRICoregOutputSpecRPT( reporting.ReportCapableOutputSpec, fs.registration.MRICoregOutputSpec ): pass
[docs] class MRICoregRPT(nrb.RegistrationRC, fs.MRICoreg): input_spec = _MRICoregInputSpecRPT output_spec = _MRICoregOutputSpecRPT def _post_run_hook(self, runtime): outputs = self.aggregate_outputs(runtime=runtime) mri_dir = None if isdefined(self.inputs.subject_id): mri_dir = os.path.join(self.inputs.subjects_dir, self.inputs.subject_id, 'mri') if isdefined(self.inputs.reference_file): target_file = self.inputs.reference_file else: target_file = os.path.join(mri_dir, 'brainmask.mgz') # Apply transform for simplicity mri_vol2vol = fs.ApplyVolTransform( source_file=self.inputs.source_file, target_file=target_file, lta_file=outputs.out_lta_file, interp='nearest', ) res = mri_vol2vol.run() self._fixed_image = target_file self._moving_image = res.outputs.transformed_file if mri_dir is not None: self._contour = os.path.join(mri_dir, 'ribbon.mgz') NIWORKFLOWS_LOG.info( 'Report - setting fixed (%s) and moving (%s) images', self._fixed_image, self._moving_image, ) return super()._post_run_hook(runtime)
class _SimpleBeforeAfterInputSpecRPT(nrb._SVGReportCapableInputSpec): before = File(exists=True, mandatory=True, desc='file before') after = File(exists=True, mandatory=True, desc='file after') wm_seg = File(desc='reference white matter segmentation mask') before_label = traits.Str('before', usedefault=True) after_label = traits.Str('after', usedefault=True) dismiss_affine = traits.Bool(False, usedefault=True, desc='rotate image(s) to cardinal axes')
[docs] class SimpleBeforeAfterRPT(nrb.RegistrationRC, nrb.ReportingInterface): input_spec = _SimpleBeforeAfterInputSpecRPT def _post_run_hook(self, runtime): """there is not inner interface to run""" self._fixed_image_label = self.inputs.after_label self._moving_image_label = self.inputs.before_label self._fixed_image = self.inputs.after self._moving_image = self.inputs.before self._contour = self.inputs.wm_seg if isdefined(self.inputs.wm_seg) else None self._dismiss_affine = self.inputs.dismiss_affine NIWORKFLOWS_LOG.info( 'Report - setting before (%s) and after (%s) images', self._fixed_image, self._moving_image, ) return super()._post_run_hook(runtime)
class _ResampleBeforeAfterInputSpecRPT(_SimpleBeforeAfterInputSpecRPT): base = traits.Enum('before', 'after', usedefault=True, mandatory=True)
[docs] class ResampleBeforeAfterRPT(SimpleBeforeAfterRPT): input_spec = _ResampleBeforeAfterInputSpecRPT def _post_run_hook(self, runtime): from nilearn import image as nli self._fixed_image = self.inputs.after self._moving_image = self.inputs.before if self.inputs.base == 'before': resampled_after = nli.resample_to_img(self._fixed_image, self._moving_image) fname = fname_presuffix(self._fixed_image, suffix='_resampled', newpath=runtime.cwd) resampled_after.to_filename(fname) self._fixed_image = fname else: resampled_before = nli.resample_to_img(self._moving_image, self._fixed_image) fname = fname_presuffix(self._moving_image, suffix='_resampled', newpath=runtime.cwd) resampled_before.to_filename(fname) self._moving_image = fname self._contour = self.inputs.wm_seg if isdefined(self.inputs.wm_seg) else None NIWORKFLOWS_LOG.info( 'Report - setting before (%s) and after (%s) images', self._fixed_image, self._moving_image, ) runtime = super()._post_run_hook(runtime) NIWORKFLOWS_LOG.info('Successfully created report (%s)', self._out_report) os.unlink(fname) return runtime