import h5py
import numpy as np
import pytest
from pathlib import Path
from unittest.mock import patch
import shutil
from blissdata.redis_engine.scan import ScanState
from ewokstomo.tasks.online.reconstruct_slice import OnlineReconstructSlice
from ewokstomo.tasks.online.reducedarkflat import OnlineReduceDarkFlat
from ewokstomo.tests.test_reducedarkflat import get_data_dir, get_raw_data_dir
from ewokstomo.tests.online.mock import FakeScan, FakeScanWithMotor
from ewokstomo.tasks.online.reconstruct_slice import get_slice_index
import time
[docs]
@pytest.fixture
def tmp_dataset_path(tmp_path) -> Path:
src_dir = get_data_dir("TestEwoksTomo_0010")
dst_dir = tmp_path / "TestEwoksTomo_0010"
shutil.copytree(src_dir, dst_dir)
proj_dir = dst_dir / "projections"
proj_dir.mkdir(exist_ok=True)
slices_dir = dst_dir / "slices"
if slices_dir.exists():
shutil.rmtree(slices_dir)
# remove any existing darks/flats and gallery
for pattern in ("*_darks.hdf5", "*_flats.hdf5", "gallery"):
for f in proj_dir.glob(pattern):
if f.is_dir():
shutil.rmtree(f)
else:
f.unlink()
references_dir = dst_dir / "references"
if references_dir.exists():
shutil.rmtree(references_dir)
return dst_dir
[docs]
def test_get_slice_index():
assert get_slice_index(3, n_z=10) == 3
assert get_slice_index("middle", n_z=11) == 5
assert get_slice_index("first", n_z=10) == 0
assert get_slice_index("last", n_z=10) == 9
with pytest.raises(ValueError, match="Invalid slice index"):
get_slice_index("banana", n_z=10)
[docs]
def test_get_image_stream_name_no_image_stream():
task = OnlineReconstructSlice.__new__(OnlineReconstructSlice)
fake_scan = type(
"Scan",
(),
{"streams": {"foo": None}},
)()
with pytest.raises(ValueError, match="No image stream found"):
task._get_image_stream_name(fake_scan)
[docs]
def test_get_motor_stream_name_no_motor():
task = OnlineReconstructSlice.__new__(OnlineReconstructSlice)
fake_scan = type(
"Scan",
(),
{"streams": {"det:image": None}},
)()
with pytest.raises(ValueError, match="No motor stream found"):
task._get_motor_stream_name(fake_scan, motor_name="rot")
[docs]
def test_get_projections_stream_timeout(monkeypatch):
task = OnlineReconstructSlice.__new__(OnlineReconstructSlice)
class FakeScan:
state = ScanState.STOPPED
streams = {"det:image": []}
def update(self, block=False):
pass
scan = FakeScan()
# Avoid real sleeping
monkeypatch.setattr(time, "sleep", lambda *_: None)
with pytest.raises(TimeoutError, match="Not enough frames received"):
task._get_projections_from_stream(
scan=scan,
projections_indices=[0, 1, 2],
)
[docs]
@pytest.fixture
def TestEwoksTomo_0010_dataset(tmp_path) -> Path:
"""Copy the test dataset to a temporary directory."""
scan = "TestEwoksTomo_0010"
processed_dir = get_data_dir(scan)
raw_dir = get_raw_data_dir(scan)
dst_dir = tmp_path / scan
shutil.copytree(processed_dir, dst_dir)
shutil.copy(raw_dir / f"{scan}.h5", dst_dir / f"{scan}.h5")
return dst_dir
[docs]
@pytest.fixture
def test_data(TestEwoksTomo_0010_dataset):
"""
Load raw frames from TestEwoksTomo_0010 dataset.
Returns raw darks, flats, and projections separately.
"""
h5py_file = TestEwoksTomo_0010_dataset / "TestEwoksTomo_0010.h5"
with h5py.File(h5py_file, "r") as f:
# Load raw frames for each scan
dark_frames = [
np.array(frame, dtype=np.float32)
for frame in f["2.1/measurement/edgetwinmic"]
]
flat_frames = [
np.array(frame, dtype=np.float32)
for frame in f["3.1/measurement/edgetwinmic"]
]
projection_frames = [
np.array(frame, dtype=np.float32)
for frame in f["4.1/measurement/edgetwinmic"]
]
# Create angles for projections (full 360-degree tomography)
n_projections = len(projection_frames)
angles = np.linspace(0, 2 * np.pi, n_projections, endpoint=False, dtype=np.float32)
# Get dimensions
n_z, n_x = projection_frames[0].shape
return {
"dark_frames": dark_frames,
"flat_frames": flat_frames,
"projection_frames": projection_frames,
"angles": angles,
"pixel_size_m": 0.0000075,
"delta_beta": 100.0,
"distance_m": 500.0,
"energy_keV": 17.0,
"n_projections": n_projections,
"n_z": n_z,
"n_x": n_x,
"dataset_dir": TestEwoksTomo_0010_dataset,
}
[docs]
def test_full_workflow_online_reconstruction(test_data, tmp_path):
"""
Test the complete online tomography workflow with multiple batches and phase retrieval:
1. Mock dark scan and reduce darks
2. Mock flat scan and reduce flats
3. Mock projection scan and reconstruct in batches with phase retrieval
"""
data = test_data
output_dir = tmp_path / "output"
output_dir.mkdir()
# Step 1: Mock dark scan and reduce darks
fake_dark_scan = FakeScan(data["dark_frames"], title="dark")
with (
patch("ewokstomo.tasks.online.reducedarkflat.BeaconData") as MockBeacon,
patch("ewokstomo.tasks.online.reducedarkflat.DataStore") as MockStore,
):
MockBeacon.return_value.get_redis_data_db.return_value = "redis://fake"
MockStore.return_value.load_scan.return_value = fake_dark_scan
dark_output = output_dir / "reduced_darks.h5"
dark_task = OnlineReduceDarkFlat(
inputs={
"scan_key": "dark_scan",
"index": 0,
"reduction_method": "mean",
"output_file_path": str(dark_output),
}
)
dark_task.execute()
assert dark_output.is_file(), "Reduced darks file was not created"
# Verify reduced darks content
with h5py.File(dark_output, "r") as f:
# Option 1: Check the nested path properly
assert "entry0000" in f, "entry0000 group not found"
assert "darks" in f["entry0000"], "Reduced darks dataset not found"
reduced_dark = f["entry0000/darks"][()]
assert reduced_dark.shape == data["dark_frames"][0].shape
# Step 2: Mock flat scan and reduce flats
fake_flat_scan = FakeScan(data["flat_frames"], title="flat")
with (
patch("ewokstomo.tasks.online.reducedarkflat.BeaconData") as MockBeacon,
patch("ewokstomo.tasks.online.reducedarkflat.DataStore") as MockStore,
):
MockBeacon.return_value.get_redis_data_db.return_value = "redis://fake"
MockStore.return_value.load_scan.return_value = fake_flat_scan
flat_output = output_dir / "reduced_flats.h5"
flat_task = OnlineReduceDarkFlat(
inputs={
"scan_key": "flat_scan",
"index": len(data["dark_frames"]), # Offset index for Nabu
"reduction_method": "median",
"output_file_path": str(flat_output),
}
)
flat_task.execute()
assert flat_output.is_file(), "Reduced flats file was not created"
# Verify reduced flats content
with h5py.File(flat_output, "r") as f:
assert "entry0000/flats" in f, "Reduced flats dataset not found"
reduced_flat = f["entry0000/flats"][()]
assert reduced_flat.shape == data["flat_frames"][0].shape
# Step 3: Mock projection scan and reconstruct with multiple batches
fake_projection_scan = FakeScanWithMotor(
arrays=data["projection_frames"],
angles=list(data["angles"]),
title="projections",
rotation_motor="rot",
)
# Use a batch size that will create multiple batches
batch_size = max(10, data["n_projections"] // 4)
expected_batches = (data["n_projections"] + batch_size - 1) // batch_size
with (
patch("ewokstomo.tasks.online.reconstruct_slice.BeaconData") as MockBeacon,
patch("ewokstomo.tasks.online.reconstruct_slice.DataStore") as MockStore,
):
MockBeacon.return_value.get_redis_data_db.return_value = "redis://fake"
MockStore.return_value.load_scan.return_value = fake_projection_scan
recon_output = output_dir / "reconstruction"
task = OnlineReconstructSlice(
inputs={
"scan_key": "projection_scan",
"output_path": str(recon_output),
"rotation_motor": "rot",
"total_nb_projection": data["n_projections"],
"center_of_rotation": data["n_x"] / 2.0,
"batch_size": batch_size,
"pixel_size_m": data["pixel_size_m"],
"distance_m": data["distance_m"],
"energy_keV": data["energy_keV"],
"reduced_dark_path": str(dark_output),
"reduced_flat_path": str(flat_output),
"delta_beta": data["delta_beta"],
"halftomo": False, # Full tomography
"padding_mode": "edges",
}
)
task.execute()
# Verify reconstruction output directory exists
assert recon_output.is_dir(), (
f"Reconstruction output directory not found: {recon_output}"
)
# Verify multiple batch output files were created
output_files = sorted(recon_output.glob("reconstructed_slice_*.h5"))
assert len(output_files) == expected_batches, (
f"Expected {expected_batches} output files, found {len(output_files)}"
)
# Verify each batch file contains valid reconstruction data
for i, output_file in enumerate(output_files):
with h5py.File(output_file, "r") as f:
assert "reconstructed_slice" in f, (
f"Reconstructed slice dataset not found in {output_file.name}"
)
print(f"\n✓ Successfully processed {len(output_files)} batches")
[docs]
def test_run_stops_on_reconstruct_batch_runtime_error(monkeypatch, tmp_path):
task = OnlineReconstructSlice(
inputs={
"scan_key": "scan",
"output_path": str(tmp_path),
"rotation_motor": "rot",
"total_nb_projection": 10,
"center_of_rotation": 4.0,
"batch_size": 5,
"pixel_size_m": 1.0,
"distance_m": 1.0,
"energy_keV": 1.0,
"reduced_dark_path": "dark.h5",
"reduced_flat_path": "flat.h5",
}
)
fake_scan = type(
"Scan",
(),
{"streams": {}, "state": ScanState.STARTED},
)()
# --- infrastructure mocks ---
monkeypatch.setattr(
"ewokstomo.tasks.online.reconstruct_slice.wait_for_scan_state",
lambda *args, **kwargs: None,
)
monkeypatch.setattr(
"ewokstomo.tasks.online.reconstruct_slice.BeaconData",
lambda: type("B", (), {"get_redis_data_db": lambda self: "redis://fake"})(),
)
monkeypatch.setattr(
"ewokstomo.tasks.online.reconstruct_slice.DataStore",
lambda *_: type("DS", (), {"load_scan": lambda self, key: fake_scan})(),
)
# --- critical mock: bypass flatfield file loading ---
monkeypatch.setattr(
"ewokstomo.tasks.online.reconstruct_slice.FlatFieldCorrection.load_reduced_data_from_file",
lambda *args, **kwargs: None,
)
# --- force RuntimeError inside loop ---
monkeypatch.setattr(
task,
"reconstruct_batch",
lambda *args, **kwargs: (_ for _ in ()).throw(RuntimeError("boom")),
)
# Must NOT raise
task.execute()
[docs]
def test_reconstruct_batch_phase_retrieval_failure(monkeypatch, tmp_path):
task = OnlineReconstructSlice.__new__(OnlineReconstructSlice)
task.flatfield_processor = type(
"FF",
(),
{"apply_correction": lambda *args, **kwargs: None},
)()
monkeypatch.setattr(
task,
"_get_projections_from_stream",
lambda *args, **kwargs: np.ones((3, 4, 8), dtype=np.float32),
)
monkeypatch.setattr(
task,
"_get_angles_from_stream",
lambda *args, **kwargs: np.zeros(3, dtype=np.float32),
)
# Phase retrieval fails
monkeypatch.setattr(
"ewokstomo.tasks.online.reconstruct_slice.apply_phase_retrieval",
lambda *args, **kwargs: (_ for _ in ()).throw(Exception("phase fail")),
)
with pytest.raises(Exception, match="phase fail"):
task.reconstruct_batch(
scan=None,
output_path=tmp_path / "out.h5",
center_of_rotation=4.0,
projections_indices=[0, 1, 2],
distance_m=1.0,
energy_keV=1.0,
pixel_size_m=1.0,
rotation_motor="rot",
)
[docs]
def test_reconstruct_batch_fbp_failure(monkeypatch, tmp_path):
task = OnlineReconstructSlice.__new__(OnlineReconstructSlice)
task.flatfield_processor = type(
"FF",
(),
{"apply_correction": lambda *args, **kwargs: None},
)()
monkeypatch.setattr(
task,
"_get_projections_from_stream",
lambda *args, **kwargs: np.ones((3, 4, 8), dtype=np.float32),
)
monkeypatch.setattr(
task,
"_get_angles_from_stream",
lambda *args, **kwargs: np.zeros(3, dtype=np.float32),
)
monkeypatch.setattr(
"ewokstomo.tasks.online.reconstruct_slice.apply_phase_retrieval",
lambda *args, **kwargs: args[0],
)
# FBP fails
monkeypatch.setattr(
"ewokstomo.tasks.online.reconstruct_slice.fbp_reconstruction_slice",
lambda *args, **kwargs: (_ for _ in ()).throw(Exception("fbp fail")),
)
with pytest.raises(Exception, match="fbp fail"):
task.reconstruct_batch(
scan=None,
output_path=tmp_path / "out.h5",
center_of_rotation=4.0,
projections_indices=[0, 1, 2],
distance_m=1.0,
energy_keV=1.0,
pixel_size_m=1.0,
rotation_motor="rot",
)