Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,31 @@ run = data_contract.test()
if not run.has_passed():
print("Data quality validation failed.")
# Abort pipeline, alert, or take corrective actions...

# run quality checks with DQX engine (Databricks server type only)
# requires: pip install datacontract-cli[dqx]
data_contract_dqx = DataContract(
data_contract_file="odcs.yaml",
server="production",
test_engine="dqx",
)
run_dqx = data_contract_dqx.test()

# access all executed DQX rule checks
all_checks = run_dqx.checks

# access failed/error/warning DQX rule checks
failed_checks = [c for c in run_dqx.checks if c.result in ("failed", "error", "warning")]

for check in failed_checks:
print(f"[{check.result}] {check.name} ({check.model})")
print(f"reason: {check.reason}")

# full structured output
print(run_dqx.pretty())

if not run_dqx.has_passed():
print("DQX data quality validation failed.")
```

## How to
Expand Down Expand Up @@ -236,6 +261,7 @@ A list of available extras:
| Avro Support | `pip install datacontract-cli[avro]` |
| Google BigQuery | `pip install datacontract-cli[bigquery]` |
| Databricks Integration | `pip install datacontract-cli[databricks]` |
| DQX (Databricks quality checks) | `pip install datacontract-cli[dqx]` |
| DuckDB (local/S3/GCS/Azure file testing) | `pip install datacontract-cli[duckdb]` |
| Iceberg | `pip install datacontract-cli[iceberg]` |
| Kafka Integration | `pip install datacontract-cli[kafka]` |
Expand Down
9 changes: 8 additions & 1 deletion datacontract/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,18 @@ async def test(
examples=["https://api.datamesh-manager.com/api/test-results"],
),
] = None,
test_engine: Annotated[
str,
Query(
description="The engine used for quality checks. Supported values: soda (default), dqx (Databricks only).",
examples=["soda", "dqx"],
),
] = "soda",
) -> Run:
check_api_key(api_key)
logging.info("Testing data contract...")
logging.info(body)
return DataContract(data_contract_str=body, server=server, publish_url=publish_url).test()
return DataContract(data_contract_str=body, server=server, publish_url=publish_url, test_engine=test_engine).test()


@app.post(
Expand Down
7 changes: 7 additions & 0 deletions datacontract/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ def test(
typer.Option(help="SSL verification when publishing the data contract."),
] = True,
debug: debug_option = None,
test_engine: Annotated[
str,
typer.Option(
help="The engine used for quality checks. Supported values: `soda` (default), `dqx` (Databricks only)."
),
] = "soda",
):
"""
Run schema and quality tests on configured servers.
Expand All @@ -177,6 +183,7 @@ def test(
publish_url=publish,
server=server,
ssl_verification=ssl_verification,
test_engine=test_engine,
).test()
if logs:
_print_logs(run)
Expand Down
11 changes: 10 additions & 1 deletion datacontract/data_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
inline_definitions: bool = True,
ssl_verification: bool = True,
publish_test_results: bool = False,
test_engine: str = "soda",
):
self._data_contract_file = data_contract_file
self._data_contract_str = data_contract_str
Expand All @@ -44,6 +45,7 @@ def __init__(
self._duckdb_connection = duckdb_connection
self._inline_definitions = inline_definitions
self._ssl_verification = ssl_verification
self._test_engine = test_engine

@classmethod
def init(cls, template: typing.Optional[str], schema: typing.Optional[str] = None) -> OpenDataContractStandard:
Expand Down Expand Up @@ -103,7 +105,14 @@ def test(self) -> Run:
inline_definitions=self._inline_definitions,
)

execute_data_contract_test(data_contract, run, self._server, self._spark, self._duckdb_connection)
execute_data_contract_test(
data_contract,
run,
self._server,
self._spark,
self._duckdb_connection,
self._test_engine,
)

except DataContractException as e:
run.checks.append(
Expand Down
25 changes: 21 additions & 4 deletions datacontract/engines/data_contract_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from datacontract.engines.datacontract.check_that_datacontract_contains_valid_servers_configuration import (
check_that_datacontract_contains_valid_server_configuration,
)
from datacontract.engines.dqx.check_dqx_execute import check_dqx_execute
from datacontract.engines.fastjsonschema.check_jsonschema import check_jsonschema
from datacontract.engines.soda.check_soda_execute import check_soda_execute
from datacontract.model.exceptions import DataContractException
Expand All @@ -27,6 +28,7 @@ def execute_data_contract_test(
server_name: str = None,
spark: "SparkSession" = None,
duckdb_connection: "DuckDBPyConnection" = None,
test_engine: str = "soda",
):
if data_contract.schema_ is None or len(data_contract.schema_) == 0:
raise DataContractException(
Expand All @@ -53,13 +55,28 @@ def execute_data_contract_test(
if server.type == "api":
server = process_api_response(run, server)

run.checks.extend(create_checks(data_contract, server))
normalized_test_engine = test_engine.lower() if test_engine is not None else "soda"

if normalized_test_engine not in ["soda", "dqx"]:
raise DataContractException(
type="test",
name="Check that test engine is supported",
result=ResultEnum.error,
reason=f"Unsupported test engine '{test_engine}'. Supported values are: soda, dqx.",
engine="datacontract",
)

if normalized_test_engine == "soda":
run.checks.extend(create_checks(data_contract, server))

# TODO check server is supported type for nicer error messages
# TODO check server credentials are complete for nicer error messages
if server.format == "json" and server.type != "kafka":
check_jsonschema(run, data_contract, server)
check_soda_execute(run, data_contract, server, spark, duckdb_connection)
if normalized_test_engine == "dqx":
check_dqx_execute(run, data_contract, server, spark)
else:
if server.format == "json" and server.type != "kafka":
check_jsonschema(run, data_contract, server)
check_soda_execute(run, data_contract, server, spark, duckdb_connection)


def get_server(data_contract: OpenDataContractStandard, server_name: str = None) -> Server | None:
Expand Down
Empty file.
201 changes: 201 additions & 0 deletions datacontract/engines/dqx/check_dqx_execute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import typing
import uuid

import yaml
from open_data_contract_standard.model import OpenDataContractStandard, Server

if typing.TYPE_CHECKING:
from pyspark.sql import SparkSession

from datacontract.engines.data_contract_checks import to_schema_name
from datacontract.export.dqx_exporter import extract_quality_rules
from datacontract.model.run import Check, ResultEnum, Run


def _get_rule_check_name(dqx_rule: dict, index: int) -> str:
check_metadata = dqx_rule.get("check") or {}
function_name = check_metadata.get("function") or "dqx_rule"
return dqx_rule.get("name") or f"{function_name}_{index + 1}"


def check_dqx_execute(
run: Run,
data_contract: OpenDataContractStandard,
server: Server,
spark: "SparkSession" = None,
):
if data_contract is None:
run.log_warn("Cannot run engine dqx, as data contract is invalid")
return

if server.type != "databricks":
run.log_info(
f"DQX execution is only available for server type 'databricks'. "
f"Configured server type is '{server.type}'. Skipping DQX checks."
)
return

rules_by_schema: list[tuple[str, list[dict]]] = []
for schema_obj in data_contract.schema_ or []:
schema_name = to_schema_name(schema_obj, server.type)
dqx_rules = extract_quality_rules(schema_obj)
if dqx_rules:
rules_by_schema.append((schema_name, dqx_rules))

try:
from databricks.labs.dqx.engine import DQEngine
from databricks.sdk import WorkspaceClient
from pyspark.sql import SparkSession
except ImportError:
run.log_warn(
"Cannot run engine dqx, dependencies are missing. "
"Install datacontract-cli[dqx] to enable DQX execution."
)
for schema_name, dqx_rules in rules_by_schema:
for index, dqx_rule in enumerate(dqx_rules):
check_name = _get_rule_check_name(dqx_rule, index)
run.checks.append(
Check(
id=str(uuid.uuid4()),
key=f"{schema_name}__{check_name}__dqx",
category="quality",
type="custom",
name=check_name,
model=schema_name,
engine="dqx",
language="python",
implementation=yaml.dump(dqx_rule, sort_keys=False),
result=ResultEnum.error,
reason="DQX dependencies are missing. Install datacontract-cli[dqx] to enable DQX execution.",
)
)
return

# Resolve or create a Spark session.
# Priority:
# 1. Explicitly provided Spark session (programmatic API).
# 2. Existing active session (e.g. running on Databricks cluster).
# 3. Databricks Connect session (if databricks-connect is installed).
# 4. Fallback SparkSession.builder.getOrCreate() (e.g. local Spark).
spark_session = spark or SparkSession.getActiveSession()

if spark_session is None:
# Try Databricks Connect first (optional dependency).
try:
from databricks.connect import DatabricksSession # type: ignore[import-not-found]

run.log_info("Creating Spark session via Databricks Connect (DatabricksSession).")
spark_session = DatabricksSession.builder.getOrCreate()
except Exception:
spark_session = None

if spark_session is None:
try:
run.log_info("Creating Spark session via SparkSession.builder.getOrCreate().")
spark_session = SparkSession.builder.getOrCreate()
except Exception:
spark_session = None

if spark_session is None:
run.log_warn("Cannot run engine dqx, as no active Spark session is available.")
for schema_name, dqx_rules in rules_by_schema:
for index, dqx_rule in enumerate(dqx_rules):
check_name = _get_rule_check_name(dqx_rule, index)
run.checks.append(
Check(
id=str(uuid.uuid4()),
key=f"{schema_name}__{check_name}__dqx",
category="quality",
type="custom",
name=check_name,
model=schema_name,
engine="dqx",
language="python",
implementation=yaml.dump(dqx_rule, sort_keys=False),
result=ResultEnum.error,
reason="No active Spark session is available to execute DQX checks.",
)
)
return

run.log_info("Running engine dqx")
dq_engine = DQEngine(workspace_client=WorkspaceClient(), spark=spark_session)

for schema_name, dqx_rules in rules_by_schema:

run.log_info(f"Running {len(dqx_rules)} DQX checks for model {schema_name}")

try:
model_df = spark_session.read.table(schema_name)
except Exception as exc:
run.log_error(str(exc))
for index, dqx_rule in enumerate(dqx_rules):
check_name = _get_rule_check_name(dqx_rule, index)
run.checks.append(
Check(
id=str(uuid.uuid4()),
key=f"{schema_name}__{check_name}__dqx",
category="quality",
type="custom",
name=check_name,
model=schema_name,
engine="dqx",
language="python",
implementation=yaml.dump(dqx_rule, sort_keys=False),
result=ResultEnum.error,
reason=str(exc),
)
)
continue

for index, dqx_rule in enumerate(dqx_rules):
check_name = _get_rule_check_name(dqx_rule, index)
check_key = f"{schema_name}__{check_name}__dqx"

try:
passed_df, invalid_df = dq_engine.apply_checks_by_metadata_and_split(model_df, [dqx_rule])
violations = invalid_df.count()

criticality = str(dqx_rule.get("criticality", "error")).lower()
if violations == 0:
result = ResultEnum.passed
reason = f"all {passed_df.count()} row(s) passed the DQX rule"
elif criticality == "warn":
result = ResultEnum.warning
reason = f"{violations} row(s) violated the warning DQX rule"
else:
result = ResultEnum.failed
reason = f"{violations} row(s) violated the DQX rule"

run.checks.append(
Check(
id=str(uuid.uuid4()),
key=check_key,
category="quality",
type="custom",
name=check_name,
model=schema_name,
engine="dqx",
language="python",
implementation=yaml.dump(dqx_rule, sort_keys=False),
result=result,
reason=reason,
)
)
except Exception as exc:
run.log_error(str(exc))
run.checks.append(
Check(
id=str(uuid.uuid4()),
key=check_key,
category="quality",
type="custom",
name=check_name,
model=schema_name,
engine="dqx",
language="python",
implementation=yaml.dump(dqx_rule, sort_keys=False),
result=ResultEnum.error,
reason=str(exc),
)
)
7 changes: 7 additions & 0 deletions datacontract/export/dqx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ def process_quality_rule(rule: DataQuality, column_name: str) -> Dict[str, Any]:
implementation = rule.implementation
check = implementation[DqxKeys.CHECK]

# Ensure each rule has a stable name so that DQX doesn't
# try to infer it from the Spark column expression (which can
# trigger issues with certain Spark Connect representations).
if "name" not in implementation:
function_name = check.get(DqxKeys.FUNCTION, "dqx_rule")
implementation["name"] = f"{column_name}__{function_name}" if column_name else function_name

if column_name:
arguments = check.setdefault(DqxKeys.ARGUMENTS, {})

Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ databricks = [
"numpy>=1.26.4,<2.0.0", # pyspark incompatible with numpy 2.0
]

dqx = [
"databricks-labs-dqx[datacontract]>=0.13.0,<1.0.0",
]

iceberg = [
"pyiceberg==0.10.0"
]
Expand Down Expand Up @@ -134,7 +138,7 @@ protobuf = [
]

all = [
"datacontract-cli[kafka,bigquery,csv,excel,snowflake,postgres,databricks,sqlserver,s3,athena,trino,dbt,dbml,duckdb,iceberg,parquet,rdf,api,protobuf,oracle]"
"datacontract-cli[kafka,bigquery,csv,excel,snowflake,postgres,databricks,dqx,sqlserver,s3,athena,trino,dbt,dbml,duckdb,iceberg,parquet,rdf,api,protobuf,oracle]"
]

# for development, we pin all libraries to an exact version
Expand Down
Loading
Loading