import os
import re
from pathlib import Path
import warnings
from typing import Any, Literal
import h5py
import numpy as np
from PIL import Image
from ewokscore import Task
from ewokscore.model import BaseInputModel
from ewokscore.model import BaseOutputModel
from esrf_pathlib import ESRFPath
from pydantic import Field
from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
from nabu.preproc.flatfield import FlatField
[docs]
class TomoPath(ESRFPath, tomo=1, fallback_depth=0): # type: ignore[call-arg]
pass
SAVE_KWARGS: dict[str, dict[str, Any]] = {
"png": {"compress_level": 6, "optimize": True},
"jpg": {"quality": 95, "subsampling": 0, "optimize": True},
"jpeg": {"quality": 95, "subsampling": 0, "optimize": True},
"webp": {"quality": 95, "method": 6},
}
def _auto_intensity_bounds(
images: np.ndarray | list[np.ndarray] | tuple[np.ndarray, ...],
lower_pct: float = 0.01,
upper_pct: float = 99.99,
) -> tuple[float, float]:
"""Compute robust lower/upper bounds for scaling to 8-bit."""
if isinstance(images, np.ndarray):
finite = np.asarray(images, dtype=np.float32)
finite = finite[np.isfinite(finite)]
else:
finite_chunks = []
for image in images:
image = np.asarray(image, dtype=np.float32)
finite_chunks.append(image[np.isfinite(image)])
finite_chunks = [chunk for chunk in finite_chunks if chunk.size]
finite = np.concatenate(finite_chunks) if finite_chunks else np.array([])
if finite.size == 0:
return 0.0, 255.0
upper_candidates = finite[finite < 1e9]
if upper_candidates.size == 0:
upper_candidates = finite
lower = float(np.percentile(finite, lower_pct))
upper = float(np.percentile(upper_candidates, upper_pct))
if not np.isfinite(lower) or not np.isfinite(upper):
return 0.0, 255.0
if lower == upper:
upper = lower + 1.0
return lower, upper
[docs]
def clean_angle_key(angle_key):
"""Convert angle key like '90.00000009(1)' to float, or leave float as is."""
if isinstance(angle_key, float):
return angle_key # already clean
cleaned = re.sub(r"\(.*?\)", "", angle_key) # remove '(1)' etc.
return float(cleaned)
def _save_kwargs_for_format(fmt: str) -> dict[str, Any]:
"""Return Pillow save options for reasonable compression without large quality loss."""
fmt = fmt.lower().lstrip(".")
return SAVE_KWARGS.get(fmt, {})
def _resize_preserve_aspect(img: Image.Image, target_size: int) -> Image.Image:
"""Resize to fit within target_size while preserving aspect ratio."""
if max(img.size) <= target_size:
return img
scale = target_size / max(img.size)
new_size = (
max(1, int(round(img.size[0] * scale))),
max(1, int(round(img.size[1] * scale))),
)
return img.resize(new_size, resample=Image.Resampling.LANCZOS)
def _projection_index_from_data_url(data_url) -> int | None:
"""Extract projection index from a silx DataUrl-like object."""
data_slice = data_url.data_slice()
if data_slice is None:
return None
if np.isscalar(data_slice):
return int(data_slice)
first_item = data_slice[:1]
return int(first_item[0]) if first_item and np.isscalar(first_item[0]) else None
def _normalize_projections_by_beam_intensity(
projections: list[np.ndarray] | np.ndarray,
beam_intensities: list[float] | np.ndarray,
) -> np.ndarray:
"""Scale projections with mean intensity divided by each beam intensity."""
normalized = np.asarray(projections, dtype=np.float32).copy()
intensities = np.asarray(beam_intensities, dtype=np.float32)
valid = np.isfinite(intensities) & (intensities > 0.0)
if np.any(valid):
mean_intensity = float(np.mean(intensities[valid]))
normalized[valid] *= mean_intensity / intensities[valid, None, None]
return normalized
[docs]
class ProjectionsGalleryOutputModel(BaseOutputModel):
processed_data_dir: str = Field(
..., description="Directory containing the processed data."
)
gallery_path: str = Field(..., description="Path to the created gallery directory.")
[docs]
class SlicesGalleryOutputModel(BaseOutputModel):
processed_data_dir: str = Field(
..., description="Directory containing the processed data."
)
gallery_path: str = Field(..., description="Path to the created gallery directory.")
gallery_image_path: str = Field(
..., description="Path to the created gallery image."
)
def _prepare_gallery_image(
image: np.ndarray,
bounds: tuple[float, float],
target_size: int,
) -> Image.Image:
# Ensure the image is 2D. If it's 3D with a single channel, squeeze it.
if image.ndim == 3 and image.shape[0] == 1:
image = image.reshape(image.shape[1:])
elif image.ndim != 2:
raise ValueError(f"Only 2D grayscale images are handled. Got {image.shape}")
lower_bound, upper_bound = bounds
# Replace non-finite values before clamping to avoid all-black output.
image = np.nan_to_num(
image, nan=lower_bound, posinf=upper_bound, neginf=lower_bound
)
# Apply clamping and normalization.
image = np.clip(image, lower_bound, upper_bound)
image = image - lower_bound
if upper_bound != lower_bound:
image = image * (255.0 / (upper_bound - lower_bound))
# Convert the image to a PIL Image.
img = Image.fromarray(np.clip(image, 0, 255).astype(np.uint8), mode="L")
# Resize if larger than target_size while keeping aspect ratio.
img = _resize_preserve_aspect(img, target_size)
return img
def _save_to_gallery(
output_file_name: str | Path,
image: np.ndarray,
bounds: tuple[float, float],
overwrite: bool = True,
image_size: int = 1000,
output_format: str | None = None,
) -> None:
"""
Process and save two gallery images:
- a full-size image named `<name>_large`
- a 200px preview with the original file name.
"""
output_path = Path(output_file_name)
img = _prepare_gallery_image(image, bounds, image_size)
fmt = output_format or output_path.suffix.lstrip(".")
if not fmt:
fmt = "png"
save_kwargs = _save_kwargs_for_format(str(fmt))
large_output_path = output_path.with_name(
f"{output_path.stem}_large{output_path.suffix}"
)
if not overwrite and (output_path.exists() or large_output_path.exists()):
raise OSError(f"File already exists ({output_path})")
img.save(str(large_output_path), **save_kwargs)
img_small = _resize_preserve_aspect(img, target_size=200)
img_small.save(str(output_path), **save_kwargs)
[docs]
class VolumeGalleryOutputModel(BaseOutputModel):
processed_data_dir: str = Field(
..., description="Directory containing the processed data."
)
gallery_path: str = Field(..., description="Path to the created gallery directory.")
gallery_image_paths: list[str] = Field(
..., description="Paths to the created gallery images."
)
[docs]
class BuildProjectionsGallery( # type: ignore[call-arg]
Task,
input_model=ProjectionsGalleryInputModel,
output_model=ProjectionsGalleryOutputModel,
):
[docs]
def run(self):
"""
Creates a gallery of images from the NXtomoScan object.
"""
self.gallery_output_format = str(self.inputs.output_format).lower()
self.gallery_overwrite = self.inputs.overwrite
self.gallery_image_size = int(self.inputs.image_size)
angle_step = self.inputs.angle_step
# Use the directory of the output file as the processed data directory.
nx_path = Path(self.inputs.nx_path)
processed_data_dir = nx_path.parent
gallery_dir = self.get_gallery_dir(processed_data_dir)
os.makedirs(gallery_dir, exist_ok=True)
# Open the NXtomoScan object.
self.nxtomoscan = NXtomoScan(str(nx_path), entry="entry0000")
angles, corrected_projections, bounds = self._get_gallery_projections(
angle_step=angle_step,
bounds=self.inputs.bounds,
)
for angle, projection in zip(angles, corrected_projections):
gallery_file_path = self.get_gallery_file_path(gallery_dir, nx_path, angle)
Path(gallery_file_path).parent.mkdir(parents=True, exist_ok=True)
# Process the image and save it in the gallery.
_save_to_gallery(
gallery_file_path,
projection,
bounds,
overwrite=self.gallery_overwrite,
image_size=self.gallery_image_size,
output_format=self.gallery_output_format,
)
self._save_projections_gif(gallery_dir, nx_path, self.inputs.bounds)
self.outputs.processed_data_dir = str(processed_data_dir)
self.outputs.gallery_path = str(gallery_dir)
[docs]
def get_flats_from_h5(
self, reduced_flat_path: str, data_path: str = "entry0000/flats"
) -> dict[int, np.ndarray]:
"""
Loads the data from an HDF5 file.
"""
with h5py.File(reduced_flat_path, "r") as h5f:
for idx in h5f[data_path]:
data = h5f[data_path][idx]
flats_idx = int(idx)
flats_data = data[()]
return {flats_idx: flats_data}
[docs]
def get_darks_from_h5(
self, reduced_dark_path: str, data_path: str = "entry0000/darks"
) -> dict[int, np.ndarray]:
"""
Loads the data from an HDF5 file.
"""
with h5py.File(reduced_dark_path, "r") as h5f:
for idx in h5f[data_path]:
data = h5f[data_path][idx]
darks_idx = int(idx)
darks_data = data[()]
return {darks_idx: darks_data}
[docs]
def apply_flat_field_correction(self, projections):
"""
Applies flat field correction to the projections.
"""
projections = np.asarray(projections, dtype=np.float32)
reduced_darks = self.get_darks_from_h5(self.inputs.reduced_darks_path)
reduced_flats = self.get_flats_from_h5(self.inputs.reduced_flats_path)
x, y = projections[0].shape
radios_shape = (len(projections), x, y)
flat_field = FlatField(
radios_shape=radios_shape, flats=reduced_flats, darks=reduced_darks
)
warnings.filterwarnings(
"ignore",
message=".*encountered in divide",
category=RuntimeWarning,
module="nabu.preproc.flatfield",
)
corrected_projections = flat_field.normalize_radios(projections)
return corrected_projections
[docs]
def get_gallery_dir(self, processed_data_dir: Path) -> Path:
return processed_data_dir / "gallery"
[docs]
def get_gallery_file_path(self, gallery_dir, nx_path: Path, angle: float) -> str:
filename = f"{nx_path.stem}_{angle:.2f}deg.{self.gallery_output_format}"
gallery_path = gallery_dir / filename
return str(gallery_path)
[docs]
def get_proj_from_data_url(self, data_url) -> np.ndarray:
"""Load the data from a DataUrl object."""
with h5py.File(data_url.file_path(), "r") as h5f:
data = h5f[data_url.data_path()]
if data_url.data_slice() is not None:
return data[data_url.data_slice()].astype(np.float32)
return data[()].astype(np.float32)
def _get_gallery_projections(
self,
angle_step: float,
bounds: tuple[float, float] | None,
) -> tuple[list[float], np.ndarray, tuple[float, float]]:
"""Load, correct and normalize the projections used for gallery outputs."""
angles, projections, beam_intensities = self.get_projections_by_angle_step(
angle_step
)
corrected_projections = self.apply_flat_field_correction(projections)
corrected_projections = _normalize_projections_by_beam_intensity(
corrected_projections, beam_intensities
)
if not isinstance(bounds, tuple):
bounds = _auto_intensity_bounds(corrected_projections)
return angles, corrected_projections, bounds
[docs]
def get_projections_by_angle_step(
self, angle_step: float = 90
) -> tuple[list[float], list[np.ndarray], np.ndarray]:
"""
Returns the projections to be processed.
"""
# Get all angles
angles_dict = self.nxtomoscan.get_proj_angle_url()
angles_dict = {clean_angle_key(k): v for k, v in angles_dict.items()}
all_angles = np.array(list(angles_dict.keys()))
# Determine regularly spaced target angles within the available range.
min_angle = np.min(all_angles)
max_angle = np.max(all_angles)
target_angles = np.arange(min_angle, max_angle + angle_step, angle_step)
# For each target angle, find the closest available
selected_angles = []
used_indices = set()
for target in target_angles:
diffs = np.abs(all_angles - target)
idx = np.argmin(diffs)
if idx not in used_indices: # avoid duplicates
used_indices.add(idx)
selected_angles.append(all_angles[idx])
selected_data_urls = [angles_dict[angle] for angle in selected_angles]
selected_projections = [
self.get_proj_from_data_url(data_url) for data_url in selected_data_urls
]
beam_intensities = self.get_beam_intensities(selected_data_urls)
return selected_angles, selected_projections, beam_intensities
[docs]
def get_beam_intensities(self, data_urls: list[Any]) -> np.ndarray:
"""Return beam monitor intensities aligned with the selected projections."""
intensities = np.full(len(data_urls), np.nan, dtype=np.float32)
nx_path = Path(self.inputs.nx_path)
try:
with h5py.File(str(nx_path), "r") as h5f:
control_data = np.asarray(
h5f["entry0000/control/data"][()], dtype=np.float32
)
except Exception:
return intensities
for i, data_url in enumerate(data_urls):
proj_idx = _projection_index_from_data_url(data_url)
if proj_idx is None:
continue
if 0 <= proj_idx < control_data.size:
intensities[i] = control_data[proj_idx]
return intensities
def _save_projections_gif(
self, gallery_dir: Path, nx_path: Path, bounds: tuple[float, float] | None
) -> None:
gif_angle_step = 10.0
gif_size = 200
gif_fps = 12
gif_duration = 1000
angles, corrected_projections, bounds = self._get_gallery_projections(
angle_step=gif_angle_step,
bounds=bounds,
)
if not angles:
return
frames = [
_prepare_gallery_image(projection, bounds, target_size=gif_size)
for projection in corrected_projections
]
if not frames:
return
gif_path = gallery_dir / f"{nx_path.stem}.gif"
overwrite = getattr(self, "gallery_overwrite", True)
if not overwrite and gif_path.exists():
raise OSError(f"File already exists ({gif_path})")
frames[0].save(
gif_path,
save_all=True,
append_images=frames[1:],
duration=int(round(gif_duration / gif_fps)),
loop=0,
)
[docs]
class BuildSlicesGallery( # type: ignore[call-arg]
Task,
input_model=SlicesGalleryInputModel,
output_model=SlicesGalleryOutputModel,
):
"""Create two gallery images from a reconstructed slice (full-size suffixed `_large`, and a 200x200 resized version).
The large image is downsampled if needed so neither dimension exceeds the configured image_size (default 1000 px).
"""
[docs]
def run(self):
"""Read the slice, normalize/downsample, and save to <processed>/gallery."""
fmt = str(self.inputs.output_format)
overwrite = bool(self.inputs.overwrite)
image_size = int(self.inputs.image_size)
bounds = self.inputs.bounds
slice_path = Path(self.inputs.reconstructed_slice_path)
if not slice_path.exists():
raise FileNotFoundError(f"Reconstructed slice not found: {slice_path}")
processed_data_dir = slice_path.parent
gallery_dir = Path(processed_data_dir) / "gallery"
os.makedirs(gallery_dir, exist_ok=True)
arr = self._load_slice(slice_path)
if not isinstance(bounds, tuple):
bounds = _auto_intensity_bounds(arr)
out_name = self.get_gallery_file_path(gallery_dir, slice_path, fmt)
out_path = gallery_dir / out_name
_save_to_gallery(
out_path,
arr,
bounds,
overwrite=overwrite,
image_size=image_size,
output_format=fmt,
)
self.outputs.processed_data_dir = str(processed_data_dir)
self.outputs.gallery_path = str(gallery_dir)
self.outputs.gallery_image_path = str(out_path)
[docs]
def get_gallery_dir(self, processed_data_dir: Path | str) -> str:
"""Return the fixed gallery directory path."""
return str(Path(processed_data_dir) / "gallery")
[docs]
def get_gallery_file_path(self, gallery_dir, reconstructed_slice_path, fmt) -> str:
slice_path = Path(reconstructed_slice_path)
# Drop the dataset prefix and keep the reconstruction descriptor.
filename = re.sub(
r"^.*?_(absorption|phase)_",
r"\1_",
f"{slice_path.stem}.{fmt}",
count=1,
)
gallery_path = gallery_dir / filename
return str(gallery_path)
@staticmethod
def _load_slice(img_path: Path) -> np.ndarray:
"""Load a 2D float32 slice (HDF5 at entry0000/reconstruction/results/data, EDF, or image)."""
ext = img_path.suffix.lower()
if ext in (".h5", ".hdf5"):
with h5py.File(img_path, "r") as h5in:
img = h5in["entry0000/reconstruction/results/data"][:]
return np.squeeze(img).astype(np.float32)
if ext == ".edf":
try:
import fabio # type: ignore
except Exception as exc:
raise RuntimeError(
"EDF support requires 'fabio' (pip install fabio)."
) from exc
return fabio.open(str(img_path)).data.astype(np.float32)
with Image.open(img_path) as im:
arr = np.array(im, dtype=np.float32)
if arr.ndim == 3 and arr.shape[-1] in (3, 4):
arr = arr[..., :3].mean(axis=-1)
return arr
[docs]
class BuildVolumeGallery( # type: ignore[call-arg]
Task,
input_model=VolumeGalleryInputModel,
output_model=VolumeGalleryOutputModel,
):
"""Create gallery images from a reconstructed 3D volume.
For each axis (X, Y, Z), extract slices at 1/4, 2/4 and 3/4 of the volume
extent and save them to the gallery directory.
"""
[docs]
def run(self):
fmt = str(self.inputs.output_format)
overwrite = bool(self.inputs.overwrite)
image_size = int(self.inputs.image_size)
bounds = self.inputs.bounds
volume_path = Path(self.inputs.reconstructed_volume_path)
if not volume_path.exists():
raise FileNotFoundError(f"Reconstructed volume not found: {volume_path}")
volume_container = TomoPath(volume_path.parent)
if volume_container.template_name == "volumes_custom_type_file":
gallery_dir = Path(volume_container.volumes_custom_type_gallery_path)
gallery_template_name = "volumes_custom_type_preview"
else:
gallery_dir = Path(volume_container.volumes_gallery_path)
gallery_template_name = "volumes_preview"
processed_data_dir = volume_path.parent
os.makedirs(gallery_dir, exist_ok=True)
slices_by_axis, axis_indices = self._load_gallery_slices(volume_path)
gallery_slices = [
slices_by_axis[slicing_direction][slice_index]
for slicing_direction in ("xy", "xz", "yz")
for slice_index in axis_indices[slicing_direction]
]
if not isinstance(bounds, tuple):
bounds = _auto_intensity_bounds(gallery_slices)
gallery_paths = []
for slicing_direction in ("xy", "xz", "yz"):
for slice_index in axis_indices[slicing_direction]:
slice_2d = slices_by_axis[slicing_direction][slice_index]
out_path = Path(
volume_container.replace_fields(
template_name=gallery_template_name,
tomo_slicing_direction=slicing_direction,
tomo_slice_number=slice_index,
thumbnail_file_type=fmt,
)
)
out_path = out_path.with_name(
# Drop the dataset prefix and keep the reconstruction descriptor.
re.sub(
r"^.*?_(absorption|phase)_",
r"\1_",
out_path.name,
count=1,
)
)
out_path.parent.mkdir(parents=True, exist_ok=True)
_save_to_gallery(
out_path,
slice_2d,
bounds,
overwrite=overwrite,
image_size=image_size,
output_format=fmt,
)
gallery_paths.append(str(out_path))
self.outputs.processed_data_dir = str(processed_data_dir)
self.outputs.gallery_path = str(gallery_dir)
self.outputs.gallery_image_paths = gallery_paths
@staticmethod
def _as_grayscale(frame: Image.Image) -> np.ndarray:
arr = np.array(frame, dtype=np.float32)
if arr.ndim == 3 and arr.shape[-1] in (3, 4):
arr = arr[..., :3].mean(axis=-1)
if arr.ndim != 2:
raise ValueError(
f"Only 2D grayscale TIFF frames are handled. Got frame shape={arr.shape}"
)
return arr
@classmethod
def _load_gallery_slices(
cls, volume_path: Path
) -> tuple[dict[str, dict[int, np.ndarray]], dict[str, list[int]]]:
if volume_path.suffix.lower() not in (".tif", ".tiff"):
raise ValueError(
f"Unsupported volume format for gallery: {volume_path}. Only TIFF3D is supported."
)
fractions = (0.25, 0.5, 0.75)
def idxs(n: int) -> list[int]:
return [min(n - 1, max(0, int(n * frac))) for frac in fractions]
with Image.open(volume_path) as im:
z_size = int(getattr(im, "n_frames", 1))
if z_size <= 0:
raise ValueError(f"No frames found in TIFF file: {volume_path}")
im.seek(0)
y_size, x_size = cls._as_grayscale(im).shape
axis_indices = {"xy": idxs(z_size), "xz": idxs(y_size), "yz": idxs(x_size)}
unique_indices = {k: sorted(set(v)) for k, v in axis_indices.items()}
xy_slices: dict[int, np.ndarray] = {}
xz_slices = {
y_idx: np.empty((z_size, x_size), dtype=np.float32)
for y_idx in unique_indices["xz"]
}
yz_slices = {
x_idx: np.empty((z_size, y_size), dtype=np.float32)
for x_idx in unique_indices["yz"]
}
for z_idx in range(z_size):
im.seek(z_idx)
frame_arr = cls._as_grayscale(im)
if z_idx in unique_indices["xy"]:
xy_slices[z_idx] = frame_arr.copy()
for y_idx in unique_indices["xz"]:
xz_slices[y_idx][z_idx, :] = frame_arr[y_idx, :]
for x_idx in unique_indices["yz"]:
yz_slices[x_idx][z_idx, :] = frame_arr[:, x_idx]
return {"xy": xy_slices, "xz": xz_slices, "yz": yz_slices}, axis_indices