Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
27ba8d2
Centroid extraction + squidpy obsm["spatial"] caching; shared scatter…
timtreis Jun 8, 2026
3bad4c0
Simplify centroid cache: dedup region mask/keys, fix half-write, trim…
timtreis Jun 8, 2026
23da9bc
Store centroids intrinsic (coordinate-system-independent cache)
timtreis Jun 8, 2026
40d8cef
Simplify centroids: numpy-affine transform on cache hits, drop privat…
timtreis Jun 8, 2026
fc3af02
Add as_points fast mode: render shapes/labels as centroid dots
timtreis Jun 8, 2026
169d897
Fix as_points crash for labels without a color column
timtreis Jun 8, 2026
4678297
Compute label centroids with a streaming, out-of-core bincount aggreg…
timtreis Jun 8, 2026
13b1df3
Merge main; refactor as_points to reuse measure_obs primitives
timtreis Jun 9, 2026
f8918fb
Simplify _render_centroids_as_points: take render_params, not 6 unpac…
timtreis Jun 9, 2026
79892ee
Tidy labels as_points color alignment
timtreis Jun 9, 2026
85692e9
Trim verbose as_points comments to the load-bearing invariants
timtreis Jun 9, 2026
b087c1c
Make labels as_points fast: centroids on the rendered raster, not scale0
timtreis Jun 9, 2026
57bcf8d
Fast axis-bounds in show(): skip per-geometry transform when axis-ali…
timtreis Jun 9, 2026
ae5fcfb
refactor: tidy _get_extent_fast (fold helper, total_bounds, batched c…
timtreis Jun 10, 2026
4b9891c
perf(datashader): use _element_extent_fast for the shapes canvas extent
timtreis Jun 10, 2026
ee2513e
fix(labels): correct as_points no-color rendering + validate size (co…
timtreis Jun 10, 2026
158b108
refactor(labels): dedup as_points size validation, lazy color conversion
timtreis Jun 10, 2026
df14a19
Merge branch 'main' of https://github.com/scverse/spatialdata-plot in…
timtreis Jun 10, 2026
ae2f124
perf(extent): reuse fetched transformations in _element_extent_fast
timtreis Jun 10, 2026
6b78f38
fix(extent): defer all-empty elements to get_extent; trim comments
timtreis Jun 10, 2026
ddb4cc5
test(as_points): add visual tests for shapes/labels centroid mode + size
timtreis Jun 10, 2026
1e88495
test(as_points): add CI-generated baselines for the 4 as_points visua…
timtreis Jun 10, 2026
97c5624
perf(shapes): vectorize the all-empty-geometry guard in show()
timtreis Jun 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,15 @@
_expand_color_panels,
_get_cs_contents,
_get_elements_to_be_rendered,
_get_extent_fast,
_get_valid_cs,
_get_wanted_render_elements,
_maybe_set_colors,
_mpl_ax_contains_elements,
_prepare_cmap_norm,
_prepare_params_plot,
_set_outline,
_validate_as_points_size,
_validate_graph_render_params,
_validate_image_render_params,
_validate_label_render_params,
Expand Down Expand Up @@ -332,6 +334,8 @@ def render_shapes(
colorbar_params: dict[str, object] | None = None,
datashader_reduction: _DsReduction | None = None,
transfunc: Callable[[float], float] | None = None,
as_points: bool = False,
size: float | int = 1.0,
) -> sd.SpatialData:
"""
Render shapes elements in SpatialData.
Expand Down Expand Up @@ -448,6 +452,8 @@ def render_shapes(
sd.SpatialData
A copy of the SpatialData object with the rendering parameters stored in its plotting tree.
"""
if as_points:
_validate_as_points_size(size)
panel_param_dicts = _expand_color_panels(
self._sdata,
color,
Expand Down Expand Up @@ -515,6 +521,8 @@ def render_shapes(
ds_reduction=param_values["ds_reduction"],
colorbar=param_values["colorbar"],
colorbar_params=param_values["colorbar_params"],
as_points=as_points,
size=size,
panel_key=panel_key,
)
n_steps += 1
Expand Down Expand Up @@ -953,6 +961,8 @@ def render_labels(
table_layer: str | None = None,
gene_symbols: str | None = None,
transfunc: Callable[[float], float] | None = None,
as_points: bool = False,
size: float | int = 1.0,
) -> sd.SpatialData:
"""
Render labels elements in SpatialData.
Expand Down Expand Up @@ -1044,6 +1054,8 @@ def render_labels(
sd.SpatialData
A copy of the SpatialData object with the rendering parameters stored in its plotting tree.
"""
if as_points:
_validate_as_points_size(size)
panel_param_dicts = _expand_color_panels(
self._sdata,
color,
Expand Down Expand Up @@ -1101,6 +1113,8 @@ def render_labels(
zorder=n_steps,
colorbar=param_values["colorbar"],
colorbar_params=param_values["colorbar_params"],
as_points=as_points,
size=size,
panel_key=panel_key,
)
n_steps += 1
Expand Down Expand Up @@ -1812,15 +1826,16 @@ def _draw_colorbar(
empty_shape_elements = [
name
for name in wanted_elements
if name in sdata.shapes and not sdata.shapes[name]["geometry"].apply(lambda g: not g.is_empty).any()
if name in sdata.shapes and sdata.shapes[name]["geometry"].is_empty.all()
]
if empty_shape_elements:
raise ValueError(
f"Cannot render shape element(s) {empty_shape_elements} in coordinate system {cs!r}: "
"all geometries are empty. Drop the element or restore at least one non-empty geometry."
)

extent = get_extent(
# fast path for axis-aligned transforms; identical result, falls back to get_extent otherwise
extent = _get_extent_fast(
sdata,
coordinate_system=cs,
has_images=has_images and wants_images,
Expand Down
180 changes: 170 additions & 10 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from anndata import AnnData
from matplotlib import patheffects
from matplotlib.cm import ScalarMappable
from matplotlib.colors import ListedColormap, Normalize
from matplotlib.colors import Colormap, ListedColormap, Normalize
from scanpy._settings import settings as sc_settings
from scanpy.plotting._tools.scatterplots import _add_categorical_legend
from spatialdata import get_extent, get_values
Expand Down Expand Up @@ -77,11 +77,13 @@
_maybe_set_colors,
_mpl_ax_contains_elements,
_multiscale_to_spatial_image,
_pixel_to_coord,
_prepare_cmap_norm,
_prepare_transformation,
_rasterize_if_necessary,
_rasterize_if_necessary_datashader,
_set_color_source_vec,
_stream_label_centroid_stats,
_validate_polygons,
)

Expand Down Expand Up @@ -750,6 +752,30 @@ def _render_shapes(
color_source_vector = color_source_vector.remove_unused_categories()

shapes = gpd.GeoDataFrame(shapes, geometry="geometry")

if render_params.as_points:
# Fast mode: draw one dot per shape at its centroid instead of its geometry.
logger.info("`as_points=True`: rendering shape centroids; `outline_*` and `shape` are ignored.")
centroids = shapes.geometry.centroid # intrinsic coords, positionally aligned to color_vector
_render_centroids_as_points(
ax,
render_params,
x=centroids.x.to_numpy(),
y=centroids.y.to_numpy(),
color_vector=color_vector,
color_source_vector=color_source_vector,
norm=norm,
na_color=render_params.cmap_params.na_color,
transform=trans_data, # intrinsic -> coordinate system -> display
adata=table,
col_for_color=col_for_color,
palette=palette,
fig_params=fig_params,
legend_params=legend_params,
colorbar_requests=colorbar_requests,
)
return

# convert shapes if necessary
if render_params.shape is not None:
current_type = shapes["geometry"].type
Expand Down Expand Up @@ -1051,6 +1077,92 @@ def _render_shapes(
)


def _scatter_points(
ax: matplotlib.axes.SubplotBase,
x: Any,
y: Any,
color_vector: Any,
*,
size: float,
cmap: Colormap,
norm: Normalize | None,
alpha: float,
trans_data: Any,
zorder: int,
) -> Any:
"""Draw one marker per (x, y) colored by ``color_vector`` via ``ax.scatter``.

Shared scatter primitive for points and the centroid "fast mode" of shapes/labels;
``color_vector`` is per-point hex strings or numeric values mapped through ``cmap``/``norm``.
"""
return ax.scatter(
x,
y,
s=size,
c=color_vector,
rasterized=sc_settings._vector_friendly,
cmap=cmap,
norm=norm,
alpha=alpha,
transform=trans_data,
zorder=zorder,
plotnonfinite=True, # nan points should be rendered as well
)


def _render_centroids_as_points(
ax: matplotlib.axes.SubplotBase,
render_params: ShapesRenderParams | LabelsRenderParams,
*,
x: Any,
y: Any,
color_vector: Any,
color_source_vector: pd.Series | None,
norm: Normalize | None,
na_color: Any,
transform: Any,
adata: AnnData | None,
col_for_color: str | None,
palette: Any,
fig_params: FigParams,
legend_params: LegendParams,
colorbar_requests: list[ColorbarSpec] | None,
) -> None:
"""Render one dot per cell at ``(x, y)`` colored like the fill, with legend/colorbar.

Shared "fast mode" draw for shapes/labels; style comes off ``render_params``. ``norm``/``na_color``
stay explicit because they differ between the shapes (locally adjusted) and labels paths.
"""
cax = _scatter_points(
ax,
x,
y,
color_vector,
size=render_params.size,
cmap=render_params.cmap_params.cmap,
norm=norm,
alpha=render_params.fill_alpha,
trans_data=transform,
zorder=render_params.zorder,
)
_add_legend_and_colorbar(
ax=ax,
cax=cax,
fig_params=fig_params,
adata=adata,
col_for_color=col_for_color,
color_source_vector=color_source_vector,
color_vector=color_vector,
palette=palette,
alpha=render_params.fill_alpha,
na_color=na_color,
legend_params=legend_params,
colorbar=render_params.colorbar,
colorbar_params=render_params.colorbar_params,
colorbar_requests=colorbar_requests,
)


def _render_points(
sdata: sd.SpatialData,
render_params: PointsRenderParams,
Expand Down Expand Up @@ -1403,18 +1515,17 @@ def _render_points(
elif method == "matplotlib":
# update axis limits if plot was empty before (necessary if datashader comes after)
update_parameters = not _mpl_ax_contains_elements(ax)
cax = ax.scatter(
cax = _scatter_points(
ax,
adata[:, 0].X.flatten(),
adata[:, 1].X.flatten(),
s=render_params.size,
c=color_vector,
rasterized=sc_settings._vector_friendly,
color_vector,
size=render_params.size,
cmap=render_params.cmap_params.cmap,
norm=norm,
alpha=render_params.alpha,
transform=trans_data,
trans_data=trans_data,
zorder=render_params.zorder,
plotnonfinite=True, # nan points should be rendered as well
)
if update_parameters:
# necessary if points are plotted with mpl first and then with datashader
Expand Down Expand Up @@ -2187,9 +2298,9 @@ def _render_labels(
len(instance_id),
)

# rasterize could have removed labels from label
# only problematic if color is specified
if rasterize and (col_for_color is not None or col_for_outline_color is not None):
# rasterize/downsampling can drop labels from the raster; remove their (now-absent) instance ids
# so per-instance colors stay aligned and as_points does not emit dots for dropped cells.
if rasterize and (col_for_color is not None or col_for_outline_color is not None or render_params.as_points):
mask = np.isin(instance_id, unique_labels)
instance_id = instance_id[mask]
if col_for_color is not None:
Expand Down Expand Up @@ -2238,6 +2349,55 @@ def _render_labels(
if color_source_vector is None and render_params.transfunc is not None:
color_vector = render_params.transfunc(color_vector)

if render_params.as_points:
# Fast mode: one dot per label at its centroid. Compute on the *rendered* raster (already
# downsampled to ~display resolution above) and draw with its `trans_data`, so this is cheap
# and the dots land where the cells are. Centroid error is sub-pixel at display resolution.
logger.info("`as_points=True`: rendering label centroids; `contour_px` and `outline_*` are ignored.")
keep = instance_id != 0 # background label 0 has no centroid
point_ids = instance_id[keep]
labels, x_idx, y_idx, _area = _stream_label_centroid_stats(label.data)
centroids = pd.DataFrame(
{
"x": _pixel_to_coord(x_idx, label.coords["x"].values),
"y": _pixel_to_coord(y_idx, label.coords["y"].values),
},
index=labels,
)
# coerce so str/object table ids (e.g. Xenium) match the integer raster labels instead of NaN
centroids = centroids.reindex(point_ids.astype(labels.dtype, copy=False))
if col_for_color is None and not na_color.color_modified_by_user():
# no color column: one distinct random colour per cell, matching the mask path
# (`_map_color_seg` Case C) instead of collapsing every dot to a single na_color.
point_color_vector = np.random.default_rng(42).random((len(point_ids), 3))
point_color_source_vector = None
elif len(color_vector) == len(instance_id):
# data-driven colour is per-instance
point_color_vector = np.asarray(color_vector)[keep]
point_color_source_vector = None if color_source_vector is None else color_source_vector[keep]
else:
# literal colour / user-set na_color -> one colour per centroid
point_color_vector = np.asarray([na_color.get_hex_with_alpha()] * len(point_ids))
point_color_source_vector = None
_render_centroids_as_points(
ax,
render_params,
x=centroids["x"].to_numpy(),
y=centroids["y"].to_numpy(),
color_vector=point_color_vector,
color_source_vector=point_color_source_vector,
norm=copy(render_params.cmap_params.norm), # ax.scatter autoscales in place; don't mutate the shared norm
na_color=na_color,
transform=trans_data, # rendered-raster intrinsic coords -> coordinate system -> display
adata=table if table_name is not None else None,
col_for_color=col_for_color,
palette=palette,
fig_params=fig_params,
legend_params=legend_params,
colorbar_requests=colorbar_requests,
)
return

def _draw_labels(
seg_erosionpx: int | None,
seg_boundaries: bool,
Expand Down
6 changes: 6 additions & 0 deletions src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ class ShapesRenderParams:
table_name: str | None = None
table_layer: str | None = None
shape: Literal["circle", "hex", "visium_hex", "square"] | None = None
# Fast mode: render each shape as a single dot at its centroid instead of its geometry.
as_points: bool = False
size: float = 1.0 # marker size for as_points (matplotlib scatter ``s``)
ds_reduction: _DsReduction | None = None
colorbar: bool | str | None = "auto"
colorbar_params: dict[str, object] | None = None
Expand Down Expand Up @@ -328,6 +331,9 @@ class LabelsRenderParams:
zorder: int = 0
colorbar: bool | str | None = "auto"
colorbar_params: dict[str, object] | None = None
# Fast mode: render each label as a single dot at its centroid instead of the mask.
as_points: bool = False
size: float = 1.0 # marker size for as_points (matplotlib scatter ``s``)
# Multi-panel color: when set, this render entry belongs to the panel identified by this
# color key. ``None`` means the entry is shared across every panel (e.g. a background layer).
panel_key: str | None = None
Expand Down
Loading
Loading