diff --git a/pyproject.toml b/pyproject.toml index 1dd9dfc..5d6f7da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,9 +37,11 @@ classifiers = [ ] dependencies = [ + "arraybridge>=0.2.9", "numpy>=1.26.0", "portalocker>=2.8.0", # Cross-platform file locking "metaclass-registry", + "imageio>=2.37.0", "zarr>=2.18.0,<3.0", # Required for ZarrStorageBackend "ome-zarr>=0.11.0", # Required for OME-ZARR HCS compliance ] @@ -197,4 +199,4 @@ ignore = [ ] [tool.ruff.per-file-ignores] -"__init__.py" = ["F401"] # unused imports \ No newline at end of file +"__init__.py" = ["F401"] # unused imports diff --git a/src/polystore/__init__.py b/src/polystore/__init__.py index 5c38d68..123c449 100644 --- a/src/polystore/__init__.py +++ b/src/polystore/__init__.py @@ -26,10 +26,10 @@ get_backend, ) from .constants import Backend, MemoryType, TransportMode -from .disk import DiskStorageBackend +from .disk import DiskBackend, DiskStorageBackend from .filemanager import FileManager from .formats import FileFormat, DEFAULT_IMAGE_EXTENSIONS -from .memory import MemoryStorageBackend +from .memory import MemoryBackend, MemoryStorageBackend from .metadata_writer import ( AtomicMetadataWriter, MetadataWriteError, @@ -76,7 +76,9 @@ "register_cleanup_callback", "STORAGE_BACKENDS", "DiskStorageBackend", + "DiskBackend", "MemoryStorageBackend", + "MemoryBackend", "FileManager", "file_lock", "atomic_write_json", diff --git a/src/polystore/base.py b/src/polystore/base.py index 2b033fc..e18849e 100644 --- a/src/polystore/base.py +++ b/src/polystore/base.py @@ -546,15 +546,16 @@ def reset_memory_backend() -> None: # Clear files from existing memory backend while preserving directories memory_backend = storage_registry[Backend.MEMORY.value] - # DEBUG: Log what's in memory before clearing existing_keys = list(memory_backend._memory_store.keys()) - logger.info(f"🔍 VFS_CLEAR: Memory backend has {len(existing_keys)} entries BEFORE clear") - logger.info(f"🔍 VFS_CLEAR: First 10 keys: {existing_keys[:10]}") + logger.debug("Memory backend has %s entries before clear", len(existing_keys)) + logger.debug("First memory backend keys before clear: %s", existing_keys[:10]) memory_backend.clear_files_only() - # DEBUG: Log what's in memory after clearing remaining_keys = list(memory_backend._memory_store.keys()) - logger.info(f"🔍 VFS_CLEAR: Memory backend has {len(remaining_keys)} entries AFTER clear (directories only)") - logger.info(f"🔍 VFS_CLEAR: First 10 remaining keys: {remaining_keys[:10]}") + logger.debug( + "Memory backend has %s entries after clear (directories only)", + len(remaining_keys), + ) + logger.debug("First memory backend keys after clear: %s", remaining_keys[:10]) logger.info("Memory backend reset - files cleared, directories preserved") diff --git a/src/polystore/disk.py b/src/polystore/disk.py index 40c33d9..ca24e7c 100644 --- a/src/polystore/disk.py +++ b/src/polystore/disk.py @@ -9,6 +9,7 @@ import logging import os import shutil +import importlib from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Set, Union @@ -23,7 +24,7 @@ def optional_import(module_name): try: - return __import__(module_name) + return importlib.import_module(module_name) except ImportError: return None @@ -44,6 +45,7 @@ def optional_import(module_name): cupy = get_cupy() tf = get_tf() tifffile = optional_import("tifffile") +imageio = optional_import("imageio.v3") # Optional arraybridge integration for memory conversion try: @@ -99,6 +101,7 @@ def _register_formats(self): # Complex formats - use custom handlers (FileFormat.TIFF, tifffile, self._tiff_writer, self._tiff_reader), + (FileFormat.RASTER_IMAGE, imageio, self._image_writer, self._image_reader), (FileFormat.TEXT, True, self._text_writer, self._text_reader), (FileFormat.JSON, True, self._json_writer, self._json_reader), (FileFormat.CSV, True, self._csv_writer, self._csv_reader), @@ -164,6 +167,14 @@ def _tiff_reader(self, path): else: return tifffile.imread(str(path)) + def _image_writer(self, path, data, **kwargs): + """Write standard raster images using imageio.""" + imageio.imwrite(path, np.asarray(data)) + + def _image_reader(self, path): + """Read standard raster images using imageio.""" + return imageio.imread(path) + def _text_writer(self, path, data, **kwargs): """Write text data to file. Accepts and ignores extra kwargs for compatibility.""" path.write_text(str(data)) @@ -261,7 +272,7 @@ def load(self, file_path: Union[str, Path], **kwargs) -> Any: ext = disk_path.suffix.lower() if not self.format_registry.is_registered(ext): - raise ValueError(f"No writer registered for extension '{ext}'") + raise ValueError(f"No reader registered for extension '{ext}'") try: reader = self.format_registry.get_reader(ext) @@ -823,3 +834,6 @@ def _save_rois(self, rois: List, output_path: Path, images_dir: str = None, **kw logger.info(f"Saved {roi_count} ROIs to .roi.zip archive: {output_path}") return str(output_path) + + +DiskBackend = DiskStorageBackend diff --git a/src/polystore/fiji_stream.py b/src/polystore/fiji_stream.py index 4d52817..08132bc 100644 --- a/src/polystore/fiji_stream.py +++ b/src/polystore/fiji_stream.py @@ -31,12 +31,9 @@ class FijiStreamingBackend(StreamingBackend): """Fiji streaming backend with ZMQ publisher pattern (matches Napari architecture).""" _backend_type = Backend.FIJI_STREAM.value - # Configure ABC attributes VIEWER_TYPE = 'fiji' SHM_PREFIX = 'fiji_' - # __init__, _get_publisher, save, cleanup now inherited from ABC - def _prepare_rois_data(self, data: Any, file_path: Union[str, Path]) -> dict: """ Prepare ROIs data for transmission. @@ -90,6 +87,7 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * source = kwargs.get('source', 'unknown_source') # Pre-built source value images_dir = kwargs.get('images_dir') # Source image subdirectory for ROI mapping plate_path = kwargs.get('plate_path') + component_metadata = kwargs.get('component_metadata') logger.info(f"🏷️ FIJI BACKEND: plate_path = {plate_path}") logger.info(f"🏷️ FIJI BACKEND: microscope_handler = {microscope_handler}") display_payload_extra = { @@ -108,6 +106,7 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * display_config, self._prepare_batch_item, plate_path=plate_path, + component_metadata=component_metadata, component_names_kwargs={"log_prefix": "🏷️ FIJI BACKEND", "verbose": True}, display_payload_extra=display_payload_extra, message_extra=message_extra, diff --git a/src/polystore/formats.py b/src/polystore/formats.py index ddfb9a5..3643361 100644 --- a/src/polystore/formats.py +++ b/src/polystore/formats.py @@ -20,6 +20,7 @@ class FileFormat(Enum): # Image formats TIFF = "tiff" + RASTER_IMAGE = "raster_image" # Data formats CSV = "csv" @@ -44,6 +45,7 @@ def extensions(self): FileFormat.TENSORFLOW: [".tf"], FileFormat.ZARR: [".zarr"], FileFormat.TIFF: [".tif", ".tiff"], + FileFormat.RASTER_IMAGE: [".bmp", ".gif", ".jpeg", ".jpg", ".png"], FileFormat.CSV: [".csv"], FileFormat.JSON: [".json"], FileFormat.TEXT: [".txt"], @@ -51,7 +53,14 @@ def extensions(self): } # Default image extensions -DEFAULT_IMAGE_EXTENSIONS = {".tif", ".tiff", ".TIF", ".TIFF"} +DEFAULT_IMAGE_EXTENSIONS = { + extension + for extensions in ( + FILE_FORMAT_EXTENSIONS[FileFormat.TIFF], + FILE_FORMAT_EXTENSIONS[FileFormat.RASTER_IMAGE], + ) + for extension in extensions +} def get_format_from_extension(ext: str) -> FileFormat: diff --git a/src/polystore/memory.py b/src/polystore/memory.py index a59114f..872d581 100644 --- a/src/polystore/memory.py +++ b/src/polystore/memory.py @@ -139,6 +139,9 @@ def list_files( if self._memory_store[dir_key] is not None: raise NotADirectoryError(f"Path is not a directory: {directory}") + lowercase_extensions = ( + None if extensions is None else {extension.lower() for extension in extensions} + ) result = [] dir_prefix = dir_key + "/" if not dir_key.endswith("/") else dir_key @@ -159,7 +162,10 @@ def list_files( filename = Path(rel_path).name # If pattern is None, match all files if pattern is None or fnmatch(filename, pattern): - if not extensions or Path(filename).suffix in extensions: + if ( + lowercase_extensions is None + or Path(filename).suffix.lower() in lowercase_extensions + ): # Calculate depth for breadth-first sorting depth = rel_path.count('/') result.append((Path(path), depth)) @@ -651,3 +657,6 @@ def __init__(self, target: str): def __repr__(self): return f"" + + +MemoryBackend = MemoryStorageBackend diff --git a/src/polystore/napari_stream.py b/src/polystore/napari_stream.py index 630bcc8..d762cd6 100644 --- a/src/polystore/napari_stream.py +++ b/src/polystore/napari_stream.py @@ -20,7 +20,6 @@ import zmq from .constants import Backend, TransportMode -from .streaming_constants import StreamingDataType from .streaming import StreamingBackend from .roi_converters import NapariROIConverter from zmqruntime.transport import get_zmq_transport_url, coerce_transport_mode @@ -32,12 +31,9 @@ class NapariStreamingBackend(StreamingBackend): """Napari streaming backend with automatic registration.""" _backend_type = Backend.NAPARI_STREAM.value - # Configure ABC attributes VIEWER_TYPE = 'napari' SHM_PREFIX = 'napari_' - # __init__, _get_publisher, save, cleanup now inherited from ABC - def _prepare_shapes_data(self, data: Any, file_path: Union[str, Path]) -> dict: """ Prepare shapes data for transmission. @@ -57,7 +53,7 @@ def _prepare_shapes_data(self, data: Any, file_path: Union[str, Path]) -> dict: } def _prepare_batch_item(self, data: Any, file_path: Union[str, Path], data_type): - if data_type in (StreamingDataType.SHAPES, StreamingDataType.POINTS): + if data_type.uses_napari_vector_payload: item_data = self._prepare_shapes_data(data, file_path) data_type_value = data_type.value else: @@ -88,6 +84,7 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * microscope_handler = kwargs['microscope_handler'] source = kwargs.get('source', 'unknown_source') # Pre-built source value plate_path = kwargs.get('plate_path') + component_metadata = kwargs.get('component_metadata') display_payload_extra = { "colormap": display_config.get_colormap_name(), "variable_size_handling": display_config.variable_size_handling.value @@ -103,6 +100,7 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * display_config, self._prepare_batch_item, plate_path=plate_path, + component_metadata=component_metadata, display_payload_extra=display_payload_extra, ) diff --git a/src/polystore/roi.py b/src/polystore/roi.py index fb6bdb6..d841591 100644 --- a/src/polystore/roi.py +++ b/src/polystore/roi.py @@ -6,12 +6,14 @@ """ import logging +from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union import numpy as np +from metaclass_registry import AutoRegisterMeta from .constants import Backend @@ -27,8 +29,14 @@ class ShapeType(Enum): ELLIPSE = "ellipse" +class ROIShape(ABC): + """Nominal base for all ROI shape records.""" + + shape_type: ShapeType + + @dataclass(frozen=True) -class PolygonShape: +class PolygonShape(ROIShape): """Polygon ROI shape defined by vertex coordinates.""" coordinates: np.ndarray # Nx2 array of (y, x) coordinates shape_type: ShapeType = field(default=ShapeType.POLYGON, init=False) @@ -41,7 +49,7 @@ def __post_init__(self): @dataclass(frozen=True) -class PolylineShape: +class PolylineShape(ROIShape): """Polyline ROI shape defined by path coordinates (open path, not closed polygon).""" coordinates: np.ndarray # Nx2 array of (y, x) coordinates shape_type: ShapeType = field(default=ShapeType.POLYLINE, init=False) @@ -54,7 +62,7 @@ def __post_init__(self): @dataclass(frozen=True) -class MaskShape: +class MaskShape(ROIShape): """Binary mask ROI shape.""" mask: np.ndarray # 2D boolean array bbox: Tuple[int, int, int, int] # (min_y, min_x, max_y, max_x) @@ -68,7 +76,7 @@ def __post_init__(self): @dataclass(frozen=True) -class PointShape: +class PointShape(ROIShape): """Point ROI shape.""" y: float x: float @@ -76,7 +84,7 @@ class PointShape: @dataclass(frozen=True) -class EllipseShape: +class EllipseShape(ROIShape): """Ellipse ROI shape.""" center_y: float center_x: float @@ -95,14 +103,82 @@ def __post_init__(self): if not self.shapes: raise ValueError("ROI must have at least one shape") for shape in self.shapes: - if not hasattr(shape, "shape_type"): - raise ValueError(f"Shape {shape} must have shape_type attribute") + if not isinstance(shape, ROIShape): + raise ValueError(f"Shape {shape} must be an ROIShape") + + +class ROIJsonShapeDecoder(ABC, metaclass=AutoRegisterMeta): + """Decode one serialized ROI shape variant.""" + + __registry_key__ = "shape_type" + __skip_if_no_key__ = True + + shape_type: ClassVar[ShapeType | None] = None + + @classmethod + def for_serialized_shape(cls, shape_dict: Dict[str, Any]) -> "ROIJsonShapeDecoder | None": + shape_type = shape_dict.get("type") + try: + shape_key = ShapeType(shape_type) + except ValueError: + logger.warning(f"Unknown shape type: {shape_type}, skipping") + return None + return cls.__registry__[shape_key]() + + @abstractmethod + def decode(self, shape_dict: Dict[str, Any]) -> Any: + """Return the concrete ROI shape represented by ``shape_dict``.""" + + +class PolygonROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.POLYGON + + def decode(self, shape_dict: Dict[str, Any]) -> PolygonShape: + return PolygonShape(coordinates=np.array(shape_dict["coordinates"])) + + +class PolylineROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.POLYLINE + + def decode(self, shape_dict: Dict[str, Any]) -> PolylineShape: + return PolylineShape(coordinates=np.array(shape_dict["coordinates"])) + + +class MaskROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.MASK + + def decode(self, shape_dict: Dict[str, Any]) -> MaskShape: + return MaskShape( + mask=np.array(shape_dict["mask"], dtype=bool), + bbox=tuple(shape_dict["bbox"]), + ) + + +class PointROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.POINT + + def decode(self, shape_dict: Dict[str, Any]) -> PointShape: + return PointShape(y=shape_dict["y"], x=shape_dict["x"]) + + +class EllipseROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.ELLIPSE + + def decode(self, shape_dict: Dict[str, Any]) -> EllipseShape: + return EllipseShape( + center_y=shape_dict["center_y"], + center_x=shape_dict["center_x"], + radius_y=shape_dict["radius_y"], + radius_x=shape_dict["radius_x"], + ) def extract_rois_from_labeled_mask( labeled_mask: np.ndarray, min_area: int = 10, extract_contours: bool = True, + spatial_origin_yx: Optional[Tuple[int, int]] = None, + source_spatial_shape_yx: Optional[Tuple[int, int]] = None, ) -> List[ROI]: """Extract ROIs from a labeled segmentation mask.""" from skimage import measure @@ -117,19 +193,33 @@ def extract_rois_from_labeled_mask( regions = regionprops(labeled_mask) slices = find_objects(labeled_mask) + origin_y, origin_x = spatial_origin_yx or (0, 0) rois = [] for region in regions: if region.area < min_area: continue + min_y, min_x, max_y, max_x = region.bbox metadata = { "label": int(region.label), "area": float(region.area), "perimeter": float(region.perimeter), - "centroid": tuple(float(c) for c in region.centroid), - "bbox": tuple(int(b) for b in region.bbox), + "centroid": ( + float(region.centroid[0] + origin_y), + float(region.centroid[1] + origin_x), + ), + "bbox": ( + int(min_y + origin_y), + int(min_x + origin_x), + int(max_y + origin_y), + int(max_x + origin_x), + ), } + if source_spatial_shape_yx is not None: + metadata["source_spatial_shape_yx"] = tuple( + int(value) for value in source_spatial_shape_yx + ) shapes = [] if extract_contours: @@ -142,14 +232,14 @@ def extract_rois_from_labeled_mask( contours = measure.find_contours(padded_mask, level=0.5) offset_y = slice_y.start offset_x = slice_x.start - padding_offset = np.array([offset_y, offset_x]) - 1 + padding_offset = np.array([offset_y + origin_y, offset_x + origin_x]) - 1 for contour in contours: if len(contour) >= 3: contour_full = contour + padding_offset shapes.append(PolygonShape(coordinates=contour_full)) else: binary_mask = (labeled_mask == region.label) - shapes.append(MaskShape(mask=binary_mask, bbox=region.bbox)) + shapes.append(MaskShape(mask=binary_mask, bbox=metadata["bbox"])) if shapes: rois.append(ROI(shapes=shapes, metadata=metadata)) @@ -203,31 +293,9 @@ def load_rois_from_json(json_path: Path) -> List[ROI]: metadata = roi_dict.get("metadata", {}) shapes = [] for shape_dict in roi_dict.get("shapes", []): - shape_type = shape_dict.get("type") - - if shape_type == "polygon": - coordinates = np.array(shape_dict["coordinates"]) - shapes.append(PolygonShape(coordinates=coordinates)) - elif shape_type == "polyline": - coordinates = np.array(shape_dict["coordinates"]) - shapes.append(PolylineShape(coordinates=coordinates)) - elif shape_type == "mask": - mask = np.array(shape_dict["mask"], dtype=bool) - bbox = tuple(shape_dict["bbox"]) - shapes.append(MaskShape(mask=mask, bbox=bbox)) - elif shape_type == "point": - shapes.append(PointShape(y=shape_dict["y"], x=shape_dict["x"])) - elif shape_type == "ellipse": - shapes.append( - EllipseShape( - center_y=shape_dict["center_y"], - center_x=shape_dict["center_x"], - radius_y=shape_dict["radius_y"], - radius_x=shape_dict["radius_x"], - ) - ) - else: - logger.warning(f"Unknown shape type: {shape_type}, skipping") + decoder = ROIJsonShapeDecoder.for_serialized_shape(shape_dict) + if decoder is not None: + shapes.append(decoder.decode(shape_dict)) if shapes: rois.append(ROI(shapes=shapes, metadata=metadata)) diff --git a/src/polystore/roi_converters.py b/src/polystore/roi_converters.py index 46e8631..616e4c4 100644 --- a/src/polystore/roi_converters.py +++ b/src/polystore/roi_converters.py @@ -7,63 +7,184 @@ """ import logging -from typing import Any, Dict, List, Tuple +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Any, ClassVar, Dict, List, Tuple import numpy as np +from metaclass_registry import AutoRegisterMeta -from .roi import EllipseShape, PointShape, PolygonShape, PolylineShape, ROI -from .streaming_constants import NapariShapeType +from .roi import EllipseShape, PointShape, PolygonShape, PolylineShape, ROI, ShapeType logger = logging.getLogger(__name__) -class NapariROIConverter: - """Convert ROI objects to Napari shapes format.""" +@dataclass(frozen=True, slots=True) +class NapariShapeTypeAlias: + """Inert alias from Napari wire shape names to ROI shape types.""" + + alias: str + shape_type: ShapeType + + +NAPARI_SHAPE_TYPE_ALIASES = ( + NapariShapeTypeAlias("path", ShapeType.POLYLINE), + NapariShapeTypeAlias("points", ShapeType.POINT), +) + + +class NapariShapeConverter(ABC, metaclass=AutoRegisterMeta): + """Registered conversion behavior for one ROI shape type.""" + + __registry_key__ = "shape_type" + __skip_if_no_key__ = True + + shape_type: ClassVar[ShapeType | None] = None + + @classmethod + def for_shape_dict(cls, shape_dict: Dict[str, Any]) -> "NapariShapeConverter": + return cls.__registry__[_shape_type_from_napari(shape_dict["type"])]() + + def append_common_properties( + self, + metadata: Dict[str, Any], + properties: dict[str, list[Any]], + centroid: tuple[Any, Any], + *, + area: Any | None = None, + ) -> None: + properties["label"].append(metadata.get("label", "")) + properties["area"].append(metadata.get("area", 0) if area is None else area) + properties["centroid_y"].append(centroid[0]) + properties["centroid_x"].append(centroid[1]) + + @abstractmethod + def add_dimensions(self, shape_dict: Dict[str, Any], prepend_dims: np.ndarray) -> np.ndarray: + """Add dimensions to a 2D shape to make it nD.""" + + @abstractmethod + def append_napari_format( + self, + shape_dict: Dict[str, Any], + napari_shapes: list[np.ndarray], + shape_types: list[str], + properties: dict[str, list[Any]], + ) -> None: + """Append this shape to a Napari layer payload.""" + + +def _shape_type_from_napari(shape_type: object) -> ShapeType: + if isinstance(shape_type, ShapeType): + return shape_type + value = str(shape_type.value) if isinstance(shape_type, Enum) else str(shape_type) + for alias in NAPARI_SHAPE_TYPE_ALIASES: + if alias.alias == value: + return alias.shape_type + return ShapeType(value) + + +class CoordinateNapariShapeConverter(NapariShapeConverter): + """Shared converter for coordinate-list shapes.""" - _SHAPE_DIMENSION_HANDLERS = { - "polygon": lambda shape_dict, prepend_dims: np.hstack( - [np.tile(prepend_dims, (len(shape_dict["coordinates"]), 1)), np.array(shape_dict["coordinates"])] - ), - "polyline": lambda shape_dict, prepend_dims: np.hstack( - [np.tile(prepend_dims, (len(shape_dict["coordinates"]), 1)), np.array(shape_dict["coordinates"])] - ), - "ellipse": lambda shape_dict, prepend_dims: np.hstack( + napari_shape_type: ClassVar[str] + + def add_dimensions(self, shape_dict: Dict[str, Any], prepend_dims: np.ndarray) -> np.ndarray: + coordinates = np.array(shape_dict["coordinates"]) + return np.hstack([np.tile(prepend_dims, (len(coordinates), 1)), coordinates]) + + def append_napari_format( + self, + shape_dict: Dict[str, Any], + napari_shapes: list[np.ndarray], + shape_types: list[str], + properties: dict[str, list[Any]], + ) -> None: + metadata = shape_dict.get("metadata", {}) + napari_shapes.append(np.array(shape_dict["coordinates"])) + shape_types.append(self.napari_shape_type) + self.append_common_properties( + metadata, + properties, + metadata.get("centroid", (0, 0)), + ) + + +class PolygonNapariShapeConverter(CoordinateNapariShapeConverter): + shape_type = ShapeType.POLYGON + napari_shape_type = "polygon" + + +class PolylineNapariShapeConverter(CoordinateNapariShapeConverter): + shape_type = ShapeType.POLYLINE + napari_shape_type = "path" + + +class EllipseNapariShapeConverter(NapariShapeConverter): + shape_type = ShapeType.ELLIPSE + + def add_dimensions(self, shape_dict: Dict[str, Any], prepend_dims: np.ndarray) -> np.ndarray: + center = shape_dict["center"] + radii = shape_dict["radii"] + corners = np.array( [ - np.tile(prepend_dims, (4, 1)), - np.array( - [ - [ - shape_dict["center"][0] - shape_dict["radii"][0], - shape_dict["center"][1] - shape_dict["radii"][1], - ], - [ - shape_dict["center"][0] - shape_dict["radii"][0], - shape_dict["center"][1] + shape_dict["radii"][1], - ], - [ - shape_dict["center"][0] + shape_dict["radii"][0], - shape_dict["center"][1] + shape_dict["radii"][1], - ], - [ - shape_dict["center"][0] + shape_dict["radii"][0], - shape_dict["center"][1] - shape_dict["radii"][1], - ], - ] - ), + [center[0] - radii[0], center[1] - radii[1]], + [center[0] - radii[0], center[1] + radii[1]], + [center[0] + radii[0], center[1] + radii[1]], + [center[0] + radii[0], center[1] - radii[1]], ] - ), - "point": lambda shape_dict, prepend_dims: np.concatenate([prepend_dims, shape_dict["coordinates"]]).reshape(1, -1), - } + ) + return np.hstack([np.tile(prepend_dims, (4, 1)), corners]) + + def append_napari_format( + self, + shape_dict: Dict[str, Any], + napari_shapes: list[np.ndarray], + shape_types: list[str], + properties: dict[str, list[Any]], + ) -> None: + metadata = shape_dict.get("metadata", {}) + center = np.array(shape_dict["center"]) + radii = np.array(shape_dict["radii"]) + napari_shapes.append(np.array([center - radii, center + radii])) + shape_types.append("ellipse") + self.append_common_properties( + metadata, + properties, + metadata.get("centroid", (0, 0)), + ) + + +class PointNapariShapeConverter(NapariShapeConverter): + shape_type = ShapeType.POINT + + def add_dimensions(self, shape_dict: Dict[str, Any], prepend_dims: np.ndarray) -> np.ndarray: + return np.concatenate([prepend_dims, shape_dict["coordinates"]]).reshape(1, -1) + + def append_napari_format( + self, + shape_dict: Dict[str, Any], + napari_shapes: list[np.ndarray], + shape_types: list[str], + properties: dict[str, list[Any]], + ) -> None: + metadata = shape_dict.get("metadata", {}) + coordinates = shape_dict["coordinates"] + napari_shapes.append(np.array([coordinates])) + shape_types.append("point") + self.append_common_properties(metadata, properties, coordinates, area=0) + + +class NapariROIConverter: + """Convert ROI objects to Napari shapes format.""" @staticmethod def add_dimensions_to_shape(shape_dict: Dict[str, Any], prepend_dims: List[float]) -> np.ndarray: """Add dimensions to a 2D shape to make it nD.""" - shape_type = shape_dict["type"] - shape_type_enum = NapariShapeType(shape_type) if isinstance(shape_type, str) else shape_type - handler = NapariROIConverter._SHAPE_DIMENSION_HANDLERS.get(shape_type_enum.value) - if handler is None: - raise ValueError(f"Unsupported shape type: {shape_type}") - return handler(shape_dict, np.array(prepend_dims)) + return NapariShapeConverter.for_shape_dict(shape_dict).add_dimensions( + shape_dict, + np.array(prepend_dims), + ) @staticmethod def rois_to_shapes(rois: List[ROI]) -> List[Dict[str, Any]]: @@ -104,40 +225,12 @@ def shapes_to_napari_format(shapes_data: List[Dict]) -> Tuple[List[np.ndarray], properties = {"label": [], "area": [], "centroid_y": [], "centroid_x": []} for shape_dict in shapes_data: - shape_type = shape_dict.get("type") - metadata = shape_dict.get("metadata", {}) - - if shape_type == "polygon": - coords = np.array(shape_dict["coordinates"]) - napari_shapes.append(coords) - shape_types.append("polygon") - centroid = metadata.get("centroid", (0, 0)) - properties["label"].append(metadata.get("label", "")) - properties["area"].append(metadata.get("area", 0)) - properties["centroid_y"].append(centroid[0]) - properties["centroid_x"].append(centroid[1]) - - elif shape_type == "ellipse": - center = np.array(shape_dict["center"]) - radii = np.array(shape_dict["radii"]) - corners = np.array([center - radii, center + radii]) - napari_shapes.append(corners) - shape_types.append("ellipse") - centroid = metadata.get("centroid", (0, 0)) - properties["label"].append(metadata.get("label", "")) - properties["area"].append(metadata.get("area", 0)) - properties["centroid_y"].append(centroid[0]) - properties["centroid_x"].append(centroid[1]) - - elif shape_type == "point": - coords = np.array([shape_dict["coordinates"]]) - napari_shapes.append(coords) - shape_types.append("point") - point_coords = shape_dict["coordinates"] - properties["label"].append(metadata.get("label", "")) - properties["area"].append(0) - properties["centroid_y"].append(point_coords[0]) - properties["centroid_x"].append(point_coords[1]) + NapariShapeConverter.for_shape_dict(shape_dict).append_napari_format( + shape_dict, + napari_shapes, + shape_types, + properties, + ) return napari_shapes, shape_types, properties diff --git a/src/polystore/streaming/_streaming_backend.py b/src/polystore/streaming/_streaming_backend.py index 417baa2..1f3a70a 100644 --- a/src/polystore/streaming/_streaming_backend.py +++ b/src/polystore/streaming/_streaming_backend.py @@ -9,9 +9,12 @@ import os import time import uuid +from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, List, Set, Union +from typing import Any, Callable, List, Mapping, Set, Union import numpy as np +from arraybridge import convert_memory, detect_memory_type +from arraybridge.types import MemoryType as ArrayBridgeMemoryType from ..base import DataSink from ..constants import TransportMode @@ -19,11 +22,62 @@ from ..roi import ROI, PointShape from ..zmq_config import POLYSTORE_ZMQ_CONFIG from zmqruntime.ack_listener import GlobalAckListener -from zmqruntime.transport import coerce_transport_mode, get_zmq_transport_url +from zmqruntime.transport import coerce_transport_mode logger = logging.getLogger(__name__) +PrepareStreamingItem = Callable[[Any, Union[str, Path], Any], tuple[dict, str]] + + +@dataclass(frozen=True) +class StreamingComponentMetadata: + """Message metadata for one streamed item.""" + + parsed_filename_metadata: Mapping[str, Any] + source: str + + def to_payload(self) -> dict[str, Any]: + if isinstance(self.parsed_filename_metadata, Mapping): + metadata = dict(self.parsed_filename_metadata) + else: + raise TypeError( + "Streaming component metadata must be a mapping, " + f"got {type(self.parsed_filename_metadata).__name__}." + ) + metadata["source"] = self.source + return metadata + + +@dataclass(frozen=True) +class StreamingBatchRequest: + """Shared provenance for one streaming batch.""" + + data_list: List[Any] + file_paths: List[Union[str, Path]] + microscope_handler: Any + source: str + prepare_item: PrepareStreamingItem + component_metadata: Mapping[str, Any] | None = None + + +class StreamingPayloadMemoryAuthority: + """Memory conversion authority for streamable image payloads.""" + + @staticmethod + def to_numpy(data: Any) -> np.ndarray: + if isinstance(data, np.ndarray): + return data + if isinstance(data, (list, tuple)): + return np.asarray(data) + return convert_memory( + data, + detect_memory_type(data), + ArrayBridgeMemoryType.NUMPY.value, + gpu_id=0, + ) + + class StreamingBackend(DataSink): """ Abstract base class for ZeroMQ-based streaming backends. @@ -104,55 +158,13 @@ def __init__(self, transport_config=None): self._shared_memory_blocks = {} self._transport_config = transport_config or POLYSTORE_ZMQ_CONFIG - def _get_publisher(self, host: str, port: int, transport_mode: TransportMode, transport_config=None): - """ - Lazy initialization of ZeroMQ publisher (common for all streaming backends). - - Uses REQ socket for Fiji (synchronous request/reply with blocking) - and PUB socket for Napari (broadcast pattern). - - Args: - host: Host to connect to (ignored for IPC mode) - port: Port to connect to - transport_mode: IPC or TCP transport (required - comes from config) - - Returns: - ZeroMQ publisher socket - """ - # Generate transport URL using centralized function - transport_config = transport_config or self._transport_config - url = get_zmq_transport_url( - port, - host=host, - mode=coerce_transport_mode(transport_mode), - config=transport_config, - ) - - key = url # Use URL as key instead of host:port - if key not in self._publishers: - try: - import zmq - if self._context is None: - self._context = zmq.Context() - - # Use REQ socket for all viewers (synchronous request/reply) - # All viewers must send acknowledgment after processing - publisher = self._context.socket(zmq.REQ) - - publisher.connect(url) - socket_name = "REQ" - logger.info(f"{self.VIEWER_TYPE} streaming {socket_name} socket connected to {url}") - time.sleep(0.1) - self._publishers[key] = publisher - - except ImportError: - logger.error("ZeroMQ not available - streaming disabled") - raise RuntimeError("ZeroMQ required for streaming") - - return self._publishers[key] - - def _parse_component_metadata(self, file_path: Union[str, Path], microscope_handler, - source: str) -> dict: + def _parse_component_metadata( + self, + file_path: Union[str, Path], + microscope_handler, + source: str, + component_metadata: Mapping[str, Any] | None = None, + ) -> dict: """ Parse component metadata from filename (common for all streaming backends). @@ -165,12 +177,17 @@ def _parse_component_metadata(self, file_path: Union[str, Path], microscope_hand Component metadata dict with source added """ filename = os.path.basename(str(file_path)) - component_metadata = microscope_handler.parser.parse_filename(filename) - - # Add pre-built source value directly - component_metadata['source'] = source - - return component_metadata + parsed_metadata = ( + component_metadata + if component_metadata is not None + else microscope_handler.parser.parse_filename(filename) + ) + if parsed_metadata is None: + raise ValueError( + "Streaming component metadata requires explicit component_metadata " + f"or a parser-readable filename; got {filename!r}." + ) + return StreamingComponentMetadata(parsed_metadata, source).to_payload() def _detect_data_type(self, data: Any): """ @@ -206,9 +223,7 @@ def _create_shared_memory(self, data: Any, file_path: Union[str, Path]) -> dict: Returns: Dict with shared memory metadata """ - # Convert to numpy - np_data = data.cpu().numpy() if hasattr(data, 'cpu') else \ - data.get() if hasattr(data, 'get') else np.asarray(data) + np_data = StreamingPayloadMemoryAuthority.to_numpy(data) # Create shared memory with hash-based naming to avoid "File name too long" errors # Hash the timestamp and object ID to create a short, unique name @@ -269,13 +284,7 @@ def _register_with_queue_tracker( tracker.register_sent(image_id) def _build_component_modes(self, display_config) -> dict: - component_modes = {} - for comp_name in display_config.COMPONENT_ORDER: - mode_field = f"{comp_name}_mode" - if hasattr(display_config, mode_field): - mode = getattr(display_config, mode_field) - component_modes[comp_name] = mode.value - return component_modes + return display_config.component_modes() def _build_display_config_base(self, display_config, component_modes: dict) -> dict: return { @@ -304,20 +313,14 @@ def _collect_component_names_metadata( try: for comp_name in component_names: - method_name = f"get_{comp_name}_values" - method = getattr(microscope_handler.metadata_handler, method_name, None) - if callable(method): - try: - metadata = method(plate_path) - if verbose and log_prefix: - logger.info(f"{log_prefix}: Got {comp_name} metadata: {metadata}") - if metadata: - component_names_metadata[comp_name] = metadata - except Exception as e: - if verbose and log_prefix: - logger.warning(f"{log_prefix}: Could not get {comp_name} metadata: {e}", exc_info=True) - elif verbose and log_prefix: - logger.info(f"{log_prefix}: No method {method_name} on metadata_handler") + metadata = microscope_handler.metadata_handler.get_component_values( + plate_path, + comp_name, + ) + if verbose and log_prefix: + logger.info(f"{log_prefix}: Got {comp_name} metadata: {metadata}") + if metadata: + component_names_metadata[comp_name] = metadata except Exception as e: if verbose and log_prefix: logger.warning(f"{log_prefix}: Could not get component metadata: {e}", exc_info=True) @@ -326,24 +329,23 @@ def _collect_component_names_metadata( def _prepare_batch_items( self, - data_list: List[Any], - file_paths: List[Union[str, Path]], - microscope_handler, - source: str, - prepare_item: Callable[[Any, Union[str, Path], Any], tuple[dict, str]], + request: StreamingBatchRequest, ) -> tuple[list[dict], list[str]]: batch_images = [] image_ids = [] - for data, file_path in zip(data_list, file_paths): + for data, file_path in zip(request.data_list, request.file_paths): image_id = str(uuid.uuid4()) image_ids.append(image_id) data_type = self._detect_data_type(data) component_metadata = self._parse_component_metadata( - file_path, microscope_handler, source + file_path, + request.microscope_handler, + request.source, + request.component_metadata, ) - item_data, data_type_value = prepare_item(data, file_path, data_type) + item_data, data_type_value = request.prepare_item(data, file_path, data_type) batch_images.append( { @@ -363,9 +365,10 @@ def _build_batch_message( microscope_handler, source: str, display_config, - prepare_item: Callable[[Any, Union[str, Path], Any], tuple[dict, str]], + prepare_item: PrepareStreamingItem, plate_path: Union[str, Path, None] = None, component_names_kwargs: dict | None = None, + component_metadata: Mapping[str, Any] | None = None, display_payload_extra: dict | None = None, message_extra: dict | None = None, ) -> tuple[dict, list[dict], list[str]]: @@ -373,11 +376,14 @@ def _build_batch_message( raise ValueError("data_list and file_paths must have the same length") batch_images, image_ids = self._prepare_batch_items( - data_list, - file_paths, - microscope_handler, - source, - prepare_item, + StreamingBatchRequest( + data_list=data_list, + file_paths=file_paths, + microscope_handler=microscope_handler, + source=source, + prepare_item=prepare_item, + component_metadata=component_metadata, + ) ) component_modes = self._build_component_modes(display_config) diff --git a/src/polystore/streaming/receivers/napari/layer_key.py b/src/polystore/streaming/receivers/napari/layer_key.py index dec6fff..51b7d67 100644 --- a/src/polystore/streaming/receivers/napari/layer_key.py +++ b/src/polystore/streaming/receivers/napari/layer_key.py @@ -14,13 +14,7 @@ def normalize_component_layout(display_config: Any) -> tuple[dict[str, str], lis component_order = display_config["component_order"] return component_modes, component_order - component_order = list(display_config.COMPONENT_ORDER) - component_modes: dict[str, str] = {} - for component in component_order: - mode_field = f"{component}_mode" - mode_value = display_config.__getattribute__(mode_field) - component_modes[component] = mode_value.value - return component_modes, component_order + return display_config.component_modes(), list(display_config.COMPONENT_ORDER) def build_layer_key( @@ -38,9 +32,4 @@ def build_layer_key( layer_key = "_".join(layer_key_parts) if layer_key_parts else "default_layer" - if data_type == StreamingDataType.SHAPES: - return f"{layer_key}_shapes" - if data_type == StreamingDataType.POINTS: - return f"{layer_key}_points" - return layer_key - + return f"{layer_key}{data_type.napari_layer_suffix}" diff --git a/src/polystore/streaming/receivers/napari/napari_batch_processor.py b/src/polystore/streaming/receivers/napari/napari_batch_processor.py index b8dcbdd..ad485af 100644 --- a/src/polystore/streaming/receivers/napari/napari_batch_processor.py +++ b/src/polystore/streaming/receivers/napari/napari_batch_processor.py @@ -1,20 +1,16 @@ import logging from typing import Any, Dict, List, Optional -from polystore.streaming.receivers.core import DebouncedBatchEngine - logger = logging.getLogger(__name__) class NapariBatchProcessor: """ - Batch processor for Napari viewer with configurable batching strategies. - - Accumulates items and displays them based on batch_size configuration: - - None: Wait for all items in operation, then display once - - N: Display every N items incrementally - - Uses debouncing to collect items arriving in rapid succession. + Batch processor for Napari viewer display operations. + + Napari layer mutation must run on the Qt event-loop thread. OpenHCS owns that + Qt-thread debounce before this processor is called, so this class only + adapts batch payloads into the server display operation. """ def __init__( @@ -29,22 +25,15 @@ def __init__( Args: napari_server: Reference to NapariViewerServer for display operations - batch_size: Number of items to batch before displaying - None = wait for all (default), N = display every N items - debounce_delay_ms: Wait time after last item before processing (ms) - max_debounce_wait_ms: Maximum total wait time before forcing display (ms) + batch_size: Reserved for compatibility with viewer configuration + debounce_delay_ms: Qt-thread debounce delay owned by the caller + max_debounce_wait_ms: Reserved for compatibility with viewer configuration """ self.napari_server = napari_server self.batch_size = batch_size self.debounce_delay_ms = debounce_delay_ms self.max_debounce_wait_ms = max_debounce_wait_ms - self._engine = DebouncedBatchEngine( - process_fn=self._process_batch, - debounce_delay_ms=debounce_delay_ms, - max_debounce_wait_ms=max_debounce_wait_ms, - ) - logger.info( f"NapariBatchProcessor: Created with batch_size={batch_size}, " f"debounce={debounce_delay_ms}ms, max_wait={max_debounce_wait_ms}ms" @@ -58,7 +47,7 @@ def add_items( component_names_metadata: Dict[str, Any], ): """ - Add items to the batch for processing. + Display items already released by the Qt-thread debounce. Args: layer_key: Unique identifier for the layer @@ -66,9 +55,9 @@ def add_items( display_config: Display configuration dict component_names_metadata: Component name mappings for dimension labels """ - self._engine.enqueue( - items=items, - context={ + self._process_batch( + items, + { "display_config": display_config, "component_names_metadata": component_names_metadata, "layer_key": layer_key, @@ -81,12 +70,11 @@ def add_items( ) def flush(self) -> None: - """Force immediate processing of the pending batch.""" - self._engine.flush() + """Compatibility no-op; OpenHCS owns the Qt-thread debounce timer.""" def _process_batch(self, items: List[Dict[str, Any]], context: Dict[str, Any]) -> None: """Process callback used by shared debounced batch engine.""" - self.napari_server._display_layer_batch( + self.napari_server.display_layer_batch( layer_key=context["layer_key"], items=items, display_config=context["display_config"], diff --git a/src/polystore/streaming_constants.py b/src/polystore/streaming_constants.py index d7f0596..05c834c 100644 --- a/src/polystore/streaming_constants.py +++ b/src/polystore/streaming_constants.py @@ -15,6 +15,21 @@ class StreamingDataType(Enum): POINTS = "points" # Napari points layer (e.g., skeleton tracings) ROIS = "rois" # Fiji ROI payloads + @property + def uses_napari_vector_payload(self) -> bool: + """Whether napari should receive this type through vector layer payloads.""" + return self in (type(self).SHAPES, type(self).POINTS) + + @property + def napari_layer_suffix(self) -> str: + """Layer-key suffix contributed by this data type.""" + return { + type(self).IMAGE: "", + type(self).SHAPES: "_shapes", + type(self).POINTS: "_points", + type(self).ROIS: "", + }[self] + class NapariShapeType(Enum): """Napari shape types for ROI visualization.""" diff --git a/src/polystore/virtual_workspace.py b/src/polystore/virtual_workspace.py index 45081a3..bec8be5 100644 --- a/src/polystore/virtual_workspace.py +++ b/src/polystore/virtual_workspace.py @@ -205,10 +205,20 @@ def list_files(self, directory: Union[str, Path], pattern: Optional[str] = None, if self._mapping_cache is None: self._load_mapping() - logger.info(f"VirtualWorkspace.list_files called: directory={directory}, recursive={recursive}, pattern={pattern}, extensions={extensions}") - logger.info(f" plate_root={self.plate_root}") - logger.info(f" relative_dir_str='{relative_dir_str}'") - logger.info(f" mapping has {len(self._mapping_cache)} entries") + logger.debug( + "VirtualWorkspace.list_files directory=%s recursive=%s pattern=%s extensions=%s", + directory, + recursive, + pattern, + extensions, + ) + logger.debug(" plate_root=%s", self.plate_root) + logger.debug(" relative_dir_str=%r", relative_dir_str) + logger.debug(" mapping has %s entries", len(self._mapping_cache)) + + lowercase_extensions = ( + None if extensions is None else {ext.lower() for ext in extensions} + ) # Filter paths in this directory results = [] @@ -230,20 +240,20 @@ def list_files(self, directory: Union[str, Path], pattern: Optional[str] = None, vpath = Path(virtual_relative) if pattern and not fnmatch(vpath.name, pattern): continue - if extensions and vpath.suffix not in extensions: + if lowercase_extensions and vpath.suffix.lower() not in lowercase_extensions: continue # Return absolute path results.append(str(self.plate_root / virtual_relative)) - logger.info(f" VirtualWorkspace.list_files returning {len(results)} files") + logger.debug(" VirtualWorkspace.list_files returning %s files", len(results)) if len(results) == 0 and len(self._mapping_cache) > 0: # Log first few mapping keys to help debug sample_keys = list(self._mapping_cache.keys())[:3] - logger.info(f" Sample mapping keys: {sample_keys}") + logger.debug(" Sample mapping keys: %s", sample_keys) if not recursive and relative_dir_str == '': sample_parents = [str(Path(k).parent).replace('\\', '/') for k in sample_keys] - logger.info(f" Sample parent dirs: {sample_parents}") + logger.debug(" Sample parent dirs: %s", sample_parents) logger.info(f" Expected parent to match: '{relative_dir_str}'") return sorted(results) diff --git a/tests/test_memory_backend.py b/tests/test_memory_backend.py index f55996b..ec8a080 100644 --- a/tests/test_memory_backend.py +++ b/tests/test_memory_backend.py @@ -109,6 +109,17 @@ def test_list_files_with_extension_filter(self): npy_files = self.backend.list_files("/test", extensions={".npy"}) assert len(npy_files) == 2 + def test_list_files_extension_filter_is_case_insensitive(self): + """Test extension filtering matches backend contract case-insensitively.""" + self.backend.save(np.array([1]), "/test/image.TIF") + self.backend.save(np.array([2]), "/test/image.tif") + self.backend.save("text", "/test/notes.TXT") + + tif_files = self.backend.list_files("/test", extensions={".tif"}) + + assert len(tif_files) == 2 + assert {path.name for path in tif_files} == {"image.TIF", "image.tif"} + def test_list_files_recursive(self): """Test recursive file listing.""" # Create files in multiple levels diff --git a/tests/test_roi.py b/tests/test_roi.py new file mode 100644 index 0000000..565022f --- /dev/null +++ b/tests/test_roi.py @@ -0,0 +1,79 @@ +import numpy as np + +from polystore.roi import MaskShape +from polystore.roi import PolygonShape +from polystore.roi import load_rois_from_json +from polystore.roi import extract_rois_from_labeled_mask + + +def test_extract_rois_from_labeled_mask_applies_spatial_origin_to_polygons(): + labels = np.zeros((8, 8), dtype=np.int32) + labels[2:6, 3:7] = 1 + + rois = extract_rois_from_labeled_mask( + labels, + min_area=0, + extract_contours=True, + spatial_origin_yx=(10, 20), + ) + + assert len(rois) == 1 + assert rois[0].metadata["bbox"] == (12, 23, 16, 27) + assert rois[0].metadata["centroid"] == (13.5, 24.5) + assert isinstance(rois[0].shapes[0], PolygonShape) + assert float(rois[0].shapes[0].coordinates[:, 0].min()) >= 11.5 + assert float(rois[0].shapes[0].coordinates[:, 1].min()) >= 22.5 + + +def test_extract_rois_from_labeled_mask_applies_spatial_origin_to_mask_bbox(): + labels = np.zeros((8, 8), dtype=np.int32) + labels[2:6, 3:7] = 1 + + rois = extract_rois_from_labeled_mask( + labels, + min_area=0, + extract_contours=False, + spatial_origin_yx=(10, 20), + ) + + assert len(rois) == 1 + assert isinstance(rois[0].shapes[0], MaskShape) + assert rois[0].shapes[0].bbox == (12, 23, 16, 27) + + +def test_extract_rois_from_labeled_mask_records_source_canvas_shape(): + labels = np.zeros((8, 8), dtype=np.int32) + labels[2:6, 3:7] = 1 + + rois = extract_rois_from_labeled_mask( + labels, + min_area=0, + source_spatial_shape_yx=(100, 200), + ) + + assert len(rois) == 1 + assert rois[0].metadata["source_spatial_shape_yx"] == (100, 200) + + +def test_load_rois_from_json_decodes_shapes_through_nominal_registry(tmp_path): + roi_path = tmp_path / "rois.json" + roi_path.write_text( + """ + [ + { + "metadata": {"label": 1}, + "shapes": [ + {"type": "polygon", "coordinates": [[1, 2], [3, 4], [5, 6]]}, + {"type": "mask", "mask": [[true, false], [false, true]], "bbox": [10, 20, 12, 22]} + ] + } + ] + """ + ) + + rois = load_rois_from_json(roi_path) + + assert len(rois) == 1 + assert isinstance(rois[0].shapes[0], PolygonShape) + assert isinstance(rois[0].shapes[1], MaskShape) + assert rois[0].shapes[1].bbox == (10, 20, 12, 22) diff --git a/tests/test_streaming_metadata.py b/tests/test_streaming_metadata.py new file mode 100644 index 0000000..95d3b00 --- /dev/null +++ b/tests/test_streaming_metadata.py @@ -0,0 +1,98 @@ +from types import SimpleNamespace + +import pytest + +from polystore.streaming._streaming_backend import StreamingBackend +from polystore.streaming._streaming_backend import StreamingBatchRequest + + +class MetadataProbeStreamingBackend(StreamingBackend): + VIEWER_TYPE = "probe" + SHM_PREFIX = "probe_" + + def save_batch(self, data_list, file_paths, **kwargs): + raise NotImplementedError + + +def test_streaming_component_metadata_rejects_unparsed_artifact_filename() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = SimpleNamespace( + parser=SimpleNamespace(parse_filename=lambda _filename: None) + ) + + with pytest.raises(ValueError, match="explicit component_metadata"): + backend._parse_component_metadata( + "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip", + microscope_handler, + source="IdentifyPrimaryObjects", + ) + + +def test_streaming_batch_items_reject_unparsed_artifact_filename() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = SimpleNamespace( + parser=SimpleNamespace(parse_filename=lambda _filename: None) + ) + + with pytest.raises(ValueError, match="explicit component_metadata"): + backend._prepare_batch_items( + StreamingBatchRequest( + data_list=[object()], + file_paths=["A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip"], + microscope_handler=microscope_handler, + source="IdentifyPrimaryObjects", + prepare_item=lambda _data, _path, _data_type: ({"payload": "ok"}, "image"), + ) + ) + + +def test_streaming_component_metadata_preserves_parsed_filename_fields() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = SimpleNamespace( + parser=SimpleNamespace( + parse_filename=lambda _filename: {"well": "A01", "channel": 1} + ) + ) + + metadata = backend._parse_component_metadata( + "A01_s001_w1_z001_t001.TIF", + microscope_handler, + source="Crop", + ) + + assert metadata == {"well": "A01", "channel": 1, "source": "Crop"} + + +def test_streaming_component_metadata_prefers_explicit_metadata() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = SimpleNamespace( + parser=SimpleNamespace(parse_filename=lambda _filename: None) + ) + + metadata = backend._parse_component_metadata( + "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip", + microscope_handler, + source="IdentifyPrimaryObjects", + component_metadata={"well": "A01", "site": 1, "channel": 1}, + ) + + assert metadata == { + "well": "A01", + "site": 1, + "channel": 1, + "source": "IdentifyPrimaryObjects", + } + + +def test_streaming_component_metadata_rejects_invalid_parser_result() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = SimpleNamespace( + parser=SimpleNamespace(parse_filename=lambda _filename: ["not", "metadata"]) + ) + + with pytest.raises(TypeError, match="must be a mapping"): + backend._parse_component_metadata( + "A01_s001_w1_z001_t001.TIF", + microscope_handler, + source="Crop", + )