Skip to content
2 changes: 2 additions & 0 deletions src/spatialdata_plot/pl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from ._palette import make_palette, make_palette_from_data
from .basic import PlotAccessor
from .utils import measure_obs

__all__ = [
"PlotAccessor",
"make_palette",
"make_palette_from_data",
"measure_obs",
]
313 changes: 312 additions & 1 deletion src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import spatialdata as sd
from anndata import AnnData
from cycler import Cycler, cycler
from dask.array.core import slices_from_chunks
from datashader.core import Canvas
from geopandas import GeoDataFrame
from matplotlib import colors, patheffects, rcParams
Expand Down Expand Up @@ -59,6 +60,7 @@
from skimage.util import map_array
from spatialdata import (
SpatialData,
deepcopy as sd_deepcopy,
get_element_annotators,
get_extent,
get_values,
Expand All @@ -67,7 +69,14 @@
)
from spatialdata._core.query.relational_query import _locate_value
from spatialdata._types import ArrayLike
from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement, get_table_keys
from spatialdata.models import (
Image2DModel,
Labels2DModel,
ShapesModel,
SpatialElement,
get_model,
get_table_keys,
)
from spatialdata.transformations.operations import get_transformation
from spatialdata.transformations.transformations import Scale, Translation
from spatialdata.transformations.transformations import Sequence as TransformSequence
Expand Down Expand Up @@ -4417,3 +4426,305 @@ def _convert_alpha_to_datashader_range(alpha: float) -> float:
"""Convert alpha from the range [0, 1] to the range [0, 255] used in datashader."""
# prevent a value of 255, bc that led to fully colored test plots instead of just colored points/shapes
return min([254, alpha * 255])


# --- Per-cell measurements into the annotating table (centroid / area / equivalent diameter) ---

# Destination keys (measurements are stored intrinsic, coordinate-system independent).
# obsm["spatial"] is the squidpy/scanpy convention for per-cell coordinates (an N x 2 array).
_CENTROID_OBSM_KEY = "spatial"
_AREA_OBS_KEY = "area"
_DIAMETER_OBS_KEY = "equivalent_diameter"


def _pixel_to_coord(idx: ArrayLike, coord: ArrayLike) -> ArrayLike:
"""Map fractional pixel indices to intrinsic coordinates along one axis (handles non-unit spacing)."""
spacing = (coord[1] - coord[0]) if len(coord) > 1 else 1.0
return coord[0] + np.asarray(idx) * spacing


def _stream_label_centroid_stats(data: Any) -> tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike]:
"""Per-label ``(labels, mean_x_index, mean_y_index, area)`` via a streaming bincount aggregator.

Streams the raster block by block — one chunk in memory at a time for a dask array, a
bounded row-block at a time for a numpy array — accumulating per-label ``count`` (= area),
``sum_x`` and ``sum_y``. Labels are relabelled to a dense ``0..k-1`` range, so memory is
O(number of distinct labels), independent of the raster size *and* of the label-id magnitude
(sparse/global ids do not blow it up). The reduction is additive, so it is exact across block
boundaries. Background label 0 is excluded.
"""
n_rows, n_cols = data.shape
if hasattr(data, "chunks"): # dask
block_slices = slices_from_chunks(data.chunks)
else:
data = np.asarray(data)
# bound the per-block coordinate-weight arrays to ~8M pixels
step = max(1, min(n_rows, (8 << 20) // max(1, n_cols)))
block_slices = [(slice(r0, min(r0 + step, n_rows)), slice(0, n_cols)) for r0 in range(0, n_rows, step)]

def _load(row_sl: slice, col_sl: slice) -> np.ndarray:
block = data[row_sl, col_sl]
block = np.asarray(block.compute() if hasattr(block, "compute") else block)
# label rasters are integer-valued even when stored as float; cast so np.unique/searchsorted
# stay integer (and the dense bincount indices below are always int).
return block.astype(np.int64) if block.dtype.kind == "f" else block

# Pass 1: the sorted set of present label values -> dense relabelling (keeps memory O(n_labels)).
uniq = np.zeros(0, dtype=np.int64)
for row_sl, col_sl in block_slices:
uniq = np.union1d(uniq, np.unique(_load(row_sl, col_sl)))
k = uniq.size

# Pass 2: additive per-(dense-)label count / sum_x / sum_y.
count = np.zeros(k)
sum_x = np.zeros(k)
sum_y = np.zeros(k)
for row_sl, col_sl in block_slices:
block = _load(row_sl, col_sl)
block_rows, block_cols = block.shape
idx = np.searchsorted(uniq, block.reshape(-1)) # dense 0..k-1 indices (always int)
cols = np.tile(np.arange(col_sl.start, col_sl.start + block_cols, dtype=np.float64), block_rows)
rows = np.repeat(np.arange(row_sl.start, row_sl.start + block_rows, dtype=np.float64), block_cols)
count += np.bincount(idx, minlength=k)
sum_x += np.bincount(idx, weights=cols, minlength=k)
sum_y += np.bincount(idx, weights=rows, minlength=k)

keep = uniq != 0 # drop background; every kept label has count >= 1
return uniq[keep], sum_x[keep] / count[keep], sum_y[keep] / count[keep], count[keep]


def _compute_element_measurements(sdata: SpatialData, element_name: str) -> pd.DataFrame:
"""One row per instance with intrinsic ``["x", "y", "area"]``, indexed by instance id.

Shapes use shapely's vectorized centroid; circles (``Point`` geometry + ``radius``) have
``area = pi*r**2`` (shapely ``.area`` is 0 for them), polygons use the geometric area. 2D labels
use the streaming bincount aggregator (``area`` = pixel count) — it holds one chunk plus
O(n_labels) accumulators, so it scales to Xenium-size masks where a whole-array ``regionprops``
table would run out of memory. Area meaning differs across element types but is consistent within
one element.
"""
element = sdata[element_name]
model = get_model(element)
if model is ShapesModel:
geometry = element.geometry
centroids = geometry.centroid
# Dispatch on geometry TYPE, not a column name: circles are Point geometries (shapely .area
# is 0 for them) with a radius -> pi*r**2; everything else uses the true geometric area.
if (geometry.geom_type == "Point").all():
area = np.pi * np.asarray(element["radius"], dtype=float) ** 2
else:
area = geometry.area.to_numpy()
return pd.DataFrame(
{"x": centroids.x.to_numpy(), "y": centroids.y.to_numpy(), "area": area}, index=element.index
)
if model is Labels2DModel:
# multiscale rasters carry their data on the scale0 level
raster = next(iter(element["scale0"].values())) if isinstance(element, DataTree) else element
labels, x_idx, y_idx, area = _stream_label_centroid_stats(raster.data)
# bincount gives mean 0-based pixel indices; map them onto the raster's intrinsic coords.
return pd.DataFrame(
{
"x": _pixel_to_coord(x_idx, raster.coords["x"].values),
"y": _pixel_to_coord(y_idx, raster.coords["y"].values),
"area": np.asarray(area, dtype=float),
},
index=labels,
)
raise NotImplementedError(
f"Measurement is only supported for shapes and 2D labels; element {element_name!r} is a {model.__name__}."
)


def _valid_spatial_obsm(arr: ArrayLike, n_obs: int) -> bool:
"""Whether ``arr`` is a usable ``obsm["spatial"]``: a 2D ``(n_obs, 2)`` coordinate grid."""
return bool(arr.ndim == 2 and arr.shape == (n_obs, 2))


def _obsm_region_finite(table: AnnData, key: str, mask: ArrayLike) -> bool:
"""Whether ``obsm[key]`` already holds finite coords for every ``mask`` row (already populated)."""
if key not in table.obsm:
return False
arr = np.asarray(table.obsm[key])
if not _valid_spatial_obsm(arr, table.n_obs):
return False
region = arr[mask].astype(float)
return bool(region.size and np.isfinite(region).all())


def _check_obs_numeric(table: AnnData, key: str) -> None:
"""Raise if ``obs[key]`` exists but is non-numeric, before any mutation (avoids half-writes)."""
if key in table.obs and not is_numeric_dtype(table.obs[key]):
raise ValueError(
f"Cannot write measurements into obs[{key!r}]: the existing column is "
f"{table.obs[key].dtype} (not numeric). Drop or rename the column first."
)


def _write_region(table: AnnData, mask: ArrayLike, key: str, values: ArrayLike, *, obsm: bool) -> None:
"""Write ``values`` into ``obsm[key]`` (2D) or ``obs[key]`` (1D) at ``mask`` rows; others stay/NaN.

Refuses to overwrite an incompatible existing ``obsm[key]`` (e.g. a 3-column xyz array) rather
than silently dropping data.
"""
store = table.obsm if obsm else table.obs
if key in store:
existing = np.asarray(store[key])
if obsm and not _valid_spatial_obsm(existing, table.n_obs):
raise ValueError(
f"Refusing to overwrite obsm[{key!r}] with shape {existing.shape}; expected "
f"({table.n_obs}, 2). Remove it first if you want it replaced."
)
arr = existing.astype(float, copy=True)
else:
arr = np.full((table.n_obs, 2) if obsm else table.n_obs, np.nan)
arr[mask] = np.asarray(values, dtype=float)
store[key] = arr


def _measure_into_table(
sdata: SpatialData, element_name: str, table_name: str, *, centroids: bool, area: bool, diameter: bool
) -> None:
"""Compute and write the requested measurements for one element into its annotating table.

Only the rows belonging to ``element_name`` are touched (a table may annotate several elements).
Centroids already present for those rows (e.g. reader-provided ``obsm["spatial"]``) are not
overwritten; ``area``/``diameter`` overwrite our own columns. All targets are validated before
the first write, so a bad column never leaves the table half-written.
"""
table = sdata.tables[table_name]
_, region_key, instance_key = get_table_keys(table)
mask = (table.obs[region_key].astype(str) == str(element_name)).to_numpy()
if not mask.any():
raise ValueError(f"Table {table_name!r} does not annotate element {element_name!r} (no matching rows).")

# #1: never clobber centroids already populated for this element's rows (reader/prior-call coords).
if centroids and _obsm_region_finite(table, _CENTROID_OBSM_KEY, mask):
warnings.warn(
f"obsm[{_CENTROID_OBSM_KEY!r}] is already populated for element {element_name!r}; not "
f"overwriting its centroids (remove it to recompute).",
UserWarning,
stacklevel=3,
)
centroids = False
if not (centroids or area or diameter):
return

keys = table.obs[instance_key].to_numpy()[mask]
meas = _compute_element_measurements(sdata, element_name).reindex(keys)
# #2: instance ids annotated in the table but absent from the element reindex to NaN -> warn.
missing = int(meas[["x", "y"]].isna().any(axis=1).sum())
if missing:
warnings.warn(
f"{missing}/{len(keys)} instances annotated for {element_name!r} have no match in the "
f"element (instance-id dtype mismatch, e.g. str vs int?); writing NaN for them.",
UserWarning,
stacklevel=3,
)

# #4: validate obs targets up front so an existing non-numeric column raises before any write.
if area:
_check_obs_numeric(table, _AREA_OBS_KEY)
if diameter:
_check_obs_numeric(table, _DIAMETER_OBS_KEY)

area_vals = meas["area"].to_numpy()
if centroids:
_write_region(table, mask, _CENTROID_OBSM_KEY, meas[["x", "y"]].to_numpy(), obsm=True)
if area:
_write_region(table, mask, _AREA_OBS_KEY, area_vals, obsm=False)
if diameter:
_write_region(table, mask, _DIAMETER_OBS_KEY, 2.0 * np.sqrt(area_vals / np.pi), obsm=False)


def _resolve_measure_table(sdata: SpatialData, element_name: str, table_name: str | None) -> str:
"""Resolve the single annotating table for ``element_name`` (where measurements are written)."""
if table_name is not None:
if table_name not in sdata.tables:
raise KeyError(f"Table {table_name!r} not found in `sdata.tables`.")
return table_name
annotators = sorted(get_element_annotators(sdata, element_name))
if not annotators:
raise ValueError(
f"Element {element_name!r} has no annotating table; per-cell measurements need a table "
f"to write into. Pass `table_name=` or annotate the element first."
)
if len(annotators) > 1:
raise ValueError(
f"Element {element_name!r} is annotated by multiple tables ({', '.join(annotators)}); "
f"pass `table_name=` to pick one."
)
return annotators[0]


def measure_obs(
sdata: SpatialData,
element: str | None = None,
*,
table_name: str | None = None,
centroids: bool = True,
area: bool = True,
diameter: bool = True,
inplace: bool = True,
) -> SpatialData | None:
"""Measure per-cell centroids, area and equivalent diameter into an element's annotating table.

Computes one centroid, area and equivalent diameter per instance of a shapes or 2D-labels
element and writes them, squidpy-style, into the annotating :class:`~anndata.AnnData` table:
centroids go to ``table.obsm["spatial"]`` (an ``(n_obs, 2)`` array, the squidpy convention),
area and diameter to ``table.obs["area"]`` and ``table.obs["equivalent_diameter"]``. Values are
stored in the element's *intrinsic* pixel coordinates/units (which align directly with the
element's own raster/geometry).
Labels area is the pixel count; shapes area is ``geometry.area`` (``pi*r**2`` for circles);
equivalent diameter is ``2 * sqrt(area / pi)``. Persisting them once lets later renders (and
downstream tools such as squidpy) reuse them instead of recomputing.

Centroids already present for an element's rows (e.g. a reader-provided ``obsm["spatial"]``) are
**not** overwritten — a warning is emitted and that element's centroid write is skipped (remove
the key to recompute); ``area``/``diameter`` overwrite our own columns. Instances annotated in
the table but absent from the element are written as NaN with a warning. Per-cell measurements
need a table to write into, so an element without an annotating table cannot be measured (this
raises). The label path never densifies the raster — it streams it block by block with memory
O(n_labels), scaling to Xenium-size masks.

Parameters
----------
sdata
The ``SpatialData`` object holding the element and its annotating table.
element
Name of the shapes/2D-labels element to measure. If ``None``, every shapes/2D-labels
element that has exactly one annotating table is measured.
table_name
Name of the annotating table to write into. If ``None``, it is inferred from the element's
annotators (an error is raised when there are zero or several).
centroids, area, diameter
Which measurements to compute/write. At least one must be ``True``. They are written to
``obsm["spatial"]``, ``obs["area"]`` and ``obs["equivalent_diameter"]`` respectively.
inplace
If ``True`` (default), mutate ``sdata``'s table in place and return ``None``. If ``False``,
operate on a deep copy and return the modified ``SpatialData``.

Returns
-------
``None`` if ``inplace`` is ``True``, otherwise the modified deep-copied ``SpatialData``.
"""
if not (centroids or area or diameter):
raise ValueError("Nothing to measure: set at least one of `centroids`, `area`, `diameter` to True.")
target = sdata if inplace else sd_deepcopy(sdata)
if element is None:
# measure every shapes / 2D-labels element that has exactly one annotating table
names = [
n
for n in list(target.shapes) + list(target.labels)
if get_model(target[n]) in (ShapesModel, Labels2DModel) and len(get_element_annotators(target, n)) == 1
]
if not names:
raise ValueError(
"No shapes/2D-labels element with a single annotating table was found; nothing to "
"measure. Pass an explicit `element=` (and `table_name=` if ambiguous)."
)
else:
names = [element]
for name in names:
table = _resolve_measure_table(target, name, table_name)
_measure_into_table(target, name, table, centroids=centroids, area=area, diameter=diameter)
return None if inplace else target
Loading
Loading