Source code for ewokstomo.tests.test_tomobasictonxtomo
from pathlib import Path
import h5py
import numpy as np
import pint
import pytest
import warnings
from ewokstomo.tasks.nxtomomill import H5ToNx
from ewokstomo.tasks.nxtomo_utils import (
build_nxtomo_from_inputs,
_resolve_detector_flips,
)
from ewokstomo.tasks.tomobasictonxtomo import TomoBasicToNXtomo
from pint.errors import UnitStrippedWarning
warnings.filterwarnings("ignore", category=UnitStrippedWarning)
_ureg = pint.get_application_registry()
DATA_ROOT = Path(__file__).resolve().parent / "data"
RAW_COLLECTION = "TestEwoksTomo"
def _get_raw_h5(file_name: str) -> str:
return str(DATA_ROOT / "RAW_DATA" / RAW_COLLECTION / file_name / f"{file_name}.h5")
def _run_task(task_cls, **inputs):
task = task_cls(inputs=inputs)
task.execute()
def _read_quantity(h5: h5py.File, path: str) -> pint.Quantity | np.ndarray | None:
if path not in h5:
return None
ds = h5[path]
data = np.asarray(ds[()])
unit = None
for key in ("units", "unit"):
if key in ds.attrs:
unit = ds.attrs[key]
if isinstance(unit, bytes):
unit = unit.decode()
break
if unit is None:
return data
try:
return data * _ureg(unit)
except Exception:
return data
def _read_raw(h5: h5py.File, path: str):
if path not in h5:
return None
data = h5[path][()]
if isinstance(data, (bytes, bytearray)):
return data.decode()
arr = np.asarray(data)
if arr.dtype.kind in ("S", "O"):
flat = arr.reshape(-1)
if flat.size == 1:
val = flat[0]
return val.decode() if isinstance(val, (bytes, bytearray)) else str(val)
return [
v.decode() if isinstance(v, (bytes, bytearray)) else str(v) for v in flat
]
return data
def _assert_value(
h5: h5py.File,
path: str,
expected,
unit: pint.Unit | None = None,
rtol: float = 1e-6,
atol: float = 1e-8,
):
if unit is None:
actual = _read_raw(h5, path)
else:
actual = _read_quantity(h5, path)
assert actual is not None, f"Missing {path}"
if unit is not None and isinstance(actual, pint.Quantity):
actual = actual.to(unit).magnitude
actual_arr = np.asarray(actual)
expected_arr = np.asarray(expected)
if actual_arr.dtype.kind in ("S", "U", "O") or expected_arr.dtype.kind in (
"S",
"U",
"O",
):
actual_str = actual_arr.astype(str)
expected_str = expected_arr.astype(str)
if actual_str.shape == () and expected_str.shape == (1,):
expected_str = expected_str.reshape(())
elif expected_str.shape == () and actual_str.shape == (1,):
actual_str = actual_str.reshape(())
assert np.array_equal(actual_str, expected_str)
else:
assert np.allclose(actual_arr, expected_arr, rtol=rtol, atol=atol)
def _extract_links(ds: h5py.Dataset):
if getattr(ds, "is_virtual", False):
return [(src.file_name, src.dset_name) for src in ds.virtual_sources()]
if ds.external:
return [(f, p) for f, p, _, _ in ds.external]
return [(ds.file.filename, ds.name)]
def _norm_path(path: str) -> str:
return str(Path(path).resolve())
def _reference_tomobasic_inputs(nx_path: str, raw_h5: str) -> dict:
inputs = {
"nx_path": nx_path,
"detector_data_file_paths": [
raw_h5,
raw_h5,
raw_h5,
raw_h5,
],
"detector_data_h5_url": [
"/2.1/instrument/edgetwinmic/image",
"/3.1/instrument/edgetwinmic/image",
"/4.1/instrument/edgetwinmic/image",
"/5.1/instrument/edgetwinmic/image",
],
"detector_data_shapes": [
(2, 16, 16),
(2, 16, 16),
(11, 16, 16),
(5, 16, 16),
],
"detector_data_dtype": "uint16",
"image_key_control": np.array(
[
2,
2,
1,
1,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
-1,
-1,
-1,
-1,
-1,
],
dtype=np.int64,
),
"rotation_angle_deg": np.array(
[
-0.0,
-0.0,
-0.0,
-0.0,
-0.0,
-36.000000036,
-72.000000072,
-108.000000108,
-144.000000144,
-180.00000018,
-216.000000216,
-252.000000252,
-288.000000288,
-324.000000324,
-360.00000036,
-360.00000036,
-270.00000027,
-180.00000018,
-90.00000009,
-0.0,
],
dtype=float,
),
"sample_name": "TestEwoksTomo",
"energy_kev": 26.0,
"title": "tomo:fullturn",
"start_time": "2025-03-20T17:27:51.826549+01:00",
"end_time": "2025-03-20T17:29:18.558287+01:00",
"estimated_cor": None,
"detector_data_axes": ["-z", "-y"],
"detector_x_pixel_size_um": 0.6693859615384615,
"detector_y_pixel_size_um": 0.6693859615384615,
"sample_x_pixel_size_um": 0.669,
"sample_y_pixel_size_um": 0.669,
"sample_detector_distance_mm": 30.0,
"source_sample_distance_mm": -52000.0,
"field_of_view": "Full",
"instrument_name": "ESRF-bm05",
"propagation_distance_mm": 29.982702287142033,
"count_time_s": np.array([0.01] * 20, dtype=float),
"x_translation_mm": np.array([0.0] * 20, dtype=float),
"y_translation_mm": np.array([-118.0] * 20, dtype=float),
"z_translation_mm": np.array([0.0] * 20, dtype=float),
"current_a": np.array(
[
0.19122,
0.19122,
0.19116999999999998,
0.19116999999999998,
0.19111000000000003,
0.19108000000000003,
0.19103,
0.19099000000000002,
0.19096000000000002,
0.19091999999999998,
0.19089,
0.19085,
0.19082,
0.19077000000000002,
0.19072999999999998,
0.19072999999999998,
0.19066,
0.19061000000000003,
0.19056,
0.1905,
],
dtype=float,
),
"sequence_number": np.array(list(range(20)), dtype=np.uint32),
}
return inputs
def _expected_detector_flip_angles(inputs: dict) -> tuple[int, int]:
lr_flip, ud_flip = _resolve_detector_flips(inputs.get("detector_data_axes"))
return (180 if ud_flip else 0, 180 if lr_flip else 0)
def _detector_data_axes_from_nx(nx_path: Path) -> list[str]:
with h5py.File(nx_path, "r") as h5f:
rx = int(
np.asarray(h5f["/entry0000/instrument/detector/transformations/rx"][()])
)
ry = int(
np.asarray(h5f["/entry0000/instrument/detector/transformations/ry"][()])
)
return ["z" if rx == 180 else "-z", "-y" if ry == 180 else "y"]
def _assert_matches_nxtomomill_ground_truth(nx_ground_truth: Path, nx_tested: Path):
common_dataset_paths = [
"/entry0000/control/data",
"/entry0000/definition",
"/entry0000/instrument/beam/incident_energy",
"/entry0000/instrument/detector/count_time",
"/entry0000/instrument/detector/data",
"/entry0000/instrument/detector/distance",
"/entry0000/instrument/detector/field_of_view",
"/entry0000/instrument/detector/image_key",
"/entry0000/instrument/detector/image_key_control",
"/entry0000/instrument/detector/sequence_number",
"/entry0000/instrument/detector/tomo_n",
"/entry0000/instrument/detector/x_pixel_size",
"/entry0000/instrument/detector/y_pixel_size",
"/entry0000/instrument/detector/transformations/gravity",
"/entry0000/instrument/detector/transformations/rx",
"/entry0000/instrument/detector/transformations/ry",
"/entry0000/instrument/name",
"/entry0000/instrument/source/distance",
"/entry0000/instrument/source/name",
"/entry0000/instrument/source/probe",
"/entry0000/instrument/source/type",
"/entry0000/sample/name",
"/entry0000/sample/propagation_distance",
"/entry0000/sample/rotation_angle",
"/entry0000/sample/x_pixel_size",
"/entry0000/sample/x_translation",
"/entry0000/sample/y_pixel_size",
"/entry0000/sample/y_translation",
"/entry0000/sample/z_translation",
"/entry0000/end_time",
"/entry0000/start_time",
]
with (
h5py.File(nx_ground_truth, "r") as out_nxtomomill,
h5py.File(nx_tested, "r") as out_tomobasic,
):
links_nxtomomill = _extract_links(
out_nxtomomill["/entry0000/instrument/detector/data"]
)
links_tomobasic = _extract_links(
out_tomobasic["/entry0000/instrument/detector/data"]
)
linked_files_nxtomomill = [_norm_path(p) for p, _ in links_nxtomomill]
linked_files_tomobasic = [_norm_path(p) for p, _ in links_tomobasic]
assert linked_files_nxtomomill == linked_files_tomobasic
for dataset_path in common_dataset_paths:
assert dataset_path in out_nxtomomill
assert dataset_path in out_tomobasic
value_nxtomomill = np.asarray(out_nxtomomill[dataset_path][()])
value_tomobasic = np.asarray(out_tomobasic[dataset_path][()])
assert value_nxtomomill.shape == value_tomobasic.shape
if value_nxtomomill.dtype.kind in (
"S",
"U",
"O",
) or value_tomobasic.dtype.kind in (
"S",
"U",
"O",
):
assert np.array_equal(
value_nxtomomill.astype(str), value_tomobasic.astype(str)
)
else:
assert np.allclose(
value_nxtomomill, value_tomobasic, rtol=1e-7, atol=1e-10
)
[docs]
def test_tomobasictonxtomo_matches_nxtomomill_ground_truth(tmp_path):
raw_h5 = _get_raw_h5("TestEwoksTomo_0010")
nx_ground_truth = tmp_path / "nxtomomill_ground_truth.nx"
nx_tomobasic = tmp_path / "tomobasic.nx"
H5ToNx(
inputs={
"bliss_hdf5_path": raw_h5,
"nx_path": str(nx_ground_truth),
}
).execute()
inputs = _reference_tomobasic_inputs(nx_path=str(nx_tomobasic), raw_h5=raw_h5)
inputs["detector_data_axes"] = _detector_data_axes_from_nx(nx_ground_truth)
_run_task(
TomoBasicToNXtomo,
**inputs,
)
_assert_matches_nxtomomill_ground_truth(nx_ground_truth, nx_tomobasic)
[docs]
def test_build_from_inputs_matches_reference(tmp_path):
raw_h5 = _get_raw_h5("TestEwoksTomo_0010")
nx_basic = tmp_path / "basic.nx"
inputs = _reference_tomobasic_inputs(nx_path=str(nx_basic), raw_h5=raw_h5)
_run_task(TomoBasicToNXtomo, **inputs)
expected_rx, expected_ry = _expected_detector_flip_angles(inputs)
with h5py.File(nx_basic, "r") as out:
data_ds = out["/entry0000/instrument/detector/data"]
links = _extract_links(data_ds)
assert [(_norm_path(p), d) for p, d in links] == [
(_norm_path(raw_h5), "/2.1/instrument/edgetwinmic/image"),
(_norm_path(raw_h5), "/3.1/instrument/edgetwinmic/image"),
(_norm_path(raw_h5), "/4.1/instrument/edgetwinmic/image"),
(_norm_path(raw_h5), "/5.1/instrument/edgetwinmic/image"),
]
if "/entry0000/instrument/detector/image_key_control" in out:
_assert_value(
out,
"/entry0000/instrument/detector/image_key_control",
inputs["image_key_control"],
)
if "/entry0000/instrument/detector/image_key" in out:
expected_image_key = np.array(inputs["image_key_control"], dtype=int)
expected_image_key[expected_image_key == -1] = 0
_assert_value(
out,
"/entry0000/instrument/detector/image_key",
expected_image_key,
)
_assert_value(
out,
"/entry0000/sample/rotation_angle",
inputs["rotation_angle_deg"],
unit=_ureg.degree,
)
_assert_value(out, "/entry0000/sample/name", inputs["sample_name"])
_assert_value(
out,
"/entry0000/instrument/beam/incident_energy",
inputs["energy_kev"],
unit=_ureg.keV,
)
_assert_value(out, "/entry0000/start_time", inputs["start_time"])
_assert_value(out, "/entry0000/end_time", inputs["end_time"])
assert "/entry0000/bliss_original_files" not in out
assert (
"/entry0000/instrument/detector/x_rotation_axis_pixel_position" not in out
)
_assert_value(
out,
"/entry0000/instrument/detector/x_pixel_size",
inputs["detector_x_pixel_size_um"],
unit=_ureg.micrometer,
)
_assert_value(
out,
"/entry0000/instrument/detector/y_pixel_size",
inputs["detector_y_pixel_size_um"],
unit=_ureg.micrometer,
)
_assert_value(
out,
"/entry0000/sample/x_pixel_size",
inputs["sample_x_pixel_size_um"],
unit=_ureg.micrometer,
)
_assert_value(
out,
"/entry0000/sample/y_pixel_size",
inputs["sample_y_pixel_size_um"],
unit=_ureg.micrometer,
)
_assert_value(
out,
"/entry0000/instrument/detector/distance",
inputs["sample_detector_distance_mm"],
unit=_ureg.millimeter,
)
_assert_value(
out,
"/entry0000/instrument/source/distance",
-abs(inputs["source_sample_distance_mm"]),
unit=_ureg.millimeter,
)
_assert_value(
out,
"/entry0000/instrument/detector/field_of_view",
inputs["field_of_view"],
)
_assert_value(
out,
"/entry0000/instrument/detector/tomo_n",
np.count_nonzero(inputs["image_key_control"] == 0),
)
_assert_value(
out,
"/entry0000/instrument/name",
inputs["instrument_name"],
)
_assert_value(
out,
"/entry0000/sample/propagation_distance",
inputs["propagation_distance_mm"],
unit=_ureg.millimeter,
)
_assert_value(
out,
"/entry0000/instrument/detector/count_time",
inputs["count_time_s"],
unit=_ureg.second,
)
_assert_value(
out,
"/entry0000/instrument/detector/transformations/rx",
expected_rx,
)
_assert_value(
out,
"/entry0000/instrument/detector/transformations/ry",
expected_ry,
)
_assert_value(
out,
"/entry0000/sample/x_translation",
inputs["x_translation_mm"],
unit=_ureg.millimeter,
)
_assert_value(
out,
"/entry0000/sample/y_translation",
inputs["y_translation_mm"],
unit=_ureg.millimeter,
)
_assert_value(
out,
"/entry0000/sample/z_translation",
inputs["z_translation_mm"],
unit=_ureg.millimeter,
)
_assert_value(
out,
"/entry0000/control/data",
inputs["current_a"],
unit=_ureg.ampere,
)
_assert_value(
out,
"/entry0000/instrument/detector/sequence_number",
inputs["sequence_number"],
)
[docs]
def test_build_from_inputs_rejects_h5_url_entries(tmp_path):
raw_h5 = _get_raw_h5("TestEwoksTomo_0010")
nx_basic = tmp_path / "basic_url_input.nx"
inputs = _reference_tomobasic_inputs(nx_path=str(nx_basic), raw_h5=raw_h5)
inputs["detector_data_h5_url"] = [
f"{raw_h5}::{path}" for path in inputs["detector_data_h5_url"]
]
with pytest.raises(
RuntimeError, match="detector_data_h5_url must contain only HDF5 dataset paths"
):
_run_task(TomoBasicToNXtomo, **inputs)
[docs]
def test_build_from_inputs_succeeds_on_missing_source_dataset(tmp_path):
nx_basic = tmp_path / "basic_invalid_input.nx"
inputs = _reference_tomobasic_inputs(nx_path=str(nx_basic), raw_h5="unused.h5")
inputs["detector_data_file_paths"] = [
str(tmp_path / f"missing_{index}.h5") for index in range(4)
]
inputs["detector_data_h5_url"][0] = "/does/not/exist"
_run_task(TomoBasicToNXtomo, **inputs)
with h5py.File(nx_basic, "r") as out:
data_ds = out["/entry0000/instrument/detector/data"]
assert data_ds.is_virtual
links = _extract_links(data_ds)
assert links[0][1] == "/does/not/exist"
[docs]
def test_build_from_inputs_virtual_sources_without_opening_sources(tmp_path):
nx_basic = tmp_path / "basic_virtual_source.nx"
inputs = _reference_tomobasic_inputs(nx_path=str(nx_basic), raw_h5="unused.h5")
inputs["detector_data_file_paths"] = [
str(tmp_path / f"missing_{index}.h5") for index in range(4)
]
inputs["detector_data_h5_url"] = [f"/entry/data_{index}" for index in range(4)]
inputs["detector_data_shapes"] = [(5, 6, 7)] * 4
inputs["detector_data_dtype"] = "uint16"
_run_task(TomoBasicToNXtomo, **inputs)
with h5py.File(nx_basic, "r") as out:
data_ds = out["/entry0000/instrument/detector/data"]
assert data_ds.is_virtual
assert data_ds.shape == (20, 6, 7)
assert data_ds.dtype == np.dtype("uint16")
links = _extract_links(data_ds)
assert len(links) == 4
assert [d for _, d in links] == [f"/entry/data_{index}" for index in range(4)]
for index, (file_name, _) in enumerate(links):
assert file_name.endswith(f"missing_{index}.h5")
for index in range(4):
assert not (tmp_path / f"missing_{index}.h5").exists()
[docs]
def test_build_from_inputs_supports_detector_data_axes(tmp_path):
raw_h5 = _get_raw_h5("TestEwoksTomo_0010")
nx_basic = tmp_path / "basic_detector_data_axes.nx"
inputs = _reference_tomobasic_inputs(nx_path=str(nx_basic), raw_h5=raw_h5)
inputs["detector_data_axes"] = np.array([b"-z", b"-y"])
_run_task(TomoBasicToNXtomo, **inputs)
with h5py.File(nx_basic, "r") as out:
_assert_value(
out,
"/entry0000/instrument/detector/transformations/rx",
0,
)
_assert_value(
out,
"/entry0000/instrument/detector/transformations/ry",
180,
)
[docs]
def test_build_from_inputs_detector_data_axes_no_flip(tmp_path):
raw_h5 = _get_raw_h5("TestEwoksTomo_0010")
nx_basic = tmp_path / "basic_detector_data_axes_no_flip.nx"
inputs = _reference_tomobasic_inputs(nx_path=str(nx_basic), raw_h5=raw_h5)
inputs["detector_data_axes"] = np.array([b"-z", b"y"])
_run_task(TomoBasicToNXtomo, **inputs)
with h5py.File(nx_basic, "r") as out:
_assert_value(
out,
"/entry0000/instrument/detector/transformations/rx",
0,
)
_assert_value(
out,
"/entry0000/instrument/detector/transformations/ry",
0,
)
[docs]
def test_build_from_inputs_coerces_integer_translation_inputs_to_float64(tmp_path):
raw_h5 = _get_raw_h5("TestEwoksTomo_0010")
nx_basic = tmp_path / "basic_integer_translation.nx"
inputs = _reference_tomobasic_inputs(nx_path=str(nx_basic), raw_h5=raw_h5)
inputs["y_translation_mm"] = [-118]
_run_task(TomoBasicToNXtomo, **inputs)
with h5py.File(nx_basic, "r") as out:
y_translation = out["/entry0000/sample/y_translation"]
assert y_translation.dtype == np.dtype("float64")
np.testing.assert_allclose(
np.asarray(y_translation[()]),
np.full(inputs["image_key_control"].size, -118.0),
)
[docs]
def test_build_from_inputs_sets_energy_without_deprecation_warning(tmp_path):
raw_h5 = _get_raw_h5("TestEwoksTomo_0010")
inputs = _reference_tomobasic_inputs(
nx_path=str(tmp_path / "unused.nx"), raw_h5=raw_h5
)
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
nx_obj = build_nxtomo_from_inputs(
**{key: value for key, value in inputs.items() if key != "nx_path"}
)
assert not [
warning
for warning in caught
if issubclass(warning.category, DeprecationWarning)
]
assert nx_obj.energy is not None
assert nx_obj.energy.to(_ureg.keV).magnitude == pytest.approx(inputs["energy_kev"])
beam = getattr(nx_obj.instrument, "beam", None)
if beam is not None:
assert beam.incident_energy is not None
assert beam.incident_energy.to(_ureg.keV).magnitude == pytest.approx(
inputs["energy_kev"]
)