Source code for ewokstomo.tasks.buildgallery

import os
import re
from pathlib import Path
import warnings
from typing import Any, Literal

import h5py
import numpy as np
from PIL import Image
from ewokscore import Task
from ewokscore.model import BaseInputModel
from ewokscore.model import BaseOutputModel
from esrf_pathlib import ESRFPath
from pydantic import Field

from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
from nabu.preproc.flatfield import FlatField


[docs] class TomoPath(ESRFPath, tomo=1, fallback_depth=0): # type: ignore[call-arg] pass
SAVE_KWARGS: dict[str, dict[str, Any]] = { "png": {"compress_level": 6, "optimize": True}, "jpg": {"quality": 95, "subsampling": 0, "optimize": True}, "jpeg": {"quality": 95, "subsampling": 0, "optimize": True}, "webp": {"quality": 95, "method": 6}, } def _auto_intensity_bounds( images: np.ndarray | list[np.ndarray] | tuple[np.ndarray, ...], lower_pct: float = 0.01, upper_pct: float = 99.99, ) -> tuple[float, float]: """Compute robust lower/upper bounds for scaling to 8-bit.""" if isinstance(images, np.ndarray): finite = np.asarray(images, dtype=np.float32) finite = finite[np.isfinite(finite)] else: finite_chunks = [] for image in images: image = np.asarray(image, dtype=np.float32) finite_chunks.append(image[np.isfinite(image)]) finite_chunks = [chunk for chunk in finite_chunks if chunk.size] finite = np.concatenate(finite_chunks) if finite_chunks else np.array([]) if finite.size == 0: return 0.0, 255.0 upper_candidates = finite[finite < 1e9] if upper_candidates.size == 0: upper_candidates = finite lower = float(np.percentile(finite, lower_pct)) upper = float(np.percentile(upper_candidates, upper_pct)) if not np.isfinite(lower) or not np.isfinite(upper): return 0.0, 255.0 if lower == upper: upper = lower + 1.0 return lower, upper
[docs] def clean_angle_key(angle_key): """Convert angle key like '90.00000009(1)' to float, or leave float as is.""" if isinstance(angle_key, float): return angle_key # already clean cleaned = re.sub(r"\(.*?\)", "", angle_key) # remove '(1)' etc. return float(cleaned)
def _save_kwargs_for_format(fmt: str) -> dict[str, Any]: """Return Pillow save options for reasonable compression without large quality loss.""" fmt = fmt.lower().lstrip(".") return SAVE_KWARGS.get(fmt, {}) def _resize_preserve_aspect(img: Image.Image, target_size: int) -> Image.Image: """Resize to fit within target_size while preserving aspect ratio.""" if max(img.size) <= target_size: return img scale = target_size / max(img.size) new_size = ( max(1, int(round(img.size[0] * scale))), max(1, int(round(img.size[1] * scale))), ) return img.resize(new_size, resample=Image.Resampling.LANCZOS) def _projection_index_from_data_url(data_url) -> int | None: """Extract projection index from a silx DataUrl-like object.""" data_slice = data_url.data_slice() if data_slice is None: return None if np.isscalar(data_slice): return int(data_slice) first_item = data_slice[:1] return int(first_item[0]) if first_item and np.isscalar(first_item[0]) else None def _normalize_projections_by_beam_intensity( projections: list[np.ndarray] | np.ndarray, beam_intensities: list[float] | np.ndarray, ) -> np.ndarray: """Scale projections with mean intensity divided by each beam intensity.""" normalized = np.asarray(projections, dtype=np.float32).copy() intensities = np.asarray(beam_intensities, dtype=np.float32) valid = np.isfinite(intensities) & (intensities > 0.0) if np.any(valid): mean_intensity = float(np.mean(intensities[valid])) normalized[valid] *= mean_intensity / intensities[valid, None, None] return normalized
[docs] class ProjectionsGalleryInputModel(BaseInputModel): nx_path: str = Field(..., description="Path to the input NX file.") reduced_darks_path: str = Field( ..., description="Path to the reduced dark frames HDF5 file." ) reduced_flats_path: str = Field( ..., description="Path to the reduced flat frames HDF5 file." ) bounds: tuple[float, float] | None = Field( None, description=( "Intensity bounds (min, max) for image normalization. " "If not provided, robust defaults are computed automatically." ), ) angle_step: float = Field( 90.0, description=( "Angular step in degrees for selecting projections to include in the gallery." ), ) output_format: Literal["jpg", "png", "jpeg", "webp"] = Field( "jpg", description="Image format for gallery images (e.g., 'jpg', 'png')." ) overwrite: bool = Field( True, description="Whether to overwrite existing gallery images." ) image_size: int = Field( 1000, description=( "Maximum size (in pixels) for the largest dimension of gallery images. " "Images larger than this will be downsampled." ), )
[docs] class ProjectionsGalleryOutputModel(BaseOutputModel): processed_data_dir: str = Field( ..., description="Directory containing the processed data." ) gallery_path: str = Field(..., description="Path to the created gallery directory.")
[docs] class SlicesGalleryInputModel(BaseInputModel): reconstructed_slice_path: str = Field( ..., description="Path to the reconstructed slice file." ) bounds: tuple[float, float] | None = Field( None, description=( "Intensity bounds (min, max) for image normalization. " "If not provided, robust defaults are computed automatically." ), ) output_format: Literal["jpg", "png", "jpeg", "webp"] = Field( "jpg", description="Image format for gallery images (e.g., 'jpg', 'png')." ) overwrite: bool = Field( True, description="Whether to overwrite existing gallery images." ) image_size: int = Field( 1000, description=( "Maximum size (in pixels) for the largest dimension of gallery images. " "Images larger than this will be downsampled." ), )
[docs] class SlicesGalleryOutputModel(BaseOutputModel): processed_data_dir: str = Field( ..., description="Directory containing the processed data." ) gallery_path: str = Field(..., description="Path to the created gallery directory.") gallery_image_path: str = Field( ..., description="Path to the created gallery image." )
def _prepare_gallery_image( image: np.ndarray, bounds: tuple[float, float], target_size: int, ) -> Image.Image: # Ensure the image is 2D. If it's 3D with a single channel, squeeze it. if image.ndim == 3 and image.shape[0] == 1: image = image.reshape(image.shape[1:]) elif image.ndim != 2: raise ValueError(f"Only 2D grayscale images are handled. Got {image.shape}") lower_bound, upper_bound = bounds # Replace non-finite values before clamping to avoid all-black output. image = np.nan_to_num( image, nan=lower_bound, posinf=upper_bound, neginf=lower_bound ) # Apply clamping and normalization. image = np.clip(image, lower_bound, upper_bound) image = image - lower_bound if upper_bound != lower_bound: image = image * (255.0 / (upper_bound - lower_bound)) # Convert the image to a PIL Image. img = Image.fromarray(np.clip(image, 0, 255).astype(np.uint8), mode="L") # Resize if larger than target_size while keeping aspect ratio. img = _resize_preserve_aspect(img, target_size) return img def _save_to_gallery( output_file_name: str | Path, image: np.ndarray, bounds: tuple[float, float], overwrite: bool = True, image_size: int = 1000, output_format: str | None = None, ) -> None: """ Process and save two gallery images: - a full-size image named `<name>_large` - a 200px preview with the original file name. """ output_path = Path(output_file_name) img = _prepare_gallery_image(image, bounds, image_size) fmt = output_format or output_path.suffix.lstrip(".") if not fmt: fmt = "png" save_kwargs = _save_kwargs_for_format(str(fmt)) large_output_path = output_path.with_name( f"{output_path.stem}_large{output_path.suffix}" ) if not overwrite and (output_path.exists() or large_output_path.exists()): raise OSError(f"File already exists ({output_path})") img.save(str(large_output_path), **save_kwargs) img_small = _resize_preserve_aspect(img, target_size=200) img_small.save(str(output_path), **save_kwargs)
[docs] class VolumeGalleryInputModel(BaseInputModel): reconstructed_volume_path: str = Field( ..., description="Path to the reconstructed 3D TIFF volume file." ) bounds: tuple[float, float] | None = Field( None, description=( "Intensity bounds (min, max) for image normalization. " "If not provided, robust defaults are computed automatically." ), ) output_format: Literal["jpg", "png", "jpeg", "webp"] = Field( "jpg", description="Image format for gallery images (e.g., 'jpg', 'png')." ) overwrite: bool = Field( True, description="Whether to overwrite existing gallery images." ) image_size: int = Field( 1000, description=( "Maximum size (in pixels) for the largest dimension of gallery images. " "Images larger than this will be downsampled." ), )
[docs] class VolumeGalleryOutputModel(BaseOutputModel): processed_data_dir: str = Field( ..., description="Directory containing the processed data." ) gallery_path: str = Field(..., description="Path to the created gallery directory.") gallery_image_paths: list[str] = Field( ..., description="Paths to the created gallery images." )
[docs] class BuildProjectionsGallery( # type: ignore[call-arg] Task, input_model=ProjectionsGalleryInputModel, output_model=ProjectionsGalleryOutputModel, ):
[docs] def run(self): """ Creates a gallery of images from the NXtomoScan object. """ self.gallery_output_format = str(self.inputs.output_format).lower() self.gallery_overwrite = self.inputs.overwrite self.gallery_image_size = int(self.inputs.image_size) angle_step = self.inputs.angle_step # Use the directory of the output file as the processed data directory. nx_path = Path(self.inputs.nx_path) processed_data_dir = nx_path.parent gallery_dir = self.get_gallery_dir(processed_data_dir) os.makedirs(gallery_dir, exist_ok=True) # Open the NXtomoScan object. self.nxtomoscan = NXtomoScan(str(nx_path), entry="entry0000") angles, corrected_projections, bounds = self._get_gallery_projections( angle_step=angle_step, bounds=self.inputs.bounds, ) for angle, projection in zip(angles, corrected_projections): gallery_file_path = self.get_gallery_file_path(gallery_dir, nx_path, angle) Path(gallery_file_path).parent.mkdir(parents=True, exist_ok=True) # Process the image and save it in the gallery. _save_to_gallery( gallery_file_path, projection, bounds, overwrite=self.gallery_overwrite, image_size=self.gallery_image_size, output_format=self.gallery_output_format, ) self._save_projections_gif(gallery_dir, nx_path, self.inputs.bounds) self.outputs.processed_data_dir = str(processed_data_dir) self.outputs.gallery_path = str(gallery_dir)
[docs] def get_flats_from_h5( self, reduced_flat_path: str, data_path: str = "entry0000/flats" ) -> dict[int, np.ndarray]: """ Loads the data from an HDF5 file. """ with h5py.File(reduced_flat_path, "r") as h5f: for idx in h5f[data_path]: data = h5f[data_path][idx] flats_idx = int(idx) flats_data = data[()] return {flats_idx: flats_data}
[docs] def get_darks_from_h5( self, reduced_dark_path: str, data_path: str = "entry0000/darks" ) -> dict[int, np.ndarray]: """ Loads the data from an HDF5 file. """ with h5py.File(reduced_dark_path, "r") as h5f: for idx in h5f[data_path]: data = h5f[data_path][idx] darks_idx = int(idx) darks_data = data[()] return {darks_idx: darks_data}
[docs] def apply_flat_field_correction(self, projections): """ Applies flat field correction to the projections. """ projections = np.asarray(projections, dtype=np.float32) reduced_darks = self.get_darks_from_h5(self.inputs.reduced_darks_path) reduced_flats = self.get_flats_from_h5(self.inputs.reduced_flats_path) x, y = projections[0].shape radios_shape = (len(projections), x, y) flat_field = FlatField( radios_shape=radios_shape, flats=reduced_flats, darks=reduced_darks ) warnings.filterwarnings( "ignore", message=".*encountered in divide", category=RuntimeWarning, module="nabu.preproc.flatfield", ) corrected_projections = flat_field.normalize_radios(projections) return corrected_projections
[docs] def get_proj_from_data_url(self, data_url) -> np.ndarray: """Load the data from a DataUrl object.""" with h5py.File(data_url.file_path(), "r") as h5f: data = h5f[data_url.data_path()] if data_url.data_slice() is not None: return data[data_url.data_slice()].astype(np.float32) return data[()].astype(np.float32)
def _get_gallery_projections( self, angle_step: float, bounds: tuple[float, float] | None, ) -> tuple[list[float], np.ndarray, tuple[float, float]]: """Load, correct and normalize the projections used for gallery outputs.""" angles, projections, beam_intensities = self.get_projections_by_angle_step( angle_step ) corrected_projections = self.apply_flat_field_correction(projections) corrected_projections = _normalize_projections_by_beam_intensity( corrected_projections, beam_intensities ) if not isinstance(bounds, tuple): bounds = _auto_intensity_bounds(corrected_projections) return angles, corrected_projections, bounds
[docs] def get_projections_by_angle_step( self, angle_step: float = 90 ) -> tuple[list[float], list[np.ndarray], np.ndarray]: """ Returns the projections to be processed. """ # Get all angles angles_dict = self.nxtomoscan.get_proj_angle_url() angles_dict = {clean_angle_key(k): v for k, v in angles_dict.items()} all_angles = np.array(list(angles_dict.keys())) # Determine regularly spaced target angles within the available range. min_angle = np.min(all_angles) max_angle = np.max(all_angles) target_angles = np.arange(min_angle, max_angle + angle_step, angle_step) # For each target angle, find the closest available selected_angles = [] used_indices = set() for target in target_angles: diffs = np.abs(all_angles - target) idx = np.argmin(diffs) if idx not in used_indices: # avoid duplicates used_indices.add(idx) selected_angles.append(all_angles[idx]) selected_data_urls = [angles_dict[angle] for angle in selected_angles] selected_projections = [ self.get_proj_from_data_url(data_url) for data_url in selected_data_urls ] beam_intensities = self.get_beam_intensities(selected_data_urls) return selected_angles, selected_projections, beam_intensities
[docs] def get_beam_intensities(self, data_urls: list[Any]) -> np.ndarray: """Return beam monitor intensities aligned with the selected projections.""" intensities = np.full(len(data_urls), np.nan, dtype=np.float32) nx_path = Path(self.inputs.nx_path) try: with h5py.File(str(nx_path), "r") as h5f: control_data = np.asarray( h5f["entry0000/control/data"][()], dtype=np.float32 ) except Exception: return intensities for i, data_url in enumerate(data_urls): proj_idx = _projection_index_from_data_url(data_url) if proj_idx is None: continue if 0 <= proj_idx < control_data.size: intensities[i] = control_data[proj_idx] return intensities
def _save_projections_gif( self, gallery_dir: Path, nx_path: Path, bounds: tuple[float, float] | None ) -> None: gif_angle_step = 10.0 gif_size = 200 gif_fps = 12 gif_duration = 1000 angles, corrected_projections, bounds = self._get_gallery_projections( angle_step=gif_angle_step, bounds=bounds, ) if not angles: return frames = [ _prepare_gallery_image(projection, bounds, target_size=gif_size) for projection in corrected_projections ] if not frames: return gif_path = gallery_dir / f"{nx_path.stem}.gif" overwrite = getattr(self, "gallery_overwrite", True) if not overwrite and gif_path.exists(): raise OSError(f"File already exists ({gif_path})") frames[0].save( gif_path, save_all=True, append_images=frames[1:], duration=int(round(gif_duration / gif_fps)), loop=0, )
[docs] class BuildSlicesGallery( # type: ignore[call-arg] Task, input_model=SlicesGalleryInputModel, output_model=SlicesGalleryOutputModel, ): """Create two gallery images from a reconstructed slice (full-size suffixed `_large`, and a 200x200 resized version). The large image is downsampled if needed so neither dimension exceeds the configured image_size (default 1000 px). """
[docs] def run(self): """Read the slice, normalize/downsample, and save to <processed>/gallery.""" fmt = str(self.inputs.output_format) overwrite = bool(self.inputs.overwrite) image_size = int(self.inputs.image_size) bounds = self.inputs.bounds slice_path = Path(self.inputs.reconstructed_slice_path) if not slice_path.exists(): raise FileNotFoundError(f"Reconstructed slice not found: {slice_path}") processed_data_dir = slice_path.parent gallery_dir = Path(processed_data_dir) / "gallery" os.makedirs(gallery_dir, exist_ok=True) arr = self._load_slice(slice_path) if not isinstance(bounds, tuple): bounds = _auto_intensity_bounds(arr) out_name = self.get_gallery_file_path(gallery_dir, slice_path, fmt) out_path = gallery_dir / out_name _save_to_gallery( out_path, arr, bounds, overwrite=overwrite, image_size=image_size, output_format=fmt, ) self.outputs.processed_data_dir = str(processed_data_dir) self.outputs.gallery_path = str(gallery_dir) self.outputs.gallery_image_path = str(out_path)
@staticmethod def _load_slice(img_path: Path) -> np.ndarray: """Load a 2D float32 slice (HDF5 at entry0000/reconstruction/results/data, EDF, or image).""" ext = img_path.suffix.lower() if ext in (".h5", ".hdf5"): with h5py.File(img_path, "r") as h5in: img = h5in["entry0000/reconstruction/results/data"][:] return np.squeeze(img).astype(np.float32) if ext == ".edf": try: import fabio # type: ignore except Exception as exc: raise RuntimeError( "EDF support requires 'fabio' (pip install fabio)." ) from exc return fabio.open(str(img_path)).data.astype(np.float32) with Image.open(img_path) as im: arr = np.array(im, dtype=np.float32) if arr.ndim == 3 and arr.shape[-1] in (3, 4): arr = arr[..., :3].mean(axis=-1) return arr
[docs] class BuildVolumeGallery( # type: ignore[call-arg] Task, input_model=VolumeGalleryInputModel, output_model=VolumeGalleryOutputModel, ): """Create gallery images from a reconstructed 3D volume. For each axis (X, Y, Z), extract slices at 1/4, 2/4 and 3/4 of the volume extent and save them to the gallery directory. """
[docs] def run(self): fmt = str(self.inputs.output_format) overwrite = bool(self.inputs.overwrite) image_size = int(self.inputs.image_size) bounds = self.inputs.bounds volume_path = Path(self.inputs.reconstructed_volume_path) if not volume_path.exists(): raise FileNotFoundError(f"Reconstructed volume not found: {volume_path}") volume_container = TomoPath(volume_path.parent) if volume_container.template_name == "volumes_custom_type_file": gallery_dir = Path(volume_container.volumes_custom_type_gallery_path) gallery_template_name = "volumes_custom_type_preview" else: gallery_dir = Path(volume_container.volumes_gallery_path) gallery_template_name = "volumes_preview" processed_data_dir = volume_path.parent os.makedirs(gallery_dir, exist_ok=True) slices_by_axis, axis_indices = self._load_gallery_slices(volume_path) gallery_slices = [ slices_by_axis[slicing_direction][slice_index] for slicing_direction in ("xy", "xz", "yz") for slice_index in axis_indices[slicing_direction] ] if not isinstance(bounds, tuple): bounds = _auto_intensity_bounds(gallery_slices) gallery_paths = [] for slicing_direction in ("xy", "xz", "yz"): for slice_index in axis_indices[slicing_direction]: slice_2d = slices_by_axis[slicing_direction][slice_index] out_path = Path( volume_container.replace_fields( template_name=gallery_template_name, tomo_slicing_direction=slicing_direction, tomo_slice_number=slice_index, thumbnail_file_type=fmt, ) ) out_path = out_path.with_name( # Drop the dataset prefix and keep the reconstruction descriptor. re.sub( r"^.*?_(absorption|phase)_", r"\1_", out_path.name, count=1, ) ) out_path.parent.mkdir(parents=True, exist_ok=True) _save_to_gallery( out_path, slice_2d, bounds, overwrite=overwrite, image_size=image_size, output_format=fmt, ) gallery_paths.append(str(out_path)) self.outputs.processed_data_dir = str(processed_data_dir) self.outputs.gallery_path = str(gallery_dir) self.outputs.gallery_image_paths = gallery_paths
@staticmethod def _as_grayscale(frame: Image.Image) -> np.ndarray: arr = np.array(frame, dtype=np.float32) if arr.ndim == 3 and arr.shape[-1] in (3, 4): arr = arr[..., :3].mean(axis=-1) if arr.ndim != 2: raise ValueError( f"Only 2D grayscale TIFF frames are handled. Got frame shape={arr.shape}" ) return arr @classmethod def _load_gallery_slices( cls, volume_path: Path ) -> tuple[dict[str, dict[int, np.ndarray]], dict[str, list[int]]]: if volume_path.suffix.lower() not in (".tif", ".tiff"): raise ValueError( f"Unsupported volume format for gallery: {volume_path}. Only TIFF3D is supported." ) fractions = (0.25, 0.5, 0.75) def idxs(n: int) -> list[int]: return [min(n - 1, max(0, int(n * frac))) for frac in fractions] with Image.open(volume_path) as im: z_size = int(getattr(im, "n_frames", 1)) if z_size <= 0: raise ValueError(f"No frames found in TIFF file: {volume_path}") im.seek(0) y_size, x_size = cls._as_grayscale(im).shape axis_indices = {"xy": idxs(z_size), "xz": idxs(y_size), "yz": idxs(x_size)} unique_indices = {k: sorted(set(v)) for k, v in axis_indices.items()} xy_slices: dict[int, np.ndarray] = {} xz_slices = { y_idx: np.empty((z_size, x_size), dtype=np.float32) for y_idx in unique_indices["xz"] } yz_slices = { x_idx: np.empty((z_size, y_size), dtype=np.float32) for x_idx in unique_indices["yz"] } for z_idx in range(z_size): im.seek(z_idx) frame_arr = cls._as_grayscale(im) if z_idx in unique_indices["xy"]: xy_slices[z_idx] = frame_arr.copy() for y_idx in unique_indices["xz"]: xz_slices[y_idx][z_idx, :] = frame_arr[y_idx, :] for x_idx in unique_indices["yz"]: yz_slices[x_idx][z_idx, :] = frame_arr[:, x_idx] return {"xy": xy_slices, "xz": xz_slices, "yz": yz_slices}, axis_indices