diff --git a/src/datajoint/adapters/base.py b/src/datajoint/adapters/base.py index 011f306ab..f5138d8c1 100644 --- a/src/datajoint/adapters/base.py +++ b/src/datajoint/adapters/base.py @@ -238,6 +238,39 @@ def parameter_placeholder(self) -> str: """ ... + def make_full_table_name(self, database: str, table_name: str) -> str: + """ + Construct a fully-qualified table name for this backend. + + Default implementation produces a two-part name (``schema.table``). + Backends that require additional namespace levels (e.g., Databricks + ``catalog.schema.table``) should override this method. + + Parameters + ---------- + database : str + Schema/database name. + table_name : str + Table name (including tier prefix). + + Returns + ------- + str + Fully-qualified, quoted table name. + """ + return f"{self.quote_identifier(database)}.{self.quote_identifier(table_name)}" + + @property + def foreign_key_action_clause(self) -> str: + """ + Referential action clause appended to FOREIGN KEY declarations. + + Default: ``ON UPDATE CASCADE ON DELETE RESTRICT`` (MySQL/PostgreSQL). + Backends that don't support referential actions (e.g., Databricks) + should override to return ``""``. + """ + return " ON UPDATE CASCADE ON DELETE RESTRICT" + # ========================================================================= # Type Mapping # ========================================================================= diff --git a/src/datajoint/autopopulate.py b/src/datajoint/autopopulate.py index 02410c9bc..0e0cbe866 100644 --- a/src/datajoint/autopopulate.py +++ b/src/datajoint/autopopulate.py @@ -776,9 +776,10 @@ def _update_job_metadata(self, key, start_time, duration, version): from .condition import make_condition pk_condition = make_condition(self, key, set()) + q = self.connection.adapter.quote_identifier self.connection.query( f"UPDATE {self.full_table_name} SET " - "`_job_start_time`=%s, `_job_duration`=%s, `_job_version`=%s " + f"{q('_job_start_time')}=%s, {q('_job_duration')}=%s, {q('_job_version')}=%s " f"WHERE {pk_condition}", args=(start_time, duration, version[:64] if version else ""), ) diff --git a/src/datajoint/codecs.py b/src/datajoint/codecs.py index d7fbaf42d..53b2956ab 100644 --- a/src/datajoint/codecs.py +++ b/src/datajoint/codecs.py @@ -557,7 +557,7 @@ def decode_attribute(attr, data, squeeze: bool = False, connection=None): # psycopg2 auto-deserializes JSON to dict/list; only parse strings if isinstance(data, str): data = json.loads(data) - elif final_dtype.lower() in ("longblob", "blob", "mediumblob", "tinyblob"): + elif final_dtype.lower() in ("longblob", "blob", "mediumblob", "tinyblob", "bytes", "binary"): pass # Blob data is already bytes elif final_dtype.lower() == "binary(16)": data = uuid_module.UUID(bytes=data) diff --git a/src/datajoint/connection.py b/src/datajoint/connection.py index e9eab0921..fd3bf35bc 100644 --- a/src/datajoint/connection.py +++ b/src/datajoint/connection.py @@ -486,19 +486,25 @@ def start_transaction(self) -> None: """ if self.in_transaction: raise errors.DataJointError("Nested connections are not supported.") - self.query(self.adapter.start_transaction_sql()) + sql = self.adapter.start_transaction_sql() + if sql: + self.query(sql) self._in_transaction = True logger.debug("Transaction started") def cancel_transaction(self) -> None: """Cancel the current transaction and roll back all changes.""" - self.query(self.adapter.rollback_sql()) + sql = self.adapter.rollback_sql() + if sql: + self.query(sql) self._in_transaction = False logger.debug("Transaction cancelled. Rolling back ...") def commit_transaction(self) -> None: """Commit all changes and close the transaction.""" - self.query(self.adapter.commit_sql()) + sql = self.adapter.commit_sql() + if sql: + self.query(sql) self._in_transaction = False logger.debug("Transaction committed and closed.") diff --git a/src/datajoint/declare.py b/src/datajoint/declare.py index 4edb0c22f..f598c2ac3 100644 --- a/src/datajoint/declare.py +++ b/src/datajoint/declare.py @@ -296,10 +296,10 @@ def compile_foreign_key( parent_full_name = ref.support[0] # Parse as database.table using the adapter's quoting convention parts = adapter.split_full_table_name(parent_full_name) - ref_table_name = f"{adapter.quote_identifier(parts[0])}.{adapter.quote_identifier(parts[1])}" + ref_table_name = adapter.make_full_table_name(parts[0], parts[1]) foreign_key_sql.append( - f"FOREIGN KEY ({fk_cols}) REFERENCES {ref_table_name} ({pk_cols}) ON UPDATE CASCADE ON DELETE RESTRICT" + f"FOREIGN KEY ({fk_cols}) REFERENCES {ref_table_name} ({pk_cols}){adapter.foreign_key_action_clause}" ) # declare unique index @@ -432,16 +432,8 @@ def declare( DataJointError If table name exceeds max length or has no primary key. """ - # Parse table name without assuming quote character - # Extract schema.table from quoted name using adapter - quote_char = adapter.quote_identifier("x")[0] # Get quote char from adapter - parts = full_table_name.split(".") - if len(parts) == 2: - schema_name = parts[0].strip(quote_char) - table_name = parts[1].strip(quote_char) - else: - schema_name = None - table_name = parts[0].strip(quote_char) + # Parse table name using adapter (handles 2-part and 3-part names) + schema_name, table_name = adapter.split_full_table_name(full_table_name) if len(table_name) > MAX_TABLE_NAME_LENGTH: raise DataJointError( @@ -924,7 +916,7 @@ def compile_attribute( # Check for invalid default values on blob types (after type substitution) # Note: blob → longblob, so check for NATIVE_BLOB or longblob result final_type = match["type"].lower() - if ("blob" in final_type) and match["default"] not in {"DEFAULT NULL", "NOT NULL"}: + if ("blob" in final_type or final_type == "binary") and match["default"] not in {"DEFAULT NULL", "NOT NULL"}: raise DataJointError("The default value for blob attributes can only be NULL in:\n{line}".format(line=line)) # Use adapter to format column definition diff --git a/src/datajoint/lineage.py b/src/datajoint/lineage.py index bb911a876..bd821404c 100644 --- a/src/datajoint/lineage.py +++ b/src/datajoint/lineage.py @@ -79,14 +79,13 @@ def lineage_table_exists(connection, database): bool True if the table exists, False otherwise. """ - result = connection.query( - """ - SELECT COUNT(*) FROM information_schema.tables - WHERE table_schema = %s AND table_name = '~lineage' - """, - args=(database,), - ).fetchone() - return result[0] > 0 + try: + result = connection.query( + connection.adapter.get_table_info_sql(database, "~lineage") + ).fetchone() + return result is not None + except Exception: + return False def get_lineage(connection, database, table_name, attribute_name): diff --git a/src/datajoint/schemas.py b/src/datajoint/schemas.py index 8747cdbf2..272c59fba 100644 --- a/src/datajoint/schemas.py +++ b/src/datajoint/schemas.py @@ -344,7 +344,7 @@ def make_classes(self, into: dict[str, Any] | None = None) -> None: tables = [ row[0] for row in self.connection.query(self.connection.adapter.list_tables_sql(self.database)) - if lookup_class_name("`{db}`.`{tab}`".format(db=self.database, tab=row[0]), into, 0) is None + if lookup_class_name(self.connection.adapter.make_full_table_name(self.database, row[0]), into, 0) is None ] master_classes = (Lookup, Manual, Imported, Computed) part_tables = [] @@ -421,7 +421,8 @@ def exists(self) -> bool: """ if self.database is None: raise DataJointError("Schema must be activated first.") - return bool(self.connection.query(self.connection.adapter.schema_exists_sql(self.database)).rowcount) + result = self.connection.query(self.connection.adapter.schema_exists_sql(self.database)) + return result.fetchone() is not None @property def lineage_table_exists(self) -> bool: @@ -502,7 +503,7 @@ def jobs(self) -> list[Job]: # Iterate over auto-populated tables and check if their job table exists for table_name in self.list_tables(): adapter = self.connection.adapter - full_name = f"{adapter.quote_identifier(self.database)}." f"{adapter.quote_identifier(table_name)}" + full_name = adapter.make_full_table_name(self.database, table_name) table = FreeTable(self.connection, full_name) tier = _get_tier(table.full_table_name) if tier in (Computed, Imported): @@ -603,7 +604,7 @@ def get_table(self, name: str) -> FreeTable: raise DataJointError(f"Table `{name}` does not exist in schema `{self.database}`.") adapter = self.connection.adapter - full_name = f"{adapter.quote_identifier(self.database)}.{adapter.quote_identifier(table_name)}" + full_name = adapter.make_full_table_name(self.database, table_name) return FreeTable(self.connection, full_name) def __getitem__(self, name: str) -> FreeTable: diff --git a/src/datajoint/settings.py b/src/datajoint/settings.py index 7019d8345..dd85c3866 100644 --- a/src/datajoint/settings.py +++ b/src/datajoint/settings.py @@ -190,7 +190,7 @@ class DatabaseSettings(BaseSettings): host: str = Field(default="localhost", validation_alias="DJ_HOST") user: str | None = Field(default=None, validation_alias="DJ_USER") password: SecretStr | None = Field(default=None, validation_alias="DJ_PASS") - backend: Literal["mysql", "postgresql"] = Field( + backend: Literal["mysql", "postgresql", "databricks"] = Field( default="mysql", validation_alias="DJ_BACKEND", description="Database backend: 'mysql' or 'postgresql'", diff --git a/src/datajoint/table.py b/src/datajoint/table.py index 256fab6e9..aa7374218 100644 --- a/src/datajoint/table.py +++ b/src/datajoint/table.py @@ -457,7 +457,7 @@ def is_declared(self): True if the table is declared in the schema. """ query = self.connection.adapter.get_table_info_sql(self.database, self.table_name) - return self.connection.query(query).rowcount > 0 + return self.connection.query(query).fetchone() is not None @property def full_table_name(self): @@ -474,7 +474,7 @@ def full_table_name(self): f"Class {self.__class__.__name__} is not associated with a schema. " "Apply a schema decorator or use schema() to bind it." ) - return f"{self.adapter.quote_identifier(self.database)}.{self.adapter.quote_identifier(self.table_name)}" + return self.adapter.make_full_table_name(self.database, self.table_name) @property def adapter(self): diff --git a/src/datajoint/user_tables.py b/src/datajoint/user_tables.py index 4c2ba8d4c..7822fa9e2 100644 --- a/src/datajoint/user_tables.py +++ b/src/datajoint/user_tables.py @@ -106,8 +106,7 @@ def full_table_name(cls): """The fully qualified table name (quoted per backend).""" if cls.database is None: return None - adapter = cls._connection.adapter - return f"{adapter.quote_identifier(cls.database)}.{adapter.quote_identifier(cls.table_name)}" + return cls._connection.adapter.make_full_table_name(cls.database, cls.table_name) class UserTable(Table, metaclass=TableMeta): @@ -186,8 +185,7 @@ def full_table_name(cls): """The fully qualified table name (quoted per backend).""" if cls.database is None or cls.table_name is None: return None - adapter = cls._connection.adapter - return f"{adapter.quote_identifier(cls.database)}.{adapter.quote_identifier(cls.table_name)}" + return cls._connection.adapter.make_full_table_name(cls.database, cls.table_name) @property def master(cls):