from __future__ import annotations
from typing import Any
import logging
from ewokscore import Task
from ewokscore.model import BaseInputModel
from ewokscore.model import BaseOutputModel
from pydantic import Field
from esrf_pathlib import ESRFPath
from pyicat_plus.client.main import IcatClient
from pyicat_plus.client import defaults
from icat_esrf_definitions.models import IcatDatasetParameters
from esrf_ontologies import technique
logger = logging.getLogger(__name__)
RECON_KEY_MAP = {
"angles": "TOMOReconstruction_angles_file",
}
WORKFLOW_TYPES = ("slices", "projections", "volumes")
WORKFLOW_TYPE_METADATA_KEY = "Workflow_type"
ICAT_METADATA_FIELD_NAMES = set(IcatDatasetParameters.icat_field_names())
def _remove_invalid_icat_metadata_keys(metadata: dict[str, Any]) -> dict[str, Any]:
"""Remove metadata keys that are not accepted by ICAT."""
metadata = dict(metadata)
invalid_keys = sorted(
key for key in metadata if key not in ICAT_METADATA_FIELD_NAMES
)
if not invalid_keys:
return metadata
logger.warning(
"ICAT metadata keys are not compatible and will not be sent: %s",
invalid_keys,
)
for key in invalid_keys:
del metadata[key]
return metadata
def _build_icat_payload(
folder_path: str,
metadata_in: dict[str, Any] | None,
dataset_in: str | None = None,
) -> dict[str, Any]:
"""Parse the path, normalize metadata, and return the ICAT call payload."""
processed_path = ESRFPath(folder_path)
if processed_path.schema_name is None:
raise ValueError(f"Unknown ESRF path schema: {folder_path}")
if processed_path.data_type.value != "processed":
raise ValueError(f"Not a PROCESSED_DATA path: {folder_path}")
collection = processed_path.collection
dataset_from_path = processed_path.dataset
dataset = dataset_from_path if dataset_in is None else dataset_in
if metadata_in is None:
metadata = {"Sample_name": collection}
else:
metadata = dict(metadata_in)
metadata.setdefault("Sample_name", collection)
dataset_lower = str(dataset).lower()
if dataset_lower in WORKFLOW_TYPES:
metadata[WORKFLOW_TYPE_METADATA_KEY] = dataset_lower
return {
"beamline": processed_path.beamline,
"proposal": processed_path.proposal,
"dataset": dataset,
"path": str(processed_path),
"raw": [str(processed_path.raw_dataset_path)],
"metadata": metadata,
}
def _build_dataportal_metadata(
processing_options: dict[str, Any] | None,
) -> dict[str, Any]:
"""Convert processing options into Data Portal metadata."""
if not isinstance(processing_options, dict):
processing_options = {}
metadata: dict[str, Any] = {WORKFLOW_TYPE_METADATA_KEY: "slices"}
reconstruction_values = processing_options.get("reconstruction", {})
excluded_recon_keys = {
"cor_estimated_auto",
"crop_filtered_data",
"hbp_legs",
"hbp_reduction_steps",
"iterations",
"outer_circle_value",
"position",
}
if isinstance(reconstruction_values, dict):
for key, value in reconstruction_values.items():
if key in excluded_recon_keys:
continue
if isinstance(value, str):
value = value.strip()
if value is None:
continue
if isinstance(value, str) and value == "":
continue
if key == "voxel_size_cm":
if isinstance(value, (list, tuple)) and len(value) == 3:
try:
voxel_size_z = float(value[0]) * 1e4 # cm -> um
voxel_size_y = float(value[1]) * 1e4
voxel_size_x = float(value[2]) * 1e4
except (TypeError, ValueError):
continue
else:
continue
metadata["TOMOReconstruction_voxel_size_x"] = voxel_size_x
metadata["TOMOReconstruction_voxel_size_y"] = voxel_size_y
metadata["TOMOReconstruction_voxel_size_z"] = voxel_size_z
else:
target = RECON_KEY_MAP.get(key, f"TOMOReconstruction_{key}")
metadata[target] = value
phase_values = processing_options.get("phase", {})
phase_method = None
if isinstance(phase_values, dict):
excluded_phase_keys = {
"distance_cm",
"energy_kev",
"pixel_size_microns",
"pixel_size_m",
}
phase_method = phase_values.get("method")
if isinstance(phase_method, str):
phase_method = phase_method.strip()
if phase_method.lower() == "none":
phase_method = ""
for key, value in phase_values.items():
if key in excluded_phase_keys:
continue
if isinstance(value, str):
value = value.strip()
if value is None:
continue
if isinstance(value, str) and value == "":
continue
if key == "distance_m":
metadata["TOMOReconstructionPhase_detector_sample_distance"] = (
value * 1000.0
) # m -> mm
else:
metadata[f"TOMOReconstructionPhase_{key}"] = value
if phase_method in (None, ""):
technique_metadata = technique.get_technique_metadata("XRCT")
else:
technique_metadata = technique.get_technique_metadata("XPCT")
metadata.update(technique_metadata.get_dataset_metadata())
return metadata
[docs]
class DataPortalUpload( # type: ignore[call-arg]
Task,
input_model=DataPortalUploadInputModel,
):
"""(ESRF-only) Upload a processed dataset folder to the Data Portal using pyicat_plus."""
icat_client_factory = staticmethod(
lambda: IcatClient(metadata_urls=defaults.METADATA_BROKERS)
)
[docs]
def run(self):
folder_path = self.inputs.process_folder_path
metadata_in = self.inputs.metadata
dataset_in = self.inputs.dataset
dry_run = self.inputs.dry_run
try:
payload = _build_icat_payload(folder_path, metadata_in, dataset_in)
payload["metadata"] = _remove_invalid_icat_metadata_keys(
payload["metadata"]
)
if dry_run:
logger.info(
"Dry-run: would store_processed_data "
"proposal=%s beamline=%s dataset=%s path=%s raw=%s metadata=%s",
payload["proposal"],
payload["beamline"],
payload["dataset"],
payload["path"],
payload["raw"],
payload["metadata"],
extra={"dp_payload": payload},
)
return
client = self.icat_client_factory()
try:
client.store_processed_data(**payload)
self.icat_status = "stored"
finally:
try:
client.disconnect()
except Exception:
logger.warning("Failed to disconnect ICAT client")
except ValueError as e:
logger.warning("DataPortalUpload skipped: %s", e)
except Exception as e:
logger.warning("Error in DataPortalUpload: %s", e)