diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d89ae5246..59c7b5a387 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,13 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm - **GFQL / Cypher**: Extracted `ASTNormalizer` into `graphistry/compute/gfql/cypher/ast_normalizer.py` and moved shortestPath + WHERE-pattern-predicate rewrite ownership out of `lowering.py`, with parity-preserving wiring in compile/lowering flows and focused regression coverage for rewrite behavior and invocation order (#1117). - **GFQL / Cypher compiler**: Lowering now functionally consumes `BoundIR` metadata for the M1 integration slice: binder-provided params are merged into effective lowering params (runtime overrides preserved) with binder metadata keys filtered out of runtime-param resolution, scope membership narrowing uses the active scope frame for WITH-boundary correctness, semantic-table entity kinds inform alias table routing, and nullable alias metadata is wired into optional-only alias detection. `_StageScope` duplicated table bookkeeping was reduced, binder now runs pre- and post-normalization in compile flow, and binder-path regression tests were added for these code paths (#1116). +### Changed +- **Collections**: Autofix validation now drops invalid collections (e.g., invalid GFQL ops) and non-string collection color fields instead of string-coercing them; warnings still emit when `warn=True`. +- **Collections**: `collections(...)` now always canonicalizes to URL-encoded JSON (string inputs are parsed + re-encoded); the `encode` parameter was removed to avoid ambiguous behavior. +- **Collections**: Set collections now require an `id` field (server requires it for subgraph storage); missing IDs are warned and dropped in autofix mode rather than auto-generated. +- **Collections**: Intersection collections now cross-validate that referenced set IDs exist; dangling references are warned and dropped in autofix mode. +- **Collections**: GFQL parsing consolidated to use `_wrap_gfql_expr` from `collections.py` as the canonical implementation with precise exception handling. + ### Tests - **GFQL / Cypher binder**: Added PR-4 white-box binder semantic conformance coverage for name resolution success/failure (including unresolved alias errors), WITH scope-reset visibility, OPTIONAL MATCH `null_extended_from` lineage as `frozenset` clause ids, label narrowing from MATCH labels + conjunctive `WHERE alias:Label` checks, and SchemaConfidence rules (min-rule propagation, operand inheritance, and strong literal/`COUNT` behavior). Parser/lowering regression lanes remain green (#1114). - **Plugins / cuDF**: 14 GPU tests in `TestCpuOnlyPluginsCudfRoundTrip` (`test_call_operations_gpu.py`) verifying real cuDF→pandas→cuDF round-trip for `compute_igraph` (pagerank, spanning_tree Graph-returning path, articulation_points list-return path, edge-attribute merge path), `layout_igraph`, `layout_graphviz`, `render_graphviz`, `execute_call`, `ensure_pandas` nullable dtype preservation, and `restore_engine` conversion. Requires `TEST_CUDF=1` and RAPIDS. @@ -412,6 +419,17 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ### Infra - **CI**: Added pandas 2.2.3/3.0.0 compatibility jobs and minimal suite coverage. +### Added +- **Collections**: New `g.collections(...)` API for defining subsets via GFQL expressions with priority-based visual encodings. Includes helper constructors `graphistry.collection_set(...)` and `graphistry.collection_intersection(...)`, support for `showCollections`, `collectionsGlobalNodeColor`, and `collectionsGlobalEdgeColor` URL params, and automatic JSON encoding. Accepts GFQL AST, Chain objects, or wire-protocol dicts (#874). +- **Docs / Collections**: Added collections usage guide in visualization/layout/settings, tutorial notebook (`demos/more_examples/graphistry_features/collections.ipynb`), and cross-references in 10-minute guides, cheatsheet, and GFQL docs (#875). + +### Changed +- **Collections**: Autofix validation now drops invalid collections (e.g., invalid GFQL ops) and non-string collection color fields instead of string-coercing them; warnings still emit when `warn=True`. +- **Collections**: `collections(...)` now always canonicalizes to URL-encoded JSON (string inputs are parsed + re-encoded); the `encode` parameter was removed to avoid ambiguous behavior. + +### Tests +- **Collections**: Added `test_collections.py` covering encoding, GFQL Chain/AST normalization, wire-protocol acceptance, validation modes, and helper constructors. + ## [0.50.4 - 2026-01-15] ### Fixed diff --git a/ai/prompts/PLAN.md b/ai/prompts/PLAN.md index d31d05227d..57a02a19e2 100644 --- a/ai/prompts/PLAN.md +++ b/ai/prompts/PLAN.md @@ -125,7 +125,7 @@ git log --oneline -n 10 - Source: `graphistry/` - Tests: `graphistry/tests/` (mirrors source structure: `graphistry/foo/bar.py` → `graphistry/tests/foo/test_bar.py`) - Docs: `docs/` -- Plans: `plans/` (gitignored - safe for auxiliary files, temp secrets, working data) +- Plans: `plans/` (gitignored - safe for auxiliary files, temp secrets, working data; Codex: avoid `~/.codex/plans`; if used, copy here then delete) - AI prompts: `ai/prompts/` - AI docs: `ai/docs/` diff --git a/graphistry/Plottable.py b/graphistry/Plottable.py index 18ed6cc67a..c94da3295e 100644 --- a/graphistry/Plottable.py +++ b/graphistry/Plottable.py @@ -17,6 +17,7 @@ from graphistry.Engine import EngineAbstractType from graphistry.utils.json import JSONVal from graphistry.client_session import ClientSession, AuthManagerProtocol +from graphistry.models.collections import CollectionsInput from graphistry.models.types import ValidationParam if TYPE_CHECKING: @@ -783,6 +784,17 @@ def settings(self, ) -> 'Plottable': ... + def collections( + self, + collections: Optional[CollectionsInput] = None, + show_collections: Optional[bool] = None, + collections_global_node_color: Optional[str] = None, + collections_global_edge_color: Optional[str] = None, + validate: ValidationParam = 'autofix', + warn: bool = True + ) -> 'Plottable': + ... + def privacy(self, mode: Optional[PrivacyMode] = None, notify: Optional[bool] = None, invited_users: Optional[List[str]] = None, message: Optional[str] = None, mode_action: Optional[ModeAction] = None) -> 'Plottable': ... diff --git a/graphistry/PlotterBase.py b/graphistry/PlotterBase.py index 28453952ae..247274d7d9 100644 --- a/graphistry/PlotterBase.py +++ b/graphistry/PlotterBase.py @@ -2,6 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Union, Tuple, cast, overload, TYPE_CHECKING from typing_extensions import Literal from graphistry.io.types import ComplexEncodingsDict +from graphistry.models.collections import CollectionsInput from graphistry.models.types import ValidationMode, ValidationParam from graphistry.plugins_types.hypergraph import HypergraphResult from graphistry.render.resolve_render_mode import resolve_render_mode @@ -1872,7 +1873,8 @@ def graph(self, ig: Any) -> Plottable: def settings(self, height=None, url_params={}, render=None): """Specify iframe height and add URL parameter dictionary. - The library takes care of URI component encoding for the dictionary. + Collections URL params are normalized and URL-encoded at plot time; other + params should already be URL-safe. :param height: Height in pixels. :type height: int @@ -1892,6 +1894,51 @@ def settings(self, height=None, url_params={}, render=None): return res + def collections( + self, + collections: Optional[CollectionsInput] = None, + show_collections: Optional[bool] = None, + collections_global_node_color: Optional[str] = None, + collections_global_edge_color: Optional[str] = None, + validate: ValidationParam = 'autofix', + warn: bool = True + ) -> 'Plottable': + """Set collections URL parameters. Additive over previous settings. + + :param collections: List/dict of collections or JSON/URL-encoded JSON string (stored as URL-encoded JSON). + :param show_collections: Toggle collections panel display. + :param collections_global_node_color: Hex color for non-collection nodes (leading # stripped). + :param collections_global_edge_color: Hex color for non-collection edges (leading # stripped). + :param validate: Validation mode. 'autofix' (default) drops invalid collections and color fields with warnings, 'strict' raises on issues. + :param warn: Whether to emit warnings when validate='autofix'. validate=False forces warn=False. + """ + from graphistry.validate.validate_collections import ( + encode_collections, + normalize_collections, + normalize_collections_url_params, + ) + + settings: Dict[str, Any] = {} + if collections is not None: + normalized = normalize_collections(collections, validate=validate, warn=warn) + settings['collections'] = encode_collections(normalized) + extras: Dict[str, Any] = {} + if show_collections is not None: + extras['showCollections'] = show_collections + if collections_global_node_color is not None: + extras['collectionsGlobalNodeColor'] = collections_global_node_color + if collections_global_edge_color is not None: + extras['collectionsGlobalEdgeColor'] = collections_global_edge_color + if extras: + extras = normalize_collections_url_params(extras, validate=validate, warn=warn) + settings.update(extras) + + if len(settings.keys()) > 0: + return self.settings(url_params={**self._url_params, **settings}) + else: + return self + + def privacy( self, mode: Optional[Mode] = None, @@ -2239,7 +2286,11 @@ def plot( 'viztoken': str(uuid.uuid4()) } - viz_url = self._pygraphistry._viz_url(info, self._url_params) + # Validate collections in url_params (catches bypass of .collections() method) + from graphistry.validate.validate_collections import normalize_collections_url_params + url_params = normalize_collections_url_params(self._url_params, validate=validate_mode, warn=warn) + + viz_url = self._pygraphistry._viz_url(info, url_params) cfg_client_protocol_hostname = self.session.client_protocol_hostname full_url = ('%s:%s' % (self.session.protocol, viz_url)) if cfg_client_protocol_hostname is None else viz_url diff --git a/graphistry/__init__.py b/graphistry/__init__.py index dea30e5c25..df57db79af 100644 --- a/graphistry/__init__.py +++ b/graphistry/__init__.py @@ -24,6 +24,7 @@ nodes, graph, settings, + collections, encode_point_color, encode_point_size, encode_point_icon, @@ -65,6 +66,13 @@ from_cugraph ) +from graphistry.collections import ( + collection_set, + collection_intersection, + CollectionSet, + CollectionIntersection, +) + from graphistry.compute import ( n, e, e_forward, e_reverse, e_undirected, let, ref, diff --git a/graphistry/collections.py b/graphistry/collections.py new file mode 100644 index 0000000000..a133b746ab --- /dev/null +++ b/graphistry/collections.py @@ -0,0 +1,77 @@ +from typing import Optional, Sequence, TypeVar + +from graphistry.models.collections import ( + CollectionIntersection, + CollectionExprInput, + CollectionSet, +) + +CollectionDict = TypeVar("CollectionDict", CollectionSet, CollectionIntersection) + + +def _apply_collection_metadata(collection: CollectionDict, **metadata: Optional[str]) -> CollectionDict: + value = metadata.get("id") + if value is not None: + collection["id"] = value + value = metadata.get("name") + if value is not None: + collection["name"] = value + value = metadata.get("description") + if value is not None: + collection["description"] = value + value = metadata.get("node_color") + if value is not None: + collection["node_color"] = value + value = metadata.get("edge_color") + if value is not None: + collection["edge_color"] = value + return collection + + +def collection_set( + *, + expr: CollectionExprInput, + id: Optional[str] = None, + name: Optional[str] = None, + description: Optional[str] = None, + node_color: Optional[str] = None, + edge_color: Optional[str] = None, +) -> CollectionSet: + """Build a collection dict for a GFQL-defined set.""" + from graphistry.compute.ast import normalize_gfql_to_wire + collection: CollectionSet = {"type": "set", "expr": {"type": "gfql_chain", "gfql": normalize_gfql_to_wire(expr)}} + return _apply_collection_metadata( + collection, + id=id, + name=name, + description=description, + node_color=node_color, + edge_color=edge_color, + ) + + +def collection_intersection( + *, + sets: Sequence[str], + id: Optional[str] = None, + name: Optional[str] = None, + description: Optional[str] = None, + node_color: Optional[str] = None, + edge_color: Optional[str] = None, +) -> CollectionIntersection: + """Build a collection dict for an intersection of set IDs.""" + collection: CollectionIntersection = { + "type": "intersection", + "expr": { + "type": "intersection", + "sets": list(sets), + }, + } + return _apply_collection_metadata( + collection, + id=id, + name=name, + description=description, + node_color=node_color, + edge_color=edge_color, + ) diff --git a/graphistry/compute/ast.py b/graphistry/compute/ast.py index 5e1b058d46..8f265ecca3 100644 --- a/graphistry/compute/ast.py +++ b/graphistry/compute/ast.py @@ -1588,6 +1588,60 @@ def from_json(o: JSONVal, validate: bool = True) -> Union[ASTNode, ASTEdge, ASTL return out +def normalize_gfql_to_wire(expr: Any) -> List[Dict[str, JSONVal]]: + """ + Normalize GFQL expression to wire format (list of JSON-serializable dicts). + + Accepts: + - Chain object + - Single ASTObject + - List of ASTObjects + - Dict with 'type': 'Chain' and 'chain' key + - Dict with 'type': 'gfql_chain' and 'gfql' key + - Dict with just 'chain' or 'gfql' key + - Single dict (parsed as AST op) + + Returns: + - List of JSON-serializable dicts ready for wire protocol + + Raises: + - TypeError: if expr type is not supported + - ValueError: if expr is empty + - GFQLSyntaxError: if dict cannot be parsed as valid AST + """ + from graphistry.compute.chain import Chain + + def _normalize_op(op: object) -> Dict[str, JSONVal]: + if isinstance(op, ASTObject): + return op.to_json() + if isinstance(op, dict): + return from_json(op, validate=True).to_json() + raise TypeError("GFQL operations must be AST objects or dictionaries") + + def _normalize_ops(raw: object) -> List[Dict[str, JSONVal]]: + if isinstance(raw, Chain): + return _normalize_ops(raw.to_json().get("chain", [])) + if isinstance(raw, ASTObject): + return [raw.to_json()] + if isinstance(raw, list): + if len(raw) == 0: + raise ValueError("GFQL operations list cannot be empty") + return [_normalize_op(op) for op in raw] + if isinstance(raw, dict): + if raw.get("type") == "Chain" and "chain" in raw: + return _normalize_ops(raw.get("chain")) + if raw.get("type") == "gfql_chain" and "gfql" in raw: + return _normalize_ops(raw.get("gfql")) + if "chain" in raw: + return _normalize_ops(raw.get("chain")) + if "gfql" in raw: + return _normalize_ops(raw.get("gfql")) + return [_normalize_op(raw)] + raise TypeError("GFQL expr must be Chain, ASTObject, list, or dict") + + return _normalize_ops(expr) + + ############################################################################### # User-friendly aliases for public API diff --git a/graphistry/models/collections.py b/graphistry/models/collections.py new file mode 100644 index 0000000000..d12896d663 --- /dev/null +++ b/graphistry/models/collections.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Dict, List, TYPE_CHECKING, Union +from typing_extensions import Literal, NotRequired, Required, TypedDict + +from graphistry.utils.json import JSONVal + +if TYPE_CHECKING: + from graphistry.compute.ast import ASTObject + from graphistry.compute.chain import Chain + + +CollectionExprInput = Union[ + "Chain", + "ASTObject", + List["ASTObject"], + Dict[str, JSONVal], + List[Dict[str, JSONVal]], +] + + +class IntersectionExpr(TypedDict): + type: Literal["intersection"] + sets: List[str] + + +class CollectionBase(TypedDict, total=False): + id: str + name: str + description: str + node_color: str + edge_color: str + + +class CollectionSet(CollectionBase): + type: NotRequired[Literal["set"]] + expr: Required[CollectionExprInput] + + +class CollectionIntersection(CollectionBase): + type: NotRequired[Literal["intersection"]] + expr: Required[IntersectionExpr] + + +Collection = Union[CollectionSet, CollectionIntersection] +CollectionsInput = Union[str, Collection, List[Collection]] diff --git a/graphistry/pygraphistry.py b/graphistry/pygraphistry.py index fe460e74c3..d08176b283 100644 --- a/graphistry/pygraphistry.py +++ b/graphistry/pygraphistry.py @@ -6,6 +6,8 @@ from graphistry.plugins_types.gexf_types import GexfEdgeViz, GexfNodeViz, GexfParseEngine from graphistry.client_session import ClientSession, ApiVersion, ENV_GRAPHISTRY_API_KEY, DatasetInfo, AuthManagerProtocol, strtobool from graphistry.Engine import EngineAbstractType +from graphistry.models.collections import CollectionsInput +from graphistry.models.types import ValidationParam from graphistry.otel import inject_trace_headers, otel as otel_config """Top-level import of class PyGraphistry as "Graphistry". Used to connect to the Graphistry server and then create a base plotter.""" @@ -2376,6 +2378,24 @@ def settings(self, height=None, url_params={}, render=None): return self._plotter().settings(height, url_params, render) + def collections( + self, + collections: Optional[CollectionsInput] = None, + show_collections: Optional[bool] = None, + collections_global_node_color: Optional[str] = None, + collections_global_edge_color: Optional[str] = None, + validate: ValidationParam = 'autofix', + warn: bool = True + ): + return self._plotter().collections( + collections=collections, + show_collections=show_collections, + collections_global_node_color=collections_global_node_color, + collections_global_edge_color=collections_global_edge_color, + validate=validate, + warn=warn + ) + def _viz_url(self, info: DatasetInfo, url_params: Dict[str, Any]) -> str: splash_time = int(calendar.timegm(time.gmtime())) + 15 extra = "&".join([k + "=" + str(v) for k, v in list(url_params.items())]) @@ -2604,6 +2624,7 @@ def _handle_api_response(self, response): pipe = PyGraphistry.pipe graph = PyGraphistry.graph settings = PyGraphistry.settings +collections = PyGraphistry.collections hypergraph = PyGraphistry.hypergraph bolt = PyGraphistry.bolt cypher = PyGraphistry.cypher diff --git a/graphistry/tests/test_collections.py b/graphistry/tests/test_collections.py new file mode 100644 index 0000000000..3e7750cb14 --- /dev/null +++ b/graphistry/tests/test_collections.py @@ -0,0 +1,306 @@ +# -*- coding: utf-8 -*- + +import json +import pytest +from urllib.parse import unquote + +import graphistry +from graphistry.collections import collection_intersection, collection_set +from graphistry.validate.validate_collections import normalize_collections_url_params + + +def decode_collections(encoded: str): + return json.loads(unquote(encoded)) + + +def collections_url_params(collections, **kwargs): + return graphistry.bind().collections(collections=collections, **kwargs)._url_params + + +def test_collections_encodes_and_normalizes(): + node_filter = graphistry.n({"subscribed_to_newsletter": True}) + collections = [ + { + "type": "set", + "id": "newsletter_subscribers", + "name": "Newsletter Subscribers", + "node_color": "#32CD32", + "expr": { + "type": "gfql_chain", + "gfql": [node_filter], + }, + } + ] + + url_params = collections_url_params( + collections, + show_collections=True, + collections_global_node_color="#00FF00", + collections_global_edge_color="#00AA00", + ) + + decoded = decode_collections(url_params["collections"]) + assert decoded == [ + { + "type": "set", + "id": "newsletter_subscribers", + "name": "Newsletter Subscribers", + "node_color": "#32CD32", + "expr": { + "type": "gfql_chain", + "gfql": [node_filter.to_json()], + }, + } + ] + assert url_params["showCollections"] is True + assert url_params["collectionsGlobalNodeColor"] == "00FF00" + assert url_params["collectionsGlobalEdgeColor"] == "00AA00" + + +@pytest.mark.parametrize("expr", [graphistry.n({"vip": True}), [graphistry.n({"vip": True})]]) +def test_collection_set_wraps_ast_expr(expr): + collection = collection_set(expr=expr, id="vip") + assert collection["expr"]["type"] == "gfql_chain" + assert collection["expr"]["gfql"][0]["type"] == "Node" + + +def test_collection_helpers_build_sets_and_intersections(): + collections = [ + collection_set(expr=[graphistry.n({"vip": True})], id="vip", name="VIP", node_color="#FFAA00"), + collection_intersection(sets=["vip"], id="vip_intersection", name="VIP Intersection", node_color="#00BFFF"), + ] + decoded = decode_collections(collections_url_params(collections)["collections"]) + assert decoded[0]["type"] == "set" + assert decoded[0]["expr"]["type"] == "gfql_chain" + assert decoded[1]["expr"] == {"type": "intersection", "sets": ["vip"]} + + +def test_collections_accepts_chain_and_preserves_dataset_id(): + node = graphistry.n({"type": "user"}) + chain = graphistry.Chain([node]) + g2 = graphistry.bind(dataset_id="dataset_123").collections(collections={"type": "set", "id": "my_set", "expr": chain}) + decoded = decode_collections(g2._url_params["collections"]) + assert decoded == [ + { + "type": "set", + "id": "my_set", + "expr": { + "type": "gfql_chain", + "gfql": [node.to_json()], + }, + } + ] + assert g2._dataset_id == "dataset_123" + + +def test_collections_string_input_is_encoded(): + # Include a set so the intersection has valid references + raw = '[{"type":"set","id":"a","expr":{"type":"gfql_chain","gfql":[{"type":"Node"}]}},{"type":"intersection","id":"b","expr":{"type":"intersection","sets":["a"]}}]' + url_params = collections_url_params(raw) + assert url_params["collections"].startswith("%5B") + decoded = decode_collections(url_params["collections"]) + # Node normalizes to include filter_dict: {} + assert decoded == [ + { + "type": "set", + "id": "a", + "expr": {"type": "gfql_chain", "gfql": [{"type": "Node", "filter_dict": {}}]}, + }, + { + "type": "intersection", + "id": "b", + "expr": {"type": "intersection", "sets": ["a"]}, + } + ] + + +def test_collections_accepts_wire_protocol_chain(): + chain_json = { + "type": "Chain", + "chain": [ + { + "type": "Node", + "filter_dict": { + "type": "user" + } + } + ] + } + decoded = decode_collections( + collections_url_params({"type": "set", "id": "users", "expr": chain_json})["collections"] + ) + assert decoded == [ + { + "type": "set", + "id": "users", + "expr": { + "type": "gfql_chain", + "gfql": chain_json["chain"], + }, + } + ] + + +def test_collections_accepts_let_expr(): + dag = graphistry.let({"seed": graphistry.n({"type": "user"})}) + decoded = decode_collections( + collections_url_params({"type": "set", "id": "users", "expr": dag})["collections"] + ) + assert decoded[0]["expr"]["type"] == "gfql_chain" + assert decoded[0]["expr"]["gfql"][0]["type"] == "Let" + + +def test_collections_drop_unexpected_fields_autofix(): + collections = [ + { + "type": "set", + "id": "vip_set", + "expr": [graphistry.n({"vip": True})], + "unexpected": "drop-me", + } + ] + decoded = decode_collections( + collections_url_params(collections, validate="autofix", warn=False)["collections"] + ) + assert "unexpected" not in decoded[0] + + +def test_collections_show_collections_coerces_autofix(): + g2 = graphistry.bind().collections(show_collections="true", validate="autofix") + assert g2._url_params["showCollections"] is True + + +def test_collections_show_collections_strict_raises(): + with pytest.raises(ValueError): + graphistry.bind().collections(show_collections="maybe", validate="strict") + + +def test_collections_validation_strict_raises(): + # Missing 'type' field in GFQL op causes validation error + bad_collections = [{"type": "set", "id": "bad_set", "expr": [{"filter_dict": {"a": 1}}]}] + with pytest.raises(ValueError): + graphistry.bind().collections(collections=bad_collections, validate="strict") + + +def test_collections_autofix_drops_invalid_colors(): + collections = [ + { + "type": "set", + "id": "vip_set", + "expr": [graphistry.n({"vip": True})], + "node_color": 123, + "edge_color": {"bad": True}, + } + ] + with pytest.warns(RuntimeWarning): + url_params = collections_url_params(collections, validate="autofix", warn=True) + decoded = decode_collections(url_params["collections"]) + assert "node_color" not in decoded[0] + assert "edge_color" not in decoded[0] + + +def test_collections_autofix_drops_invalid_gfql_ops(): + # Collection with invalid GFQL op (missing 'type' field) gets dropped in autofix + collections = [ + { + "type": "set", + "id": "bad_set", + "expr": [graphistry.n({"vip": True}), {"filter_dict": {"a": 1}}], + } + ] + with pytest.warns(RuntimeWarning): + url_params = collections_url_params(collections, validate="autofix", warn=True) + # Collection dropped due to invalid GFQL, so no collections key or empty + assert "collections" not in url_params or decode_collections(url_params["collections"]) == [] + + +def test_plot_url_param_validation_autofix_warns(): + bad = '[{"type":"set","expr":[{"filter_dict":{"a":1}}]}]' + with pytest.warns(RuntimeWarning): + normalized = normalize_collections_url_params({"collections": bad}, validate="autofix", warn=True) + assert "collections" not in normalized or normalized["collections"].startswith("%5B") + + +def test_collections_autofix_generates_missing_ids(): + # Collections without IDs get auto-generated IDs in autofix mode (kebab-case) + collections = [ + {"type": "set", "expr": [graphistry.n({"a": 1})]}, + {"type": "intersection", "expr": {"type": "intersection", "sets": ["set-0"]}}, + ] + with pytest.warns(RuntimeWarning): + url_params = collections_url_params(collections, validate="autofix", warn=True) + decoded = decode_collections(url_params["collections"]) + assert decoded[0]["id"] == "set-0" + assert decoded[1]["id"] == "intersection-1" + + +def test_collections_intersection_of_intersections(): + # Backend supports intersections-of-intersections (DAG structure) + collections = [ + {"type": "set", "id": "set_a", "expr": [graphistry.n({"a": 1})]}, + {"type": "set", "id": "set_b", "expr": [graphistry.n({"b": 1})]}, + {"type": "intersection", "id": "inter_ab", "expr": {"type": "intersection", "sets": ["set_a", "set_b"]}}, + {"type": "intersection", "id": "inter_of_inter", "expr": {"type": "intersection", "sets": ["set_a", "inter_ab"]}}, + ] + url_params = collections_url_params(collections) + decoded = decode_collections(url_params["collections"]) + assert len(decoded) == 4 + assert decoded[3]["id"] == "inter_of_inter" + assert decoded[3]["expr"]["sets"] == ["set_a", "inter_ab"] + + +def test_collections_intersection_self_reference_rejected(): + # Intersection cannot reference itself + collections = [ + {"type": "set", "id": "set_a", "expr": [graphistry.n({"a": 1})]}, + {"type": "intersection", "id": "bad_inter", "expr": {"type": "intersection", "sets": ["set_a", "bad_inter"]}}, + ] + with pytest.raises(ValueError): + collections_url_params(collections, validate="strict") + + +def test_collections_intersection_cycle_rejected(): + # Cycles in intersection DAG are rejected + collections = [ + {"type": "set", "id": "set_a", "expr": [graphistry.n({"a": 1})]}, + {"type": "intersection", "id": "inter_a", "expr": {"type": "intersection", "sets": ["set_a", "inter_b"]}}, + {"type": "intersection", "id": "inter_b", "expr": {"type": "intersection", "sets": ["set_a", "inter_a"]}}, + ] + with pytest.raises(ValueError): + collections_url_params(collections, validate="strict") + + +def test_collections_intersection_cycle_autofix_drops(): + # In autofix mode, cyclic intersections are dropped + collections = [ + {"type": "set", "id": "set_a", "expr": [graphistry.n({"a": 1})]}, + {"type": "intersection", "id": "inter_a", "expr": {"type": "intersection", "sets": ["set_a", "inter_b"]}}, + {"type": "intersection", "id": "inter_b", "expr": {"type": "intersection", "sets": ["set_a", "inter_a"]}}, + ] + with pytest.warns(RuntimeWarning): + url_params = collections_url_params(collections, validate="autofix", warn=True) + decoded = decode_collections(url_params["collections"]) + # Both cyclic intersections dropped, only set remains + assert len(decoded) == 1 + assert decoded[0]["id"] == "set_a" + + +def test_collections_malformed_ast_autofix_drops(): + # AST from_json uses bare asserts - these should be caught, not crash + # {"type": "Let"} missing required 'bindings' field + from graphistry.validate.validate_collections import normalize_collections + collections = [ + {"type": "set", "id": "good", "expr": [{"type": "Node"}]}, + {"type": "set", "id": "bad-let", "expr": [{"type": "Let"}]}, # missing bindings + ] + result = normalize_collections(collections, validate="autofix", warn=False) + ids = [c.get("id") for c in result] + assert "good" in ids + assert "bad-let" not in ids + + +def test_collections_malformed_ast_strict_raises(): + from graphistry.validate.validate_collections import normalize_collections + collections = [{"type": "set", "id": "bad", "expr": [{"type": "Let"}]}] + with pytest.raises(ValueError): + normalize_collections(collections, validate="strict") diff --git a/graphistry/validate/validate_collections.py b/graphistry/validate/validate_collections.py new file mode 100644 index 0000000000..6fc87e2643 --- /dev/null +++ b/graphistry/validate/validate_collections.py @@ -0,0 +1,488 @@ +import json +from typing import Any, Dict, List, Optional, Tuple, Union +from urllib.parse import quote, unquote + +from graphistry.client_session import strtobool +from graphistry.compute.exceptions import GFQLSyntaxError, GFQLValidationError +from graphistry.models.collections import Collection, CollectionsInput +from graphistry.models.types import ValidationMode, ValidationParam +from graphistry.util import warn as emit_warn +_ALLOWED_COLLECTION_FIELDS_ORDER = ( + 'type', + 'id', + 'name', + 'description', + 'node_color', + 'edge_color', + 'expr', +) +_ALLOWED_COLLECTION_FIELDS_SET = set(_ALLOWED_COLLECTION_FIELDS_ORDER) + + +def normalize_validation_params( + validate: ValidationParam = 'autofix', + warn: bool = True +) -> Tuple[ValidationMode, bool]: + if validate is True: + validate_mode: ValidationMode = 'strict' + elif validate is False: + validate_mode = 'autofix' + warn = False + else: + validate_mode = validate + return validate_mode, warn + + +def encode_collections(collections: List[Dict[str, Any]]) -> str: + json_str = json.dumps(collections, separators=(',', ':'), ensure_ascii=True) + return quote(json_str, safe='') + + +def _issue( + message: str, + data: Optional[Dict[str, Any]], + validate_mode: ValidationMode, + warn: bool +) -> None: + error = ValueError({'message': message, 'data': data} if data else {'message': message}) + if validate_mode in ('strict', 'strict-fast'): + raise error + if warn and validate_mode == 'autofix': + emit_warn(f"Collections validation warning: {message} ({data})") + + +def _parse_collections_input( + collections: CollectionsInput, + validate_mode: ValidationMode, + warn: bool +) -> Union[List[Dict[str, Any]], List[Collection]]: + """Parse collections input to a list of dicts, handling list/dict/JSON string inputs.""" + if isinstance(collections, list): + return collections + if isinstance(collections, dict): + return [collections] + if isinstance(collections, str): + try: + parsed = json.loads(collections) + except json.JSONDecodeError: + try: + parsed = json.loads(unquote(collections)) + except json.JSONDecodeError as exc: + _issue('Collections string must be JSON or URL-encoded JSON', {'error': str(exc)}, validate_mode, warn) + return [] + # Coerce parsed JSON to list + if isinstance(parsed, list): + return parsed + if isinstance(parsed, dict): + return [parsed] + _issue('Collections JSON must be a list or dict', {'type': type(parsed).__name__}, validate_mode, warn) + return [] + _issue('Collections must be a list, dict, or JSON string', {'type': type(collections).__name__}, validate_mode, warn) + return [] + + +def _normalize_str_field( + entry: Dict[str, Any], + key: str, + validate_mode: ValidationMode, + warn: bool, + entry_index: int, + autofix_drop: bool +) -> None: + if key not in entry or entry[key] is None: + return + if isinstance(entry[key], str): + return + _issue( + f'Collection field "{key}" should be a string', + {'index': entry_index, 'value': entry[key], 'type': type(entry[key]).__name__}, + validate_mode, + warn + ) + if validate_mode == 'autofix': + if autofix_drop: + entry.pop(key, None) + else: + entry[key] = str(entry[key]) + + +def _normalize_sets_list( + sets_value: Any, + validate_mode: ValidationMode, + warn: bool, + entry_index: int +) -> Optional[List[str]]: + if not isinstance(sets_value, list): + _issue( + 'Intersection sets must be a list of strings', + {'index': entry_index, 'value': sets_value, 'type': type(sets_value).__name__}, + validate_mode, + warn + ) + return None + out: List[str] = [] + for set_id in sets_value: + if isinstance(set_id, str): + out.append(set_id) + continue + _issue( + 'Intersection set IDs must be strings', + {'index': entry_index, 'value': set_id, 'type': type(set_id).__name__}, + validate_mode, + warn + ) + if validate_mode == 'autofix': + out.append(str(set_id)) + + if len(out) == 0: + _issue( + 'Intersection sets list cannot be empty', + {'index': entry_index}, + validate_mode, + warn + ) + return None + + return out + + +def _normalize_gfql_ops( + gfql_ops: Any, + validate_mode: ValidationMode, + warn: bool, + entry_index: int +) -> Optional[List[Dict[str, Any]]]: + """ + Normalize GFQL operations to a list of JSON-serializable dicts. + + Uses normalize_gfql_to_wire from compute/ast.py as the canonical implementation, + wrapping with error handling for validation modes. + """ + if gfql_ops is None: + _issue('GFQL chain is missing', {'index': entry_index}, validate_mode, warn) + return None + + # Handle JSON string input + if isinstance(gfql_ops, str): + try: + gfql_ops = json.loads(gfql_ops) + except json.JSONDecodeError as exc: + _issue('GFQL chain string must be JSON', {'index': entry_index, 'error': str(exc)}, validate_mode, warn) + return None + + # Use canonical implementation from compute/ast.py + try: + from graphistry.compute.ast import normalize_gfql_to_wire + return normalize_gfql_to_wire(gfql_ops) + except (TypeError, ValueError, AssertionError, GFQLValidationError, GFQLSyntaxError) as exc: + # AssertionError: AST from_json methods use bare asserts for required fields + _issue( + 'Invalid GFQL operation in collection', + {'index': entry_index, 'error': str(exc)}, + validate_mode, + warn + ) + return None + + +def _normalize_gfql_expr( + expr: Any, + validate_mode: ValidationMode, + warn: bool, + entry_index: int +) -> Optional[Dict[str, Any]]: + if isinstance(expr, dict) and expr.get('type') == 'intersection': + _issue('Set collection expr cannot be intersection', {'index': entry_index}, validate_mode, warn) + return None + ops = _normalize_gfql_ops(expr, validate_mode, warn, entry_index) + if ops is None: + return None + return {'type': 'gfql_chain', 'gfql': ops} + + +def _normalize_intersection_expr( + expr: Any, + validate_mode: ValidationMode, + warn: bool, + entry_index: int +) -> Optional[Dict[str, Any]]: + if not isinstance(expr, dict): + _issue('Intersection expr must be a dict', {'index': entry_index}, validate_mode, warn) + return None + expr_type = expr.get('type', 'intersection') + if expr_type != 'intersection': + _issue( + 'Intersection expr type must be "intersection"', + {'index': entry_index, 'value': expr_type}, + validate_mode, + warn + ) + return None + sets_value = expr.get('sets', expr.get('intersection')) + if sets_value is None: + _issue('Intersection expr missing "sets"', {'index': entry_index}, validate_mode, warn) + return None + sets_list = _normalize_sets_list(sets_value, validate_mode, warn, entry_index) + if sets_list is None: + return None + return {'type': 'intersection', 'sets': sets_list} + + +def normalize_collections( + collections: CollectionsInput, + validate: ValidationParam = 'autofix', + warn: bool = True +) -> List[Dict[str, Any]]: + validate_mode, warn = normalize_validation_params(validate, warn) + items = _parse_collections_input(collections, validate_mode, warn) + + normalized: List[Dict[str, Any]] = [] + for idx, entry in enumerate(items): + if not isinstance(entry, dict): + _issue( + 'Collection entries must be dictionaries', + {'index': idx, 'type': type(entry).__name__}, + validate_mode, + warn + ) + if validate_mode == 'autofix': + continue + return [] + + # Convert to plain dict for uniform handling (TypedDicts become regular dicts) + entry_dict: Dict[str, Any] = dict(entry) + + unexpected_fields = [key for key in entry_dict.keys() if key not in _ALLOWED_COLLECTION_FIELDS_SET] + if unexpected_fields: + _issue( + 'Unexpected fields in collection', + {'index': idx, 'fields': unexpected_fields}, + validate_mode, + warn + ) + + normalized_entry = {key: entry_dict[key] for key in _ALLOWED_COLLECTION_FIELDS_ORDER if key in entry_dict} + collection_type = normalized_entry.get('type', 'set') + if not isinstance(collection_type, str): + _issue( + 'Collection type must be a string', + {'index': idx, 'value': collection_type, 'type': type(collection_type).__name__}, + validate_mode, + warn + ) + # str() coercion is pointless - it won't produce 'set' or 'intersection' + # so we skip this entry in autofix mode, or fail in strict mode + if validate_mode == 'autofix': + continue + return [] + collection_type = collection_type.lower() + + if collection_type not in ('set', 'intersection'): + _issue( + 'Collection type must be "set" or "intersection"', + {'index': idx, 'value': collection_type}, + validate_mode, + warn + ) + if validate_mode == 'autofix': + continue + return [] + + normalized_entry['type'] = collection_type + + for field in ('id', 'name', 'description'): + _normalize_str_field(normalized_entry, field, validate_mode, warn, idx, autofix_drop=False) + for field in ('node_color', 'edge_color'): + _normalize_str_field(normalized_entry, field, validate_mode, warn, idx, autofix_drop=True) + + # Validate id field - required by server for all collection types (used as storage key) + if 'id' not in normalized_entry or normalized_entry.get('id') is None: + _issue( + f'{collection_type.capitalize()} collection missing id field (server requires it for storage)', + {'index': idx}, + validate_mode, + warn + ) + if validate_mode == 'autofix': + # Auto-generate ID so collection still works + normalized_entry['id'] = f'{collection_type}-{idx}' + else: + continue + + expr = normalized_entry.get('expr') + if collection_type == 'intersection': + normalized_expr = _normalize_intersection_expr(expr, validate_mode, warn, idx) + else: + normalized_expr = _normalize_gfql_expr(expr, validate_mode, warn, idx) + if normalized_expr is None: + if validate_mode == 'autofix': + continue + return [] + normalized_entry['expr'] = normalized_expr + normalized_entry = { + key: normalized_entry[key] + for key in _ALLOWED_COLLECTION_FIELDS_ORDER + if key in normalized_entry + } + normalized.append(normalized_entry) + + # Cross-validate intersection set references + normalized = _validate_intersection_references(normalized, validate_mode, warn) + + return normalized + + +def _validate_intersection_references( + collections: List[Dict[str, Any]], + validate_mode: ValidationMode, + warn: bool +) -> List[Dict[str, Any]]: + """ + Validate intersection references form a valid DAG. + + Checks: + 1. All referenced IDs exist as 'set' or 'intersection' collections + 2. No self-references (intersection referencing itself) + 3. No cycles (A->B->A) + + Dangling/cyclic references cause backend errors ("Infinite loop detected"). + In strict mode, raise on first issue. In autofix mode, drop invalid intersections. + """ + # Build ID -> collection type mapping for valid reference targets + collection_ids: Dict[str, str] = {} + for c in collections: + cid = c.get('id') + ctype = c.get('type') + if cid and ctype in ('set', 'intersection'): + collection_ids[cid] = ctype + + # Build dependency graph for cycle detection + # intersection_id -> set of IDs it references + dependencies: Dict[str, List[str]] = {} + for c in collections: + if c.get('type') == 'intersection' and c.get('id'): + expr = c.get('expr', {}) + dependencies[c['id']] = expr.get('sets', []) + + def has_cycle(start_id: str, visited: set, path: set) -> bool: + """DFS cycle detection.""" + if start_id in path: + return True + if start_id not in dependencies: + return False # It's a set, not an intersection - no further deps + if start_id in visited: + return False + + visited.add(start_id) + path.add(start_id) + + for dep_id in dependencies.get(start_id, []): + if has_cycle(dep_id, visited, path): + return True + + path.remove(start_id) + return False + + valid_collections: List[Dict[str, Any]] = [] + for idx, collection in enumerate(collections): + if collection.get('type') == 'intersection': + coll_id = collection.get('id') + expr = collection.get('expr', {}) + referenced_ids = expr.get('sets', []) + + # Check for self-reference + if coll_id and coll_id in referenced_ids: + _issue( + 'Intersection references itself', + {'index': idx, 'id': coll_id}, + validate_mode, + warn + ) + if validate_mode == 'autofix': + continue + return [] + + # Check all referenced IDs exist and are valid types + missing = [sid for sid in referenced_ids if sid not in collection_ids] + if missing: + _issue( + 'Intersection references non-existent collection IDs', + {'index': idx, 'missing': missing, 'available': list(collection_ids.keys())}, + validate_mode, + warn + ) + if validate_mode == 'autofix': + continue + return [] + + # Check for cycles + if coll_id and has_cycle(coll_id, set(), set()): + _issue( + 'Intersection creates a dependency cycle', + {'index': idx, 'id': coll_id}, + validate_mode, + warn + ) + if validate_mode == 'autofix': + continue + return [] + + valid_collections.append(collection) + + return valid_collections + + +def normalize_collections_url_params( + url_params: Dict[str, Any], + validate: ValidationParam = 'autofix', + warn: bool = True +) -> Dict[str, Any]: + validate_mode, warn = normalize_validation_params(validate, warn) + updated = dict(url_params) + + if 'collections' in updated: + normalized = normalize_collections(updated['collections'], validate_mode, warn) + if len(normalized) > 0: + updated['collections'] = encode_collections(normalized) + else: + if validate_mode in ('strict', 'strict-fast'): + return updated + updated.pop('collections', None) + + if 'showCollections' in updated: + value = updated['showCollections'] + if isinstance(value, bool): + pass + else: + try: + updated['showCollections'] = strtobool(value) + except Exception as exc: + _issue( + 'showCollections must be a boolean', + {'value': value, 'error': str(exc)}, + validate_mode, + warn + ) + if validate_mode == 'autofix': + updated.pop('showCollections', None) + + for color_key in ('collectionsGlobalNodeColor', 'collectionsGlobalEdgeColor'): + if color_key in updated: + value = updated[color_key] + if value is None: + updated.pop(color_key, None) + continue + if not isinstance(value, str): + _issue( + f'{color_key} must be a string', + {'value': value, 'type': type(value).__name__}, + validate_mode, + warn + ) + if validate_mode == 'autofix': + value = str(value) + if isinstance(value, str) and value.startswith('#'): + value = value[1:] + updated[color_key] = value + + return updated