# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2021 The NiPreps Developers <nipreps@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#
"""ITK files handling."""
import os
from mimetypes import guess_type
from tempfile import TemporaryDirectory
import nibabel as nb
import nitransforms as nt
import numpy as np
from nipype import logging
from nipype.interfaces.base import (
BaseInterfaceInputSpec,
File,
InputMultiObject,
OutputMultiObject,
SimpleInterface,
TraitedSpec,
isdefined,
traits,
)
from nipype.utils.filemanip import fname_presuffix
from .fixes import _FixTraitApplyTransformsInputSpec
LOGGER = logging.getLogger("nipype.interface")
class _MCFLIRT2ITKInputSpec(BaseInterfaceInputSpec):
in_files = InputMultiObject(
File(exists=True), mandatory=True, desc="list of MAT files from MCFLIRT"
)
in_reference = File(exists=True, mandatory=True, desc="input image for spatial reference")
in_source = File(exists=True, mandatory=True, desc="input image for spatial source")
num_threads = traits.Int(nohash=True, desc="number of parallel processes")
class _MCFLIRT2ITKOutputSpec(TraitedSpec):
out_file = File(desc="the output ITKTransform file")
[docs]
class MCFLIRT2ITK(SimpleInterface):
"""Convert a list of MAT files from MCFLIRT into an ITK Transform file."""
input_spec = _MCFLIRT2ITKInputSpec
output_spec = _MCFLIRT2ITKOutputSpec
def _run_interface(self, runtime):
if isdefined(self.inputs.num_threads):
LOGGER.warning("Multithreading is deprecated. Remove the num_threads input.")
source = nb.load(self.inputs.in_source)
reference = nb.load(self.inputs.in_reference)
affines = [
nt.linear.load(mat, fmt='fsl', reference=reference, moving=source)
for mat in self.inputs.in_files
]
affarray = nt.io.itk.ITKLinearTransformArray.from_ras(
np.stack([a.matrix for a in affines], axis=0),
)
self._results["out_file"] = os.path.join(runtime.cwd, "mat2itk.txt")
affarray.to_filename(self._results["out_file"])
return runtime
class _MultiApplyTransformsInputSpec(_FixTraitApplyTransformsInputSpec):
input_image = InputMultiObject(
File(exists=True),
mandatory=True,
desc="input time-series as a list of volumes after splitting"
" through the fourth dimension",
)
num_threads = traits.Int(
1, usedefault=True, nohash=True, desc="number of parallel processes"
)
save_cmd = traits.Bool(
True, usedefault=True, desc="write a log of command lines that were applied"
)
copy_dtype = traits.Bool(
False, usedefault=True, desc="copy dtype from inputs to outputs"
)
class _MultiApplyTransformsOutputSpec(TraitedSpec):
out_files = OutputMultiObject(File(), desc="the output ITKTransform file")
log_cmdline = File(desc="a list of command lines used to apply transforms")
def _applytfms(args):
"""
Applies ANTs' antsApplyTransforms to the input image.
All inputs are zipped in one tuple to make it digestible by
multiprocessing's map
"""
import nibabel as nb
from nipype.utils.filemanip import fname_presuffix
from niworkflows.interfaces.fixes import FixHeaderApplyTransforms as ApplyTransforms
in_file, in_xform, ifargs, index, newpath = args
out_file = fname_presuffix(
in_file, suffix="_xform-%05d" % index, newpath=newpath, use_ext=True
)
copy_dtype = ifargs.pop("copy_dtype", False)
xfm = ApplyTransforms(
input_image=in_file, transforms=in_xform, output_image=out_file, **ifargs
)
xfm.terminal_output = "allatonce"
xfm.resource_monitor = False
runtime = xfm.run().runtime
if copy_dtype:
nii = nb.load(out_file, mmap=False)
in_dtype = nb.load(in_file).get_data_dtype()
# Overwrite only iff dtypes don't match
if in_dtype != nii.get_data_dtype():
nii.set_data_dtype(in_dtype)
nii.to_filename(out_file)
return (out_file, runtime.cmdline)
def _arrange_xfms(transforms, num_files, tmp_folder):
"""
Convenience method to arrange the list of transforms that should be applied
to each input file
"""
base_xform = ["#Insight Transform File V1.0", "#Transform 0"]
# Initialize the transforms matrix
xfms_T = []
for i, tf_file in enumerate(transforms):
if tf_file == "identity":
xfms_T.append([tf_file] * num_files)
continue
# If it is a deformation field, copy to the tfs_matrix directly
if guess_type(tf_file)[0] != "text/plain":
xfms_T.append([tf_file] * num_files)
continue
with open(tf_file) as tf_fh:
tfdata = tf_fh.read().strip()
# If it is not an ITK transform file, copy to the tfs_matrix directly
if not tfdata.startswith("#Insight Transform File"):
xfms_T.append([tf_file] * num_files)
continue
# Count number of transforms in ITK transform file
nxforms = tfdata.count("#Transform")
# Remove first line
tfdata = tfdata.split("\n")[1:]
# If it is a ITK transform file with only 1 xform, copy to the tfs_matrix directly
if nxforms == 1:
xfms_T.append([tf_file] * num_files)
continue
if nxforms != num_files:
raise RuntimeError(
"Number of transforms (%d) found in the ITK file does not match"
" the number of input image files (%d)." % (nxforms, num_files)
)
# At this point splitting transforms will be necessary, generate a base name
out_base = fname_presuffix(
tf_file, suffix="_pos-%03d_xfm-{:05d}" % i, newpath=tmp_folder.name
).format
# Split combined ITK transforms file
split_xfms = []
for xform_i in range(nxforms):
# Find start token to extract
startidx = tfdata.index("#Transform %d" % xform_i)
next_xform = base_xform + tfdata[startidx + 1:startidx + 4] + [""]
xfm_file = out_base(xform_i)
with open(xfm_file, "w") as out_xfm:
out_xfm.write("\n".join(next_xform))
split_xfms.append(xfm_file)
xfms_T.append(split_xfms)
# Transpose back (only Python 3)
return list(map(list, zip(*xfms_T)))