Source code for ewokstomo.tests.test_tomobasictonxtomo

from pathlib import Path

import h5py
import numpy as np
import pint
import pytest
import warnings

from ewokstomo.tasks.nxtomomill import H5ToNx
from ewokstomo.tasks.nxtomo_utils import (
    build_nxtomo_from_inputs,
    _resolve_detector_flips,
)
from ewokstomo.tasks.tomobasictonxtomo import TomoBasicToNXtomo

from pint.errors import UnitStrippedWarning

warnings.filterwarnings("ignore", category=UnitStrippedWarning)

_ureg = pint.get_application_registry()
DATA_ROOT = Path(__file__).resolve().parent / "data"
RAW_COLLECTION = "TestEwoksTomo"


def _get_raw_h5(file_name: str) -> str:
    return str(DATA_ROOT / "RAW_DATA" / RAW_COLLECTION / file_name / f"{file_name}.h5")


def _run_task(task_cls, **inputs):
    task = task_cls(inputs=inputs)
    task.execute()


def _read_quantity(h5: h5py.File, path: str) -> pint.Quantity | np.ndarray | None:
    if path not in h5:
        return None
    ds = h5[path]
    data = np.asarray(ds[()])
    unit = None
    for key in ("units", "unit"):
        if key in ds.attrs:
            unit = ds.attrs[key]
            if isinstance(unit, bytes):
                unit = unit.decode()
            break
    if unit is None:
        return data
    try:
        return data * _ureg(unit)
    except Exception:
        return data


def _read_raw(h5: h5py.File, path: str):
    if path not in h5:
        return None
    data = h5[path][()]
    if isinstance(data, (bytes, bytearray)):
        return data.decode()
    arr = np.asarray(data)
    if arr.dtype.kind in ("S", "O"):
        flat = arr.reshape(-1)
        if flat.size == 1:
            val = flat[0]
            return val.decode() if isinstance(val, (bytes, bytearray)) else str(val)
        return [
            v.decode() if isinstance(v, (bytes, bytearray)) else str(v) for v in flat
        ]
    return data


def _assert_value(
    h5: h5py.File,
    path: str,
    expected,
    unit: pint.Unit | None = None,
    rtol: float = 1e-6,
    atol: float = 1e-8,
):
    if unit is None:
        actual = _read_raw(h5, path)
    else:
        actual = _read_quantity(h5, path)
    assert actual is not None, f"Missing {path}"

    if unit is not None and isinstance(actual, pint.Quantity):
        actual = actual.to(unit).magnitude
    actual_arr = np.asarray(actual)
    expected_arr = np.asarray(expected)
    if actual_arr.dtype.kind in ("S", "U", "O") or expected_arr.dtype.kind in (
        "S",
        "U",
        "O",
    ):
        actual_str = actual_arr.astype(str)
        expected_str = expected_arr.astype(str)
        if actual_str.shape == () and expected_str.shape == (1,):
            expected_str = expected_str.reshape(())
        elif expected_str.shape == () and actual_str.shape == (1,):
            actual_str = actual_str.reshape(())
        assert np.array_equal(actual_str, expected_str)
    else:
        assert np.allclose(actual_arr, expected_arr, rtol=rtol, atol=atol)


def _extract_links(ds: h5py.Dataset):
    if getattr(ds, "is_virtual", False):
        return [(src.file_name, src.dset_name) for src in ds.virtual_sources()]
    if ds.external:
        return [(f, p) for f, p, _, _ in ds.external]
    return [(ds.file.filename, ds.name)]


def _norm_path(path: str) -> str:
    return str(Path(path).resolve())


def _reference_tomobasic_inputs(nx_path: str, raw_h5: str) -> dict:
    inputs = {
        "nx_path": nx_path,
        "detector_data_file_paths": [
            raw_h5,
            raw_h5,
            raw_h5,
            raw_h5,
        ],
        "detector_data_h5_url": [
            "/2.1/instrument/edgetwinmic/image",
            "/3.1/instrument/edgetwinmic/image",
            "/4.1/instrument/edgetwinmic/image",
            "/5.1/instrument/edgetwinmic/image",
        ],
        "detector_data_shapes": [
            (2, 16, 16),
            (2, 16, 16),
            (11, 16, 16),
            (5, 16, 16),
        ],
        "detector_data_dtype": "uint16",
        "image_key_control": np.array(
            [
                2,
                2,
                1,
                1,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                -1,
                -1,
                -1,
                -1,
                -1,
            ],
            dtype=np.int64,
        ),
        "rotation_angle_deg": np.array(
            [
                -0.0,
                -0.0,
                -0.0,
                -0.0,
                -0.0,
                -36.000000036,
                -72.000000072,
                -108.000000108,
                -144.000000144,
                -180.00000018,
                -216.000000216,
                -252.000000252,
                -288.000000288,
                -324.000000324,
                -360.00000036,
                -360.00000036,
                -270.00000027,
                -180.00000018,
                -90.00000009,
                -0.0,
            ],
            dtype=float,
        ),
        "sample_name": "TestEwoksTomo",
        "energy_kev": 26.0,
        "title": "tomo:fullturn",
        "start_time": "2025-03-20T17:27:51.826549+01:00",
        "end_time": "2025-03-20T17:29:18.558287+01:00",
        "estimated_cor": None,
        "detector_data_axes": ["-z", "-y"],
        "detector_x_pixel_size_um": 0.6693859615384615,
        "detector_y_pixel_size_um": 0.6693859615384615,
        "sample_x_pixel_size_um": 0.669,
        "sample_y_pixel_size_um": 0.669,
        "sample_detector_distance_mm": 30.0,
        "source_sample_distance_mm": -52000.0,
        "field_of_view": "Full",
        "instrument_name": "ESRF-bm05",
        "propagation_distance_mm": 29.982702287142033,
        "count_time_s": np.array([0.01] * 20, dtype=float),
        "x_translation_mm": np.array([0.0] * 20, dtype=float),
        "y_translation_mm": np.array([-118.0] * 20, dtype=float),
        "z_translation_mm": np.array([0.0] * 20, dtype=float),
        "current_a": np.array(
            [
                0.19122,
                0.19122,
                0.19116999999999998,
                0.19116999999999998,
                0.19111000000000003,
                0.19108000000000003,
                0.19103,
                0.19099000000000002,
                0.19096000000000002,
                0.19091999999999998,
                0.19089,
                0.19085,
                0.19082,
                0.19077000000000002,
                0.19072999999999998,
                0.19072999999999998,
                0.19066,
                0.19061000000000003,
                0.19056,
                0.1905,
            ],
            dtype=float,
        ),
        "sequence_number": np.array(list(range(20)), dtype=np.uint32),
    }
    return inputs


def _expected_detector_flip_angles(inputs: dict) -> tuple[int, int]:
    lr_flip, ud_flip = _resolve_detector_flips(inputs.get("detector_data_axes"))
    return (180 if ud_flip else 0, 180 if lr_flip else 0)


def _detector_data_axes_from_nx(nx_path: Path) -> list[str]:
    with h5py.File(nx_path, "r") as h5f:
        rx = int(
            np.asarray(h5f["/entry0000/instrument/detector/transformations/rx"][()])
        )
        ry = int(
            np.asarray(h5f["/entry0000/instrument/detector/transformations/ry"][()])
        )
    return ["z" if rx == 180 else "-z", "-y" if ry == 180 else "y"]


def _assert_matches_nxtomomill_ground_truth(nx_ground_truth: Path, nx_tested: Path):
    common_dataset_paths = [
        "/entry0000/control/data",
        "/entry0000/definition",
        "/entry0000/instrument/beam/incident_energy",
        "/entry0000/instrument/detector/count_time",
        "/entry0000/instrument/detector/data",
        "/entry0000/instrument/detector/distance",
        "/entry0000/instrument/detector/field_of_view",
        "/entry0000/instrument/detector/image_key",
        "/entry0000/instrument/detector/image_key_control",
        "/entry0000/instrument/detector/sequence_number",
        "/entry0000/instrument/detector/tomo_n",
        "/entry0000/instrument/detector/x_pixel_size",
        "/entry0000/instrument/detector/y_pixel_size",
        "/entry0000/instrument/detector/transformations/gravity",
        "/entry0000/instrument/detector/transformations/rx",
        "/entry0000/instrument/detector/transformations/ry",
        "/entry0000/instrument/name",
        "/entry0000/instrument/source/distance",
        "/entry0000/instrument/source/name",
        "/entry0000/instrument/source/probe",
        "/entry0000/instrument/source/type",
        "/entry0000/sample/name",
        "/entry0000/sample/propagation_distance",
        "/entry0000/sample/rotation_angle",
        "/entry0000/sample/x_pixel_size",
        "/entry0000/sample/x_translation",
        "/entry0000/sample/y_pixel_size",
        "/entry0000/sample/y_translation",
        "/entry0000/sample/z_translation",
        "/entry0000/end_time",
        "/entry0000/start_time",
    ]
    with (
        h5py.File(nx_ground_truth, "r") as out_nxtomomill,
        h5py.File(nx_tested, "r") as out_tomobasic,
    ):
        links_nxtomomill = _extract_links(
            out_nxtomomill["/entry0000/instrument/detector/data"]
        )
        links_tomobasic = _extract_links(
            out_tomobasic["/entry0000/instrument/detector/data"]
        )
        linked_files_nxtomomill = [_norm_path(p) for p, _ in links_nxtomomill]
        linked_files_tomobasic = [_norm_path(p) for p, _ in links_tomobasic]
        assert linked_files_nxtomomill == linked_files_tomobasic

        for dataset_path in common_dataset_paths:
            assert dataset_path in out_nxtomomill
            assert dataset_path in out_tomobasic
            value_nxtomomill = np.asarray(out_nxtomomill[dataset_path][()])
            value_tomobasic = np.asarray(out_tomobasic[dataset_path][()])
            assert value_nxtomomill.shape == value_tomobasic.shape
            if value_nxtomomill.dtype.kind in (
                "S",
                "U",
                "O",
            ) or value_tomobasic.dtype.kind in (
                "S",
                "U",
                "O",
            ):
                assert np.array_equal(
                    value_nxtomomill.astype(str), value_tomobasic.astype(str)
                )
            else:
                assert np.allclose(
                    value_nxtomomill, value_tomobasic, rtol=1e-7, atol=1e-10
                )


[docs] def test_tomobasictonxtomo_matches_nxtomomill_ground_truth(tmp_path): raw_h5 = _get_raw_h5("TestEwoksTomo_0010") nx_ground_truth = tmp_path / "nxtomomill_ground_truth.nx" nx_tomobasic = tmp_path / "tomobasic.nx" H5ToNx( inputs={ "bliss_hdf5_path": raw_h5, "nx_path": str(nx_ground_truth), } ).execute() inputs = _reference_tomobasic_inputs(nx_path=str(nx_tomobasic), raw_h5=raw_h5) inputs["detector_data_axes"] = _detector_data_axes_from_nx(nx_ground_truth) _run_task( TomoBasicToNXtomo, **inputs, ) _assert_matches_nxtomomill_ground_truth(nx_ground_truth, nx_tomobasic)
[docs] def test_build_from_inputs_matches_reference(tmp_path): raw_h5 = _get_raw_h5("TestEwoksTomo_0010") nx_basic = tmp_path / "basic.nx" inputs = _reference_tomobasic_inputs(nx_path=str(nx_basic), raw_h5=raw_h5) _run_task(TomoBasicToNXtomo, **inputs) expected_rx, expected_ry = _expected_detector_flip_angles(inputs) with h5py.File(nx_basic, "r") as out: data_ds = out["/entry0000/instrument/detector/data"] links = _extract_links(data_ds) assert [(_norm_path(p), d) for p, d in links] == [ (_norm_path(raw_h5), "/2.1/instrument/edgetwinmic/image"), (_norm_path(raw_h5), "/3.1/instrument/edgetwinmic/image"), (_norm_path(raw_h5), "/4.1/instrument/edgetwinmic/image"), (_norm_path(raw_h5), "/5.1/instrument/edgetwinmic/image"), ] if "/entry0000/instrument/detector/image_key_control" in out: _assert_value( out, "/entry0000/instrument/detector/image_key_control", inputs["image_key_control"], ) if "/entry0000/instrument/detector/image_key" in out: expected_image_key = np.array(inputs["image_key_control"], dtype=int) expected_image_key[expected_image_key == -1] = 0 _assert_value( out, "/entry0000/instrument/detector/image_key", expected_image_key, ) _assert_value( out, "/entry0000/sample/rotation_angle", inputs["rotation_angle_deg"], unit=_ureg.degree, ) _assert_value(out, "/entry0000/sample/name", inputs["sample_name"]) _assert_value( out, "/entry0000/instrument/beam/incident_energy", inputs["energy_kev"], unit=_ureg.keV, ) _assert_value(out, "/entry0000/start_time", inputs["start_time"]) _assert_value(out, "/entry0000/end_time", inputs["end_time"]) assert "/entry0000/bliss_original_files" not in out assert ( "/entry0000/instrument/detector/x_rotation_axis_pixel_position" not in out ) _assert_value( out, "/entry0000/instrument/detector/x_pixel_size", inputs["detector_x_pixel_size_um"], unit=_ureg.micrometer, ) _assert_value( out, "/entry0000/instrument/detector/y_pixel_size", inputs["detector_y_pixel_size_um"], unit=_ureg.micrometer, ) _assert_value( out, "/entry0000/sample/x_pixel_size", inputs["sample_x_pixel_size_um"], unit=_ureg.micrometer, ) _assert_value( out, "/entry0000/sample/y_pixel_size", inputs["sample_y_pixel_size_um"], unit=_ureg.micrometer, ) _assert_value( out, "/entry0000/instrument/detector/distance", inputs["sample_detector_distance_mm"], unit=_ureg.millimeter, ) _assert_value( out, "/entry0000/instrument/source/distance", -abs(inputs["source_sample_distance_mm"]), unit=_ureg.millimeter, ) _assert_value( out, "/entry0000/instrument/detector/field_of_view", inputs["field_of_view"], ) _assert_value( out, "/entry0000/instrument/detector/tomo_n", np.count_nonzero(inputs["image_key_control"] == 0), ) _assert_value( out, "/entry0000/instrument/name", inputs["instrument_name"], ) _assert_value( out, "/entry0000/sample/propagation_distance", inputs["propagation_distance_mm"], unit=_ureg.millimeter, ) _assert_value( out, "/entry0000/instrument/detector/count_time", inputs["count_time_s"], unit=_ureg.second, ) _assert_value( out, "/entry0000/instrument/detector/transformations/rx", expected_rx, ) _assert_value( out, "/entry0000/instrument/detector/transformations/ry", expected_ry, ) _assert_value( out, "/entry0000/sample/x_translation", inputs["x_translation_mm"], unit=_ureg.millimeter, ) _assert_value( out, "/entry0000/sample/y_translation", inputs["y_translation_mm"], unit=_ureg.millimeter, ) _assert_value( out, "/entry0000/sample/z_translation", inputs["z_translation_mm"], unit=_ureg.millimeter, ) _assert_value( out, "/entry0000/control/data", inputs["current_a"], unit=_ureg.ampere, ) _assert_value( out, "/entry0000/instrument/detector/sequence_number", inputs["sequence_number"], )
[docs] def test_build_from_inputs_rejects_h5_url_entries(tmp_path): raw_h5 = _get_raw_h5("TestEwoksTomo_0010") nx_basic = tmp_path / "basic_url_input.nx" inputs = _reference_tomobasic_inputs(nx_path=str(nx_basic), raw_h5=raw_h5) inputs["detector_data_h5_url"] = [ f"{raw_h5}::{path}" for path in inputs["detector_data_h5_url"] ] with pytest.raises( RuntimeError, match="detector_data_h5_url must contain only HDF5 dataset paths" ): _run_task(TomoBasicToNXtomo, **inputs)
[docs] def test_build_from_inputs_succeeds_on_missing_source_dataset(tmp_path): nx_basic = tmp_path / "basic_invalid_input.nx" inputs = _reference_tomobasic_inputs(nx_path=str(nx_basic), raw_h5="unused.h5") inputs["detector_data_file_paths"] = [ str(tmp_path / f"missing_{index}.h5") for index in range(4) ] inputs["detector_data_h5_url"][0] = "/does/not/exist" _run_task(TomoBasicToNXtomo, **inputs) with h5py.File(nx_basic, "r") as out: data_ds = out["/entry0000/instrument/detector/data"] assert data_ds.is_virtual links = _extract_links(data_ds) assert links[0][1] == "/does/not/exist"
[docs] def test_build_from_inputs_virtual_sources_without_opening_sources(tmp_path): nx_basic = tmp_path / "basic_virtual_source.nx" inputs = _reference_tomobasic_inputs(nx_path=str(nx_basic), raw_h5="unused.h5") inputs["detector_data_file_paths"] = [ str(tmp_path / f"missing_{index}.h5") for index in range(4) ] inputs["detector_data_h5_url"] = [f"/entry/data_{index}" for index in range(4)] inputs["detector_data_shapes"] = [(5, 6, 7)] * 4 inputs["detector_data_dtype"] = "uint16" _run_task(TomoBasicToNXtomo, **inputs) with h5py.File(nx_basic, "r") as out: data_ds = out["/entry0000/instrument/detector/data"] assert data_ds.is_virtual assert data_ds.shape == (20, 6, 7) assert data_ds.dtype == np.dtype("uint16") links = _extract_links(data_ds) assert len(links) == 4 assert [d for _, d in links] == [f"/entry/data_{index}" for index in range(4)] for index, (file_name, _) in enumerate(links): assert file_name.endswith(f"missing_{index}.h5") for index in range(4): assert not (tmp_path / f"missing_{index}.h5").exists()
[docs] def test_build_from_inputs_supports_detector_data_axes(tmp_path): raw_h5 = _get_raw_h5("TestEwoksTomo_0010") nx_basic = tmp_path / "basic_detector_data_axes.nx" inputs = _reference_tomobasic_inputs(nx_path=str(nx_basic), raw_h5=raw_h5) inputs["detector_data_axes"] = np.array([b"-z", b"-y"]) _run_task(TomoBasicToNXtomo, **inputs) with h5py.File(nx_basic, "r") as out: _assert_value( out, "/entry0000/instrument/detector/transformations/rx", 0, ) _assert_value( out, "/entry0000/instrument/detector/transformations/ry", 180, )
[docs] def test_build_from_inputs_detector_data_axes_no_flip(tmp_path): raw_h5 = _get_raw_h5("TestEwoksTomo_0010") nx_basic = tmp_path / "basic_detector_data_axes_no_flip.nx" inputs = _reference_tomobasic_inputs(nx_path=str(nx_basic), raw_h5=raw_h5) inputs["detector_data_axes"] = np.array([b"-z", b"y"]) _run_task(TomoBasicToNXtomo, **inputs) with h5py.File(nx_basic, "r") as out: _assert_value( out, "/entry0000/instrument/detector/transformations/rx", 0, ) _assert_value( out, "/entry0000/instrument/detector/transformations/ry", 0, )
[docs] def test_build_from_inputs_coerces_integer_translation_inputs_to_float64(tmp_path): raw_h5 = _get_raw_h5("TestEwoksTomo_0010") nx_basic = tmp_path / "basic_integer_translation.nx" inputs = _reference_tomobasic_inputs(nx_path=str(nx_basic), raw_h5=raw_h5) inputs["y_translation_mm"] = [-118] _run_task(TomoBasicToNXtomo, **inputs) with h5py.File(nx_basic, "r") as out: y_translation = out["/entry0000/sample/y_translation"] assert y_translation.dtype == np.dtype("float64") np.testing.assert_allclose( np.asarray(y_translation[()]), np.full(inputs["image_key_control"].size, -118.0), )
[docs] def test_build_from_inputs_sets_energy_without_deprecation_warning(tmp_path): raw_h5 = _get_raw_h5("TestEwoksTomo_0010") inputs = _reference_tomobasic_inputs( nx_path=str(tmp_path / "unused.nx"), raw_h5=raw_h5 ) with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") nx_obj = build_nxtomo_from_inputs( **{key: value for key, value in inputs.items() if key != "nx_path"} ) assert not [ warning for warning in caught if issubclass(warning.category, DeprecationWarning) ] assert nx_obj.energy is not None assert nx_obj.energy.to(_ureg.keV).magnitude == pytest.approx(inputs["energy_kev"]) beam = getattr(nx_obj.instrument, "beam", None) if beam is not None: assert beam.incident_energy is not None assert beam.incident_energy.to(_ureg.keV).magnitude == pytest.approx( inputs["energy_kev"] )