Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/datajoint/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,39 @@ def parameter_placeholder(self) -> str:
"""
...

def make_full_table_name(self, database: str, table_name: str) -> str:
"""
Construct a fully-qualified table name for this backend.

Default implementation produces a two-part name (``schema.table``).
Backends that require additional namespace levels (e.g., Databricks
``catalog.schema.table``) should override this method.

Parameters
----------
database : str
Schema/database name.
table_name : str
Table name (including tier prefix).

Returns
-------
str
Fully-qualified, quoted table name.
"""
return f"{self.quote_identifier(database)}.{self.quote_identifier(table_name)}"

@property
def foreign_key_action_clause(self) -> str:
"""
Referential action clause appended to FOREIGN KEY declarations.

Default: ``ON UPDATE CASCADE ON DELETE RESTRICT`` (MySQL/PostgreSQL).
Backends that don't support referential actions (e.g., Databricks)
should override to return ``""``.
"""
return " ON UPDATE CASCADE ON DELETE RESTRICT"

# =========================================================================
# Type Mapping
# =========================================================================
Expand Down
3 changes: 2 additions & 1 deletion src/datajoint/autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,9 +776,10 @@ def _update_job_metadata(self, key, start_time, duration, version):
from .condition import make_condition

pk_condition = make_condition(self, key, set())
q = self.connection.adapter.quote_identifier
self.connection.query(
f"UPDATE {self.full_table_name} SET "
"`_job_start_time`=%s, `_job_duration`=%s, `_job_version`=%s "
f"{q('_job_start_time')}=%s, {q('_job_duration')}=%s, {q('_job_version')}=%s "
f"WHERE {pk_condition}",
args=(start_time, duration, version[:64] if version else ""),
)
2 changes: 1 addition & 1 deletion src/datajoint/codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def decode_attribute(attr, data, squeeze: bool = False, connection=None):
# psycopg2 auto-deserializes JSON to dict/list; only parse strings
if isinstance(data, str):
data = json.loads(data)
elif final_dtype.lower() in ("longblob", "blob", "mediumblob", "tinyblob"):
elif final_dtype.lower() in ("longblob", "blob", "mediumblob", "tinyblob", "bytes", "binary"):
pass # Blob data is already bytes
elif final_dtype.lower() == "binary(16)":
data = uuid_module.UUID(bytes=data)
Expand Down
12 changes: 9 additions & 3 deletions src/datajoint/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,19 +486,25 @@ def start_transaction(self) -> None:
"""
if self.in_transaction:
raise errors.DataJointError("Nested connections are not supported.")
self.query(self.adapter.start_transaction_sql())
sql = self.adapter.start_transaction_sql()
if sql:
self.query(sql)
self._in_transaction = True
logger.debug("Transaction started")

def cancel_transaction(self) -> None:
"""Cancel the current transaction and roll back all changes."""
self.query(self.adapter.rollback_sql())
sql = self.adapter.rollback_sql()
if sql:
self.query(sql)
self._in_transaction = False
logger.debug("Transaction cancelled. Rolling back ...")

def commit_transaction(self) -> None:
"""Commit all changes and close the transaction."""
self.query(self.adapter.commit_sql())
sql = self.adapter.commit_sql()
if sql:
self.query(sql)
self._in_transaction = False
logger.debug("Transaction committed and closed.")

Expand Down
18 changes: 5 additions & 13 deletions src/datajoint/declare.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,10 @@ def compile_foreign_key(
parent_full_name = ref.support[0]
# Parse as database.table using the adapter's quoting convention
parts = adapter.split_full_table_name(parent_full_name)
ref_table_name = f"{adapter.quote_identifier(parts[0])}.{adapter.quote_identifier(parts[1])}"
ref_table_name = adapter.make_full_table_name(parts[0], parts[1])

foreign_key_sql.append(
f"FOREIGN KEY ({fk_cols}) REFERENCES {ref_table_name} ({pk_cols}) ON UPDATE CASCADE ON DELETE RESTRICT"
f"FOREIGN KEY ({fk_cols}) REFERENCES {ref_table_name} ({pk_cols}){adapter.foreign_key_action_clause}"
)

# declare unique index
Expand Down Expand Up @@ -432,16 +432,8 @@ def declare(
DataJointError
If table name exceeds max length or has no primary key.
"""
# Parse table name without assuming quote character
# Extract schema.table from quoted name using adapter
quote_char = adapter.quote_identifier("x")[0] # Get quote char from adapter
parts = full_table_name.split(".")
if len(parts) == 2:
schema_name = parts[0].strip(quote_char)
table_name = parts[1].strip(quote_char)
else:
schema_name = None
table_name = parts[0].strip(quote_char)
# Parse table name using adapter (handles 2-part and 3-part names)
schema_name, table_name = adapter.split_full_table_name(full_table_name)

if len(table_name) > MAX_TABLE_NAME_LENGTH:
raise DataJointError(
Expand Down Expand Up @@ -924,7 +916,7 @@ def compile_attribute(
# Check for invalid default values on blob types (after type substitution)
# Note: blob → longblob, so check for NATIVE_BLOB or longblob result
final_type = match["type"].lower()
if ("blob" in final_type) and match["default"] not in {"DEFAULT NULL", "NOT NULL"}:
if ("blob" in final_type or final_type == "binary") and match["default"] not in {"DEFAULT NULL", "NOT NULL"}:
raise DataJointError("The default value for blob attributes can only be NULL in:\n{line}".format(line=line))

# Use adapter to format column definition
Expand Down
15 changes: 7 additions & 8 deletions src/datajoint/lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,13 @@ def lineage_table_exists(connection, database):
bool
True if the table exists, False otherwise.
"""
result = connection.query(
"""
SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = %s AND table_name = '~lineage'
""",
args=(database,),
).fetchone()
return result[0] > 0
try:
result = connection.query(
connection.adapter.get_table_info_sql(database, "~lineage")
).fetchone()
return result is not None
except Exception:
return False


def get_lineage(connection, database, table_name, attribute_name):
Expand Down
9 changes: 5 additions & 4 deletions src/datajoint/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def make_classes(self, into: dict[str, Any] | None = None) -> None:
tables = [
row[0]
for row in self.connection.query(self.connection.adapter.list_tables_sql(self.database))
if lookup_class_name("`{db}`.`{tab}`".format(db=self.database, tab=row[0]), into, 0) is None
if lookup_class_name(self.connection.adapter.make_full_table_name(self.database, row[0]), into, 0) is None
]
master_classes = (Lookup, Manual, Imported, Computed)
part_tables = []
Expand Down Expand Up @@ -421,7 +421,8 @@ def exists(self) -> bool:
"""
if self.database is None:
raise DataJointError("Schema must be activated first.")
return bool(self.connection.query(self.connection.adapter.schema_exists_sql(self.database)).rowcount)
result = self.connection.query(self.connection.adapter.schema_exists_sql(self.database))
return result.fetchone() is not None

@property
def lineage_table_exists(self) -> bool:
Expand Down Expand Up @@ -502,7 +503,7 @@ def jobs(self) -> list[Job]:
# Iterate over auto-populated tables and check if their job table exists
for table_name in self.list_tables():
adapter = self.connection.adapter
full_name = f"{adapter.quote_identifier(self.database)}." f"{adapter.quote_identifier(table_name)}"
full_name = adapter.make_full_table_name(self.database, table_name)
table = FreeTable(self.connection, full_name)
tier = _get_tier(table.full_table_name)
if tier in (Computed, Imported):
Expand Down Expand Up @@ -603,7 +604,7 @@ def get_table(self, name: str) -> FreeTable:
raise DataJointError(f"Table `{name}` does not exist in schema `{self.database}`.")

adapter = self.connection.adapter
full_name = f"{adapter.quote_identifier(self.database)}.{adapter.quote_identifier(table_name)}"
full_name = adapter.make_full_table_name(self.database, table_name)
return FreeTable(self.connection, full_name)

def __getitem__(self, name: str) -> FreeTable:
Expand Down
2 changes: 1 addition & 1 deletion src/datajoint/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class DatabaseSettings(BaseSettings):
host: str = Field(default="localhost", validation_alias="DJ_HOST")
user: str | None = Field(default=None, validation_alias="DJ_USER")
password: SecretStr | None = Field(default=None, validation_alias="DJ_PASS")
backend: Literal["mysql", "postgresql"] = Field(
backend: Literal["mysql", "postgresql", "databricks"] = Field(
default="mysql",
validation_alias="DJ_BACKEND",
description="Database backend: 'mysql' or 'postgresql'",
Expand Down
4 changes: 2 additions & 2 deletions src/datajoint/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def is_declared(self):
True if the table is declared in the schema.
"""
query = self.connection.adapter.get_table_info_sql(self.database, self.table_name)
return self.connection.query(query).rowcount > 0
return self.connection.query(query).fetchone() is not None

@property
def full_table_name(self):
Expand All @@ -474,7 +474,7 @@ def full_table_name(self):
f"Class {self.__class__.__name__} is not associated with a schema. "
"Apply a schema decorator or use schema() to bind it."
)
return f"{self.adapter.quote_identifier(self.database)}.{self.adapter.quote_identifier(self.table_name)}"
return self.adapter.make_full_table_name(self.database, self.table_name)

@property
def adapter(self):
Expand Down
6 changes: 2 additions & 4 deletions src/datajoint/user_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ def full_table_name(cls):
"""The fully qualified table name (quoted per backend)."""
if cls.database is None:
return None
adapter = cls._connection.adapter
return f"{adapter.quote_identifier(cls.database)}.{adapter.quote_identifier(cls.table_name)}"
return cls._connection.adapter.make_full_table_name(cls.database, cls.table_name)


class UserTable(Table, metaclass=TableMeta):
Expand Down Expand Up @@ -186,8 +185,7 @@ def full_table_name(cls):
"""The fully qualified table name (quoted per backend)."""
if cls.database is None or cls.table_name is None:
return None
adapter = cls._connection.adapter
return f"{adapter.quote_identifier(cls.database)}.{adapter.quote_identifier(cls.table_name)}"
return cls._connection.adapter.make_full_table_name(cls.database, cls.table_name)

@property
def master(cls):
Expand Down
Loading