diff --git a/pyproject.toml b/pyproject.toml index 05c5baa2..2cbd0355 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,10 @@ dependencies = [ "tifffile>=2023.8.12", "ome-types", "xmltodict", + # CosMx reader + "polars", + "dask", + "zarr>=3", ] [project.optional-dependencies] @@ -188,6 +192,7 @@ lint.ignore = [ ] [tool.ruff.lint.per-file-ignores] "src/spatialdata_io/__init__.py" = ["I001"] +"tests/*" = ["D"] [tool.jupytext] formats = "ipynb,md" diff --git a/src/spatialdata_io/readers/cosmx.py b/src/spatialdata_io/readers/cosmx.py deleted file mode 100644 index 4c918a4f..00000000 --- a/src/spatialdata_io/readers/cosmx.py +++ /dev/null @@ -1,294 +0,0 @@ -from __future__ import annotations - -import os -import re -from pathlib import Path -from types import MappingProxyType -from typing import TYPE_CHECKING, Any - -import dask.array as da -import numpy as np -import pandas as pd -import pyarrow as pa -from anndata import AnnData -from dask_image.imread import imread -from scipy.sparse import csr_matrix -from skimage.transform import estimate_transform -from spatialdata import SpatialData -from spatialdata._logging import logger -from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, TableModel -from spatialdata.transformations.transformations import Affine, Identity - -from spatialdata_io._constants._constants import CosmxKeys -from spatialdata_io._docs import inject_docs -from spatialdata_io.readers._utils._utils import _set_reader_metadata - -if TYPE_CHECKING: - from collections.abc import Mapping - - from dask.dataframe import DataFrame as DaskDataFrame - -__all__ = ["cosmx"] - - -@inject_docs(cx=CosmxKeys) -def cosmx( - path: str | Path, - dataset_id: str | None = None, - transcripts: bool = True, - imread_kwargs: Mapping[str, Any] = MappingProxyType({}), - image_models_kwargs: Mapping[str, Any] = MappingProxyType({}), -) -> SpatialData: - """Read *Cosmx Nanostring* data. - - This function reads the following files: - - - ``_`{cx.COUNTS_SUFFIX!r}```: Counts matrix. - - ``_`{cx.METADATA_SUFFIX!r}```: Metadata file. - - ``_`{cx.FOV_SUFFIX!r}```: Field of view file. - - ``{cx.IMAGES_DIR!r}``: Directory containing the images. - - ``{cx.LABELS_DIR!r}``: Directory containing the labels. - - .. seealso:: - - - `Nanostring Spatial Molecular Imager `_. - - Parameters - ---------- - path - Path to the root directory containing *Nanostring* files. - dataset_id - Name of the dataset. - transcripts - Whether to also read in transcripts information. - imread_kwargs - Keyword arguments passed to :func:`dask_image.imread.imread`. - image_models_kwargs - Keyword arguments passed to :class:`spatialdata.models.Image2DModel`. - - Returns - ------- - :class:`spatialdata.SpatialData` - """ - path = Path(path) - - # tries to infer dataset_id from the name of the counts file - if dataset_id is None: - counts_files = [f for f in os.listdir(path) if str(f).endswith(CosmxKeys.COUNTS_SUFFIX)] - if len(counts_files) == 1: - found = re.match(rf"(.*)_{CosmxKeys.COUNTS_SUFFIX}", counts_files[0]) - if found: - dataset_id = found.group(1) - if dataset_id is None: - raise ValueError("Could not infer `dataset_id` from the name of the counts file. Please specify it manually.") - - # check for file existence - counts_file = path / f"{dataset_id}_{CosmxKeys.COUNTS_SUFFIX}" - if not counts_file.exists(): - raise FileNotFoundError(f"Counts file not found: {counts_file}.") - if transcripts: - transcripts_file = path / f"{dataset_id}_{CosmxKeys.TRANSCRIPTS_SUFFIX}" - if not transcripts_file.exists(): - raise FileNotFoundError(f"Transcripts file not found: {transcripts_file}.") - else: - transcripts_file = None - meta_file = path / f"{dataset_id}_{CosmxKeys.METADATA_SUFFIX}" - if not meta_file.exists(): - raise FileNotFoundError(f"Metadata file not found: {meta_file}.") - fov_file = path / f"{dataset_id}_{CosmxKeys.FOV_SUFFIX}" - if not fov_file.exists(): - raise FileNotFoundError(f"Found field of view file: {fov_file}.") - images_dir = path / CosmxKeys.IMAGES_DIR - if not images_dir.exists(): - raise FileNotFoundError(f"Images directory not found: {images_dir}.") - labels_dir = path / CosmxKeys.LABELS_DIR - if not labels_dir.exists(): - raise FileNotFoundError(f"Labels directory not found: {labels_dir}.") - - counts = pd.read_csv(counts_file, header=0, index_col=CosmxKeys.INSTANCE_KEY) - counts.index = counts.index.astype(str).str.cat(counts.pop(CosmxKeys.FOV).astype(str).values, sep="_") - - obs = pd.read_csv(meta_file, header=0, index_col=CosmxKeys.INSTANCE_KEY) - obs[CosmxKeys.FOV] = pd.Categorical(obs[CosmxKeys.FOV].astype(str)) - obs[CosmxKeys.REGION_KEY] = pd.Categorical(obs[CosmxKeys.FOV].astype(str).apply(lambda s: s + "_labels")) - obs[CosmxKeys.INSTANCE_KEY] = obs.index.astype(np.int64) - obs.rename_axis(None, inplace=True) - obs.index = obs.index.astype(str).str.cat(obs[CosmxKeys.FOV].values, sep="_") - - common_index = obs.index.intersection(counts.index) - - adata = AnnData( - csr_matrix(counts.loc[common_index, :].values), - dtype=counts.values.dtype, - obs=obs.loc[common_index, :], - ) - adata.var_names = counts.columns - - table = TableModel.parse( - adata, - region=list(set(adata.obs[CosmxKeys.REGION_KEY].astype(str).tolist())), - region_key=CosmxKeys.REGION_KEY.value, - instance_key=CosmxKeys.INSTANCE_KEY.value, - ) - - fovs_counts = list(map(str, adata.obs.fov.astype(int).unique())) - - affine_transforms_to_global = {} - - for fov in fovs_counts: - idx = table.obs.fov.astype(str) == fov - loc = table[idx, :].obs[[CosmxKeys.X_LOCAL_CELL, CosmxKeys.Y_LOCAL_CELL]].values - glob = table[idx, :].obs[[CosmxKeys.X_GLOBAL_CELL, CosmxKeys.Y_GLOBAL_CELL]].values - out = estimate_transform(ttype="affine", src=loc, dst=glob) - affine_transforms_to_global[fov] = Affine( - # out.params, input_coordinate_system=input_cs, output_coordinate_system=output_cs - out.params, - input_axes=("x", "y"), - output_axes=("x", "y"), - ) - - table.obsm["global"] = table.obs[[CosmxKeys.X_GLOBAL_CELL, CosmxKeys.Y_GLOBAL_CELL]].to_numpy() - table.obsm["spatial"] = table.obs[[CosmxKeys.X_LOCAL_CELL, CosmxKeys.Y_LOCAL_CELL]].to_numpy() - table.obs.drop( - columns=[CosmxKeys.X_LOCAL_CELL, CosmxKeys.Y_LOCAL_CELL, CosmxKeys.X_GLOBAL_CELL, CosmxKeys.Y_GLOBAL_CELL], - inplace=True, - ) - - # prepare to read images and labels - file_extensions = (".jpg", ".png", ".jpeg", ".tif", ".tiff") - pat = re.compile(r".*_F(\d+)") - - # check if fovs are correct for images and labels - fovs_images = [] - for fname in os.listdir(path / CosmxKeys.IMAGES_DIR): - if fname.endswith(file_extensions): - fovs_images.append(str(int(pat.findall(fname)[0]))) - - fovs_labels = [] - for fname in os.listdir(path / CosmxKeys.LABELS_DIR): - if fname.endswith(file_extensions): - fovs_labels.append(str(int(pat.findall(fname)[0]))) - - fovs_images_and_labels = set(fovs_images).intersection(set(fovs_labels)) - fovs_diff = fovs_images_and_labels.difference(set(fovs_counts)) - if len(fovs_diff): - logger.warning( - f"Found images and labels for {len(fovs_images)} FOVs, but only {len(fovs_counts)} FOVs in the counts file.\n" - + f"The following FOVs are missing: {fovs_diff} \n" - + "... will use only fovs in Table." - ) - - # read images - images = {} - for fname in os.listdir(path / CosmxKeys.IMAGES_DIR): - if fname.endswith(file_extensions): - fov = str(int(pat.findall(fname)[0])) - if fov in fovs_counts: - aff = affine_transforms_to_global[fov] - im = imread(path / CosmxKeys.IMAGES_DIR / fname, **imread_kwargs).squeeze() - flipped_im = da.flip(im, axis=0) - parsed_im = Image2DModel.parse( - flipped_im, - transformations={ - fov: Identity(), - "global": aff, - "global_only_image": aff, - }, - dims=("y", "x", "c"), - rgb=None, - **image_models_kwargs, - ) - images[f"{fov}_image"] = parsed_im - else: - logger.warning(f"FOV {fov} not found in counts file. Skipping image {fname}.") - - # read labels - labels = {} - for fname in os.listdir(path / CosmxKeys.LABELS_DIR): - if fname.endswith(file_extensions): - fov = str(int(pat.findall(fname)[0])) - if fov in fovs_counts: - aff = affine_transforms_to_global[fov] - la = imread(path / CosmxKeys.LABELS_DIR / fname, **imread_kwargs).squeeze() - flipped_la = da.flip(la, axis=0) - parsed_la = Labels2DModel.parse( - flipped_la, - transformations={ - fov: Identity(), - "global": aff, - "global_only_labels": aff, - }, - dims=("y", "x"), - **image_models_kwargs, - ) - labels[f"{fov}_labels"] = parsed_la - else: - logger.warning(f"FOV {fov} not found in counts file. Skipping labels {fname}.") - - points: dict[str, DaskDataFrame] = {} - if transcripts: - # assert transcripts_file is not None - # from pyarrow.csv import read_csv - # - # ptable = read_csv(path / transcripts_file) # , header=0) - # for fov in fovs_counts: - # aff = affine_transforms_to_global[fov] - # sub_table = ptable.filter(pa.compute.equal(ptable.column(CosmxKeys.FOV), int(fov))).to_pandas() - # sub_table[CosmxKeys.INSTANCE_KEY] = sub_table[CosmxKeys.INSTANCE_KEY].astype("category") - # # we rename z because we want to treat the data as 2d - # sub_table.rename(columns={"z": "z_raw"}, inplace=True) - # points[fov] = PointsModel.parse( - # sub_table, - # coordinates={"x": CosmxKeys.X_LOCAL_TRANSCRIPT, "y": CosmxKeys.Y_LOCAL_TRANSCRIPT}, - # feature_key=CosmxKeys.TARGET_OF_TRANSCRIPT, - # instance_key=CosmxKeys.INSTANCE_KEY, - # transformations={ - # fov: Identity(), - # "global": aff, - # "global_only_labels": aff, - # }, - # ) - # let's convert the .csv to .parquet and let's read it with pyarrow.parquet for faster subsetting - import tempfile - - import pyarrow.parquet as pq - - with tempfile.TemporaryDirectory() as tmpdir: - print("converting .csv to .parquet to improve the speed of the slicing operations... ", end="", flush=True) - assert transcripts_file is not None - transcripts_data = pd.read_csv(transcripts_file, header=0) - transcripts_data.to_parquet(Path(tmpdir) / "transcripts.parquet") - print("done") - - ptable = pq.read_table(Path(tmpdir) / "transcripts.parquet") - for fov in fovs_counts: - aff = affine_transforms_to_global[fov] - sub_table = ptable.filter(pa.compute.equal(ptable.column(CosmxKeys.FOV), int(fov))).to_pandas() - sub_table[CosmxKeys.INSTANCE_KEY] = sub_table[CosmxKeys.INSTANCE_KEY].astype("category") - # we rename z because we want to treat the data as 2d - sub_table.rename(columns={"z": "z_raw"}, inplace=True) - if len(sub_table) > 0: - points[f"{fov}_points"] = PointsModel.parse( - sub_table, - coordinates={"x": CosmxKeys.X_LOCAL_TRANSCRIPT, "y": CosmxKeys.Y_LOCAL_TRANSCRIPT}, - feature_key=CosmxKeys.TARGET_OF_TRANSCRIPT, - instance_key=CosmxKeys.INSTANCE_KEY, - transformations={ - fov: Identity(), - "global": aff, - "global_only_labels": aff, - }, - ) - - # TODO: what to do with fov file? - # if fov_file is not None: - # fov_positions = pd.read_csv(path / fov_file, header=0, index_col=CosmxKeys.FOV) - # for fov, row in fov_positions.iterrows(): - # try: - # adata.uns["spatial"][str(fov)]["metadata"] = row.to_dict() - # except KeyError: - # logg.warning(f"FOV `{str(fov)}` does not exist, skipping it.") - # continue - - sdata = SpatialData(images=images, labels=labels, points=points, tables={"table": table}) - return _set_reader_metadata(sdata, "cosmx") diff --git a/src/spatialdata_io/readers/cosmx/__init__.py b/src/spatialdata_io/readers/cosmx/__init__.py new file mode 100644 index 00000000..438cb288 --- /dev/null +++ b/src/spatialdata_io/readers/cosmx/__init__.py @@ -0,0 +1,7 @@ +"""NanoString CosMx reader.""" + +from __future__ import annotations + +from ._reader import cosmx + +__all__ = ["cosmx"] diff --git a/src/spatialdata_io/readers/cosmx/_discovery.py b/src/spatialdata_io/readers/cosmx/_discovery.py new file mode 100644 index 00000000..265a2095 --- /dev/null +++ b/src/spatialdata_io/readers/cosmx/_discovery.py @@ -0,0 +1,376 @@ +"""File discovery and dataset-ID inference for CosMx data.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from spatialdata._logging import logger + +if TYPE_CHECKING: + from collections.abc import Iterable + from pathlib import Path + +_FLAT_FILE_TARGETS = {"exprMat_file", "fov_positions_file", "metadata_file", "tx_file", "polygons"} +_DIR_TARGETS = [ + "AnalysisResults", + "CellStatsDir", + "RunSummary", + "CellComposite", + "CellLabels", + "CellOverlay", + "CellType_Accessory_Data", + "CompartmentLabels", + "ProteinDir", + "Morphology2D", + "Morphology3D", +] +_SCAN_TARGETS = list(_FLAT_FILE_TARGETS) + _DIR_TARGETS + + +def _infer_dataset_id(path: Path, dataset_id: str | None) -> str | None: + """Infer the dataset ID from marker-file prefixes on disk. + + Handles multimodal layouts (e.g. ``S3RNA`` / ``S3Protein`` → ``S3``). + + Parameters + ---------- + path + Root directory to scan. + dataset_id + Explicit ID override. Validated against what is found on disk. + + Returns + ------- + The inferred or validated dataset ID, or ``None`` if nothing was found. + """ + suffixes = [".csv", ".csv.gz", ".parquet"] + found_ids: list[str] = [] + + def _base_id(pid: str) -> str: + for suf in ("RNA", "Protein"): + if pid.endswith(suf) and len(pid) > len(suf): + return pid[: -len(suf)] + return pid + + for marker in _FLAT_FILE_TARGETS: + for suffix in suffixes: + for file in path.rglob(f"*{marker}{suffix}"): + stem = file.name + for s in suffixes: + if stem.endswith(s): + stem = stem.removesuffix(s) + if stem.endswith(f"_{marker}"): + prefix = stem.removesuffix(f"_{marker}") + elif stem.endswith(f"-{marker}"): + prefix = stem.removesuffix(f"-{marker}") + else: + continue + if prefix: + found_ids.append(prefix) + base = _base_id(prefix) + if base and base != prefix: + found_ids.append(base) + + unique_ids = list(dict.fromkeys(found_ids)) + unique_bases = list(dict.fromkeys(_base_id(x) for x in found_ids if x)) + + if dataset_id is not None: + if not unique_ids: + raise ValueError( + f"Provided dataset_id={dataset_id!r} but no CosMx marker files " + f"with an ID prefix were found under {path}." + ) + if dataset_id in unique_ids: + return dataset_id + if dataset_id in unique_bases and len(set(unique_bases)) == 1: + return dataset_id + raise ValueError(f"Provided dataset_id={dataset_id!r} not among inferred IDs {unique_ids}.") + + if not unique_ids: + logger.warning("Could not infer dataset_id from marker files under %s.", path) + return None + if len(unique_ids) == 1: + return unique_ids[0] + if len(set(unique_bases)) == 1 and unique_bases: + return unique_bases[0] + + raise ValueError(f"Found multiple possible dataset IDs {unique_ids}. Please specify dataset_id=... explicitly.") + + +# --------------------------------------------------------------------------- +# recursive file scanner +# --------------------------------------------------------------------------- + +_WHITELIST_DIRS = { + "flatFiles", + "Flatfiles_RNA", + "Flatfiles_Protein", + "images", + "AnalysisResults", + "AnalysisResults_RNA", + "AnalysisResults_Protein", + "CellStatsDir", + "RunSummary", + "CellComposite", + "CellLabels", + "CellOverlay", + "CellType_Accessory_Data", + "CompartmentLabels", + "ProteinImages", + "Morphology2D", + "Morphology3D", +} + + +def _scan_for_files_to_read( + path: Path, + targets: Iterable[str], + dataset_id: str | None = None, + max_depth: int = 4, +) -> dict[str, Path | None]: + """Recursively scan *path* for CosMx data files matching *targets*.""" + found: dict[str, Path | None] = dict.fromkeys(targets) + base = path.resolve() + + dataset_prefixes: set[str] = set() + if dataset_id: + dataset_prefixes = {dataset_id, f"{dataset_id}RNA", f"{dataset_id}Protein"} + + def _strip_all_extensions(p: Path) -> str: + q = p + while q.suffix: + q = q.with_suffix("") + return q.name + + def _matches_target(stem: str, target: str) -> bool: + return stem.endswith(f"_{target}") or stem.endswith(f"-{target}") or stem == target + + def _recurse(curr: Path, depth: int) -> None: + if depth > max_depth: + return + for child in curr.iterdir(): + if dataset_id is not None: + for t, val in found.items(): + if val is not None: + continue + wanted_names = { + f"{dataset_id}_{t}", + f"{dataset_id}-{t}", + f"{dataset_id}RNA_{t}", + f"{dataset_id}RNA-{t}", + f"{dataset_id}Protein_{t}", + f"{dataset_id}Protein-{t}", + t, + } + if child.is_dir() and child.name in wanted_names: + found[t] = child + break + if child.is_file(): + stem = _strip_all_extensions(child) + if stem in wanted_names: + found[t] = child + break + for t, val in found.items(): + if val is not None: + continue + if child.is_dir() and child.name == t: + found[t] = child + break + if child.is_file() and _matches_target(_strip_all_extensions(child), t): + found[t] = child + break + if child.is_dir() and any(v is None for v in found.values()): + if dataset_prefixes: + if child.name not in _WHITELIST_DIRS and not any( + child.name.startswith(pref) for pref in dataset_prefixes + ): + continue + _recurse(child, depth + 1) + if all(v is not None for v in found.values()): + return + + _recurse(base, 0) + return found + + +# --------------------------------------------------------------------------- +# multimodal detection +# --------------------------------------------------------------------------- + + +def _modality_from_ancestors(file: Path, root: Path) -> str | None: + """Infer a modality name (``RNA`` / ``Protein``) from ancestor directory names.""" + try: + rel = file.relative_to(root) + except ValueError: + return None + for part in rel.parts[:-1]: + upper = part.upper() + if "RNA" in upper and "PROTEIN" not in upper: + return "RNA" + if "PROTEIN" in upper: + return "Protein" + return None + + +def _discover_modalities(path: Path, base_id: str | None) -> dict[str, dict[str, Path]]: + """Detect multiple modality prefixes (e.g. ``S3RNA`` / ``S3Protein``). + + Handles two conventions: + + * **Prefix-based** (V2): filenames encode the modality + (``S3RNA_exprMat_file`` vs ``S3Protein_exprMat_file``). + * **Directory-based** (breast multiomics): filenames share the same + prefix but live under modality-specific directories + (``Flatfiles_RNA/…/BreastCancer_exprMat_file`` vs + ``Flatfiles_Protein/…/BreastCancer_exprMat_file``). + + For directory-based layouts the returned dict includes a ``flat_root`` + key pointing to the directory that contains the flat files for that + modality so the scanner can be scoped correctly. + """ + modalities: dict[str, dict[str, Path]] = {} + expr_files = list(path.rglob("*exprMat_file.csv*")) + for f in expr_files: + stem = f.name + for suf in (".csv.gz", ".csv", ".gz"): + if stem.endswith(suf): + stem = stem.removesuffix(suf) + if stem.endswith("_exprMat_file"): + prefix = stem.removesuffix("_exprMat_file") + elif stem.endswith("-exprMat_file"): + prefix = stem.removesuffix("-exprMat_file") + else: + continue + if not prefix: + continue + if base_id and not prefix.startswith(base_id): + continue + mod = prefix + if base_id and prefix.startswith(base_id): + mod = prefix[len(base_id) :].lstrip("_-") or base_id + # No modality suffix in the filename — try the directory tree. + flat_root = None + if mod == base_id: + dir_mod = _modality_from_ancestors(f, path) + if dir_mod: + mod = dir_mod + flat_root = f.parent + if mod not in modalities: + info: dict[str, Path] = {"prefix": prefix, "exprMat_file": f} + if flat_root is not None: + info["flat_root"] = flat_root + modalities[mod] = info + return modalities + + +# --------------------------------------------------------------------------- +# high-level dataset setup +# --------------------------------------------------------------------------- + + +def _label_dir_fallback( + standalone: Path | None, + cell_stats_dir: Path | None, + prefix: str, +) -> Path | None: + """Return the standalone label directory, or fall back to *cell_stats_dir*. + + Some datasets (e.g. breast multiomics) package ``CellLabels_F*.tif`` and + ``CompartmentLabels_F*.tif`` inside per-FOV subdirectories of + ``CellStatsDir/`` instead of separate top-level directories. When the + standalone directory is absent we check whether the per-FOV TIFFs exist + inside *cell_stats_dir* and, if so, return it as the label root. The + reader uses ``rglob`` so it finds the TIFFs in either layout. + """ + if standalone is not None: + return standalone + if cell_stats_dir is None: + return None + # Quick check: at least one per-FOV TIF present? + if any(cell_stats_dir.rglob(f"{prefix}_F*.tif")): + return cell_stats_dir + return None + + +def _set_up_cosmx_dataset_for_conversion( + path: Path, + dataset_id: str | None = None, +): + """Build a :class:`CosMxDataset` descriptor from a directory.""" + # import here to avoid circular dependency + from ._reader import CosMxDataset + + path = path.resolve() + inferred_id = _infer_dataset_id(path, dataset_id) + modal_map = _discover_modalities(path, inferred_id) + + def _build_dataset( + prefix: str | None, + flat_root: Path | None = None, + modality: str | None = None, + ) -> CosMxDataset: + if flat_root is not None: + # Directory-based multimodal: scan flat files from the modality + # subtree, shared directory targets from the dataset root. + flat_files = _scan_for_files_to_read( + path=flat_root, + targets=_FLAT_FILE_TARGETS, + dataset_id=prefix, + ) + dir_files = _scan_for_files_to_read( + path=path, + targets=_DIR_TARGETS, + dataset_id=prefix, + ) + files = {**flat_files, **dir_files} + else: + files = _scan_for_files_to_read(path=path, targets=_SCAN_TARGETS, dataset_id=prefix) + # Fall back to modality-suffixed AnalysisResults directory + # (e.g. AnalysisResults_RNA, AnalysisResults_Protein). + analysis_dir = files.get("AnalysisResults") + if analysis_dir is None and modality is not None: + candidate = path / f"AnalysisResults_{modality}" + if candidate.is_dir(): + analysis_dir = candidate + return CosMxDataset( + path=path, + dataset_id=prefix, + exprMat_file=files.get("exprMat_file"), + fov_positions_file=files.get("fov_positions_file"), + metadata_file=files.get("metadata_file"), + tx_file=files.get("tx_file"), + polygons_file=files.get("polygons"), + analysis_results_dir=analysis_dir, + cell_stats_dir=files.get("CellStatsDir"), + run_summary_dir=files.get("RunSummary"), + cell_composite_dir=files.get("CellComposite"), + cell_labels_dir=_label_dir_fallback( + files.get("CellLabels"), + files.get("CellStatsDir"), + "CellLabels", + ), + cell_overlay_dir=files.get("CellOverlay"), + celltype_accessory_data=files.get("CellType_Accessory_Data"), + compartment_labels_dir=_label_dir_fallback( + files.get("CompartmentLabels"), + files.get("CellStatsDir"), + "CompartmentLabels", + ), + protein_dir=files.get("ProteinDir"), + morphology_2d_dir=files.get("Morphology2D"), + morphology_3d_dir=files.get("Morphology3D"), + ) + + if len(modal_map) > 1: + children = { + mod_name: _build_dataset( + info["prefix"], + flat_root=info.get("flat_root"), + modality=mod_name, + ) + for mod_name, info in modal_map.items() + } + return CosMxDataset(path=path, dataset_id=inferred_id, modalities=children) + + return _build_dataset(inferred_id) diff --git a/src/spatialdata_io/readers/cosmx/_io.py b/src/spatialdata_io/readers/cosmx/_io.py new file mode 100644 index 00000000..be586978 --- /dev/null +++ b/src/spatialdata_io/readers/cosmx/_io.py @@ -0,0 +1,714 @@ +"""CSV / TIFF / Parquet I/O helpers for the CosMx reader.""" + +from __future__ import annotations + +import csv +import gzip +import math +import re +import shutil +import subprocess +from typing import TYPE_CHECKING, Any + +import dask.array as da +import pandas as pd +import polars as pl +import pyarrow as pa +import pyarrow.csv as pacsv +import pyarrow.parquet as pq +import scipy.sparse +import shapely.geometry as sgeom +import tifffile +from anndata import AnnData +from anndata.utils import make_index_unique +from dask_image.imread import imread +from spatialdata._logging import logger +from tqdm import tqdm + +from ._utils import ( + _match_header, + _pandas_categoricals_to_string, + _to_float01_dtype_max, +) + +if TYPE_CHECKING: + import io + from pathlib import Path + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +COSMX_PIXEL_SIZE = 0.120280945 +MM_TO_PX = 1000.0 / COSMX_PIXEL_SIZE +COSMX_FOV_SIZE_PX = 4256.0 + + +# --------------------------------------------------------------------------- +# default image kwargs +# --------------------------------------------------------------------------- + + +def _default_image_kwargs( + image_models_kwargs: dict[str, Any] | None = None, + imread_kwargs: dict[str, Any] | None = None, +) -> tuple[dict[str, Any], dict[str, Any]]: + image_models_kwargs = {} if image_models_kwargs is None else image_models_kwargs + imread_kwargs = {} if imread_kwargs is None else imread_kwargs + + if "chunks" not in image_models_kwargs: + image_models_kwargs["chunks"] = (1, 1024, 1024) + if "scale_factors" not in image_models_kwargs: + image_models_kwargs["scale_factors"] = [2, 2, 2, 2] + + return image_models_kwargs, imread_kwargs + + +# --------------------------------------------------------------------------- +# FOV positions +# --------------------------------------------------------------------------- + + +def _read_fov_locs( + csv_path: Path, + *, + fovs: list[int] | None = None, +) -> pd.DataFrame: + df = pd.read_csv(csv_path) + + # 1) find FOV column + cols_lower = {c.lower(): c for c in df.columns} + fov_col: str | None = None + for cand in ("fov", "fov_id", "fov_idx", "field_of_view", "roi", "order"): + if cand in cols_lower: + fov_col = cols_lower[cand] + break + if fov_col is None: + raise ValueError(f"{csv_path.name}: cannot identify FOV column in {list(df.columns)!r}") + + # 2) detect orientation source BEFORE we synthesize px from mm + raw_cols = [c.lower() for c in df.columns] + file_has_px = any("px" in c for c in raw_cols) + file_has_mm = any("mm" in c for c in raw_cols) + needs_extra_flip = not (file_has_px and file_has_mm) + + # 3) try pixel columns first + def _find_px(colnames: list[str]) -> tuple[str | None, str | None]: + x_px_col = None + y_px_col = None + for c in colnames: + lc = c.lower() + if "px" in lc: + if lc.startswith("x") or "x_" in lc: + x_px_col = c + elif lc.startswith("y") or "y_" in lc: + y_px_col = c + return x_px_col, y_px_col + + x_px_col, y_px_col = _find_px(list(df.columns)) + + # 4) if no px -> look for mm and convert + if x_px_col is None or y_px_col is None: + x_mm_col = None + y_mm_col = None + for c in df.columns: + lc = c.lower() + if lc in {"x_mm", "x_global_mm"} or lc.endswith("_x_mm"): + x_mm_col = c + elif lc in {"y_mm", "y_global_mm"} or lc.endswith("_y_mm"): + y_mm_col = c + if x_mm_col is None and "X_mm" in df.columns: + x_mm_col = "X_mm" + if y_mm_col is None and "Y_mm" in df.columns: + y_mm_col = "Y_mm" + + if x_mm_col is None or y_mm_col is None: + raise ValueError( + f"{csv_path.name}: neither pixel nor mm coordinate columns found; got columns {list(df.columns)!r}" + ) + + df["x_global_px"] = df[x_mm_col].astype(float) * MM_TO_PX + df["y_global_px"] = df[y_mm_col].astype(float) * MM_TO_PX + x_px_col = "x_global_px" + y_px_col = "y_global_px" + + # 5) normalize names + fov_ser = pd.to_numeric(df[fov_col], errors="coerce").astype("Int64") + if fov_ser.isna().any(): + raise ValueError(f"{csv_path.name}: FOV column {fov_col!r} contains NaNs / non-numeric values.") + df["fov"] = fov_ser.astype(int) + + df = df.rename( + columns={ + x_px_col: "xmin", + y_px_col: "ymin", + } + ) + + # 6) add width/height and the per-FOV flip flag + df["xmax"] = df["xmin"].astype(float) + COSMX_FOV_SIZE_PX + df["ymax"] = df["ymin"].astype(float) + COSMX_FOV_SIZE_PX + df["flip_y"] = bool(needs_extra_flip) + + # 7) index + optional subset + df = df.set_index("fov", verify_integrity=True).sort_index() + + if fovs is not None: + fovs_sorted = sorted(int(f) for f in fovs) + missing = set(fovs_sorted) - set(df.index) + if missing: + raise KeyError(f"Requested FOVs not found in positions file {csv_path.name}: {sorted(missing)}") + df = df.loc[fovs_sorted] + + return df + + +# --------------------------------------------------------------------------- +# image helpers +# --------------------------------------------------------------------------- + + +def _get_cosmx_morphology_coords(images_dir: Path) -> list[str]: + images_paths = list(images_dir.glob("*.TIF")) + if len(images_paths) == 0: + raise FileNotFoundError(f"Expected to find images inside {images_dir}") + + with tifffile.TiffFile(images_paths[0]) as tif: + description = tif.pages[0].description + substrings = re.findall(r'"BiologicalTarget": "(.*?)",', description) + channels = re.findall(r'"ChannelId": "(.*?)",', description) + channel_order = list(re.findall(r'"ChannelOrder": "(.*?)",', description)[0]) + return [substrings[channels.index(x)] if x in channels else x for x in channel_order] + + +def _get_cosmx_protein_name(image_path: Path) -> str: + with tifffile.TiffFile(image_path) as tif: + description = tif.pages[0].description + substrings = re.findall(r'"DisplayName": "(.*?)",', description) + return substrings[0].replace("/", ".") + + +def _read_protein_fov(protein_dir: Path) -> tuple[da.Array, list[str]]: + images_paths = list(protein_dir.rglob("*.TIF")) + protein_imgs = [imread(image_path) for image_path in images_paths] + protein_imgs = [_to_float01_dtype_max(img) for img in protein_imgs] + protein_image = da.concatenate(protein_imgs, axis=0) + channel_names = [_get_cosmx_protein_name(image_path) for image_path in images_paths] + return protein_image, channel_names + + +def _find_matching_fov_file(images_dir: Path, fov: int) -> Path: + pattern = re.compile(rf".*_F0*{fov}\.TIF") + fov_files = [file for file in images_dir.rglob("*") if pattern.match(file.name)] + if len(fov_files) == 0: + raise FileNotFoundError(f"No file matches the pattern {pattern} inside {images_dir}") + if len(fov_files) != 1: + raise ValueError(f"Multiple files match the pattern {pattern}: {', '.join(map(str, fov_files))}") + return fov_files[0] + + +def _read_fov_image( + morphology_path: Path, + protein_path: Path | None, + morphology_coords: list[str], + *, + selected_channels: list[str] | None = None, + **imread_kwargs: Any, +) -> tuple[da.Array, list[str]]: + image = imread(morphology_path, **imread_kwargs) + image = _to_float01_dtype_max(image) + + protein_names: list[str] = [] + if protein_path is not None: + protein_image, protein_names = _read_protein_fov(protein_path) + image = da.concatenate([image, protein_image], axis=0) + + all_names = make_index_unique(pd.Index(morphology_coords + protein_names)).tolist() + + if selected_channels is not None: + name_to_idx = {n: i for i, n in enumerate(all_names)} + present = [c for c in selected_channels if c in name_to_idx] + missing_here = [c for c in selected_channels if c not in name_to_idx] + if missing_here: + logger.warning( + "FOV %s: skipping %d missing requested channel(s): %s. Present: %s", + morphology_path.name, + len(missing_here), + missing_here, + all_names, + ) + if not present: + raise ValueError( + f"No requested channels present in file {morphology_path.name}. " + f"Requested {selected_channels}, available {all_names}." + ) + idx = [name_to_idx[c] for c in present] + image = image[idx, :, :] + all_names = present + + return image, all_names + + +# --------------------------------------------------------------------------- +# expr / metadata readers (polars) +# --------------------------------------------------------------------------- + + +def _read_expr_mat_polars( + expr_path: Path, + n_rows: int | None = None, + fovs: set[int] | None = None, +) -> tuple[AnnData, pd.DataFrame | None]: + """Read expression matrix, filtering out background (cell_ID=0). + + Returns + ------- + tuple of (AnnData, bg_df or None) + The cell expression table and, if any cell_ID=0 rows were present, + a DataFrame of per-FOV background signal indexed by ``fov``. + """ + sample = pl.read_csv(expr_path, n_rows=1) + expr_cols = [c for c in sample.columns if c not in ("fov", "cell_ID")] + + lf = pl.scan_csv(expr_path, n_rows=n_rows) + if fovs is not None: + lf = lf.filter(pl.col("fov").is_in(fovs)) + + lf = lf.with_columns( + (pl.col("fov").cast(pl.Utf8) + "_" + pl.col("cell_ID").cast(pl.Utf8)).alias("fov_cellID") + ).select(["fov", "cell_ID", "fov_cellID"] + expr_cols) + pdf = lf.collect().to_pandas() + + # Separate background rows (cell_ID=0) before building the AnnData. + bg_mask = pdf["cell_ID"] == 0 + bg_df: pd.DataFrame | None = None + if bg_mask.any(): + bg_rows = pdf.loc[bg_mask] + bg_df = bg_rows.set_index("fov")[expr_cols] + bg_df.index = bg_df.index.astype(int) + bg_df.index.name = "fov" + logger.info( + "Filtered %d background row(s) (cell_ID=0) from expression matrix across %d FOV(s).", + len(bg_df), + bg_df.index.nunique(), + ) + pdf = pdf.loc[~bg_mask] + + obs = pdf[["fov", "cell_ID", "fov_cellID"]].set_index("fov_cellID", verify_integrity=True) + X = scipy.sparse.csr_matrix(pdf[expr_cols].values) + var = pd.DataFrame({"gene": expr_cols}).set_index("gene") + + return AnnData(X=X, obs=obs, var=var), bg_df + + +def _read_metadata_polars( + meta_path: Path, + n_rows: int | None = None, +) -> pd.DataFrame: + df = ( + pl.read_csv( + meta_path, + n_rows=n_rows, + infer_schema_length=n_rows, + ) + .with_columns((pl.col("fov").cast(pl.Utf8) + "_" + pl.col("cell_ID").cast(pl.Utf8)).alias("fov_cellID")) + .filter(pl.col("cell_ID") != 0) + ) + + pdf = df.to_pandas() + return pdf.set_index("fov_cellID", verify_integrity=True) + + +# --------------------------------------------------------------------------- +# per-FOV local -> stitched-grid placement (shared by polygons & transcripts) +# --------------------------------------------------------------------------- + + +def place_local_in_fov_grid( + df: Any, + fov_locs: pd.DataFrame, + *, + fov_col: str = "fov", + x_local_col: str = "x_local_px", + y_local_col: str = "y_local_px", + out_x: str = "x_global_px", + out_y: str = "y_global_px", +) -> Any: + """Overwrite *out_x* / *out_y* with stitched-grid coordinates. + + Derived from per-FOV local px columns, matching polygon placement:: + + x_global = x0 + x_local + y_global = y0 + y_local if the FOV is flipped + = y0 + (FOV_SIZE - y_local) otherwise + + where ``x0``/``y0`` are the FOV's ``xmin``/``ymin``. This mirrors the + placement :func:`_read_polygons_csv` applies to polygon vertices, so + transcripts co-register with polygons by construction (both share the + same per-FOV local coordinate system). + + The flip height is the fixed CosMx FOV size, not ``fov_locs['ymax']`` — + image stitching may overwrite ``ymax`` with the actual raster height, + which must not perturb coordinate placement. + + Looked up per row via ``Series.map`` (no join/shuffle), so it works for + both pandas and Dask DataFrames. *df* must contain *fov_col*, + *x_local_col* and *y_local_col*. + """ + x0 = fov_locs["xmin"].astype(float).to_dict() + y0 = fov_locs["ymin"].astype(float) + flip = fov_locs["flip_y"].astype(bool) if "flip_y" in fov_locs.columns else pd.Series(False, index=fov_locs.index) + # Express the y flip as a per-FOV affine (offset + sign * y_local) so the + # placement is plain arithmetic — no boolean ``Series.where``, whose + # condition is fragile on Dask (``Series.map`` of a dict yields object + # dtype there, which ``where`` rejects): + # flip: y = y0 + y_local -> offset = y0, sign = +1 + # noflip: y = y0 + (FOV - y_local) -> offset = y0 + FOV, sign = -1 + y_off = (y0 + (~flip) * COSMX_FOV_SIZE_PX).to_dict() + y_sgn = flip.map({True: 1.0, False: -1.0}).to_dict() + + # ``.astype(float)`` is required: on Dask, ``Series.map(dict)`` yields object + # dtype, which would propagate through the arithmetic below. + fov = df[fov_col] + x_origin = fov.map(x0).astype(float) + y_origin = fov.map(y_off).astype(float) + y_sign = fov.map(y_sgn).astype(float) + return df.assign( + **{ + out_x: x_origin + df[x_local_col], + out_y: y_origin + y_sign * df[y_local_col], + } + ) + + +# --------------------------------------------------------------------------- +# polygons CSV -> DataFrame (global px) +# --------------------------------------------------------------------------- + + +def _read_polygons_csv( + csv_path: Path, + *, + fov_locs: pd.DataFrame, + use_polars: bool = True, + n_workers: int | None = None, + fov_set: set[int] | None = None, +) -> pd.DataFrame: + if fov_locs is None: + raise ValueError("fov_locs is required to globalize polygon coordinates.") + + opener = gzip.open if csv_path.suffix == ".gz" else open + with opener(csv_path, "rt") as fh: + raw_hdr = next(csv.reader(fh)) + rename = _match_header(raw_hdr) + + xl = next((c for c in raw_hdr if re.fullmatch(r"x_local_px", c, flags=re.I)), None) + yl = next((c for c in raw_hdr if re.fullmatch(r"y_local_px", c, flags=re.I)), None) + if xl is None or yl is None: + xl = next((c for c in raw_hdr if c.strip().lower() in {"x", "x_px"}), None) + yl = next((c for c in raw_hdr if c.strip().lower() in {"y", "y_px"}), None) + if xl is None or yl is None: + raise ValueError(f"{csv_path.name}: missing x_local_px / y_local_px columns (or acceptable fallback).") + + core_cols = {orig: canon for orig, canon in rename.items() if canon in {"fov", "cell_ID", "polygon_index"}} + + if use_polars: + select_cols = [pl.col(orig).alias(core_cols[orig]) for orig in core_cols] + [ + pl.col(xl).alias("x_local"), + pl.col(yl).alias("y_local"), + ] + lf = pl.scan_csv(csv_path).select(select_cols) + if "polygon_index" not in core_cols.values(): + lf = lf.with_columns(pl.int_ranges(0, pl.len()).over(["fov", "cell_ID"]).alias("polygon_index")) + pdf = ( + lf.group_by(["fov", "cell_ID", "polygon_index"]) + .agg([pl.col("x_local").alias("vx_local"), pl.col("y_local").alias("vy_local")]) + .collect() + .to_pandas() + ) + else: + usecols = list(core_cols.keys()) + [xl, yl] + raw = pd.read_csv( + csv_path, + usecols=usecols, + compression="gzip" if csv_path.suffix == ".gz" else None, + ).rename(columns={**core_cols, xl: "x_local", yl: "y_local"}) + if "polygon_index" not in raw.columns: + raw["polygon_index"] = raw.groupby(["fov", "cell_ID"]).cumcount() + pdf = ( + raw.groupby(["fov", "cell_ID", "polygon_index"]) + .agg({"x_local": list, "y_local": list}) + .reset_index() + .rename(columns={"x_local": "vx_local", "y_local": "vy_local"}) + ) + + if fov_set is not None: + pdf = pdf[pdf["fov"].isin(fov_set)].reset_index(drop=True) + + x_off = fov_locs["xmin"].astype(float).to_dict() + y_off = fov_locs["ymin"].astype(float).to_dict() + y_max = fov_locs["ymax"].astype(float).to_dict() if "ymax" in fov_locs.columns else None + flip_map = fov_locs["flip_y"].astype(bool).to_dict() if "flip_y" in fov_locs.columns else None + + local_max_y: dict[int, float] = {} + for fov, vy_local in zip(pdf["fov"], pdf["vy_local"], strict=False): + m = max(float(v) for v in vy_local) + if fov not in local_max_y or m > local_max_y[fov]: + local_max_y[fov] = m + + geoms: list[sgeom.Polygon | None] = [] + bad = 0 + + for fov, vx_local, vy_local in zip(pdf["fov"], pdf["vx_local"], pdf["vy_local"], strict=False): + if fov not in x_off: + raise KeyError(f"FOV {fov} in polygons has no entry in fov_locs") + + x0 = x_off[fov] + y0 = y_off[fov] + + if y_max is not None and not math.isnan(y_max[fov]) and y_max[fov] > y0: + h = y_max[fov] - y0 + else: + h = max(COSMX_FOV_SIZE_PX, local_max_y.get(fov, 0.0)) + + extra_flip = bool(flip_map.get(fov, False)) if flip_map is not None else False + + pts: list[tuple[float, float]] = [] + for xl_, yl_ in zip(vx_local, vy_local, strict=False): + try: + x = x0 + float(xl_) + y_local = float(yl_) + if extra_flip: + y = y0 + y_local + else: + y = y0 + (h - y_local) + except Exception: + continue + if math.isfinite(x) and math.isfinite(y): + pts.append((x, y)) + + if len({(round(x, 6), round(y, 6)) for x, y in pts}) < 3: + geoms.append(None) + bad += 1 + continue + + if pts[0] != pts[-1]: + pts.append(pts[0]) + + poly = sgeom.Polygon(pts) + if not poly.is_valid or poly.is_empty: + geoms.append(None) + bad += 1 + else: + geoms.append(poly) + + pdf["geometry"] = geoms + if bad: + logger.warning("%s: skipped %d malformed polygon(s).", csv_path.name, bad) + pdf = pdf.dropna(subset=["geometry"]).reset_index(drop=True) + + # Merge multi-polygon cells: a cell with multiple polygon_index values + # must become a single MultiPolygon row to avoid duplicate global_cell_id. + if (pdf.groupby(["fov", "cell_ID"]).size() > 1).any(): + + def _merge_geoms(geoms): + geoms = [g for g in geoms if g is not None] + if len(geoms) == 1: + return geoms[0] + return sgeom.MultiPolygon(geoms) + + pdf = pdf.groupby(["fov", "cell_ID"], sort=False).agg({"geometry": _merge_geoms}).reset_index() + + pdf["bounds_max_x"] = pdf.geometry.map(lambda g: g.bounds[2]) + pdf["bounds_max_y"] = pdf.geometry.map(lambda g: g.bounds[3]) + + return pdf + + +# --------------------------------------------------------------------------- +# parquet cache for transcripts +# --------------------------------------------------------------------------- + + +def _is_valid_parquet(path: Path) -> bool: + try: + pq.ParquetFile(path) + return True + except Exception: + return False + + +def _stream_csvgz_to_parquet( + src: Path, + parquet_path: Path, + *, + row_group_rows: int = 5_000_000, + n_workers: int | None = None, +) -> None: + total_rows = _count_csv_rows(src, n_workers=n_workers) + if total_rows == 0: + raise ValueError(f"{src} appears empty or header-only") + + if src.suffix != ".gz": + fin: io.BufferedReader | io.RawIOBase = open(src, "rb") + elif n_workers and n_workers > 1 and shutil.which("pigz"): + proc = subprocess.Popen( + ["pigz", f"-p{n_workers}", "-dc", str(src)], + stdout=subprocess.PIPE, + bufsize=2**20, + ) + fin = proc.stdout # type: ignore[assignment] + else: + fin = gzip.open(src, "rb") + + # Columns like y_global_px can mix plain ints and scientific notation + # (e.g. "1e+05"), which Arrow's int64 auto-inference can't handle. + # Force all plausible numeric coordinate columns to float64. + _float_cols = { + "x_local_px", + "y_local_px", + "x_global_px", + "y_global_px", + "x_global_mm", + "y_global_mm", + "z", + } + read_opts = pacsv.ReadOptions(block_size=32 << 20) + convert_opts = pacsv.ConvertOptions( + auto_dict_encode=True, + column_types={c: pa.float64() for c in _float_cols}, + ) + reader = pacsv.open_csv(fin, read_options=read_opts, convert_options=convert_opts) + + with ( + pq.ParquetWriter(parquet_path, reader.schema, compression="zstd") as writer, + tqdm(total=total_rows, unit="rows", desc="CSV -> Parquet", unit_scale=True, dynamic_ncols=True) as bar, + ): + for batch in reader: + writer.write_table( + pa.Table.from_batches([batch]), + row_group_size=row_group_rows, + ) + bar.update(batch.num_rows) + + if hasattr(fin, "close"): + fin.close() + + +def _count_csv_rows(src: Path, *, n_workers: int | None = None) -> int: + if src.suffix == ".gz" and shutil.which("pigz") and shutil.which("wc"): + cmd = f"pigz -dc -p{n_workers or 1} {src} | wc -l" + rows = int(subprocess.check_output(cmd, shell=True).strip()) + return max(rows - 1, 0) + if src.suffix != ".gz" and shutil.which("wc"): + rows = int(subprocess.check_output(["wc", "-l", str(src)]).split()[0]) + return max(rows - 1, 0) + + opener = gzip.open if src.suffix == ".gz" else open + with opener(src, "rb") as fh: + buf_size = 2**20 + rows = 0 + while True: + chunk = fh.read(buf_size) + if not chunk: + break + rows += chunk.count(b"\n") + return max(rows - 1, 0) + + +def _resolve_tx_path(path: Path, dataset_id: str | None) -> Path: + """Resolve a transcript CSV path. + + If *path* is a file, return it. Otherwise, expect a directory + containing ``_tx_file.csv[.gz]``. + """ + if path.is_file(): + return path + + if dataset_id is None: + raise FileNotFoundError("Transcript path is a directory but no dataset_id was provided.") + + gz_path = path / f"{dataset_id}_tx_file.csv.gz" + csv_path = path / f"{dataset_id}_tx_file.csv" + src = gz_path if gz_path.exists() else csv_path + if not src.exists(): + raise FileNotFoundError(f"No *_tx_file.* for dataset {dataset_id!r} under {path}") + return src + + +def _parquet_cache_for_tx( + path: Path, + dataset_id: str | None, + *, + n_workers: int | None = None, + row_group_rows: int = 5_000_000, +) -> Path: + src = _resolve_tx_path(path, dataset_id) + parquet_path = src.with_suffix(".parquet") + + if parquet_path.exists() and _is_valid_parquet(parquet_path): + logger.info("[transcripts] Found Parquet cache — skipping conversion.") + return parquet_path + + if parquet_path.exists(): + logger.warning("[transcripts] Found corrupted Parquet cache — rebuilding.") + parquet_path.unlink() + + size_gb = src.stat().st_size / 1_073_741_824 + logger.warning( + "[transcripts] Converting %s (%.1f GB) to Parquet for faster reads. " + "This is a one-time operation but may take several minutes...", + src.name, + size_gb, + ) + + _stream_csvgz_to_parquet( + src, + parquet_path, + row_group_rows=row_group_rows, + n_workers=n_workers, + ) + + if not _is_valid_parquet(parquet_path): + parquet_path.unlink(missing_ok=True) + raise RuntimeError(f"Could not create a valid Parquet file at {parquet_path}") + + logger.info("[transcripts] Parquet cache ready at %s.", parquet_path.name) + return parquet_path + + +# --------------------------------------------------------------------------- +# transcript CSV (small) -> pandas +# --------------------------------------------------------------------------- + + +def _maybe_warn_big_file(src: Path, *, threshold_gb: float = 1.0) -> bool: + size_gb = src.stat().st_size / 1_073_741_824 + if size_gb > threshold_gb: + return True + return False + + +def _read_transcripts_csv(path: Path, dataset_id: str | None, nrows: int | None = None) -> pd.DataFrame: + src = _resolve_tx_path(path, dataset_id) + if src.suffix == ".gz": + df = pd.read_csv(src, compression="gzip", nrows=nrows) + else: + df = pd.read_csv(src, nrows=nrows) + + needed = ["x_global_px", "y_global_px", "target"] + missing = [c for c in needed if c not in df.columns] + if missing: + raise ValueError(f"The file {src} must contain the following columns: {', '.join(needed)}. Missing: {missing}") + + _pandas_categoricals_to_string(df) + + fov_ser = pd.to_numeric(df["fov"], errors="coerce").fillna(0).astype(int) + cell_ser = pd.to_numeric(df["cell_ID"], errors="coerce").fillna(0).astype(int) + df["fov"] = fov_ser + df["cell_ID"] = cell_ser + + max_cell = int(cell_ser.max()) if len(cell_ser) else 0 + df["unique_cell_id"] = fov_ser * (max_cell + 1) * (cell_ser > 0).astype(int) + cell_ser + + return df diff --git a/src/spatialdata_io/readers/cosmx/_reader.py b/src/spatialdata_io/readers/cosmx/_reader.py new file mode 100644 index 00000000..96a01db5 --- /dev/null +++ b/src/spatialdata_io/readers/cosmx/_reader.py @@ -0,0 +1,1512 @@ +"""CosMx reader for spatialdata. + +Reads NanoString CosMx spatial transcriptomics data into a +:class:`spatialdata.SpatialData` object. Supports all known CosMx export +formats (AtomX flat files, nested CellStatsDir layouts, multimodal +RNA+Protein runs) and transparently handles stitching, polygon +rasterization, and coordinate normalization. + +Public API +---------- +cosmx + Top-level reader function (single or multimodal datasets). +CosMxDataset + Frozen dataclass describing the on-disk layout of a CosMx dataset. +CosMxDatasetReader + Stateful reader that orchestrates per-element I/O. +""" + +from __future__ import annotations + +import math +import re +import textwrap +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import dask.array as da +import dask.dataframe as dd +import geopandas as gpd +import numpy as np +import pandas as pd +import shapely.geometry as sgeom +import tifffile +from shapely.affinity import translate as _translate +from spatialdata import SpatialData, sanitize_table +from spatialdata._logging import logger +from spatialdata.models import ( + Image2DModel, + Labels2DModel, + PointsModel, + ShapesModel, + TableModel, +) +from spatialdata.transformations import Translation + +from spatialdata_io._constants._constants import CosmxKeys +from spatialdata_io._docs import inject_docs +from spatialdata_io.readers._utils._utils import _set_reader_metadata + +from ._discovery import ( + _infer_dataset_id, + _set_up_cosmx_dataset_for_conversion, +) +from ._io import ( + COSMX_FOV_SIZE_PX, + _default_image_kwargs, + _find_matching_fov_file, + _get_cosmx_morphology_coords, + _maybe_warn_big_file, + _parquet_cache_for_tx, + _read_expr_mat_polars, + _read_fov_image, + _read_fov_locs, + _read_metadata_polars, + _read_polygons_csv, + _read_transcripts_csv, + place_local_in_fov_grid, +) +from ._stitching import ( + _canvas_from_fov_locs_for_polygons, + _plot_fov_preview, + _polygons_to_label_raster, + _read_stitched_cell_labels_from_dir, + _read_stitched_image, + stitch_segmentation_label_image, +) +from ._utils import ( + _dask_categoricals_to_string, + _match_canonical, + _normalize_image_channels, + _pandas_categoricals_to_string, + detect_fovs_with_data, +) + +if TYPE_CHECKING: + from anndata import AnnData + +__all__ = ["cosmx"] + +# --------------------------------------------------------------------------- +# Dataset descriptor +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True, kw_only=True) +class CosMxDataset: + """Frozen descriptor of a CosMx dataset directory layout. + + Automatically infers ``dataset_id`` from marker-file prefixes on disk. + For multimodal runs (e.g. S3RNA + S3Protein), the ``modalities`` field + maps each modality name to its own ``CosMxDataset``. + """ + + path: Path = field(default=Path("."), metadata={"help": "Root directory."}) + dataset_id: str | None = field(default=None) + exprMat_file: Path | None = field(default=None) + fov_positions_file: Path | None = field(default=None) + metadata_file: Path | None = field(default=None) + tx_file: Path | None = field(default=None) + polygons_file: Path | None = field(default=None) + analysis_results_dir: Path | None = field(default=None) + cell_stats_dir: Path | None = field(default=None) + run_summary_dir: Path | None = field(default=None) + cell_composite_dir: Path | None = field(default=None) + cell_labels_dir: Path | None = field(default=None) + cell_overlay_dir: Path | None = field(default=None) + celltype_accessory_data: Path | None = field(default=None) + compartment_labels_dir: Path | None = field(default=None) + protein_dir: Path | None = field(default=None) + morphology_2d_dir: Path | None = field(default=None) + morphology_3d_dir: Path | None = field(default=None) + modalities: dict[str, CosMxDataset] | None = field(default=None) + + def __post_init__(self) -> None: + object.__setattr__(self, "path", self.path.resolve()) + if self.modalities: + if not isinstance(self.modalities, dict) or not self.modalities: + raise ValueError("modalities must be a non-empty dict.") + return + # At least one data file must be present. + has_data = any( + val is not None for name, val in self.__dict__.items() if name not in {"dataset_id", "modalities", "path"} + ) + if not has_data: + raise ValueError("At least one data file or directory must be present.") + inferred = _infer_dataset_id(self.path, self.dataset_id) + object.__setattr__(self, "dataset_id", inferred) + + # -- pretty-printing -- + + def __str__(self) -> str: + root = self.path + lines = [self.__class__.__name__, f" path: {root}", f" dataset_id: {self.dataset_id or ''}"] + for name, val in self.__dict__.items(): + if name in {"path", "dataset_id", "modalities"} or val is None: + continue + if isinstance(val, Path): + if not val.exists(): + continue + try: + val = f"/{val.resolve().relative_to(root).as_posix()}" + except ValueError: + val = str(val.resolve()) + lines.append(f" {name}: {val}") + else: + lines.append(f" {name}: {val!r}") + for mod_name, mod_ds in (self.modalities or {}).items(): + lines.append(f" {mod_name}:") + lines.append(textwrap.indent(str(mod_ds), " ")) + return "\n".join(lines) + + def _repr_pretty_(self, p, cycle) -> None: + p.text(self.__class__.__name__ + "(…)" if cycle else str(self)) + + +# --------------------------------------------------------------------------- +# Transformation helper +# --------------------------------------------------------------------------- + + +def _translation_transform(ox: float, oy: float) -> Translation: + """Translation by ``(ox, oy)`` in the global xy coordinate system.""" + return Translation([ox, oy], axes=("x", "y")) + + +# --------------------------------------------------------------------------- +# Element-name helpers +# --------------------------------------------------------------------------- + + +def _element_name(base: str, fovs: set[int] | None, *, max_listed: int = 8) -> str: + """Build a context-dependent element name from *base* and the FOV set.""" + if not fovs: + return base + if len(fovs) == 1: + return f"F{next(iter(fovs)):05d}_{base}" + sorted_fovs = sorted(fovs) + if len(sorted_fovs) <= max_listed: + return "subset_" + "_".join(f"F{f:05d}" for f in sorted_fovs) + f"_{base}" + return base + + +# --------------------------------------------------------------------------- +# Table helpers +# --------------------------------------------------------------------------- + + +def _global_id_series(table: AnnData) -> pd.Series | None: + """Return the (Int64-coerced) ``global_cell_id`` column of a table, or ``None``.""" + adata = table.table if hasattr(table, "table") else table + if "global_cell_id" not in adata.obs: + return None + col = adata.obs["global_cell_id"] + return col.astype("Int64") if isinstance(col.dtype, pd.CategoricalDtype) else col + + +def _parse_cell_table(adata: AnnData, *, region: str | None, region_key: str | None = "region_key") -> AnnData: + """Parse a cell-annotation table with the standard CosMx region/instance keys.""" + return TableModel.parse( + adata, region=region, region_key=region_key, instance_key="global_cell_id", overwrite_metadata=True + ) + + +def _parse_labels(arr: Any, ox: float, oy: float) -> Labels2DModel: + """Parse a ``(y, x)`` label raster with a global translation by ``(ox, oy)``.""" + return Labels2DModel.parse(arr, dims=("y", "x"), transformations={"global": _translation_transform(ox, oy)}) + + +def _filter_tables_to_ids( + tables: dict[str, AnnData], + valid_ids: set[int], + *, + source_label: str = "label IDs", +) -> dict[str, AnnData]: + """Keep only table rows whose global_cell_id is in *valid_ids*.""" + if not valid_ids: + return tables + filtered: dict[str, AnnData] = {} + for name, table in tables.items(): + col = _global_id_series(table) + if col is None: + filtered[name] = table + continue + mask = col.isin(valid_ids) + if not mask.any(): + logger.warning("Table %s: all rows removed after matching to %s.", name, source_label) + continue + if mask.all(): + filtered[name] = table + continue + logger.info("Table %s: dropping %d row(s) not in %s.", name, int((~mask).sum()), source_label) + adata = table.table if hasattr(table, "table") else table + ad2 = adata[mask].copy() + region_val = ad2.obs["region_key"].iloc[0] if "region_key" in ad2.obs else None + filtered[name] = _parse_cell_table(ad2, region=region_val, region_key="region_key" if region_val else None) + return filtered + + +# --------------------------------------------------------------------------- +# CosMxDatasetReader +# --------------------------------------------------------------------------- + + +class CosMxDatasetReader: + """Stateful reader that orchestrates reading of all CosMx elements. + + Manages FOV subsetting, coordinate origin normalization, + global_cell_id computation, and lazy polygon caching. + + Parameters + ---------- + dataset + A :class:`CosMxDataset` describing the files on disk. + fovs + Optional set of FOV IDs to read (``None`` = all). + n_workers + Number of parallel workers for dask/zarr operations. + flip_image + Whether to flip images vertically. + polygons_as_labels + If ``True``, rasterize polygons into a label image. + fov_locs + Pre-parsed FOV positions (skips re-reading the CSV). + align_rasters_to_polygons + If ``True``, anchor images/labels to the polygon coordinate origin. + keep_polygons_after_rasterize + If ``True``, keep vector polygons even when rasterizing to labels. + """ + + def __init__( + self, + dataset: CosMxDataset, + *, + fovs: set[int] | None = None, + n_workers: int | None = None, + flip_image: bool | None = None, + polygons_as_labels: bool = False, + fov_locs: pd.DataFrame | None = None, + align_rasters_to_polygons: bool | None = None, + keep_polygons_after_rasterize: bool = False, + image_normalization_percentile: float | None = None, + ) -> None: + self.dataset = dataset + self.fovs = fovs + self.n_workers = n_workers or 1 + self.flip_image = bool(flip_image) if flip_image is not None else False + self.polygons_as_labels = polygons_as_labels + self.keep_polygons_after_rasterize = keep_polygons_after_rasterize + if image_normalization_percentile is not None: + if not 0.0 <= image_normalization_percentile <= 100.0: + raise ValueError( + f"image_normalization_percentile must be in [0, 100] or None, got " + f"{image_normalization_percentile!r}." + ) + if image_normalization_percentile < 1.0: + logger.warning( + "image_normalization_percentile=%s is < 1; percentiles are 0-100 — did you mean e.g. 99.9?", + image_normalization_percentile, + ) + self.image_normalization_percentile = image_normalization_percentile + self._image_norm_logged = False + + # Read or reuse FOV positions + if fov_locs is not None: + self.fov_locs = fov_locs + elif dataset.fov_positions_file is not None: + self.fov_locs = _read_fov_locs( + dataset.fov_positions_file, + fovs=sorted(fovs) if fovs else None, + ) + else: + raise ValueError("FOV positions file is required to read CosMx data.") + + self._base_origin_x = float(self.fov_locs["xmin"].min()) + self._base_origin_y = float(self.fov_locs["ymin"].min()) + + # Polygon cache and origin (set lazily by _get_polygons) + self._poly_df: pd.DataFrame | None = None + self._origin_x: float | None = None + self._origin_y: float | None = None + + if align_rasters_to_polygons is None: + align_rasters_to_polygons = polygons_as_labels + self.align_rasters_to_polygons = bool(align_rasters_to_polygons) + + # Frozen after first use to ensure consistent IDs across elements + self.max_cell_id: int | None = None + + # -- Polygon loading and origin computation -- + + def _get_polygons(self) -> pd.DataFrame: + """Load polygons (cached). Sets ``_origin_x`` / ``_origin_y``.""" + if self._poly_df is not None: + return self._poly_df + if self.dataset.polygons_file is None: + raise FileNotFoundError("Dataset has no polygons_file.") + + poly_df = _read_polygons_csv( + self.dataset.polygons_file, + fov_locs=self.fov_locs, + n_workers=self.n_workers, + fov_set=self.fovs, + use_polars=True, + ) + poly_df["global_cell_id"] = self.global_cell_id(poly_df) + + # Compute origin from actual polygon bounds + xs = [float(g.bounds[0]) for g in poly_df["geometry"] if g is not None and not g.is_empty] + ys = [float(g.bounds[1]) for g in poly_df["geometry"] if g is not None and not g.is_empty] + ox = min(xs) if xs else self._base_origin_x + oy = min(ys) if ys else self._base_origin_y + + # Shift polygons to origin + if ox != 0.0 or oy != 0.0: + poly_df["geometry"] = [ + _translate(g, xoff=-ox, yoff=-oy) if g is not None else None for g in poly_df["geometry"] + ] + poly_df = poly_df.dropna(subset=["geometry"]).reset_index(drop=True) + + self._origin_x = ox + self._origin_y = oy + self._poly_df = poly_df + return self._poly_df + + def _ensure_origin(self) -> tuple[float, float]: + """Return the coordinate origin, loading polygons if needed.""" + if self._origin_x is not None and self._origin_y is not None: + return self._origin_x, self._origin_y + if self.dataset.polygons_file is not None and (self.align_rasters_to_polygons or self.polygons_as_labels): + self._get_polygons() + return self._origin_x, self._origin_y # type: ignore[return-value] + self._origin_x = self._base_origin_x + self._origin_y = self._base_origin_y + return self._origin_x, self._origin_y + + # -- Global cell ID -- + + def _lock_max_cell_id(self, current_max: int) -> int: + """Lock or validate ``max_cell_id`` and return the base multiplier. + + On first call, locks ``max_cell_id`` to *current_max*. On subsequent + calls, validates that *current_max* does not exceed the locked value. + + Returns ``max_cell_id + 1`` (the base used in the ID formula). + """ + if self.max_cell_id is None: + self.max_cell_id = current_max + elif current_max > self.max_cell_id: + raise ValueError( + f"cell_ID up to {current_max} but reader locked to " + f"max_cell_id={self.max_cell_id}. Load the file with the " + "largest per-FOV cell_ID first (usually the polygon CSV)." + ) + return self.max_cell_id + 1 + + def global_cell_id(self, df: pd.DataFrame) -> pd.Series: + """Compute ``global_cell_id = fov * (max_cell_id + 1) * (cell_ID > 0) + cell_ID``. + + Locks ``max_cell_id`` on first call for consistency across elements. + """ + if "fov" not in df.columns or "cell_ID" not in df.columns: + raise ValueError("Expected columns 'fov' and 'cell_ID'.") + fov_ser = pd.to_numeric(df["fov"], errors="coerce").fillna(0).astype(int) + cell_ser = pd.to_numeric(df["cell_ID"], errors="coerce").fillna(0).astype(int) + current_max = int(cell_ser.max()) if len(cell_ser) else 0 + base = self._lock_max_cell_id(current_max) + return fov_ser * base * (cell_ser > 0).astype(int) + cell_ser + + # -- Read images -- + + def read_images( + self, + *, + read_proteins: bool, + image_models_kwargs: dict[str, Any], + imread_kwargs: dict[str, Any], + channels: list[str] | None = None, + ) -> dict[str, Image2DModel]: + """Read morphology (and optionally protein) images.""" + if self.dataset.morphology_2d_dir is None: + return {} + + images_dir = self.dataset.morphology_2d_dir + morphology_coords = _get_cosmx_morphology_coords(images_dir) + + if self.align_rasters_to_polygons and self.dataset.polygons_file is not None: + self._get_polygons() + ox, oy = self._ensure_origin() + + # Morphology TIFFs are stored y-inverted relative to the FOV-grid + # placement, so they need a vertical flip to co-register with + # transcripts/labels. This must NOT depend on whether polygons were + # read — gating it on ``align_rasters_to_polygons`` left px-only images + # mirrored when read without polygons (issue #42). ``flip_image`` + # defaults to True (a user may pass ``flip_image=False`` for a dataset + # stored the other way). + per_fov_flip = self.flip_image + + protein_dirs: dict[int, Path] = {} + if read_proteins and self.dataset.analysis_results_dir is not None: + protein_dirs = { + int(p.parent.name[3:]): p for p in self.dataset.analysis_results_dir.rglob("**/FOV*/ProteinImages") + } + + transform = _translation_transform(ox, oy) + + # Single FOV — no stitching + if self.fovs and len(self.fovs) == 1: + fov = next(iter(self.fovs)) + try: + fov_path = _find_matching_fov_file(images_dir, fov) + except FileNotFoundError: + logger.warning("No image file found for FOV %d — skipping images.", fov) + return {} + image, c_coords = _read_fov_image( + fov_path, + protein_dirs.get(fov), + morphology_coords, + selected_channels=channels, + **imread_kwargs, + ) + if per_fov_flip: + image = image[:, ::-1, :] + return { + _element_name("image", self.fovs): self._normalize_and_parse_image( + image, c_coords, transform, image_models_kwargs + ) + } + + # Multi-FOV — stitch + fovs_filter = self.fovs if self.fovs else None + sel_locs = self.fov_locs.loc[sorted(self.fovs)] if self.fovs else self.fov_locs + prot_subset = {k: v for k, v in protein_dirs.items() if k in self.fovs} if self.fovs else protein_dirs + + # Tighten the image canvas to image-bearing FOVs (issue #37), but only + # when the base FOV origin governs; when polygons drive the origin keep + # it so the image stays co-registered with transcripts/labels. + tighten = not ( + self.dataset.polygons_file is not None and (self.align_rasters_to_polygons or self.polygons_as_labels) + ) + image, c_coords, used_ox, used_oy = _read_stitched_image( + images_dir, + sel_locs, + prot_subset, + morphology_coords, + flip_image=per_fov_flip, + fovs_filter=fovs_filter, + selected_channels=channels, + n_workers=self.n_workers, + tighten_to_seen=tighten, + **imread_kwargs, + ) + if image is None: + return {} + img_transform = _translation_transform(used_ox, used_oy) if tighten else transform + name = "stitched_image" if not self.fovs else _element_name("image", self.fovs) + return {name: self._normalize_and_parse_image(image, c_coords, img_transform, image_models_kwargs)} + + def _normalize_and_parse_image( + self, + image: da.Array, + c_coords: list[str], + transform: Any, + image_models_kwargs: dict[str, Any], + ) -> Image2DModel: + """Parse the assembled image, optionally applying a per-channel percentile stretch. + + Post-stitch; the reversible per-channel divisors are recorded in the element's + ``attrs`` for promotion into ``sdata.attrs`` at assembly. + """ + pct = self.image_normalization_percentile + image, scales = _normalize_image_channels(image, c_coords, percentile=pct, n_workers=self.n_workers) + stretched = any(div != 1.0 for div in scales.values()) # all-1.0 == nothing applied + if stretched and not self._image_norm_logged: + logger.info( + "Applying per-channel %.4g-th-percentile image normalization (scale-only, reversible via attrs).", + pct, + ) + self._image_norm_logged = True + parsed = Image2DModel.parse( + image, + dims=("c", "y", "x"), + c_coords=c_coords, + transformations={"global": transform}, + **image_models_kwargs, + ) + if stretched: + parsed.attrs["cosmx_image_normalization"] = {"percentile": pct, "channel_scales": scales} + return parsed + + # -- Read transcripts -- + + def read_transcripts(self) -> dict[str, PointsModel]: + """Read transcript coordinates into a PointsModel.""" + if self.dataset.tx_file is None: + return {} + src = self.dataset.tx_file + if not src.exists() or not src.is_file(): + raise FileNotFoundError(f"Transcript file not found: {src}") + + is_big = _maybe_warn_big_file(src) + + if is_big: + return self._read_transcripts_parquet(src) + return self._read_transcripts_csv(src) + + def _place_transcripts(self, df): + """Map transcripts into the stitched FOV grid. + + Transcripts and polygons share the same per-FOV local coordinate + system, so we place transcript points with the exact mapping used + for polygon vertices (``fov_locs`` origin + ``flip_y``). This + co-registers transcripts with polygons/images/labels by construction + on every dataset, regardless of flip or any global-frame offset. + + Falls back to the raw ``x_global_px`` / ``y_global_px`` columns when + the per-FOV local columns are not available. + """ + if {"x_local_px", "y_local_px", "fov"}.issubset(df.columns): + return place_local_in_fov_grid(df, self.fov_locs) + logger.warning( + "[transcripts] No local px columns found — using raw global " + "coordinates, which may not align with images/polygons." + ) + return df + + def _finalize_transcripts(self, frame: Any, name: str, coord_map: dict[str, str]) -> dict[str, PointsModel]: + """Parse a placed transcripts frame into a PointsModel at the global origin.""" + ox, oy = self._ensure_origin() + return { + name: PointsModel.parse( + frame, + coordinates=coord_map, + feature_key=CosmxKeys.TARGET_OF_TRANSCRIPT, + transformations={"global": _translation_transform(ox, oy)}, + ) + } + + def _read_transcripts_parquet(self, src: Path) -> dict[str, PointsModel]: + """Read large transcripts via Parquet cache + Dask.""" + pq_file = _parquet_cache_for_tx(src, self.dataset.dataset_id, n_workers=self.n_workers) + logger.info("[transcripts] Reading from Parquet cache...") + df = dd.read_parquet(pq_file, engine="pyarrow") + + if self.fovs: + df = df[df["fov"].isin(list(self.fovs))] + + df = self._place_transcripts(df) + + if {"fov", "cell_ID"}.issubset(set(df.columns)): + raw_max = df["cell_ID"].max().compute() + if pd.isna(raw_max): + logger.warning("[transcripts] No transcripts found for the requested FOVs — skipping.") + return {} + current_max = int(raw_max) + base = self._lock_max_cell_id(current_max) + df = df.assign(global_cell_id=df["fov"] * base * (df["cell_ID"] > 0) + df["cell_ID"]) + + df = _dask_categoricals_to_string(df) + name, coord_map = self._transcript_naming(df) + + if isinstance(df, dd.DataFrame): + # Apply coord transforms for Dask + if "x" not in df.columns: + ox, oy = self._ensure_origin() + df = df.assign(x=df["x_global_px"] - ox, y=df["y_global_px"] - oy) + + return self._finalize_transcripts(df, name, coord_map) + + def _read_transcripts_csv(self, src: Path) -> dict[str, PointsModel]: + """Read small transcripts directly from CSV.""" + logger.info("[transcripts] Reading %s into memory...", src.name) + df = _read_transcripts_csv(src, self.dataset.dataset_id) + if self.fovs: + df = df[df["fov"].isin(list(self.fovs))] + if df.empty: + logger.warning("[transcripts] No transcripts found for the requested FOVs — skipping.") + return {} + + df = self._place_transcripts(df) + + df["global_cell_id"] = self.global_cell_id(df) + + name, coord_map = self._transcript_naming(df) + _pandas_categoricals_to_string(df) + + # Convert to Dask before PointsModel.parse to avoid a partition + # alignment error in spatialdata when it independently converts the + # feature column and the coordinate columns to separate Dask objects. + ddf = dd.from_pandas(df, npartitions=max(1, len(df) // 2_000_000)) + return self._finalize_transcripts(ddf, name, coord_map) + + def _transcript_naming(self, df) -> tuple[str, dict[str, str]]: + """Determine element name and coordinate mapping for transcripts. + + Coordinates are shifted to the shared 0-based origin used by + images/labels/shapes so all elements align in the ``global`` + coordinate system. (For Dask frames the shift is applied by the + caller, since ``x`` may not yet exist here.) + """ + ox, oy = self._ensure_origin() + + if not isinstance(df, dd.DataFrame): + df["x"] = (df["x_global_px"] - ox).astype(float) + df["y"] = (df["y_global_px"] - oy).astype(float) + + return _element_name("points", self.fovs), {"x": "x", "y": "y"} + + # -- Read shapes (polygons) -- + + def read_shapes(self) -> dict[str, ShapesModel]: + """Read cell polygons as vector shapes.""" + if self.dataset.polygons_file is None: + return {} + if self.polygons_as_labels and not self.keep_polygons_after_rasterize: + return {} + + poly_df = self._get_polygons() + ox, oy = self._origin_x, self._origin_y + gdf = gpd.GeoDataFrame(poly_df.copy(), geometry="geometry").set_index("global_cell_id", drop=False) + + return { + _element_name("cells_polygons", self.fovs): ShapesModel.parse( + gdf, + transformations={"global": _translation_transform(ox, oy)}, + ) + } + + # -- Build FOV box shapes -- + + def build_fov_shapes(self) -> dict[str, ShapesModel]: + """Build one square per FOV as shape elements.""" + if self.fov_locs is None or self.fov_locs.empty: + return {} + ox, oy = self._ensure_origin() + + locs = self.fov_locs.loc[sorted(self.fovs)].copy() if self.fovs else self.fov_locs.copy() + + geoms, fov_vals, skipped = [], [], [] + for fov, row in locs.iterrows(): + xmin = float(row.get("xmin", math.nan)) + ymin = float(row.get("ymin", math.nan)) + if not math.isfinite(xmin) or not math.isfinite(ymin): + skipped.append(int(fov)) + continue + x0, y0 = xmin - ox, ymin - oy + geoms.append(sgeom.box(x0, y0, x0 + COSMX_FOV_SIZE_PX, y0 + COSMX_FOV_SIZE_PX)) + fov_vals.append(int(fov)) + + if skipped: + logger.warning("Skipped %d FOV(s) with NaN coords: %s", len(skipped), skipped) + if not geoms: + return {} + + gdf = gpd.GeoDataFrame({"fov": fov_vals, "geometry": geoms}, geometry="geometry") + return { + _element_name("fov_boxes", self.fovs): ShapesModel.parse( + gdf, + transformations={"global": _translation_transform(ox, oy)}, + ) + } + + # -- Read labels -- + + def read_labels(self) -> tuple[dict[str, Labels2DModel], set[int]]: + """Read or rasterize cell labels. + + Tries in order: + 1. Polygon rasterization (if ``polygons_as_labels`` and polygons exist) + 2. CellLabels directory (pre-computed TIFFs) + 3. Legacy CellStatsDir layout + + Returns ``(labels_dict, set_of_label_ids_present)``. + """ + # --- Path 1: polygon → raster --- + if self.polygons_as_labels and self.dataset.polygons_file is not None: + return self._labels_from_polygons() + + # --- Path 2: flat CellLabels directory --- + if self.dataset.cell_labels_dir is not None: + return self._labels_from_cell_labels_dir() + + # --- Path 3: legacy CellStatsDir --- + if self.dataset.cell_stats_dir is not None: + return self._labels_from_cell_stats_dir() + + return {}, set() + + def _labels_from_polygons(self) -> tuple[dict[str, Labels2DModel], set[int]]: + """Rasterize polygons into a label image.""" + poly_df = self._get_polygons() + ox, oy = self._origin_x, self._origin_y + + if poly_df.empty: + logger.warning("No polygons left after alignment. Skipping labels.") + return {}, set() + + canvas_min_x, canvas_min_y, canvas_w, canvas_h = _canvas_from_fov_locs_for_polygons( + self.fov_locs, + origin_x=ox, + origin_y=oy, + fovs=self.fovs, + ) + + raster = _polygons_to_label_raster( + poly_df, + chunks=(2048, 2048), + n_jobs=self.n_workers, + canvas_min_x=canvas_min_x, + canvas_min_y=canvas_min_y, + canvas_width=canvas_w, + canvas_height=canvas_h, + ) + raster.attrs.pop("transform", None) + + present_ids = {int(x) for x in da.unique(raster.data).compute() if x != 0} + name = _element_name("polygons_labels", self.fovs) + lbl = _parse_labels(raster, ox, oy) + return {name: lbl}, present_ids + + def _labels_from_cell_labels_dir(self) -> tuple[dict[str, Labels2DModel], set[int]]: + """Stitch pre-computed CellLabels TIFFs.""" + # The per-FOV tile flip is derived from the ``flip_y`` column inside the + # stitcher (it is the inverse of the transcript-placement flip). We pass + # only ``self.flip_image`` here as the fallback used when no ``flip_y`` + # column is present (labels read without transcripts). Folding ``flip_y`` + # into ``flip_image`` here would double-apply it and mirror the labels. + fov_local_to_global = self._build_cell_label_luts() + + # When no cell_info.csv is available, build LUTs from the TIFFs + # themselves using the same global_cell_id formula as read_tables(). + # Without this, labels keep local cell IDs that collide across FOVs + # and don't match the global IDs in the expression table. + if fov_local_to_global is None and self.dataset.cell_labels_dir is not None: + fov_local_to_global = self._build_cell_label_luts_from_tiffs() + + stitched, present_ids, used_fov_locs = _read_stitched_cell_labels_from_dir( + self.dataset.cell_labels_dir, + self.fov_locs, + fovs=self.fovs, + flip_image=self.flip_image, + n_workers=self.n_workers, + fov_local_to_global=fov_local_to_global, + ) + + if stitched is None: + return {}, set() + + # Update reader state to match the FOVs we actually stitched + ox, oy = float(used_fov_locs["xmin"].min()), float(used_fov_locs["ymin"].min()) + self.fov_locs = used_fov_locs + self._base_origin_x = ox + self._base_origin_y = oy + self._origin_x = ox + self._origin_y = oy + + name = _element_name("cell_labels", self.fovs) + lbl = _parse_labels(stitched, ox, oy) + return {name: lbl}, present_ids + + def _labels_from_cell_stats_dir(self) -> tuple[dict[str, Labels2DModel], set[int]]: + """Legacy label stitching from CellStatsDir.""" + ox, oy = self._ensure_origin() + stitched, df = stitch_segmentation_label_image( + path=self.dataset.path, + fov_position_file=str(self.dataset.fov_positions_file), + cell_info_file=str(self.dataset.cell_stats_dir / f"{self.dataset.dataset_id}_cell_info.csv"), + dataset_id=self.dataset.dataset_id, + flip_image=self.flip_image, + n_workers=self.n_workers, + fovs=self.fovs if self.fovs else None, + ) + present_ids = {int(x) for x in df["global_cell_id"].to_numpy() if x != 0} + lbl = _parse_labels(stitched, ox, oy) + return {"cell_labels_from_cellstats": lbl}, present_ids + + def _build_cell_label_luts(self) -> dict[int, np.ndarray] | None: + """Build local→global cell ID LUTs for CellLabels stitching.""" + if self.dataset.cell_stats_dir is None or self.dataset.dataset_id is None: + return None + info_path = self.dataset.cell_stats_dir / f"{self.dataset.dataset_id}_cell_info.csv" + if not info_path.exists(): + return None + cell_info = pd.read_csv(info_path) + luts: dict[int, np.ndarray] = {} + for fov in sorted(cell_info["fov"].unique()): + local_ids = np.unique(cell_info.loc[cell_info["fov"] == fov, "cellID"].astype(int)) + if local_ids.size == 0: + continue + lut = np.zeros(int(local_ids.max()) + 1, dtype=int) + df_f = cell_info.loc[cell_info["fov"] == fov] + for _, row in df_f.iterrows(): + lut[int(row["cellID"])] = int(row["global_cell_id"]) + luts[fov] = lut + return luts + + def _build_cell_label_luts_from_tiffs(self) -> dict[int, np.ndarray] | None: + """Build local→global cell ID LUTs by scanning CellLabels TIFFs. + + Used when no ``cell_info.csv`` is available. Scans each TIFF for + its maximum local cell ID, then builds per-FOV look-up tables + using the same ``global_cell_id`` formula as :meth:`global_cell_id` + so that label IDs match the table's global_cell_id values. + """ + from ._utils import find_cell_label_tifs + + cell_labels_dir = self.dataset.cell_labels_dir + if cell_labels_dir is None: + return None + + fov_tif_map = find_cell_label_tifs(cell_labels_dir) + + # Pass 1: find the max local cell ID across all relevant FOVs. + fov_max_ids: dict[int, int] = {} + for fov, img_path in fov_tif_map.items(): + if self.fovs is not None and fov not in self.fovs: + continue + if fov not in self.fov_locs.index: + continue + arr = tifffile.imread(img_path) + fov_max_ids[fov] = int(arr.max()) + + if not fov_max_ids: + return None + + global_max_local = max(fov_max_ids.values()) + + # Lock max_cell_id if not yet locked; otherwise use the existing + # locked value so that the LUTs stay consistent with read_tables(). + if self.max_cell_id is None: + self.max_cell_id = global_max_local + # If already locked to a smaller value, warn but keep the locked + # value — labels with IDs beyond the locked max will map to IDs + # that have no matching table row (acceptable: those cells simply + # aren't in the expression matrix). + if global_max_local > self.max_cell_id: + logger.warning( + "CellLabels contain local cell IDs up to %d but " + "max_cell_id is locked to %d. Some label IDs may not " + "match table rows.", + global_max_local, + self.max_cell_id, + ) + + base = self.max_cell_id + 1 + + # Pass 2: build per-FOV LUTs. + luts: dict[int, np.ndarray] = {} + for fov, max_local in fov_max_ids.items(): + lut = np.zeros(max_local + 1, dtype=np.int64) + for local_id in range(1, max_local + 1): + lut[local_id] = fov * base + local_id + # cell_ID == 0 stays 0 (background) + luts[fov] = lut + + return luts + + # -- Read tables (expression + metadata) -- + + def read_tables(self, regions: list[str], *, modality: str | None = None) -> dict[str, AnnData]: + """Read expression matrix and metadata into an AnnData table.""" + if self.dataset.exprMat_file is None: + return {} + + adata, bg_df = _read_expr_mat_polars( + self.dataset.exprMat_file, + n_rows=None, + fovs=self.fovs if self.fovs else None, + ) + + if adata.n_obs == 0: + label = modality or "expression" + logger.warning("[tables] No %s data for the requested FOVs — skipping.", label) + return {} + + if self.dataset.metadata_file is not None: + meta = _read_metadata_polars(self.dataset.metadata_file, n_rows=None) + common = adata.obs.index.intersection(meta.index) + adata = adata[common].copy() + adata.obs = meta.loc[common] + + # Compute global cell IDs + if "fov" in adata.obs.columns and "cell_ID" in adata.obs.columns: + tmp = adata.obs.reset_index(drop=False) + tmp["global_cell_id"] = self.global_cell_id(tmp) + adata.obs["global_cell_id"] = tmp.set_index("fov_cellID").loc[adata.obs.index, "global_cell_id"] + else: + adata.obs["global_cell_id"] = np.arange(1, adata.n_obs + 1, dtype=int) + + # Determine region + region = regions[0] if regions else "cells" + if ( + region == "cells" + and self.polygons_as_labels + and self.dataset.polygons_file is not None + and not self.keep_polygons_after_rasterize + ): + region = _element_name("polygons_labels", self.fovs) + + adata.obs["region_key"] = pd.Series(region, index=adata.obs.index, dtype="category") + + # Zarr-safe, case-insensitively unique obs/var keys (spatialdata core), + # then categoricals → string (zarr cannot serialize categoricals). + sanitize_table(adata) + _pandas_categoricals_to_string(adata.obs) + + if bg_df is not None: + adata.uns["fov_bg_signal"] = bg_df + + table = _parse_cell_table(adata, region=region) + return {"cell_data": table} + + +# --------------------------------------------------------------------------- +# Shared read orchestration (single- and multimodal paths) +# --------------------------------------------------------------------------- + + +def _read_images_and_shapes( + reader: CosMxDatasetReader, + *, + read_images: bool, + read_proteins: bool, + read_polygons: bool, + image_models_kwargs: dict[str, Any] | None, + imread_kwargs: dict[str, Any] | None, + channels: list[str] | None, +) -> tuple[dict[str, Any], dict[str, Any]]: + """Read morphology/protein images and polygon shapes (gate shared by both paths).""" + images = ( + reader.read_images( + read_proteins=read_proteins, + image_models_kwargs=image_models_kwargs, + imread_kwargs=imread_kwargs, + channels=channels, + ) + if read_images and reader.dataset.morphology_2d_dir is not None + else {} + ) + shapes = reader.read_shapes() if read_polygons else {} + return images, shapes + + +# --------------------------------------------------------------------------- +# Multi-modal orchestrator +# --------------------------------------------------------------------------- + + +def _cosmx_multi( + dataset: CosMxDataset, + *, + read_images: bool, + read_labels: bool, + read_proteins: bool, + read_transcripts: bool, + read_polygons: bool, + read_gexp: bool, + fovs: list[int] | int | None, + channels: list[str] | None, + n_workers: int, + flip_image: bool | None, + image_normalization_percentile: float | None, + image_models_kwargs: dict[str, Any] | None, + imread_kwargs: dict[str, Any] | None, + polygons_as_labels: bool, + keep_polygons_after_rasterize: bool, + align_rasters_to_polygons: bool | None, + add_fovs_as_shapes: bool, + skip_empty_fovs: bool, + preview_fovs: bool, +) -> SpatialData | None: + """Read a multimodal CosMx dataset (e.g. RNA + Protein).""" + modalities = dataset.modalities + if not modalities: + raise ValueError("Expected modalities for multi-modal dataset.") + + fov_set = _normalize_fovs(fovs) + image_models_kwargs, imread_kwargs = _default_image_kwargs(image_models_kwargs, imread_kwargs) + + # Pick modality for shared spatial elements (prefer one with polygons/labels) + label_ds = next( + (ds for ds in modalities.values() if ds.polygons_file is not None or ds.cell_labels_dir is not None), + next(iter(modalities.values())), + ) + if label_ds.fov_positions_file is None: + raise ValueError("FOV positions file required for multimodal runs.") + + fov_locs = _read_fov_locs(label_ds.fov_positions_file, fovs=sorted(fov_set) if fov_set else None) + + if skip_empty_fovs and (read_images or read_labels): + fov_locs = _prune_empty_fovs(dataset, fov_locs, fov_set) + + if preview_fovs: + _plot_fov_preview(fov_locs, fov_set) + return None + + n_fovs = len(fov_set) if fov_set else len(fov_locs) + n_mods = len(modalities) + logger.info( + "Reading multimodal CosMx dataset (%d modalities, %d FOV(s)). " + "Source files are gzip-compressed CSVs that must be fully decompressed " + "even for a FOV subset — this may take a while on first load.", + n_mods, + n_fovs, + ) + + # Shared reader for spatial elements + logger.debug("[1/5] Setting up reader and loading polygons...") + reader = CosMxDatasetReader( + label_ds, + fovs=fov_set, + n_workers=n_workers, + flip_image=flip_image, + polygons_as_labels=polygons_as_labels, + fov_locs=fov_locs, + align_rasters_to_polygons=align_rasters_to_polygons, + keep_polygons_after_rasterize=keep_polygons_after_rasterize, + image_normalization_percentile=image_normalization_percentile, + ) + + # Pre-load polygons to freeze max_cell_id + base_max_cell_id: int | None = None + try: + reader._get_polygons() + base_max_cell_id = reader.max_cell_id + except Exception as e: + logger.debug("Could not pre-load polygons to freeze max_cell_id: %s", e) + + def _modality_reader(ds: CosMxDataset) -> CosMxDatasetReader: + """Secondary reader (transcripts/tables) sharing the frozen max_cell_id.""" + r = CosMxDatasetReader( + ds, + fovs=fov_set, + n_workers=n_workers, + flip_image=flip_image, + polygons_as_labels=False, + fov_locs=fov_locs, + align_rasters_to_polygons=align_rasters_to_polygons, + ) + if base_max_cell_id is not None: + r.max_cell_id = base_max_cell_id + return r + + logger.debug("[2/5] Reading images and labels...") + images, shapes = _read_images_and_shapes( + reader, + read_images=read_images, + read_proteins=read_proteins, + read_polygons=read_polygons, + image_models_kwargs=image_models_kwargs, + imread_kwargs=imread_kwargs, + channels=channels, + ) + labels, label_ids = reader.read_labels() if read_labels else ({}, set()) + + if add_fovs_as_shapes: + shapes = {**shapes, **reader.build_fov_shapes()} + + # Transcripts (from whichever modality has tx_file) + logger.debug("[3/5] Reading transcripts...") + points: dict[str, PointsModel] = {} + tx_mod = next((name for name, ds in modalities.items() if ds.tx_file is not None), None) + if read_transcripts and tx_mod is not None: + tx_points = _modality_reader(modalities[tx_mod]).read_transcripts() + if tx_points: + points = {"transcripts": next(iter(tx_points.values()))} + + # Label-anchored co-registration (same as cosmx()): the protein-derived + # segmentation read above is the ground truth; filter each modality's table to + # the rasterised label IDs. (Modalities quantify different cell subsets, so the + # labels cannot be anchored on any single modality's table.) + tables: dict[str, AnnData] = {} + region_refs = list(labels.keys()) or list(shapes.keys()) or ["cells"] + if read_gexp: + for mod_name, mod_ds in modalities.items(): + logger.debug("[4/5] Reading %s data...", mod_name) + mod_tables = _modality_reader(mod_ds).read_tables(region_refs, modality=mod_name) + if label_ids: + mod_tables = _filter_tables_to_ids(mod_tables, label_ids, source_label="rasterised labels") + if mod_tables: + tables[mod_name] = next(iter(mod_tables.values())) + + logger.debug("[5/5] Assembling SpatialData object...") + return _assemble_sdata(images, points, labels, shapes, tables) + + +# --------------------------------------------------------------------------- +# SpatialData assembly +# --------------------------------------------------------------------------- + + +def _assemble_sdata( + images: dict[str, Any], + points: dict[str, Any], + labels: dict[str, Any], + shapes: dict[str, Any], + tables: dict[str, Any], +) -> SpatialData: + """Build and annotate the final SpatialData object.""" + # Move each image's per-channel normalization divisors from the (non-persisted) element + # attrs into the persisted SpatialData attrs, so the reversible stretch survives a + # write/read round-trip. Popped from the element so the metadata has a single home. + attrs: dict[str, Any] = {} + image_norm = { + nm: im.attrs.pop("cosmx_image_normalization") + for nm, im in images.items() + if "cosmx_image_normalization" in getattr(im, "attrs", {}) + } + if image_norm: + attrs["cosmx_image_normalization"] = image_norm + + sdata = SpatialData( + images=images, + points=points, + labels=labels, + shapes=shapes, + tables=tables, + attrs=attrs, + ) + + # Wire tables to label elements + if labels and sdata.tables: + main_label = next(iter(labels)) + for table_name, table in list(sdata.tables.items()): + raw = table.table if hasattr(table, "table") else table + if "region_key" not in raw.obs: + continue + reg_vals = pd.Series(raw.obs["region_key"]).astype("string").unique() + if len(reg_vals) == 1 and reg_vals[0] == main_label: + sdata.set_table_annotates_spatialelement(table_name, region=main_label) + continue + raw.uns.pop("spatialdata_attrs", None) + raw.obs["region_key"] = pd.Series(main_label, index=raw.obs.index, dtype="category") + sdata.tables[table_name] = _parse_cell_table(raw, region=main_label) + sdata.set_table_annotates_spatialelement(table_name, region=main_label) + + return _set_reader_metadata(sdata, "cosmx") + + +# --------------------------------------------------------------------------- +# FOV normalization +# --------------------------------------------------------------------------- + + +def _normalize_fovs(fovs: list[int | str] | int | str | None) -> set[int] | None: + """Convert the user-facing ``fovs`` argument to an optional set of ints.""" + if fovs is None: + return None + if isinstance(fovs, int | str): + return {int(fovs)} + return {int(f) for f in fovs} + + +def _present_fovs(dataset: CosMxDataset) -> set[int]: + """FOV ids with per-FOV data on disk, unioned over the dataset + modalities (#37).""" + sources = [dataset, *(dataset.modalities.values() if dataset.modalities else [])] + present: set[int] = set() + for s in sources: + present |= detect_fovs_with_data( + morphology_2d_dir=s.morphology_2d_dir, + cell_labels_dir=s.cell_labels_dir, + cell_stats_dir=s.cell_stats_dir, + analysis_results_dir=s.analysis_results_dir, + ) + return present + + +def _prune_empty_fovs( + dataset: CosMxDataset, + fov_locs: pd.DataFrame | None, + fov_set: set[int] | None, +) -> pd.DataFrame | None: + """Drop phantom FOVs (listed in positions, no data files) from *fov_locs* (#37). + + Pruned before any element is read so origin, canvases and FOV boxes stay + consistent. Never drops an explicitly requested FOV; keeps every FOV when + detection finds nothing (transcript-/gexp-only datasets). + """ + if fov_locs is None or fov_locs.empty: + return fov_locs + present = _present_fovs(dataset) + if not present: + return fov_locs + listed = {int(f) for f in fov_locs.index} + phantom = listed - present + if fov_set is not None: + phantom -= fov_set + if not phantom: + return fov_locs + logger.info("skip_empty_fovs: dropping %d FOV(s) with no data files: %s", len(phantom), sorted(phantom)) + return fov_locs.loc[sorted(listed - phantom)] + + +def _prescan_max_cell_id( + reader: CosMxDatasetReader, + dataset: CosMxDataset, + fov_set: set[int] | None, +) -> None: + """Pre-scan cell_ID column from all available files to lock max_cell_id. + + This removes the load-order dependency: without pre-scanning, + whichever element calls ``global_cell_id()`` first locks the max, + and subsequent elements with a higher max will crash. + """ + import polars as pl + + max_ids: list[int] = [] + for path in [dataset.polygons_file, dataset.exprMat_file, dataset.metadata_file]: + if path is None or not path.exists(): + continue + try: + lf = pl.scan_csv(path, n_rows=None) + cols = lf.collect_schema().names() + except Exception as e: + # A file that exists but cannot be scanned is a real problem for a pre-scan + # whose whole job is to establish the max_cell_id invariant — surface it + # rather than silently dropping the file (which would mis-lock the max). + logger.warning("Could not pre-scan %s for max cell_ID: %s", path, e) + continue + # Match the id column through the shared canonical alias set, NOT a hardcoded + # ("cell_ID", "cellID") pair: expression/metadata use ``cell_ID`` while polygon + # CSVs use ``cellID``, and other exports use ``cell_id``/``object_id``. Including + # the polygon column is essential — orphan (segmented-only) cells are often the + # highest per-FOV IDs and would otherwise be missed, mis-locking max_cell_id. + id_col = _match_canonical(cols, "cell_ID") + if id_col is None: + logger.warning("No cell_ID-like column found in %s (columns: %s) — skipping for max cell_ID.", path, cols) + continue + fov_col = _match_canonical(cols, "fov") + if fov_set is not None and fov_col is not None: + lf = lf.filter(pl.col(fov_col).is_in(list(fov_set))) + try: + val = lf.select(pl.col(id_col).max()).collect().item() + except Exception as e: + logger.warning("Could not read max %s from %s: %s", id_col, path, e) + continue + if val is not None: + max_ids.append(int(val)) + + if max_ids: + reader.max_cell_id = max(max_ids) + logger.debug("Pre-scanned max_cell_id=%d from %d file(s).", reader.max_cell_id, len(max_ids)) + + +# --------------------------------------------------------------------------- +# Public entry point +# --------------------------------------------------------------------------- + + +@inject_docs(cx=CosmxKeys) +def cosmx( + path: str | Path, + dataset_id: str | None = None, + *, + read_images: bool = True, + read_labels: bool = True, + read_proteins: bool = True, + read_transcripts: bool = True, + read_polygons: bool = True, + read_gexp: bool = True, + fovs: list[int | str] | int | str | None = None, + channels: list[str] | None = None, + n_workers: int | None = None, + flip_image: bool | None = None, + image_normalization_percentile: float | None = None, + image_models_kwargs: dict[str, Any] | None = None, + imread_kwargs: dict[str, Any] | None = None, + polygons_as_labels: bool = True, + keep_polygons_after_rasterize: bool = False, + align_rasters_to_polygons: bool | None = None, + add_fovs_as_shapes: bool = True, + skip_empty_fovs: bool = True, + preview_fovs: bool = False, +) -> SpatialData | None: + """Read *CosMx Nanostring* data into a :class:`spatialdata.SpatialData` object. + + Supports all known CosMx export formats: flat CSV files, nested + CellStatsDir layouts, multimodal RNA+Protein runs, and both old-style + (px-only / mm-only) and new-style (px+mm) FOV positions files. + + Files are recognized by their standard CosMx suffixes — counts + ``{cx.COUNTS_SUFFIX!r}``, metadata ``{cx.METADATA_SUFFIX!r}``, transcripts + ``{cx.TRANSCRIPTS_SUFFIX!r}``, and FOV positions ``{cx.FOV_SUFFIX!r}`` — + typically prefixed with the dataset id. + + Parameters + ---------- + path + Path to the root directory containing CosMx files. + dataset_id + Optional dataset identifier. Inferred from file prefixes if not given. + read_images + Whether to read morphology images. + read_labels + Whether to read or rasterize cell labels. + read_proteins + Whether to include protein image channels. + read_transcripts + Whether to read transcript coordinates. + read_polygons + Whether to read cell boundary polygons as shapes. + read_gexp + Whether to read the gene expression matrix. + fovs + Specific FOV(s) to read. ``None`` reads all FOVs. + channels + Specific image channel names to include. + n_workers + Number of parallel workers for stitching operations. + flip_image + Flip morphology images vertically to co-register with transcripts and + labels. Defaults to ``True`` when ``None`` (CosMx morphology TIFFs are + stored y-inverted relative to the FOV-grid placement); pass ``False`` for + a dataset stored the other way. + image_normalization_percentile + Optional per-channel contrast normalization for morphology/protein images, applied + post-stitching on top of the dtype-max scaling. ``None`` (default) keeps the plain + dtype-max ``[0, 1]`` image. A float in ``[0, 100]`` (e.g. ``99.9``) divides each + channel by that percentile of its non-zero pixels, recovering channels whose real + signal sits far below the dtype ceiling and would otherwise render near-black + (issue #38). The stretch is scale-only (no clipping), so the brightest pixels may + exceed ``1.0``; it is reversible via the divisors recorded in + ``sdata.attrs['cosmx_image_normalization']`` — these are in dtype-max ``[0, 1]`` units + (they reverse to the dtype-max image, not raw counts). The percentile is approximate + for large multi-chunk images, so the exact divisor may vary slightly with chunking. + image_models_kwargs + Extra kwargs for :class:`spatialdata.models.Image2DModel`. + imread_kwargs + Extra kwargs for :func:`dask_image.imread.imread`. + polygons_as_labels + Rasterize polygons into a label image. + keep_polygons_after_rasterize + Keep vector polygons even when ``polygons_as_labels=True``. + align_rasters_to_polygons + Anchor images and labels to the polygon coordinate origin. + add_fovs_as_shapes + Add FOV bounding boxes as shape elements. + skip_empty_fovs + Drop FOVs that appear in the positions file but have no data files on + disk (issue #37). CosMx positions files often list more FOVs than ship + data; those phantom FOVs otherwise inflate the image canvas, desync the + image and label canvases, and add empty FOV boxes. Defaults to ``True``; + FOVs are pruned only when detection finds a strict, non-empty subset with + data (transcript-/gexp-only datasets are left untouched). Pass ``False`` + to keep every listed FOV. + preview_fovs + Show a preview plot of FOV positions and return ``None``. + + Returns + ------- + :class:`spatialdata.SpatialData` or ``None`` (if ``preview_fovs=True``). + """ + n_workers = n_workers or 1 + path = Path(path).resolve() + + dataset = _set_up_cosmx_dataset_for_conversion(path=path, dataset_id=dataset_id) + + # Morphology TIFFs are stored y-inverted relative to the FOV-grid placement, + # so they need a vertical flip to co-register with transcripts/labels. When + # the user does not specify, default to flipping (issue #42). This replaces + # an earlier heuristic that inferred the flip from the transcript coordinate + # convention (``not flip_y``) — the wrong signal, which left px-only images + # mirrored when read without polygons. Resolve here, BEFORE the multimodal + # dispatch, so both single- and multi-modal paths get the same default. + if flip_image is None: + flip_image = True + flip_image = bool(flip_image) + + # --- Multimodal dispatch --- + if dataset.modalities: + return _cosmx_multi( + dataset=dataset, + read_images=read_images, + read_labels=read_labels, + read_proteins=read_proteins, + read_transcripts=read_transcripts, + read_polygons=read_polygons, + read_gexp=read_gexp, + fovs=fovs, + channels=channels, + n_workers=n_workers, + flip_image=flip_image, + image_normalization_percentile=image_normalization_percentile, + image_models_kwargs=image_models_kwargs, + imread_kwargs=imread_kwargs, + polygons_as_labels=polygons_as_labels, + keep_polygons_after_rasterize=keep_polygons_after_rasterize, + align_rasters_to_polygons=align_rasters_to_polygons, + add_fovs_as_shapes=add_fovs_as_shapes, + skip_empty_fovs=skip_empty_fovs, + preview_fovs=preview_fovs, + ) + + # --- Single-modality path --- + logger.info( + "Reading single-modality CosMx dataset (ID: %s). " + "Gzip-compressed CSV sources must be fully decompressed even for a " + "FOV subset — this may take a while on first load.", + dataset.dataset_id, + ) + + fov_set = _normalize_fovs(fovs) + image_models_kwargs, imread_kwargs = _default_image_kwargs(image_models_kwargs, imread_kwargs) + + fov_locs = _read_fov_locs(dataset.fov_positions_file) if dataset.fov_positions_file else None + + if skip_empty_fovs and (read_images or read_labels): + fov_locs = _prune_empty_fovs(dataset, fov_locs, fov_set) + + if preview_fovs: + if fov_locs is None: + raise ValueError("preview_fovs=True but no FOV positions file found.") + _plot_fov_preview(fov_locs, fov_set) + return None + + if align_rasters_to_polygons is None: + align_rasters_to_polygons = bool(read_polygons or (read_labels and polygons_as_labels)) + + reader = CosMxDatasetReader( + dataset, + fovs=fov_set, + n_workers=n_workers, + flip_image=flip_image, + polygons_as_labels=polygons_as_labels, + fov_locs=fov_locs, + align_rasters_to_polygons=align_rasters_to_polygons, + keep_polygons_after_rasterize=keep_polygons_after_rasterize, + image_normalization_percentile=image_normalization_percentile, + ) + + # Pre-scan cell_ID max across all available files to avoid + # load-order dependency crashes in global_cell_id(). + _prescan_max_cell_id(reader, dataset, fov_set) + + # --- Read elements --- + images, shapes = _read_images_and_shapes( + reader, + read_images=read_images, + read_proteins=read_proteins, + read_polygons=read_polygons, + image_models_kwargs=image_models_kwargs, + imread_kwargs=imread_kwargs, + channels=channels, + ) + points = reader.read_transcripts() if read_transcripts and dataset.tx_file is not None else {} + + # Label-anchored co-registration: the segmentation is the ground truth (in + # multimodal runs it is the protein-derived segmentation, shared across + # modalities), so rasterise every cell and filter the table to the cells that + # actually rasterised. _cosmx_multi runs the same cascade per modality. + labels, label_ids = reader.read_labels() if read_labels else ({}, set()) + + region_refs = list(labels.keys()) or list(shapes.keys()) or ["cells"] + tables = reader.read_tables(region_refs) if read_gexp else {} + if label_ids: + tables = _filter_tables_to_ids(tables, label_ids, source_label="rasterised labels") + + if add_fovs_as_shapes: + shapes = {**shapes, **reader.build_fov_shapes()} + + return _assemble_sdata(images, points, labels, shapes, tables) diff --git a/src/spatialdata_io/readers/cosmx/_stitching.py b/src/spatialdata_io/readers/cosmx/_stitching.py new file mode 100644 index 00000000..14c48cd5 --- /dev/null +++ b/src/spatialdata_io/readers/cosmx/_stitching.py @@ -0,0 +1,922 @@ +"""Stitching and rasterization helpers for CosMx multi-FOV data. + +All functions in this module produce zarr-backed dask arrays using +``zarr.storage.LocalStore(tempfile.mkdtemp(...))`` so that the resulting +arrays are compatible with ``SpatialData.write()`` and do not duplicate +data in memory. +""" + +from __future__ import annotations + +import math +import tempfile +from pathlib import Path +from typing import Any + +import dask +import dask.array as da +import numpy as np +import pandas as pd +import tifffile +import xarray as xr +import zarr +from dask.diagnostics import ProgressBar +from dask.utils import SerializableLock +from skimage.draw import polygon +from spatialdata._logging import logger +from tqdm.auto import tqdm + +from ._io import ( + COSMX_FOV_SIZE_PX, + _read_fov_image, +) +from ._utils import FOV_DIR_RE, MORPH_FOV_RE, compute_with_limit + +# --------------------------------------------------------------------------- +# Tile clipping helper +# --------------------------------------------------------------------------- + + +def _clip_tile_to_canvas( + y0: int, + y1: int, + x0: int, + x1: int, + canvas_h: int, + canvas_w: int, +) -> tuple[slice, slice, slice, slice] | None: + """Clip a tile's placement to the canvas bounds. + + Returns ``(global_y, global_x, src_y, src_x)`` slices suitable for + ``da.store(tile[src_y, src_x], zarr, regions=(global_y, global_x))``, + or ``None`` if the tile falls entirely outside the canvas. + """ + gy0, gx0 = max(0, y0), max(0, x0) + gy1, gx1 = min(canvas_h, y1), min(canvas_w, x1) + if gy0 >= gy1 or gx0 >= gx1: + return None + sy0, sx0 = gy0 - y0, gx0 - x0 + sy1 = sy0 + (gy1 - gy0) + sx1 = sx0 + (gx1 - gx0) + return slice(gy0, gy1), slice(gx0, gx1), slice(sy0, sy1), slice(sx0, sx1) + + +def _open_temp_zarr(prefix: str, shape: tuple[int, ...], chunks: tuple[int, ...], dtype: Any) -> zarr.Array: + """Allocate a fresh temp-backed zarr array for scratch stitching output.""" + store = zarr.storage.LocalStore(tempfile.mkdtemp(prefix=prefix)) + return zarr.create_array(store=store, shape=shape, chunks=chunks, dtype=dtype) + + +# --------------------------------------------------------------------------- +# FOV grid normalisation +# --------------------------------------------------------------------------- + + +def _snap_fov_grid( + fov_locs: pd.DataFrame, + fov_arrays: dict[int, da.Array] | dict[int, np.ndarray] | None = None, + *, + fov_size: float = COSMX_FOV_SIZE_PX, + tol: int = 8, +) -> pd.DataFrame: + """Normalize FOV boxes to a 0-based canvas using the CosMx spec size. + + CosMx local coordinates span 0 .. *fov_size* (default 4256) for both + axes. Polygon CSVs are written in that frame and the rasterizer uses + ``fov_locs["xmax"]`` / ``fov_locs["ymax"]`` to build the canvas. + Forcing every FOV to the same extent prevents a single + cropped / malformed FOV from shrinking the canvas. + + A warning is emitted when the actual image size deviates from the spec + by more than *tol* pixels. + + Parameters + ---------- + fov_locs + DataFrame indexed by FOV id with at least ``xmin`` and ``ymin``. + fov_arrays + Optional mapping of FOV id to the image array for that FOV + (used only for sanity-checking dimensions). + fov_size + Expected tile width **and** height in pixels. + tol + Tolerance in pixels before a size-mismatch warning is raised. + + Returns + ------- + pd.DataFrame + A copy of *fov_locs* with added columns ``x0``, ``x1``, ``y0``, + ``y1`` (canvas coordinates) and updated ``xmax`` / ``ymax``. + """ + if fov_locs.empty: + return fov_locs + + fov_locs = fov_locs.copy() + + x_min = float(fov_locs["xmin"].min()) + y_min = float(fov_locs["ymin"].min()) + + wanted_w = int(round(fov_size)) + wanted_h = int(round(fov_size)) + + for f in fov_locs.index: + x0 = int(round(float(fov_locs.at[f, "xmin"]) - x_min)) + y0 = int(round(float(fov_locs.at[f, "ymin"]) - y_min)) + + w = wanted_w + h = wanted_h + + arr = fov_arrays.get(int(f)) if fov_arrays is not None else None + if arr is not None: + if arr.ndim == 3: + _, ah, aw = arr.shape + else: + ah, aw = arr.shape + if abs(aw - wanted_w) > tol or abs(ah - wanted_h) > tol: + logger.warning( + "FOV %s has image %sx%s but CosMx expects %sx%s. Forcing spec size.", + f, + aw, + ah, + wanted_w, + wanted_h, + ) + + x1 = x0 + w + y1 = y0 + h + + fov_locs.at[f, "x0"] = x0 + fov_locs.at[f, "x1"] = x1 + fov_locs.at[f, "y0"] = y0 + fov_locs.at[f, "y1"] = y1 + + fov_locs.at[f, "xmax"] = float(fov_locs.at[f, "xmin"]) + w + fov_locs.at[f, "ymax"] = float(fov_locs.at[f, "ymin"]) + h + + return fov_locs + + +def _label_tile_flip(fov_locs: pd.DataFrame, fov: int, flip_image: bool) -> bool: + """Per-FOV label-tile flip: the inverse of ``flip_y``. + + Transcripts are placed in coordinate space via ``flip_y`` (see + :func:`spatialdata_io.readers.cosmx._io.place_local_in_fov_grid`), so a label + raster co-registers iff it is flipped ``not flip_y``. When ``flip_y`` is + unavailable (column absent, or FOV missing from the table) fall back to + ``flip_image``. Single source of truth for both label stitchers (#39 / #41). + """ + if "flip_y" in fov_locs.columns and fov in fov_locs.index: + return not bool(fov_locs.at[fov, "flip_y"]) + return flip_image + + +# --------------------------------------------------------------------------- +# Multi-FOV image stitcher +# --------------------------------------------------------------------------- + + +def _read_stitched_image( + images_dir: Path, + fov_locs: pd.DataFrame, + protein_dir_dict: dict[int, Path], + morphology_coords: list[str], + flip_image: bool, + *, + fovs_filter: set[int] | None = None, + selected_channels: list[str] | None = None, + n_workers: int | None = None, + tighten_to_seen: bool = False, + **imread_kwargs: Any, +) -> tuple[da.Array, list[str], float, float] | tuple[None, list[str], float, float]: + """Stitch per-FOV morphology (and optional protein) TIFFs into one image. + + Each FOV tile is rechunked to align with the zarr target, clipped to + the canvas bounds, and written via ``da.store``. The result is a + dask array backed by a temporary ``zarr.storage.LocalStore``. + + Parameters + ---------- + images_dir + Directory containing per-FOV Morphology2D TIFFs. + fov_locs + FOV positions table (indexed by FOV id). + protein_dir_dict + Mapping of FOV id to the directory holding protein TIFFs for + that FOV (may be empty). + morphology_coords + Channel names parsed from the TIFF metadata. + flip_image + Whether to flip each tile vertically before stitching. + fovs_filter + If given, only these FOV ids are stitched. + selected_channels + Restrict the output to these channel names. + n_workers + Parallelism cap passed to :func:`compute_with_limit`. + tighten_to_seen + Snap the canvas over only image-bearing FOVs, so phantom / image-less + FOVs neither anchor the origin nor inflate it (issue #37). + **imread_kwargs + Forwarded to :func:`dask_image.imread.imread`. + + Returns + ------- + tuple[da.Array, list[str], float, float] + ``(stitched_image, channel_names, origin_x, origin_y)``; origin is the + canvas min over image-bearing FOVs. + """ + tif_re = MORPH_FOV_RE + all_tifs = list(images_dir.glob("*.TIF")) + + fov_images: dict[int, da.Array] = {} + fov_channels: dict[int, list[str]] = {} + seen_fovs: list[int] = [] + + for img_path in all_tifs: + m = tif_re.match(img_path.name) + if not m: + continue + fov = int(m.group(1)) + if fovs_filter is not None and fov not in fovs_filter: + continue + if fov not in fov_locs.index: + logger.warning("Image for FOV %d ignored (no entry in positions table).", fov) + continue + + img, c_names = _read_fov_image( + img_path, + protein_dir_dict.get(fov), + morphology_coords, + selected_channels=selected_channels, + **imread_kwargs, + ) + if flip_image: + img = img[:, ::-1, :] + + fov_images[fov] = img + fov_channels[fov] = c_names + seen_fovs.append(fov) + + _, h, w = img.shape + fov_locs.loc[fov, "xmax"] = float(fov_locs.loc[fov, "xmin"]) + w + fov_locs.loc[fov, "ymax"] = float(fov_locs.loc[fov, "ymin"]) + h + + if not seen_fovs: + logger.warning("No matching FOV images found to stitch — skipping images.") + return None, [], 0.0, 0.0 + + # Snap only image-bearing FOVs when tightening (issue #37). + snap_locs = fov_locs.loc[sorted(seen_fovs)].copy() if tighten_to_seen else fov_locs + fov_locs = _snap_fov_grid(snap_locs, fov_images) + + used_ox = float(fov_locs.loc[seen_fovs, "xmin"].min()) + used_oy = float(fov_locs.loc[seen_fovs, "ymin"].min()) + + H = int(fov_locs.loc[seen_fovs, "y1"].max()) + W = int(fov_locs.loc[seen_fovs, "x1"].max()) + + if selected_channels is not None: + all_channels = [ch for ch in selected_channels if any(ch in v for v in fov_channels.values())] + if not all_channels: + raise ValueError(f"None of the requested channels are present in the selected FOVs: {selected_channels}.") + else: + all_channels = sorted(set().union(*(set(v) for v in fov_channels.values()))) + + channel_to_idx = {ch: i for i, ch in enumerate(all_channels)} + sample_dtype = fov_images[seen_fovs[0]].dtype + + cy = min(1024, H) + cx = min(1024, W) + + z_img = _open_temp_zarr("cosmx_stitch_", (len(all_channels), H, W), (1, cy, cx), sample_dtype) + + store_ops: list[dask.delayed] = [] + lock = SerializableLock() # shared across tiles: serialises writes to shared zarr chunks + + for fov in tqdm(seen_fovs, desc="Reading FOVs", unit="fov"): + img = fov_images[fov].rechunk((1, cy, cx)) + c_here = fov_channels[fov] + + y0 = int(fov_locs.loc[fov, "y0"]) + y1 = int(fov_locs.loc[fov, "y1"]) + x0 = int(fov_locs.loc[fov, "x0"]) + x1 = int(fov_locs.loc[fov, "x1"]) + + slices = _clip_tile_to_canvas(y0, y1, x0, x1, H, W) + if slices is None: + continue + gl_y, gl_x, src_y, src_x = slices + + for local_ci, ch in enumerate(c_here): + global_ci = channel_to_idx[ch] + + src = img[local_ci : local_ci + 1, src_y, src_x] + region = (slice(global_ci, global_ci + 1), gl_y, gl_x) + + op = da.store(src, z_img, regions=region, lock=lock, compute=False) + store_ops.append(op) + + with ProgressBar(): + logger.info("Stitching FOVs") + compute_with_limit(*store_ops, n_workers=n_workers) + + stitched = da.from_zarr(z_img) + return stitched, all_channels, used_ox, used_oy + + +# --------------------------------------------------------------------------- +# Cell-label TIFF stitcher +# --------------------------------------------------------------------------- + + +def _read_stitched_cell_labels_from_dir( + cell_labels_dir: Path, + fov_locs: pd.DataFrame, + *, + fovs: set[int] | None = None, + flip_image: bool = False, + n_workers: int | None = None, + fov_local_to_global: dict[int, np.ndarray] | None = None, +) -> tuple[da.Array | None, set[int], pd.DataFrame]: + """Read flat ``CellLabels_Fxxx.tif`` images and stitch them. + + The grid is snapped only to the FOVs for which a TIFF was actually + found, which removes blank space caused by FOVs listed in the + positions file but absent from disk. + + Parameters + ---------- + cell_labels_dir + Directory containing ``CellLabels_F*.tif`` files. + fov_locs + FOV positions table (indexed by FOV id). + fovs + Optional subset of FOV ids to include. + flip_image + Fallback tile flip, used only when ``flip_y`` is unavailable for the FOV. + n_workers + Parallelism cap for :func:`compute_with_limit`. + fov_local_to_global + Per-FOV look-up table mapping local cell ids to global ids. + + Returns + ------- + tuple[da.Array, set[int], pd.DataFrame] + ``(stitched_labels, present_cell_ids, fov_locs_used)`` + """ + from ._utils import find_cell_label_tifs + + fov_tif_map = find_cell_label_tifs(cell_labels_dir) + + fov_tiles: dict[int, da.Array] = {} + seen_fovs: list[int] = [] + + for fov, img_path in fov_tif_map.items(): + if fovs is not None and fov not in fovs: + continue + if fov not in fov_locs.index: + logger.warning("Cell labels for FOV %d ignored (no entry in positions table).", fov) + continue + + arr = tifffile.imread(img_path) + if arr.ndim != 2: + raise ValueError(f"Expected 2D label image for {img_path}, got shape {arr.shape}") + + if _label_tile_flip(fov_locs, fov, flip_image): # see :func:`_label_tile_flip` + arr = arr[::-1, :] + + if fov_local_to_global is not None and fov in fov_local_to_global: + lut = fov_local_to_global[fov] + if arr.max() >= len(lut): + raise ValueError( + f"FOV {fov}: label image contains id {int(arr.max())} but mapping has length {len(lut)}." + ) + arr = lut[arr] + + tile = da.from_array(arr, chunks=arr.shape) + + h, w = arr.shape + fov_locs.loc[fov, "xmax"] = float(fov_locs.loc[fov, "xmin"]) + w + fov_locs.loc[fov, "ymax"] = float(fov_locs.loc[fov, "ymin"]) + h + + fov_tiles[fov] = tile + seen_fovs.append(fov) + + if not seen_fovs: + logger.warning("No matching CellLabels TIFs found in %s — skipping labels.", cell_labels_dir) + return None, set(), fov_locs + + stitched_fov_locs = fov_locs.loc[sorted(seen_fovs)].copy() + stitched_fov_locs = _snap_fov_grid(stitched_fov_locs, fov_tiles) + + total_y = int(stitched_fov_locs["y1"].max()) + total_x = int(stitched_fov_locs["x1"].max()) + + cy = min(1024, total_y) + cx = min(1024, total_x) + + z_arr = _open_temp_zarr("cosmx_cell_labels_", (total_y, total_x), (cy, cx), next(iter(fov_tiles.values())).dtype) + + store_ops: list[dask.delayed] = [] + lock = SerializableLock() # shared across tiles: serialises writes to shared zarr chunks + + for fov, tile in fov_tiles.items(): + tile = tile.rechunk((cy, cx)) + + y0 = int(stitched_fov_locs.loc[fov, "y0"]) + y1 = int(stitched_fov_locs.loc[fov, "y1"]) + x0 = int(stitched_fov_locs.loc[fov, "x0"]) + x1 = int(stitched_fov_locs.loc[fov, "x1"]) + + slices = _clip_tile_to_canvas(y0, y1, x0, x1, total_y, total_x) + if slices is None: + continue + gl_y, gl_x, src_y, src_x = slices + + op = da.store(tile[src_y, src_x], z_arr, regions=(gl_y, gl_x), lock=lock, compute=False) + store_ops.append(op) + + with ProgressBar(): + compute_with_limit(*store_ops, n_workers=n_workers) + + stitched = da.from_zarr(z_arr) + + uniq_da = da.unique(stitched) + uniq_np = np.asarray(uniq_da.compute(), dtype=int) + present_ids = {int(x) for x in uniq_np if x != 0} + + return stitched, present_ids, stitched_fov_locs + + +# --------------------------------------------------------------------------- +# Polygon → label rasterization +# --------------------------------------------------------------------------- + + +def _polygons_to_label_raster( + poly_df: pd.DataFrame, + *, + chunks: tuple[int, int] = (2048, 2048), + n_jobs: int | None = None, + canvas_min_x: float | None = None, + canvas_min_y: float | None = None, + canvas_width: int | None = None, + canvas_height: int | None = None, +) -> xr.DataArray: + """Rasterize polygons into a uint32 label image. + + Operates tile-by-tile to keep peak memory bounded, writing directly + into a zarr ``LocalStore``. ``skimage.draw.polygon`` is used + for the actual rasterization. + + Parameters + ---------- + poly_df + DataFrame with a ``geometry`` column of :class:`shapely.geometry.Polygon` + objects and, optionally, a ``global_cell_id`` column. + chunks + Tile size ``(height, width)`` for the zarr output. + n_jobs + Parallelism cap for :func:`compute_with_limit`. + canvas_min_x, canvas_min_y + Origin of the rasterization canvas (defaults to polygon bounds). + canvas_width, canvas_height + Canvas extent in pixels (defaults to polygon bounds). + + Returns + ------- + xr.DataArray + 2-D label image with dims ``("y", "x")`` and ``uint32`` dtype, + backed by a dask array. + """ + if poly_df.empty: + raise ValueError("poly_df is empty, cannot rasterize polygons.") + + geom_arr = poly_df.geometry.to_list() + bounds = np.array([g.bounds for g in geom_arr], dtype=float) + + if canvas_min_x is None or canvas_min_y is None or canvas_width is None or canvas_height is None: + min_x = float(bounds[:, 0].min()) + min_y = float(bounds[:, 1].min()) + width = int(math.ceil(float(bounds[:, 2].max() - min_x))) + height = int(math.ceil(float(bounds[:, 3].max() - min_y))) + else: + min_x = float(canvas_min_x) + min_y = float(canvas_min_y) + width = int(canvas_width) + height = int(canvas_height) + + cy, cx = chunks + n_tiles_y = int(math.ceil(height / cy)) + n_tiles_x = int(math.ceil(width / cx)) + + if "global_cell_id" in poly_df.columns: + label_ids = poly_df["global_cell_id"].to_numpy(dtype=np.uint32, copy=False) + else: + label_ids = np.arange(len(poly_df), dtype=np.uint32) + 1 + + z_arr = _open_temp_zarr("cosmx_labels_", (height, width), chunks, "uint32") + + def _burn_row(ty: int) -> None: + y0 = ty * cy + y1 = min((ty + 1) * cy, height) + + row_min_y = min_y + y0 + row_max_y = min_y + y1 + + row_mask = (bounds[:, 1] < row_max_y) & (bounds[:, 3] > row_min_y) + row_idxs = np.nonzero(row_mask)[0] + + if row_idxs.size == 0: + for tx in range(n_tiles_x): + x0 = tx * cx + x1 = min((tx + 1) * cx, width) + z_arr[y0:y1, x0:x1] = 0 + return + + row_geoms = [geom_arr[i] for i in row_idxs] + row_labels = label_ids[row_idxs] + row_bounds = bounds[row_idxs] + + for tx in range(n_tiles_x): + x0 = tx * cx + x1 = min((tx + 1) * cx, width) + + tile_min_x = min_x + x0 + tile_max_x = min_x + x1 + + mask_x = (row_bounds[:, 0] < tile_max_x) & (row_bounds[:, 2] > tile_min_x) + idxs = np.nonzero(mask_x)[0] + + if idxs.size == 0: + z_arr[y0:y1, x0:x1] = 0 + continue + + tile_h = y1 - y0 + tile_w = x1 - x0 + tile_canvas = np.zeros((tile_h, tile_w), dtype=np.uint32) + + for j in idxs: + geom = row_geoms[j] + if geom is None or geom.is_empty: + continue + + # Handle both Polygon and MultiPolygon geometries. + parts = geom.geoms if hasattr(geom, "geoms") else [geom] + for part in parts: + if part.is_empty: + continue + xs, ys = part.exterior.coords.xy + xs = np.asarray(xs, dtype=float) - tile_min_x + ys = np.asarray(ys, dtype=float) - row_min_y + + ys = ys - (y0 - ty * cy) + + rr, cc = polygon(ys, xs, shape=tile_canvas.shape) + tile_canvas[rr, cc] = row_labels[j] + + z_arr[y0:y1, x0:x1] = tile_canvas + + tasks: list[dask.delayed] = [] + for ty in range(n_tiles_y): + tasks.append(dask.delayed(_burn_row)(ty)) + + if tasks: + compute_with_limit(*tasks, n_workers=n_jobs) + + darr = da.from_zarr(z_arr) + + xs = np.arange(min_x, min_x + width, dtype=float) + ys = np.arange(min_y, min_y + height, dtype=float) + + return xr.DataArray( + darr, + dims=("y", "x"), + coords={"y": ys, "x": xs}, + ) + + +# --------------------------------------------------------------------------- +# Canvas from FOV positions (for polygon rasterization) +# --------------------------------------------------------------------------- + + +def _canvas_from_fov_locs_for_polygons( + fov_locs: pd.DataFrame, + *, + origin_x: float, + origin_y: float, + fovs: set[int] | None = None, + fov_size: float = COSMX_FOV_SIZE_PX, +) -> tuple[float, float, int, int]: + """Compute canvas dimensions from FOV positions in polygon space. + + The result is in the same coordinate frame as the polygons (i.e. + after shifting by ``-origin_x``, ``-origin_y``). + + Parameters + ---------- + fov_locs + FOV positions table. + origin_x, origin_y + Global origin subtracted from FOV coordinates. + fovs + Optional FOV id subset. + fov_size + Expected tile extent in pixels. + + Returns + ------- + tuple[float, float, int, int] + ``(canvas_min_x, canvas_min_y, width, height)`` + """ + locs = fov_locs + if fovs is not None: + locs = locs.loc[sorted(fovs)].copy() + + xmins = (locs["xmin"].astype(float) - origin_x).to_numpy() + ymins = (locs["ymin"].astype(float) - origin_y).to_numpy() + + if "xmax" in locs.columns: + xmaxs = (locs["xmax"].astype(float) - origin_x).to_numpy() + else: + xmaxs = xmins + fov_size + + if "ymax" in locs.columns: + ymaxs = (locs["ymax"].astype(float) - origin_y).to_numpy() + else: + ymaxs = ymins + fov_size + + min_x = float(xmins.min()) + min_y = float(ymins.min()) + width = int(math.ceil(float(xmaxs.max() - min_x))) + height = int(math.ceil(float(ymaxs.max() - min_y))) + return min_x, min_y, width, height + + +# --------------------------------------------------------------------------- +# Legacy label stitcher (CellStatsDir layout) +# --------------------------------------------------------------------------- + + +def _find_dir(path: Path, name: str) -> Path: + """Locate a sub-directory by *name* under *path*. + + Raises + ------ + FileNotFoundError + If no directory with *name* is found, or multiple matches exist. + """ + direct = path / name + if direct.is_dir(): + return direct + + paths = list(path.rglob(f"**/{name}")) + if len(paths) != 1: + raise FileNotFoundError(f"Found {len(paths)} path(s) with name {name} inside {path}") + return paths[0] + + +def stitch_segmentation_label_image( + path: str | Path, + fov_position_file: str | Path, + cell_info_file: str | Path, + dataset_id: str | None = None, + seg_dir_name: str = "CellStatsDir", + label_prefix: str = "CellLabels", + flip_image: bool = False, + n_workers: int | None = None, + fovs: set[int] | None = None, +) -> tuple[da.Array, pd.DataFrame]: + """Stitch per-FOV label TIFFs from a CellStatsDir layout. + + Each ``FOVxxx/CellLabels_Fxxx.tif`` is read, local cell ids are + mapped to unique global ids via *cell_info_file*, and the tiles are + stitched into a single zarr-backed dask array. + + Parameters + ---------- + path + Root directory of the CosMx dataset. + fov_position_file + Path to the FOV positions CSV. + cell_info_file + Path to the CSV containing ``fov`` and ``cellID`` columns. + dataset_id + Optional dataset ID override. + seg_dir_name + Name of the segmentation directory to search for. + label_prefix + Filename prefix for the label TIFFs. + flip_image + Fallback tile flip, used only when ``flip_y`` is unavailable for the FOV. + n_workers + Parallelism cap for :func:`compute_with_limit`. + fovs + Optional set of FOV IDs to include. If ``None``, all FOVs are used. + + Returns + ------- + tuple[da.Array, pd.DataFrame] + ``(stitched_labels, cell_info_df)`` where *cell_info_df* has an + added ``global_cell_id`` column. + """ + from ._discovery import _infer_dataset_id + from ._io import _read_fov_locs + + path = Path(path) + dataset_id = _infer_dataset_id(path, dataset_id) + fov_locs = _read_fov_locs(Path(fov_position_file)) + + df = pd.read_csv(cell_info_file) + if fovs is not None: + df = df[df["fov"].isin(fovs)].reset_index(drop=True) + mapping: dict[int, np.ndarray] = {} + offset = 0 + for fov in sorted(df["fov"].unique()): + local_ids = np.unique(df.loc[df["fov"] == fov, "cellID"].astype(int)) + if local_ids.size == 0: + continue + max_local = int(local_ids.max()) + arr_map = np.zeros(max_local + 1, dtype=int) + for lid in local_ids: + offset += 1 + arr_map[lid] = offset + mapping[fov] = arr_map + + df = df.copy() + df["global_cell_id"] = df.apply(lambda row: mapping[row["fov"]][int(row["cellID"])], axis=1) + + seg_root = _find_dir(path, seg_dir_name) + folder_re = FOV_DIR_RE + + fov_tiles: dict[int, da.Array] = {} + for sub in seg_root.iterdir(): + if not sub.is_dir(): + continue + m = folder_re.match(sub.name) + if m is None: + continue + fov = int(m.group(1)) + if fovs is not None and fov not in fovs: + continue + + tif_path = sub / f"{label_prefix}_F{fov:03d}.tif" + if not tif_path.exists(): + raise FileNotFoundError(f"Expected {tif_path} for FOV {fov}") + + arr = tifffile.imread(tif_path) + if fov not in mapping: + logger.warning("Segmentation labels for FOV %d ignored (no entry in %s).", fov, cell_info_file) + continue + lut = mapping[fov] + if int(arr.max()) >= len(lut): + raise ValueError(f"FOV {fov}: label image contains id {int(arr.max())} but mapping has length {len(lut)}.") + arr = lut[arr] + if _label_tile_flip(fov_locs, fov, flip_image): # see :func:`_label_tile_flip` + arr = arr[::-1, :] + + tile = da.from_array(arr, chunks=arr.shape) + + y_size, x_size = arr.shape + fov_locs.loc[fov, "xmax"] = fov_locs.loc[fov, "xmin"] + x_size + fov_locs.loc[fov, "ymax"] = fov_locs.loc[fov, "ymin"] + y_size + + fov_tiles[fov] = tile + + fov_locs = _snap_fov_grid(fov_locs, fov_tiles) + + total_y = int(fov_locs["y1"].max()) + total_x = int(fov_locs["x1"].max()) + + cy = min(1024, total_y) + cx = min(1024, total_x) + + z_arr = _open_temp_zarr("cosmx_seg_labels_", (total_y, total_x), (cy, cx), next(iter(fov_tiles.values())).dtype) + + store_ops: list[dask.delayed] = [] + lock = SerializableLock() # shared across tiles: serialises writes to shared zarr chunks + + for fov, tile in fov_tiles.items(): + tile = tile.rechunk((cy, cx)) + + y0 = int(fov_locs.loc[fov, "y0"]) + y1 = int(fov_locs.loc[fov, "y1"]) + x0 = int(fov_locs.loc[fov, "x0"]) + x1 = int(fov_locs.loc[fov, "x1"]) + + slices = _clip_tile_to_canvas(y0, y1, x0, x1, total_y, total_x) + if slices is None: + continue + gl_y, gl_x, src_y, src_x = slices + + op = da.store(tile[src_y, src_x], z_arr, regions=(gl_y, gl_x), lock=lock, compute=False) + store_ops.append(op) + + with ProgressBar(): + compute_with_limit(*store_ops, n_workers=n_workers) + + stitched = da.from_zarr(z_arr) + return stitched, df + + +# --------------------------------------------------------------------------- +# FOV preview plot +# --------------------------------------------------------------------------- + + +def _plot_fov_preview( + fov_locs: pd.DataFrame, + selected_fovs: set[int] | None = None, + *, + fov_size: float = COSMX_FOV_SIZE_PX, +) -> None: + """Render a preview of the FOV layout using matplotlib. + + Each FOV is drawn as a square with the real CosMx tile size (default + 4256 x 4256 px) and labelled with its FOV id. The Y axis is inverted + to match the CosMx top-left origin convention. + + Parameters + ---------- + fov_locs + FOV positions table with ``xmin`` and ``ymin`` columns. + selected_fovs + If given, only these FOV ids are drawn. + fov_size + Tile size in pixels. + + Raises + ------ + ValueError + If *fov_locs* is empty or missing required columns. + """ + import matplotlib.pyplot as plt + from matplotlib.patches import Rectangle + + if fov_locs is None or fov_locs.empty: + raise ValueError("preview_fovs=True but FOV locations dataframe is empty.") + + df = fov_locs.reset_index().rename(columns={"index": "fov"}) + if selected_fovs is not None: + df = df[df["fov"].isin(selected_fovs)].copy() + + if df.empty: + raise ValueError("preview_fovs=True but no matching FOVs after applying the user subset.") + + if "xmin" not in df.columns or "ymin" not in df.columns: + raise ValueError("preview_fovs=True but FOV positions do not have xmin/ymin columns.") + + x0s = df["xmin"].astype(float).to_numpy() + y0s = df["ymin"].astype(float).to_numpy() + fov_ids = df["fov"].astype(int).to_numpy() + + x_min = float(x0s.min()) + x_max = float((x0s + fov_size).max()) + y_min = float(y0s.min()) + y_max = float((y0s + fov_size).max()) + n = len(fov_ids) + + if n <= 40: + fontsize = 12 + elif n <= 120: + fontsize = 10 + elif n <= 300: + fontsize = 8 + else: + fontsize = 6 + + fig, ax = plt.subplots(figsize=(8, 6)) + + for x, y, _f in zip(x0s, y0s, fov_ids, strict=False): + rect = Rectangle( + (x, y), + fov_size, + fov_size, + fill=False, + edgecolor="black", + linewidth=0.8, + ) + ax.add_patch(rect) + + text_kwargs = { + "ha": "center", + "va": "center", + "color": "black", + "fontsize": fontsize, + } + for x, y, f in zip(x0s, y0s, fov_ids, strict=False): + cx = x + fov_size / 2.0 + cy = y + fov_size / 2.0 + ax.text(cx, cy, str(f), **text_kwargs) + + ax.set_xlim(x_min - fov_size * 0.05, x_max + fov_size * 0.05) + ax.set_ylim(y_max + fov_size * 0.05, y_min - fov_size * 0.05) + ax.set_aspect("equal") + + ax.set_xlabel("Global X (px)") + ax.set_ylabel("Global Y (px)") + ax.set_title("FOV Locations") + ax.grid(linestyle="--", alpha=0.35) + fig.tight_layout() + plt.show() diff --git a/src/spatialdata_io/readers/cosmx/_utils.py b/src/spatialdata_io/readers/cosmx/_utils.py new file mode 100644 index 00000000..1d1d2217 --- /dev/null +++ b/src/spatialdata_io/readers/cosmx/_utils.py @@ -0,0 +1,294 @@ +"""Pure utilities for the CosMx reader — no domain logic.""" + +from __future__ import annotations + +import re +from contextlib import contextmanager +from multiprocessing.pool import ThreadPool +from typing import TYPE_CHECKING, Any + +import dask +import dask.array as da +import dask.dataframe as dd +import numpy as np +import pandas as pd +from spatialdata._logging import logger + +if TYPE_CHECKING: + from pathlib import Path + +# --------------------------------------------------------------------------- +# Dask parallelism helpers +# --------------------------------------------------------------------------- + + +@contextmanager +def dask_thread_pool(n_workers: int | None): + """Temporarily cap dask parallelism to *n_workers* threads.""" + if n_workers is None or n_workers <= 0: + yield + return + with ThreadPool(n_workers) as pool: + with dask.config.set(scheduler="threads", pool=pool): + yield + + +def compute_with_limit(*tasks: Any, n_workers: int | None = None) -> tuple[Any, ...]: + """``dask.compute`` with an optional thread-pool cap.""" + if not tasks: + return () + with dask_thread_pool(n_workers): + return dask.compute(*tasks) + + +# --------------------------------------------------------------------------- +# Dtype helpers +# --------------------------------------------------------------------------- + + +def _to_float01_dtype_max(arr: da.Array) -> da.Array: + """Normalize an array to float32 in [0, 1] by its dtype range (legacy CosMx behavior). + + Unsigned/signed integers are scaled by their dtype min/max; floating arrays are assumed + already normalized and passed through unchanged (dtype/precision preserved). + """ + dt = arr.dtype + if np.issubdtype(dt, np.floating): + return arr + if np.issubdtype(dt, np.unsignedinteger): + return (arr.astype("float32") / float(np.iinfo(dt).max)).astype("float32") + if np.issubdtype(dt, np.signedinteger): + info = np.iinfo(dt) + return ((arr.astype("float32") - float(info.min)) / float(info.max - info.min)).astype("float32") + return arr.astype("float32") + + +def _normalize_image_channels( + arr: da.Array, + channel_names: list[str] | None = None, + *, + percentile: float | None = None, + n_workers: int | None = None, +) -> tuple[da.Array, dict[str, float]]: + """Optionally apply a per-channel percentile contrast stretch to a ``(c, y, x)`` image. + + ``percentile=None`` returns *arr* unchanged (the IO layer already dtype-max normalized + it to ``[0, 1]``). A float such as ``99.9`` divides each channel by that percentile of + its **non-zero, finite** pixels — non-zero so the zero padding between non-contiguous + FOVs does not bias the estimate — recovering channels whose real signal sits far below + the dtype ceiling and otherwise render near-black (issue #38). + + The stretch is scale-only (no clipping), hence exactly reversible via the returned + *scales* (channel name -> divisor). ``da.percentile`` is approximate for multi-chunk + images, so a divisor is a close estimate. A channel with no positive signal is left + unscaled (divisor ``1.0``) with a warning. + """ + if percentile is None: + return arr.astype("float32"), {} + + is_2d = arr.ndim == 2 + channels = [arr] if is_2d else [arr[i] for i in range(arr.shape[0])] + if channel_names is not None and len(channel_names) == len(channels): + names = list(channel_names) + else: + names = [str(i) for i in range(len(channels))] + + scales: dict[str, float] = {} + out: list[da.Array] = [] + for nm, ch in zip(names, channels, strict=False): + flat = ch.ravel() + sample = flat[(flat != 0) & da.isfinite(flat)] + try: + (p,) = compute_with_limit(da.percentile(sample, percentile), n_workers=n_workers) + p = float(np.asarray(p).ravel()[0]) + except ValueError: # no non-zero finite pixels + p = float("nan") + if np.isfinite(p) and p > 0.0: + div = p + else: + logger.warning( + "Image channel %r: no positive signal at the %.4g-th percentile; left unscaled.", nm, percentile + ) + div = 1.0 + scales[nm] = div + out.append(ch.astype("float32") / div) + + return (out[0] if is_2d else da.stack(out, axis=0)), scales + + +# --------------------------------------------------------------------------- +# Categorical / string helpers (zarr compatibility) +# --------------------------------------------------------------------------- + + +def _pandas_categoricals_to_string(df: pd.DataFrame) -> pd.DataFrame: + """Convert categorical columns to string dtype (zarr cannot serialize categoricals).""" + if df.empty: + return df + cat_cols = [c for c in df.columns if isinstance(df[c].dtype, pd.CategoricalDtype)] + for c in cat_cols: + df[c] = df[c].astype("string").fillna("") + return df + + +def _dask_categoricals_to_string(df: dd.DataFrame) -> dd.DataFrame: + """Partition-wise categorical → string conversion for dask DataFrames.""" + cat_cols = [ + col for col, dtype in df.dtypes.items() if isinstance(dtype, pd.CategoricalDtype) or str(dtype) == "category" + ] + if not cat_cols: + return df + + def _to_str(pdf: pd.DataFrame) -> pd.DataFrame: + for col in cat_cols: + if col in pdf.columns: + pdf[col] = pdf[col].astype("string").fillna("") + return pdf + + return df.map_partitions(_to_str) + + +# --------------------------------------------------------------------------- +# CSV header matching for polygon files +# --------------------------------------------------------------------------- + +# Canonical column names → list of regex patterns to try (case-insensitive). +_CANONICAL: dict[str, list[str]] = { + "fov": [r"^fov$", r"^roi$", r"^field_?of_?view$", r"^fov_id$"], + "cell_ID": [r"^cell[_ ]?id$", r"^cellid$", r"^cell$", r"^object_id$", r"^cell_identifier$"], + "polygon_index": [ + r"^polygon_index$", + r"^(poly|shape|segm)[-_ ]?index$", + r"^(poly|shape|segm)[-_ ]?idx$", + r"^object_index$", + r"^segmentation_index$", + ], + "x": [r"^x_global_px$", r"^x_local_px$", r"^x[_ ]?px$", r"^x$"], + "y": [r"^y_global_px$", r"^y_local_px$", r"^y[_ ]?px$", r"^y$"], +} + + +def _match_canonical(hdr: list[str], canon: str) -> str | None: + """Return the raw column in *hdr* matching canonical name *canon*, else ``None``. + + Uses the same case-insensitive alias patterns as :func:`_match_header`, but for a + single column and without requiring the full polygon schema. This is the shared + entry point so every caller honours the *same* alias set (e.g. ``cell_ID`` also + matching ``cellID``/``cell_id``/``object_id``) — a narrower ad-hoc match risks + missing columns one path accepts and another silently drops. + """ + hdr_stripped = [c.strip() for c in hdr] + for pat in _CANONICAL.get(canon, []): + hit = next( + ( + orig + for orig, stripped in zip(hdr, hdr_stripped, strict=False) + if re.fullmatch(pat, stripped, flags=re.IGNORECASE) + ), + None, + ) + if hit is not None: + return hit + return None + + +def _match_header(hdr: list[str]) -> dict[str, str]: + """Map raw CSV column names to canonical names via regex matching. + + Returns ``{original_name: canonical_name}`` for each matched column. + ``polygon_index`` is optional; ``fov``, ``cell_ID``, ``x``, ``y`` are required. + + Raises + ------ + ValueError + If required columns cannot be identified. + """ + rename: dict[str, str] = {} + + for canon in _CANONICAL: + hit = _match_canonical(hdr, canon) + if hit is not None: + rename[hit] = canon + + # Heuristic fallback for x/y — these vary the most across exports. + def _first_like(coord: str) -> str | None: + patt = re.compile(rf"^{coord}.*px$", flags=re.IGNORECASE) + for orig in hdr: + if patt.match(orig.strip()): + return orig + for orig in hdr: + if orig.strip().lower() == coord: + return orig + return None + + if "x" not in rename.values(): + maybe_x = _first_like("x") + if maybe_x is not None: + rename[maybe_x] = "x" + if "y" not in rename.values(): + maybe_y = _first_like("y") + if maybe_y is not None: + rename[maybe_y] = "y" + + required = {"fov", "cell_ID", "x", "y"} + missing = required - set(rename.values()) + if missing: + raise ValueError(f"Failed to identify required columns {sorted(missing)} in polygons header {hdr}") + return rename + + +# --------------------------------------------------------------------------- +# CellLabels TIF discovery +# --------------------------------------------------------------------------- + +_CELL_LABELS_RE = re.compile(r"CellLabels_F(\d+)\.tif$", re.IGNORECASE) + + +def find_cell_label_tifs(directory: Path) -> dict[int, Path]: + """Find ``CellLabels_F*.tif`` files under *directory* (recursive). + + Returns a ``{fov_id: path}`` mapping. + """ + result: dict[int, Path] = {} + for p in directory.rglob("CellLabels_F*.[tT][iI][fF]"): + m = _CELL_LABELS_RE.search(p.name) + if m: + result[int(m.group(1))] = p + return result + + +# Canonical FOV-id patterns, shared with the stitchers so detection matches +# what the reader actually loads: morphology TIFFs (``*_F``) and protein +# FOV directories (``FOV``). +MORPH_FOV_RE = re.compile(r".*_F(\d+)", re.IGNORECASE) +FOV_DIR_RE = re.compile(r"FOV0*(\d+)$", re.IGNORECASE) + + +def detect_fovs_with_data( + *, + morphology_2d_dir: Path | None = None, + cell_labels_dir: Path | None = None, + cell_stats_dir: Path | None = None, + analysis_results_dir: Path | None = None, +) -> set[int]: + """FOV ids with at least one per-FOV data file on disk, by union (issue #37). + + Sources: morphology TIFFs, ``CellLabels`` TIFFs (incl. nested under a + CellStatsDir), and protein ``FOV/ProteinImages`` dirs. One listing per + source. Transcripts/exprMat are not per-FOV, so an empty result means + "cannot tell" — callers must not prune on it. + """ + present: set[int] = set() + if morphology_2d_dir and morphology_2d_dir.exists(): + present |= {int(m.group(1)) for p in morphology_2d_dir.glob("*.TIF") if (m := MORPH_FOV_RE.match(p.name))} + if cell_labels_dir and cell_labels_dir.exists(): + present |= find_cell_label_tifs(cell_labels_dir).keys() + if cell_stats_dir and cell_stats_dir.exists(): + present |= find_cell_label_tifs(cell_stats_dir).keys() + if analysis_results_dir and analysis_results_dir.exists(): + present |= { + int(m.group(1)) + for p in analysis_results_dir.rglob("FOV*/ProteinImages") + if (m := FOV_DIR_RE.match(p.parent.name)) + } + return present diff --git a/tests/test_cosmx_regression.py b/tests/test_cosmx_regression.py new file mode 100644 index 00000000..9d8d51e9 --- /dev/null +++ b/tests/test_cosmx_regression.py @@ -0,0 +1,1707 @@ +"""CosmX reader regression tests with synthetic fixtures. + +Fixture matrix (derived from real data): + + Fixture A — "new-style" (lymph/tonsil-like) + px+mm FOV positions, flip_y=False, .csv.gz files, uppercase FOV column, + Morphology2D images, single modality, TWO ADJACENT FOVs + + Fixture B — "old-style px-only" (hippocampus-like) + px-only FOV positions, flip_y=True, .csv (uncompressed), lowercase fov + column, Morphology2D images, single modality, TWO NON-ADJACENT FOVs + + Fixture C — "old-style mm-only" (pancreas-like) + mm-only positions, flip_y=True, extra columns (Slide, ROI, Order), + CellLabels TIFFs instead of Morphology2D, single modality, SINGLE FOV + +After standardization, ALL fixtures should produce the same normalized +SpatialData structure. + +Run: pixi run -e dev-py313 python -m pytest tests/test_cosmx_regression.py -v +""" + +from __future__ import annotations + +import json +import math +from typing import TYPE_CHECKING + +import numpy as np +import pandas as pd +import pytest + +if TYPE_CHECKING: + from pathlib import Path + +# --------------------------------------------------------------------------- +# constants +# --------------------------------------------------------------------------- +COSMX_PIXEL_SIZE = 0.120280945 +MM_TO_PX = 1000.0 / COSMX_PIXEL_SIZE +# The reader forces FOV_SIZE_PX=4256 in _snap_fov_grid. +# Our polygon local coords and FOV positions must be consistent with this. +# Real CosmX = 4256. We use a small size for fast test TIFFs. +# _snap_fov_grid will warn but still force 4256 for coordinates. +TIFF_FOV_SIZE = 256 +# The reader's internal spec size (used for polygon coords, FOV positions, grid snapping) +SPEC_FOV_SIZE = 4256 +N_CHANNELS = 3 +N_CELLS_PER_FOV = 5 +N_GENES = 10 + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + + +def _tiff_description() -> str: + """JSON that _get_cosmx_morphology_coords can parse via regex.""" + return json.dumps( + { + "ChannelOrder": "BGU", + "MorphologyKit": { + "MorphologyReagents": [ + {"BiologicalTarget": "Histone", "Fluorophore": {"ChannelId": "B"}}, + {"BiologicalTarget": "DNA", "Fluorophore": {"ChannelId": "G"}}, + {"BiologicalTarget": "rRNA", "Fluorophore": {"ChannelId": "U"}}, + ] + }, + } + ) + + +def _write_morphology_tiff(path: Path, n_channels: int = N_CHANNELS, fov_size: int = SPEC_FOV_SIZE): + """Write a multi-page TIFF that dask_image.imread reads as (C, Y, X).""" + import tifffile + + path.parent.mkdir(parents=True, exist_ok=True) + desc = _tiff_description() + data = np.random.default_rng(0).integers(0, 65535, (n_channels, fov_size, fov_size), dtype=np.uint16) + tifffile.imwrite(str(path), data, description=desc, photometric="minisblack") + + +def _write_asymmetric_morphology_tiff(path: Path, n_channels: int = N_CHANNELS, fov_size: int = TIFF_FOV_SIZE): + """Morphology TIFF whose top half (100) differs from its bottom half (200), + so a vertical flip is detectable by comparing top- vs bottom-row means. + """ + import tifffile + + path.parent.mkdir(parents=True, exist_ok=True) + data = np.zeros((n_channels, fov_size, fov_size), dtype=np.uint16) + data[:, : fov_size // 2, :] = 100 + data[:, fov_size // 2 :, :] = 200 + tifffile.imwrite(str(path), data, description=_tiff_description(), photometric="minisblack") + + +def _write_cell_label_tiff(path: Path, n_cells: int = N_CELLS_PER_FOV, fov_size: int = TIFF_FOV_SIZE): + """Write a CellLabels TIFF with small blocks of unique cell IDs.""" + import tifffile + + path.parent.mkdir(parents=True, exist_ok=True) + data = np.zeros((fov_size, fov_size), dtype=np.uint16) + block = fov_size // (n_cells + 2) + for cid in range(1, n_cells + 1): + y0, y1 = cid * block, cid * block + block + x0, x1 = cid * block, cid * block + block + data[y0:y1, x0:x1] = cid + tifffile.imwrite(str(path), data) + + +def _write_csv(path: Path, df: pd.DataFrame, compress: bool = False): + path.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(path, index=False, compression="gzip" if compress else None) + + +def _make_polygon_rows(fov: int, cell_id: int, cx: float, cy: float, r: float = 50.0, n_pts: int = 6): + """Hexagon vertices for one cell in LOCAL coordinates.""" + rows = [] + for i in range(n_pts): + angle = 2 * math.pi * i / n_pts + rows.append( + { + "fov": fov, + "cellID": cell_id, + "x_local_px": cx + r * math.cos(angle), + "y_local_px": cy + r * math.sin(angle), + } + ) + return rows + + +def _make_polygon_df(fov_positions: dict[int, tuple[float, float]], n_cells: int = N_CELLS_PER_FOV): + """Build polygon CSV with both local and global coords. + + fov_positions: {fov_id: (x_global_px, y_global_px)} + """ + rows = [] + spacing = SPEC_FOV_SIZE / (n_cells + 1) + for fov, (x_off, y_off) in fov_positions.items(): + for cid in range(1, n_cells + 1): + cx_local = spacing * cid + cy_local = spacing * cid + for pt in _make_polygon_rows(fov, cid, cx_local, cy_local): + pt["x_global_px"] = pt["x_local_px"] + x_off + pt["y_global_px"] = pt["y_local_px"] + y_off + rows.append(pt) + return pd.DataFrame(rows) + + +def _make_expr_mat(fovs: list[int], n_cells: int = N_CELLS_PER_FOV, n_genes: int = N_GENES): + rows = [] + rng = np.random.default_rng(42) + gene_names = [f"Gene_{i}" for i in range(n_genes)] + for fov in fovs: + for cid in range(1, n_cells + 1): + row = {"fov": fov, "cell_ID": cid} + for g in gene_names: + row[g] = int(rng.integers(0, 100)) + rows.append(row) + return pd.DataFrame(rows) + + +def _make_metadata(fovs: list[int], n_cells: int = N_CELLS_PER_FOV): + rows = [] + rng = np.random.default_rng(42) + for fov in fovs: + for cid in range(1, n_cells + 1): + rows.append( + { + "fov": fov, + "cell_ID": cid, + "Area": float(rng.uniform(50, 500)), + } + ) + return pd.DataFrame(rows) + + +# --------------------------------------------------------------------------- +# Fixture A: new-style, TWO ADJACENT FOVs +# --------------------------------------------------------------------------- +@pytest.fixture(scope="session") +def fixture_a(tmp_path_factory) -> Path: + """px+mm, Morphology2D, .csv.gz. FOVs 1 & 2 are side-by-side.""" + root = tmp_path_factory.mktemp("fixture_a") + prefix = "S0" + fovs = [1, 2] + + # Adjacent: FOV 2 is directly right of FOV 1 + fov_pos = {1: (0.0, 0.0), 2: (float(SPEC_FOV_SIZE), 0.0)} + x_mm = {f: x / MM_TO_PX for f, (x, _) in fov_pos.items()} + y_mm = {f: y / MM_TO_PX for f, (_, y) in fov_pos.items()} + + fov_df = pd.DataFrame( + { + "FOV": fovs, + "x_global_px": [fov_pos[f][0] for f in fovs], + "y_global_px": [fov_pos[f][1] for f in fovs], + "x_global_mm": [x_mm[f] for f in fovs], + "y_global_mm": [y_mm[f] for f in fovs], + } + ) + _write_csv(root / f"{prefix}_fov_positions_file.csv.gz", fov_df, compress=True) + _write_csv(root / f"{prefix}_exprMat_file.csv.gz", _make_expr_mat(fovs), compress=True) + _write_csv(root / f"{prefix}_metadata_file.csv.gz", _make_metadata(fovs), compress=True) + _write_csv(root / f"{prefix}-polygons.csv.gz", _make_polygon_df(fov_pos), compress=True) + + morph = root / "Morphology2D" + for fov in fovs: + _write_morphology_tiff(morph / f"20240101_S0_F{fov:05d}.TIF", fov_size=TIFF_FOV_SIZE) + + return root + + +# --------------------------------------------------------------------------- +# Fixture B: old-style px-only, TWO NON-ADJACENT FOVs +# --------------------------------------------------------------------------- +@pytest.fixture(scope="session") +def fixture_b(tmp_path_factory) -> Path: + """px-only → flip_y=True, .csv, lowercase fov. FOVs 1 & 5 are spaced apart.""" + root = tmp_path_factory.mktemp("fixture_b") + prefix = "Run5642_S3_Quarter" + fovs = [1, 5] + + # Non-adjacent: FOV 5 is 3 FOV-widths away + fov_pos = {1: (0.0, 0.0), 5: (3.0 * SPEC_FOV_SIZE, 2.0 * SPEC_FOV_SIZE)} + + fov_df = pd.DataFrame( + { + "fov": fovs, + "x_global_px": [fov_pos[f][0] for f in fovs], + "y_global_px": [fov_pos[f][1] for f in fovs], + } + ) + _write_csv(root / f"{prefix}_fov_positions_file.csv", fov_df) + _write_csv(root / f"{prefix}_exprMat_file.csv", _make_expr_mat(fovs)) + _write_csv(root / f"{prefix}_metadata_file.csv", _make_metadata(fovs)) + _write_csv(root / f"{prefix}-polygons.csv", _make_polygon_df(fov_pos)) + + morph = root / "Morphology2D" + for fov in fovs: + _write_morphology_tiff(morph / f"20240101_S0_F{fov:05d}.TIF", fov_size=TIFF_FOV_SIZE) + + return root + + +# --------------------------------------------------------------------------- +# Fixture C: mm-only, SINGLE FOV, CellLabels instead of Morphology2D +# --------------------------------------------------------------------------- +@pytest.fixture(scope="session") +def fixture_c(tmp_path_factory) -> Path: + """mm-only → flip_y=True, extra columns, CellLabels dir, single FOV.""" + root = tmp_path_factory.mktemp("fixture_c") + prefix = "Pancreas" + fovs = [1] + + x_mm_val = 1.0 + y_mm_val = 2.0 + fov_pos = {1: (x_mm_val * MM_TO_PX, y_mm_val * MM_TO_PX)} + + fov_df = pd.DataFrame( + { + "Slide": ["S1"], + "X_mm": [x_mm_val], + "Y_mm": [y_mm_val], + "Z_mm": [0.0], + "ZOffset_mm": [0.0], + "ROI": [1], + "FOV": fovs, + "Order": fovs, + "Run_Tissue_name": ["tissue"], + } + ) + _write_csv(root / f"{prefix}_fov_positions_file.csv", fov_df) + _write_csv(root / f"{prefix}_exprMat_file.csv", _make_expr_mat(fovs)) + _write_csv(root / f"{prefix}_metadata_file.csv", _make_metadata(fovs)) + _write_csv(root / f"{prefix}-polygons.csv", _make_polygon_df(fov_pos)) + + labels_dir = root / "CellLabels" + for fov in fovs: + _write_cell_label_tiff(labels_dir / f"CellLabels_F{fov:03d}.tif") + + return root + + +# --------------------------------------------------------------------------- +# Fixture MM: multimodal (RNA + Protein), prefix-based (V2-style), SINGLE FOV +# --------------------------------------------------------------------------- +@pytest.fixture(scope="session") +def fixture_multimodal(tmp_path_factory) -> Path: + """Two modality prefixes (``S0RNA`` / ``S0Protein``) → base id ``S0`` → + triggers the multimodal (``_cosmx_multi``) dispatch. px-only (flip_y=True), + single FOV. Morphology is asymmetric (top 100 / bottom 200) so the image + flip is detectable. The RNA modality carries the shared spatial files + (it's chosen as the label modality because it has polygons). + """ + root = tmp_path_factory.mktemp("fixture_mm") + fovs = [1] + fov_pos = {1: (0.0, 0.0)} + + for prefix in ("S0RNA", "S0Protein"): + _write_csv(root / f"{prefix}_exprMat_file.csv", _make_expr_mat(fovs)) + _write_csv(root / f"{prefix}_metadata_file.csv", _make_metadata(fovs)) + + # Shared spatial files live with the RNA (label) modality. + fov_df = pd.DataFrame( + { + "fov": fovs, + "x_global_px": [fov_pos[f][0] for f in fovs], # px-only → flip_y=True + "y_global_px": [fov_pos[f][1] for f in fovs], + } + ) + _write_csv(root / "S0RNA_fov_positions_file.csv", fov_df) + _write_csv(root / "S0RNA-polygons.csv", _make_polygon_df(fov_pos)) + + morph = root / "Morphology2D" + for fov in fovs: + _write_asymmetric_morphology_tiff(morph / f"20240101_S0_F{fov:05d}.TIF") + + return root + + +# --------------------------------------------------------------------------- +# shared sdio fixture +# --------------------------------------------------------------------------- +@pytest.fixture(scope="session") +def sdio(): + import spatialdata_io as _sdio + + return _sdio + + +def _read(sdio, path: Path, **kw): + """Read a fixture with sensible defaults for fast testing.""" + defaults = {"n_workers": 1, "read_transcripts": False} + defaults.update(kw) + return sdio.cosmx(path, **defaults) + + +# ═══════════════════════════════════════════════════════════════════════════ +# 1. DISCOVERY +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestDatasetDiscovery: + def test_fixture_a(self, fixture_a): + from spatialdata_io.readers.cosmx._discovery import _set_up_cosmx_dataset_for_conversion + + ds = _set_up_cosmx_dataset_for_conversion(fixture_a) + assert ds.dataset_id == "S0" + assert ds.fov_positions_file is not None + assert ds.polygons_file is not None + assert ds.exprMat_file is not None + assert ds.metadata_file is not None + assert ds.morphology_2d_dir is not None + assert ds.modalities is None + + +# ═══════════════════════════════════════════════════════════════════════════ +# 2. FOV POSITIONS +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestFovPositions: + def test_mm_only_flip_and_conversion(self, fixture_c): + from spatialdata_io.readers.cosmx._io import _read_fov_locs + + fl = _read_fov_locs(fixture_c / "Pancreas_fov_positions_file.csv") + assert fl["flip_y"].all() + assert fl["xmin"].iloc[0] > 1000 # mm→px conversion + + def test_all_produce_standard_columns(self, fixture_a, fixture_b, fixture_c): + from spatialdata_io.readers.cosmx._io import _read_fov_locs + + paths = [ + fixture_a / "S0_fov_positions_file.csv.gz", + fixture_b / "Run5642_S3_Quarter_fov_positions_file.csv", + fixture_c / "Pancreas_fov_positions_file.csv", + ] + for p in paths: + fl = _read_fov_locs(p) + for col in ("xmin", "ymin", "xmax", "ymax", "flip_y"): + assert col in fl.columns, f"{p.name} missing {col}" + assert (fl["xmax"] > fl["xmin"]).all() + assert (fl["ymax"] > fl["ymin"]).all() + + def test_missing_fov_raises(self, fixture_a): + from spatialdata_io.readers.cosmx._io import _read_fov_locs + + with pytest.raises(KeyError): + _read_fov_locs(fixture_a / "S0_fov_positions_file.csv.gz", fovs=[99]) + + +# ═══════════════════════════════════════════════════════════════════════════ +# 3. HEADER MATCHING (no I/O) +# ═══════════════════════════════════════════════════════════════════════════ + + +# ═══════════════════════════════════════════════════════════════════════════ +# 4. GLOBAL CELL ID (no I/O) +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestGlobalCellId: + def test_formula(self): + base = 6 + df = pd.DataFrame({"fov": [1, 1, 2, 2], "cell_ID": [0, 5, 0, 3]}) + gids = df["fov"] * base * (df["cell_ID"] > 0).astype(int) + df["cell_ID"] + assert gids.tolist() == [0, 11, 0, 15] + + def test_prescan_prevents_crash(self, tmp_path): + """max_cell_id pre-scan should prevent ValueError when elements + have different max cell_IDs. + """ + from spatialdata_io.readers.cosmx._reader import CosMxDatasetReader, _prescan_max_cell_id + + # Create two CSVs: polygons with max cell_ID=10, expr with max=20 + poly_csv = tmp_path / "poly.csv" + expr_csv = tmp_path / "expr.csv" + poly_csv.write_text("fov,cell_ID,x,y\n1,10,0,0\n") + expr_csv.write_text("fov,cell_ID,Gene_0\n1,20,5\n") + + class FakeDataset: + polygons_file = poly_csv + exprMat_file = expr_csv + metadata_file = None + + class FakeReader: + max_cell_id = None + + reader = FakeReader() + _prescan_max_cell_id(reader, FakeDataset(), fov_set=None) + assert reader.max_cell_id == 20, "Should lock to the global max across all files" + + # Now calling global_cell_id with cell_ID=20 should NOT crash + real_reader = CosMxDatasetReader.__new__(CosMxDatasetReader) + real_reader.max_cell_id = reader.max_cell_id + gid = real_reader.global_cell_id(pd.DataFrame({"fov": [1], "cell_ID": [20]})) + assert gid.iloc[0] == 1 * 21 + 20 + + def test_prescan_matches_aliased_id_and_fov_columns(self, tmp_path): + """The pre-scan must honour the same column aliases as the header + canonicaliser. A polygon CSV using ``object_id``/``roi`` (instead of + ``cell_ID``/``fov``) must still contribute its max — otherwise an + orphan (segmented-only) cell with the highest per-FOV id is missed and + max_cell_id is mis-locked. The old hardcoded ("cell_ID", "cellID") + match would skip this file entirely. + """ + from spatialdata_io.readers.cosmx._reader import _prescan_max_cell_id + + poly_csv = tmp_path / "poly.csv" + expr_csv = tmp_path / "expr.csv" + # Orphan id 30 lives only in the polygons, under aliased headers. The + # fov-2 id=99 row must be excluded once we restrict to FOV 1 — which + # only works if the aliased ``roi`` column is resolved too. + poly_csv.write_text("roi,object_id,x,y\n1,30,0,0\n2,99,0,0\n") + expr_csv.write_text("fov,cell_ID,Gene_0\n1,20,5\n") + + class FakeDataset: + polygons_file = poly_csv + exprMat_file = expr_csv + metadata_file = None + + class FakeReader: + max_cell_id = None + + reader = FakeReader() + _prescan_max_cell_id(reader, FakeDataset(), fov_set={1}) + assert reader.max_cell_id == 30 + + def test_prescan_warns_instead_of_silently_dropping_a_file(self, tmp_path, caplog): + """A file that exists but has no recognizable cell-ID column must be + surfaced (warning), not silently swallowed — a silent drop in a pre-scan + that establishes the max_cell_id invariant hides the exact schema + mismatch that mis-locks the max. + """ + import logging + + from spatialdata._logging import logger as sd_logger + + from spatialdata_io.readers.cosmx._reader import _prescan_max_cell_id + + bad_csv = tmp_path / "bad.csv" + bad_csv.write_text("fov,mystery,x,y\n1,7,0,0\n") + + class FakeDataset: + polygons_file = bad_csv + exprMat_file = None + metadata_file = None + + class FakeReader: + max_cell_id = None + + reader = FakeReader() + # spatialdata's logger does not propagate to the root logger caplog + # attaches to, so enable it for the duration of this assertion. + prev_propagate = sd_logger.propagate + sd_logger.propagate = True + try: + with caplog.at_level(logging.WARNING): + _prescan_max_cell_id(reader, FakeDataset(), fov_set=None) + finally: + sd_logger.propagate = prev_propagate + + assert reader.max_cell_id is None + assert any("cell_ID" in rec.getMessage() for rec in caplog.records) + + +# ═══════════════════════════════════════════════════════════════════════════ +# 4b. TILE CLIPPING HELPER +# ═══════════════════════════════════════════════════════════════════════════ + + +# ═══════════════════════════════════════════════════════════════════════════ +# 4c. TRANSCRIPT PLACEMENT (local-coord; replaces the old fov_shift / +4256) +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestTranscriptPlacement: + """Transcripts must be placed with the SAME per-FOV local->grid mapping as + polygons, so they co-register by construction (issue #39 / #18). + """ + + _fov_locs = pd.DataFrame( + { + "xmin": [0.0, 4256.0], + "ymin": [0.0, 0.0], + "xmax": [4256.0, 8512.0], + "ymax": [4256.0, 4256.0], + "flip_y": [False, False], + }, + index=pd.Index([1, 2], name="fov"), + ) + + def _df(self): + return pd.DataFrame( + { + "fov": [1, 2], + "x_local_px": [10.0, 20.0], + "y_local_px": [100.0, 200.0], + # raw global deliberately bogus to prove it is NOT used + "x_global_px": [-9e9, -9e9], + "y_global_px": [-9e9, -9e9], + } + ) + + def test_flip_matches_polygon_formula(self): + from spatialdata_io.readers.cosmx._io import place_local_in_fov_grid + + fov_locs = self._fov_locs.copy() + fov_locs["flip_y"] = [True, True] + out = place_local_in_fov_grid(self._df(), fov_locs) + # flip: y = y0 + y_local + assert out["x_global_px"].tolist() == [10.0, 4256.0 + 20.0] + assert out["y_global_px"].tolist() == [100.0, 200.0] + + @pytest.mark.parametrize("flip", [True, False]) + def test_dask_path_matches_pandas(self, flip): + """Regression: the real transcript path runs on a Dask DataFrame (CSV is + converted to Parquet, then read lazily). A prior rewrite expressed the + flip as ``y_local.where(fov.map(flip_dict), ...)`` — on Dask, + ``Series.map(dict)`` is object dtype and ``where`` rejects a non-bool + condition, crashing every ``read_labels=True`` read. The synthetic tests + only covered the pandas path, so it slipped through. Assert the Dask path + runs AND yields the same numbers as pandas. + """ + import dask.dataframe as dd + + from spatialdata_io.readers.cosmx._io import place_local_in_fov_grid + + fov_locs = self._fov_locs.copy() + fov_locs["flip_y"] = [flip, flip] + + pdf = self._df() + expect = place_local_in_fov_grid(pdf, fov_locs) + + ddf = dd.from_pandas(pdf, npartitions=2) + out = place_local_in_fov_grid(ddf, fov_locs) + # must be lazy (not eagerly computed to pandas) and must not have raised + assert hasattr(out, "compute"), "expected a Dask DataFrame back" + got = out.compute() + assert got["x_global_px"].tolist() == expect["x_global_px"].tolist() + assert got["y_global_px"].tolist() == expect["y_global_px"].tolist() + + +def _lookup_centres(raster, fov_locs, fov, n_cells, ox, oy): + """Place one transcript at each cell's local block-centre via + ``place_local_in_fov_grid`` and read the label *raster* there; the returned + list holds the raster value under each cell's centre (``-1`` if off-canvas). + """ + from spatialdata_io.readers.cosmx._io import place_local_in_fov_grid + + block = SPEC_FOV_SIZE // (n_cells + 2) + centres = [c * block + block // 2 for c in range(1, n_cells + 1)] + placed = place_local_in_fov_grid( + pd.DataFrame({"fov": [fov] * n_cells, "x_local_px": centres, "y_local_px": centres}), + fov_locs, + ) + col = np.round(placed["x_global_px"].to_numpy() - ox).astype(int) + row = np.round(placed["y_global_px"].to_numpy() - oy).astype(int) + H, W = raster.shape[-2:] + ib = (row >= 0) & (row < H) & (col >= 0) & (col < W) + return np.where(ib, raster[np.clip(row, 0, H - 1), np.clip(col, 0, W - 1)], -1).tolist() + + +class TestLabelTranscriptCoRegistration: + """CellLabels-TIFF rasters must co-register with transcripts placed via + ``place_local_in_fov_grid``. + + Regression for the stitcher tile-flip bug (#39): the CellLabels stitcher + flipped each tile by ``flip_y`` (``do_flip = flip_image or flip_y``), but + transcripts with ``flip_y=True`` are placed DIRECT (``y = y0 + y_local``) — + so the raster ended up vertically mirrored relative to the points. Confirmed + on pancreas (mm-only, the only real dataset with CellLabels TIFFs): the + stitched-label lookup ``label[row, col] == global_cell_id`` was 0.8% direct / + 100% within-FOV-mirrored, and 100% / 0.8% after the fix. + + Earlier regression tests only exercised the pandas transcript path and the + polygon-rasterized label path, so neither this flip nor the Dask ``.where`` + crash was covered. This test drives the CellLabels-TIFF stitcher and the + transcript placement together and asserts they index the same cell. + """ + + @staticmethod + def _stitch_and_lookup(tmp_path: Path, flip_y: bool): + import tifffile + + from spatialdata_io.readers.cosmx._stitching import _read_stitched_cell_labels_from_dir + + # Use the real FOV size so the flip arithmetic in placement (which uses + # COSMX_FOV_SIZE_PX) matches the label tile height. + fov_size = SPEC_FOV_SIZE + n_cells = 5 + labels_dir = tmp_path / "CellLabels" + _write_cell_label_tiff(labels_dir / "CellLabels_F001.tif", n_cells=n_cells, fov_size=fov_size) + + xmin, ymin = 1000.0, 2000.0 # non-zero origin to catch frame mistakes + fov_locs = pd.DataFrame( + { + "xmin": [xmin], + "ymin": [ymin], + "xmax": [xmin + fov_size], + "ymax": [ymin + fov_size], + "flip_y": [flip_y], + }, + index=pd.Index([1], name="fov"), + ) + + stitched, _, used = _read_stitched_cell_labels_from_dir( + labels_dir, + fov_locs.copy(), + fovs={1}, + flip_image=False, + n_workers=1, + fov_local_to_global=None, + ) + arr = np.asarray(stitched.compute() if hasattr(stitched, "compute") else stitched) + ox, oy = float(used["xmin"].min()), float(used["ymin"].min()) + looked_up = _lookup_centres(arr, fov_locs, 1, n_cells, ox, oy) + return looked_up, list(range(1, n_cells + 1)) + + @pytest.mark.parametrize("flip_y", [True, False]) + def test_transcripts_hit_their_cell(self, tmp_path, flip_y): + looked_up, expected = self._stitch_and_lookup(tmp_path, flip_y) + assert looked_up == expected, ( + f"flip_y={flip_y}: each transcript placed from local coords must " + f"index its own cell in the stitched CellLabels raster. Got " + f"{looked_up}, expected {expected}. A mismatch means the label tile " + f"flip and the transcript placement disagree (the #39 mirror bug)." + ) + + +class TestLabelTileFlip: + """Unit tests for the per-FOV label-tile flip rule (``_label_tile_flip``).""" + + @staticmethod + def _locs(flip_y: bool): + return pd.DataFrame( + {"xmin": [0.0, 4256.0], "ymin": [0.0, 0.0], "flip_y": [flip_y, flip_y]}, + index=pd.Index([1, 2], name="fov"), + ) + + def test_flip_y_true_means_no_tile_flip(self): + from spatialdata_io.readers.cosmx._stitching import _label_tile_flip + + locs = self._locs(True) + assert _label_tile_flip(locs, 1, flip_image=False) is False + # flip_y is authoritative — it overrides the flip_image fallback + assert _label_tile_flip(locs, 2, flip_image=True) is False + + def test_fallback_when_fov_absent_from_locs(self): + from spatialdata_io.readers.cosmx._stitching import _label_tile_flip + + locs = self._locs(True) # only FOVs 1, 2 + assert _label_tile_flip(locs, 99, flip_image=True) is True + assert _label_tile_flip(locs, 99, flip_image=False) is False + + def test_cellstats_stitcher_coregisters_px_only(self, tmp_path): + """End-to-end #41: the legacy CellStatsDir stitcher must co-register + labels with transcripts on a px-only (flip_y=True) dataset. + + We pass ``flip_image=True`` to prove ``flip_y`` overrides it: pre-fix the + tile flipped by ``flip_image`` (mirrored vs the direct-placed + transcripts); post-fix it flips ``not flip_y`` (= no flip). + """ + from spatialdata_io.readers.cosmx._io import _read_fov_locs + from spatialdata_io.readers.cosmx._stitching import stitch_segmentation_label_image + + prefix = "S0" + fov = 1 + n_cells = N_CELLS_PER_FOV + + # px-only positions -> flip_y=True (the convention the bug affects) + pos_file = tmp_path / f"{prefix}_fov_positions_file.csv" + _write_csv( + pos_file, + pd.DataFrame( + { + "fov": [fov], + "x_global_px": [1000.0], + "y_global_px": [2000.0], + } + ), + ) + cs = tmp_path / "CellStatsDir" + ci_file = cs / f"{prefix}_cell_info.csv" + _write_csv( + ci_file, + pd.DataFrame( + { + "fov": [fov] * n_cells, + "cellID": list(range(1, n_cells + 1)), + } + ), + ) + _write_cell_label_tiff( + cs / f"FOV{fov:03d}" / f"CellLabels_F{fov:03d}.tif", n_cells=n_cells, fov_size=SPEC_FOV_SIZE + ) + + stitched, df = stitch_segmentation_label_image( + path=tmp_path, + fov_position_file=str(pos_file), + cell_info_file=str(ci_file), + dataset_id=prefix, + flip_image=True, + n_workers=1, + fovs={fov}, # inert here: overridden by flip_y + ) + lab = np.asarray(stitched.compute() if hasattr(stitched, "compute") else stitched) + + # The stitcher returns only (labels, cell_info); re-read positions for the + # global origin (xmin/ymin) — _snap_fov_grid zeroes only the raster origin. + fov_locs = _read_fov_locs(pos_file) + ox, oy = float(fov_locs.loc[fov, "xmin"]), float(fov_locs.loc[fov, "ymin"]) + gid = {int(r.cellID): int(r.global_cell_id) for r in df.itertuples()} + hit = _lookup_centres(lab, fov_locs, fov, n_cells, ox, oy) + expected = [gid[c] for c in range(1, n_cells + 1)] + assert hit == expected, ( + f"cell_stats labels not co-registered with transcripts (#41): got {hit}, expected {expected}" + ) + + +class TestTranscriptAlignmentE2E: + """End-to-end: transcript points must land on their cells' polygons in the + output 'global' coordinate system, regardless of the global-frame offset + (the scenario that used to trigger fov_shift). + """ + + @staticmethod + def _build(root: Path, prefix: str, poly_y_offset: float) -> Path: + """Dataset where polygon GLOBAL y sits *poly_y_offset* from the FOV + position (e.g. -SPEC_FOV_SIZE reproduces the old fov_shift trigger). + Transcripts are written at each cell's local centroid. + """ + fovs = [1, 2] + fov_pos = {1: (0.0, 0.0), 2: (float(SPEC_FOV_SIZE), 0.0)} + x_mm = {f: x / MM_TO_PX for f, (x, _) in fov_pos.items()} + y_mm = {f: y / MM_TO_PX for f, (_, y) in fov_pos.items()} + fov_df = pd.DataFrame( + { + "FOV": fovs, + "x_global_px": [fov_pos[f][0] for f in fovs], + "y_global_px": [fov_pos[f][1] for f in fovs], + "x_global_mm": [x_mm[f] for f in fovs], + "y_global_mm": [y_mm[f] for f in fovs], + } + ) + _write_csv(root / f"{prefix}_fov_positions_file.csv.gz", fov_df, compress=True) + _write_csv(root / f"{prefix}_exprMat_file.csv.gz", _make_expr_mat(fovs), compress=True) + _write_csv(root / f"{prefix}_metadata_file.csv.gz", _make_metadata(fovs), compress=True) + + # polygons globalised with an arbitrary y offset + poly_off = {f: (x, y + poly_y_offset) for f, (x, y) in fov_pos.items()} + _write_csv(root / f"{prefix}-polygons.csv.gz", _make_polygon_df(poly_off), compress=True) + + # transcripts: one per cell at the cell's LOCAL centroid + spacing = SPEC_FOV_SIZE / (N_CELLS_PER_FOV + 1) + rows = [] + for fov in fovs: + ox, oy = poly_off[fov] + for cid in range(1, N_CELLS_PER_FOV + 1): + cl = spacing * cid + rows.append( + { + "fov": fov, + "cell_ID": cid, + "x_local_px": cl, + "y_local_px": cl, + "x_global_px": cl + ox, + "y_global_px": cl + oy, + "target": f"Gene_{cid % 3}", + } + ) + _write_csv(root / f"{prefix}_tx_file.csv.gz", pd.DataFrame(rows), compress=True) + + morph = root / "Morphology2D" + for fov in fovs: + _write_morphology_tiff(morph / f"20240101_S0_F{fov:05d}.TIF", fov_size=TIFF_FOV_SIZE) + return root + + @staticmethod + def _gy_extent(elem): + import numpy as np + from spatialdata.transformations import get_transformation + + t = get_transformation(elem, "global") + m = t.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y")) + a, b = float(m[1, 1]), float(m[1, 2]) + try: # points + d = elem.compute() + ys = np.asarray(d["y"], float) + return a * ys.min() + b, a * ys.max() + b + except AttributeError: # shapes + lo = float(elem.geometry.bounds["miny"].min()) + hi = float(elem.geometry.bounds["maxy"].max()) + return a * lo + b, a * hi + b + + @pytest.mark.parametrize("offset", [0.0, -SPEC_FOV_SIZE, 3.0 * SPEC_FOV_SIZE]) + def test_transcripts_overlap_polygons(self, sdio, tmp_path, offset): + root = self._build(tmp_path, "TxAlign", offset) + # keep cell polygons as shapes (don't rasterise) so we can compare + sdata = sdio.cosmx( + root, + n_workers=1, + read_transcripts=True, + read_proteins=False, + polygons_as_labels=False, + ) + pts = next(iter(sdata.points.values())) + shp = next(v for k, v in sdata.shapes.items() if "box" not in k.lower()) + p_lo, p_hi = self._gy_extent(pts) + s_lo, s_hi = self._gy_extent(shp) + # transcript y-extent must sit within the polygon y-extent (+margin), + # i.e. no one-FOV-height drift regardless of the global offset + margin = SPEC_FOV_SIZE * 0.25 + assert p_lo >= s_lo - margin and p_hi <= s_hi + margin, ( + f"offset={offset}: points y[{p_lo:.0f},{p_hi:.0f}] not within polygons y[{s_lo:.0f},{s_hi:.0f}]" + ) + + +# ═══════════════════════════════════════════════════════════════════════════ +# 5. POLYGON READING +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestPolygonReading: + def test_valid_geometries(self, fixture_a): + from spatialdata_io.readers.cosmx._io import _read_fov_locs, _read_polygons_csv + + fl = _read_fov_locs(fixture_a / "S0_fov_positions_file.csv.gz") + pdf = _read_polygons_csv(fixture_a / "S0-polygons.csv.gz", fov_locs=fl, fov_set={1}, use_polars=True) + assert len(pdf) > 0 + assert pdf["geometry"].apply(lambda g: g.is_valid).all() + + def test_flip_y_still_valid(self, fixture_b): + from spatialdata_io.readers.cosmx._io import _read_fov_locs, _read_polygons_csv + + fl = _read_fov_locs(fixture_b / "Run5642_S3_Quarter_fov_positions_file.csv") + pdf = _read_polygons_csv( + fixture_b / "Run5642_S3_Quarter-polygons.csv", fov_locs=fl, fov_set={1}, use_polars=True + ) + assert len(pdf) > 0 + assert pdf["geometry"].apply(lambda g: g.is_valid).all() + + def test_coords_finite(self, fixture_a): + from spatialdata_io.readers.cosmx._io import _read_fov_locs, _read_polygons_csv + + fl = _read_fov_locs(fixture_a / "S0_fov_positions_file.csv.gz") + pdf = _read_polygons_csv(fixture_a / "S0-polygons.csv.gz", fov_locs=fl, fov_set={1, 2}, use_polars=True) + for g in pdf["geometry"]: + assert all(math.isfinite(v) for v in g.bounds) + + def test_cell_count_per_fov(self, fixture_a): + from spatialdata_io.readers.cosmx._io import _read_fov_locs, _read_polygons_csv + + fl = _read_fov_locs(fixture_a / "S0_fov_positions_file.csv.gz") + pdf = _read_polygons_csv(fixture_a / "S0-polygons.csv.gz", fov_locs=fl, fov_set={1}, use_polars=True) + assert pdf["cell_ID"].nunique() == N_CELLS_PER_FOV + + def test_multi_polygon_cells_merged(self, tmp_path): + """Cells with multiple polygon_index values are merged into MultiPolygon.""" + import shapely.geometry as sgeom + + from spatialdata_io.readers.cosmx._io import _read_polygons_csv + + # Build a polygon CSV where cell_ID=1 has TWO polygon parts + spacing = SPEC_FOV_SIZE / (N_CELLS_PER_FOV + 1) + rows = [] + # Part 1 of cell 1 + for pt in _make_polygon_rows(1, 1, spacing, spacing, r=50.0): + pt["polygon_index"] = 0 + rows.append(pt) + # Part 2 of cell 1 (offset) + for pt in _make_polygon_rows(1, 1, spacing + 200, spacing + 200, r=50.0): + pt["polygon_index"] = 1 + rows.append(pt) + # Cell 2 — single polygon + for pt in _make_polygon_rows(1, 2, spacing * 2, spacing * 2, r=50.0): + pt["polygon_index"] = 0 + rows.append(pt) + + poly_csv = tmp_path / "multi_poly.csv" + _write_csv(poly_csv, pd.DataFrame(rows)) + + fl = pd.DataFrame( + { + "xmin": [0.0], + "ymin": [0.0], + "xmax": [float(SPEC_FOV_SIZE)], + "ymax": [float(SPEC_FOV_SIZE)], + }, + index=[1], + ) + fl.index.name = "fov" + + pdf = _read_polygons_csv(poly_csv, fov_locs=fl, fov_set={1}, use_polars=True) + + # Should have exactly 2 rows: one per cell + assert len(pdf) == 2, f"Expected 2 rows (one per cell), got {len(pdf)}" + assert set(pdf["cell_ID"]) == {1, 2} + + # Cell 1 should be a MultiPolygon, cell 2 a Polygon + cell1 = pdf[pdf["cell_ID"] == 1].iloc[0]["geometry"] + cell2 = pdf[pdf["cell_ID"] == 2].iloc[0]["geometry"] + assert isinstance(cell1, sgeom.MultiPolygon), f"Expected MultiPolygon, got {type(cell1)}" + assert len(cell1.geoms) == 2 + assert isinstance(cell2, sgeom.Polygon) + + +# ═══════════════════════════════════════════════════════════════════════════ +# 6. END-TO-END NORMALIZED OUTPUT +# ═══════════════════════════════════════════════════════════════════════════ + + +def _assert_normalized(sdata): + """Invariants that must hold for ALL fixture outputs.""" + from spatialdata import SpatialData + + assert isinstance(sdata, SpatialData) + + # tables + assert len(sdata.tables) > 0 + for _name, tbl in sdata.tables.items(): + assert tbl.n_obs > 0 + assert tbl.n_vars > 0 + assert "global_cell_id" in tbl.obs.columns + assert (tbl.obs["global_cell_id"] != 0).all(), "background rows should be filtered" + assert "region_key" in tbl.obs.columns + # region_key is allowed to be categorical (spatialdata enforces it) + for col in tbl.obs.columns: + if col == "region_key": + continue + assert not isinstance(tbl.obs[col].dtype, pd.CategoricalDtype), f"categorical {col} will break zarr" + + # shapes + assert len(sdata.shapes) > 0 + fov_box_keys = [k for k in sdata.shapes if "fov_box" in k] + assert len(fov_box_keys) > 0 + + +class TestNormalizedOutput: + """All fixtures produce a standardized SpatialData after cosmx().""" + + def test_fixture_a_adjacent(self, sdio, fixture_a): + """Two adjacent FOVs: images should stitch, labels/tables should align.""" + sdata = _read(sdio, fixture_a, fovs=[1, 2]) + _assert_normalized(sdata) + assert len(sdata.images) > 0 + assert len(sdata.labels) > 0 + + def test_fixture_b_non_adjacent(self, sdio, fixture_b): + """Two non-adjacent FOVs: gap between FOVs should be handled.""" + sdata = _read(sdio, fixture_b, fovs=[1, 5]) + _assert_normalized(sdata) + assert len(sdata.images) > 0 + assert len(sdata.labels) > 0 + + def test_fixture_c_single_fov(self, sdio, fixture_c): + """Single FOV with CellLabels (no Morphology2D).""" + sdata = _read(sdio, fixture_c, fovs=[1], read_images=False) + _assert_normalized(sdata) + assert len(sdata.labels) > 0 + + def test_single_fov_from_multi(self, sdio, fixture_a): + """Selecting 1 FOV from a multi-FOV dataset should use F00001_ prefix.""" + sdata = _read(sdio, fixture_a, fovs=[1]) + _assert_normalized(sdata) + all_names = list(sdata.images.keys()) + list(sdata.labels.keys()) + list(sdata.shapes.keys()) + assert any("F00001" in n for n in all_names), f"Expected F00001 prefix, got {all_names}" + + def test_polygons_as_shapes(self, sdio, fixture_a): + """polygons_as_labels=False → shape elements contain cell polygons.""" + sdata = _read(sdio, fixture_a, fovs=[1, 2], polygons_as_labels=False) + poly_keys = [k for k in sdata.shapes if "cells_polygons" in k] + assert len(poly_keys) > 0 + gdf = sdata.shapes[poly_keys[0]] + assert len(gdf) > 0 + assert gdf.geometry.is_valid.all() + + +# ═══════════════════════════════════════════════════════════════════════════ +# 7. FOV PLACEMENT GEOMETRY +# ═══════════════════════════════════════════════════════════════════════════ + + +def _img_shape(sdata): + """Get (C, Y, X) shape from the first image (may be DataTree or DataArray).""" + img = next(iter(sdata.images.values())) + # Image2DModel with scale_factors produces a DataTree + if hasattr(img, "ds"): + # xarray DataTree: base level is scale0 + arr = img["scale0"].ds["image"] + elif hasattr(img, "shape"): + arr = img + else: + # try to get the base resolution from the tree + arr = img[list(img.keys())[0]].ds[list(img[list(img.keys())[0]].ds.keys())[0]] + return tuple(arr.shape) + + +class TestLabelAnchoring: + """Co-registration is label-anchored: the segmentation is the ground truth.""" + + def test_segmented_but_unquantified_cell_kept(self, sdio, tmp_path): + # A polygon with no expression row (cell 6, the highest cell_ID) is an + # orphan. It must be KEPT in the labels (segmentation = ground truth) while + # the table holds only the 5 quantified cells. Its high cell_ID also guards + # the max_cell_id prescan, which must read the polygon 'cellID' column. + import dask.array as da + + root = tmp_path / "orphan" + root.mkdir() + fov_pos = {1: (0.0, 0.0)} + fov_df = pd.DataFrame( + {"FOV": [1], "x_global_px": [0.0], "y_global_px": [0.0], "x_global_mm": [0.0], "y_global_mm": [0.0]} + ) + _write_csv(root / "S0_fov_positions_file.csv.gz", fov_df, compress=True) + _write_csv(root / "S0-polygons.csv.gz", _make_polygon_df(fov_pos, n_cells=6), compress=True) + _write_csv(root / "S0_exprMat_file.csv.gz", _make_expr_mat([1], n_cells=5), compress=True) + _write_csv(root / "S0_metadata_file.csv.gz", _make_metadata([1], n_cells=5), compress=True) + + sd = sdio.cosmx(root, fovs=[1], read_images=False, read_transcripts=False, read_proteins=False, n_workers=1) + lab = next(iter(sd.labels.values())) + arr = lab.data if hasattr(lab, "data") else lab["scale0"]["image"].data + n_label = int((np.asarray(da.unique(arr).compute()) != 0).sum()) + n_table = next(iter(sd.tables.values())).n_obs + assert n_label == 6, f"labels must keep all 6 segmented cells (incl. the orphan), got {n_label}" + assert n_table == 5, f"table must hold only the 5 quantified cells, got {n_table}" + + +class TestFovPlacement: + """Verify stitching geometry for adjacent, non-adjacent, and single FOV.""" + + def test_adjacent_stitched_width(self, sdio, fixture_a): + """Two side-by-side FOVs → stitched canvas ≈ 2 * SPEC_FOV_SIZE wide.""" + import dask.array as da + + sdata = _read(sdio, fixture_a, fovs=[1, 2]) + if not sdata.images: + pytest.skip("No images") + c, h, w = _img_shape(sdata) + # _snap_fov_grid forces 4256, so canvas = 2*4256 + assert w > SPEC_FOV_SIZE, f"Expected width > {SPEC_FOV_SIZE}, got {w}" + assert w <= 2 * SPEC_FOV_SIZE + 10 + + # Verify the image has actual data, not all zeros + img = next(iter(sdata.images.values())) + if hasattr(img, "ds"): + arr = img["scale0"].ds["image"] + else: + arr = img + data = arr.data if isinstance(arr.data, da.Array) else arr + # Check a region that should have FOV data + sample = data[0, :TIFF_FOV_SIZE, :TIFF_FOV_SIZE].compute() + assert sample.max() > 0, "Stitched image is all zeros in FOV 1 region" + + def test_non_adjacent_has_gap(self, sdio, fixture_b): + """Non-adjacent FOVs → canvas larger than 2 * SPEC_FOV_SIZE.""" + sdata = _read(sdio, fixture_b, fovs=[1, 5]) + if not sdata.images: + pytest.skip("No images") + c, h, w = _img_shape(sdata) + assert w > 2 * SPEC_FOV_SIZE, f"Non-adjacent gap not reflected: w={w}" + + def test_single_fov_size(self, sdio, fixture_a): + """Single FOV → image matches the actual TIFF size (no snapping).""" + sdata = _read(sdio, fixture_a, fovs=[1]) + if not sdata.images: + pytest.skip("No images") + c, h, w = _img_shape(sdata) + assert h == TIFF_FOV_SIZE + assert w == TIFF_FOV_SIZE + + def test_fov_boxes_geometry(self, sdio, fixture_a): + """FOV boxes should be SPEC_FOV_SIZE x SPEC_FOV_SIZE squares.""" + sdata = _read(sdio, fixture_a, fovs=[1, 2]) + fov_keys = [k for k in sdata.shapes if "fov_box" in k] + assert len(fov_keys) == 1 + gdf = sdata.shapes[fov_keys[0]] + assert len(gdf) == 2 + for geom in gdf.geometry: + minx, miny, maxx, maxy = geom.bounds + w = maxx - minx + h = maxy - miny + assert abs(w - SPEC_FOV_SIZE) < 1 + assert abs(h - SPEC_FOV_SIZE) < 1 + + def test_non_adjacent_fov_boxes_separated(self, sdio, fixture_b): + """Non-adjacent FOV boxes should not overlap.""" + sdata = _read(sdio, fixture_b, fovs=[1, 5]) + fov_keys = [k for k in sdata.shapes if "fov_box" in k] + assert len(fov_keys) == 1 + gdf = sdata.shapes[fov_keys[0]] + assert len(gdf) == 2 + box1, box2 = gdf.geometry.iloc[0], gdf.geometry.iloc[1] + assert not box1.intersects(box2), "Non-adjacent FOV boxes should not overlap" + + +def _img_array(sdata): + """Materialize the first image element as a numpy array (C, Y, X).""" + img = next(iter(sdata.images.values())) + if hasattr(img, "ds"): + arr = img["scale0"].ds["image"] + elif hasattr(img, "shape"): + arr = img + else: + arr = img[list(img.keys())[0]].ds[list(img[list(img.keys())[0]].ds.keys())[0]] + data = arr.data + return np.asarray(data.compute() if hasattr(data, "compute") else data) + + +class TestImageFlipOrientation: + """Morphology image orientation must not depend on whether polygons are read. + + Regression for #42: the morphology vertical flip was gated on + ``align_rasters_to_polygons`` (which is True only when polygons/labels are + read), so px-only (``flip_y=True``) images were flipped when polygons were + read but left mirrored otherwise. The flip is now applied consistently + (``flip_image`` defaults to True), independent of polygon reading. + + fixture_b is px-only (``flip_y=True``) with random-valued Morphology2D TIFFs, + so a vertical flip is detectable. Pre-fix: read_polygons=True flips the image + while read_polygons=False does not, so the two arrays differ. Post-fix: both + are flipped, so the arrays are identical. + """ + + def test_orientation_independent_of_polygons(self, sdio, fixture_b): + common = { + "n_workers": 1, + "read_transcripts": False, + "read_proteins": False, + "read_gexp": False, + "read_labels": False, + } + sd_with = sdio.cosmx(fixture_b, read_polygons=True, **common) + sd_without = sdio.cosmx(fixture_b, read_polygons=False, **common) + if not sd_with.images or not sd_without.images: + pytest.skip("No images") + a_with = _img_array(sd_with) + a_without = _img_array(sd_without) + assert a_with.shape == a_without.shape, ( + f"image shape changed with read_polygons: {a_with.shape} vs {a_without.shape}" + ) + assert np.array_equal(a_with, a_without), ( + "morphology image orientation changed depending on read_polygons (#42)" + ) + + def test_explicit_flip_image_false_respected(self, sdio, fixture_b): + """An explicit flip_image=False must override the default flip. + + Single FOV so a per-tile flip equals a whole-canvas flip (the reader + flips each FOV tile, not the stitched canvas). + """ + common = { + "n_workers": 1, + "fovs": [1], + "read_transcripts": False, + "read_proteins": False, + "read_gexp": False, + "read_labels": False, + "read_polygons": False, + } + a_default = _img_array(sdio.cosmx(fixture_b, **common)) # flip_image=None -> True + a_noflip = _img_array(sdio.cosmx(fixture_b, flip_image=False, **common)) + assert np.array_equal(a_default, a_noflip[:, ::-1, :]), ( + "flip_image=False should yield the un-flipped image (vertical mirror of the default flipped one)" + ) + + +class TestMultimodalImageFlip: + """The morphology flip must also be applied on the multimodal (RNA+Protein) + read path. + + Regression guard for #42: the `flip_image` default (True) is resolved at the + top of `cosmx()`, BEFORE the multimodal dispatch — resolving it afterwards + left multimodal `flip_image=None -> False`, so multimodal morphology images + came out un-flipped (mirrored). The single-modality flip tests above do NOT + exercise `_cosmx_multi`, so this dedicated multimodal fixture is needed. + + fixture_multimodal's morphology is asymmetric (top 100 / bottom 200); a + correct vertical flip makes the stored top half hold the larger (bottom) + value, which we detect via top- vs bottom-region means (robust to the + reader's float intensity normalization). + """ + + def test_dispatches_to_multimodal(self, fixture_multimodal): + from spatialdata_io.readers.cosmx._discovery import ( + _set_up_cosmx_dataset_for_conversion, + ) + + ds = _set_up_cosmx_dataset_for_conversion(fixture_multimodal) + assert ds.modalities is not None, "fixture did not register as multimodal" + assert set(ds.modalities) == {"RNA", "Protein"} + + def test_multimodal_morphology_is_flipped(self, sdio, fixture_multimodal): + sdata = sdio.cosmx( + fixture_multimodal, + fovs=[1], + n_workers=1, + read_transcripts=False, + read_proteins=False, + read_gexp=False, + read_labels=False, + read_polygons=True, + ) + if not sdata.images: + pytest.skip("No images") + arr = _img_array(sdata) # (C, Y, X), float-normalized + h = arr.shape[1] + top = float(arr[0, : h // 2, :].mean()) + bottom = float(arr[0, h // 2 :, :].mean()) + # raw top=100, bottom=200; after the vertical flip top must hold the + # larger (originally-bottom) value. + assert top > bottom, ( + "multimodal morphology image was not vertically flipped (#42): the " + f"flip default did not reach the multimodal path (top={top:.5f} " + f"!> bottom={bottom:.5f})" + ) + + +# ═══════════════════════════════════════════════════════════════════════════ +# 8. PREVIEW +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestPreview: + def test_returns_none(self, sdio, fixture_a): + import matplotlib + + matplotlib.use("Agg") + assert sdio.cosmx(fixture_a, preview_fovs=True) is None + + def test_with_subset(self, sdio, fixture_a): + import matplotlib + + matplotlib.use("Agg") + assert sdio.cosmx(fixture_a, fovs=[1], preview_fovs=True) is None + + +# ═══════════════════════════════════════════════════════════════════════════ +# 9. ZARR ROUND-TRIP +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestZarrRoundTrip: + def test_write_read(self, sdio, fixture_a, tmp_path): + import dask.array as da + import spatialdata as sd + + sdata = _read(sdio, fixture_a, fovs=[1], read_images=False) + + # Verify labels have non-zero values BEFORE writing + if sdata.labels: + lbl = next(iter(sdata.labels.values())) + lbl_np = lbl.data.compute() if isinstance(lbl.data, da.Array) else np.asarray(lbl) + assert lbl_np.max() > 0, "Labels are all zeros before writing" + + zarr_path = tmp_path / "test.zarr" + sdata.write(zarr_path) + sdata2 = sd.read_zarr(zarr_path) + assert set(sdata.labels.keys()) == set(sdata2.labels.keys()) + assert set(sdata.shapes.keys()) == set(sdata2.shapes.keys()) + assert set(sdata.tables.keys()) == set(sdata2.tables.keys()) + for name in sdata.tables: + assert sdata.tables[name].n_obs == sdata2.tables[name].n_obs + + # Verify labels survived round-trip with non-zero values + if sdata2.labels: + lbl2 = next(iter(sdata2.labels.values())) + lbl2_np = lbl2.data.compute() if isinstance(lbl2.data, da.Array) else np.asarray(lbl2) + assert lbl2_np.max() > 0, "Labels are all zeros after round-trip" + assert set(np.unique(lbl2_np)) == set(np.unique(lbl_np)), "Label IDs changed after round-trip" + + +# ═══════════════════════════════════════════════════════════════════════════ +# 10. SNAPSHOT VALUES +# ═══════════════════════════════════════════════════════════════════════════ + + +# ═══════════════════════════════════════════════════════════════════════════ +# skip_empty_fovs / phantom FOVs (issue #37) +# ═══════════════════════════════════════════════════════════════════════════ + + +@pytest.fixture(scope="session") +def fixture_phantom(tmp_path_factory) -> Path: + """Positions list FOV 1 (phantom: no data files) + FOV 2 (real). px-only. + + Only FOV 2 ships a Morphology2D TIFF and a CellLabels TIFF, so a correct + reader should size both rasters to one FOV. Phantom FOV 1 sits at the origin + (smaller coords) so leaving it in inflates the image canvas (#37). + """ + root = tmp_path_factory.mktemp("fixture_phantom") + prefix = "Run_Ph" + pos_fovs = [1, 2] + real = [2] + fov_pos = {1: (0.0, 0.0), 2: (float(SPEC_FOV_SIZE), 0.0)} + + fov_df = pd.DataFrame( + { + "fov": pos_fovs, + "x_global_px": [fov_pos[f][0] for f in pos_fovs], + "y_global_px": [fov_pos[f][1] for f in pos_fovs], + } + ) + _write_csv(root / f"{prefix}_fov_positions_file.csv", fov_df) + _write_csv(root / f"{prefix}_exprMat_file.csv", _make_expr_mat(real)) + _write_csv(root / f"{prefix}_metadata_file.csv", _make_metadata(real)) + + morph = root / "Morphology2D" + labels = root / "CellLabels" + for fov in real: + _write_morphology_tiff(morph / f"20240101_S0_F{fov:05d}.TIF", fov_size=TIFF_FOV_SIZE) + _write_cell_label_tiff(labels / f"CellLabels_F{fov:03d}.tif") + return root + + +@pytest.fixture(scope="session") +def fixture_no_perfov(tmp_path_factory) -> Path: + """Positions for 2 FOVs but NO per-FOV image/label files (transcript/gexp + style). ``detect_fovs_with_data`` returns empty → ``skip_empty_fovs`` must + keep all FOVs rather than dropping the whole dataset (#37 safe rule). + """ + root = tmp_path_factory.mktemp("fixture_noperfov") + prefix = "Run_NP" + fovs = [1, 2] + fov_pos = {1: (0.0, 0.0), 2: (float(SPEC_FOV_SIZE), 0.0)} + fov_df = pd.DataFrame( + { + "fov": fovs, + "x_global_px": [fov_pos[f][0] for f in fovs], + "y_global_px": [fov_pos[f][1] for f in fovs], + } + ) + _write_csv(root / f"{prefix}_fov_positions_file.csv", fov_df) + _write_csv(root / f"{prefix}_exprMat_file.csv", _make_expr_mat(fovs)) + _write_csv(root / f"{prefix}_metadata_file.csv", _make_metadata(fovs)) + return root + + +def _label_shape(sdata): + """(Y, X) shape of the first label element.""" + lbl = next(iter(sdata.labels.values())) + if hasattr(lbl, "ds"): + arr = lbl["scale0"].ds["image"] + elif hasattr(lbl, "shape"): + arr = lbl + else: + arr = lbl[list(lbl.keys())[0]].ds[list(lbl[list(lbl.keys())[0]].ds.keys())[0]] + return tuple(arr.shape) + + +def _global_x(elem): + """Global-transform x translation of a raster element.""" + from spatialdata.transformations import get_transformation + + t = get_transformation(elem, "global") + return float(t.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))[0, 2]) + + +class TestSkipEmptyFovs: + """Regression for #37: phantom FOVs (positions but no data files) must not + inflate the image canvas, desync the image/label canvases, or add empty + FOV boxes. + """ + + _COMMON = { + "n_workers": 1, + "read_transcripts": False, + "read_proteins": False, + "read_gexp": False, + "read_polygons": False, + "polygons_as_labels": False, + "read_images": True, + "read_labels": True, + } + + def test_prune_tightens_and_matches_labels(self, sdio, fixture_phantom): + sd = sdio.cosmx(fixture_phantom, skip_empty_fovs=True, **self._COMMON) + _, ih, iw = _img_shape(sd) + lh, lw = _label_shape(sd) + assert iw <= SPEC_FOV_SIZE + 10, f"image canvas not tightened: {iw}" + assert (ih, iw) == (lh, lw), f"image {ih}x{iw} != label {lh}x{lw}" + # both rasters land on the real FOV 2 (global x = SPEC_FOV_SIZE) + assert abs(_global_x(next(iter(sd.images.values()))) - SPEC_FOV_SIZE) <= 1 + assert abs(_global_x(next(iter(sd.labels.values()))) - SPEC_FOV_SIZE) <= 1 + + def test_prune_drops_phantom_fov_box(self, sdio, fixture_phantom): + sd = sdio.cosmx(fixture_phantom, skip_empty_fovs=True, add_fovs_as_shapes=True, **self._COMMON) + keys = [k for k in sd.shapes if "fov_box" in k] + assert keys and len(sd.shapes[keys[0]]) == 1, "phantom FOV box not pruned" + + def test_no_skip_still_tightens_image_via_seen_filter(self, sdio, fixture_phantom): + # Part 2: even with skip_empty_fovs=False, the image canvas is tightened + # to image-bearing FOVs and stays co-registered with the labels. + sd = sdio.cosmx(fixture_phantom, skip_empty_fovs=False, **self._COMMON) + _, ih, iw = _img_shape(sd) + lh, lw = _label_shape(sd) + assert iw <= SPEC_FOV_SIZE + 10, f"image canvas not tightened w/o skip: {iw}" + assert (ih, iw) == (lh, lw) + assert abs(_global_x(next(iter(sd.images.values()))) - SPEC_FOV_SIZE) <= 1 + + def test_no_skip_keeps_phantom_fov_box(self, sdio, fixture_phantom): + # Read without labels: the label path independently prunes fov_locs to + # seen labels, so this isolates the skip_empty_fovs=False behaviour + # (phantom FOV kept in fov_locs → still gets a box). + sd = sdio.cosmx( + fixture_phantom, + skip_empty_fovs=False, + add_fovs_as_shapes=True, + n_workers=1, + read_transcripts=False, + read_proteins=False, + read_gexp=False, + read_polygons=False, + polygons_as_labels=False, + read_images=True, + read_labels=False, + ) + keys = [k for k in sd.shapes if "fov_box" in k] + assert keys and len(sd.shapes[keys[0]]) == 2, "skip_empty_fovs=False must keep every listed FOV box" + + def test_explicit_phantom_request_not_read_as_all(self, sdio, fixture_phantom): + # Requesting only the phantom FOV must NOT collapse to "all FOVs": the + # reader should produce no image (FOV 1 has none), not silently load FOV 2. + sd = sdio.cosmx( + fixture_phantom, + skip_empty_fovs=True, + fovs=[1], + n_workers=1, + read_transcripts=False, + read_proteins=False, + read_gexp=False, + read_polygons=False, + polygons_as_labels=False, + read_images=True, + read_labels=False, + ) + assert not sd.images, "requesting a phantom FOV silently read other FOVs' images" + + def test_no_perfov_files_keeps_all_fovs(self, sdio, fixture_no_perfov): + # detection finds no per-FOV files → must NOT prune (would nuke the set) + sd = sdio.cosmx( + fixture_no_perfov, + skip_empty_fovs=True, + n_workers=1, + read_images=False, + read_labels=False, + read_transcripts=False, + read_gexp=False, + read_polygons=False, + polygons_as_labels=False, + add_fovs_as_shapes=True, + ) + keys = [k for k in sd.shapes if "fov_box" in k] + assert keys and len(sd.shapes[keys[0]]) == 2 + + +class TestDetectFovsWithData: + """Unit tests for the per-FOV data detector (#37).""" + + def test_detects_across_sources(self, tmp_path): + from spatialdata_io.readers.cosmx._utils import detect_fovs_with_data + + morph = tmp_path / "Morphology2D" + _write_morphology_tiff(morph / "20240101_S0_F00003.TIF", fov_size=TIFF_FOV_SIZE) + labels = tmp_path / "CellLabels" + _write_cell_label_tiff(labels / "CellLabels_F007.tif") + # CellStatsDir: a FOV with a label TIF counts; an empty FOV dir does NOT. + _write_cell_label_tiff(tmp_path / "CellStatsDir" / "FOV010" / "CellLabels_F010.tif") + (tmp_path / "CellStatsDir" / "FOV011").mkdir(parents=True) + got = detect_fovs_with_data( + morphology_2d_dir=morph, + cell_labels_dir=labels, + cell_stats_dir=tmp_path / "CellStatsDir", + ) + assert got == {3, 7, 10} # 11 excluded: empty dir, no TIF + + def test_empty_when_no_files(self, tmp_path): + from spatialdata_io.readers.cosmx._utils import detect_fovs_with_data + + assert detect_fovs_with_data(morphology_2d_dir=tmp_path / "missing") == set() + + +# ═══════════════════════════════════════════════════════════════════════════ +# IMAGE NORMALIZATION (issue #38) — opt-in, scale-only +# ═══════════════════════════════════════════════════════════════════════════ + + +def _write_low_signal_morphology_tiff( + path: Path, n_channels: int = N_CHANNELS, fov_size: int = TIFF_FOV_SIZE, bg: int = 100, bright: int = 5000 +): + """Morphology TIFF whose real signal (``bright``) sits FAR below the uint16 + ceiling (65535) — so dtype-max scaling renders it near-black (#38) while a + per-channel percentile stretch recovers it. A quarter-FOV block is bright + (well above the 99.9th-percentile floor); the rest is low background. + """ + import tifffile + + path.parent.mkdir(parents=True, exist_ok=True) + data = np.full((n_channels, fov_size, fov_size), bg, dtype=np.uint16) + q = fov_size // 4 + data[:, :q, :q] = bright + tifffile.imwrite(str(path), data, description=_tiff_description(), photometric="minisblack") + + +@pytest.fixture(scope="session") +def fixture_lowsig(tmp_path_factory) -> Path: + """px+mm, single modality, TWO ADJACENT FOVs with LOW-signal morphology + (bright=5000 ≪ 65535) — exercises the #38 stretch end-to-end. + """ + root = tmp_path_factory.mktemp("fixture_lowsig") + prefix = "S0" + fovs = [1, 2] + fov_pos = {1: (0.0, 0.0), 2: (float(SPEC_FOV_SIZE), 0.0)} + x_mm = {f: x / MM_TO_PX for f, (x, _) in fov_pos.items()} + y_mm = {f: y / MM_TO_PX for f, (_, y) in fov_pos.items()} + fov_df = pd.DataFrame( + { + "FOV": fovs, + "x_global_px": [fov_pos[f][0] for f in fovs], + "y_global_px": [fov_pos[f][1] for f in fovs], + "x_global_mm": [x_mm[f] for f in fovs], + "y_global_mm": [y_mm[f] for f in fovs], + } + ) + _write_csv(root / f"{prefix}_fov_positions_file.csv.gz", fov_df, compress=True) + _write_csv(root / f"{prefix}_exprMat_file.csv.gz", _make_expr_mat(fovs), compress=True) + _write_csv(root / f"{prefix}_metadata_file.csv.gz", _make_metadata(fovs), compress=True) + _write_csv(root / f"{prefix}-polygons.csv.gz", _make_polygon_df(fov_pos), compress=True) + morph = root / "Morphology2D" + for fov in fovs: + _write_low_signal_morphology_tiff(morph / f"20240101_S0_F{fov:05d}.TIF", fov_size=TIFF_FOV_SIZE) + return root + + +class TestImageNormalizationPrimitive: + """Unit tests for ``_normalize_image_channels`` (issue #38).""" + + @staticmethod + def _norm(arr, names, **kw): + import dask.array as da + + from spatialdata_io.readers.cosmx._utils import _normalize_image_channels + + return _normalize_image_channels(da.from_array(arr, chunks=(1,) + arr.shape[1:]), names, **kw) + + def test_none_is_passthrough(self): + # Opt-in: percentile=None leaves the (already dtype-max normalized) image untouched. + ch = np.full((1, 64, 64), 0.0015, dtype="float32") + ch[0, :8, :8] = 0.05 + out, scales = self._norm(ch, ["DNA"], percentile=None) + out = np.asarray(out.compute()) + assert out.dtype == np.float32 and scales == {} + assert np.array_equal(out, ch) + + def test_percentile_recovers_low_signal(self): + # Regression for #38: signal at 3000 (≪ ceiling) reaches ~1.0 via its own percentile. + ch = np.full((1, 128, 128), 50, dtype="uint16") + ch[0, :16, :16] = 3000 + out, scales = self._norm(ch, ["DNA"], percentile=99.9) + out = np.asarray(out.compute()) + assert out.dtype == np.float32 + assert out.max() > 0.9, f"low-signal channel not recovered: max={out.max()}" + assert scales["DNA"] > 0 + + def test_scale_only_is_reversible(self): + # No clipping: the brightest pixels (above the percentile) exceed 1.0, and + # multiplying back by the divisor recovers the input exactly. + ch = np.full((1, 64, 64), 40, dtype="uint16") + ch[0, :10, :10] = 5000 # bulk signal -> ~99.9th percentile + ch[0, 0, :2] = 50000 # a few pixels brighter than the percentile + out, scales = self._norm(ch, ["DNA"], percentile=99.9) + out = np.asarray(out.compute()) + assert out.max() > 1.0, "scale-only must not clip the brightest pixels to 1.0" + np.testing.assert_allclose(out * scales["DNA"], ch.astype("float32"), rtol=1e-4) + + def test_multi_chunk_percentile(self): + # Production stitched images are multi-chunk, where da.percentile is approximate. + # Exercise that path: it must run, recover signal, and stay exactly reversible. + import dask.array as da + + from spatialdata_io.readers.cosmx._utils import _normalize_image_channels + + arr = np.full((1, 300, 300), 40, dtype="uint16") + arr[0, :30, :30] = 8000 + out, scales = _normalize_image_channels(da.from_array(arr, chunks=(1, 128, 128)), ["DNA"], percentile=99.9) + out = np.asarray(out.compute()) + assert out.dtype == np.float32 and scales["DNA"] > 0 + assert out.max() > 0.5, "signal not recovered on the multi-chunk path" + np.testing.assert_allclose(out * scales["DNA"], arr.astype("float32"), rtol=1e-4) + + def test_zeros_ignored_in_percentile(self): + # Zero padding (inter-FOV gaps) must not bias the percentile toward 0. + ch = np.zeros((1, 64, 64), dtype="uint16") # mostly zero "canvas" + ch[0, :8, :8] = 100 # a single covered FOV: background + ch[0, :2, :2] = 5000 # signal within it + _, scales = self._norm(ch, ["DNA"], percentile=99.9) + assert scales["DNA"] > 100, "percentile collapsed onto the zero padding" + + def test_empty_channel_left_unscaled_without_crash(self): + zeros = np.zeros((1, 32, 32), dtype="uint16") + out, scales = self._norm(zeros, ["empty"], percentile=99.9) + out = np.asarray(out.compute()) + assert out.max() == 0.0 and scales["empty"] == 1.0 # unscaled fallback, no crash + + def test_float_channel_with_nan_does_not_propagate(self): + fl = np.full((1, 16, 16), 5.0, dtype="float32") + fl[0, 0, 0] = np.nan + out, _ = self._norm(fl, ["f"], percentile=99.9) + out = np.asarray(out.compute()) + assert np.isnan(out).sum() == 1, "a single NaN must not spread to the whole channel" + + +class TestImageNormalizationEndToEnd: + """End-to-end through ``read_images`` — single-, multi-FOV, and multimodal paths (#38).""" + + def test_default_is_legacy_dim(self, sdio, fixture_lowsig): + # Opt-in: the default (None) keeps dtype-max scaling, so low signal stays dim. + arr = _img_array(_read(sdio, fixture_lowsig, fovs=[1])) + assert arr.dtype == np.float32 + assert arr.max() < 0.2, "default must keep legacy dtype-max (dim) scaling" + + def test_percentile_recovers_single_fov(self, sdio, fixture_lowsig): + arr = _img_array(_read(sdio, fixture_lowsig, fovs=[1], image_normalization_percentile=99.9)) + assert arr.max() > 0.9, "opt-in percentile did not recover the low-signal single-FOV image (#38)" + + def test_percentile_recovers_stitched_multi_fov(self, sdio, fixture_lowsig): + arr = _img_array(_read(sdio, fixture_lowsig, fovs=[1, 2], image_normalization_percentile=99.9)) + assert arr.dtype == np.float32 + assert arr.max() > 0.9, "opt-in percentile did not reach the stitched multi-FOV image (#38)" + + def test_multimodal_percentile_reaches_morphology(self, sdio, fixture_multimodal): + # The param must thread through _cosmx_multi; fixture morphology is low-signal. + kw = { + "fovs": [1], + "n_workers": 1, + "read_transcripts": False, + "read_proteins": False, + "read_gexp": False, + "read_labels": False, + "read_polygons": True, + } + arr = _img_array(sdio.cosmx(fixture_multimodal, image_normalization_percentile=99.9, **kw)) + assert arr.max() > 0.9, "opt-in normalization did not reach the multimodal morphology image" + arr0 = _img_array(sdio.cosmx(fixture_multimodal, **kw)) + assert arr0.max() < 0.05, "default should keep legacy (dim) scaling on the multimodal path" + + def test_scales_recorded_reversible_and_survive_round_trip(self, sdio, fixture_lowsig, tmp_path): + import spatialdata as sd_mod + + sd = _read(sdio, fixture_lowsig, fovs=[1], image_normalization_percentile=99.9) + meta = sd.attrs.get("cosmx_image_normalization") + assert meta, "per-channel normalization divisors not recorded in sdata.attrs" + (img_meta,) = list(meta.values()) + assert img_meta["percentile"] == 99.9 and img_meta["channel_scales"] + # Reversible: normalized * divisor recovers the dtype-max image read with percentile=None. + a_norm = np.asarray(_img_array(sd)) + a_raw = np.asarray(_img_array(_read(sdio, fixture_lowsig, fovs=[1]))) + divs = np.array(list(img_meta["channel_scales"].values()), dtype="float32")[:, None, None] + np.testing.assert_allclose(a_norm * divs, a_raw, rtol=1e-3, atol=1e-4) + zp = tmp_path / "norm.zarr" + sd.write(zp) + assert sd_mod.read_zarr(zp).attrs.get("cosmx_image_normalization"), "normalization attrs lost on round-trip" + + def test_out_of_range_percentile_raises(self, sdio, fixture_lowsig): + with pytest.raises(ValueError, match="image_normalization_percentile"): + _read(sdio, fixture_lowsig, fovs=[1], image_normalization_percentile=150.0)