Source code for ewokstomo.tasks.dataportalupload

from __future__ import annotations
from typing import Any
import logging

from ewokscore import Task
from ewokscore.model import BaseInputModel
from ewokscore.model import BaseOutputModel
from pydantic import Field
from esrf_pathlib import ESRFPath
from pyicat_plus.client.main import IcatClient
from pyicat_plus.client import defaults
from icat_esrf_definitions.models import IcatDatasetParameters
from esrf_ontologies import technique

logger = logging.getLogger(__name__)


RECON_KEY_MAP = {
    "angles": "TOMOReconstruction_angles_file",
}

WORKFLOW_TYPES = ("slices", "projections", "volumes")
WORKFLOW_TYPE_METADATA_KEY = "Workflow_type"
ICAT_METADATA_FIELD_NAMES = set(IcatDatasetParameters.icat_field_names())


def _remove_invalid_icat_metadata_keys(metadata: dict[str, Any]) -> dict[str, Any]:
    """Remove metadata keys that are not accepted by ICAT."""
    metadata = dict(metadata)
    invalid_keys = sorted(
        key for key in metadata if key not in ICAT_METADATA_FIELD_NAMES
    )
    if not invalid_keys:
        return metadata

    logger.warning(
        "ICAT metadata keys are not compatible and will not be sent: %s",
        invalid_keys,
    )
    for key in invalid_keys:
        del metadata[key]
    return metadata


def _build_icat_payload(
    folder_path: str,
    metadata_in: dict[str, Any] | None,
    dataset_in: str | None = None,
) -> dict[str, Any]:
    """Parse the path, normalize metadata, and return the ICAT call payload."""
    processed_path = ESRFPath(folder_path)

    if processed_path.schema_name is None:
        raise ValueError(f"Unknown ESRF path schema: {folder_path}")

    if processed_path.data_type.value != "processed":
        raise ValueError(f"Not a PROCESSED_DATA path: {folder_path}")

    collection = processed_path.collection
    dataset_from_path = processed_path.dataset

    dataset = dataset_from_path if dataset_in is None else dataset_in

    if metadata_in is None:
        metadata = {"Sample_name": collection}
    else:
        metadata = dict(metadata_in)
        metadata.setdefault("Sample_name", collection)

    dataset_lower = str(dataset).lower()
    if dataset_lower in WORKFLOW_TYPES:
        metadata[WORKFLOW_TYPE_METADATA_KEY] = dataset_lower

    return {
        "beamline": processed_path.beamline,
        "proposal": processed_path.proposal,
        "dataset": dataset,
        "path": str(processed_path),
        "raw": [str(processed_path.raw_dataset_path)],
        "metadata": metadata,
    }


def _build_dataportal_metadata(
    processing_options: dict[str, Any] | None,
) -> dict[str, Any]:
    """Convert processing options into Data Portal metadata."""
    if not isinstance(processing_options, dict):
        processing_options = {}

    metadata: dict[str, Any] = {WORKFLOW_TYPE_METADATA_KEY: "slices"}

    reconstruction_values = processing_options.get("reconstruction", {})
    excluded_recon_keys = {
        "cor_estimated_auto",
        "crop_filtered_data",
        "hbp_legs",
        "hbp_reduction_steps",
        "iterations",
        "outer_circle_value",
        "position",
    }
    if isinstance(reconstruction_values, dict):
        for key, value in reconstruction_values.items():
            if key in excluded_recon_keys:
                continue
            if isinstance(value, str):
                value = value.strip()
            if value is None:
                continue
            if isinstance(value, str) and value == "":
                continue

            if key == "voxel_size_cm":
                if isinstance(value, (list, tuple)) and len(value) == 3:
                    try:
                        voxel_size_z = float(value[0]) * 1e4  # cm -> um
                        voxel_size_y = float(value[1]) * 1e4
                        voxel_size_x = float(value[2]) * 1e4
                    except (TypeError, ValueError):
                        continue
                else:
                    continue
                metadata["TOMOReconstruction_voxel_size_x"] = voxel_size_x
                metadata["TOMOReconstruction_voxel_size_y"] = voxel_size_y
                metadata["TOMOReconstruction_voxel_size_z"] = voxel_size_z
            else:
                target = RECON_KEY_MAP.get(key, f"TOMOReconstruction_{key}")
                metadata[target] = value

    phase_values = processing_options.get("phase", {})

    phase_method = None
    if isinstance(phase_values, dict):
        excluded_phase_keys = {
            "distance_cm",
            "energy_kev",
            "pixel_size_microns",
            "pixel_size_m",
        }
        phase_method = phase_values.get("method")
        if isinstance(phase_method, str):
            phase_method = phase_method.strip()
            if phase_method.lower() == "none":
                phase_method = ""

        for key, value in phase_values.items():
            if key in excluded_phase_keys:
                continue
            if isinstance(value, str):
                value = value.strip()
            if value is None:
                continue
            if isinstance(value, str) and value == "":
                continue

            if key == "distance_m":
                metadata["TOMOReconstructionPhase_detector_sample_distance"] = (
                    value * 1000.0
                )  # m -> mm
            else:
                metadata[f"TOMOReconstructionPhase_{key}"] = value

    if phase_method in (None, ""):
        technique_metadata = technique.get_technique_metadata("XRCT")
    else:
        technique_metadata = technique.get_technique_metadata("XPCT")
    metadata.update(technique_metadata.get_dataset_metadata())

    return metadata


[docs] class BuildDataPortalMetadataInputModel(BaseInputModel): processing_options: dict[str, Any] | None = Field( default=None, description="Nabu processing options dictionary to convert into Data Portal metadata.", )
[docs] class BuildDataPortalMetadataOutputModel(BaseOutputModel): dataportal_metadata: dict[str, Any] = Field( ..., description="Generated Data Portal metadata dictionary.", )
[docs] class DataPortalUploadInputModel(BaseInputModel): process_folder_path: str = Field( ..., description="Path to the processed dataset folder to upload.", ) metadata: dict[str, Any] | None = Field( default=None, description="Optional metadata dictionary to include in the upload.", ) dry_run: bool = Field( default=False, description="If True, simulate the upload without performing it.", ) dataset: str | None = Field( default=None, description="Optional dataset name to use for the upload.", )
[docs] class BuildDataPortalMetadata( # type: ignore[call-arg] Task, input_model=BuildDataPortalMetadataInputModel, output_model=BuildDataPortalMetadataOutputModel, ): """(ESRF-only) Convert Nabu processing options into Data Portal metadata."""
[docs] def run(self): processing_options = self.inputs.processing_options if not isinstance(processing_options, dict): logger.warning("Invalid processing_options provided; using empty metadata") processing_options = None self.outputs.dataportal_metadata = _build_dataportal_metadata( processing_options )
[docs] class DataPortalUpload( # type: ignore[call-arg] Task, input_model=DataPortalUploadInputModel, ): """(ESRF-only) Upload a processed dataset folder to the Data Portal using pyicat_plus.""" icat_client_factory = staticmethod( lambda: IcatClient(metadata_urls=defaults.METADATA_BROKERS) )
[docs] def run(self): folder_path = self.inputs.process_folder_path metadata_in = self.inputs.metadata dataset_in = self.inputs.dataset dry_run = self.inputs.dry_run try: payload = _build_icat_payload(folder_path, metadata_in, dataset_in) payload["metadata"] = _remove_invalid_icat_metadata_keys( payload["metadata"] ) if dry_run: logger.info( "Dry-run: would store_processed_data " "proposal=%s beamline=%s dataset=%s path=%s raw=%s metadata=%s", payload["proposal"], payload["beamline"], payload["dataset"], payload["path"], payload["raw"], payload["metadata"], extra={"dp_payload": payload}, ) return client = self.icat_client_factory() try: client.store_processed_data(**payload) self.icat_status = "stored" finally: try: client.disconnect() except Exception: logger.warning("Failed to disconnect ICAT client") except ValueError as e: logger.warning("DataPortalUpload skipped: %s", e) except Exception as e: logger.warning("Error in DataPortalUpload: %s", e)