Source code for ewokstomo.tasks.nxtomo_utils
from __future__ import annotations
import logging
import warnings
import h5py
import numpy as np
import pint
from nxtomo import NXtomo
from nxtomo.nxobject.nxdetector import ImageKey
from pint.errors import UnitStrippedWarning
_logger = logging.getLogger(__name__)
_ureg = pint.get_application_registry()
warnings.filterwarnings("ignore", category=UnitStrippedWarning)
def _warn(msg: str, *args):
_logger.warning(msg, *args)
def _as_array(val, dtype=None):
if val is None:
return None
arr = np.asarray(val)
if dtype is not None:
try:
arr = arr.astype(dtype)
except Exception:
_warn("Failed to cast %s to %s", val, dtype)
return arr
def _ensure_length(arr, length: int, name: str):
if arr is None:
return None
if arr.shape == ():
return np.full(length, arr.item())
flat = arr.reshape(-1)
if flat.size != length:
_warn("Size mismatch for %s (%s vs %s), resizing", name, flat.size, length)
flat = np.resize(flat, length)
return flat
def _as_quantity(val, unit: pint.Unit, length: int | None = None, name: str = ""):
arr = _as_array(val, dtype=np.float64)
if arr is None:
return None
if length is not None:
arr = _ensure_length(arr, length, name or "quantity")
return arr * unit
def _decode_data_axis(value) -> str:
if isinstance(value, (bytes, np.bytes_)):
value = value.decode()
if isinstance(value, np.str_):
value = str(value)
if not isinstance(value, str):
raise ValueError(f"Unsupported detector data axis value: {value!r}")
return value.strip()
def _resolve_detector_flips(detector_data_axes=None) -> tuple[bool, bool]:
if detector_data_axes is None:
return False, False
axes = np.asarray(detector_data_axes, dtype=object).reshape(-1)
if axes.size != 2:
raise ValueError("detector_data_axes must contain exactly 2 values (ud, lr)")
soft_ud_flip, soft_lr_flip = map(_decode_data_axis, axes.tolist())
try:
ud_flipped = {"-z": False, "z": True}[soft_ud_flip]
except KeyError as exc:
raise ValueError(
f"Unsupported detector up-down axis: {soft_ud_flip!r}"
) from exc
try:
lr_flipped = {"y": False, "-y": True}[soft_lr_flip]
except KeyError as exc:
raise ValueError(
f"Unsupported detector left-right axis: {soft_lr_flip!r}"
) from exc
return lr_flipped, ud_flipped
[docs]
def build_nxtomo_from_inputs(
*,
energy_kev=None,
title=None,
start_time=None,
end_time=None,
group_size=None,
estimated_cor=None,
detector_data_axes=None,
detector_x_pixel_size_um=None,
detector_y_pixel_size_um=None,
sample_x_pixel_size_um=None,
sample_y_pixel_size_um=None,
sample_detector_distance_mm=None,
source_sample_distance_mm=None,
field_of_view=None,
tomo_n=None,
instrument_name=None,
sample_name=None,
propagation_distance_mm=None,
detector_data_file_paths=None,
detector_data_h5_url=None,
detector_data_shapes=None,
detector_data_dtype=None,
image_key_control=None,
count_time_s=None,
rotation_angle_deg=None,
x_translation_mm=None,
y_translation_mm=None,
z_translation_mm=None,
current_a=None,
sequence_number=None,
) -> NXtomo:
image_keys = _as_array(image_key_control, dtype=np.int64)
if image_keys is None:
raise ValueError("image_key_control is required")
frames = image_keys.reshape(-1).size
if sample_name is None:
raise ValueError("sample_name is required")
if rotation_angle_deg is None:
raise ValueError("rotation_angle_deg is required")
nx = NXtomo()
nx.energy = _as_quantity(energy_kev, _ureg.keV)
nx.title = title
nx.start_time = start_time
nx.end_time = end_time
nx.group_size = group_size
det = nx.instrument.detector
sample = nx.sample
if estimated_cor is not None:
try:
det.x_rotation_axis_pixel_position = float(estimated_cor)
except Exception:
_warn("Bad COR %s", estimated_cor)
lr_flipped, ud_flipped = _resolve_detector_flips(detector_data_axes)
det.set_transformation_from_lr_flipped(lr_flipped)
det.set_transformation_from_ud_flipped(ud_flipped)
det.x_pixel_size = _as_quantity(detector_x_pixel_size_um, _ureg.micrometer)
det.y_pixel_size = _as_quantity(detector_y_pixel_size_um, _ureg.micrometer)
sample.x_pixel_size = _as_quantity(sample_x_pixel_size_um, _ureg.micrometer)
sample.y_pixel_size = _as_quantity(sample_y_pixel_size_um, _ureg.micrometer)
det.distance = _as_quantity(sample_detector_distance_mm, _ureg.millimeter)
if det.distance is not None:
det.distance = abs(det.distance)
src_dist = _as_quantity(source_sample_distance_mm, _ureg.millimeter)
if src_dist is not None:
nx.instrument.source.distance = -abs(src_dist)
det.field_of_view = field_of_view
det.tomo_n = tomo_n
nx.instrument.name = instrument_name
sample.name = sample_name
sample.propagation_distance = _as_quantity(
propagation_distance_mm, _ureg.millimeter
)
if detector_data_file_paths is None or detector_data_h5_url is None:
raise ValueError("detector_data_file_paths and detector_data_h5_url required")
file_paths = _as_array(detector_data_file_paths)
data_paths = _as_array(detector_data_h5_url)
if file_paths is None or data_paths is None:
raise ValueError("detector_data_file_paths and detector_data_h5_url required")
file_paths = file_paths.reshape(-1)
data_paths = data_paths.reshape(-1)
if file_paths.size != data_paths.size:
raise ValueError(
"detector_data_file_paths and detector_data_h5_url size mismatch"
)
normalized_data_paths = np.asarray(data_paths, dtype=object)
for index, data_ref in enumerate(normalized_data_paths):
data_path = str(data_ref)
if not data_path:
raise ValueError(
f"Invalid detector data reference at index {index}: missing dataset path"
)
if "::" in data_path:
raise ValueError(
"detector_data_h5_url must contain only HDF5 dataset paths. "
f"Found URL-like value at index {index}: {data_path}"
)
normalized_data_paths[index] = data_path
if detector_data_shapes is None:
raise ValueError("detector_data_shapes is required")
if detector_data_dtype is None:
raise ValueError("detector_data_dtype is required")
raw_shapes = np.asarray(detector_data_shapes, dtype=object).tolist()
if len(raw_shapes) != file_paths.size:
raise ValueError(
"detector_data_shapes size mismatch with detector_data_file_paths and detector_data_h5_url"
)
shapes: list[tuple[int, int, int]] = []
for index, shape in enumerate(raw_shapes):
arr = np.asarray(shape, dtype=np.int64).reshape(-1)
if arr.size != 3:
raise ValueError(
f"detector_data_shapes[{index}] must have exactly 3 integers (frames, ny, nx)"
)
if np.any(arr <= 0):
raise ValueError(
f"detector_data_shapes[{index}] must contain positive integers, got {tuple(arr.tolist())}"
)
shapes.append((int(arr[0]), int(arr[1]), int(arr[2])))
raw_dtypes = np.asarray(detector_data_dtype, dtype=object).reshape(-1).tolist()
if len(raw_dtypes) == 1:
dtypes = [np.dtype(raw_dtypes[0])] * file_paths.size
else:
if len(raw_dtypes) != file_paths.size:
raise ValueError(
"detector_data_dtype size mismatch with detector_data_file_paths and detector_data_h5_url"
)
dtypes = [np.dtype(dtype) for dtype in raw_dtypes]
det.data = tuple(
h5py.VirtualSource(
str(file_path),
str(data_path),
shape=shape,
dtype=dtype,
)
for file_path, data_path, shape, dtype in zip(
file_paths, normalized_data_paths, shapes, dtypes
)
)
det.image_key_control = image_keys.reshape(-1)
det.count_time = _as_quantity(count_time_s, _ureg.second, frames, "count_time")
if det.count_time is not None:
det.count_time = det.count_time.to(_ureg.second)
if sequence_number is None:
det.sequence_number = np.arange(frames, dtype=np.uint32)
else:
det.sequence_number = _as_array(sequence_number, dtype=np.uint32).reshape(-1)
if det.tomo_n is None:
det.tomo_n = int(
np.count_nonzero(image_keys.reshape(-1) == ImageKey.PROJECTION.value)
)
sample.rotation_angle = _as_quantity(
rotation_angle_deg, _ureg.degree, frames, "rotation_angle"
)
sample.x_translation = _as_quantity(
x_translation_mm, _ureg.millimeter, frames, "x_translation"
)
sample.y_translation = _as_quantity(
y_translation_mm, _ureg.millimeter, frames, "y_translation"
)
sample.z_translation = _as_quantity(
z_translation_mm, _ureg.millimeter, frames, "z_translation"
)
nx.control.data = _as_quantity(current_a, _ureg.ampere, frames, "current")
if nx.control.data is not None:
nx.control.data = nx.control.data.to(_ureg.ampere)
return nx
__all__ = ["build_nxtomo_from_inputs"]