diff --git a/docs/design/thread-safe-mode.md b/docs/design/thread-safe-mode.md
deleted file mode 100644
index 5d7472667..000000000
--- a/docs/design/thread-safe-mode.md
+++ /dev/null
@@ -1,387 +0,0 @@
-# Thread-Safe Mode Specification
-
-## Problem
-
-DataJoint uses global state (`dj.config`, `dj.conn()`) that is not thread-safe. Multi-tenant applications (web servers, async workers) need isolated connections per request/task.
-
-## Solution
-
-Introduce **Instance** objects that encapsulate config and connection. The `dj` module provides a global config that can be modified before connecting, and a lazily-loaded singleton connection. New isolated instances are created with `dj.Instance()`.
-
-## API
-
-### Legacy API (global config + singleton connection)
-
-```python
-import datajoint as dj
-
-# Configure credentials (no connection yet)
-dj.config.database.user = "user"
-dj.config.database.password = "password"
-
-# First call to conn() or Schema() creates the singleton connection
-dj.conn() # Creates connection using dj.config credentials
-schema = dj.Schema("my_schema")
-
-@schema
-class Mouse(dj.Manual):
- definition = "..."
-```
-
-Alternatively, pass credentials directly to `conn()`:
-```python
-dj.conn(host="localhost", user="user", password="password")
-```
-
-Internally:
-- `dj.config` → delegates to `_global_config` (with thread-safety check)
-- `dj.conn()` → returns `_singleton_connection` (created lazily)
-- `dj.Schema()` → uses `_singleton_connection`
-- `dj.FreeTable()` → uses `_singleton_connection`
-
-### New API (isolated instance)
-
-```python
-import datajoint as dj
-
-inst = dj.Instance(
- host="localhost",
- user="user",
- password="password",
-)
-schema = inst.Schema("my_schema")
-
-@schema
-class Mouse(dj.Manual):
- definition = "..."
-```
-
-### Instance structure
-
-Each instance has:
-- `inst.config` - Config (created fresh at instance creation)
-- `inst.connection` - Connection (created at instance creation)
-- `inst.Schema()` - Schema factory using instance's connection
-- `inst.FreeTable()` - FreeTable factory using instance's connection
-
-```python
-inst = dj.Instance(host="localhost", user="u", password="p")
-inst.config # Config instance
-inst.connection # Connection instance
-inst.Schema("name") # Creates schema using inst.connection
-inst.FreeTable("db.tbl") # Access table using inst.connection
-```
-
-### Table base classes vs instance methods
-
-**Base classes** (`dj.Manual`, `dj.Lookup`, etc.) - Used with `@schema` decorator:
-```python
-@schema
-class Mouse(dj.Manual): # dj.Manual - schema links to connection
- definition = "..."
-```
-
-**Instance methods** (`inst.Schema()`, `inst.FreeTable()`) - Need connection directly:
-```python
-schema = inst.Schema("my_schema") # Uses inst.connection
-table = inst.FreeTable("db.table") # Uses inst.connection
-```
-
-### Thread-safe mode
-
-```bash
-export DJ_THREAD_SAFE=true
-```
-
-`thread_safe` is checked dynamically on each access to global state.
-
-When `thread_safe=True`, accessing global state raises `ThreadSafetyError`:
-- `dj.config` raises `ThreadSafetyError`
-- `dj.conn()` raises `ThreadSafetyError`
-- `dj.Schema()` raises `ThreadSafetyError` (without explicit connection)
-- `dj.FreeTable()` raises `ThreadSafetyError` (without explicit connection)
-- `dj.Instance()` works - isolated instances are always allowed
-
-```python
-# thread_safe=True
-
-dj.config # ThreadSafetyError
-dj.conn() # ThreadSafetyError
-dj.Schema("name") # ThreadSafetyError
-
-inst = dj.Instance(host="h", user="u", password="p") # OK
-inst.Schema("name") # OK
-```
-
-## Behavior Summary
-
-| Operation | `thread_safe=False` | `thread_safe=True` |
-|-----------|--------------------|--------------------|
-| `dj.config` | `_global_config` | `ThreadSafetyError` |
-| `dj.conn()` | `_singleton_connection` | `ThreadSafetyError` |
-| `dj.Schema()` | Uses singleton | `ThreadSafetyError` |
-| `dj.FreeTable()` | Uses singleton | `ThreadSafetyError` |
-| `dj.Instance()` | Works | Works |
-| `inst.config` | Works | Works |
-| `inst.connection` | Works | Works |
-| `inst.Schema()` | Works | Works |
-
-## Lazy Loading
-
-The global config is created at module import time. The singleton connection is created lazily on first access:
-
-```python
-dj.config.database.user = "user" # Modifies global config (no connection yet)
-dj.config.database.password = "pw"
-dj.conn() # Creates singleton connection using global config
-dj.Schema("name") # Uses existing singleton connection
-```
-
-## Usage Example
-
-```python
-import datajoint as dj
-
-# Create isolated instance
-inst = dj.Instance(
- host="localhost",
- user="user",
- password="password",
-)
-
-# Create schema
-schema = inst.Schema("my_schema")
-
-@schema
-class Mouse(dj.Manual):
- definition = """
- mouse_id: int
- """
-
-# Use tables
-Mouse().insert1({"mouse_id": 1})
-Mouse().fetch()
-```
-
-## Architecture
-
-### Object graph
-
-There is exactly **one** global `Config` object created at import time in `settings.py`. Both the legacy API and the `Instance` API hang off `Connection` objects, each of which carries a `_config` reference.
-
-```
-settings.py
- config = _create_config() ← THE single global Config
-
-instance.py
- _global_config = settings.config ← same object (not a copy)
- _singleton_connection = None ← lazily created Connection
-
-__init__.py
- dj.config = _ConfigProxy() ← proxy → _global_config (with thread-safety check)
- dj.conn() ← returns _singleton_connection
- dj.Schema() ← uses _singleton_connection
- dj.FreeTable() ← uses _singleton_connection
-
-Connection (singleton)
- _config → _global_config ← same Config that dj.config writes to
-
-Connection (Instance)
- _config → fresh Config ← isolated per-instance
-```
-
-### Config flow: singleton path
-
-```
-dj.config["safemode"] = False
- ↓ _ConfigProxy.__setitem__
-_global_config["safemode"] = False (same object as settings.config)
- ↓
-Connection._config["safemode"] (points to _global_config)
- ↓
-schema.drop() reads self.connection._config["safemode"] → False ✓
-```
-
-### Config flow: Instance path
-
-```
-inst = dj.Instance(host=..., user=..., password=...)
- ↓
-inst.config = _create_config() (fresh Config, independent)
-inst.connection._config = inst.config
- ↓
-inst.config["safemode"] = False
- ↓
-schema.drop() reads self.connection._config["safemode"] → False ✓
-```
-
-### Key invariant
-
-**All runtime config reads go through `self.connection._config`**, never through the global `config` directly. This ensures both the singleton and Instance paths read the correct config.
-
-### Connection-scoped config reads
-
-Every module that previously imported `from .settings import config` now reads config from the connection:
-
-| Module | What was read | How it's read now |
-|--------|--------------|-------------------|
-| `schemas.py` | `config["safemode"]`, `config.database.create_tables` | `self.connection._config[...]` |
-| `table.py` | `config["safemode"]` in `delete()`, `drop()` | `self.connection._config["safemode"]` |
-| `expression.py` | `config["loglevel"]` in `__repr__()` | `self.connection._config["loglevel"]` |
-| `preview.py` | `config["display.*"]` (8 reads) | `query_expression.connection._config[...]` |
-| `autopopulate.py` | `config.jobs.allow_new_pk_fields`, `auto_refresh` | `self.connection._config.jobs.*` |
-| `jobs.py` | `config.jobs.default_priority`, `stale_timeout`, `keep_completed` | `self.connection._config.jobs.*` |
-| `declare.py` | `config.jobs.add_job_metadata` | `config` param (threaded from `table.py`) |
-| `diagram.py` | `config.display.diagram_direction` | `self._connection._config.display.*` |
-| `staged_insert.py` | `config.get_store_spec()` | `self._table.connection._config.get_store_spec()` |
-| `hash_registry.py` | `config.get_store_spec()` in 5 functions | `config` kwarg (falls back to `settings.config`) |
-| `builtin_codecs/hash.py` | `config` via hash_registry | `_config` from key dict → `config` kwarg to hash_registry |
-| `builtin_codecs/attach.py` | `config.get("download_path")` | `_config` from key dict (falls back to `settings.config`) |
-| `builtin_codecs/filepath.py` | `config.get_store_spec()` | `_config` from key dict (falls back to `settings.config`) |
-| `builtin_codecs/schema.py` | `config.get_store_spec()` in helpers | `config` kwarg to `_build_path()`, `_get_backend()` |
-| `builtin_codecs/npy.py` | `config` via schema helpers | `_config` from key dict → `config` kwarg to helpers |
-| `builtin_codecs/object.py` | `config` via schema helpers | `_config` from key dict → `config` kwarg to helpers |
-| `gc.py` | `config` via hash_registry | `schemas[0].connection._config` → `config` kwarg |
-
-### Functions that receive config as a parameter
-
-Some module-level functions cannot access `self.connection`. Config is threaded through:
-
-| Function | Caller | How config arrives |
-|----------|--------|--------------------|
-| `declare()` in `declare.py` | `Table.declare()` in `table.py` | `config=self.connection._config` kwarg |
-| `_get_job_version()` in `jobs.py` | `AutoPopulate._make_tuples()`, `Job.reserve()` | `config=self.connection._config` positional arg |
-| `get_store_backend()` in `hash_registry.py` | codecs, gc.py | `config` kwarg from key dict or schema connection |
-| `get_store_subfolding()` in `hash_registry.py` | `put_hash()` | `config` kwarg chained from caller |
-| `put_hash()` in `hash_registry.py` | `HashCodec.encode()` | `config` kwarg from `_config` in key dict |
-| `get_hash()` in `hash_registry.py` | `HashCodec.decode()` | `config` kwarg from `_config` in key dict |
-| `delete_path()` in `hash_registry.py` | `gc.collect()` | `config` kwarg from `schemas[0].connection._config` |
-| `decode_attribute()` in `codecs.py` | `expression.py` fetch methods | `connection` kwarg → extracts `connection._config` |
-
-All functions accept `config=None` and fall back to the global `settings.config` for backward compatibility.
-
-## Implementation
-
-### 1. Create Instance class
-
-```python
-class Instance:
- def __init__(self, host, user, password, port=3306, **kwargs):
- self.config = _create_config() # Fresh config with defaults
- # Apply any config overrides from kwargs
- self.connection = Connection(host, user, password, port, ...)
- self.connection._config = self.config
-
- def Schema(self, name, **kwargs):
- return Schema(name, connection=self.connection, **kwargs)
-
- def FreeTable(self, full_table_name):
- return FreeTable(self.connection, full_table_name)
-```
-
-### 2. Global config and singleton connection
-
-```python
-# settings.py - THE single global config
-config = _create_config() # Created at import time
-
-# instance.py - reuses the same config object
-_global_config = settings.config # Same reference, not a copy
-_singleton_connection = None # Created lazily
-
-def _check_thread_safe():
- if _load_thread_safe():
- raise ThreadSafetyError(
- "Global DataJoint state is disabled in thread-safe mode. "
- "Use dj.Instance() to create an isolated instance."
- )
-
-def _get_singleton_connection():
- _check_thread_safe()
- global _singleton_connection
- if _singleton_connection is None:
- _singleton_connection = Connection(
- host=_global_config.database.host,
- user=_global_config.database.user,
- password=_global_config.database.password,
- ...
- )
- _singleton_connection._config = _global_config
- return _singleton_connection
-```
-
-### 3. Legacy API with thread-safety checks
-
-```python
-# dj.config -> global config with thread-safety check
-class _ConfigProxy:
- def __getattr__(self, name):
- _check_thread_safe()
- return getattr(_global_config, name)
- def __setattr__(self, name, value):
- _check_thread_safe()
- setattr(_global_config, name, value)
-
-config = _ConfigProxy()
-
-# dj.conn() -> singleton connection (persistent across calls)
-def conn(host=None, user=None, password=None, *, reset=False):
- _check_thread_safe()
- if reset or (_singleton_connection is None and credentials_provided):
- _singleton_connection = Connection(...)
- _singleton_connection._config = _global_config
- return _get_singleton_connection()
-
-# dj.Schema() -> uses singleton connection
-def Schema(name, connection=None, **kwargs):
- if connection is None:
- _check_thread_safe()
- connection = _get_singleton_connection()
- return _Schema(name, connection=connection, **kwargs)
-
-# dj.FreeTable() -> uses singleton connection
-def FreeTable(conn_or_name, full_table_name=None):
- if full_table_name is None:
- _check_thread_safe()
- return _FreeTable(_get_singleton_connection(), conn_or_name)
- else:
- return _FreeTable(conn_or_name, full_table_name)
-```
-
-## Global State Audit
-
-All module-level mutable state was reviewed for thread-safety implications.
-
-### Guarded (blocked in thread-safe mode)
-
-| State | Location | Mechanism |
-|-------|----------|-----------|
-| `config` singleton | `settings.py:979` | `_ConfigProxy` raises `ThreadSafetyError`; use `inst.config` instead |
-| `conn()` singleton | `connection.py:108` | `_check_thread_safe()` guard; use `inst.connection` instead |
-
-These are the two globals that carry connection-scoped state (credentials, database settings) and are the primary source of cross-tenant interference.
-
-### Safe by design (no guard needed)
-
-| State | Location | Rationale |
-|-------|----------|-----------|
-| `_codec_registry` | `codecs.py:47` | Effectively immutable after import. Registration runs in `__init_subclass__` under Python's import lock. Runtime mutation (`_load_entry_points`) is idempotent under the GIL. Codecs are part of the type system, not connection-scoped. |
-| `_entry_points_loaded` | `codecs.py:48` | Bool flag for idempotent lazy loading; worst case under concurrent access is redundant work, not corruption. |
-
-### Low risk (no guard needed)
-
-| State | Location | Rationale |
-|-------|----------|-----------|
-| Logging side effects | `logging.py:8,17,40-45,56` | Standard Python logging configuration. Monkey-patches `Logger` and replaces `sys.excepthook` at import time. Not DataJoint-specific mutable state. |
-| `use_32bit_dims` | `blob.py:65` | Runtime flag affecting deserialization. Rarely changed; not connection-scoped. |
-| `compression` dict | `blob.py:61` | Decompressor function registry. Populated at import time, effectively read-only thereafter. |
-| `_lazy_modules` | `__init__.py:92` | Import caching via `globals()` mutation. Protected by Python's import lock. |
-| `ADAPTERS` dict | `adapters/__init__.py:16` | Backend registry. Populated at import time, read-only in practice. |
-
-### Design principle
-
-Only state that is **connection-scoped** (credentials, database settings, connection objects) needs thread-safe guards. State that is **code-scoped** (type registries, import caches, logging configuration) is shared across all threads by design and does not vary between tenants.
-
-## Error Messages
-
-- Singleton access: `"Global DataJoint state is disabled in thread-safe mode. Use dj.Instance() to create an isolated instance."`
diff --git a/pixi.lock b/pixi.lock
index c425c2176..0421929da 100644
--- a/pixi.lock
+++ b/pixi.lock
@@ -2092,8 +2092,8 @@ packages:
requires_python: '>=3.8'
- pypi: ./
name: datajoint
- version: 2.1.1
- sha256: 267defaa9ea7f22a8497568e8a14679be178f78cd3b34a4132609a57f0f71227
+ version: 2.2.0.dev0
+ sha256: 48335cedf96fa3b5efd3ddf880bd5065813f2baea43cad01a2fddbba94e561ec
requires_dist:
- deepdiff
- fsspec>=2023.1.0
diff --git a/pyproject.toml b/pyproject.toml
index 20832342b..5bf25dc29 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -235,6 +235,7 @@ markers = [
]
+
[tool.pixi.workspace]
channels = ["conda-forge"]
platforms = ["linux-64", "osx-arm64", "linux-aarch64"]
diff --git a/src/datajoint/adapters/base.py b/src/datajoint/adapters/base.py
index 5a595001b..da4779543 100644
--- a/src/datajoint/adapters/base.py
+++ b/src/datajoint/adapters/base.py
@@ -647,7 +647,7 @@ def create_index_ddl(
# Generate index name from table and columns if not provided
if index_name is None:
# Extract table name from full_table_name for index naming
- table_part = full_table_name.split(".")[-1].strip('`"')
+ _, table_part = self.split_full_table_name(full_table_name)
col_part = "_".join(columns)[:30] # Truncate for long column lists
index_name = f"idx_{table_part}_{col_part}"
unique_clause = "UNIQUE " if unique else ""
@@ -830,6 +830,26 @@ def load_foreign_keys_sql(self, schemas_list: str, like_pattern: str) -> str:
"""
...
+ def find_downstream_schemas_sql(self, schemas_list: str) -> str:
+ """
+ Generate query to find schemas with FK references to the given schemas.
+
+ Used to discover unloaded schemas that depend on loaded ones.
+
+ Parameters
+ ----------
+ schemas_list : str
+ Comma-separated, quoted schema names for an IN clause.
+
+ Returns
+ -------
+ str
+ SQL query returning rows with a single column ``schema_name``
+ containing distinct schema names that reference the given schemas.
+ """
+ raise NotImplementedError
+ ...
+
@abstractmethod
def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str:
"""
diff --git a/src/datajoint/adapters/mysql.py b/src/datajoint/adapters/mysql.py
index 3c28a85e6..1888eccf4 100644
--- a/src/datajoint/adapters/mysql.py
+++ b/src/datajoint/adapters/mysql.py
@@ -687,6 +687,15 @@ def load_foreign_keys_sql(self, schemas_list: str, like_pattern: str) -> str:
f"OR referenced_table_schema is not NULL AND table_schema in ({schemas_list}))"
)
+ def find_downstream_schemas_sql(self, schemas_list: str) -> str:
+ """Find schemas with FK references to the given schemas."""
+ return (
+ f"SELECT DISTINCT table_schema as schema_name "
+ f"FROM information_schema.key_column_usage "
+ f"WHERE referenced_table_schema IN ({schemas_list}) "
+ f"AND table_schema NOT IN ({schemas_list})"
+ )
+
def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str:
"""Query to get FK constraint details from information_schema."""
return (
diff --git a/src/datajoint/adapters/postgres.py b/src/datajoint/adapters/postgres.py
index cfb99b728..543e972d3 100644
--- a/src/datajoint/adapters/postgres.py
+++ b/src/datajoint/adapters/postgres.py
@@ -847,6 +847,20 @@ def load_foreign_keys_sql(self, schemas_list: str, like_pattern: str) -> str:
f"ORDER BY c.conname, cols.ord"
)
+ def find_downstream_schemas_sql(self, schemas_list: str) -> str:
+ """Find schemas with FK references to the given schemas."""
+ return (
+ f"SELECT DISTINCT ns1.nspname as schema_name "
+ f"FROM pg_constraint c "
+ f"JOIN pg_class cl1 ON c.conrelid = cl1.oid "
+ f"JOIN pg_namespace ns1 ON cl1.relnamespace = ns1.oid "
+ f"JOIN pg_class cl2 ON c.confrelid = cl2.oid "
+ f"JOIN pg_namespace ns2 ON cl2.relnamespace = ns2.oid "
+ f"WHERE c.contype = 'f' "
+ f"AND ns2.nspname IN ({schemas_list}) "
+ f"AND ns1.nspname NOT IN ({schemas_list})"
+ )
+
def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str:
"""
Query to get FK constraint details from information_schema.
diff --git a/src/datajoint/declare.py b/src/datajoint/declare.py
index 8807348ad..1370628bc 100644
--- a/src/datajoint/declare.py
+++ b/src/datajoint/declare.py
@@ -369,8 +369,8 @@ def prepare_declare(
adapter,
fk_attribute_map,
)
- elif re.match(r"^(unique\s+)?index\s*\(.*\)$", line, re.I): # index
- compile_index(line, index_sql, adapter)
+ elif re.match(r"^(unique\s+)?index\s*\(.*\)\s*(#.*)?$", line, re.I): # index
+ compile_index(re.sub(r"\s*#.*$", "", line), index_sql, adapter)
else:
name, sql, store, comment = compile_attribute(line, in_key, foreign_key_sql, context, adapter)
if store:
diff --git a/src/datajoint/dependencies.py b/src/datajoint/dependencies.py
index 99556345e..08fb50e1b 100644
--- a/src/datajoint/dependencies.py
+++ b/src/datajoint/dependencies.py
@@ -140,9 +140,9 @@ def clear(self) -> None:
self._node_alias_count = itertools.count() # reset alias IDs for consistency
super().clear()
- def load(self, force: bool = True) -> None:
+ def load(self, force: bool = True, schema_names: set[str] | None = None) -> None:
"""
- Load dependencies for all loaded schemas.
+ Load dependencies for the given schemas.
Called before operations requiring dependencies: delete, drop,
populate, progress.
@@ -151,6 +151,8 @@ def load(self, force: bool = True) -> None:
----------
force : bool, optional
If True (default), reload even if already loaded.
+ schema_names : set[str], optional
+ Schema names to load. If None, uses all activated schemas.
"""
# reload from scratch to prevent duplication of renamed edges
if self._loaded and not force:
@@ -162,7 +164,11 @@ def load(self, force: bool = True) -> None:
adapter = self._conn.adapter
# Build schema list for IN clause
- schemas_list = ", ".join(adapter.quote_string(s) for s in self._conn.schemas)
+ names = schema_names if schema_names is not None else set(self._conn.schemas)
+ if not names:
+ self._loaded = True
+ return
+ schemas_list = ", ".join(adapter.quote_string(s) for s in names)
# Load primary keys and foreign keys via adapter methods
# Note: Both PyMySQL and psycopg use %s placeholders, so escape % as %%
@@ -220,6 +226,39 @@ def load(self, force: bool = True) -> None:
raise DataJointError("DataJoint can only work with acyclic dependencies")
self._loaded = True
+ def load_all_downstream(self) -> None:
+ """
+ Load dependencies including all downstream schemas reachable via FK chains.
+
+ Iteratively discovers schemas that reference the currently loaded
+ schemas, expanding the dependency graph until no new schemas are
+ found. This ensures that cascade delete and drop reach all
+ dependent tables, even those in schemas that haven't been
+ explicitly activated.
+
+ Called automatically by ``Diagram.cascade()`` and ``Table.drop()``.
+ Call manually before constructing a ``Diagram`` to include
+ cross-schema dependencies in visualization::
+
+ conn.dependencies.load_all_downstream()
+ dj.Diagram(schema) # now includes all downstream schemas
+ """
+ adapter = self._conn.adapter
+ known_schemas = set(self._conn.schemas)
+ if not known_schemas:
+ self.load()
+ return
+
+ while True:
+ schemas_list = ", ".join(adapter.quote_string(s) for s in known_schemas)
+ result = self._conn.query(adapter.find_downstream_schemas_sql(schemas_list))
+ new_schemas = {row[0] for row in result} - known_schemas
+ if not new_schemas:
+ break
+ known_schemas |= new_schemas
+
+ self.load(force=True, schema_names=known_schemas)
+
def topo_sort(self) -> list[str]:
"""
Return table names in topological order.
diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py
index 75e00c21c..1436ec7b8 100644
--- a/src/datajoint/diagram.py
+++ b/src/datajoint/diagram.py
@@ -1,12 +1,21 @@
"""
-Diagram visualization for DataJoint schemas.
+Diagram for DataJoint schemas.
-This module provides the Diagram class for visualizing schema structure
-as directed acyclic graphs showing tables and their foreign key relationships.
+This module provides the Diagram class for constructing derived views of the
+dependency graph. Diagram supports set operators (+, -, *) for selecting subsets
+of tables, restriction propagation (cascade, restrict) for selecting subsets of
+data, and inspection (counts, prune) for viewing those selections.
+
+Mutation operations (delete, drop) live in Table, which uses Diagram internally
+for graph computation.
+
+Visualization methods (draw, make_dot, make_svg, etc.) require matplotlib and
+pygraphviz. All other methods are always available.
"""
from __future__ import annotations
+import copy as copy_module
import functools
import inspect
import io
@@ -14,7 +23,8 @@
import networkx as nx
-from .dependencies import topo_sort
+from .condition import AndList
+from .dependencies import extract_master, topo_sort
from .errors import DataJointError
from .table import Table, lookup_class_name
from .user_tables import Computed, Imported, Lookup, Manual, Part, _AliasNode, _get_tier
@@ -37,1002 +47,1367 @@
logger = logging.getLogger(__name__.split(".")[0])
-if not diagram_active: # noqa: C901
+class Diagram(nx.DiGraph): # noqa: C901
+ """
+ Schema diagram as a directed acyclic graph (DAG).
+
+ Visualizes tables and foreign key relationships derived from
+ ``connection.dependencies``.
+
+ Parameters
+ ----------
+ source : Table, Schema, or module
+ A table object, table class, schema, or module with a schema.
+ context : dict, optional
+ Namespace for resolving table class names. If None, uses caller's
+ frame globals/locals.
+
+ Examples
+ --------
+ >>> diag = dj.Diagram(schema.MyTable)
+ >>> diag.draw()
+
+ Operators:
+
+ - ``diag1 + diag2`` - union of diagrams
+ - ``diag1 - diag2`` - difference of diagrams
+ - ``diag1 * diag2`` - intersection of diagrams
+ - ``diag + n`` - expand n levels of successors (children)
+ - ``diag - n`` - expand n levels of predecessors (parents)
+
+ >>> dj.Diagram(schema.Table) + 1 - 1 # immediate ancestors and descendants
+
+ Notes
+ -----
+ ``diagram + 1 - 1`` may differ from ``diagram - 1 + 1``.
+ Only tables in activated schemas are displayed. To include tables in
+ downstream schemas that depend on the current schema but haven't been
+ explicitly activated::
+
+ conn.dependencies.load_all_downstream()
+ dj.Diagram(schema) # now includes all downstream schemas
+
+ ``Diagram.cascade()`` calls ``load_all_downstream()`` automatically.
+
+ Layout direction is controlled via ``dj.config.display.diagram_direction``
+ (default ``"TB"``). Use ``dj.config.override()`` to change temporarily::
+
+ with dj.config.override(display_diagram_direction="LR"):
+ dj.Diagram(schema).draw()
+ """
+
+ def __init__(self, source, context=None) -> None:
+ if isinstance(source, Diagram):
+ # copy constructor
+ self.nodes_to_show = set(source.nodes_to_show)
+ self._expanded_nodes = set(source._expanded_nodes)
+ self.context = source.context
+ self._connection = source._connection
+ self._cascade_restrictions = copy_module.deepcopy(source._cascade_restrictions)
+ self._restrict_conditions = copy_module.deepcopy(source._restrict_conditions)
+ self._restriction_attrs = copy_module.deepcopy(source._restriction_attrs)
+ super().__init__(source)
+ return
+
+ # get the caller's context
+ if context is None:
+ frame = inspect.currentframe().f_back
+ self.context = dict(frame.f_globals, **frame.f_locals)
+ del frame
+ else:
+ self.context = context
+
+ # find connection in the source
+ try:
+ connection = source.connection
+ except AttributeError:
+ try:
+ connection = source.schema.connection
+ except AttributeError:
+ raise DataJointError("Could not find database connection in %s" % repr(source))
+
+ # initialize graph from dependencies
+ connection.dependencies.load()
+ super().__init__(connection.dependencies)
+ self._connection = connection
+ self._cascade_restrictions = {}
+ self._restrict_conditions = {}
+ self._restriction_attrs = {}
+
+ # Enumerate nodes from all the items in the list
+ self.nodes_to_show = set()
+ try:
+ self.nodes_to_show.add(source.full_table_name)
+ except AttributeError:
+ try:
+ database = source.database
+ except AttributeError:
+ try:
+ database = source.schema.database
+ except AttributeError:
+ raise DataJointError("Cannot plot Diagram for %s" % repr(source))
+ for node in self.nodes():
+ # Handle both MySQL backticks and PostgreSQL double quotes
+ if node.startswith("`%s`" % database) or node.startswith('"%s"' % database):
+ self.nodes_to_show.add(node)
+ # All nodes start as expanded
+ self._expanded_nodes = set(self.nodes_to_show)
+
+ @classmethod
+ def from_sequence(cls, sequence) -> "Diagram":
+ """
+ Create combined Diagram from a sequence of sources.
+
+ Parameters
+ ----------
+ sequence : iterable
+ Sequence of table objects, classes, or schemas.
- class Diagram:
+ Returns
+ -------
+ Diagram
+ Union of diagrams: ``Diagram(arg1) + ... + Diagram(argn)``.
"""
- Schema diagram (disabled).
+ return functools.reduce(lambda x, y: x + y, map(Diagram, sequence))
- Diagram visualization requires matplotlib and pygraphviz packages.
- Install them to enable this feature.
+ def add_parts(self) -> "Diagram":
+ """
+ Add part tables of all masters already in the diagram.
- See Also
+ Returns
+ -------
+ Diagram
+ New diagram with part tables included.
+ """
+
+ split = self._connection.adapter.split_full_table_name
+
+ def is_part(part, master):
+ p_schema, p_table = split(part)
+ m_schema, m_table = split(master)
+ return m_schema == p_schema and m_table + "__" == p_table[: len(m_table) + 2]
+
+ self = Diagram(self) # copy
+ self.nodes_to_show.update(n for n in self.nodes() if any(is_part(n, m) for m in self.nodes_to_show))
+ return self
+
+ def collapse(self) -> "Diagram":
+ """
+ Mark all nodes in this diagram as collapsed.
+
+ Collapsed nodes are shown as a single node per schema. When combined
+ with other diagrams using ``+``, expanded nodes win: if a node is
+ expanded in either operand, it remains expanded in the result.
+
+ Returns
+ -------
+ Diagram
+ A copy of this diagram with all nodes collapsed.
+
+ Examples
--------
- https://docs.datajoint.com/how-to/installation/
+ >>> # Show schema1 expanded, schema2 collapsed into single nodes
+ >>> dj.Diagram(schema1) + dj.Diagram(schema2).collapse()
+
+ >>> # Collapse all three schemas together
+ >>> (dj.Diagram(schema1) + dj.Diagram(schema2) + dj.Diagram(schema3)).collapse()
+
+ >>> # Expand one table from collapsed schema
+ >>> dj.Diagram(schema).collapse() + dj.Diagram(SingleTable)
"""
+ result = Diagram(self)
+ result._expanded_nodes = set() # All nodes collapsed
+ return result
- def __init__(self, *args, **kwargs) -> None:
- logger.warning("Please install matplotlib and pygraphviz libraries to enable the Diagram feature.")
+ def __add__(self, arg) -> "Diagram":
+ """
+ Union or downstream expansion.
-else:
+ Parameters
+ ----------
+ arg : Diagram or int
+ Another Diagram for union, or positive int for downstream expansion.
+
+ Returns
+ -------
+ Diagram
+ Combined or expanded diagram.
+ """
+ result = Diagram(self) # copy
+ try:
+ # Merge nodes and edges from the other diagram
+ result.add_nodes_from(arg.nodes(data=True))
+ result.add_edges_from(arg.edges(data=True))
+ result.nodes_to_show.update(arg.nodes_to_show)
+ # Merge contexts for class name lookups
+ result.context = {**result.context, **arg.context}
+ # Expanded wins: union of expanded nodes from both operands
+ result._expanded_nodes = self._expanded_nodes | arg._expanded_nodes
+ except AttributeError:
+ try:
+ result.nodes_to_show.add(arg.full_table_name)
+ result._expanded_nodes.add(arg.full_table_name)
+ except AttributeError:
+ for i in range(arg):
+ new = nx.algorithms.boundary.node_boundary(result, result.nodes_to_show)
+ if not new:
+ break
+ # add nodes referenced by aliased nodes
+ new.update(nx.algorithms.boundary.node_boundary(result, (a for a in new if a.isdigit())))
+ result.nodes_to_show.update(new)
+ # New nodes from expansion are expanded
+ result._expanded_nodes = result._expanded_nodes | result.nodes_to_show
+ return result
+
+ def __sub__(self, arg) -> "Diagram":
+ """
+ Difference or upstream expansion.
+
+ Parameters
+ ----------
+ arg : Diagram or int
+ Another Diagram for difference, or positive int for upstream expansion.
+
+ Returns
+ -------
+ Diagram
+ Reduced or expanded diagram.
+ """
+ self = Diagram(self) # copy
+ try:
+ self.nodes_to_show.difference_update(arg.nodes_to_show)
+ except AttributeError:
+ try:
+ self.nodes_to_show.remove(arg.full_table_name)
+ except AttributeError:
+ for i in range(arg):
+ graph = nx.DiGraph(self).reverse()
+ new = nx.algorithms.boundary.node_boundary(graph, self.nodes_to_show)
+ if not new:
+ break
+ # add nodes referenced by aliased nodes
+ new.update(nx.algorithms.boundary.node_boundary(graph, (a for a in new if a.isdigit())))
+ self.nodes_to_show.update(new)
+ return self
+
+ def __mul__(self, arg) -> "Diagram":
+ """
+ Intersection of two diagrams.
- class Diagram(nx.DiGraph):
+ Parameters
+ ----------
+ arg : Diagram
+ Another Diagram.
+
+ Returns
+ -------
+ Diagram
+ Diagram with nodes present in both operands.
"""
- Schema diagram as a directed acyclic graph (DAG).
+ self = Diagram(self) # copy
+ self.nodes_to_show.intersection_update(arg.nodes_to_show)
+ return self
- Visualizes tables and foreign key relationships derived from
- ``connection.dependencies``.
+ def topo_sort(self) -> list[str]:
+ """
+ Return nodes in topological order.
+
+ Returns
+ -------
+ list[str]
+ Node names in topological order.
+ """
+ return topo_sort(self)
+
+ @classmethod
+ def cascade(cls, table_expr, part_integrity="enforce"):
+ """
+ Create a cascade diagram for a table expression.
+
+ Builds a Diagram from the table's dependency graph, includes all
+ descendants (across all loaded schemas), and propagates the
+ restriction downstream using OR convergence — a child row is
+ affected if *any* restricted ancestor taints it.
Parameters
----------
- source : Table, Schema, or module
- A table object, table class, schema, or module with a schema.
- context : dict, optional
- Namespace for resolving table class names. If None, uses caller's
- frame globals/locals.
+ table_expr : QueryExpression
+ A (possibly restricted) table expression
+ (e.g., ``Session & 'subject_id=1'``).
+ part_integrity : str, optional
+ ``"enforce"`` (default), ``"ignore"``, or ``"cascade"``.
+
+ Returns
+ -------
+ Diagram
+ New Diagram with cascade restrictions applied, trimmed to
+ the seed table and its affected descendants.
Examples
--------
- >>> diag = dj.Diagram(schema.MyTable)
- >>> diag.draw()
+ >>> # Preview cascade impact across all downstream schemas
+ >>> dj.Diagram.cascade(Session & 'subject_id=1').counts()
- Operators:
+ >>> # Inspect the cascade subgraph
+ >>> dj.Diagram.cascade(Session & 'subject_id=1')
+ """
+ conn = table_expr.connection
+ conn.dependencies.load_all_downstream()
+ node = table_expr.full_table_name
+
+ result = cls.__new__(cls)
+ nx.DiGraph.__init__(result, conn.dependencies)
+ result._connection = conn
+ result.context = {}
+ result._cascade_restrictions = {}
+ result._restrict_conditions = {}
+ result._restriction_attrs = {}
+
+ # Include seed + all descendants
+ descendants = set(nx.descendants(result, node)) | {node}
+ result.nodes_to_show = descendants
+ result._expanded_nodes = set(descendants)
+
+ # Seed restriction
+ restriction = AndList(table_expr.restriction)
+ result._cascade_restrictions[node] = [restriction] if restriction else []
+ result._restriction_attrs[node] = set(table_expr.restriction_attributes)
+
+ # Propagate downstream
+ result._propagate_restrictions(node, mode="cascade", part_integrity=part_integrity)
+
+ # Trim graph to cascade subgraph: only restricted tables
+ # (seed + descendants) plus alias nodes connecting them.
+ keep = set(result._cascade_restrictions)
+ for alias in (n for n in result.nodes() if n.isdigit()):
+ if set(result.predecessors(alias)) & keep and set(result.successors(alias)) & keep:
+ keep.add(alias)
+ result.remove_nodes_from(set(result.nodes()) - keep)
+ result.nodes_to_show &= keep
+ result._expanded_nodes &= keep
+ return result
+
+ def _restricted_table(self, node):
+ """
+ Return a FreeTable for ``node`` with this diagram's restrictions applied.
- - ``diag1 + diag2`` - union of diagrams
- - ``diag1 - diag2`` - difference of diagrams
- - ``diag1 * diag2`` - intersection of diagrams
- - ``diag + n`` - expand n levels of successors (children)
- - ``diag - n`` - expand n levels of predecessors (parents)
+ Cascade restrictions are OR-combined (a row is affected if ANY
+ FK reference points to a deleted row). Restrict conditions are
+ AND-combined (a row is included only when ALL ancestor conditions
+ are satisfied).
+ """
+ from .table import FreeTable
+
+ ft = FreeTable(self._connection, node)
+ restrictions = (self._cascade_restrictions or self._restrict_conditions).get(node, [])
+ if not restrictions:
+ return ft
+ if self._cascade_restrictions:
+ # OR semantics — passing a list to restrict() creates an OrList
+ return ft.restrict(restrictions)
+ else:
+ # AND semantics — each restriction narrows further
+ for r in restrictions:
+ ft = ft.restrict(r)
+ return ft
+
+ def restrict(self, table_expr):
+ """
+ Apply restrict condition and propagate downstream.
- >>> dj.Diagram(schema.Table) + 1 - 1 # immediate ancestors and descendants
+ AND at convergence — a child row is included only if it satisfies
+ *all* restricted ancestors. Used for export. Can be chained.
- Notes
- -----
- ``diagram + 1 - 1`` may differ from ``diagram - 1 + 1``.
- Only tables loaded in the connection are displayed.
+ Cannot be called on a Diagram produced by ``Diagram.cascade()``.
+
+ Parameters
+ ----------
+ table_expr : QueryExpression
+ A restricted table expression.
+
+ Returns
+ -------
+ Diagram
+ New Diagram with restrict conditions applied.
+ """
+ if self._cascade_restrictions:
+ raise DataJointError(
+ "Cannot apply restrict() on a Diagram produced by Diagram.cascade(). "
+ "cascade and restrict are mutually exclusive modes."
+ )
+ result = Diagram(self)
+ node = table_expr.full_table_name
+ if node not in result.nodes():
+ raise DataJointError(f"Table {node} is not in the diagram.")
+ # Seed restriction (AND accumulation)
+ result._restrict_conditions.setdefault(node, AndList()).extend(table_expr.restriction)
+ result._restriction_attrs.setdefault(node, set()).update(table_expr.restriction_attributes)
+ # Propagate downstream
+ result._propagate_restrictions(node, mode="restrict")
+ return result
+
+ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"):
+ """
+ Propagate restrictions from start_node to all its descendants.
+
+ Walks the dependency graph in topological order, applying
+ propagation rules at each edge. Only processes descendants of
+ start_node to avoid duplicate propagation when chaining.
+ """
+ from .table import FreeTable
+
+ sorted_nodes = topo_sort(self)
+ # Only propagate through descendants of start_node
+ allowed_nodes = {start_node} | set(nx.descendants(self, start_node))
+ propagated_edges = set()
+ visited_masters = set()
+
+ restrictions = self._cascade_restrictions if mode == "cascade" else self._restrict_conditions
+
+ # Multiple passes to handle part_integrity="cascade" upward propagation.
+ # When a part table triggers its master to join the cascade, the master's
+ # other descendants need processing in a subsequent pass. The loop
+ # terminates when no new nodes are added — guaranteed in a DAG.
+ any_new = True
+ while any_new:
+ any_new = False
+
+ for node in sorted_nodes:
+ if node not in restrictions or node not in allowed_nodes:
+ continue
+
+ # Build parent FreeTable with current restriction
+ parent_ft = self._restricted_table(node)
+
+ parent_attrs = self._restriction_attrs.get(node, set())
+
+ for _, target, edge_props in self.out_edges(node, data=True):
+ attr_map = edge_props.get("attr_map", {})
+ aliased = edge_props.get("aliased", False)
+
+ if target.isdigit():
+ # Alias node — follow through to real child
+ for _, child_node, _ in self.out_edges(target, data=True):
+ edge_key = (node, target, child_node)
+ if edge_key in propagated_edges:
+ continue
+ propagated_edges.add(edge_key)
+ was_new = child_node not in restrictions
+ self._apply_propagation_rule(
+ parent_ft,
+ parent_attrs,
+ child_node,
+ attr_map,
+ True,
+ mode,
+ restrictions,
+ )
+ if was_new and child_node in restrictions:
+ any_new = True
+ else:
+ edge_key = (node, target)
+ if edge_key in propagated_edges:
+ continue
+ propagated_edges.add(edge_key)
+ was_new = target not in restrictions
+ self._apply_propagation_rule(
+ parent_ft,
+ parent_attrs,
+ target,
+ attr_map,
+ aliased,
+ mode,
+ restrictions,
+ )
+ if was_new and target in restrictions:
+ any_new = True
+
+ # part_integrity="cascade": propagate up from part to master
+ if part_integrity == "cascade" and mode == "cascade":
+ master_name = extract_master(target)
+ if (
+ master_name
+ and master_name in self.nodes()
+ and master_name not in restrictions
+ and master_name not in visited_masters
+ ):
+ visited_masters.add(master_name)
+ child_ft = self._restricted_table(target)
+ master_ft = FreeTable(self._connection, master_name)
+ from .condition import make_condition
+
+ master_restr = make_condition(
+ master_ft,
+ (master_ft.proj() & child_ft.proj()).to_arrays(),
+ master_ft.restriction_attributes,
+ )
+ restrictions[master_name] = [master_restr]
+ self._restriction_attrs[master_name] = set()
+ allowed_nodes.add(master_name)
+ allowed_nodes.update(nx.descendants(self, master_name))
+ any_new = True
+
+ def _apply_propagation_rule(
+ self,
+ parent_ft,
+ parent_attrs,
+ child_node,
+ attr_map,
+ aliased,
+ mode,
+ restrictions,
+ ):
+ """
+ Apply one of the 3 propagation rules to a parent→child edge.
- Layout direction is controlled via ``dj.config.display.diagram_direction``
- (default ``"TB"``). Use ``dj.config.override()`` to change temporarily::
+ Rules (from table.py restriction propagation):
- with dj.config.override(display_diagram_direction="LR"):
- dj.Diagram(schema).draw()
+ 1. Non-aliased AND parent restriction attrs ⊆ child PK:
+ Copy parent restriction directly.
+ 2. Aliased FK (attr_map renames columns):
+ ``parent.proj(**{fk: pk for fk, pk in attr_map.items()})``
+ 3. Non-aliased AND parent restriction attrs ⊄ child PK:
+ ``parent.proj()``
"""
+ child_pk = self.nodes[child_node].get("primary_key", set())
- def __init__(self, source, context=None) -> None:
- if isinstance(source, Diagram):
- # copy constructor
- self.nodes_to_show = set(source.nodes_to_show)
- self._expanded_nodes = set(source._expanded_nodes)
- self.context = source.context
- self._connection = source._connection
- super().__init__(source)
- return
-
- # get the caller's context
- if context is None:
- frame = inspect.currentframe().f_back
- self.context = dict(frame.f_globals, **frame.f_locals)
- del frame
+ if not aliased and parent_attrs and parent_attrs <= child_pk:
+ # Rule 1: copy parent restriction directly
+ parent_restr = restrictions.get(
+ parent_ft.full_table_name,
+ [] if mode == "cascade" else AndList(),
+ )
+ if mode == "cascade":
+ restrictions.setdefault(child_node, []).extend(parent_restr)
+ else:
+ restrictions.setdefault(child_node, AndList()).extend(parent_restr)
+ child_attrs = set(parent_attrs)
+ elif aliased:
+ # Rule 2: aliased FK — project with renaming
+ child_item = parent_ft.proj(**{fk: pk for fk, pk in attr_map.items()})
+ if mode == "cascade":
+ restrictions.setdefault(child_node, []).append(child_item)
else:
- self.context = context
+ restrictions.setdefault(child_node, AndList()).append(child_item)
+ child_attrs = set(attr_map.keys())
+ else:
+ # Rule 3: non-aliased, restriction attrs ⊄ child PK — project
+ child_item = parent_ft.proj()
+ if mode == "cascade":
+ restrictions.setdefault(child_node, []).append(child_item)
+ else:
+ restrictions.setdefault(child_node, AndList()).append(child_item)
+ child_attrs = set(attr_map.values())
- # find connection in the source
- try:
- connection = source.connection
- except AttributeError:
- try:
- connection = source.schema.connection
- except AttributeError:
- raise DataJointError("Could not find database connection in %s" % repr(source[0]))
+ self._restriction_attrs.setdefault(child_node, set()).update(child_attrs)
- # initialize graph from dependencies
- self._connection = connection
- connection.dependencies.load()
- super().__init__(connection.dependencies)
+ def counts(self):
+ """
+ Return affected row counts per table without modifying data.
- # Enumerate nodes from all the items in the list
- self.nodes_to_show = set()
- try:
- self.nodes_to_show.add(source.full_table_name)
- except AttributeError:
- try:
- database = source.database
- except AttributeError:
- try:
- database = source.schema.database
- except AttributeError:
- raise DataJointError("Cannot plot Diagram for %s" % repr(source))
- for node in self:
- # Handle both MySQL backticks and PostgreSQL double quotes
- if node.startswith("`%s`" % database) or node.startswith('"%s"' % database):
- self.nodes_to_show.add(node)
- # All nodes start as expanded
- self._expanded_nodes = set(self.nodes_to_show)
-
- @classmethod
- def from_sequence(cls, sequence) -> "Diagram":
- """
- Create combined Diagram from a sequence of sources.
-
- Parameters
- ----------
- sequence : iterable
- Sequence of table objects, classes, or schemas.
-
- Returns
- -------
- Diagram
- Union of diagrams: ``Diagram(arg1) + ... + Diagram(argn)``.
- """
- return functools.reduce(lambda x, y: x + y, map(Diagram, sequence))
-
- def add_parts(self) -> "Diagram":
- """
- Add part tables of all masters already in the diagram.
-
- Returns
- -------
- Diagram
- New diagram with part tables included.
- """
-
- def is_part(part, master):
- part = [s.strip("`") for s in part.split(".")]
- master = [s.strip("`") for s in master.split(".")]
- return master[0] == part[0] and master[1] + "__" == part[1][: len(master[1]) + 2]
-
- self = Diagram(self) # copy
- self.nodes_to_show.update(n for n in self.nodes() if any(is_part(n, m) for m in self.nodes_to_show))
- return self
-
- def collapse(self) -> "Diagram":
- """
- Mark all nodes in this diagram as collapsed.
-
- Collapsed nodes are shown as a single node per schema. When combined
- with other diagrams using ``+``, expanded nodes win: if a node is
- expanded in either operand, it remains expanded in the result.
-
- Returns
- -------
- Diagram
- A copy of this diagram with all nodes collapsed.
-
- Examples
- --------
- >>> # Show schema1 expanded, schema2 collapsed into single nodes
- >>> dj.Diagram(schema1) + dj.Diagram(schema2).collapse()
-
- >>> # Collapse all three schemas together
- >>> (dj.Diagram(schema1) + dj.Diagram(schema2) + dj.Diagram(schema3)).collapse()
-
- >>> # Expand one table from collapsed schema
- >>> dj.Diagram(schema).collapse() + dj.Diagram(SingleTable)
- """
- result = Diagram(self)
- result._expanded_nodes = set() # All nodes collapsed
- return result
-
- def __add__(self, arg) -> "Diagram":
- """
- Union or downstream expansion.
-
- Parameters
- ----------
- arg : Diagram or int
- Another Diagram for union, or positive int for downstream expansion.
-
- Returns
- -------
- Diagram
- Combined or expanded diagram.
- """
- result = Diagram(self) # copy
- try:
- # Merge nodes and edges from the other diagram
- result.add_nodes_from(arg.nodes(data=True))
- result.add_edges_from(arg.edges(data=True))
- result.nodes_to_show.update(arg.nodes_to_show)
- # Merge contexts for class name lookups
- result.context = {**result.context, **arg.context}
- # Expanded wins: union of expanded nodes from both operands
- result._expanded_nodes = self._expanded_nodes | arg._expanded_nodes
- except AttributeError:
- try:
- result.nodes_to_show.add(arg.full_table_name)
- result._expanded_nodes.add(arg.full_table_name)
- except AttributeError:
- for i in range(arg):
- new = nx.algorithms.boundary.node_boundary(result, result.nodes_to_show)
- if not new:
- break
- # add nodes referenced by aliased nodes
- new.update(nx.algorithms.boundary.node_boundary(result, (a for a in new if a.isdigit())))
- result.nodes_to_show.update(new)
- # New nodes from expansion are expanded
- result._expanded_nodes = result._expanded_nodes | result.nodes_to_show
- return result
-
- def __sub__(self, arg) -> "Diagram":
- """
- Difference or upstream expansion.
-
- Parameters
- ----------
- arg : Diagram or int
- Another Diagram for difference, or positive int for upstream expansion.
-
- Returns
- -------
- Diagram
- Reduced or expanded diagram.
- """
- self = Diagram(self) # copy
- try:
- self.nodes_to_show.difference_update(arg.nodes_to_show)
- except AttributeError:
- try:
- self.nodes_to_show.remove(arg.full_table_name)
- except AttributeError:
- for i in range(arg):
- graph = nx.DiGraph(self).reverse()
- new = nx.algorithms.boundary.node_boundary(graph, self.nodes_to_show)
- if not new:
- break
- # add nodes referenced by aliased nodes
- new.update(nx.algorithms.boundary.node_boundary(graph, (a for a in new if a.isdigit())))
- self.nodes_to_show.update(new)
- return self
-
- def __mul__(self, arg) -> "Diagram":
- """
- Intersection of two diagrams.
-
- Parameters
- ----------
- arg : Diagram
- Another Diagram.
-
- Returns
- -------
- Diagram
- Diagram with nodes present in both operands.
- """
- self = Diagram(self) # copy
- self.nodes_to_show.intersection_update(arg.nodes_to_show)
- return self
-
- def topo_sort(self) -> list[str]:
- """
- Return nodes in topological order.
-
- Returns
- -------
- list[str]
- Node names in topological order.
- """
- return topo_sort(self)
-
- def _make_graph(self) -> nx.DiGraph:
- """
- Build graph object ready for drawing.
-
- Returns
- -------
- nx.DiGraph
- Graph with nodes relabeled to class names.
- """
- # mark "distinguished" tables, i.e. those that introduce new primary key
- # attributes
- # Filter nodes_to_show to only include nodes that exist in the graph
- valid_nodes = self.nodes_to_show.intersection(set(self.nodes()))
- for name in valid_nodes:
- foreign_attributes = set(
- attr for p in self.in_edges(name, data=True) for attr in p[2]["attr_map"] if p[2]["primary"]
- )
- self.nodes[name]["distinguished"] = (
- "primary_key" in self.nodes[name] and foreign_attributes < self.nodes[name]["primary_key"]
- )
- # include aliased nodes that are sandwiched between two displayed nodes
- gaps = set(nx.algorithms.boundary.node_boundary(self, valid_nodes)).intersection(
- nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), valid_nodes)
+ Returns
+ -------
+ dict[str, int]
+ Mapping of full table name to affected row count.
+ """
+ restrictions = self._cascade_restrictions or self._restrict_conditions
+ if not restrictions:
+ raise DataJointError(
+ "No restrictions applied. " "Use Diagram.cascade(table_expr) or diag.restrict(table_expr) first."
)
- nodes = valid_nodes.union(a for a in gaps if a.isdigit())
- # construct subgraph and rename nodes to class names
- graph = nx.DiGraph(nx.DiGraph(self).subgraph(nodes))
- nx.set_node_attributes(graph, name="node_type", values={n: _get_tier(n) for n in graph})
- # relabel nodes to class names
- mapping = {node: lookup_class_name(node, self.context) or node for node in graph.nodes()}
- new_names = list(mapping.values())
- if len(new_names) > len(set(new_names)):
- raise DataJointError("Some classes have identical names. The Diagram cannot be plotted.")
- nx.relabel_nodes(graph, mapping, copy=False)
- return graph
-
- def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str]]:
- """
- Apply collapse logic to the graph.
-
- Nodes in nodes_to_show but not in _expanded_nodes are collapsed into
- single schema nodes.
-
- Parameters
- ----------
- graph : nx.DiGraph
- The graph from _make_graph().
-
- Returns
- -------
- tuple[nx.DiGraph, dict[str, str]]
- Modified graph and mapping of collapsed schema labels to their table count.
- """
- # Filter to valid nodes (those that exist in the underlying graph)
- valid_nodes = self.nodes_to_show.intersection(set(self.nodes()))
- valid_expanded = self._expanded_nodes.intersection(set(self.nodes()))
-
- # If all nodes are expanded, no collapse needed
- if valid_expanded >= valid_nodes:
- return graph, {}
-
- # Map full_table_names to class_names
- full_to_class = {node: lookup_class_name(node, self.context) or node for node in valid_nodes}
- class_to_full = {v: k for k, v in full_to_class.items()}
-
- # Identify expanded class names
- expanded_class_names = {full_to_class.get(node, node) for node in valid_expanded}
-
- # Identify nodes to collapse (class names)
- nodes_to_collapse = set(graph.nodes()) - expanded_class_names
-
- if not nodes_to_collapse:
- return graph, {}
-
- # Group collapsed nodes by schema
- collapsed_by_schema = {} # schema_name -> list of class_names
- for class_name in nodes_to_collapse:
- full_name = class_to_full.get(class_name)
- if full_name:
- parts = full_name.replace('"', "`").split("`")
- if len(parts) >= 2:
- schema_name = parts[1]
- if schema_name not in collapsed_by_schema:
- collapsed_by_schema[schema_name] = []
- collapsed_by_schema[schema_name].append(class_name)
-
- if not collapsed_by_schema:
- return graph, {}
-
- # Determine labels for collapsed schemas
- schema_modules = {}
- for schema_name, class_names in collapsed_by_schema.items():
- schema_modules[schema_name] = set()
- for class_name in class_names:
- cls = self._resolve_class(class_name)
- if cls is not None and hasattr(cls, "__module__"):
- module_name = cls.__module__.split(".")[-1]
- schema_modules[schema_name].add(module_name)
-
- # Collect module names for ALL schemas in the diagram (not just collapsed)
- all_schema_modules = {} # schema_name -> module_name
- for node in graph.nodes():
- full_name = class_to_full.get(node)
- if full_name:
- parts = full_name.replace('"', "`").split("`")
- if len(parts) >= 2:
- db_schema = parts[1]
- cls = self._resolve_class(node)
- if cls is not None and hasattr(cls, "__module__"):
- module_name = cls.__module__.split(".")[-1]
- all_schema_modules[db_schema] = module_name
-
- # Check which module names are shared by multiple schemas
- module_to_schemas = {}
- for db_schema, module_name in all_schema_modules.items():
- if module_name not in module_to_schemas:
- module_to_schemas[module_name] = []
- module_to_schemas[module_name].append(db_schema)
-
- ambiguous_modules = {m for m, schemas in module_to_schemas.items() if len(schemas) > 1}
-
- # Determine labels for collapsed schemas
- collapsed_labels = {} # schema_name -> label
- for schema_name, modules in schema_modules.items():
- if len(modules) == 1:
- module_name = next(iter(modules))
- # Use database schema name if module is ambiguous
- if module_name in ambiguous_modules:
- label = schema_name
- else:
- label = module_name
- else:
- label = schema_name
- collapsed_labels[schema_name] = label
-
- # Build counts using final labels
- collapsed_counts = {} # label -> count of tables
- for schema_name, class_names in collapsed_by_schema.items():
- label = collapsed_labels[schema_name]
- collapsed_counts[label] = len(class_names)
-
- # Create new graph with collapsed nodes
- new_graph = nx.DiGraph()
-
- # Map old node names to new names (collapsed nodes -> schema label)
- node_mapping = {}
- for node in graph.nodes():
- full_name = class_to_full.get(node)
- if full_name:
- parts = full_name.replace('"', "`").split("`")
- if len(parts) >= 2 and node in nodes_to_collapse:
- schema_name = parts[1]
- node_mapping[node] = collapsed_labels[schema_name]
- else:
- node_mapping[node] = node
- else:
- # Alias nodes - check if they should be collapsed
- # An alias node should be collapsed if ALL its neighbors are collapsed
- neighbors = set(graph.predecessors(node)) | set(graph.successors(node))
- if neighbors and neighbors <= nodes_to_collapse:
- # Get schema from first neighbor
- neighbor = next(iter(neighbors))
- full_name = class_to_full.get(neighbor)
- if full_name:
- parts = full_name.replace('"', "`").split("`")
- if len(parts) >= 2:
- schema_name = parts[1]
- node_mapping[node] = collapsed_labels[schema_name]
- continue
- node_mapping[node] = node
- # Build reverse mapping: label -> schema_name
- label_to_schema = {label: schema for schema, label in collapsed_labels.items()}
-
- # Add nodes
- added_collapsed = set()
- for old_node, new_node in node_mapping.items():
- if new_node in collapsed_counts:
- # This is a collapsed schema node
- if new_node not in added_collapsed:
- schema_name = label_to_schema.get(new_node, new_node)
- new_graph.add_node(
- new_node,
- node_type=None,
- collapsed=True,
- table_count=collapsed_counts[new_node],
- schema_name=schema_name,
- )
- added_collapsed.add(new_node)
- else:
- new_graph.add_node(new_node, **graph.nodes[old_node])
-
- # Add edges (avoiding self-loops and duplicates)
- for src, dest, data in graph.edges(data=True):
- new_src = node_mapping[src]
- new_dest = node_mapping[dest]
- if new_src != new_dest and not new_graph.has_edge(new_src, new_dest):
- new_graph.add_edge(new_src, new_dest, **data)
-
- return new_graph, collapsed_counts
-
- def _resolve_class(self, name: str):
- """
- Safely resolve a table class from a dotted name without eval().
-
- Parameters
- ----------
- name : str
- Dotted class name like "MyTable" or "Module.MyTable".
-
- Returns
- -------
- type or None
- The table class if found, otherwise None.
- """
- parts = name.split(".")
- obj = self.context.get(parts[0])
- for part in parts[1:]:
- if obj is None:
- return None
- obj = getattr(obj, part, None)
- if obj is not None and isinstance(obj, type) and issubclass(obj, Table):
- return obj
- return None
-
- @staticmethod
- def _encapsulate_edge_attributes(graph: nx.DiGraph) -> None:
- """
- Encapsulate edge attr_map in double quotes for pydot compatibility.
-
- Modifies graph in place.
-
- See Also
- --------
- https://github.com/pydot/pydot/issues/258#issuecomment-795798099
- """
- for u, v, *_, edgedata in graph.edges(data=True):
- if "attr_map" in edgedata:
- graph.edges[u, v]["attr_map"] = '"{0}"'.format(edgedata["attr_map"])
-
- @staticmethod
- def _encapsulate_node_names(graph: nx.DiGraph) -> None:
- """
- Encapsulate node names in double quotes for pydot compatibility.
-
- Modifies graph in place.
-
- See Also
- --------
- https://github.com/datajoint/datajoint-python/pull/1176
- """
- nx.relabel_nodes(
- graph,
- {node: '"{0}"'.format(node) for node in graph.nodes()},
- copy=False,
+ result = {}
+ for ft in self:
+ if ft.full_table_name in restrictions:
+ count = len(ft)
+ result[ft.full_table_name] = count
+ logger.info("{table} ({count} tuples)".format(table=ft.full_table_name, count=count))
+ return result
+
+ def __iter__(self):
+ """
+ Iterate over non-alias nodes in topological order (parents first).
+
+ Yields restricted ``FreeTable`` objects when cascade or restrict
+ conditions have been applied, unrestricted ``FreeTable`` otherwise.
+
+ Alias nodes (used internally for multi-FK edges) are skipped.
+ """
+ for node in topo_sort(self):
+ if not node.isdigit() and node in self.nodes_to_show:
+ yield self._restricted_table(node)
+
+ def __reversed__(self):
+ """
+ Iterate in reverse topological order (leaves first).
+
+ Same as ``__iter__`` but reversed — useful for cascading
+ deletes and drops.
+ """
+ for node in reversed(topo_sort(self)):
+ if not node.isdigit() and node in self.nodes_to_show:
+ yield self._restricted_table(node)
+
+ def prune(self):
+ """
+ Remove tables with zero matching rows from the diagram.
+
+ Without prior restrictions, removes physically empty tables.
+ After ``restrict()``, removes tables where the restricted query
+ yields zero rows. Cannot be used on a cascade Diagram (cascade
+ is for delete, where zero-count tables must remain in the graph
+ to handle concurrent inserts safely).
+
+ Returns
+ -------
+ Diagram
+ New Diagram with empty tables removed.
+ """
+ from .table import FreeTable
+
+ if self._cascade_restrictions:
+ raise DataJointError(
+ "prune() cannot be used on a Diagram produced by Diagram.cascade(). "
+ "Cascade diagrams must retain all descendant tables for safe deletion."
)
- def make_dot(self):
- """
- Generate a pydot graph object.
-
- Returns
- -------
- pydot.Dot
- The graph object ready for rendering.
-
- Notes
- -----
- Layout direction is controlled via ``dj.config.display.diagram_direction``.
- Tables are grouped by schema, with the Python module name shown as the
- group label when available.
- """
- direction = self._connection._config.display.diagram_direction
- graph = self._make_graph()
-
- # Apply collapse logic if needed
- graph, collapsed_counts = self._apply_collapse(graph)
-
- # Build schema mapping: class_name -> schema_name
- # Group by database schema, label with Python module name if 1:1 mapping
- schema_map = {} # class_name -> schema_name
- schema_modules = {} # schema_name -> set of module names
-
- for full_name in self.nodes_to_show:
- # Extract schema from full table name like `schema`.`table` or "schema"."table"
- parts = full_name.replace('"', "`").split("`")
- if len(parts) >= 2:
- schema_name = parts[1] # schema is between first pair of backticks
- class_name = lookup_class_name(full_name, self.context) or full_name
- schema_map[class_name] = schema_name
-
- # Collect all module names for this schema
- if schema_name not in schema_modules:
- schema_modules[schema_name] = set()
- cls = self._resolve_class(class_name)
+ result = Diagram(self)
+
+ if result._restrict_conditions:
+ for node in list(result._restrict_conditions):
+ if node.isdigit():
+ continue
+ if len(result._restricted_table(node)) == 0:
+ result._restrict_conditions.pop(node)
+ result._restriction_attrs.pop(node, None)
+ result.nodes_to_show.discard(node)
+ else:
+ # Unrestricted: check physical row counts
+ for node in list(result.nodes_to_show):
+ if node.isdigit():
+ continue
+ ft = FreeTable(self._connection, node)
+ if len(ft) == 0:
+ result.nodes_to_show.discard(node)
+
+ return result
+
+ def _make_graph(self) -> nx.DiGraph:
+ """
+ Build graph object ready for drawing.
+
+ Returns
+ -------
+ nx.DiGraph
+ Graph with nodes relabeled to class names.
+ """
+ # mark "distinguished" tables, i.e. those that introduce new primary key
+ # attributes
+ # Filter nodes_to_show to only include nodes that exist in the graph
+ valid_nodes = self.nodes_to_show.intersection(set(self.nodes()))
+ for name in valid_nodes:
+ foreign_attributes = set(
+ attr for p in self.in_edges(name, data=True) for attr in p[2]["attr_map"] if p[2]["primary"]
+ )
+ self.nodes[name]["distinguished"] = (
+ "primary_key" in self.nodes[name] and foreign_attributes < self.nodes[name]["primary_key"]
+ )
+ # include aliased nodes that are sandwiched between two displayed nodes
+ gaps = set(nx.algorithms.boundary.node_boundary(self, valid_nodes)).intersection(
+ nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), valid_nodes)
+ )
+ nodes = valid_nodes.union(a for a in gaps if a.isdigit())
+ # construct subgraph and rename nodes to class names
+ graph = nx.DiGraph(nx.DiGraph(self).subgraph(nodes))
+ nx.set_node_attributes(graph, name="node_type", values={n: _get_tier(n) for n in graph})
+ # relabel nodes to class names
+ mapping = {node: lookup_class_name(node, self.context) or node for node in graph.nodes()}
+ new_names = list(mapping.values())
+ if len(new_names) > len(set(new_names)):
+ raise DataJointError("Some classes have identical names. The Diagram cannot be plotted.")
+ nx.relabel_nodes(graph, mapping, copy=False)
+ return graph
+
+ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str]]:
+ """
+ Apply collapse logic to the graph.
+
+ Nodes in nodes_to_show but not in _expanded_nodes are collapsed into
+ single schema nodes.
+
+ Parameters
+ ----------
+ graph : nx.DiGraph
+ The graph from _make_graph().
+
+ Returns
+ -------
+ tuple[nx.DiGraph, dict[str, str]]
+ Modified graph and mapping of collapsed schema labels to their table count.
+ """
+ # Filter to valid nodes (those that exist in the underlying graph)
+ valid_nodes = self.nodes_to_show.intersection(set(self.nodes()))
+ valid_expanded = self._expanded_nodes.intersection(set(self.nodes()))
+
+ # If all nodes are expanded, no collapse needed
+ if valid_expanded >= valid_nodes:
+ return graph, {}
+
+ # Map full_table_names to class_names
+ full_to_class = {node: lookup_class_name(node, self.context) or node for node in valid_nodes}
+ class_to_full = {v: k for k, v in full_to_class.items()}
+
+ # Identify expanded class names
+ expanded_class_names = {full_to_class.get(node, node) for node in valid_expanded}
+
+ # Identify nodes to collapse (class names)
+ nodes_to_collapse = set(graph.nodes()) - expanded_class_names
+
+ if not nodes_to_collapse:
+ return graph, {}
+
+ # Group collapsed nodes by schema
+ collapsed_by_schema = {} # schema_name -> list of class_names
+ for class_name in nodes_to_collapse:
+ full_name = class_to_full.get(class_name)
+ if full_name:
+ schema_name, _ = self._connection.adapter.split_full_table_name(full_name)
+ if schema_name:
+ if schema_name not in collapsed_by_schema:
+ collapsed_by_schema[schema_name] = []
+ collapsed_by_schema[schema_name].append(class_name)
+
+ if not collapsed_by_schema:
+ return graph, {}
+
+ # Determine labels for collapsed schemas
+ schema_modules = {}
+ for schema_name, class_names in collapsed_by_schema.items():
+ schema_modules[schema_name] = set()
+ for class_name in class_names:
+ cls = self._resolve_class(class_name)
+ if cls is not None and hasattr(cls, "__module__"):
+ module_name = cls.__module__.split(".")[-1]
+ schema_modules[schema_name].add(module_name)
+
+ # Collect module names for ALL schemas in the diagram (not just collapsed)
+ all_schema_modules = {} # schema_name -> module_name
+ for node in graph.nodes():
+ full_name = class_to_full.get(node)
+ if full_name:
+ db_schema, _ = self._connection.adapter.split_full_table_name(full_name)
+ if db_schema:
+ cls = self._resolve_class(node)
if cls is not None and hasattr(cls, "__module__"):
module_name = cls.__module__.split(".")[-1]
- schema_modules[schema_name].add(module_name)
-
- # Determine cluster labels: use module name if 1:1, else database schema name
- cluster_labels = {} # schema_name -> label
- for schema_name, modules in schema_modules.items():
- if len(modules) == 1:
- cluster_labels[schema_name] = next(iter(modules))
- else:
- cluster_labels[schema_name] = schema_name
-
- # Disambiguate labels if multiple schemas share the same module name
- # (e.g., all defined in __main__ in a notebook)
- label_counts = {}
- for label in cluster_labels.values():
- label_counts[label] = label_counts.get(label, 0) + 1
-
- for schema_name, label in cluster_labels.items():
- if label_counts[label] > 1:
- # Multiple schemas share this module name - add schema name
- cluster_labels[schema_name] = f"{label} ({schema_name})"
-
- # Assign alias nodes (orange dots) to the same schema as their child table
- for node, data in graph.nodes(data=True):
- if data.get("node_type") is _AliasNode:
- # Find the child (successor) - the table that declares the renamed FK
- successors = list(graph.successors(node))
- if successors and successors[0] in schema_map:
- schema_map[node] = schema_map[successors[0]]
-
- # Assign collapsed nodes to their schema so they appear in the cluster
- for node, data in graph.nodes(data=True):
- if data.get("collapsed") and data.get("schema_name"):
- schema_map[node] = data["schema_name"]
-
- scale = 1.2 # scaling factor for fonts and boxes
- label_props = { # http://matplotlib.org/examples/color/named_colors.html
- None: dict(
- shape="circle",
- color="#FFFF0040",
- fontcolor="yellow",
- fontsize=round(scale * 8),
- size=0.4 * scale,
- fixed=False,
- ),
- _AliasNode: dict(
- shape="circle",
- color="#FF880080",
- fontcolor="#FF880080",
- fontsize=round(scale * 0),
- size=0.05 * scale,
- fixed=True,
- ),
- Manual: dict(
- shape="box",
- color="#00FF0030",
- fontcolor="darkgreen",
- fontsize=round(scale * 10),
- size=0.4 * scale,
- fixed=False,
- ),
- Lookup: dict(
- shape="plaintext",
- color="#00000020",
- fontcolor="black",
- fontsize=round(scale * 8),
- size=0.4 * scale,
- fixed=False,
- ),
- Computed: dict(
- shape="ellipse",
- color="#FF000020",
- fontcolor="#7F0000A0",
- fontsize=round(scale * 10),
- size=0.4 * scale,
- fixed=False,
- ),
- Imported: dict(
- shape="ellipse",
- color="#00007F40",
- fontcolor="#00007FA0",
- fontsize=round(scale * 10),
- size=0.4 * scale,
- fixed=False,
- ),
- Part: dict(
- shape="plaintext",
- color="#00000000",
- fontcolor="black",
- fontsize=round(scale * 8),
- size=0.1 * scale,
- fixed=False,
- ),
- "collapsed": dict(
- shape="box3d",
- color="#80808060",
- fontcolor="#404040",
- fontsize=round(scale * 10),
- size=0.5 * scale,
- fixed=False,
- ),
- }
- # Build node_props, handling collapsed nodes specially
- node_props = {}
- for node, d in graph.nodes(data=True):
- if d.get("collapsed"):
- node_props[node] = label_props["collapsed"]
+ all_schema_modules[db_schema] = module_name
+
+ # Check which module names are shared by multiple schemas
+ module_to_schemas = {}
+ for db_schema, module_name in all_schema_modules.items():
+ if module_name not in module_to_schemas:
+ module_to_schemas[module_name] = []
+ module_to_schemas[module_name].append(db_schema)
+
+ ambiguous_modules = {m for m, schemas in module_to_schemas.items() if len(schemas) > 1}
+
+ # Determine labels for collapsed schemas
+ collapsed_labels = {} # schema_name -> label
+ for schema_name, modules in schema_modules.items():
+ if len(modules) == 1:
+ module_name = next(iter(modules))
+ # Use database schema name if module is ambiguous
+ if module_name in ambiguous_modules:
+ label = schema_name
else:
- node_props[node] = label_props[d["node_type"]]
-
- self._encapsulate_node_names(graph)
- self._encapsulate_edge_attributes(graph)
- dot = nx.drawing.nx_pydot.to_pydot(graph)
- dot.set_rankdir(direction)
- for node in dot.get_nodes():
- node.set_shape("circle")
- name = node.get_name().strip('"')
- props = node_props[name]
- node.set_fontsize(props["fontsize"])
- node.set_fontcolor(props["fontcolor"])
- node.set_shape(props["shape"])
- node.set_fontname("arial")
- node.set_fixedsize("shape" if props["fixed"] else False)
- node.set_width(props["size"])
- node.set_height(props["size"])
-
- # Handle collapsed nodes specially
- node_data = graph.nodes.get(f'"{name}"', {})
- if node_data.get("collapsed"):
- table_count = node_data.get("table_count", 0)
- label = f"({table_count} tables)" if table_count != 1 else "(1 table)"
- node.set_label(label)
- node.set_tooltip(f"Collapsed schema: {table_count} tables")
+ label = module_name
+ else:
+ label = schema_name
+ collapsed_labels[schema_name] = label
+
+ # Build counts using final labels
+ collapsed_counts = {} # label -> count of tables
+ for schema_name, class_names in collapsed_by_schema.items():
+ label = collapsed_labels[schema_name]
+ collapsed_counts[label] = len(class_names)
+
+ # Create new graph with collapsed nodes
+ new_graph = nx.DiGraph()
+
+ # Map old node names to new names (collapsed nodes -> schema label)
+ node_mapping = {}
+ for node in graph.nodes():
+ full_name = class_to_full.get(node)
+ if full_name:
+ schema_name, _ = self._connection.adapter.split_full_table_name(full_name)
+ if schema_name and node in nodes_to_collapse:
+ node_mapping[node] = collapsed_labels[schema_name]
else:
- cls = self._resolve_class(name)
- if cls is not None:
- description = cls().describe(context=self.context).split("\n")
- description = (
- (
- "-" * 30
- if q.startswith("---")
- else (q.replace("->", "→") if "->" in q else q.split(":")[0])
- )
- for q in description
- if not q.startswith("#")
- )
- node.set_tooltip("
".join(description))
- # Strip module prefix from label if it matches the cluster label
- display_name = name
- schema_name = schema_map.get(name)
- if schema_name and "." in name:
- cluster_label = cluster_labels.get(schema_name)
- if cluster_label and name.startswith(cluster_label + "."):
- display_name = name[len(cluster_label) + 1 :]
- node.set_label("<" + display_name + ">" if node.get("distinguished") == "True" else display_name)
- node.set_color(props["color"])
- node.set_style("filled")
-
- for edge in dot.get_edges():
- # see https://graphviz.org/doc/info/attrs.html
- src = edge.get_source()
- dest = edge.get_destination()
- props = graph.get_edge_data(src, dest)
- if props is None:
- raise DataJointError("Could not find edge with source '{}' and destination '{}'".format(src, dest))
- edge.set_color("#00000040")
- edge.set_style("solid" if props.get("primary") else "dashed")
- dest_node_type = graph.nodes[dest].get("node_type")
- master_part = dest_node_type is Part and dest.startswith(src + ".")
- edge.set_weight(3 if master_part else 1)
- edge.set_arrowhead("none")
- edge.set_penwidth(0.75 if props.get("multi") else 2)
-
- # Group nodes into schema clusters (always on)
- if schema_map:
- import pydot
-
- # Group nodes by schema
- schemas = {}
- for node in list(dot.get_nodes()):
- name = node.get_name().strip('"')
- schema_name = schema_map.get(name)
- if schema_name:
- if schema_name not in schemas:
- schemas[schema_name] = []
- schemas[schema_name].append(node)
-
- # Create clusters for each schema
- # Use Python module name if 1:1 mapping, otherwise database schema name
- for schema_name, nodes in schemas.items():
- label = cluster_labels.get(schema_name, schema_name)
- cluster = pydot.Cluster(
- f"cluster_{schema_name}",
- label=label,
- style="dashed",
- color="gray",
- fontcolor="gray",
+ node_mapping[node] = node
+ else:
+ # Alias nodes - check if they should be collapsed
+ # An alias node should be collapsed if ALL its neighbors are collapsed
+ neighbors = set(graph.predecessors(node)) | set(graph.successors(node))
+ if neighbors and neighbors <= nodes_to_collapse:
+ # Get schema from first neighbor
+ neighbor = next(iter(neighbors))
+ full_name = class_to_full.get(neighbor)
+ if full_name:
+ schema_name, _ = self._connection.adapter.split_full_table_name(full_name)
+ if schema_name:
+ node_mapping[node] = collapsed_labels[schema_name]
+ continue
+ node_mapping[node] = node
+
+ # Build reverse mapping: label -> schema_name
+ label_to_schema = {label: schema for schema, label in collapsed_labels.items()}
+
+ # Add nodes
+ added_collapsed = set()
+ for old_node, new_node in node_mapping.items():
+ if new_node in collapsed_counts:
+ # This is a collapsed schema node
+ if new_node not in added_collapsed:
+ schema_name = label_to_schema.get(new_node, new_node)
+ new_graph.add_node(
+ new_node,
+ node_type=None,
+ collapsed=True,
+ table_count=collapsed_counts[new_node],
+ schema_name=schema_name,
)
- for node in nodes:
- cluster.add_node(node)
- dot.add_subgraph(cluster)
+ added_collapsed.add(new_node)
+ else:
+ new_graph.add_node(new_node, **graph.nodes[old_node])
- return dot
+ # Add edges (avoiding self-loops and duplicates)
+ for src, dest, data in graph.edges(data=True):
+ new_src = node_mapping[src]
+ new_dest = node_mapping[dest]
+ if new_src != new_dest and not new_graph.has_edge(new_src, new_dest):
+ new_graph.add_edge(new_src, new_dest, **data)
- def make_svg(self):
- from IPython.display import SVG
+ return new_graph, collapsed_counts
- return SVG(self.make_dot().create_svg())
+ def _resolve_class(self, name: str):
+ """
+ Safely resolve a table class from a dotted name without eval().
- def make_png(self):
- return io.BytesIO(self.make_dot().create_png())
+ Parameters
+ ----------
+ name : str
+ Dotted class name like "MyTable" or "Module.MyTable".
- def make_image(self):
- if plot_active:
- return plt.imread(self.make_png())
- else:
- raise DataJointError("pyplot was not imported")
-
- def make_mermaid(self) -> str:
- """
- Generate Mermaid diagram syntax.
-
- Produces a flowchart in Mermaid syntax that can be rendered in
- Markdown documentation, GitHub, or https://mermaid.live.
-
- Returns
- -------
- str
- Mermaid flowchart syntax.
-
- Notes
- -----
- Layout direction is controlled via ``dj.config.display.diagram_direction``.
- Tables are grouped by schema using Mermaid subgraphs, with the Python
- module name shown as the group label when available.
-
- Examples
- --------
- >>> print(dj.Diagram(schema).make_mermaid())
- flowchart TB
- subgraph my_pipeline
- Mouse[Mouse]:::manual
- Session[Session]:::manual
- Neuron([Neuron]):::computed
- end
- Mouse --> Session
- Session --> Neuron
- """
- graph = self._make_graph()
- direction = self._connection._config.display.diagram_direction
-
- # Apply collapse logic if needed
- graph, collapsed_counts = self._apply_collapse(graph)
-
- # Build schema mapping for grouping
- schema_map = {} # class_name -> schema_name
- schema_modules = {} # schema_name -> set of module names
-
- for full_name in self.nodes_to_show:
- parts = full_name.replace('"', "`").split("`")
- if len(parts) >= 2:
- schema_name = parts[1]
- class_name = lookup_class_name(full_name, self.context) or full_name
- schema_map[class_name] = schema_name
-
- # Collect all module names for this schema
- if schema_name not in schema_modules:
- schema_modules[schema_name] = set()
- cls = self._resolve_class(class_name)
- if cls is not None and hasattr(cls, "__module__"):
- module_name = cls.__module__.split(".")[-1]
- schema_modules[schema_name].add(module_name)
+ Returns
+ -------
+ type or None
+ The table class if found, otherwise None.
+ """
+ parts = name.split(".")
+ obj = self.context.get(parts[0])
+ for part in parts[1:]:
+ if obj is None:
+ return None
+ obj = getattr(obj, part, None)
+ if obj is not None and isinstance(obj, type) and issubclass(obj, Table):
+ return obj
+ return None
+
+ @staticmethod
+ def _encapsulate_edge_attributes(graph: nx.DiGraph) -> None:
+ """
+ Encapsulate edge attr_map in double quotes for pydot compatibility.
- # Determine cluster labels: use module name if 1:1, else database schema name
- cluster_labels = {}
- for schema_name, modules in schema_modules.items():
- if len(modules) == 1:
- cluster_labels[schema_name] = next(iter(modules))
- else:
- cluster_labels[schema_name] = schema_name
-
- # Assign alias nodes to the same schema as their child table
- for node, data in graph.nodes(data=True):
- if data.get("node_type") is _AliasNode:
- successors = list(graph.successors(node))
- if successors and successors[0] in schema_map:
- schema_map[node] = schema_map[successors[0]]
-
- lines = [f"flowchart {direction}"]
-
- # Define class styles matching Graphviz colors
- lines.append(" classDef manual fill:#90EE90,stroke:#006400")
- lines.append(" classDef lookup fill:#D3D3D3,stroke:#696969")
- lines.append(" classDef computed fill:#FFB6C1,stroke:#8B0000")
- lines.append(" classDef imported fill:#ADD8E6,stroke:#00008B")
- lines.append(" classDef part fill:#FFFFFF,stroke:#000000")
- lines.append(" classDef collapsed fill:#808080,stroke:#404040")
- lines.append("")
-
- # Shape mapping: Manual=box, Computed/Imported=stadium, Lookup/Part=box
- shape_map = {
- Manual: ("[", "]"), # box
- Lookup: ("[", "]"), # box
- Computed: ("([", "])"), # stadium/pill
- Imported: ("([", "])"), # stadium/pill
- Part: ("[", "]"), # box
- _AliasNode: ("((", "))"), # circle
- None: ("((", "))"), # circle
- }
-
- tier_class = {
- Manual: "manual",
- Lookup: "lookup",
- Computed: "computed",
- Imported: "imported",
- Part: "part",
- _AliasNode: "",
- None: "",
- }
-
- # Group nodes by schema into subgraphs (including collapsed nodes)
+ Modifies graph in place.
+
+ See Also
+ --------
+ https://github.com/pydot/pydot/issues/258#issuecomment-795798099
+ """
+ for u, v, *_, edgedata in graph.edges(data=True):
+ if "attr_map" in edgedata:
+ graph.edges[u, v]["attr_map"] = '"{0}"'.format(edgedata["attr_map"])
+
+ @staticmethod
+ def _encapsulate_node_names(graph: nx.DiGraph) -> None:
+ """
+ Encapsulate node names in double quotes for pydot compatibility.
+
+ Modifies graph in place.
+
+ See Also
+ --------
+ https://github.com/datajoint/datajoint-python/pull/1176
+ """
+ nx.relabel_nodes(
+ graph,
+ {node: '"{0}"'.format(node) for node in graph.nodes()},
+ copy=False,
+ )
+
+ def make_dot(self):
+ """
+ Generate a pydot graph object.
+
+ Returns
+ -------
+ pydot.Dot
+ The graph object ready for rendering.
+
+ Raises
+ ------
+ DataJointError
+ If pygraphviz/pydot is not installed.
+
+ Notes
+ -----
+ Layout direction is controlled via ``dj.config.display.diagram_direction``.
+ Tables are grouped by schema, with the Python module name shown as the
+ group label when available.
+ """
+ if not diagram_active:
+ raise DataJointError("Install pygraphviz and pydot libraries to enable diagram visualization.")
+ direction = self._connection._config.display.diagram_direction
+ graph = self._make_graph()
+
+ # Apply collapse logic if needed
+ graph, collapsed_counts = self._apply_collapse(graph)
+
+ # Build schema mapping: class_name -> schema_name
+ # Group by database schema, label with Python module name if 1:1 mapping
+ schema_map = {} # class_name -> schema_name
+ schema_modules = {} # schema_name -> set of module names
+
+ for full_name in self.nodes_to_show:
+ schema_name, _ = self._connection.adapter.split_full_table_name(full_name)
+ if schema_name:
+ class_name = lookup_class_name(full_name, self.context) or full_name
+ schema_map[class_name] = schema_name
+
+ # Collect all module names for this schema
+ if schema_name not in schema_modules:
+ schema_modules[schema_name] = set()
+ cls = self._resolve_class(class_name)
+ if cls is not None and hasattr(cls, "__module__"):
+ module_name = cls.__module__.split(".")[-1]
+ schema_modules[schema_name].add(module_name)
+
+ # Determine cluster labels: use module name if 1:1, else database schema name
+ cluster_labels = {} # schema_name -> label
+ for schema_name, modules in schema_modules.items():
+ if len(modules) == 1:
+ cluster_labels[schema_name] = next(iter(modules))
+ else:
+ cluster_labels[schema_name] = schema_name
+
+ # Disambiguate labels if multiple schemas share the same module name
+ # (e.g., all defined in __main__ in a notebook)
+ label_counts = {}
+ for label in cluster_labels.values():
+ label_counts[label] = label_counts.get(label, 0) + 1
+
+ for schema_name, label in cluster_labels.items():
+ if label_counts[label] > 1:
+ # Multiple schemas share this module name - add schema name
+ cluster_labels[schema_name] = f"{label} ({schema_name})"
+
+ # Assign alias nodes (orange dots) to the same schema as their child table
+ for node, data in graph.nodes(data=True):
+ if data.get("node_type") is _AliasNode:
+ # Find the child (successor) - the table that declares the renamed FK
+ successors = list(graph.successors(node))
+ if successors and successors[0] in schema_map:
+ schema_map[node] = schema_map[successors[0]]
+
+ # Assign collapsed nodes to their schema so they appear in the cluster
+ for node, data in graph.nodes(data=True):
+ if data.get("collapsed") and data.get("schema_name"):
+ schema_map[node] = data["schema_name"]
+
+ scale = 1.2 # scaling factor for fonts and boxes
+ label_props = { # http://matplotlib.org/examples/color/named_colors.html
+ None: dict(
+ shape="circle",
+ color="#FFFF0040",
+ fontcolor="yellow",
+ fontsize=round(scale * 8),
+ size=0.4 * scale,
+ fixed=False,
+ ),
+ _AliasNode: dict(
+ shape="circle",
+ color="#FF880080",
+ fontcolor="#FF880080",
+ fontsize=round(scale * 0),
+ size=0.05 * scale,
+ fixed=True,
+ ),
+ Manual: dict(
+ shape="box",
+ color="#00FF0030",
+ fontcolor="darkgreen",
+ fontsize=round(scale * 10),
+ size=0.4 * scale,
+ fixed=False,
+ ),
+ Lookup: dict(
+ shape="plaintext",
+ color="#00000020",
+ fontcolor="black",
+ fontsize=round(scale * 8),
+ size=0.4 * scale,
+ fixed=False,
+ ),
+ Computed: dict(
+ shape="ellipse",
+ color="#FF000020",
+ fontcolor="#7F0000A0",
+ fontsize=round(scale * 10),
+ size=0.4 * scale,
+ fixed=False,
+ ),
+ Imported: dict(
+ shape="ellipse",
+ color="#00007F40",
+ fontcolor="#00007FA0",
+ fontsize=round(scale * 10),
+ size=0.4 * scale,
+ fixed=False,
+ ),
+ Part: dict(
+ shape="plaintext",
+ color="#00000000",
+ fontcolor="black",
+ fontsize=round(scale * 8),
+ size=0.1 * scale,
+ fixed=False,
+ ),
+ "collapsed": dict(
+ shape="box3d",
+ color="#80808060",
+ fontcolor="#404040",
+ fontsize=round(scale * 10),
+ size=0.5 * scale,
+ fixed=False,
+ ),
+ }
+ # Build node_props, handling collapsed nodes specially
+ node_props = {}
+ for node, d in graph.nodes(data=True):
+ if d.get("collapsed"):
+ node_props[node] = label_props["collapsed"]
+ else:
+ node_props[node] = label_props[d["node_type"]]
+
+ self._encapsulate_node_names(graph)
+ self._encapsulate_edge_attributes(graph)
+ dot = nx.drawing.nx_pydot.to_pydot(graph)
+ dot.set_rankdir(direction)
+ for node in dot.get_nodes():
+ node.set_shape("circle")
+ name = node.get_name().strip('"')
+ props = node_props[name]
+ node.set_fontsize(props["fontsize"])
+ node.set_fontcolor(props["fontcolor"])
+ node.set_shape(props["shape"])
+ node.set_fontname("arial")
+ node.set_fixedsize("shape" if props["fixed"] else False)
+ node.set_width(props["size"])
+ node.set_height(props["size"])
+
+ # Handle collapsed nodes specially
+ node_data = graph.nodes.get(f'"{name}"', {})
+ if node_data.get("collapsed"):
+ table_count = node_data.get("table_count", 0)
+ label = f"({table_count} tables)" if table_count != 1 else "(1 table)"
+ node.set_label(label)
+ node.set_tooltip(f"Collapsed schema: {table_count} tables")
+ else:
+ cls = self._resolve_class(name)
+ if cls is not None:
+ description = cls().describe(context=self.context).split("\n")
+ description = (
+ ("-" * 30 if q.startswith("---") else (q.replace("->", "→") if "->" in q else q.split(":")[0]))
+ for q in description
+ if not q.startswith("#")
+ )
+ node.set_tooltip("
".join(description))
+ # Strip module prefix from label if it matches the cluster label
+ display_name = name
+ schema_name = schema_map.get(name)
+ if schema_name and "." in name:
+ cluster_label = cluster_labels.get(schema_name)
+ if cluster_label and name.startswith(cluster_label + "."):
+ display_name = name[len(cluster_label) + 1 :]
+ node.set_label("<" + display_name + ">" if node.get("distinguished") == "True" else display_name)
+ node.set_color(props["color"])
+ node.set_style("filled")
+
+ for edge in dot.get_edges():
+ # see https://graphviz.org/doc/info/attrs.html
+ src = edge.get_source()
+ dest = edge.get_destination()
+ props = graph.get_edge_data(src, dest)
+ if props is None:
+ raise DataJointError("Could not find edge with source '{}' and destination '{}'".format(src, dest))
+ edge.set_color("#00000040")
+ edge.set_style("solid" if props.get("primary") else "dashed")
+ dest_node_type = graph.nodes[dest].get("node_type")
+ master_part = dest_node_type is Part and dest.startswith(src + ".")
+ edge.set_weight(3 if master_part else 1)
+ edge.set_arrowhead("none")
+ edge.set_penwidth(0.75 if props.get("multi") else 2)
+
+ # Group nodes into schema clusters (always on)
+ if schema_map:
+ import pydot
+
+ # Group nodes by schema
schemas = {}
- for node, data in graph.nodes(data=True):
- if data.get("collapsed"):
- # Collapsed nodes use their schema_name attribute
- schema_name = data.get("schema_name")
- else:
- schema_name = schema_map.get(node)
+ for node in list(dot.get_nodes()):
+ name = node.get_name().strip('"')
+ schema_name = schema_map.get(name)
if schema_name:
if schema_name not in schemas:
schemas[schema_name] = []
- schemas[schema_name].append((node, data))
+ schemas[schema_name].append(node)
- # Add nodes grouped by schema subgraphs
+ # Create clusters for each schema
+ # Use Python module name if 1:1 mapping, otherwise database schema name
for schema_name, nodes in schemas.items():
label = cluster_labels.get(schema_name, schema_name)
- lines.append(f" subgraph {label}")
- for node, data in nodes:
- safe_id = node.replace(".", "_").replace(" ", "_")
- if data.get("collapsed"):
- # Collapsed node - show only table count
- table_count = data.get("table_count", 0)
- count_text = f"{table_count} tables" if table_count != 1 else "1 table"
- lines.append(f' {safe_id}[["({count_text})"]]:::collapsed')
- else:
- # Regular node
- tier = data.get("node_type")
- left, right = shape_map.get(tier, ("[", "]"))
- cls = tier_class.get(tier, "")
- # Strip module prefix from display name if it matches the cluster label
- display_name = node
- if "." in node and node.startswith(label + "."):
- display_name = node[len(label) + 1 :]
- class_suffix = f":::{cls}" if cls else ""
- lines.append(f" {safe_id}{left}{display_name}{right}{class_suffix}")
- lines.append(" end")
-
- lines.append("")
-
- # Add edges
- for src, dest, data in graph.edges(data=True):
- safe_src = src.replace(".", "_").replace(" ", "_")
- safe_dest = dest.replace(".", "_").replace(" ", "_")
- # Solid arrow for primary FK, dotted for non-primary
- style = "-->" if data.get("primary") else "-.->"
- lines.append(f" {safe_src} {style} {safe_dest}")
-
- return "\n".join(lines)
-
- def _repr_svg_(self):
- return self.make_svg()._repr_svg_()
-
- def draw(self):
- if plot_active:
- plt.imshow(self.make_image())
- plt.gca().axis("off")
- plt.show()
+ cluster = pydot.Cluster(
+ f"cluster_{schema_name}",
+ label=label,
+ style="dashed",
+ color="gray",
+ fontcolor="gray",
+ )
+ for node in nodes:
+ cluster.add_node(node)
+ dot.add_subgraph(cluster)
+
+ return dot
+
+ def make_svg(self):
+ from IPython.display import SVG
+
+ return SVG(self.make_dot().create_svg())
+
+ def make_png(self):
+ return io.BytesIO(self.make_dot().create_png())
+
+ def make_image(self):
+ if plot_active:
+ return plt.imread(self.make_png())
+ else:
+ raise DataJointError("pyplot was not imported")
+
+ def make_mermaid(self) -> str:
+ """
+ Generate Mermaid diagram syntax.
+
+ Produces a flowchart in Mermaid syntax that can be rendered in
+ Markdown documentation, GitHub, or https://mermaid.live.
+
+ Returns
+ -------
+ str
+ Mermaid flowchart syntax.
+
+ Notes
+ -----
+ Layout direction is controlled via ``dj.config.display.diagram_direction``.
+ Tables are grouped by schema using Mermaid subgraphs, with the Python
+ module name shown as the group label when available.
+
+ Examples
+ --------
+ >>> print(dj.Diagram(schema).make_mermaid())
+ flowchart TB
+ subgraph my_pipeline
+ Mouse[Mouse]:::manual
+ Session[Session]:::manual
+ Neuron([Neuron]):::computed
+ end
+ Mouse --> Session
+ Session --> Neuron
+ """
+ graph = self._make_graph()
+ direction = self._connection._config.display.diagram_direction
+
+ # Apply collapse logic if needed
+ graph, collapsed_counts = self._apply_collapse(graph)
+
+ # Build schema mapping for grouping
+ schema_map = {} # class_name -> schema_name
+ schema_modules = {} # schema_name -> set of module names
+
+ for full_name in self.nodes_to_show:
+ schema_name, _ = self._connection.adapter.split_full_table_name(full_name)
+ if schema_name:
+ class_name = lookup_class_name(full_name, self.context) or full_name
+ schema_map[class_name] = schema_name
+
+ # Collect all module names for this schema
+ if schema_name not in schema_modules:
+ schema_modules[schema_name] = set()
+ cls = self._resolve_class(class_name)
+ if cls is not None and hasattr(cls, "__module__"):
+ module_name = cls.__module__.split(".")[-1]
+ schema_modules[schema_name].add(module_name)
+
+ # Determine cluster labels: use module name if 1:1, else database schema name
+ cluster_labels = {}
+ for schema_name, modules in schema_modules.items():
+ if len(modules) == 1:
+ cluster_labels[schema_name] = next(iter(modules))
else:
- raise DataJointError("pyplot was not imported")
-
- def save(self, filename: str, format: str | None = None) -> None:
- """
- Save diagram to file.
-
- Parameters
- ----------
- filename : str
- Output filename.
- format : str, optional
- File format (``'png'``, ``'svg'``, or ``'mermaid'``).
- Inferred from extension if None.
-
- Raises
- ------
- DataJointError
- If format is unsupported.
-
- Notes
- -----
- Layout direction is controlled via ``dj.config.display.diagram_direction``.
- Tables are grouped by schema, with the Python module name shown as the
- group label when available.
- """
- if format is None:
- if filename.lower().endswith(".png"):
- format = "png"
- elif filename.lower().endswith(".svg"):
- format = "svg"
- elif filename.lower().endswith((".mmd", ".mermaid")):
- format = "mermaid"
- if format is None:
- raise DataJointError("Could not infer format from filename. Specify format explicitly.")
- if format.lower() == "png":
- with open(filename, "wb") as f:
- f.write(self.make_png().getbuffer().tobytes())
- elif format.lower() == "svg":
- with open(filename, "w") as f:
- f.write(self.make_svg().data)
- elif format.lower() == "mermaid":
- with open(filename, "w") as f:
- f.write(self.make_mermaid())
+ cluster_labels[schema_name] = schema_name
+
+ # Assign alias nodes to the same schema as their child table
+ for node, data in graph.nodes(data=True):
+ if data.get("node_type") is _AliasNode:
+ successors = list(graph.successors(node))
+ if successors and successors[0] in schema_map:
+ schema_map[node] = schema_map[successors[0]]
+
+ lines = [f"flowchart {direction}"]
+
+ # Define class styles matching Graphviz colors
+ lines.append(" classDef manual fill:#90EE90,stroke:#006400")
+ lines.append(" classDef lookup fill:#D3D3D3,stroke:#696969")
+ lines.append(" classDef computed fill:#FFB6C1,stroke:#8B0000")
+ lines.append(" classDef imported fill:#ADD8E6,stroke:#00008B")
+ lines.append(" classDef part fill:#FFFFFF,stroke:#000000")
+ lines.append(" classDef collapsed fill:#808080,stroke:#404040")
+ lines.append("")
+
+ # Shape mapping: Manual=box, Computed/Imported=stadium, Lookup/Part=box
+ shape_map = {
+ Manual: ("[", "]"), # box
+ Lookup: ("[", "]"), # box
+ Computed: ("([", "])"), # stadium/pill
+ Imported: ("([", "])"), # stadium/pill
+ Part: ("[", "]"), # box
+ _AliasNode: ("((", "))"), # circle
+ None: ("((", "))"), # circle
+ }
+
+ tier_class = {
+ Manual: "manual",
+ Lookup: "lookup",
+ Computed: "computed",
+ Imported: "imported",
+ Part: "part",
+ _AliasNode: "",
+ None: "",
+ }
+
+ # Group nodes by schema into subgraphs (including collapsed nodes)
+ schemas = {}
+ for node, data in graph.nodes(data=True):
+ if data.get("collapsed"):
+ # Collapsed nodes use their schema_name attribute
+ schema_name = data.get("schema_name")
else:
- raise DataJointError("Unsupported file format")
+ schema_name = schema_map.get(node)
+ if schema_name:
+ if schema_name not in schemas:
+ schemas[schema_name] = []
+ schemas[schema_name].append((node, data))
+
+ # Add nodes grouped by schema subgraphs
+ for schema_name, nodes in schemas.items():
+ label = cluster_labels.get(schema_name, schema_name)
+ lines.append(f" subgraph {label}")
+ for node, data in nodes:
+ safe_id = node.replace(".", "_").replace(" ", "_")
+ if data.get("collapsed"):
+ # Collapsed node - show only table count
+ table_count = data.get("table_count", 0)
+ count_text = f"{table_count} tables" if table_count != 1 else "1 table"
+ lines.append(f' {safe_id}[["({count_text})"]]:::collapsed')
+ else:
+ # Regular node
+ tier = data.get("node_type")
+ left, right = shape_map.get(tier, ("[", "]"))
+ cls = tier_class.get(tier, "")
+ # Strip module prefix from display name if it matches the cluster label
+ display_name = node
+ if "." in node and node.startswith(label + "."):
+ display_name = node[len(label) + 1 :]
+ class_suffix = f":::{cls}" if cls else ""
+ lines.append(f" {safe_id}{left}{display_name}{right}{class_suffix}")
+ lines.append(" end")
+
+ lines.append("")
+
+ # Add edges
+ for src, dest, data in graph.edges(data=True):
+ safe_src = src.replace(".", "_").replace(" ", "_")
+ safe_dest = dest.replace(".", "_").replace(" ", "_")
+ # Solid arrow for primary FK, dotted for non-primary
+ style = "-->" if data.get("primary") else "-.->"
+ lines.append(f" {safe_src} {style} {safe_dest}")
+
+ return "\n".join(lines)
+
+ def _repr_svg_(self):
+ return self.make_svg()._repr_svg_()
+
+ def draw(self):
+ if plot_active:
+ plt.imshow(self.make_image())
+ plt.gca().axis("off")
+ plt.show()
+ else:
+ raise DataJointError("pyplot was not imported")
+
+ def save(self, filename: str, format: str | None = None) -> None:
+ """
+ Save diagram to file.
- @staticmethod
- def _layout(graph, **kwargs):
- return pydot_layout(graph, prog="dot", **kwargs)
+ Parameters
+ ----------
+ filename : str
+ Output filename.
+ format : str, optional
+ File format (``'png'``, ``'svg'``, or ``'mermaid'``).
+ Inferred from extension if None.
+
+ Raises
+ ------
+ DataJointError
+ If format is unsupported.
+
+ Notes
+ -----
+ Layout direction is controlled via ``dj.config.display.diagram_direction``.
+ Tables are grouped by schema, with the Python module name shown as the
+ group label when available.
+ """
+ if format is None:
+ if filename.lower().endswith(".png"):
+ format = "png"
+ elif filename.lower().endswith(".svg"):
+ format = "svg"
+ elif filename.lower().endswith((".mmd", ".mermaid")):
+ format = "mermaid"
+ if format is None:
+ raise DataJointError("Could not infer format from filename. Specify format explicitly.")
+ if format.lower() == "png":
+ with open(filename, "wb") as f:
+ f.write(self.make_png().getbuffer().tobytes())
+ elif format.lower() == "svg":
+ with open(filename, "w") as f:
+ f.write(self.make_svg().data)
+ elif format.lower() == "mermaid":
+ with open(filename, "w") as f:
+ f.write(self.make_mermaid())
+ else:
+ raise DataJointError("Unsupported file format")
+
+ @staticmethod
+ def _layout(graph, **kwargs):
+ return pydot_layout(graph, prog="dot", **kwargs)
diff --git a/src/datajoint/table.py b/src/datajoint/table.py
index 83e73e2bb..7f8cbaf70 100644
--- a/src/datajoint/table.py
+++ b/src/datajoint/table.py
@@ -14,6 +14,7 @@
from .condition import make_condition
from .declare import alter, declare
+from .dependencies import extract_master
from .errors import (
AccessError,
DataJointError,
@@ -24,7 +25,7 @@
from .expression import QueryExpression
from .heading import Heading
from .staged_insert import staged_insert1 as _staged_insert1
-from .utils import get_master, is_camel_case, user_choice
+from .utils import is_camel_case, user_choice
logger = logging.getLogger(__name__.split(".")[0])
@@ -984,6 +985,19 @@ def delete(
"""
Deletes the contents of the table and its dependent tables, recursively.
+ Uses graph-driven cascade: builds a dependency diagram, propagates
+ restrictions to all descendants, then deletes in reverse topological
+ order (leaves first).
+
+ With ``safemode=True`` (the default), delete previews all affected
+ tables and row counts, executes within a transaction, and asks for
+ confirmation before committing. Declining rolls back all changes —
+ effectively a built-in dry run.
+
+ To preview cascade impact without executing, use ``Diagram``::
+
+ dj.Diagram.cascade(MyTable & restriction).counts()
+
Args:
transaction: If `True`, use of the entire delete becomes an atomic transaction.
This is the default and recommended behavior. Set to `False` if this delete is
@@ -999,182 +1013,100 @@ def delete(
Number of deleted rows (excluding those from dependent tables).
Raises:
- DataJointError: Delete exceeds maximum number of delete attempts.
DataJointError: When deleting within an existing transaction.
DataJointError: Deleting a part table before its master (when part_integrity="enforce").
ValueError: Invalid part_integrity value.
"""
if part_integrity not in ("enforce", "ignore", "cascade"):
- raise ValueError(f"part_integrity must be 'enforce', 'ignore', or 'cascade', got {part_integrity!r}")
- deleted = set()
- visited_masters = set()
-
- def cascade(table):
- """service function to perform cascading deletes recursively."""
- max_attempts = 50
- for _ in range(max_attempts):
- # Set savepoint before delete attempt (for PostgreSQL transaction handling)
- savepoint_name = f"cascade_delete_{id(table)}"
- if transaction:
- table.connection.query(f"SAVEPOINT {savepoint_name}")
+ raise ValueError(f"part_integrity must be 'enforce', 'ignore', or 'cascade', " f"got {part_integrity!r}")
+ from .diagram import Diagram
- try:
- delete_count = table.delete_quick(get_count=True)
- except IntegrityError as error:
- # Rollback to savepoint so we can continue querying (PostgreSQL requirement)
- if transaction:
- table.connection.query(f"ROLLBACK TO SAVEPOINT {savepoint_name}")
- # Use adapter to parse FK error message
- match = table.connection.adapter.parse_foreign_key_error(error.args[0])
- if match is None:
- raise DataJointError(
- "Cascading deletes failed because the error message is missing foreign key information. "
- "Make sure you have REFERENCES privilege to all dependent tables."
- ) from None
-
- # Strip quotes from parsed values for backend-agnostic processing
- quote_chars = ("`", '"')
-
- def strip_quotes(s):
- if s and any(s.startswith(q) for q in quote_chars):
- return s.strip('`"')
- return s
-
- # Extract schema and table name from child (work with unquoted names)
- child_table_raw = strip_quotes(match["child"])
- if "." in child_table_raw:
- child_parts = child_table_raw.split(".")
- child_schema = strip_quotes(child_parts[0])
- child_table_name = strip_quotes(child_parts[1])
- else:
- # Add schema from current table
- schema_parts = table.full_table_name.split(".")
- child_schema = strip_quotes(schema_parts[0])
- child_table_name = child_table_raw
-
- # If FK/PK attributes not in error message, query information_schema
- if match["fk_attrs"] is None or match["pk_attrs"] is None:
- constraint_query = table.connection.adapter.get_constraint_info_sql(
- strip_quotes(match["name"]),
- child_schema,
- child_table_name,
- )
+ diagram = Diagram.cascade(self, part_integrity=part_integrity)
- results = table.connection.query(
- constraint_query,
- args=(strip_quotes(match["name"]), child_schema, child_table_name),
- ).fetchall()
- if results:
- match["fk_attrs"], match["parent"], match["pk_attrs"] = list(map(list, zip(*results)))
- match["parent"] = match["parent"][0] # All rows have same parent
-
- # Build properly quoted full table name for FreeTable
- child_full_name = (
- f"{table.connection.adapter.quote_identifier(child_schema)}."
- f"{table.connection.adapter.quote_identifier(child_table_name)}"
- )
-
- # Restrict child by table if
- # 1. if table's restriction attributes are not in child's primary key
- # 2. if child renames any attributes
- # Otherwise restrict child by table's restriction.
- child = FreeTable(table.connection, child_full_name)
- if set(table.restriction_attributes) <= set(child.primary_key) and match["fk_attrs"] == match["pk_attrs"]:
- child._restriction = table._restriction
- child._restriction_attributes = table.restriction_attributes
- elif match["fk_attrs"] != match["pk_attrs"]:
- child &= table.proj(**dict(zip(match["fk_attrs"], match["pk_attrs"])))
- else:
- child &= table.proj()
-
- master_name = get_master(child.full_table_name, table.connection.adapter)
- if (
- part_integrity == "cascade"
- and master_name
- and master_name != table.full_table_name
- and master_name not in visited_masters
- ):
- master = FreeTable(table.connection, master_name)
- master._restriction_attributes = set()
- master._restriction = [
- make_condition( # &= may cause in target tables in subquery
- master,
- (master.proj() & child.proj()).to_arrays(),
- master._restriction_attributes,
- )
- ]
- visited_masters.add(master_name)
- cascade(master)
- else:
- cascade(child)
- else:
- # Successful delete - release savepoint
- if transaction:
- table.connection.query(f"RELEASE SAVEPOINT {savepoint_name}")
- deleted.add(table.full_table_name)
- logger.info("Deleting {count} rows from {table}".format(count=delete_count, table=table.full_table_name))
- break
- else:
- raise DataJointError("Exceeded maximum number of delete attempts.")
- return delete_count
+ conn = self.connection
+ prompt = conn._config["safemode"] if prompt is None else prompt
- prompt = self.connection._config["safemode"] if prompt is None else prompt
+ # Preview
+ if prompt:
+ for ft in diagram:
+ logger.info("{table} ({count} tuples)".format(table=ft.full_table_name, count=len(ft)))
# Start transaction
if transaction:
- if not self.connection.in_transaction:
- self.connection.start_transaction()
+ if not conn.in_transaction:
+ conn.start_transaction()
else:
if not prompt:
transaction = False
else:
raise DataJointError(
- "Delete cannot use a transaction within an ongoing transaction. Set transaction=False or prompt=False."
+ "Delete cannot use a transaction within an "
+ "ongoing transaction. Set transaction=False "
+ "or prompt=False."
)
- # Cascading delete
+ # Execute deletes in reverse topological order (leaves first)
+ root_count = 0
+ deleted_tables = set()
try:
- delete_count = cascade(self)
- except:
+ for ft in reversed(diagram):
+ count = ft.delete_quick(get_count=True)
+ if count > 0:
+ deleted_tables.add(ft.full_table_name)
+ logger.info("Deleting {count} rows from {table}".format(count=count, table=ft.full_table_name))
+ if ft.full_table_name == self.full_table_name:
+ root_count = count
+ except IntegrityError as error:
if transaction:
- self.connection.cancel_transaction()
+ conn.cancel_transaction()
+ match = conn.adapter.parse_foreign_key_error(error.args[0])
+ if match:
+ raise DataJointError(
+ "Delete blocked by table {child} in an unloaded "
+ "schema. Activate all dependent schemas before "
+ "deleting.".format(child=match["child"])
+ ) from None
+ raise DataJointError("Delete blocked by FK in unloaded/inaccessible schema.") from None
+ except Exception:
+ if transaction:
+ conn.cancel_transaction()
raise
- if part_integrity == "enforce":
- # Avoid deleting from part before master (See issue #151)
- for part in deleted:
- master = get_master(part, self.connection.adapter)
- if master and master not in deleted:
+ # Post-check part_integrity="enforce": roll back if a part table
+ # had rows deleted without its master also having rows deleted.
+ if part_integrity == "enforce" and deleted_tables:
+ for table_name in deleted_tables:
+ master = extract_master(table_name)
+ if master and master not in deleted_tables:
if transaction:
- self.connection.cancel_transaction()
+ conn.cancel_transaction()
raise DataJointError(
- "Attempt to delete part table {part} before deleting from its master {master} first. "
- "Use part_integrity='ignore' to allow, or part_integrity='cascade' to also delete master.".format(
- part=part, master=master
- )
+ f"Attempt to delete part table {table_name} before "
+ f"its master {master}. Delete from the master first, "
+ f"or use part_integrity='ignore' or 'cascade'."
)
# Confirm and commit
- if delete_count == 0:
+ if root_count == 0:
if prompt:
logger.warning("Nothing to delete.")
if transaction:
- self.connection.cancel_transaction()
+ conn.cancel_transaction()
elif not transaction:
logger.info("Delete completed")
else:
if not prompt or user_choice("Commit deletes?", default="no") == "yes":
if transaction:
- self.connection.commit_transaction()
+ conn.commit_transaction()
if prompt:
logger.info("Delete committed.")
else:
if transaction:
- self.connection.cancel_transaction()
+ conn.cancel_transaction()
if prompt:
logger.warning("Delete cancelled")
- delete_count = 0 # Reset count when delete is cancelled
- return delete_count
+ root_count = 0
+ return root_count
def drop_quick(self):
"""
@@ -1214,41 +1146,59 @@ def drop_quick(self):
else:
logger.info("Nothing to drop: table %s is not declared" % self.full_table_name)
- def drop(self, prompt: bool | None = None):
+ def drop(self, prompt: bool | None = None, part_integrity: str = "enforce"):
"""
Drop the table and all tables that reference it, recursively.
+ Uses graph-driven traversal: builds a dependency diagram and drops
+ in reverse topological order (leaves first).
+
+ With ``safemode=True`` (the default), drop previews all affected
+ tables and row counts and asks for confirmation before proceeding.
+
Args:
prompt: If `True`, show what will be dropped and ask for confirmation.
If `False`, drop without confirmation. Default is `dj.config['safemode']`.
+ part_integrity: Policy for master-part integrity. One of:
+ - ``"enforce"`` (default): Error if parts would be dropped without masters.
+ - ``"ignore"``: Allow dropping parts without masters.
"""
if self.restriction:
raise DataJointError(
- "A table with an applied restriction cannot be dropped. Call drop() on the unrestricted Table."
+ "A table with an applied restriction cannot be dropped. " "Call drop() on the unrestricted Table."
)
- prompt = self.connection._config["safemode"] if prompt is None else prompt
+ import networkx as nx
+ from .diagram import Diagram
- self.connection.dependencies.load()
- do_drop = True
- tables = [table for table in self.connection.dependencies.descendants(self.full_table_name) if not table.isdigit()]
+ self.connection.dependencies.load_all_downstream()
+ diagram = Diagram(self)
+ # Expand to include all descendants (cross-schema)
+ descendants = set(nx.descendants(diagram, self.full_table_name)) | {self.full_table_name}
+ diagram.nodes_to_show = descendants
+ diagram._expanded_nodes = set(descendants)
+ conn = self.connection
+ prompt = conn._config["safemode"] if prompt is None else prompt
- # avoid dropping part tables without their masters: See issue #374
- for part in tables:
- master = get_master(part, self.connection.adapter)
- if master and master not in tables:
- raise DataJointError(
- "Attempt to drop part table {part} before dropping its master. Drop {master} first.".format(
- part=part, master=master
+ table_names = [ft.full_table_name for ft in diagram]
+
+ if part_integrity == "enforce":
+ for name in table_names:
+ master = extract_master(name)
+ if master and master not in table_names:
+ raise DataJointError(
+ "Attempt to drop part table {part} before its " "master {master}. Drop the master first.".format(
+ part=name, master=master
+ )
)
- )
+ do_drop = True
if prompt:
- for table in tables:
- logger.info(table + " (%d tuples)" % len(FreeTable(self.connection, table)))
+ for ft in diagram:
+ logger.info("{table} ({count} tuples)".format(table=ft.full_table_name, count=len(ft)))
do_drop = user_choice("Proceed?", default="no") == "yes"
if do_drop:
- for table in reversed(tables):
- FreeTable(self.connection, table).drop_quick()
+ for ft in reversed(diagram):
+ ft.drop_quick()
logger.info("Tables dropped. Restart kernel.")
def describe(self, context=None, printout=False):
@@ -1611,8 +1561,7 @@ class FreeTable(Table):
"""
def __init__(self, conn, full_table_name):
- # Backend-agnostic quote stripping (MySQL uses `, PostgreSQL uses ")
- self.database, self._table_name = (s.strip('`"') for s in full_table_name.split("."))
+ self.database, self._table_name = conn.adapter.split_full_table_name(full_table_name)
self._connection = conn
self._support = [full_table_name]
self._heading = Heading(
diff --git a/src/datajoint/user_tables.py b/src/datajoint/user_tables.py
index 7822fa9e2..514f4eb60 100644
--- a/src/datajoint/user_tables.py
+++ b/src/datajoint/user_tables.py
@@ -235,7 +235,7 @@ def delete(self, part_integrity: str = "enforce", **kwargs):
"or use part_integrity='ignore' to break integrity, "
"or part_integrity='cascade' to also delete master."
)
- super().delete(part_integrity=part_integrity, **kwargs)
+ return super().delete(part_integrity=part_integrity, **kwargs)
def drop(self, part_integrity: str = "enforce"):
"""
@@ -251,7 +251,7 @@ def drop(self, part_integrity: str = "enforce"):
DataJointError: If part_integrity="enforce" (direct Part drops prohibited)
"""
if part_integrity == "ignore":
- super().drop()
+ return super().drop(part_integrity="ignore")
elif part_integrity == "enforce":
raise DataJointError("Cannot drop a Part directly. Drop master instead, or use part_integrity='ignore' to force.")
else:
diff --git a/src/datajoint/utils.py b/src/datajoint/utils.py
index 0441af354..e36267936 100644
--- a/src/datajoint/utils.py
+++ b/src/datajoint/utils.py
@@ -37,46 +37,6 @@ def user_choice(prompt, choices=("yes", "no"), default=None):
return response
-def get_master(full_table_name: str, adapter=None) -> str:
- """
- Get the master table name from a part table name.
-
- If the table name is that of a part table, then return what the master table name would be.
- This follows DataJoint's table naming convention where a master and a part must be in the
- same schema and the part table is prefixed with the master table name + ``__``.
-
- Parameters
- ----------
- full_table_name : str
- Full table name including part.
- adapter : DatabaseAdapter, optional
- Database adapter for backend-specific parsing. Default None.
-
- Returns
- -------
- str
- Supposed master full table name or empty string if not a part table name.
-
- Examples
- --------
- >>> get_master('`ephys`.`session__recording`') # MySQL part table
- '`ephys`.`session`'
- >>> get_master('"ephys"."session__recording"') # PostgreSQL part table
- '"ephys"."session"'
- >>> get_master('`ephys`.`session`') # Not a part table
- ''
- """
- if adapter is not None:
- result = adapter.get_master_table_name(full_table_name)
- return result if result else ""
-
- # Fallback: handle both MySQL backticks and PostgreSQL double quotes
- match = re.match(r'(?P(?P[`"])[\w]+(?P=q)\.(?P=q)[\w]+)__[\w]+(?P=q)', full_table_name)
- if match:
- return match["master"] + match["q"]
- return ""
-
-
def is_camel_case(s):
"""
Check if a string is in CamelCase notation.
diff --git a/src/datajoint/version.py b/src/datajoint/version.py
index 871a28cbb..9a1d4aff2 100644
--- a/src/datajoint/version.py
+++ b/src/datajoint/version.py
@@ -1,4 +1,4 @@
# version bump auto managed by Github Actions:
# label_prs.yaml(prep), release.yaml(bump), post_release.yaml(edit)
# manually set this version will be eventually overwritten by the above actions
-__version__ = "2.1.1"
+__version__ = "2.2.0.dev0"
diff --git a/tests/integration/test_cascade_delete.py b/tests/integration/test_cascade_delete.py
index caf5f331b..3bc3dc73b 100644
--- a/tests/integration/test_cascade_delete.py
+++ b/tests/integration/test_cascade_delete.py
@@ -188,3 +188,107 @@ class Observation(dj.Manual):
assert remaining_obs[0]["obs_id"] == 3
assert remaining_obs[0]["subject_id"] == 2
assert remaining_obs[0]["measurement"] == 15.3
+
+
+def test_delete_preview_with_counts(schema_by_backend):
+ """Diagram.cascade().counts() previews affected rows without deleting."""
+
+ @schema_by_backend
+ class Parent(dj.Manual):
+ definition = """
+ parent_id : int
+ ---
+ name : varchar(255)
+ """
+
+ @schema_by_backend
+ class Child(dj.Manual):
+ definition = """
+ -> Parent
+ child_id : int
+ ---
+ data : varchar(255)
+ """
+
+ Parent.insert1((1, "P1"))
+ Parent.insert1((2, "P2"))
+ Child.insert1((1, 1, "C1-1"))
+ Child.insert1((1, 2, "C1-2"))
+ Child.insert1((2, 1, "C2-1"))
+
+ # Preview restricted cascade via Diagram
+ counts = dj.Diagram.cascade(Parent & {"parent_id": 1}).counts()
+
+ assert isinstance(counts, dict)
+ assert counts[Parent.full_table_name] == 1
+ assert counts[Child.full_table_name] == 2
+
+ # Data must still be intact
+ assert len(Parent()) == 2
+ assert len(Child()) == 3
+
+
+def test_cascade_discovers_downstream_schema(connection_by_backend, db_creds_by_backend):
+ """Cascade delete discovers and includes tables in unloaded downstream schemas."""
+ import time
+
+ backend = db_creds_by_backend["backend"]
+ test_id = str(int(time.time() * 1000))[-8:]
+
+ upstream_name = f"djtest_upstream_{backend}_{test_id}"[:64]
+ downstream_name = f"djtest_downstream_{backend}_{test_id}"[:64]
+
+ qi = connection_by_backend.adapter.quote_identifier
+
+ # Clean up any previous runs
+ for name in (downstream_name, upstream_name):
+ try:
+ connection_by_backend.query(f"DROP DATABASE IF EXISTS {qi(name)}")
+ except Exception:
+ pass
+
+ # Create upstream schema and table
+ upstream = dj.Schema(upstream_name, connection=connection_by_backend)
+
+ @upstream
+ class Parent(dj.Manual):
+ definition = """
+ parent_id : int
+ ---
+ name : varchar(100)
+ """
+
+ # Create downstream schema with FK to upstream — separate schema object
+ downstream = dj.Schema(downstream_name, connection=connection_by_backend)
+
+ @downstream
+ class Child(dj.Manual):
+ definition = """
+ -> Parent
+ child_id : int
+ ---
+ data : varchar(100)
+ """
+
+ # Insert data
+ Parent.insert1(dict(parent_id=1, name="Alice"))
+ Child.insert1(dict(parent_id=1, child_id=1, data="row1"))
+ Child.insert1(dict(parent_id=1, child_id=2, data="row2"))
+
+ # Verify cascade preview discovers the downstream schema
+ counts = dj.Diagram.cascade(Parent & "parent_id=1").counts()
+ assert Parent.full_table_name in counts
+ assert Child.full_table_name in counts
+ assert counts[Child.full_table_name] == 2
+
+ # Verify actual delete cascades across schemas
+ (Parent & "parent_id=1").delete()
+ assert len(Parent()) == 0
+ assert len(Child()) == 0
+
+ # Clean up
+ for name in (downstream_name, upstream_name):
+ try:
+ connection_by_backend.query(f"DROP DATABASE IF EXISTS {qi(name)}")
+ except Exception:
+ pass
diff --git a/tests/integration/test_cli.py b/tests/integration/test_cli.py
index 35230ea4e..1f8144f0f 100644
--- a/tests/integration/test_cli.py
+++ b/tests/integration/test_cli.py
@@ -3,6 +3,7 @@
"""
import subprocess
+import sys
import pytest
@@ -31,7 +32,7 @@ def test_cli_help(capsys):
def test_cli_config():
process = subprocess.Popen(
- ["dj"],
+ [sys.executable, "-m", "datajoint.cli"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
@@ -50,7 +51,7 @@ def test_cli_config():
def test_cli_args():
process = subprocess.Popen(
- ["dj", "-u", "test_user", "-p", "test_pass", "--host", "test_host"],
+ [sys.executable, "-m", "datajoint.cli", "-u", "test_user", "-p", "test_pass", "--host", "test_host"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
@@ -82,7 +83,9 @@ class IJ(dj.Lookup):
# Pass credentials via CLI args to avoid prompting for username
process = subprocess.Popen(
[
- "dj",
+ sys.executable,
+ "-m",
+ "datajoint.cli",
"-u",
db_creds_root["user"],
"-p",
diff --git a/tests/integration/test_erd.py b/tests/integration/test_erd.py
index 95077da50..d746bf49e 100644
--- a/tests/integration/test_erd.py
+++ b/tests/integration/test_erd.py
@@ -1,6 +1,8 @@
+import pytest as _pytest
+
import datajoint as dj
-from tests.schema_simple import LOCALS_SIMPLE, A, B, D, E, G, L
+from tests.schema_simple import LOCALS_SIMPLE, A, B, D, E, G, L, Profile, Website
def test_decorator(schema_simp):
@@ -61,3 +63,96 @@ def test_part_table_parsing(schema_simp):
graph = erd._make_graph()
assert "OutfitLaunch" in graph.nodes()
assert "OutfitLaunch.OutfitPiece" in graph.nodes()
+
+
+# --- prune() tests ---
+
+
+@_pytest.fixture
+def schema_simp_pop(schema_simp):
+ """Populate the simple schema for prune tests."""
+ Profile().delete()
+ Website().delete()
+ G().delete()
+ E().delete()
+ D().delete()
+ B().delete()
+ L().delete()
+ A().delete()
+
+ A().insert(A.contents, skip_duplicates=True)
+ L().insert(L.contents, skip_duplicates=True)
+ B().populate()
+ D().populate()
+ E().populate()
+ G().populate()
+ yield schema_simp
+
+
+def test_prune_unrestricted(schema_simp_pop):
+ """Prune on unrestricted diagram removes physically empty tables."""
+ diag = dj.Diagram(schema_simp_pop, context=LOCALS_SIMPLE)
+ original_count = len(diag.nodes_to_show)
+ pruned = diag.prune()
+
+ # Populated tables (A, L, B, B.C, D, E, E.F, G, etc.) should survive
+ for cls in (A, B, D, E, L):
+ assert cls.full_table_name in pruned.nodes_to_show, f"{cls.__name__} should not be pruned"
+
+ # Empty tables like Profile should be removed
+ assert Profile.full_table_name not in pruned.nodes_to_show, "empty Profile should be pruned"
+
+ # Pruned diagram should have fewer nodes
+ assert len(pruned.nodes_to_show) < original_count
+
+
+def test_prune_after_restrict(schema_simp_pop):
+ """Prune after restrict removes tables with zero matching rows."""
+ diag = dj.Diagram(schema_simp_pop, context=LOCALS_SIMPLE)
+ restricted = diag.restrict(A & "id_a=0")
+ counts = restricted.counts()
+
+ pruned = restricted.prune()
+ pruned_counts = pruned.counts()
+
+ # Every table in pruned preview should have > 0 rows
+ assert all(c > 0 for c in pruned_counts.values()), "pruned diagram should have no zero-count tables"
+
+ # Tables with zero rows in the original preview should be gone
+ for table, count in counts.items():
+ if count == 0:
+ assert table not in pruned._restrict_conditions, f"{table} had 0 rows but was not pruned"
+
+
+def test_prune_raises_on_cascade(schema_simp_pop):
+ """prune() raises on a cascade Diagram — cascade must retain all tables for safe deletion."""
+ cascaded = dj.Diagram.cascade(A & "id_a=0")
+ with _pytest.raises(dj.DataJointError, match="prune.*cannot be used.*cascade"):
+ cascaded.prune()
+
+
+def test_prune_idempotent(schema_simp_pop):
+ """Pruning twice gives the same result."""
+ diag = dj.Diagram(schema_simp_pop, context=LOCALS_SIMPLE)
+ restricted = diag.restrict(A & "id_a=0")
+ pruned_once = restricted.prune()
+ pruned_twice = pruned_once.prune()
+
+ assert pruned_once.nodes_to_show == pruned_twice.nodes_to_show
+ assert set(pruned_once._restrict_conditions) == set(pruned_twice._restrict_conditions)
+
+
+def test_prune_then_restrict(schema_simp_pop):
+ """Restrict can be called after prune."""
+ diag = dj.Diagram(schema_simp_pop, context=LOCALS_SIMPLE)
+ pruned = diag.restrict(A & "id_a < 5").prune()
+ # Restrict again on the same seed table with a tighter condition
+ further = pruned.restrict(A & "id_a=0")
+
+ # Should not raise; further restriction should narrow results
+ counts = further.counts()
+ assert all(c >= 0 for c in counts.values())
+ # Tighter restriction should produce fewer or equal rows
+ pruned_counts = pruned.counts()
+ for table in counts:
+ assert counts[table] <= pruned_counts.get(table, 0)