# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2021 The NiPreps Developers <nipreps@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#
"""ReportCapableInterfaces for registration tools."""
import os
from looseversion import LooseVersion
from nipype.utils.filemanip import fname_presuffix
from nipype.interfaces.base import (
traits,
isdefined,
File,
)
from nipype.interfaces.mixins import reporting
from nipype.interfaces import freesurfer as fs
from nipype.interfaces import fsl
from nipype.interfaces.ants import registration, resampling
from ... import NIWORKFLOWS_LOG
from . import base as nrb
from ..norm import (
_SpatialNormalizationInputSpec,
_SpatialNormalizationOutputSpec,
SpatialNormalization,
)
from ..fixes import (
FixHeaderApplyTransforms as ApplyTransforms,
FixHeaderRegistration as Registration,
)
class _SpatialNormalizationInputSpecRPT(
nrb._SVGReportCapableInputSpec, _SpatialNormalizationInputSpec
):
pass
class _SpatialNormalizationOutputSpecRPT(
reporting.ReportCapableOutputSpec, _SpatialNormalizationOutputSpec
):
pass
[docs]
class SpatialNormalizationRPT(nrb.RegistrationRC, SpatialNormalization):
input_spec = _SpatialNormalizationInputSpecRPT
output_spec = _SpatialNormalizationOutputSpecRPT
def _post_run_hook(self, runtime):
# We need to dig into the internal ants.Registration interface
self._fixed_image = self._get_ants_args()["fixed_image"]
if isinstance(self._fixed_image, (list, tuple)):
self._fixed_image = self._fixed_image[0] # get first item if list
if self._get_ants_args().get("fixed_image_mask") is not None:
self._fixed_image_mask = self._get_ants_args().get("fixed_image_mask")
self._moving_image = self.aggregate_outputs(runtime=runtime).warped_image
NIWORKFLOWS_LOG.info(
"Report - setting fixed (%s) and moving (%s) images",
self._fixed_image,
self._moving_image,
)
return super()._post_run_hook(runtime)
class _ANTSRegistrationInputSpecRPT(
nrb._SVGReportCapableInputSpec, registration.RegistrationInputSpec
):
pass
class _ANTSRegistrationOutputSpecRPT(
reporting.ReportCapableOutputSpec, registration.RegistrationOutputSpec
):
pass
[docs]
class ANTSRegistrationRPT(nrb.RegistrationRC, Registration):
input_spec = _ANTSRegistrationInputSpecRPT
output_spec = _ANTSRegistrationOutputSpecRPT
def _post_run_hook(self, runtime):
self._fixed_image = self.inputs.fixed_image[0]
self._moving_image = self.aggregate_outputs(runtime=runtime).warped_image
NIWORKFLOWS_LOG.info(
"Report - setting fixed (%s) and moving (%s) images",
self._fixed_image,
self._moving_image,
)
return super()._post_run_hook(runtime)
class _ANTSApplyTransformsInputSpecRPT(
nrb._SVGReportCapableInputSpec, resampling.ApplyTransformsInputSpec
):
pass
class _ANTSApplyTransformsOutputSpecRPT(
reporting.ReportCapableOutputSpec, resampling.ApplyTransformsOutputSpec
):
pass
class _ApplyTOPUPInputSpecRPT(
nrb._SVGReportCapableInputSpec, fsl.epi.ApplyTOPUPInputSpec
):
wm_seg = File(argstr="-wmseg %s", desc="reference white matter segmentation mask")
class _ApplyTOPUPOutputSpecRPT(
reporting.ReportCapableOutputSpec, fsl.epi.ApplyTOPUPOutputSpec
):
pass
[docs]
class ApplyTOPUPRPT(nrb.RegistrationRC, fsl.ApplyTOPUP):
input_spec = _ApplyTOPUPInputSpecRPT
output_spec = _ApplyTOPUPOutputSpecRPT
def _post_run_hook(self, runtime):
from nilearn.image import index_img
self._fixed_image_label = "after"
self._moving_image_label = "before"
self._fixed_image = index_img(
self.aggregate_outputs(runtime=runtime).out_corrected, 0
)
self._moving_image = index_img(self.inputs.in_files[0], 0)
self._contour = self.inputs.wm_seg if isdefined(self.inputs.wm_seg) else None
NIWORKFLOWS_LOG.info(
"Report - setting corrected (%s) and warped (%s) images",
self._fixed_image,
self._moving_image,
)
return super()._post_run_hook(runtime)
class _FUGUEInputSpecRPT(nrb._SVGReportCapableInputSpec, fsl.preprocess.FUGUEInputSpec):
wm_seg = File(argstr="-wmseg %s", desc="reference white matter segmentation mask")
class _FUGUEOutputSpecRPT(
reporting.ReportCapableOutputSpec, fsl.preprocess.FUGUEOutputSpec
):
pass
[docs]
class FUGUERPT(nrb.RegistrationRC, fsl.FUGUE):
input_spec = _FUGUEInputSpecRPT
output_spec = _FUGUEOutputSpecRPT
def _post_run_hook(self, runtime):
self._fixed_image_label = "after"
self._moving_image_label = "before"
self._fixed_image = self.aggregate_outputs(runtime=runtime).unwarped_file
self._moving_image = self.inputs.in_file
self._contour = self.inputs.wm_seg if isdefined(self.inputs.wm_seg) else None
NIWORKFLOWS_LOG.info(
"Report - setting corrected (%s) and warped (%s) images",
self._fixed_image,
self._moving_image,
)
return super()._post_run_hook(runtime)
class _FLIRTInputSpecRPT(nrb._SVGReportCapableInputSpec, fsl.preprocess.FLIRTInputSpec):
pass
class _FLIRTOutputSpecRPT(
reporting.ReportCapableOutputSpec, fsl.preprocess.FLIRTOutputSpec
):
pass
[docs]
class FLIRTRPT(nrb.RegistrationRC, fsl.FLIRT):
input_spec = _FLIRTInputSpecRPT
output_spec = _FLIRTOutputSpecRPT
def _post_run_hook(self, runtime):
self._fixed_image = self.inputs.reference
self._moving_image = self.aggregate_outputs(runtime=runtime).out_file
self._contour = self.inputs.wm_seg if isdefined(self.inputs.wm_seg) else None
NIWORKFLOWS_LOG.info(
"Report - setting fixed (%s) and moving (%s) images",
self._fixed_image,
self._moving_image,
)
return super()._post_run_hook(runtime)
class _ApplyXFMInputSpecRPT(
nrb._SVGReportCapableInputSpec, fsl.preprocess.ApplyXFMInputSpec
):
pass
[docs]
class ApplyXFMRPT(FLIRTRPT, fsl.ApplyXFM):
input_spec = _ApplyXFMInputSpecRPT
output_spec = _FLIRTOutputSpecRPT
if LooseVersion("0.0.0") < fs.Info.looseversion() < LooseVersion("6.0.0"):
_BBRegisterInputSpec = fs.preprocess.BBRegisterInputSpec
else:
_BBRegisterInputSpec = fs.preprocess.BBRegisterInputSpec6
class _BBRegisterInputSpecRPT(nrb._SVGReportCapableInputSpec, _BBRegisterInputSpec):
# Adds default=True, usedefault=True
out_lta_file = traits.Either(
traits.Bool,
File,
default=True,
usedefault=True,
argstr="--lta %s",
min_ver="5.2.0",
desc="write the transformation matrix in LTA format",
)
class _BBRegisterOutputSpecRPT(
reporting.ReportCapableOutputSpec, fs.preprocess.BBRegisterOutputSpec
):
pass
[docs]
class BBRegisterRPT(nrb.RegistrationRC, fs.BBRegister):
input_spec = _BBRegisterInputSpecRPT
output_spec = _BBRegisterOutputSpecRPT
def _post_run_hook(self, runtime):
outputs = self.aggregate_outputs(runtime=runtime)
mri_dir = os.path.join(self.inputs.subjects_dir, self.inputs.subject_id, "mri")
target_file = os.path.join(mri_dir, "brainmask.mgz")
# Apply transform for simplicity
mri_vol2vol = fs.ApplyVolTransform(
source_file=self.inputs.source_file,
target_file=target_file,
lta_file=outputs.out_lta_file,
interp="nearest",
)
res = mri_vol2vol.run()
self._fixed_image = target_file
self._moving_image = res.outputs.transformed_file
self._contour = os.path.join(mri_dir, "ribbon.mgz")
NIWORKFLOWS_LOG.info(
"Report - setting fixed (%s) and moving (%s) images",
self._fixed_image,
self._moving_image,
)
return super()._post_run_hook(runtime)
class _MRICoregInputSpecRPT(
nrb._SVGReportCapableInputSpec, fs.registration.MRICoregInputSpec
):
pass
class _MRICoregOutputSpecRPT(
reporting.ReportCapableOutputSpec, fs.registration.MRICoregOutputSpec
):
pass
[docs]
class MRICoregRPT(nrb.RegistrationRC, fs.MRICoreg):
input_spec = _MRICoregInputSpecRPT
output_spec = _MRICoregOutputSpecRPT
def _post_run_hook(self, runtime):
outputs = self.aggregate_outputs(runtime=runtime)
mri_dir = None
if isdefined(self.inputs.subject_id):
mri_dir = os.path.join(
self.inputs.subjects_dir, self.inputs.subject_id, "mri"
)
if isdefined(self.inputs.reference_file):
target_file = self.inputs.reference_file
else:
target_file = os.path.join(mri_dir, "brainmask.mgz")
# Apply transform for simplicity
mri_vol2vol = fs.ApplyVolTransform(
source_file=self.inputs.source_file,
target_file=target_file,
lta_file=outputs.out_lta_file,
interp="nearest",
)
res = mri_vol2vol.run()
self._fixed_image = target_file
self._moving_image = res.outputs.transformed_file
if mri_dir is not None:
self._contour = os.path.join(mri_dir, "ribbon.mgz")
NIWORKFLOWS_LOG.info(
"Report - setting fixed (%s) and moving (%s) images",
self._fixed_image,
self._moving_image,
)
return super()._post_run_hook(runtime)
class _SimpleBeforeAfterInputSpecRPT(nrb._SVGReportCapableInputSpec):
before = File(exists=True, mandatory=True, desc="file before")
after = File(exists=True, mandatory=True, desc="file after")
wm_seg = File(desc="reference white matter segmentation mask")
before_label = traits.Str("before", usedefault=True)
after_label = traits.Str("after", usedefault=True)
dismiss_affine = traits.Bool(
False, usedefault=True, desc="rotate image(s) to cardinal axes"
)
[docs]
class SimpleBeforeAfterRPT(nrb.RegistrationRC, nrb.ReportingInterface):
input_spec = _SimpleBeforeAfterInputSpecRPT
def _post_run_hook(self, runtime):
""" there is not inner interface to run """
self._fixed_image_label = self.inputs.after_label
self._moving_image_label = self.inputs.before_label
self._fixed_image = self.inputs.after
self._moving_image = self.inputs.before
self._contour = self.inputs.wm_seg if isdefined(self.inputs.wm_seg) else None
self._dismiss_affine = self.inputs.dismiss_affine
NIWORKFLOWS_LOG.info(
"Report - setting before (%s) and after (%s) images",
self._fixed_image,
self._moving_image,
)
return super()._post_run_hook(runtime)
class _ResampleBeforeAfterInputSpecRPT(_SimpleBeforeAfterInputSpecRPT):
base = traits.Enum("before", "after", usedefault=True, mandatory=True)
[docs]
class ResampleBeforeAfterRPT(SimpleBeforeAfterRPT):
input_spec = _ResampleBeforeAfterInputSpecRPT
def _post_run_hook(self, runtime):
from nilearn import image as nli
self._fixed_image = self.inputs.after
self._moving_image = self.inputs.before
if self.inputs.base == "before":
resampled_after = nli.resample_to_img(self._fixed_image, self._moving_image)
fname = fname_presuffix(
self._fixed_image, suffix="_resampled", newpath=runtime.cwd
)
resampled_after.to_filename(fname)
self._fixed_image = fname
else:
resampled_before = nli.resample_to_img(
self._moving_image, self._fixed_image
)
fname = fname_presuffix(
self._moving_image, suffix="_resampled", newpath=runtime.cwd
)
resampled_before.to_filename(fname)
self._moving_image = fname
self._contour = self.inputs.wm_seg if isdefined(self.inputs.wm_seg) else None
NIWORKFLOWS_LOG.info(
"Report - setting before (%s) and after (%s) images",
self._fixed_image,
self._moving_image,
)
runtime = super()._post_run_hook(runtime)
NIWORKFLOWS_LOG.info("Successfully created report (%s)", self._out_report)
os.unlink(fname)
return runtime