diff --git a/docs/design/thread-safe-mode.md b/docs/design/thread-safe-mode.md deleted file mode 100644 index 5d7472667..000000000 --- a/docs/design/thread-safe-mode.md +++ /dev/null @@ -1,387 +0,0 @@ -# Thread-Safe Mode Specification - -## Problem - -DataJoint uses global state (`dj.config`, `dj.conn()`) that is not thread-safe. Multi-tenant applications (web servers, async workers) need isolated connections per request/task. - -## Solution - -Introduce **Instance** objects that encapsulate config and connection. The `dj` module provides a global config that can be modified before connecting, and a lazily-loaded singleton connection. New isolated instances are created with `dj.Instance()`. - -## API - -### Legacy API (global config + singleton connection) - -```python -import datajoint as dj - -# Configure credentials (no connection yet) -dj.config.database.user = "user" -dj.config.database.password = "password" - -# First call to conn() or Schema() creates the singleton connection -dj.conn() # Creates connection using dj.config credentials -schema = dj.Schema("my_schema") - -@schema -class Mouse(dj.Manual): - definition = "..." -``` - -Alternatively, pass credentials directly to `conn()`: -```python -dj.conn(host="localhost", user="user", password="password") -``` - -Internally: -- `dj.config` → delegates to `_global_config` (with thread-safety check) -- `dj.conn()` → returns `_singleton_connection` (created lazily) -- `dj.Schema()` → uses `_singleton_connection` -- `dj.FreeTable()` → uses `_singleton_connection` - -### New API (isolated instance) - -```python -import datajoint as dj - -inst = dj.Instance( - host="localhost", - user="user", - password="password", -) -schema = inst.Schema("my_schema") - -@schema -class Mouse(dj.Manual): - definition = "..." -``` - -### Instance structure - -Each instance has: -- `inst.config` - Config (created fresh at instance creation) -- `inst.connection` - Connection (created at instance creation) -- `inst.Schema()` - Schema factory using instance's connection -- `inst.FreeTable()` - FreeTable factory using instance's connection - -```python -inst = dj.Instance(host="localhost", user="u", password="p") -inst.config # Config instance -inst.connection # Connection instance -inst.Schema("name") # Creates schema using inst.connection -inst.FreeTable("db.tbl") # Access table using inst.connection -``` - -### Table base classes vs instance methods - -**Base classes** (`dj.Manual`, `dj.Lookup`, etc.) - Used with `@schema` decorator: -```python -@schema -class Mouse(dj.Manual): # dj.Manual - schema links to connection - definition = "..." -``` - -**Instance methods** (`inst.Schema()`, `inst.FreeTable()`) - Need connection directly: -```python -schema = inst.Schema("my_schema") # Uses inst.connection -table = inst.FreeTable("db.table") # Uses inst.connection -``` - -### Thread-safe mode - -```bash -export DJ_THREAD_SAFE=true -``` - -`thread_safe` is checked dynamically on each access to global state. - -When `thread_safe=True`, accessing global state raises `ThreadSafetyError`: -- `dj.config` raises `ThreadSafetyError` -- `dj.conn()` raises `ThreadSafetyError` -- `dj.Schema()` raises `ThreadSafetyError` (without explicit connection) -- `dj.FreeTable()` raises `ThreadSafetyError` (without explicit connection) -- `dj.Instance()` works - isolated instances are always allowed - -```python -# thread_safe=True - -dj.config # ThreadSafetyError -dj.conn() # ThreadSafetyError -dj.Schema("name") # ThreadSafetyError - -inst = dj.Instance(host="h", user="u", password="p") # OK -inst.Schema("name") # OK -``` - -## Behavior Summary - -| Operation | `thread_safe=False` | `thread_safe=True` | -|-----------|--------------------|--------------------| -| `dj.config` | `_global_config` | `ThreadSafetyError` | -| `dj.conn()` | `_singleton_connection` | `ThreadSafetyError` | -| `dj.Schema()` | Uses singleton | `ThreadSafetyError` | -| `dj.FreeTable()` | Uses singleton | `ThreadSafetyError` | -| `dj.Instance()` | Works | Works | -| `inst.config` | Works | Works | -| `inst.connection` | Works | Works | -| `inst.Schema()` | Works | Works | - -## Lazy Loading - -The global config is created at module import time. The singleton connection is created lazily on first access: - -```python -dj.config.database.user = "user" # Modifies global config (no connection yet) -dj.config.database.password = "pw" -dj.conn() # Creates singleton connection using global config -dj.Schema("name") # Uses existing singleton connection -``` - -## Usage Example - -```python -import datajoint as dj - -# Create isolated instance -inst = dj.Instance( - host="localhost", - user="user", - password="password", -) - -# Create schema -schema = inst.Schema("my_schema") - -@schema -class Mouse(dj.Manual): - definition = """ - mouse_id: int - """ - -# Use tables -Mouse().insert1({"mouse_id": 1}) -Mouse().fetch() -``` - -## Architecture - -### Object graph - -There is exactly **one** global `Config` object created at import time in `settings.py`. Both the legacy API and the `Instance` API hang off `Connection` objects, each of which carries a `_config` reference. - -``` -settings.py - config = _create_config() ← THE single global Config - -instance.py - _global_config = settings.config ← same object (not a copy) - _singleton_connection = None ← lazily created Connection - -__init__.py - dj.config = _ConfigProxy() ← proxy → _global_config (with thread-safety check) - dj.conn() ← returns _singleton_connection - dj.Schema() ← uses _singleton_connection - dj.FreeTable() ← uses _singleton_connection - -Connection (singleton) - _config → _global_config ← same Config that dj.config writes to - -Connection (Instance) - _config → fresh Config ← isolated per-instance -``` - -### Config flow: singleton path - -``` -dj.config["safemode"] = False - ↓ _ConfigProxy.__setitem__ -_global_config["safemode"] = False (same object as settings.config) - ↓ -Connection._config["safemode"] (points to _global_config) - ↓ -schema.drop() reads self.connection._config["safemode"] → False ✓ -``` - -### Config flow: Instance path - -``` -inst = dj.Instance(host=..., user=..., password=...) - ↓ -inst.config = _create_config() (fresh Config, independent) -inst.connection._config = inst.config - ↓ -inst.config["safemode"] = False - ↓ -schema.drop() reads self.connection._config["safemode"] → False ✓ -``` - -### Key invariant - -**All runtime config reads go through `self.connection._config`**, never through the global `config` directly. This ensures both the singleton and Instance paths read the correct config. - -### Connection-scoped config reads - -Every module that previously imported `from .settings import config` now reads config from the connection: - -| Module | What was read | How it's read now | -|--------|--------------|-------------------| -| `schemas.py` | `config["safemode"]`, `config.database.create_tables` | `self.connection._config[...]` | -| `table.py` | `config["safemode"]` in `delete()`, `drop()` | `self.connection._config["safemode"]` | -| `expression.py` | `config["loglevel"]` in `__repr__()` | `self.connection._config["loglevel"]` | -| `preview.py` | `config["display.*"]` (8 reads) | `query_expression.connection._config[...]` | -| `autopopulate.py` | `config.jobs.allow_new_pk_fields`, `auto_refresh` | `self.connection._config.jobs.*` | -| `jobs.py` | `config.jobs.default_priority`, `stale_timeout`, `keep_completed` | `self.connection._config.jobs.*` | -| `declare.py` | `config.jobs.add_job_metadata` | `config` param (threaded from `table.py`) | -| `diagram.py` | `config.display.diagram_direction` | `self._connection._config.display.*` | -| `staged_insert.py` | `config.get_store_spec()` | `self._table.connection._config.get_store_spec()` | -| `hash_registry.py` | `config.get_store_spec()` in 5 functions | `config` kwarg (falls back to `settings.config`) | -| `builtin_codecs/hash.py` | `config` via hash_registry | `_config` from key dict → `config` kwarg to hash_registry | -| `builtin_codecs/attach.py` | `config.get("download_path")` | `_config` from key dict (falls back to `settings.config`) | -| `builtin_codecs/filepath.py` | `config.get_store_spec()` | `_config` from key dict (falls back to `settings.config`) | -| `builtin_codecs/schema.py` | `config.get_store_spec()` in helpers | `config` kwarg to `_build_path()`, `_get_backend()` | -| `builtin_codecs/npy.py` | `config` via schema helpers | `_config` from key dict → `config` kwarg to helpers | -| `builtin_codecs/object.py` | `config` via schema helpers | `_config` from key dict → `config` kwarg to helpers | -| `gc.py` | `config` via hash_registry | `schemas[0].connection._config` → `config` kwarg | - -### Functions that receive config as a parameter - -Some module-level functions cannot access `self.connection`. Config is threaded through: - -| Function | Caller | How config arrives | -|----------|--------|--------------------| -| `declare()` in `declare.py` | `Table.declare()` in `table.py` | `config=self.connection._config` kwarg | -| `_get_job_version()` in `jobs.py` | `AutoPopulate._make_tuples()`, `Job.reserve()` | `config=self.connection._config` positional arg | -| `get_store_backend()` in `hash_registry.py` | codecs, gc.py | `config` kwarg from key dict or schema connection | -| `get_store_subfolding()` in `hash_registry.py` | `put_hash()` | `config` kwarg chained from caller | -| `put_hash()` in `hash_registry.py` | `HashCodec.encode()` | `config` kwarg from `_config` in key dict | -| `get_hash()` in `hash_registry.py` | `HashCodec.decode()` | `config` kwarg from `_config` in key dict | -| `delete_path()` in `hash_registry.py` | `gc.collect()` | `config` kwarg from `schemas[0].connection._config` | -| `decode_attribute()` in `codecs.py` | `expression.py` fetch methods | `connection` kwarg → extracts `connection._config` | - -All functions accept `config=None` and fall back to the global `settings.config` for backward compatibility. - -## Implementation - -### 1. Create Instance class - -```python -class Instance: - def __init__(self, host, user, password, port=3306, **kwargs): - self.config = _create_config() # Fresh config with defaults - # Apply any config overrides from kwargs - self.connection = Connection(host, user, password, port, ...) - self.connection._config = self.config - - def Schema(self, name, **kwargs): - return Schema(name, connection=self.connection, **kwargs) - - def FreeTable(self, full_table_name): - return FreeTable(self.connection, full_table_name) -``` - -### 2. Global config and singleton connection - -```python -# settings.py - THE single global config -config = _create_config() # Created at import time - -# instance.py - reuses the same config object -_global_config = settings.config # Same reference, not a copy -_singleton_connection = None # Created lazily - -def _check_thread_safe(): - if _load_thread_safe(): - raise ThreadSafetyError( - "Global DataJoint state is disabled in thread-safe mode. " - "Use dj.Instance() to create an isolated instance." - ) - -def _get_singleton_connection(): - _check_thread_safe() - global _singleton_connection - if _singleton_connection is None: - _singleton_connection = Connection( - host=_global_config.database.host, - user=_global_config.database.user, - password=_global_config.database.password, - ... - ) - _singleton_connection._config = _global_config - return _singleton_connection -``` - -### 3. Legacy API with thread-safety checks - -```python -# dj.config -> global config with thread-safety check -class _ConfigProxy: - def __getattr__(self, name): - _check_thread_safe() - return getattr(_global_config, name) - def __setattr__(self, name, value): - _check_thread_safe() - setattr(_global_config, name, value) - -config = _ConfigProxy() - -# dj.conn() -> singleton connection (persistent across calls) -def conn(host=None, user=None, password=None, *, reset=False): - _check_thread_safe() - if reset or (_singleton_connection is None and credentials_provided): - _singleton_connection = Connection(...) - _singleton_connection._config = _global_config - return _get_singleton_connection() - -# dj.Schema() -> uses singleton connection -def Schema(name, connection=None, **kwargs): - if connection is None: - _check_thread_safe() - connection = _get_singleton_connection() - return _Schema(name, connection=connection, **kwargs) - -# dj.FreeTable() -> uses singleton connection -def FreeTable(conn_or_name, full_table_name=None): - if full_table_name is None: - _check_thread_safe() - return _FreeTable(_get_singleton_connection(), conn_or_name) - else: - return _FreeTable(conn_or_name, full_table_name) -``` - -## Global State Audit - -All module-level mutable state was reviewed for thread-safety implications. - -### Guarded (blocked in thread-safe mode) - -| State | Location | Mechanism | -|-------|----------|-----------| -| `config` singleton | `settings.py:979` | `_ConfigProxy` raises `ThreadSafetyError`; use `inst.config` instead | -| `conn()` singleton | `connection.py:108` | `_check_thread_safe()` guard; use `inst.connection` instead | - -These are the two globals that carry connection-scoped state (credentials, database settings) and are the primary source of cross-tenant interference. - -### Safe by design (no guard needed) - -| State | Location | Rationale | -|-------|----------|-----------| -| `_codec_registry` | `codecs.py:47` | Effectively immutable after import. Registration runs in `__init_subclass__` under Python's import lock. Runtime mutation (`_load_entry_points`) is idempotent under the GIL. Codecs are part of the type system, not connection-scoped. | -| `_entry_points_loaded` | `codecs.py:48` | Bool flag for idempotent lazy loading; worst case under concurrent access is redundant work, not corruption. | - -### Low risk (no guard needed) - -| State | Location | Rationale | -|-------|----------|-----------| -| Logging side effects | `logging.py:8,17,40-45,56` | Standard Python logging configuration. Monkey-patches `Logger` and replaces `sys.excepthook` at import time. Not DataJoint-specific mutable state. | -| `use_32bit_dims` | `blob.py:65` | Runtime flag affecting deserialization. Rarely changed; not connection-scoped. | -| `compression` dict | `blob.py:61` | Decompressor function registry. Populated at import time, effectively read-only thereafter. | -| `_lazy_modules` | `__init__.py:92` | Import caching via `globals()` mutation. Protected by Python's import lock. | -| `ADAPTERS` dict | `adapters/__init__.py:16` | Backend registry. Populated at import time, read-only in practice. | - -### Design principle - -Only state that is **connection-scoped** (credentials, database settings, connection objects) needs thread-safe guards. State that is **code-scoped** (type registries, import caches, logging configuration) is shared across all threads by design and does not vary between tenants. - -## Error Messages - -- Singleton access: `"Global DataJoint state is disabled in thread-safe mode. Use dj.Instance() to create an isolated instance."` diff --git a/pixi.lock b/pixi.lock index c425c2176..0421929da 100644 --- a/pixi.lock +++ b/pixi.lock @@ -2092,8 +2092,8 @@ packages: requires_python: '>=3.8' - pypi: ./ name: datajoint - version: 2.1.1 - sha256: 267defaa9ea7f22a8497568e8a14679be178f78cd3b34a4132609a57f0f71227 + version: 2.2.0.dev0 + sha256: 48335cedf96fa3b5efd3ddf880bd5065813f2baea43cad01a2fddbba94e561ec requires_dist: - deepdiff - fsspec>=2023.1.0 diff --git a/pyproject.toml b/pyproject.toml index 20832342b..5bf25dc29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -235,6 +235,7 @@ markers = [ ] + [tool.pixi.workspace] channels = ["conda-forge"] platforms = ["linux-64", "osx-arm64", "linux-aarch64"] diff --git a/src/datajoint/adapters/base.py b/src/datajoint/adapters/base.py index 5a595001b..da4779543 100644 --- a/src/datajoint/adapters/base.py +++ b/src/datajoint/adapters/base.py @@ -647,7 +647,7 @@ def create_index_ddl( # Generate index name from table and columns if not provided if index_name is None: # Extract table name from full_table_name for index naming - table_part = full_table_name.split(".")[-1].strip('`"') + _, table_part = self.split_full_table_name(full_table_name) col_part = "_".join(columns)[:30] # Truncate for long column lists index_name = f"idx_{table_part}_{col_part}" unique_clause = "UNIQUE " if unique else "" @@ -830,6 +830,26 @@ def load_foreign_keys_sql(self, schemas_list: str, like_pattern: str) -> str: """ ... + def find_downstream_schemas_sql(self, schemas_list: str) -> str: + """ + Generate query to find schemas with FK references to the given schemas. + + Used to discover unloaded schemas that depend on loaded ones. + + Parameters + ---------- + schemas_list : str + Comma-separated, quoted schema names for an IN clause. + + Returns + ------- + str + SQL query returning rows with a single column ``schema_name`` + containing distinct schema names that reference the given schemas. + """ + raise NotImplementedError + ... + @abstractmethod def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str: """ diff --git a/src/datajoint/adapters/mysql.py b/src/datajoint/adapters/mysql.py index 3c28a85e6..1888eccf4 100644 --- a/src/datajoint/adapters/mysql.py +++ b/src/datajoint/adapters/mysql.py @@ -687,6 +687,15 @@ def load_foreign_keys_sql(self, schemas_list: str, like_pattern: str) -> str: f"OR referenced_table_schema is not NULL AND table_schema in ({schemas_list}))" ) + def find_downstream_schemas_sql(self, schemas_list: str) -> str: + """Find schemas with FK references to the given schemas.""" + return ( + f"SELECT DISTINCT table_schema as schema_name " + f"FROM information_schema.key_column_usage " + f"WHERE referenced_table_schema IN ({schemas_list}) " + f"AND table_schema NOT IN ({schemas_list})" + ) + def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str: """Query to get FK constraint details from information_schema.""" return ( diff --git a/src/datajoint/adapters/postgres.py b/src/datajoint/adapters/postgres.py index cfb99b728..543e972d3 100644 --- a/src/datajoint/adapters/postgres.py +++ b/src/datajoint/adapters/postgres.py @@ -847,6 +847,20 @@ def load_foreign_keys_sql(self, schemas_list: str, like_pattern: str) -> str: f"ORDER BY c.conname, cols.ord" ) + def find_downstream_schemas_sql(self, schemas_list: str) -> str: + """Find schemas with FK references to the given schemas.""" + return ( + f"SELECT DISTINCT ns1.nspname as schema_name " + f"FROM pg_constraint c " + f"JOIN pg_class cl1 ON c.conrelid = cl1.oid " + f"JOIN pg_namespace ns1 ON cl1.relnamespace = ns1.oid " + f"JOIN pg_class cl2 ON c.confrelid = cl2.oid " + f"JOIN pg_namespace ns2 ON cl2.relnamespace = ns2.oid " + f"WHERE c.contype = 'f' " + f"AND ns2.nspname IN ({schemas_list}) " + f"AND ns1.nspname NOT IN ({schemas_list})" + ) + def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str: """ Query to get FK constraint details from information_schema. diff --git a/src/datajoint/declare.py b/src/datajoint/declare.py index 8807348ad..1370628bc 100644 --- a/src/datajoint/declare.py +++ b/src/datajoint/declare.py @@ -369,8 +369,8 @@ def prepare_declare( adapter, fk_attribute_map, ) - elif re.match(r"^(unique\s+)?index\s*\(.*\)$", line, re.I): # index - compile_index(line, index_sql, adapter) + elif re.match(r"^(unique\s+)?index\s*\(.*\)\s*(#.*)?$", line, re.I): # index + compile_index(re.sub(r"\s*#.*$", "", line), index_sql, adapter) else: name, sql, store, comment = compile_attribute(line, in_key, foreign_key_sql, context, adapter) if store: diff --git a/src/datajoint/dependencies.py b/src/datajoint/dependencies.py index 99556345e..08fb50e1b 100644 --- a/src/datajoint/dependencies.py +++ b/src/datajoint/dependencies.py @@ -140,9 +140,9 @@ def clear(self) -> None: self._node_alias_count = itertools.count() # reset alias IDs for consistency super().clear() - def load(self, force: bool = True) -> None: + def load(self, force: bool = True, schema_names: set[str] | None = None) -> None: """ - Load dependencies for all loaded schemas. + Load dependencies for the given schemas. Called before operations requiring dependencies: delete, drop, populate, progress. @@ -151,6 +151,8 @@ def load(self, force: bool = True) -> None: ---------- force : bool, optional If True (default), reload even if already loaded. + schema_names : set[str], optional + Schema names to load. If None, uses all activated schemas. """ # reload from scratch to prevent duplication of renamed edges if self._loaded and not force: @@ -162,7 +164,11 @@ def load(self, force: bool = True) -> None: adapter = self._conn.adapter # Build schema list for IN clause - schemas_list = ", ".join(adapter.quote_string(s) for s in self._conn.schemas) + names = schema_names if schema_names is not None else set(self._conn.schemas) + if not names: + self._loaded = True + return + schemas_list = ", ".join(adapter.quote_string(s) for s in names) # Load primary keys and foreign keys via adapter methods # Note: Both PyMySQL and psycopg use %s placeholders, so escape % as %% @@ -220,6 +226,39 @@ def load(self, force: bool = True) -> None: raise DataJointError("DataJoint can only work with acyclic dependencies") self._loaded = True + def load_all_downstream(self) -> None: + """ + Load dependencies including all downstream schemas reachable via FK chains. + + Iteratively discovers schemas that reference the currently loaded + schemas, expanding the dependency graph until no new schemas are + found. This ensures that cascade delete and drop reach all + dependent tables, even those in schemas that haven't been + explicitly activated. + + Called automatically by ``Diagram.cascade()`` and ``Table.drop()``. + Call manually before constructing a ``Diagram`` to include + cross-schema dependencies in visualization:: + + conn.dependencies.load_all_downstream() + dj.Diagram(schema) # now includes all downstream schemas + """ + adapter = self._conn.adapter + known_schemas = set(self._conn.schemas) + if not known_schemas: + self.load() + return + + while True: + schemas_list = ", ".join(adapter.quote_string(s) for s in known_schemas) + result = self._conn.query(adapter.find_downstream_schemas_sql(schemas_list)) + new_schemas = {row[0] for row in result} - known_schemas + if not new_schemas: + break + known_schemas |= new_schemas + + self.load(force=True, schema_names=known_schemas) + def topo_sort(self) -> list[str]: """ Return table names in topological order. diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index 75e00c21c..1436ec7b8 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -1,12 +1,21 @@ """ -Diagram visualization for DataJoint schemas. +Diagram for DataJoint schemas. -This module provides the Diagram class for visualizing schema structure -as directed acyclic graphs showing tables and their foreign key relationships. +This module provides the Diagram class for constructing derived views of the +dependency graph. Diagram supports set operators (+, -, *) for selecting subsets +of tables, restriction propagation (cascade, restrict) for selecting subsets of +data, and inspection (counts, prune) for viewing those selections. + +Mutation operations (delete, drop) live in Table, which uses Diagram internally +for graph computation. + +Visualization methods (draw, make_dot, make_svg, etc.) require matplotlib and +pygraphviz. All other methods are always available. """ from __future__ import annotations +import copy as copy_module import functools import inspect import io @@ -14,7 +23,8 @@ import networkx as nx -from .dependencies import topo_sort +from .condition import AndList +from .dependencies import extract_master, topo_sort from .errors import DataJointError from .table import Table, lookup_class_name from .user_tables import Computed, Imported, Lookup, Manual, Part, _AliasNode, _get_tier @@ -37,1002 +47,1367 @@ logger = logging.getLogger(__name__.split(".")[0]) -if not diagram_active: # noqa: C901 +class Diagram(nx.DiGraph): # noqa: C901 + """ + Schema diagram as a directed acyclic graph (DAG). + + Visualizes tables and foreign key relationships derived from + ``connection.dependencies``. + + Parameters + ---------- + source : Table, Schema, or module + A table object, table class, schema, or module with a schema. + context : dict, optional + Namespace for resolving table class names. If None, uses caller's + frame globals/locals. + + Examples + -------- + >>> diag = dj.Diagram(schema.MyTable) + >>> diag.draw() + + Operators: + + - ``diag1 + diag2`` - union of diagrams + - ``diag1 - diag2`` - difference of diagrams + - ``diag1 * diag2`` - intersection of diagrams + - ``diag + n`` - expand n levels of successors (children) + - ``diag - n`` - expand n levels of predecessors (parents) + + >>> dj.Diagram(schema.Table) + 1 - 1 # immediate ancestors and descendants + + Notes + ----- + ``diagram + 1 - 1`` may differ from ``diagram - 1 + 1``. + Only tables in activated schemas are displayed. To include tables in + downstream schemas that depend on the current schema but haven't been + explicitly activated:: + + conn.dependencies.load_all_downstream() + dj.Diagram(schema) # now includes all downstream schemas + + ``Diagram.cascade()`` calls ``load_all_downstream()`` automatically. + + Layout direction is controlled via ``dj.config.display.diagram_direction`` + (default ``"TB"``). Use ``dj.config.override()`` to change temporarily:: + + with dj.config.override(display_diagram_direction="LR"): + dj.Diagram(schema).draw() + """ + + def __init__(self, source, context=None) -> None: + if isinstance(source, Diagram): + # copy constructor + self.nodes_to_show = set(source.nodes_to_show) + self._expanded_nodes = set(source._expanded_nodes) + self.context = source.context + self._connection = source._connection + self._cascade_restrictions = copy_module.deepcopy(source._cascade_restrictions) + self._restrict_conditions = copy_module.deepcopy(source._restrict_conditions) + self._restriction_attrs = copy_module.deepcopy(source._restriction_attrs) + super().__init__(source) + return + + # get the caller's context + if context is None: + frame = inspect.currentframe().f_back + self.context = dict(frame.f_globals, **frame.f_locals) + del frame + else: + self.context = context + + # find connection in the source + try: + connection = source.connection + except AttributeError: + try: + connection = source.schema.connection + except AttributeError: + raise DataJointError("Could not find database connection in %s" % repr(source)) + + # initialize graph from dependencies + connection.dependencies.load() + super().__init__(connection.dependencies) + self._connection = connection + self._cascade_restrictions = {} + self._restrict_conditions = {} + self._restriction_attrs = {} + + # Enumerate nodes from all the items in the list + self.nodes_to_show = set() + try: + self.nodes_to_show.add(source.full_table_name) + except AttributeError: + try: + database = source.database + except AttributeError: + try: + database = source.schema.database + except AttributeError: + raise DataJointError("Cannot plot Diagram for %s" % repr(source)) + for node in self.nodes(): + # Handle both MySQL backticks and PostgreSQL double quotes + if node.startswith("`%s`" % database) or node.startswith('"%s"' % database): + self.nodes_to_show.add(node) + # All nodes start as expanded + self._expanded_nodes = set(self.nodes_to_show) + + @classmethod + def from_sequence(cls, sequence) -> "Diagram": + """ + Create combined Diagram from a sequence of sources. + + Parameters + ---------- + sequence : iterable + Sequence of table objects, classes, or schemas. - class Diagram: + Returns + ------- + Diagram + Union of diagrams: ``Diagram(arg1) + ... + Diagram(argn)``. """ - Schema diagram (disabled). + return functools.reduce(lambda x, y: x + y, map(Diagram, sequence)) - Diagram visualization requires matplotlib and pygraphviz packages. - Install them to enable this feature. + def add_parts(self) -> "Diagram": + """ + Add part tables of all masters already in the diagram. - See Also + Returns + ------- + Diagram + New diagram with part tables included. + """ + + split = self._connection.adapter.split_full_table_name + + def is_part(part, master): + p_schema, p_table = split(part) + m_schema, m_table = split(master) + return m_schema == p_schema and m_table + "__" == p_table[: len(m_table) + 2] + + self = Diagram(self) # copy + self.nodes_to_show.update(n for n in self.nodes() if any(is_part(n, m) for m in self.nodes_to_show)) + return self + + def collapse(self) -> "Diagram": + """ + Mark all nodes in this diagram as collapsed. + + Collapsed nodes are shown as a single node per schema. When combined + with other diagrams using ``+``, expanded nodes win: if a node is + expanded in either operand, it remains expanded in the result. + + Returns + ------- + Diagram + A copy of this diagram with all nodes collapsed. + + Examples -------- - https://docs.datajoint.com/how-to/installation/ + >>> # Show schema1 expanded, schema2 collapsed into single nodes + >>> dj.Diagram(schema1) + dj.Diagram(schema2).collapse() + + >>> # Collapse all three schemas together + >>> (dj.Diagram(schema1) + dj.Diagram(schema2) + dj.Diagram(schema3)).collapse() + + >>> # Expand one table from collapsed schema + >>> dj.Diagram(schema).collapse() + dj.Diagram(SingleTable) """ + result = Diagram(self) + result._expanded_nodes = set() # All nodes collapsed + return result - def __init__(self, *args, **kwargs) -> None: - logger.warning("Please install matplotlib and pygraphviz libraries to enable the Diagram feature.") + def __add__(self, arg) -> "Diagram": + """ + Union or downstream expansion. -else: + Parameters + ---------- + arg : Diagram or int + Another Diagram for union, or positive int for downstream expansion. + + Returns + ------- + Diagram + Combined or expanded diagram. + """ + result = Diagram(self) # copy + try: + # Merge nodes and edges from the other diagram + result.add_nodes_from(arg.nodes(data=True)) + result.add_edges_from(arg.edges(data=True)) + result.nodes_to_show.update(arg.nodes_to_show) + # Merge contexts for class name lookups + result.context = {**result.context, **arg.context} + # Expanded wins: union of expanded nodes from both operands + result._expanded_nodes = self._expanded_nodes | arg._expanded_nodes + except AttributeError: + try: + result.nodes_to_show.add(arg.full_table_name) + result._expanded_nodes.add(arg.full_table_name) + except AttributeError: + for i in range(arg): + new = nx.algorithms.boundary.node_boundary(result, result.nodes_to_show) + if not new: + break + # add nodes referenced by aliased nodes + new.update(nx.algorithms.boundary.node_boundary(result, (a for a in new if a.isdigit()))) + result.nodes_to_show.update(new) + # New nodes from expansion are expanded + result._expanded_nodes = result._expanded_nodes | result.nodes_to_show + return result + + def __sub__(self, arg) -> "Diagram": + """ + Difference or upstream expansion. + + Parameters + ---------- + arg : Diagram or int + Another Diagram for difference, or positive int for upstream expansion. + + Returns + ------- + Diagram + Reduced or expanded diagram. + """ + self = Diagram(self) # copy + try: + self.nodes_to_show.difference_update(arg.nodes_to_show) + except AttributeError: + try: + self.nodes_to_show.remove(arg.full_table_name) + except AttributeError: + for i in range(arg): + graph = nx.DiGraph(self).reverse() + new = nx.algorithms.boundary.node_boundary(graph, self.nodes_to_show) + if not new: + break + # add nodes referenced by aliased nodes + new.update(nx.algorithms.boundary.node_boundary(graph, (a for a in new if a.isdigit()))) + self.nodes_to_show.update(new) + return self + + def __mul__(self, arg) -> "Diagram": + """ + Intersection of two diagrams. - class Diagram(nx.DiGraph): + Parameters + ---------- + arg : Diagram + Another Diagram. + + Returns + ------- + Diagram + Diagram with nodes present in both operands. """ - Schema diagram as a directed acyclic graph (DAG). + self = Diagram(self) # copy + self.nodes_to_show.intersection_update(arg.nodes_to_show) + return self - Visualizes tables and foreign key relationships derived from - ``connection.dependencies``. + def topo_sort(self) -> list[str]: + """ + Return nodes in topological order. + + Returns + ------- + list[str] + Node names in topological order. + """ + return topo_sort(self) + + @classmethod + def cascade(cls, table_expr, part_integrity="enforce"): + """ + Create a cascade diagram for a table expression. + + Builds a Diagram from the table's dependency graph, includes all + descendants (across all loaded schemas), and propagates the + restriction downstream using OR convergence — a child row is + affected if *any* restricted ancestor taints it. Parameters ---------- - source : Table, Schema, or module - A table object, table class, schema, or module with a schema. - context : dict, optional - Namespace for resolving table class names. If None, uses caller's - frame globals/locals. + table_expr : QueryExpression + A (possibly restricted) table expression + (e.g., ``Session & 'subject_id=1'``). + part_integrity : str, optional + ``"enforce"`` (default), ``"ignore"``, or ``"cascade"``. + + Returns + ------- + Diagram + New Diagram with cascade restrictions applied, trimmed to + the seed table and its affected descendants. Examples -------- - >>> diag = dj.Diagram(schema.MyTable) - >>> diag.draw() + >>> # Preview cascade impact across all downstream schemas + >>> dj.Diagram.cascade(Session & 'subject_id=1').counts() - Operators: + >>> # Inspect the cascade subgraph + >>> dj.Diagram.cascade(Session & 'subject_id=1') + """ + conn = table_expr.connection + conn.dependencies.load_all_downstream() + node = table_expr.full_table_name + + result = cls.__new__(cls) + nx.DiGraph.__init__(result, conn.dependencies) + result._connection = conn + result.context = {} + result._cascade_restrictions = {} + result._restrict_conditions = {} + result._restriction_attrs = {} + + # Include seed + all descendants + descendants = set(nx.descendants(result, node)) | {node} + result.nodes_to_show = descendants + result._expanded_nodes = set(descendants) + + # Seed restriction + restriction = AndList(table_expr.restriction) + result._cascade_restrictions[node] = [restriction] if restriction else [] + result._restriction_attrs[node] = set(table_expr.restriction_attributes) + + # Propagate downstream + result._propagate_restrictions(node, mode="cascade", part_integrity=part_integrity) + + # Trim graph to cascade subgraph: only restricted tables + # (seed + descendants) plus alias nodes connecting them. + keep = set(result._cascade_restrictions) + for alias in (n for n in result.nodes() if n.isdigit()): + if set(result.predecessors(alias)) & keep and set(result.successors(alias)) & keep: + keep.add(alias) + result.remove_nodes_from(set(result.nodes()) - keep) + result.nodes_to_show &= keep + result._expanded_nodes &= keep + return result + + def _restricted_table(self, node): + """ + Return a FreeTable for ``node`` with this diagram's restrictions applied. - - ``diag1 + diag2`` - union of diagrams - - ``diag1 - diag2`` - difference of diagrams - - ``diag1 * diag2`` - intersection of diagrams - - ``diag + n`` - expand n levels of successors (children) - - ``diag - n`` - expand n levels of predecessors (parents) + Cascade restrictions are OR-combined (a row is affected if ANY + FK reference points to a deleted row). Restrict conditions are + AND-combined (a row is included only when ALL ancestor conditions + are satisfied). + """ + from .table import FreeTable + + ft = FreeTable(self._connection, node) + restrictions = (self._cascade_restrictions or self._restrict_conditions).get(node, []) + if not restrictions: + return ft + if self._cascade_restrictions: + # OR semantics — passing a list to restrict() creates an OrList + return ft.restrict(restrictions) + else: + # AND semantics — each restriction narrows further + for r in restrictions: + ft = ft.restrict(r) + return ft + + def restrict(self, table_expr): + """ + Apply restrict condition and propagate downstream. - >>> dj.Diagram(schema.Table) + 1 - 1 # immediate ancestors and descendants + AND at convergence — a child row is included only if it satisfies + *all* restricted ancestors. Used for export. Can be chained. - Notes - ----- - ``diagram + 1 - 1`` may differ from ``diagram - 1 + 1``. - Only tables loaded in the connection are displayed. + Cannot be called on a Diagram produced by ``Diagram.cascade()``. + + Parameters + ---------- + table_expr : QueryExpression + A restricted table expression. + + Returns + ------- + Diagram + New Diagram with restrict conditions applied. + """ + if self._cascade_restrictions: + raise DataJointError( + "Cannot apply restrict() on a Diagram produced by Diagram.cascade(). " + "cascade and restrict are mutually exclusive modes." + ) + result = Diagram(self) + node = table_expr.full_table_name + if node not in result.nodes(): + raise DataJointError(f"Table {node} is not in the diagram.") + # Seed restriction (AND accumulation) + result._restrict_conditions.setdefault(node, AndList()).extend(table_expr.restriction) + result._restriction_attrs.setdefault(node, set()).update(table_expr.restriction_attributes) + # Propagate downstream + result._propagate_restrictions(node, mode="restrict") + return result + + def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"): + """ + Propagate restrictions from start_node to all its descendants. + + Walks the dependency graph in topological order, applying + propagation rules at each edge. Only processes descendants of + start_node to avoid duplicate propagation when chaining. + """ + from .table import FreeTable + + sorted_nodes = topo_sort(self) + # Only propagate through descendants of start_node + allowed_nodes = {start_node} | set(nx.descendants(self, start_node)) + propagated_edges = set() + visited_masters = set() + + restrictions = self._cascade_restrictions if mode == "cascade" else self._restrict_conditions + + # Multiple passes to handle part_integrity="cascade" upward propagation. + # When a part table triggers its master to join the cascade, the master's + # other descendants need processing in a subsequent pass. The loop + # terminates when no new nodes are added — guaranteed in a DAG. + any_new = True + while any_new: + any_new = False + + for node in sorted_nodes: + if node not in restrictions or node not in allowed_nodes: + continue + + # Build parent FreeTable with current restriction + parent_ft = self._restricted_table(node) + + parent_attrs = self._restriction_attrs.get(node, set()) + + for _, target, edge_props in self.out_edges(node, data=True): + attr_map = edge_props.get("attr_map", {}) + aliased = edge_props.get("aliased", False) + + if target.isdigit(): + # Alias node — follow through to real child + for _, child_node, _ in self.out_edges(target, data=True): + edge_key = (node, target, child_node) + if edge_key in propagated_edges: + continue + propagated_edges.add(edge_key) + was_new = child_node not in restrictions + self._apply_propagation_rule( + parent_ft, + parent_attrs, + child_node, + attr_map, + True, + mode, + restrictions, + ) + if was_new and child_node in restrictions: + any_new = True + else: + edge_key = (node, target) + if edge_key in propagated_edges: + continue + propagated_edges.add(edge_key) + was_new = target not in restrictions + self._apply_propagation_rule( + parent_ft, + parent_attrs, + target, + attr_map, + aliased, + mode, + restrictions, + ) + if was_new and target in restrictions: + any_new = True + + # part_integrity="cascade": propagate up from part to master + if part_integrity == "cascade" and mode == "cascade": + master_name = extract_master(target) + if ( + master_name + and master_name in self.nodes() + and master_name not in restrictions + and master_name not in visited_masters + ): + visited_masters.add(master_name) + child_ft = self._restricted_table(target) + master_ft = FreeTable(self._connection, master_name) + from .condition import make_condition + + master_restr = make_condition( + master_ft, + (master_ft.proj() & child_ft.proj()).to_arrays(), + master_ft.restriction_attributes, + ) + restrictions[master_name] = [master_restr] + self._restriction_attrs[master_name] = set() + allowed_nodes.add(master_name) + allowed_nodes.update(nx.descendants(self, master_name)) + any_new = True + + def _apply_propagation_rule( + self, + parent_ft, + parent_attrs, + child_node, + attr_map, + aliased, + mode, + restrictions, + ): + """ + Apply one of the 3 propagation rules to a parent→child edge. - Layout direction is controlled via ``dj.config.display.diagram_direction`` - (default ``"TB"``). Use ``dj.config.override()`` to change temporarily:: + Rules (from table.py restriction propagation): - with dj.config.override(display_diagram_direction="LR"): - dj.Diagram(schema).draw() + 1. Non-aliased AND parent restriction attrs ⊆ child PK: + Copy parent restriction directly. + 2. Aliased FK (attr_map renames columns): + ``parent.proj(**{fk: pk for fk, pk in attr_map.items()})`` + 3. Non-aliased AND parent restriction attrs ⊄ child PK: + ``parent.proj()`` """ + child_pk = self.nodes[child_node].get("primary_key", set()) - def __init__(self, source, context=None) -> None: - if isinstance(source, Diagram): - # copy constructor - self.nodes_to_show = set(source.nodes_to_show) - self._expanded_nodes = set(source._expanded_nodes) - self.context = source.context - self._connection = source._connection - super().__init__(source) - return - - # get the caller's context - if context is None: - frame = inspect.currentframe().f_back - self.context = dict(frame.f_globals, **frame.f_locals) - del frame + if not aliased and parent_attrs and parent_attrs <= child_pk: + # Rule 1: copy parent restriction directly + parent_restr = restrictions.get( + parent_ft.full_table_name, + [] if mode == "cascade" else AndList(), + ) + if mode == "cascade": + restrictions.setdefault(child_node, []).extend(parent_restr) + else: + restrictions.setdefault(child_node, AndList()).extend(parent_restr) + child_attrs = set(parent_attrs) + elif aliased: + # Rule 2: aliased FK — project with renaming + child_item = parent_ft.proj(**{fk: pk for fk, pk in attr_map.items()}) + if mode == "cascade": + restrictions.setdefault(child_node, []).append(child_item) else: - self.context = context + restrictions.setdefault(child_node, AndList()).append(child_item) + child_attrs = set(attr_map.keys()) + else: + # Rule 3: non-aliased, restriction attrs ⊄ child PK — project + child_item = parent_ft.proj() + if mode == "cascade": + restrictions.setdefault(child_node, []).append(child_item) + else: + restrictions.setdefault(child_node, AndList()).append(child_item) + child_attrs = set(attr_map.values()) - # find connection in the source - try: - connection = source.connection - except AttributeError: - try: - connection = source.schema.connection - except AttributeError: - raise DataJointError("Could not find database connection in %s" % repr(source[0])) + self._restriction_attrs.setdefault(child_node, set()).update(child_attrs) - # initialize graph from dependencies - self._connection = connection - connection.dependencies.load() - super().__init__(connection.dependencies) + def counts(self): + """ + Return affected row counts per table without modifying data. - # Enumerate nodes from all the items in the list - self.nodes_to_show = set() - try: - self.nodes_to_show.add(source.full_table_name) - except AttributeError: - try: - database = source.database - except AttributeError: - try: - database = source.schema.database - except AttributeError: - raise DataJointError("Cannot plot Diagram for %s" % repr(source)) - for node in self: - # Handle both MySQL backticks and PostgreSQL double quotes - if node.startswith("`%s`" % database) or node.startswith('"%s"' % database): - self.nodes_to_show.add(node) - # All nodes start as expanded - self._expanded_nodes = set(self.nodes_to_show) - - @classmethod - def from_sequence(cls, sequence) -> "Diagram": - """ - Create combined Diagram from a sequence of sources. - - Parameters - ---------- - sequence : iterable - Sequence of table objects, classes, or schemas. - - Returns - ------- - Diagram - Union of diagrams: ``Diagram(arg1) + ... + Diagram(argn)``. - """ - return functools.reduce(lambda x, y: x + y, map(Diagram, sequence)) - - def add_parts(self) -> "Diagram": - """ - Add part tables of all masters already in the diagram. - - Returns - ------- - Diagram - New diagram with part tables included. - """ - - def is_part(part, master): - part = [s.strip("`") for s in part.split(".")] - master = [s.strip("`") for s in master.split(".")] - return master[0] == part[0] and master[1] + "__" == part[1][: len(master[1]) + 2] - - self = Diagram(self) # copy - self.nodes_to_show.update(n for n in self.nodes() if any(is_part(n, m) for m in self.nodes_to_show)) - return self - - def collapse(self) -> "Diagram": - """ - Mark all nodes in this diagram as collapsed. - - Collapsed nodes are shown as a single node per schema. When combined - with other diagrams using ``+``, expanded nodes win: if a node is - expanded in either operand, it remains expanded in the result. - - Returns - ------- - Diagram - A copy of this diagram with all nodes collapsed. - - Examples - -------- - >>> # Show schema1 expanded, schema2 collapsed into single nodes - >>> dj.Diagram(schema1) + dj.Diagram(schema2).collapse() - - >>> # Collapse all three schemas together - >>> (dj.Diagram(schema1) + dj.Diagram(schema2) + dj.Diagram(schema3)).collapse() - - >>> # Expand one table from collapsed schema - >>> dj.Diagram(schema).collapse() + dj.Diagram(SingleTable) - """ - result = Diagram(self) - result._expanded_nodes = set() # All nodes collapsed - return result - - def __add__(self, arg) -> "Diagram": - """ - Union or downstream expansion. - - Parameters - ---------- - arg : Diagram or int - Another Diagram for union, or positive int for downstream expansion. - - Returns - ------- - Diagram - Combined or expanded diagram. - """ - result = Diagram(self) # copy - try: - # Merge nodes and edges from the other diagram - result.add_nodes_from(arg.nodes(data=True)) - result.add_edges_from(arg.edges(data=True)) - result.nodes_to_show.update(arg.nodes_to_show) - # Merge contexts for class name lookups - result.context = {**result.context, **arg.context} - # Expanded wins: union of expanded nodes from both operands - result._expanded_nodes = self._expanded_nodes | arg._expanded_nodes - except AttributeError: - try: - result.nodes_to_show.add(arg.full_table_name) - result._expanded_nodes.add(arg.full_table_name) - except AttributeError: - for i in range(arg): - new = nx.algorithms.boundary.node_boundary(result, result.nodes_to_show) - if not new: - break - # add nodes referenced by aliased nodes - new.update(nx.algorithms.boundary.node_boundary(result, (a for a in new if a.isdigit()))) - result.nodes_to_show.update(new) - # New nodes from expansion are expanded - result._expanded_nodes = result._expanded_nodes | result.nodes_to_show - return result - - def __sub__(self, arg) -> "Diagram": - """ - Difference or upstream expansion. - - Parameters - ---------- - arg : Diagram or int - Another Diagram for difference, or positive int for upstream expansion. - - Returns - ------- - Diagram - Reduced or expanded diagram. - """ - self = Diagram(self) # copy - try: - self.nodes_to_show.difference_update(arg.nodes_to_show) - except AttributeError: - try: - self.nodes_to_show.remove(arg.full_table_name) - except AttributeError: - for i in range(arg): - graph = nx.DiGraph(self).reverse() - new = nx.algorithms.boundary.node_boundary(graph, self.nodes_to_show) - if not new: - break - # add nodes referenced by aliased nodes - new.update(nx.algorithms.boundary.node_boundary(graph, (a for a in new if a.isdigit()))) - self.nodes_to_show.update(new) - return self - - def __mul__(self, arg) -> "Diagram": - """ - Intersection of two diagrams. - - Parameters - ---------- - arg : Diagram - Another Diagram. - - Returns - ------- - Diagram - Diagram with nodes present in both operands. - """ - self = Diagram(self) # copy - self.nodes_to_show.intersection_update(arg.nodes_to_show) - return self - - def topo_sort(self) -> list[str]: - """ - Return nodes in topological order. - - Returns - ------- - list[str] - Node names in topological order. - """ - return topo_sort(self) - - def _make_graph(self) -> nx.DiGraph: - """ - Build graph object ready for drawing. - - Returns - ------- - nx.DiGraph - Graph with nodes relabeled to class names. - """ - # mark "distinguished" tables, i.e. those that introduce new primary key - # attributes - # Filter nodes_to_show to only include nodes that exist in the graph - valid_nodes = self.nodes_to_show.intersection(set(self.nodes())) - for name in valid_nodes: - foreign_attributes = set( - attr for p in self.in_edges(name, data=True) for attr in p[2]["attr_map"] if p[2]["primary"] - ) - self.nodes[name]["distinguished"] = ( - "primary_key" in self.nodes[name] and foreign_attributes < self.nodes[name]["primary_key"] - ) - # include aliased nodes that are sandwiched between two displayed nodes - gaps = set(nx.algorithms.boundary.node_boundary(self, valid_nodes)).intersection( - nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), valid_nodes) + Returns + ------- + dict[str, int] + Mapping of full table name to affected row count. + """ + restrictions = self._cascade_restrictions or self._restrict_conditions + if not restrictions: + raise DataJointError( + "No restrictions applied. " "Use Diagram.cascade(table_expr) or diag.restrict(table_expr) first." ) - nodes = valid_nodes.union(a for a in gaps if a.isdigit()) - # construct subgraph and rename nodes to class names - graph = nx.DiGraph(nx.DiGraph(self).subgraph(nodes)) - nx.set_node_attributes(graph, name="node_type", values={n: _get_tier(n) for n in graph}) - # relabel nodes to class names - mapping = {node: lookup_class_name(node, self.context) or node for node in graph.nodes()} - new_names = list(mapping.values()) - if len(new_names) > len(set(new_names)): - raise DataJointError("Some classes have identical names. The Diagram cannot be plotted.") - nx.relabel_nodes(graph, mapping, copy=False) - return graph - - def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str]]: - """ - Apply collapse logic to the graph. - - Nodes in nodes_to_show but not in _expanded_nodes are collapsed into - single schema nodes. - - Parameters - ---------- - graph : nx.DiGraph - The graph from _make_graph(). - - Returns - ------- - tuple[nx.DiGraph, dict[str, str]] - Modified graph and mapping of collapsed schema labels to their table count. - """ - # Filter to valid nodes (those that exist in the underlying graph) - valid_nodes = self.nodes_to_show.intersection(set(self.nodes())) - valid_expanded = self._expanded_nodes.intersection(set(self.nodes())) - - # If all nodes are expanded, no collapse needed - if valid_expanded >= valid_nodes: - return graph, {} - - # Map full_table_names to class_names - full_to_class = {node: lookup_class_name(node, self.context) or node for node in valid_nodes} - class_to_full = {v: k for k, v in full_to_class.items()} - - # Identify expanded class names - expanded_class_names = {full_to_class.get(node, node) for node in valid_expanded} - - # Identify nodes to collapse (class names) - nodes_to_collapse = set(graph.nodes()) - expanded_class_names - - if not nodes_to_collapse: - return graph, {} - - # Group collapsed nodes by schema - collapsed_by_schema = {} # schema_name -> list of class_names - for class_name in nodes_to_collapse: - full_name = class_to_full.get(class_name) - if full_name: - parts = full_name.replace('"', "`").split("`") - if len(parts) >= 2: - schema_name = parts[1] - if schema_name not in collapsed_by_schema: - collapsed_by_schema[schema_name] = [] - collapsed_by_schema[schema_name].append(class_name) - - if not collapsed_by_schema: - return graph, {} - - # Determine labels for collapsed schemas - schema_modules = {} - for schema_name, class_names in collapsed_by_schema.items(): - schema_modules[schema_name] = set() - for class_name in class_names: - cls = self._resolve_class(class_name) - if cls is not None and hasattr(cls, "__module__"): - module_name = cls.__module__.split(".")[-1] - schema_modules[schema_name].add(module_name) - - # Collect module names for ALL schemas in the diagram (not just collapsed) - all_schema_modules = {} # schema_name -> module_name - for node in graph.nodes(): - full_name = class_to_full.get(node) - if full_name: - parts = full_name.replace('"', "`").split("`") - if len(parts) >= 2: - db_schema = parts[1] - cls = self._resolve_class(node) - if cls is not None and hasattr(cls, "__module__"): - module_name = cls.__module__.split(".")[-1] - all_schema_modules[db_schema] = module_name - - # Check which module names are shared by multiple schemas - module_to_schemas = {} - for db_schema, module_name in all_schema_modules.items(): - if module_name not in module_to_schemas: - module_to_schemas[module_name] = [] - module_to_schemas[module_name].append(db_schema) - - ambiguous_modules = {m for m, schemas in module_to_schemas.items() if len(schemas) > 1} - - # Determine labels for collapsed schemas - collapsed_labels = {} # schema_name -> label - for schema_name, modules in schema_modules.items(): - if len(modules) == 1: - module_name = next(iter(modules)) - # Use database schema name if module is ambiguous - if module_name in ambiguous_modules: - label = schema_name - else: - label = module_name - else: - label = schema_name - collapsed_labels[schema_name] = label - - # Build counts using final labels - collapsed_counts = {} # label -> count of tables - for schema_name, class_names in collapsed_by_schema.items(): - label = collapsed_labels[schema_name] - collapsed_counts[label] = len(class_names) - - # Create new graph with collapsed nodes - new_graph = nx.DiGraph() - - # Map old node names to new names (collapsed nodes -> schema label) - node_mapping = {} - for node in graph.nodes(): - full_name = class_to_full.get(node) - if full_name: - parts = full_name.replace('"', "`").split("`") - if len(parts) >= 2 and node in nodes_to_collapse: - schema_name = parts[1] - node_mapping[node] = collapsed_labels[schema_name] - else: - node_mapping[node] = node - else: - # Alias nodes - check if they should be collapsed - # An alias node should be collapsed if ALL its neighbors are collapsed - neighbors = set(graph.predecessors(node)) | set(graph.successors(node)) - if neighbors and neighbors <= nodes_to_collapse: - # Get schema from first neighbor - neighbor = next(iter(neighbors)) - full_name = class_to_full.get(neighbor) - if full_name: - parts = full_name.replace('"', "`").split("`") - if len(parts) >= 2: - schema_name = parts[1] - node_mapping[node] = collapsed_labels[schema_name] - continue - node_mapping[node] = node - # Build reverse mapping: label -> schema_name - label_to_schema = {label: schema for schema, label in collapsed_labels.items()} - - # Add nodes - added_collapsed = set() - for old_node, new_node in node_mapping.items(): - if new_node in collapsed_counts: - # This is a collapsed schema node - if new_node not in added_collapsed: - schema_name = label_to_schema.get(new_node, new_node) - new_graph.add_node( - new_node, - node_type=None, - collapsed=True, - table_count=collapsed_counts[new_node], - schema_name=schema_name, - ) - added_collapsed.add(new_node) - else: - new_graph.add_node(new_node, **graph.nodes[old_node]) - - # Add edges (avoiding self-loops and duplicates) - for src, dest, data in graph.edges(data=True): - new_src = node_mapping[src] - new_dest = node_mapping[dest] - if new_src != new_dest and not new_graph.has_edge(new_src, new_dest): - new_graph.add_edge(new_src, new_dest, **data) - - return new_graph, collapsed_counts - - def _resolve_class(self, name: str): - """ - Safely resolve a table class from a dotted name without eval(). - - Parameters - ---------- - name : str - Dotted class name like "MyTable" or "Module.MyTable". - - Returns - ------- - type or None - The table class if found, otherwise None. - """ - parts = name.split(".") - obj = self.context.get(parts[0]) - for part in parts[1:]: - if obj is None: - return None - obj = getattr(obj, part, None) - if obj is not None and isinstance(obj, type) and issubclass(obj, Table): - return obj - return None - - @staticmethod - def _encapsulate_edge_attributes(graph: nx.DiGraph) -> None: - """ - Encapsulate edge attr_map in double quotes for pydot compatibility. - - Modifies graph in place. - - See Also - -------- - https://github.com/pydot/pydot/issues/258#issuecomment-795798099 - """ - for u, v, *_, edgedata in graph.edges(data=True): - if "attr_map" in edgedata: - graph.edges[u, v]["attr_map"] = '"{0}"'.format(edgedata["attr_map"]) - - @staticmethod - def _encapsulate_node_names(graph: nx.DiGraph) -> None: - """ - Encapsulate node names in double quotes for pydot compatibility. - - Modifies graph in place. - - See Also - -------- - https://github.com/datajoint/datajoint-python/pull/1176 - """ - nx.relabel_nodes( - graph, - {node: '"{0}"'.format(node) for node in graph.nodes()}, - copy=False, + result = {} + for ft in self: + if ft.full_table_name in restrictions: + count = len(ft) + result[ft.full_table_name] = count + logger.info("{table} ({count} tuples)".format(table=ft.full_table_name, count=count)) + return result + + def __iter__(self): + """ + Iterate over non-alias nodes in topological order (parents first). + + Yields restricted ``FreeTable`` objects when cascade or restrict + conditions have been applied, unrestricted ``FreeTable`` otherwise. + + Alias nodes (used internally for multi-FK edges) are skipped. + """ + for node in topo_sort(self): + if not node.isdigit() and node in self.nodes_to_show: + yield self._restricted_table(node) + + def __reversed__(self): + """ + Iterate in reverse topological order (leaves first). + + Same as ``__iter__`` but reversed — useful for cascading + deletes and drops. + """ + for node in reversed(topo_sort(self)): + if not node.isdigit() and node in self.nodes_to_show: + yield self._restricted_table(node) + + def prune(self): + """ + Remove tables with zero matching rows from the diagram. + + Without prior restrictions, removes physically empty tables. + After ``restrict()``, removes tables where the restricted query + yields zero rows. Cannot be used on a cascade Diagram (cascade + is for delete, where zero-count tables must remain in the graph + to handle concurrent inserts safely). + + Returns + ------- + Diagram + New Diagram with empty tables removed. + """ + from .table import FreeTable + + if self._cascade_restrictions: + raise DataJointError( + "prune() cannot be used on a Diagram produced by Diagram.cascade(). " + "Cascade diagrams must retain all descendant tables for safe deletion." ) - def make_dot(self): - """ - Generate a pydot graph object. - - Returns - ------- - pydot.Dot - The graph object ready for rendering. - - Notes - ----- - Layout direction is controlled via ``dj.config.display.diagram_direction``. - Tables are grouped by schema, with the Python module name shown as the - group label when available. - """ - direction = self._connection._config.display.diagram_direction - graph = self._make_graph() - - # Apply collapse logic if needed - graph, collapsed_counts = self._apply_collapse(graph) - - # Build schema mapping: class_name -> schema_name - # Group by database schema, label with Python module name if 1:1 mapping - schema_map = {} # class_name -> schema_name - schema_modules = {} # schema_name -> set of module names - - for full_name in self.nodes_to_show: - # Extract schema from full table name like `schema`.`table` or "schema"."table" - parts = full_name.replace('"', "`").split("`") - if len(parts) >= 2: - schema_name = parts[1] # schema is between first pair of backticks - class_name = lookup_class_name(full_name, self.context) or full_name - schema_map[class_name] = schema_name - - # Collect all module names for this schema - if schema_name not in schema_modules: - schema_modules[schema_name] = set() - cls = self._resolve_class(class_name) + result = Diagram(self) + + if result._restrict_conditions: + for node in list(result._restrict_conditions): + if node.isdigit(): + continue + if len(result._restricted_table(node)) == 0: + result._restrict_conditions.pop(node) + result._restriction_attrs.pop(node, None) + result.nodes_to_show.discard(node) + else: + # Unrestricted: check physical row counts + for node in list(result.nodes_to_show): + if node.isdigit(): + continue + ft = FreeTable(self._connection, node) + if len(ft) == 0: + result.nodes_to_show.discard(node) + + return result + + def _make_graph(self) -> nx.DiGraph: + """ + Build graph object ready for drawing. + + Returns + ------- + nx.DiGraph + Graph with nodes relabeled to class names. + """ + # mark "distinguished" tables, i.e. those that introduce new primary key + # attributes + # Filter nodes_to_show to only include nodes that exist in the graph + valid_nodes = self.nodes_to_show.intersection(set(self.nodes())) + for name in valid_nodes: + foreign_attributes = set( + attr for p in self.in_edges(name, data=True) for attr in p[2]["attr_map"] if p[2]["primary"] + ) + self.nodes[name]["distinguished"] = ( + "primary_key" in self.nodes[name] and foreign_attributes < self.nodes[name]["primary_key"] + ) + # include aliased nodes that are sandwiched between two displayed nodes + gaps = set(nx.algorithms.boundary.node_boundary(self, valid_nodes)).intersection( + nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), valid_nodes) + ) + nodes = valid_nodes.union(a for a in gaps if a.isdigit()) + # construct subgraph and rename nodes to class names + graph = nx.DiGraph(nx.DiGraph(self).subgraph(nodes)) + nx.set_node_attributes(graph, name="node_type", values={n: _get_tier(n) for n in graph}) + # relabel nodes to class names + mapping = {node: lookup_class_name(node, self.context) or node for node in graph.nodes()} + new_names = list(mapping.values()) + if len(new_names) > len(set(new_names)): + raise DataJointError("Some classes have identical names. The Diagram cannot be plotted.") + nx.relabel_nodes(graph, mapping, copy=False) + return graph + + def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str]]: + """ + Apply collapse logic to the graph. + + Nodes in nodes_to_show but not in _expanded_nodes are collapsed into + single schema nodes. + + Parameters + ---------- + graph : nx.DiGraph + The graph from _make_graph(). + + Returns + ------- + tuple[nx.DiGraph, dict[str, str]] + Modified graph and mapping of collapsed schema labels to their table count. + """ + # Filter to valid nodes (those that exist in the underlying graph) + valid_nodes = self.nodes_to_show.intersection(set(self.nodes())) + valid_expanded = self._expanded_nodes.intersection(set(self.nodes())) + + # If all nodes are expanded, no collapse needed + if valid_expanded >= valid_nodes: + return graph, {} + + # Map full_table_names to class_names + full_to_class = {node: lookup_class_name(node, self.context) or node for node in valid_nodes} + class_to_full = {v: k for k, v in full_to_class.items()} + + # Identify expanded class names + expanded_class_names = {full_to_class.get(node, node) for node in valid_expanded} + + # Identify nodes to collapse (class names) + nodes_to_collapse = set(graph.nodes()) - expanded_class_names + + if not nodes_to_collapse: + return graph, {} + + # Group collapsed nodes by schema + collapsed_by_schema = {} # schema_name -> list of class_names + for class_name in nodes_to_collapse: + full_name = class_to_full.get(class_name) + if full_name: + schema_name, _ = self._connection.adapter.split_full_table_name(full_name) + if schema_name: + if schema_name not in collapsed_by_schema: + collapsed_by_schema[schema_name] = [] + collapsed_by_schema[schema_name].append(class_name) + + if not collapsed_by_schema: + return graph, {} + + # Determine labels for collapsed schemas + schema_modules = {} + for schema_name, class_names in collapsed_by_schema.items(): + schema_modules[schema_name] = set() + for class_name in class_names: + cls = self._resolve_class(class_name) + if cls is not None and hasattr(cls, "__module__"): + module_name = cls.__module__.split(".")[-1] + schema_modules[schema_name].add(module_name) + + # Collect module names for ALL schemas in the diagram (not just collapsed) + all_schema_modules = {} # schema_name -> module_name + for node in graph.nodes(): + full_name = class_to_full.get(node) + if full_name: + db_schema, _ = self._connection.adapter.split_full_table_name(full_name) + if db_schema: + cls = self._resolve_class(node) if cls is not None and hasattr(cls, "__module__"): module_name = cls.__module__.split(".")[-1] - schema_modules[schema_name].add(module_name) - - # Determine cluster labels: use module name if 1:1, else database schema name - cluster_labels = {} # schema_name -> label - for schema_name, modules in schema_modules.items(): - if len(modules) == 1: - cluster_labels[schema_name] = next(iter(modules)) - else: - cluster_labels[schema_name] = schema_name - - # Disambiguate labels if multiple schemas share the same module name - # (e.g., all defined in __main__ in a notebook) - label_counts = {} - for label in cluster_labels.values(): - label_counts[label] = label_counts.get(label, 0) + 1 - - for schema_name, label in cluster_labels.items(): - if label_counts[label] > 1: - # Multiple schemas share this module name - add schema name - cluster_labels[schema_name] = f"{label} ({schema_name})" - - # Assign alias nodes (orange dots) to the same schema as their child table - for node, data in graph.nodes(data=True): - if data.get("node_type") is _AliasNode: - # Find the child (successor) - the table that declares the renamed FK - successors = list(graph.successors(node)) - if successors and successors[0] in schema_map: - schema_map[node] = schema_map[successors[0]] - - # Assign collapsed nodes to their schema so they appear in the cluster - for node, data in graph.nodes(data=True): - if data.get("collapsed") and data.get("schema_name"): - schema_map[node] = data["schema_name"] - - scale = 1.2 # scaling factor for fonts and boxes - label_props = { # http://matplotlib.org/examples/color/named_colors.html - None: dict( - shape="circle", - color="#FFFF0040", - fontcolor="yellow", - fontsize=round(scale * 8), - size=0.4 * scale, - fixed=False, - ), - _AliasNode: dict( - shape="circle", - color="#FF880080", - fontcolor="#FF880080", - fontsize=round(scale * 0), - size=0.05 * scale, - fixed=True, - ), - Manual: dict( - shape="box", - color="#00FF0030", - fontcolor="darkgreen", - fontsize=round(scale * 10), - size=0.4 * scale, - fixed=False, - ), - Lookup: dict( - shape="plaintext", - color="#00000020", - fontcolor="black", - fontsize=round(scale * 8), - size=0.4 * scale, - fixed=False, - ), - Computed: dict( - shape="ellipse", - color="#FF000020", - fontcolor="#7F0000A0", - fontsize=round(scale * 10), - size=0.4 * scale, - fixed=False, - ), - Imported: dict( - shape="ellipse", - color="#00007F40", - fontcolor="#00007FA0", - fontsize=round(scale * 10), - size=0.4 * scale, - fixed=False, - ), - Part: dict( - shape="plaintext", - color="#00000000", - fontcolor="black", - fontsize=round(scale * 8), - size=0.1 * scale, - fixed=False, - ), - "collapsed": dict( - shape="box3d", - color="#80808060", - fontcolor="#404040", - fontsize=round(scale * 10), - size=0.5 * scale, - fixed=False, - ), - } - # Build node_props, handling collapsed nodes specially - node_props = {} - for node, d in graph.nodes(data=True): - if d.get("collapsed"): - node_props[node] = label_props["collapsed"] + all_schema_modules[db_schema] = module_name + + # Check which module names are shared by multiple schemas + module_to_schemas = {} + for db_schema, module_name in all_schema_modules.items(): + if module_name not in module_to_schemas: + module_to_schemas[module_name] = [] + module_to_schemas[module_name].append(db_schema) + + ambiguous_modules = {m for m, schemas in module_to_schemas.items() if len(schemas) > 1} + + # Determine labels for collapsed schemas + collapsed_labels = {} # schema_name -> label + for schema_name, modules in schema_modules.items(): + if len(modules) == 1: + module_name = next(iter(modules)) + # Use database schema name if module is ambiguous + if module_name in ambiguous_modules: + label = schema_name else: - node_props[node] = label_props[d["node_type"]] - - self._encapsulate_node_names(graph) - self._encapsulate_edge_attributes(graph) - dot = nx.drawing.nx_pydot.to_pydot(graph) - dot.set_rankdir(direction) - for node in dot.get_nodes(): - node.set_shape("circle") - name = node.get_name().strip('"') - props = node_props[name] - node.set_fontsize(props["fontsize"]) - node.set_fontcolor(props["fontcolor"]) - node.set_shape(props["shape"]) - node.set_fontname("arial") - node.set_fixedsize("shape" if props["fixed"] else False) - node.set_width(props["size"]) - node.set_height(props["size"]) - - # Handle collapsed nodes specially - node_data = graph.nodes.get(f'"{name}"', {}) - if node_data.get("collapsed"): - table_count = node_data.get("table_count", 0) - label = f"({table_count} tables)" if table_count != 1 else "(1 table)" - node.set_label(label) - node.set_tooltip(f"Collapsed schema: {table_count} tables") + label = module_name + else: + label = schema_name + collapsed_labels[schema_name] = label + + # Build counts using final labels + collapsed_counts = {} # label -> count of tables + for schema_name, class_names in collapsed_by_schema.items(): + label = collapsed_labels[schema_name] + collapsed_counts[label] = len(class_names) + + # Create new graph with collapsed nodes + new_graph = nx.DiGraph() + + # Map old node names to new names (collapsed nodes -> schema label) + node_mapping = {} + for node in graph.nodes(): + full_name = class_to_full.get(node) + if full_name: + schema_name, _ = self._connection.adapter.split_full_table_name(full_name) + if schema_name and node in nodes_to_collapse: + node_mapping[node] = collapsed_labels[schema_name] else: - cls = self._resolve_class(name) - if cls is not None: - description = cls().describe(context=self.context).split("\n") - description = ( - ( - "-" * 30 - if q.startswith("---") - else (q.replace("->", "→") if "->" in q else q.split(":")[0]) - ) - for q in description - if not q.startswith("#") - ) - node.set_tooltip(" ".join(description)) - # Strip module prefix from label if it matches the cluster label - display_name = name - schema_name = schema_map.get(name) - if schema_name and "." in name: - cluster_label = cluster_labels.get(schema_name) - if cluster_label and name.startswith(cluster_label + "."): - display_name = name[len(cluster_label) + 1 :] - node.set_label("<" + display_name + ">" if node.get("distinguished") == "True" else display_name) - node.set_color(props["color"]) - node.set_style("filled") - - for edge in dot.get_edges(): - # see https://graphviz.org/doc/info/attrs.html - src = edge.get_source() - dest = edge.get_destination() - props = graph.get_edge_data(src, dest) - if props is None: - raise DataJointError("Could not find edge with source '{}' and destination '{}'".format(src, dest)) - edge.set_color("#00000040") - edge.set_style("solid" if props.get("primary") else "dashed") - dest_node_type = graph.nodes[dest].get("node_type") - master_part = dest_node_type is Part and dest.startswith(src + ".") - edge.set_weight(3 if master_part else 1) - edge.set_arrowhead("none") - edge.set_penwidth(0.75 if props.get("multi") else 2) - - # Group nodes into schema clusters (always on) - if schema_map: - import pydot - - # Group nodes by schema - schemas = {} - for node in list(dot.get_nodes()): - name = node.get_name().strip('"') - schema_name = schema_map.get(name) - if schema_name: - if schema_name not in schemas: - schemas[schema_name] = [] - schemas[schema_name].append(node) - - # Create clusters for each schema - # Use Python module name if 1:1 mapping, otherwise database schema name - for schema_name, nodes in schemas.items(): - label = cluster_labels.get(schema_name, schema_name) - cluster = pydot.Cluster( - f"cluster_{schema_name}", - label=label, - style="dashed", - color="gray", - fontcolor="gray", + node_mapping[node] = node + else: + # Alias nodes - check if they should be collapsed + # An alias node should be collapsed if ALL its neighbors are collapsed + neighbors = set(graph.predecessors(node)) | set(graph.successors(node)) + if neighbors and neighbors <= nodes_to_collapse: + # Get schema from first neighbor + neighbor = next(iter(neighbors)) + full_name = class_to_full.get(neighbor) + if full_name: + schema_name, _ = self._connection.adapter.split_full_table_name(full_name) + if schema_name: + node_mapping[node] = collapsed_labels[schema_name] + continue + node_mapping[node] = node + + # Build reverse mapping: label -> schema_name + label_to_schema = {label: schema for schema, label in collapsed_labels.items()} + + # Add nodes + added_collapsed = set() + for old_node, new_node in node_mapping.items(): + if new_node in collapsed_counts: + # This is a collapsed schema node + if new_node not in added_collapsed: + schema_name = label_to_schema.get(new_node, new_node) + new_graph.add_node( + new_node, + node_type=None, + collapsed=True, + table_count=collapsed_counts[new_node], + schema_name=schema_name, ) - for node in nodes: - cluster.add_node(node) - dot.add_subgraph(cluster) + added_collapsed.add(new_node) + else: + new_graph.add_node(new_node, **graph.nodes[old_node]) - return dot + # Add edges (avoiding self-loops and duplicates) + for src, dest, data in graph.edges(data=True): + new_src = node_mapping[src] + new_dest = node_mapping[dest] + if new_src != new_dest and not new_graph.has_edge(new_src, new_dest): + new_graph.add_edge(new_src, new_dest, **data) - def make_svg(self): - from IPython.display import SVG + return new_graph, collapsed_counts - return SVG(self.make_dot().create_svg()) + def _resolve_class(self, name: str): + """ + Safely resolve a table class from a dotted name without eval(). - def make_png(self): - return io.BytesIO(self.make_dot().create_png()) + Parameters + ---------- + name : str + Dotted class name like "MyTable" or "Module.MyTable". - def make_image(self): - if plot_active: - return plt.imread(self.make_png()) - else: - raise DataJointError("pyplot was not imported") - - def make_mermaid(self) -> str: - """ - Generate Mermaid diagram syntax. - - Produces a flowchart in Mermaid syntax that can be rendered in - Markdown documentation, GitHub, or https://mermaid.live. - - Returns - ------- - str - Mermaid flowchart syntax. - - Notes - ----- - Layout direction is controlled via ``dj.config.display.diagram_direction``. - Tables are grouped by schema using Mermaid subgraphs, with the Python - module name shown as the group label when available. - - Examples - -------- - >>> print(dj.Diagram(schema).make_mermaid()) - flowchart TB - subgraph my_pipeline - Mouse[Mouse]:::manual - Session[Session]:::manual - Neuron([Neuron]):::computed - end - Mouse --> Session - Session --> Neuron - """ - graph = self._make_graph() - direction = self._connection._config.display.diagram_direction - - # Apply collapse logic if needed - graph, collapsed_counts = self._apply_collapse(graph) - - # Build schema mapping for grouping - schema_map = {} # class_name -> schema_name - schema_modules = {} # schema_name -> set of module names - - for full_name in self.nodes_to_show: - parts = full_name.replace('"', "`").split("`") - if len(parts) >= 2: - schema_name = parts[1] - class_name = lookup_class_name(full_name, self.context) or full_name - schema_map[class_name] = schema_name - - # Collect all module names for this schema - if schema_name not in schema_modules: - schema_modules[schema_name] = set() - cls = self._resolve_class(class_name) - if cls is not None and hasattr(cls, "__module__"): - module_name = cls.__module__.split(".")[-1] - schema_modules[schema_name].add(module_name) + Returns + ------- + type or None + The table class if found, otherwise None. + """ + parts = name.split(".") + obj = self.context.get(parts[0]) + for part in parts[1:]: + if obj is None: + return None + obj = getattr(obj, part, None) + if obj is not None and isinstance(obj, type) and issubclass(obj, Table): + return obj + return None + + @staticmethod + def _encapsulate_edge_attributes(graph: nx.DiGraph) -> None: + """ + Encapsulate edge attr_map in double quotes for pydot compatibility. - # Determine cluster labels: use module name if 1:1, else database schema name - cluster_labels = {} - for schema_name, modules in schema_modules.items(): - if len(modules) == 1: - cluster_labels[schema_name] = next(iter(modules)) - else: - cluster_labels[schema_name] = schema_name - - # Assign alias nodes to the same schema as their child table - for node, data in graph.nodes(data=True): - if data.get("node_type") is _AliasNode: - successors = list(graph.successors(node)) - if successors and successors[0] in schema_map: - schema_map[node] = schema_map[successors[0]] - - lines = [f"flowchart {direction}"] - - # Define class styles matching Graphviz colors - lines.append(" classDef manual fill:#90EE90,stroke:#006400") - lines.append(" classDef lookup fill:#D3D3D3,stroke:#696969") - lines.append(" classDef computed fill:#FFB6C1,stroke:#8B0000") - lines.append(" classDef imported fill:#ADD8E6,stroke:#00008B") - lines.append(" classDef part fill:#FFFFFF,stroke:#000000") - lines.append(" classDef collapsed fill:#808080,stroke:#404040") - lines.append("") - - # Shape mapping: Manual=box, Computed/Imported=stadium, Lookup/Part=box - shape_map = { - Manual: ("[", "]"), # box - Lookup: ("[", "]"), # box - Computed: ("([", "])"), # stadium/pill - Imported: ("([", "])"), # stadium/pill - Part: ("[", "]"), # box - _AliasNode: ("((", "))"), # circle - None: ("((", "))"), # circle - } - - tier_class = { - Manual: "manual", - Lookup: "lookup", - Computed: "computed", - Imported: "imported", - Part: "part", - _AliasNode: "", - None: "", - } - - # Group nodes by schema into subgraphs (including collapsed nodes) + Modifies graph in place. + + See Also + -------- + https://github.com/pydot/pydot/issues/258#issuecomment-795798099 + """ + for u, v, *_, edgedata in graph.edges(data=True): + if "attr_map" in edgedata: + graph.edges[u, v]["attr_map"] = '"{0}"'.format(edgedata["attr_map"]) + + @staticmethod + def _encapsulate_node_names(graph: nx.DiGraph) -> None: + """ + Encapsulate node names in double quotes for pydot compatibility. + + Modifies graph in place. + + See Also + -------- + https://github.com/datajoint/datajoint-python/pull/1176 + """ + nx.relabel_nodes( + graph, + {node: '"{0}"'.format(node) for node in graph.nodes()}, + copy=False, + ) + + def make_dot(self): + """ + Generate a pydot graph object. + + Returns + ------- + pydot.Dot + The graph object ready for rendering. + + Raises + ------ + DataJointError + If pygraphviz/pydot is not installed. + + Notes + ----- + Layout direction is controlled via ``dj.config.display.diagram_direction``. + Tables are grouped by schema, with the Python module name shown as the + group label when available. + """ + if not diagram_active: + raise DataJointError("Install pygraphviz and pydot libraries to enable diagram visualization.") + direction = self._connection._config.display.diagram_direction + graph = self._make_graph() + + # Apply collapse logic if needed + graph, collapsed_counts = self._apply_collapse(graph) + + # Build schema mapping: class_name -> schema_name + # Group by database schema, label with Python module name if 1:1 mapping + schema_map = {} # class_name -> schema_name + schema_modules = {} # schema_name -> set of module names + + for full_name in self.nodes_to_show: + schema_name, _ = self._connection.adapter.split_full_table_name(full_name) + if schema_name: + class_name = lookup_class_name(full_name, self.context) or full_name + schema_map[class_name] = schema_name + + # Collect all module names for this schema + if schema_name not in schema_modules: + schema_modules[schema_name] = set() + cls = self._resolve_class(class_name) + if cls is not None and hasattr(cls, "__module__"): + module_name = cls.__module__.split(".")[-1] + schema_modules[schema_name].add(module_name) + + # Determine cluster labels: use module name if 1:1, else database schema name + cluster_labels = {} # schema_name -> label + for schema_name, modules in schema_modules.items(): + if len(modules) == 1: + cluster_labels[schema_name] = next(iter(modules)) + else: + cluster_labels[schema_name] = schema_name + + # Disambiguate labels if multiple schemas share the same module name + # (e.g., all defined in __main__ in a notebook) + label_counts = {} + for label in cluster_labels.values(): + label_counts[label] = label_counts.get(label, 0) + 1 + + for schema_name, label in cluster_labels.items(): + if label_counts[label] > 1: + # Multiple schemas share this module name - add schema name + cluster_labels[schema_name] = f"{label} ({schema_name})" + + # Assign alias nodes (orange dots) to the same schema as their child table + for node, data in graph.nodes(data=True): + if data.get("node_type") is _AliasNode: + # Find the child (successor) - the table that declares the renamed FK + successors = list(graph.successors(node)) + if successors and successors[0] in schema_map: + schema_map[node] = schema_map[successors[0]] + + # Assign collapsed nodes to their schema so they appear in the cluster + for node, data in graph.nodes(data=True): + if data.get("collapsed") and data.get("schema_name"): + schema_map[node] = data["schema_name"] + + scale = 1.2 # scaling factor for fonts and boxes + label_props = { # http://matplotlib.org/examples/color/named_colors.html + None: dict( + shape="circle", + color="#FFFF0040", + fontcolor="yellow", + fontsize=round(scale * 8), + size=0.4 * scale, + fixed=False, + ), + _AliasNode: dict( + shape="circle", + color="#FF880080", + fontcolor="#FF880080", + fontsize=round(scale * 0), + size=0.05 * scale, + fixed=True, + ), + Manual: dict( + shape="box", + color="#00FF0030", + fontcolor="darkgreen", + fontsize=round(scale * 10), + size=0.4 * scale, + fixed=False, + ), + Lookup: dict( + shape="plaintext", + color="#00000020", + fontcolor="black", + fontsize=round(scale * 8), + size=0.4 * scale, + fixed=False, + ), + Computed: dict( + shape="ellipse", + color="#FF000020", + fontcolor="#7F0000A0", + fontsize=round(scale * 10), + size=0.4 * scale, + fixed=False, + ), + Imported: dict( + shape="ellipse", + color="#00007F40", + fontcolor="#00007FA0", + fontsize=round(scale * 10), + size=0.4 * scale, + fixed=False, + ), + Part: dict( + shape="plaintext", + color="#00000000", + fontcolor="black", + fontsize=round(scale * 8), + size=0.1 * scale, + fixed=False, + ), + "collapsed": dict( + shape="box3d", + color="#80808060", + fontcolor="#404040", + fontsize=round(scale * 10), + size=0.5 * scale, + fixed=False, + ), + } + # Build node_props, handling collapsed nodes specially + node_props = {} + for node, d in graph.nodes(data=True): + if d.get("collapsed"): + node_props[node] = label_props["collapsed"] + else: + node_props[node] = label_props[d["node_type"]] + + self._encapsulate_node_names(graph) + self._encapsulate_edge_attributes(graph) + dot = nx.drawing.nx_pydot.to_pydot(graph) + dot.set_rankdir(direction) + for node in dot.get_nodes(): + node.set_shape("circle") + name = node.get_name().strip('"') + props = node_props[name] + node.set_fontsize(props["fontsize"]) + node.set_fontcolor(props["fontcolor"]) + node.set_shape(props["shape"]) + node.set_fontname("arial") + node.set_fixedsize("shape" if props["fixed"] else False) + node.set_width(props["size"]) + node.set_height(props["size"]) + + # Handle collapsed nodes specially + node_data = graph.nodes.get(f'"{name}"', {}) + if node_data.get("collapsed"): + table_count = node_data.get("table_count", 0) + label = f"({table_count} tables)" if table_count != 1 else "(1 table)" + node.set_label(label) + node.set_tooltip(f"Collapsed schema: {table_count} tables") + else: + cls = self._resolve_class(name) + if cls is not None: + description = cls().describe(context=self.context).split("\n") + description = ( + ("-" * 30 if q.startswith("---") else (q.replace("->", "→") if "->" in q else q.split(":")[0])) + for q in description + if not q.startswith("#") + ) + node.set_tooltip(" ".join(description)) + # Strip module prefix from label if it matches the cluster label + display_name = name + schema_name = schema_map.get(name) + if schema_name and "." in name: + cluster_label = cluster_labels.get(schema_name) + if cluster_label and name.startswith(cluster_label + "."): + display_name = name[len(cluster_label) + 1 :] + node.set_label("<" + display_name + ">" if node.get("distinguished") == "True" else display_name) + node.set_color(props["color"]) + node.set_style("filled") + + for edge in dot.get_edges(): + # see https://graphviz.org/doc/info/attrs.html + src = edge.get_source() + dest = edge.get_destination() + props = graph.get_edge_data(src, dest) + if props is None: + raise DataJointError("Could not find edge with source '{}' and destination '{}'".format(src, dest)) + edge.set_color("#00000040") + edge.set_style("solid" if props.get("primary") else "dashed") + dest_node_type = graph.nodes[dest].get("node_type") + master_part = dest_node_type is Part and dest.startswith(src + ".") + edge.set_weight(3 if master_part else 1) + edge.set_arrowhead("none") + edge.set_penwidth(0.75 if props.get("multi") else 2) + + # Group nodes into schema clusters (always on) + if schema_map: + import pydot + + # Group nodes by schema schemas = {} - for node, data in graph.nodes(data=True): - if data.get("collapsed"): - # Collapsed nodes use their schema_name attribute - schema_name = data.get("schema_name") - else: - schema_name = schema_map.get(node) + for node in list(dot.get_nodes()): + name = node.get_name().strip('"') + schema_name = schema_map.get(name) if schema_name: if schema_name not in schemas: schemas[schema_name] = [] - schemas[schema_name].append((node, data)) + schemas[schema_name].append(node) - # Add nodes grouped by schema subgraphs + # Create clusters for each schema + # Use Python module name if 1:1 mapping, otherwise database schema name for schema_name, nodes in schemas.items(): label = cluster_labels.get(schema_name, schema_name) - lines.append(f" subgraph {label}") - for node, data in nodes: - safe_id = node.replace(".", "_").replace(" ", "_") - if data.get("collapsed"): - # Collapsed node - show only table count - table_count = data.get("table_count", 0) - count_text = f"{table_count} tables" if table_count != 1 else "1 table" - lines.append(f' {safe_id}[["({count_text})"]]:::collapsed') - else: - # Regular node - tier = data.get("node_type") - left, right = shape_map.get(tier, ("[", "]")) - cls = tier_class.get(tier, "") - # Strip module prefix from display name if it matches the cluster label - display_name = node - if "." in node and node.startswith(label + "."): - display_name = node[len(label) + 1 :] - class_suffix = f":::{cls}" if cls else "" - lines.append(f" {safe_id}{left}{display_name}{right}{class_suffix}") - lines.append(" end") - - lines.append("") - - # Add edges - for src, dest, data in graph.edges(data=True): - safe_src = src.replace(".", "_").replace(" ", "_") - safe_dest = dest.replace(".", "_").replace(" ", "_") - # Solid arrow for primary FK, dotted for non-primary - style = "-->" if data.get("primary") else "-.->" - lines.append(f" {safe_src} {style} {safe_dest}") - - return "\n".join(lines) - - def _repr_svg_(self): - return self.make_svg()._repr_svg_() - - def draw(self): - if plot_active: - plt.imshow(self.make_image()) - plt.gca().axis("off") - plt.show() + cluster = pydot.Cluster( + f"cluster_{schema_name}", + label=label, + style="dashed", + color="gray", + fontcolor="gray", + ) + for node in nodes: + cluster.add_node(node) + dot.add_subgraph(cluster) + + return dot + + def make_svg(self): + from IPython.display import SVG + + return SVG(self.make_dot().create_svg()) + + def make_png(self): + return io.BytesIO(self.make_dot().create_png()) + + def make_image(self): + if plot_active: + return plt.imread(self.make_png()) + else: + raise DataJointError("pyplot was not imported") + + def make_mermaid(self) -> str: + """ + Generate Mermaid diagram syntax. + + Produces a flowchart in Mermaid syntax that can be rendered in + Markdown documentation, GitHub, or https://mermaid.live. + + Returns + ------- + str + Mermaid flowchart syntax. + + Notes + ----- + Layout direction is controlled via ``dj.config.display.diagram_direction``. + Tables are grouped by schema using Mermaid subgraphs, with the Python + module name shown as the group label when available. + + Examples + -------- + >>> print(dj.Diagram(schema).make_mermaid()) + flowchart TB + subgraph my_pipeline + Mouse[Mouse]:::manual + Session[Session]:::manual + Neuron([Neuron]):::computed + end + Mouse --> Session + Session --> Neuron + """ + graph = self._make_graph() + direction = self._connection._config.display.diagram_direction + + # Apply collapse logic if needed + graph, collapsed_counts = self._apply_collapse(graph) + + # Build schema mapping for grouping + schema_map = {} # class_name -> schema_name + schema_modules = {} # schema_name -> set of module names + + for full_name in self.nodes_to_show: + schema_name, _ = self._connection.adapter.split_full_table_name(full_name) + if schema_name: + class_name = lookup_class_name(full_name, self.context) or full_name + schema_map[class_name] = schema_name + + # Collect all module names for this schema + if schema_name not in schema_modules: + schema_modules[schema_name] = set() + cls = self._resolve_class(class_name) + if cls is not None and hasattr(cls, "__module__"): + module_name = cls.__module__.split(".")[-1] + schema_modules[schema_name].add(module_name) + + # Determine cluster labels: use module name if 1:1, else database schema name + cluster_labels = {} + for schema_name, modules in schema_modules.items(): + if len(modules) == 1: + cluster_labels[schema_name] = next(iter(modules)) else: - raise DataJointError("pyplot was not imported") - - def save(self, filename: str, format: str | None = None) -> None: - """ - Save diagram to file. - - Parameters - ---------- - filename : str - Output filename. - format : str, optional - File format (``'png'``, ``'svg'``, or ``'mermaid'``). - Inferred from extension if None. - - Raises - ------ - DataJointError - If format is unsupported. - - Notes - ----- - Layout direction is controlled via ``dj.config.display.diagram_direction``. - Tables are grouped by schema, with the Python module name shown as the - group label when available. - """ - if format is None: - if filename.lower().endswith(".png"): - format = "png" - elif filename.lower().endswith(".svg"): - format = "svg" - elif filename.lower().endswith((".mmd", ".mermaid")): - format = "mermaid" - if format is None: - raise DataJointError("Could not infer format from filename. Specify format explicitly.") - if format.lower() == "png": - with open(filename, "wb") as f: - f.write(self.make_png().getbuffer().tobytes()) - elif format.lower() == "svg": - with open(filename, "w") as f: - f.write(self.make_svg().data) - elif format.lower() == "mermaid": - with open(filename, "w") as f: - f.write(self.make_mermaid()) + cluster_labels[schema_name] = schema_name + + # Assign alias nodes to the same schema as their child table + for node, data in graph.nodes(data=True): + if data.get("node_type") is _AliasNode: + successors = list(graph.successors(node)) + if successors and successors[0] in schema_map: + schema_map[node] = schema_map[successors[0]] + + lines = [f"flowchart {direction}"] + + # Define class styles matching Graphviz colors + lines.append(" classDef manual fill:#90EE90,stroke:#006400") + lines.append(" classDef lookup fill:#D3D3D3,stroke:#696969") + lines.append(" classDef computed fill:#FFB6C1,stroke:#8B0000") + lines.append(" classDef imported fill:#ADD8E6,stroke:#00008B") + lines.append(" classDef part fill:#FFFFFF,stroke:#000000") + lines.append(" classDef collapsed fill:#808080,stroke:#404040") + lines.append("") + + # Shape mapping: Manual=box, Computed/Imported=stadium, Lookup/Part=box + shape_map = { + Manual: ("[", "]"), # box + Lookup: ("[", "]"), # box + Computed: ("([", "])"), # stadium/pill + Imported: ("([", "])"), # stadium/pill + Part: ("[", "]"), # box + _AliasNode: ("((", "))"), # circle + None: ("((", "))"), # circle + } + + tier_class = { + Manual: "manual", + Lookup: "lookup", + Computed: "computed", + Imported: "imported", + Part: "part", + _AliasNode: "", + None: "", + } + + # Group nodes by schema into subgraphs (including collapsed nodes) + schemas = {} + for node, data in graph.nodes(data=True): + if data.get("collapsed"): + # Collapsed nodes use their schema_name attribute + schema_name = data.get("schema_name") else: - raise DataJointError("Unsupported file format") + schema_name = schema_map.get(node) + if schema_name: + if schema_name not in schemas: + schemas[schema_name] = [] + schemas[schema_name].append((node, data)) + + # Add nodes grouped by schema subgraphs + for schema_name, nodes in schemas.items(): + label = cluster_labels.get(schema_name, schema_name) + lines.append(f" subgraph {label}") + for node, data in nodes: + safe_id = node.replace(".", "_").replace(" ", "_") + if data.get("collapsed"): + # Collapsed node - show only table count + table_count = data.get("table_count", 0) + count_text = f"{table_count} tables" if table_count != 1 else "1 table" + lines.append(f' {safe_id}[["({count_text})"]]:::collapsed') + else: + # Regular node + tier = data.get("node_type") + left, right = shape_map.get(tier, ("[", "]")) + cls = tier_class.get(tier, "") + # Strip module prefix from display name if it matches the cluster label + display_name = node + if "." in node and node.startswith(label + "."): + display_name = node[len(label) + 1 :] + class_suffix = f":::{cls}" if cls else "" + lines.append(f" {safe_id}{left}{display_name}{right}{class_suffix}") + lines.append(" end") + + lines.append("") + + # Add edges + for src, dest, data in graph.edges(data=True): + safe_src = src.replace(".", "_").replace(" ", "_") + safe_dest = dest.replace(".", "_").replace(" ", "_") + # Solid arrow for primary FK, dotted for non-primary + style = "-->" if data.get("primary") else "-.->" + lines.append(f" {safe_src} {style} {safe_dest}") + + return "\n".join(lines) + + def _repr_svg_(self): + return self.make_svg()._repr_svg_() + + def draw(self): + if plot_active: + plt.imshow(self.make_image()) + plt.gca().axis("off") + plt.show() + else: + raise DataJointError("pyplot was not imported") + + def save(self, filename: str, format: str | None = None) -> None: + """ + Save diagram to file. - @staticmethod - def _layout(graph, **kwargs): - return pydot_layout(graph, prog="dot", **kwargs) + Parameters + ---------- + filename : str + Output filename. + format : str, optional + File format (``'png'``, ``'svg'``, or ``'mermaid'``). + Inferred from extension if None. + + Raises + ------ + DataJointError + If format is unsupported. + + Notes + ----- + Layout direction is controlled via ``dj.config.display.diagram_direction``. + Tables are grouped by schema, with the Python module name shown as the + group label when available. + """ + if format is None: + if filename.lower().endswith(".png"): + format = "png" + elif filename.lower().endswith(".svg"): + format = "svg" + elif filename.lower().endswith((".mmd", ".mermaid")): + format = "mermaid" + if format is None: + raise DataJointError("Could not infer format from filename. Specify format explicitly.") + if format.lower() == "png": + with open(filename, "wb") as f: + f.write(self.make_png().getbuffer().tobytes()) + elif format.lower() == "svg": + with open(filename, "w") as f: + f.write(self.make_svg().data) + elif format.lower() == "mermaid": + with open(filename, "w") as f: + f.write(self.make_mermaid()) + else: + raise DataJointError("Unsupported file format") + + @staticmethod + def _layout(graph, **kwargs): + return pydot_layout(graph, prog="dot", **kwargs) diff --git a/src/datajoint/table.py b/src/datajoint/table.py index 83e73e2bb..7f8cbaf70 100644 --- a/src/datajoint/table.py +++ b/src/datajoint/table.py @@ -14,6 +14,7 @@ from .condition import make_condition from .declare import alter, declare +from .dependencies import extract_master from .errors import ( AccessError, DataJointError, @@ -24,7 +25,7 @@ from .expression import QueryExpression from .heading import Heading from .staged_insert import staged_insert1 as _staged_insert1 -from .utils import get_master, is_camel_case, user_choice +from .utils import is_camel_case, user_choice logger = logging.getLogger(__name__.split(".")[0]) @@ -984,6 +985,19 @@ def delete( """ Deletes the contents of the table and its dependent tables, recursively. + Uses graph-driven cascade: builds a dependency diagram, propagates + restrictions to all descendants, then deletes in reverse topological + order (leaves first). + + With ``safemode=True`` (the default), delete previews all affected + tables and row counts, executes within a transaction, and asks for + confirmation before committing. Declining rolls back all changes — + effectively a built-in dry run. + + To preview cascade impact without executing, use ``Diagram``:: + + dj.Diagram.cascade(MyTable & restriction).counts() + Args: transaction: If `True`, use of the entire delete becomes an atomic transaction. This is the default and recommended behavior. Set to `False` if this delete is @@ -999,182 +1013,100 @@ def delete( Number of deleted rows (excluding those from dependent tables). Raises: - DataJointError: Delete exceeds maximum number of delete attempts. DataJointError: When deleting within an existing transaction. DataJointError: Deleting a part table before its master (when part_integrity="enforce"). ValueError: Invalid part_integrity value. """ if part_integrity not in ("enforce", "ignore", "cascade"): - raise ValueError(f"part_integrity must be 'enforce', 'ignore', or 'cascade', got {part_integrity!r}") - deleted = set() - visited_masters = set() - - def cascade(table): - """service function to perform cascading deletes recursively.""" - max_attempts = 50 - for _ in range(max_attempts): - # Set savepoint before delete attempt (for PostgreSQL transaction handling) - savepoint_name = f"cascade_delete_{id(table)}" - if transaction: - table.connection.query(f"SAVEPOINT {savepoint_name}") + raise ValueError(f"part_integrity must be 'enforce', 'ignore', or 'cascade', " f"got {part_integrity!r}") + from .diagram import Diagram - try: - delete_count = table.delete_quick(get_count=True) - except IntegrityError as error: - # Rollback to savepoint so we can continue querying (PostgreSQL requirement) - if transaction: - table.connection.query(f"ROLLBACK TO SAVEPOINT {savepoint_name}") - # Use adapter to parse FK error message - match = table.connection.adapter.parse_foreign_key_error(error.args[0]) - if match is None: - raise DataJointError( - "Cascading deletes failed because the error message is missing foreign key information. " - "Make sure you have REFERENCES privilege to all dependent tables." - ) from None - - # Strip quotes from parsed values for backend-agnostic processing - quote_chars = ("`", '"') - - def strip_quotes(s): - if s and any(s.startswith(q) for q in quote_chars): - return s.strip('`"') - return s - - # Extract schema and table name from child (work with unquoted names) - child_table_raw = strip_quotes(match["child"]) - if "." in child_table_raw: - child_parts = child_table_raw.split(".") - child_schema = strip_quotes(child_parts[0]) - child_table_name = strip_quotes(child_parts[1]) - else: - # Add schema from current table - schema_parts = table.full_table_name.split(".") - child_schema = strip_quotes(schema_parts[0]) - child_table_name = child_table_raw - - # If FK/PK attributes not in error message, query information_schema - if match["fk_attrs"] is None or match["pk_attrs"] is None: - constraint_query = table.connection.adapter.get_constraint_info_sql( - strip_quotes(match["name"]), - child_schema, - child_table_name, - ) + diagram = Diagram.cascade(self, part_integrity=part_integrity) - results = table.connection.query( - constraint_query, - args=(strip_quotes(match["name"]), child_schema, child_table_name), - ).fetchall() - if results: - match["fk_attrs"], match["parent"], match["pk_attrs"] = list(map(list, zip(*results))) - match["parent"] = match["parent"][0] # All rows have same parent - - # Build properly quoted full table name for FreeTable - child_full_name = ( - f"{table.connection.adapter.quote_identifier(child_schema)}." - f"{table.connection.adapter.quote_identifier(child_table_name)}" - ) - - # Restrict child by table if - # 1. if table's restriction attributes are not in child's primary key - # 2. if child renames any attributes - # Otherwise restrict child by table's restriction. - child = FreeTable(table.connection, child_full_name) - if set(table.restriction_attributes) <= set(child.primary_key) and match["fk_attrs"] == match["pk_attrs"]: - child._restriction = table._restriction - child._restriction_attributes = table.restriction_attributes - elif match["fk_attrs"] != match["pk_attrs"]: - child &= table.proj(**dict(zip(match["fk_attrs"], match["pk_attrs"]))) - else: - child &= table.proj() - - master_name = get_master(child.full_table_name, table.connection.adapter) - if ( - part_integrity == "cascade" - and master_name - and master_name != table.full_table_name - and master_name not in visited_masters - ): - master = FreeTable(table.connection, master_name) - master._restriction_attributes = set() - master._restriction = [ - make_condition( # &= may cause in target tables in subquery - master, - (master.proj() & child.proj()).to_arrays(), - master._restriction_attributes, - ) - ] - visited_masters.add(master_name) - cascade(master) - else: - cascade(child) - else: - # Successful delete - release savepoint - if transaction: - table.connection.query(f"RELEASE SAVEPOINT {savepoint_name}") - deleted.add(table.full_table_name) - logger.info("Deleting {count} rows from {table}".format(count=delete_count, table=table.full_table_name)) - break - else: - raise DataJointError("Exceeded maximum number of delete attempts.") - return delete_count + conn = self.connection + prompt = conn._config["safemode"] if prompt is None else prompt - prompt = self.connection._config["safemode"] if prompt is None else prompt + # Preview + if prompt: + for ft in diagram: + logger.info("{table} ({count} tuples)".format(table=ft.full_table_name, count=len(ft))) # Start transaction if transaction: - if not self.connection.in_transaction: - self.connection.start_transaction() + if not conn.in_transaction: + conn.start_transaction() else: if not prompt: transaction = False else: raise DataJointError( - "Delete cannot use a transaction within an ongoing transaction. Set transaction=False or prompt=False." + "Delete cannot use a transaction within an " + "ongoing transaction. Set transaction=False " + "or prompt=False." ) - # Cascading delete + # Execute deletes in reverse topological order (leaves first) + root_count = 0 + deleted_tables = set() try: - delete_count = cascade(self) - except: + for ft in reversed(diagram): + count = ft.delete_quick(get_count=True) + if count > 0: + deleted_tables.add(ft.full_table_name) + logger.info("Deleting {count} rows from {table}".format(count=count, table=ft.full_table_name)) + if ft.full_table_name == self.full_table_name: + root_count = count + except IntegrityError as error: if transaction: - self.connection.cancel_transaction() + conn.cancel_transaction() + match = conn.adapter.parse_foreign_key_error(error.args[0]) + if match: + raise DataJointError( + "Delete blocked by table {child} in an unloaded " + "schema. Activate all dependent schemas before " + "deleting.".format(child=match["child"]) + ) from None + raise DataJointError("Delete blocked by FK in unloaded/inaccessible schema.") from None + except Exception: + if transaction: + conn.cancel_transaction() raise - if part_integrity == "enforce": - # Avoid deleting from part before master (See issue #151) - for part in deleted: - master = get_master(part, self.connection.adapter) - if master and master not in deleted: + # Post-check part_integrity="enforce": roll back if a part table + # had rows deleted without its master also having rows deleted. + if part_integrity == "enforce" and deleted_tables: + for table_name in deleted_tables: + master = extract_master(table_name) + if master and master not in deleted_tables: if transaction: - self.connection.cancel_transaction() + conn.cancel_transaction() raise DataJointError( - "Attempt to delete part table {part} before deleting from its master {master} first. " - "Use part_integrity='ignore' to allow, or part_integrity='cascade' to also delete master.".format( - part=part, master=master - ) + f"Attempt to delete part table {table_name} before " + f"its master {master}. Delete from the master first, " + f"or use part_integrity='ignore' or 'cascade'." ) # Confirm and commit - if delete_count == 0: + if root_count == 0: if prompt: logger.warning("Nothing to delete.") if transaction: - self.connection.cancel_transaction() + conn.cancel_transaction() elif not transaction: logger.info("Delete completed") else: if not prompt or user_choice("Commit deletes?", default="no") == "yes": if transaction: - self.connection.commit_transaction() + conn.commit_transaction() if prompt: logger.info("Delete committed.") else: if transaction: - self.connection.cancel_transaction() + conn.cancel_transaction() if prompt: logger.warning("Delete cancelled") - delete_count = 0 # Reset count when delete is cancelled - return delete_count + root_count = 0 + return root_count def drop_quick(self): """ @@ -1214,41 +1146,59 @@ def drop_quick(self): else: logger.info("Nothing to drop: table %s is not declared" % self.full_table_name) - def drop(self, prompt: bool | None = None): + def drop(self, prompt: bool | None = None, part_integrity: str = "enforce"): """ Drop the table and all tables that reference it, recursively. + Uses graph-driven traversal: builds a dependency diagram and drops + in reverse topological order (leaves first). + + With ``safemode=True`` (the default), drop previews all affected + tables and row counts and asks for confirmation before proceeding. + Args: prompt: If `True`, show what will be dropped and ask for confirmation. If `False`, drop without confirmation. Default is `dj.config['safemode']`. + part_integrity: Policy for master-part integrity. One of: + - ``"enforce"`` (default): Error if parts would be dropped without masters. + - ``"ignore"``: Allow dropping parts without masters. """ if self.restriction: raise DataJointError( - "A table with an applied restriction cannot be dropped. Call drop() on the unrestricted Table." + "A table with an applied restriction cannot be dropped. " "Call drop() on the unrestricted Table." ) - prompt = self.connection._config["safemode"] if prompt is None else prompt + import networkx as nx + from .diagram import Diagram - self.connection.dependencies.load() - do_drop = True - tables = [table for table in self.connection.dependencies.descendants(self.full_table_name) if not table.isdigit()] + self.connection.dependencies.load_all_downstream() + diagram = Diagram(self) + # Expand to include all descendants (cross-schema) + descendants = set(nx.descendants(diagram, self.full_table_name)) | {self.full_table_name} + diagram.nodes_to_show = descendants + diagram._expanded_nodes = set(descendants) + conn = self.connection + prompt = conn._config["safemode"] if prompt is None else prompt - # avoid dropping part tables without their masters: See issue #374 - for part in tables: - master = get_master(part, self.connection.adapter) - if master and master not in tables: - raise DataJointError( - "Attempt to drop part table {part} before dropping its master. Drop {master} first.".format( - part=part, master=master + table_names = [ft.full_table_name for ft in diagram] + + if part_integrity == "enforce": + for name in table_names: + master = extract_master(name) + if master and master not in table_names: + raise DataJointError( + "Attempt to drop part table {part} before its " "master {master}. Drop the master first.".format( + part=name, master=master + ) ) - ) + do_drop = True if prompt: - for table in tables: - logger.info(table + " (%d tuples)" % len(FreeTable(self.connection, table))) + for ft in diagram: + logger.info("{table} ({count} tuples)".format(table=ft.full_table_name, count=len(ft))) do_drop = user_choice("Proceed?", default="no") == "yes" if do_drop: - for table in reversed(tables): - FreeTable(self.connection, table).drop_quick() + for ft in reversed(diagram): + ft.drop_quick() logger.info("Tables dropped. Restart kernel.") def describe(self, context=None, printout=False): @@ -1611,8 +1561,7 @@ class FreeTable(Table): """ def __init__(self, conn, full_table_name): - # Backend-agnostic quote stripping (MySQL uses `, PostgreSQL uses ") - self.database, self._table_name = (s.strip('`"') for s in full_table_name.split(".")) + self.database, self._table_name = conn.adapter.split_full_table_name(full_table_name) self._connection = conn self._support = [full_table_name] self._heading = Heading( diff --git a/src/datajoint/user_tables.py b/src/datajoint/user_tables.py index 7822fa9e2..514f4eb60 100644 --- a/src/datajoint/user_tables.py +++ b/src/datajoint/user_tables.py @@ -235,7 +235,7 @@ def delete(self, part_integrity: str = "enforce", **kwargs): "or use part_integrity='ignore' to break integrity, " "or part_integrity='cascade' to also delete master." ) - super().delete(part_integrity=part_integrity, **kwargs) + return super().delete(part_integrity=part_integrity, **kwargs) def drop(self, part_integrity: str = "enforce"): """ @@ -251,7 +251,7 @@ def drop(self, part_integrity: str = "enforce"): DataJointError: If part_integrity="enforce" (direct Part drops prohibited) """ if part_integrity == "ignore": - super().drop() + return super().drop(part_integrity="ignore") elif part_integrity == "enforce": raise DataJointError("Cannot drop a Part directly. Drop master instead, or use part_integrity='ignore' to force.") else: diff --git a/src/datajoint/utils.py b/src/datajoint/utils.py index 0441af354..e36267936 100644 --- a/src/datajoint/utils.py +++ b/src/datajoint/utils.py @@ -37,46 +37,6 @@ def user_choice(prompt, choices=("yes", "no"), default=None): return response -def get_master(full_table_name: str, adapter=None) -> str: - """ - Get the master table name from a part table name. - - If the table name is that of a part table, then return what the master table name would be. - This follows DataJoint's table naming convention where a master and a part must be in the - same schema and the part table is prefixed with the master table name + ``__``. - - Parameters - ---------- - full_table_name : str - Full table name including part. - adapter : DatabaseAdapter, optional - Database adapter for backend-specific parsing. Default None. - - Returns - ------- - str - Supposed master full table name or empty string if not a part table name. - - Examples - -------- - >>> get_master('`ephys`.`session__recording`') # MySQL part table - '`ephys`.`session`' - >>> get_master('"ephys"."session__recording"') # PostgreSQL part table - '"ephys"."session"' - >>> get_master('`ephys`.`session`') # Not a part table - '' - """ - if adapter is not None: - result = adapter.get_master_table_name(full_table_name) - return result if result else "" - - # Fallback: handle both MySQL backticks and PostgreSQL double quotes - match = re.match(r'(?P(?P[`"])[\w]+(?P=q)\.(?P=q)[\w]+)__[\w]+(?P=q)', full_table_name) - if match: - return match["master"] + match["q"] - return "" - - def is_camel_case(s): """ Check if a string is in CamelCase notation. diff --git a/src/datajoint/version.py b/src/datajoint/version.py index 871a28cbb..9a1d4aff2 100644 --- a/src/datajoint/version.py +++ b/src/datajoint/version.py @@ -1,4 +1,4 @@ # version bump auto managed by Github Actions: # label_prs.yaml(prep), release.yaml(bump), post_release.yaml(edit) # manually set this version will be eventually overwritten by the above actions -__version__ = "2.1.1" +__version__ = "2.2.0.dev0" diff --git a/tests/integration/test_cascade_delete.py b/tests/integration/test_cascade_delete.py index caf5f331b..3bc3dc73b 100644 --- a/tests/integration/test_cascade_delete.py +++ b/tests/integration/test_cascade_delete.py @@ -188,3 +188,107 @@ class Observation(dj.Manual): assert remaining_obs[0]["obs_id"] == 3 assert remaining_obs[0]["subject_id"] == 2 assert remaining_obs[0]["measurement"] == 15.3 + + +def test_delete_preview_with_counts(schema_by_backend): + """Diagram.cascade().counts() previews affected rows without deleting.""" + + @schema_by_backend + class Parent(dj.Manual): + definition = """ + parent_id : int + --- + name : varchar(255) + """ + + @schema_by_backend + class Child(dj.Manual): + definition = """ + -> Parent + child_id : int + --- + data : varchar(255) + """ + + Parent.insert1((1, "P1")) + Parent.insert1((2, "P2")) + Child.insert1((1, 1, "C1-1")) + Child.insert1((1, 2, "C1-2")) + Child.insert1((2, 1, "C2-1")) + + # Preview restricted cascade via Diagram + counts = dj.Diagram.cascade(Parent & {"parent_id": 1}).counts() + + assert isinstance(counts, dict) + assert counts[Parent.full_table_name] == 1 + assert counts[Child.full_table_name] == 2 + + # Data must still be intact + assert len(Parent()) == 2 + assert len(Child()) == 3 + + +def test_cascade_discovers_downstream_schema(connection_by_backend, db_creds_by_backend): + """Cascade delete discovers and includes tables in unloaded downstream schemas.""" + import time + + backend = db_creds_by_backend["backend"] + test_id = str(int(time.time() * 1000))[-8:] + + upstream_name = f"djtest_upstream_{backend}_{test_id}"[:64] + downstream_name = f"djtest_downstream_{backend}_{test_id}"[:64] + + qi = connection_by_backend.adapter.quote_identifier + + # Clean up any previous runs + for name in (downstream_name, upstream_name): + try: + connection_by_backend.query(f"DROP DATABASE IF EXISTS {qi(name)}") + except Exception: + pass + + # Create upstream schema and table + upstream = dj.Schema(upstream_name, connection=connection_by_backend) + + @upstream + class Parent(dj.Manual): + definition = """ + parent_id : int + --- + name : varchar(100) + """ + + # Create downstream schema with FK to upstream — separate schema object + downstream = dj.Schema(downstream_name, connection=connection_by_backend) + + @downstream + class Child(dj.Manual): + definition = """ + -> Parent + child_id : int + --- + data : varchar(100) + """ + + # Insert data + Parent.insert1(dict(parent_id=1, name="Alice")) + Child.insert1(dict(parent_id=1, child_id=1, data="row1")) + Child.insert1(dict(parent_id=1, child_id=2, data="row2")) + + # Verify cascade preview discovers the downstream schema + counts = dj.Diagram.cascade(Parent & "parent_id=1").counts() + assert Parent.full_table_name in counts + assert Child.full_table_name in counts + assert counts[Child.full_table_name] == 2 + + # Verify actual delete cascades across schemas + (Parent & "parent_id=1").delete() + assert len(Parent()) == 0 + assert len(Child()) == 0 + + # Clean up + for name in (downstream_name, upstream_name): + try: + connection_by_backend.query(f"DROP DATABASE IF EXISTS {qi(name)}") + except Exception: + pass diff --git a/tests/integration/test_cli.py b/tests/integration/test_cli.py index 35230ea4e..1f8144f0f 100644 --- a/tests/integration/test_cli.py +++ b/tests/integration/test_cli.py @@ -3,6 +3,7 @@ """ import subprocess +import sys import pytest @@ -31,7 +32,7 @@ def test_cli_help(capsys): def test_cli_config(): process = subprocess.Popen( - ["dj"], + [sys.executable, "-m", "datajoint.cli"], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, @@ -50,7 +51,7 @@ def test_cli_config(): def test_cli_args(): process = subprocess.Popen( - ["dj", "-u", "test_user", "-p", "test_pass", "--host", "test_host"], + [sys.executable, "-m", "datajoint.cli", "-u", "test_user", "-p", "test_pass", "--host", "test_host"], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, @@ -82,7 +83,9 @@ class IJ(dj.Lookup): # Pass credentials via CLI args to avoid prompting for username process = subprocess.Popen( [ - "dj", + sys.executable, + "-m", + "datajoint.cli", "-u", db_creds_root["user"], "-p", diff --git a/tests/integration/test_erd.py b/tests/integration/test_erd.py index 95077da50..d746bf49e 100644 --- a/tests/integration/test_erd.py +++ b/tests/integration/test_erd.py @@ -1,6 +1,8 @@ +import pytest as _pytest + import datajoint as dj -from tests.schema_simple import LOCALS_SIMPLE, A, B, D, E, G, L +from tests.schema_simple import LOCALS_SIMPLE, A, B, D, E, G, L, Profile, Website def test_decorator(schema_simp): @@ -61,3 +63,96 @@ def test_part_table_parsing(schema_simp): graph = erd._make_graph() assert "OutfitLaunch" in graph.nodes() assert "OutfitLaunch.OutfitPiece" in graph.nodes() + + +# --- prune() tests --- + + +@_pytest.fixture +def schema_simp_pop(schema_simp): + """Populate the simple schema for prune tests.""" + Profile().delete() + Website().delete() + G().delete() + E().delete() + D().delete() + B().delete() + L().delete() + A().delete() + + A().insert(A.contents, skip_duplicates=True) + L().insert(L.contents, skip_duplicates=True) + B().populate() + D().populate() + E().populate() + G().populate() + yield schema_simp + + +def test_prune_unrestricted(schema_simp_pop): + """Prune on unrestricted diagram removes physically empty tables.""" + diag = dj.Diagram(schema_simp_pop, context=LOCALS_SIMPLE) + original_count = len(diag.nodes_to_show) + pruned = diag.prune() + + # Populated tables (A, L, B, B.C, D, E, E.F, G, etc.) should survive + for cls in (A, B, D, E, L): + assert cls.full_table_name in pruned.nodes_to_show, f"{cls.__name__} should not be pruned" + + # Empty tables like Profile should be removed + assert Profile.full_table_name not in pruned.nodes_to_show, "empty Profile should be pruned" + + # Pruned diagram should have fewer nodes + assert len(pruned.nodes_to_show) < original_count + + +def test_prune_after_restrict(schema_simp_pop): + """Prune after restrict removes tables with zero matching rows.""" + diag = dj.Diagram(schema_simp_pop, context=LOCALS_SIMPLE) + restricted = diag.restrict(A & "id_a=0") + counts = restricted.counts() + + pruned = restricted.prune() + pruned_counts = pruned.counts() + + # Every table in pruned preview should have > 0 rows + assert all(c > 0 for c in pruned_counts.values()), "pruned diagram should have no zero-count tables" + + # Tables with zero rows in the original preview should be gone + for table, count in counts.items(): + if count == 0: + assert table not in pruned._restrict_conditions, f"{table} had 0 rows but was not pruned" + + +def test_prune_raises_on_cascade(schema_simp_pop): + """prune() raises on a cascade Diagram — cascade must retain all tables for safe deletion.""" + cascaded = dj.Diagram.cascade(A & "id_a=0") + with _pytest.raises(dj.DataJointError, match="prune.*cannot be used.*cascade"): + cascaded.prune() + + +def test_prune_idempotent(schema_simp_pop): + """Pruning twice gives the same result.""" + diag = dj.Diagram(schema_simp_pop, context=LOCALS_SIMPLE) + restricted = diag.restrict(A & "id_a=0") + pruned_once = restricted.prune() + pruned_twice = pruned_once.prune() + + assert pruned_once.nodes_to_show == pruned_twice.nodes_to_show + assert set(pruned_once._restrict_conditions) == set(pruned_twice._restrict_conditions) + + +def test_prune_then_restrict(schema_simp_pop): + """Restrict can be called after prune.""" + diag = dj.Diagram(schema_simp_pop, context=LOCALS_SIMPLE) + pruned = diag.restrict(A & "id_a < 5").prune() + # Restrict again on the same seed table with a tighter condition + further = pruned.restrict(A & "id_a=0") + + # Should not raise; further restriction should narrow results + counts = further.counts() + assert all(c >= 0 for c in counts.values()) + # Tighter restriction should produce fewer or equal rows + pruned_counts = pruned.counts() + for table in counts: + assert counts[table] <= pruned_counts.get(table, 0)