Source code for ewokstomo.tasks.nxtomo_utils

from __future__ import annotations

import logging
import warnings

import h5py
import numpy as np
import pint
from nxtomo import NXtomo
from nxtomo.nxobject.nxdetector import ImageKey

from pint.errors import UnitStrippedWarning

_logger = logging.getLogger(__name__)
_ureg = pint.get_application_registry()

warnings.filterwarnings("ignore", category=UnitStrippedWarning)


def _warn(msg: str, *args):
    _logger.warning(msg, *args)


def _as_array(val, dtype=None):
    if val is None:
        return None
    arr = np.asarray(val)
    if dtype is not None:
        try:
            arr = arr.astype(dtype)
        except Exception:
            _warn("Failed to cast %s to %s", val, dtype)
    return arr


def _ensure_length(arr, length: int, name: str):
    if arr is None:
        return None
    if arr.shape == ():
        return np.full(length, arr.item())
    flat = arr.reshape(-1)
    if flat.size != length:
        _warn("Size mismatch for %s (%s vs %s), resizing", name, flat.size, length)
        flat = np.resize(flat, length)
    return flat


def _as_quantity(val, unit: pint.Unit, length: int | None = None, name: str = ""):
    arr = _as_array(val, dtype=np.float64)
    if arr is None:
        return None
    if length is not None:
        arr = _ensure_length(arr, length, name or "quantity")
    return arr * unit


def _decode_data_axis(value) -> str:
    if isinstance(value, (bytes, np.bytes_)):
        value = value.decode()
    if isinstance(value, np.str_):
        value = str(value)
    if not isinstance(value, str):
        raise ValueError(f"Unsupported detector data axis value: {value!r}")
    return value.strip()


def _resolve_detector_flips(detector_data_axes=None) -> tuple[bool, bool]:
    if detector_data_axes is None:
        return False, False

    axes = np.asarray(detector_data_axes, dtype=object).reshape(-1)
    if axes.size != 2:
        raise ValueError("detector_data_axes must contain exactly 2 values (ud, lr)")

    soft_ud_flip, soft_lr_flip = map(_decode_data_axis, axes.tolist())
    try:
        ud_flipped = {"-z": False, "z": True}[soft_ud_flip]
    except KeyError as exc:
        raise ValueError(
            f"Unsupported detector up-down axis: {soft_ud_flip!r}"
        ) from exc
    try:
        lr_flipped = {"y": False, "-y": True}[soft_lr_flip]
    except KeyError as exc:
        raise ValueError(
            f"Unsupported detector left-right axis: {soft_lr_flip!r}"
        ) from exc
    return lr_flipped, ud_flipped


[docs] def build_nxtomo_from_inputs( *, energy_kev=None, title=None, start_time=None, end_time=None, group_size=None, estimated_cor=None, detector_data_axes=None, detector_x_pixel_size_um=None, detector_y_pixel_size_um=None, sample_x_pixel_size_um=None, sample_y_pixel_size_um=None, sample_detector_distance_mm=None, source_sample_distance_mm=None, field_of_view=None, tomo_n=None, instrument_name=None, sample_name=None, propagation_distance_mm=None, detector_data_file_paths=None, detector_data_h5_url=None, detector_data_shapes=None, detector_data_dtype=None, image_key_control=None, count_time_s=None, rotation_angle_deg=None, x_translation_mm=None, y_translation_mm=None, z_translation_mm=None, current_a=None, sequence_number=None, ) -> NXtomo: image_keys = _as_array(image_key_control, dtype=np.int64) if image_keys is None: raise ValueError("image_key_control is required") frames = image_keys.reshape(-1).size if sample_name is None: raise ValueError("sample_name is required") if rotation_angle_deg is None: raise ValueError("rotation_angle_deg is required") nx = NXtomo() nx.energy = _as_quantity(energy_kev, _ureg.keV) nx.title = title nx.start_time = start_time nx.end_time = end_time nx.group_size = group_size det = nx.instrument.detector sample = nx.sample if estimated_cor is not None: try: det.x_rotation_axis_pixel_position = float(estimated_cor) except Exception: _warn("Bad COR %s", estimated_cor) lr_flipped, ud_flipped = _resolve_detector_flips(detector_data_axes) det.set_transformation_from_lr_flipped(lr_flipped) det.set_transformation_from_ud_flipped(ud_flipped) det.x_pixel_size = _as_quantity(detector_x_pixel_size_um, _ureg.micrometer) det.y_pixel_size = _as_quantity(detector_y_pixel_size_um, _ureg.micrometer) sample.x_pixel_size = _as_quantity(sample_x_pixel_size_um, _ureg.micrometer) sample.y_pixel_size = _as_quantity(sample_y_pixel_size_um, _ureg.micrometer) det.distance = _as_quantity(sample_detector_distance_mm, _ureg.millimeter) if det.distance is not None: det.distance = abs(det.distance) src_dist = _as_quantity(source_sample_distance_mm, _ureg.millimeter) if src_dist is not None: nx.instrument.source.distance = -abs(src_dist) det.field_of_view = field_of_view det.tomo_n = tomo_n nx.instrument.name = instrument_name sample.name = sample_name sample.propagation_distance = _as_quantity( propagation_distance_mm, _ureg.millimeter ) if detector_data_file_paths is None or detector_data_h5_url is None: raise ValueError("detector_data_file_paths and detector_data_h5_url required") file_paths = _as_array(detector_data_file_paths) data_paths = _as_array(detector_data_h5_url) if file_paths is None or data_paths is None: raise ValueError("detector_data_file_paths and detector_data_h5_url required") file_paths = file_paths.reshape(-1) data_paths = data_paths.reshape(-1) if file_paths.size != data_paths.size: raise ValueError( "detector_data_file_paths and detector_data_h5_url size mismatch" ) normalized_data_paths = np.asarray(data_paths, dtype=object) for index, data_ref in enumerate(normalized_data_paths): data_path = str(data_ref) if not data_path: raise ValueError( f"Invalid detector data reference at index {index}: missing dataset path" ) if "::" in data_path: raise ValueError( "detector_data_h5_url must contain only HDF5 dataset paths. " f"Found URL-like value at index {index}: {data_path}" ) normalized_data_paths[index] = data_path if detector_data_shapes is None: raise ValueError("detector_data_shapes is required") if detector_data_dtype is None: raise ValueError("detector_data_dtype is required") raw_shapes = np.asarray(detector_data_shapes, dtype=object).tolist() if len(raw_shapes) != file_paths.size: raise ValueError( "detector_data_shapes size mismatch with detector_data_file_paths and detector_data_h5_url" ) shapes: list[tuple[int, int, int]] = [] for index, shape in enumerate(raw_shapes): arr = np.asarray(shape, dtype=np.int64).reshape(-1) if arr.size != 3: raise ValueError( f"detector_data_shapes[{index}] must have exactly 3 integers (frames, ny, nx)" ) if np.any(arr <= 0): raise ValueError( f"detector_data_shapes[{index}] must contain positive integers, got {tuple(arr.tolist())}" ) shapes.append((int(arr[0]), int(arr[1]), int(arr[2]))) raw_dtypes = np.asarray(detector_data_dtype, dtype=object).reshape(-1).tolist() if len(raw_dtypes) == 1: dtypes = [np.dtype(raw_dtypes[0])] * file_paths.size else: if len(raw_dtypes) != file_paths.size: raise ValueError( "detector_data_dtype size mismatch with detector_data_file_paths and detector_data_h5_url" ) dtypes = [np.dtype(dtype) for dtype in raw_dtypes] det.data = tuple( h5py.VirtualSource( str(file_path), str(data_path), shape=shape, dtype=dtype, ) for file_path, data_path, shape, dtype in zip( file_paths, normalized_data_paths, shapes, dtypes ) ) det.image_key_control = image_keys.reshape(-1) det.count_time = _as_quantity(count_time_s, _ureg.second, frames, "count_time") if det.count_time is not None: det.count_time = det.count_time.to(_ureg.second) if sequence_number is None: det.sequence_number = np.arange(frames, dtype=np.uint32) else: det.sequence_number = _as_array(sequence_number, dtype=np.uint32).reshape(-1) if det.tomo_n is None: det.tomo_n = int( np.count_nonzero(image_keys.reshape(-1) == ImageKey.PROJECTION.value) ) sample.rotation_angle = _as_quantity( rotation_angle_deg, _ureg.degree, frames, "rotation_angle" ) sample.x_translation = _as_quantity( x_translation_mm, _ureg.millimeter, frames, "x_translation" ) sample.y_translation = _as_quantity( y_translation_mm, _ureg.millimeter, frames, "y_translation" ) sample.z_translation = _as_quantity( z_translation_mm, _ureg.millimeter, frames, "z_translation" ) nx.control.data = _as_quantity(current_a, _ureg.ampere, frames, "current") if nx.control.data is not None: nx.control.data = nx.control.data.to(_ureg.ampere) return nx
__all__ = ["build_nxtomo_from_inputs"]