diff --git a/knip.json b/knip.json index fae5b9c1..f4092643 100644 --- a/knip.json +++ b/knip.json @@ -3,6 +3,7 @@ "ignoreWorkspaces": [ "packages/shared", "packages/lakebase", + "packages/appkit-py", "apps/**", "docs" ], @@ -15,7 +16,9 @@ "**/*.example.tsx", "template/**", "tools/**", - "docs/**" + "docs/**", + "client/**", + "test-e2e-minimal.ts" ], "ignoreBinaries": ["tarball"] } diff --git a/packages/appkit-py/.gitignore b/packages/appkit-py/.gitignore new file mode 100644 index 00000000..6719fa2a --- /dev/null +++ b/packages/appkit-py/.gitignore @@ -0,0 +1,25 @@ +# Python +__pycache__/ +*.py[cod] +*.egg-info/ +*.egg +dist/ +build/ +.eggs/ + +# Virtual environment +.venv/ +venv/ + +# IDE +.idea/ +.vscode/ +*.swp + +# Testing +.pytest_cache/ +htmlcov/ +.coverage + +# OS +.DS_Store diff --git a/packages/appkit-py/pyproject.toml b/packages/appkit-py/pyproject.toml new file mode 100644 index 00000000..19ca8d47 --- /dev/null +++ b/packages/appkit-py/pyproject.toml @@ -0,0 +1,48 @@ +[project] +name = "appkit-py" +version = "0.1.0" +description = "Python backend for Databricks AppKit — 100% API compatible with the TypeScript version" +requires-python = ">=3.12" +dependencies = [ + "fastapi>=0.115", + "uvicorn[standard]>=0.30", + "starlette>=0.40", + "databricks-sdk>=0.30", + "pyarrow>=14.0", + "httpx>=0.27", + "pydantic>=2.0", + "cachetools>=5.3", + "python-dotenv>=1.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0", + "pytest-asyncio>=0.23", + "httpx>=0.27", + "pytest-cov>=5.0", + "ruff>=0.5", + "mypy>=1.10", +] + +[build-system] +requires = ["setuptools>=68.0"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +markers = [ + "integration: marks tests that require a running backend server", + "unit: marks unit tests that run in isolation", +] + +[tool.ruff] +target-version = "py312" +line-length = 100 + +[tool.ruff.lint] +select = ["E", "F", "I", "W"] diff --git a/packages/appkit-py/src/appkit_py/__init__.py b/packages/appkit-py/src/appkit_py/__init__.py new file mode 100644 index 00000000..fc431487 --- /dev/null +++ b/packages/appkit-py/src/appkit_py/__init__.py @@ -0,0 +1 @@ +"""Python backend for Databricks AppKit — 100% API compatible with the TypeScript version.""" diff --git a/packages/appkit-py/src/appkit_py/__main__.py b/packages/appkit-py/src/appkit_py/__main__.py new file mode 100644 index 00000000..6f18a45f --- /dev/null +++ b/packages/appkit-py/src/appkit_py/__main__.py @@ -0,0 +1,25 @@ +"""Entry point for running the AppKit Python backend with `python -m appkit_py`.""" + +import os + +from dotenv import load_dotenv + + +def main() -> None: + load_dotenv() + + import uvicorn + + from appkit_py.server import create_server + + # Match TS AppKit env vars for compatibility + host = os.environ.get("FLASK_RUN_HOST", os.environ.get("APPKIT_HOST", "0.0.0.0")) + port = int(os.environ.get("DATABRICKS_APP_PORT", "8000")) + log_level = os.environ.get("APPKIT_LOG_LEVEL", "info") + + app = create_server() + uvicorn.run(app, host=host, port=port, log_level=log_level) + + +if __name__ == "__main__": + main() diff --git a/packages/appkit-py/src/appkit_py/app/__init__.py b/packages/appkit-py/src/appkit_py/app/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/src/appkit_py/cache/__init__.py b/packages/appkit-py/src/appkit_py/cache/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/src/appkit_py/cache/cache_manager.py b/packages/appkit-py/src/appkit_py/cache/cache_manager.py new file mode 100644 index 00000000..95c45d0d --- /dev/null +++ b/packages/appkit-py/src/appkit_py/cache/cache_manager.py @@ -0,0 +1,86 @@ +"""CacheManager with TTL-based in-memory caching. + +Mirrors the TypeScript CacheManager from packages/appkit/src/cache/index.ts. +""" + +from __future__ import annotations + +import hashlib +import json +import time +from typing import Any, Awaitable, Callable, TypeVar + +T = TypeVar("T") + + +class CacheManager: + """In-memory TTL cache with SHA256 key generation.""" + + _instance: CacheManager | None = None + + def __init__(self) -> None: + self._store: dict[str, tuple[Any, float]] = {} # key -> (value, expires_at) + + @classmethod + def get_instance(cls) -> CacheManager: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def get_instance_sync(cls) -> CacheManager: + return cls.get_instance() + + @classmethod + def reset(cls) -> None: + cls._instance = None + + def generate_key(self, parts: list[Any], user_key: str) -> str: + """Generate a SHA256 cache key from parts and user key.""" + raw = json.dumps([user_key] + [str(p) for p in parts], sort_keys=True) + return hashlib.sha256(raw.encode()).hexdigest() + + async def get_or_execute( + self, + key_parts: list[Any], + fn: Callable[[], Awaitable[T]], + user_key: str, + ttl: float = 300, + ) -> T: + """Get cached value or execute function and cache the result.""" + cache_key = self.generate_key(key_parts, user_key) + + # Check cache + if cache_key in self._store: + value, expires_at = self._store[cache_key] + if time.time() < expires_at: + return value + else: + del self._store[cache_key] + + # Execute and cache + result = await fn() + self._store[cache_key] = (result, time.time() + ttl) + return result + + def get(self, key: str) -> Any | None: + if key in self._store: + value, expires_at = self._store[key] + if time.time() < expires_at: + return value + del self._store[key] + return None + + def set(self, key: str, value: Any, ttl: float = 300) -> None: + self._store[key] = (value, time.time() + ttl) + + def delete(self, key: str) -> None: + self._store.pop(key, None) + + def has(self, key: str) -> bool: + if key in self._store: + _, expires_at = self._store[key] + if time.time() < expires_at: + return True + del self._store[key] + return False diff --git a/packages/appkit-py/src/appkit_py/connectors/__init__.py b/packages/appkit-py/src/appkit_py/connectors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/src/appkit_py/connectors/files/__init__.py b/packages/appkit-py/src/appkit_py/connectors/files/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/src/appkit_py/connectors/files/client.py b/packages/appkit-py/src/appkit_py/connectors/files/client.py new file mode 100644 index 00000000..21598488 --- /dev/null +++ b/packages/appkit-py/src/appkit_py/connectors/files/client.py @@ -0,0 +1,168 @@ +"""Files connector wrapping databricks.sdk. + +Mirrors packages/appkit/src/connectors/files/client.ts +""" + +from __future__ import annotations + +import asyncio +import io +import logging +import mimetypes +from typing import Any + +from databricks.sdk import WorkspaceClient + +logger = logging.getLogger("appkit.connector.files") + +# Maximum path length (matching TS) +MAX_PATH_LENGTH = 4096 + + +class FilesConnector: + """Perform file operations on Unity Catalog Volumes via Databricks SDK.""" + + def __init__(self, default_volume: str | None = None) -> None: + self.default_volume = default_volume or "" + + def resolve_path(self, file_path: str) -> str: + """Resolve a relative path against the default volume. + + Rejects path traversal sequences to prevent escaping the volume. + """ + # Reject traversal sequences + if ".." in file_path: + raise ValueError(f"Path must not contain '..': {file_path}") + + if file_path.startswith("/Volumes/"): + return file_path + # Strip leading slash and join with volume path + clean = file_path.lstrip("/") + return f"{self.default_volume.rstrip('/')}/{clean}" + + async def list( + self, client: WorkspaceClient, directory_path: str | None = None + ) -> list[dict[str, Any]]: + """List directory contents.""" + path = self.resolve_path(directory_path or "") + entries = await asyncio.to_thread( + lambda: list(client.files.list_directory_contents(path)) + ) + return [ + { + "name": e.name, + "path": e.path, + "is_directory": e.is_directory, + "file_size": e.file_size, + "last_modified": e.last_modified, + } + for e in entries + ] + + async def read( + self, client: WorkspaceClient, file_path: str, options: dict | None = None + ) -> str: + """Read file as text, enforcing optional maxSize limit.""" + max_size = (options or {}).get("maxSize") + path = self.resolve_path(file_path) + response = await asyncio.to_thread(client.files.download, path) + + if max_size: + content = response.contents.read(max_size + 1) + if isinstance(content, bytes) and len(content) > max_size: + raise ValueError( + f"File exceeds maximum read size ({max_size} bytes)" + ) + else: + content = response.contents.read() + + if isinstance(content, bytes): + return content.decode("utf-8", errors="replace") + return content + + async def download( + self, client: WorkspaceClient, file_path: str + ) -> dict[str, Any]: + """Download file as binary stream.""" + path = self.resolve_path(file_path) + response = await asyncio.to_thread(client.files.download, path) + return {"contents": response.contents, "content_type": response.content_type} + + async def exists(self, client: WorkspaceClient, file_path: str) -> bool: + """Check if a file exists.""" + path = self.resolve_path(file_path) + try: + await asyncio.to_thread(client.files.get_metadata, path) + return True + except Exception: + return False + + async def metadata( + self, client: WorkspaceClient, file_path: str + ) -> dict[str, Any]: + """Get file metadata.""" + path = self.resolve_path(file_path) + meta = await asyncio.to_thread(client.files.get_metadata, path) + return { + "contentLength": meta.content_length, + "contentType": meta.content_type, + "lastModified": str(meta.last_modified) if meta.last_modified else None, + } + + async def upload( + self, + client: WorkspaceClient, + file_path: str, + contents: bytes | io.IOBase, + options: dict | None = None, + ) -> None: + """Upload file contents.""" + path = self.resolve_path(file_path) + overwrite = (options or {}).get("overwrite", True) + if isinstance(contents, bytes): + contents = io.BytesIO(contents) + await asyncio.to_thread( + client.files.upload, path, contents, overwrite=overwrite + ) + + async def create_directory( + self, client: WorkspaceClient, directory_path: str + ) -> None: + """Create a directory.""" + path = self.resolve_path(directory_path) + await asyncio.to_thread(client.files.create_directory, path) + + async def delete(self, client: WorkspaceClient, file_path: str) -> None: + """Delete a file.""" + path = self.resolve_path(file_path) + await asyncio.to_thread(client.files.delete, path) + + async def preview( + self, client: WorkspaceClient, file_path: str + ) -> dict[str, Any]: + """Get a preview of a file (metadata + text preview for text files).""" + path = self.resolve_path(file_path) + meta = await asyncio.to_thread(client.files.get_metadata, path) + content_type = meta.content_type or mimetypes.guess_type(file_path)[0] or "" + is_text = content_type.startswith("text/") or content_type in ( + "application/json", "application/xml", "application/javascript", + ) + is_image = content_type.startswith("image/") + + text_preview = None + if is_text: + try: + response = await asyncio.to_thread(client.files.download, path) + raw = response.contents.read(1024) + text_preview = raw.decode("utf-8", errors="replace") if isinstance(raw, bytes) else raw + except Exception: + pass + + return { + "contentLength": meta.content_length, + "contentType": meta.content_type, + "lastModified": str(meta.last_modified) if meta.last_modified else None, + "textPreview": text_preview, + "isText": is_text, + "isImage": is_image, + } diff --git a/packages/appkit-py/src/appkit_py/connectors/genie/__init__.py b/packages/appkit-py/src/appkit_py/connectors/genie/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/src/appkit_py/connectors/genie/client.py b/packages/appkit-py/src/appkit_py/connectors/genie/client.py new file mode 100644 index 00000000..a3602787 --- /dev/null +++ b/packages/appkit-py/src/appkit_py/connectors/genie/client.py @@ -0,0 +1,234 @@ +"""Genie connector wrapping databricks.sdk. + +Mirrors packages/appkit/src/connectors/genie/client.ts +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any, AsyncGenerator + +from databricks.sdk import WorkspaceClient + +logger = logging.getLogger("appkit.connector.genie") + + +class GenieConnector: + """Interact with Databricks AI/BI Genie via the SDK.""" + + def __init__(self, timeout: float = 120.0, max_messages: int = 200) -> None: + self.timeout = timeout + self.max_messages = max_messages + + async def stream_send_message( + self, + client: WorkspaceClient, + space_id: str, + content: str, + conversation_id: str | None = None, + *, + timeout: float | None = None, + signal: asyncio.Event | None = None, + ) -> AsyncGenerator[dict[str, Any], None]: + """Send a message and stream events.""" + if conversation_id: + # Existing conversation + waiter = await asyncio.to_thread( + client.genie.create_message, space_id, conversation_id, content + ) + else: + # New conversation + waiter = await asyncio.to_thread( + client.genie.start_conversation, space_id, content + ) + + # Yield message_start + msg_id = getattr(waiter, "message_id", None) or "pending" + conv_id = conversation_id or getattr(waiter, "conversation_id", None) or "new" + yield { + "type": "message_start", + "conversationId": conv_id, + "messageId": msg_id, + "spaceId": space_id, + } + + # Yield status + yield {"type": "status", "status": "EXECUTING"} + + # Wait for completion + try: + result = await asyncio.to_thread( + waiter.result, timeout=self.timeout + ) + + conv_id = result.conversation_id or conv_id + msg_id = result.id or msg_id + + # Build message response + message_response = { + "messageId": msg_id, + "conversationId": conv_id, + "spaceId": space_id, + "status": result.status.value if result.status else "COMPLETED", + "content": result.content or "", + "attachments": [], + } + + if result.attachments: + for att in result.attachments: + att_data: dict[str, Any] = {} + if att.query: + att_data["query"] = { + "title": getattr(att.query, "title", None), + "description": getattr(att.query, "description", None), + "query": getattr(att.query, "query", None), + } + if att.text: + att_data["text"] = {"content": getattr(att.text, "content", None)} + message_response["attachments"].append(att_data) + + yield {"type": "message_result", "message": message_response} + + # Fetch query results for attachments + if result.attachments: + for att in result.attachments: + if att.query and hasattr(att, "id") and att.id: + try: + query_result = await asyncio.to_thread( + client.genie.execute_message_attachment_query, + space_id, conv_id, msg_id, att.id, + ) + yield { + "type": "query_result", + "attachmentId": att.id, + "statementId": getattr(query_result, "statement_id", ""), + "data": _serialize_query_result(query_result), + } + except Exception as exc: + logger.warning("Failed to fetch query result: %s", exc) + + except Exception as exc: + yield {"type": "error", "error": str(exc)} + + async def stream_conversation( + self, + client: WorkspaceClient, + space_id: str, + conversation_id: str, + *, + include_query_results: bool = True, + page_token: str | None = None, + signal: asyncio.Event | None = None, + ) -> AsyncGenerator[dict[str, Any], None]: + """Stream conversation history.""" + try: + result = await asyncio.to_thread( + client.genie.list_conversation_messages, + space_id, conversation_id, + page_token=page_token, + page_size=self.max_messages, + ) + + messages = result.messages or [] + for msg in messages: + yield { + "type": "message_result", + "message": { + "messageId": msg.id, + "conversationId": conversation_id, + "spaceId": space_id, + "status": msg.status.value if msg.status else "COMPLETED", + "content": msg.content or "", + "attachments": [], + }, + } + + yield { + "type": "history_info", + "conversationId": conversation_id, + "spaceId": space_id, + "nextPageToken": result.next_page_token, + "loadedCount": len(messages), + } + + except Exception as exc: + yield {"type": "error", "error": str(exc)} + + async def stream_get_message( + self, + client: WorkspaceClient, + space_id: str, + conversation_id: str, + message_id: str, + *, + timeout: float | None = None, + signal: asyncio.Event | None = None, + ) -> AsyncGenerator[dict[str, Any], None]: + """Stream a single message (poll until complete).""" + try: + result = await asyncio.to_thread( + client.genie.get_message, + space_id, conversation_id, message_id, + ) + + yield { + "type": "message_result", + "message": { + "messageId": result.id, + "conversationId": conversation_id, + "spaceId": space_id, + "status": result.status.value if result.status else "COMPLETED", + "content": result.content or "", + "attachments": [], + }, + } + + except Exception as exc: + yield {"type": "error", "error": str(exc)} + + async def get_conversation( + self, + client: WorkspaceClient, + space_id: str, + conversation_id: str, + ) -> dict[str, Any]: + """Get full conversation (non-streaming).""" + result = await asyncio.to_thread( + client.genie.list_conversation_messages, + space_id, conversation_id, + ) + return { + "messages": [ + { + "messageId": msg.id, + "conversationId": conversation_id, + "spaceId": space_id, + "status": msg.status.value if msg.status else "COMPLETED", + "content": msg.content or "", + } + for msg in (result.messages or []) + ], + "nextPageToken": result.next_page_token, + } + + +def _serialize_query_result(result: Any) -> dict[str, Any]: + """Serialize a GenieGetMessageQueryResultResponse to match TS format.""" + columns = [] + data_array = [] + if hasattr(result, "columns") and result.columns: + columns = [{"name": c.name, "type_name": c.type_name} for c in result.columns] + if hasattr(result, "statement_response") and result.statement_response: + sr = result.statement_response + if sr.manifest and sr.manifest.schema and sr.manifest.schema.columns: + columns = [ + {"name": c.name, "type_name": c.type_name} + for c in sr.manifest.schema.columns + ] + if sr.result and sr.result.data_array: + data_array = sr.result.data_array + return { + "manifest": {"schema": {"columns": columns}}, + "result": {"data_array": data_array}, + } diff --git a/packages/appkit-py/src/appkit_py/connectors/sql_warehouse/__init__.py b/packages/appkit-py/src/appkit_py/connectors/sql_warehouse/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/src/appkit_py/connectors/sql_warehouse/client.py b/packages/appkit-py/src/appkit_py/connectors/sql_warehouse/client.py new file mode 100644 index 00000000..3a28109c --- /dev/null +++ b/packages/appkit-py/src/appkit_py/connectors/sql_warehouse/client.py @@ -0,0 +1,187 @@ +"""SQL Warehouse connector wrapping databricks.sdk. + +Mirrors packages/appkit/src/connectors/sql-warehouse/client.ts +""" + +from __future__ import annotations + +import asyncio +import base64 +import logging +import time +from typing import Any + +import httpx +import pyarrow as pa +import pyarrow.ipc as ipc +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.sql import ( + Disposition, + Format, + StatementParameterListItem, + StatementResponse, + StatementState, +) + +logger = logging.getLogger("appkit.connector.sql") + +# States that indicate the query is still running +_PENDING_STATES = {StatementState.PENDING, StatementState.RUNNING} +_FAILED_STATES = {StatementState.FAILED, StatementState.CANCELED, StatementState.CLOSED} + + +def decode_arrow_attachment(attachment_b64: str) -> list[dict[str, Any]]: + """Decode a base64 Arrow IPC attachment into row dicts. + + Mirrors the TS _transformArrowAttachment: base64 → Arrow IPC → row objects. + """ + buf = base64.b64decode(attachment_b64) + reader = ipc.open_stream(buf) + table = reader.read_all() + return table.to_pylist() + + +class SQLWarehouseConnector: + """Execute SQL statements against a Databricks SQL Warehouse.""" + + def __init__(self, timeout: float = 60.0) -> None: + self.timeout = timeout + + async def execute_statement( + self, + client: WorkspaceClient, + *, + statement: str, + warehouse_id: str, + parameters: list[dict[str, Any]] | None = None, + disposition: str = "INLINE", + format: str = "JSON_ARRAY", + wait_timeout: str = "30s", + ) -> StatementResponse: + """Execute a SQL statement and poll until completion.""" + sdk_params = None + if parameters: + sdk_params = [ + StatementParameterListItem( + name=p["name"], + value=p.get("value"), + type=p.get("type"), + ) + for p in parameters + ] + + disp = Disposition(disposition) + fmt = Format(format) + + response = await asyncio.to_thread( + client.statement_execution.execute_statement, + statement=statement, + warehouse_id=warehouse_id, + parameters=sdk_params, + disposition=disp, + format=fmt, + wait_timeout=wait_timeout, + ) + + # Poll if still pending + if response.status and response.status.state in _PENDING_STATES: + response = await self._poll_until_done(client, response.statement_id) + + # Check for terminal failure states + if response.status and response.status.state in _FAILED_STATES: + error_msg = "" + if response.status.error: + error_msg = getattr(response.status.error, "message", str(response.status.error)) + raise RuntimeError( + f"Statement {response.statement_id} failed with state " + f"{response.status.state.value}: {error_msg}" + ) + + return response + + def transform_result(self, response: StatementResponse) -> list[dict[str, Any]]: + """Transform a StatementResponse into row dicts. + + Handles three result shapes (matching TS _transformDataArray): + 1. Inline Arrow IPC attachment (serverless warehouses) → decode base64 + 2. data_array (classic warehouses) → zip with column names + 3. external_links (large results) → not transformed here + """ + result = response.result + if result is None: + return [] + + # 1. Inline Arrow IPC attachment + attachment = getattr(result, "attachment", None) + if attachment: + try: + return decode_arrow_attachment(attachment) + except Exception as exc: + logger.warning("Failed to decode inline Arrow IPC attachment: %s", exc) + # Fall through to data_array + + # 2. data_array (JSON format) + if result.data_array: + columns: list[str] = [] + if response.manifest and response.manifest.schema and response.manifest.schema.columns: + columns = [c.name for c in response.manifest.schema.columns] + rows: list[dict[str, Any]] = [] + for row in result.data_array: + if columns: + rows.append(dict(zip(columns, row))) + else: + rows.append({"values": row}) + return rows + + return [] + + async def _poll_until_done( + self, client: WorkspaceClient, statement_id: str + ) -> StatementResponse: + """Poll a statement until it reaches a terminal state.""" + delay = 1.0 + deadline = time.monotonic() + self.timeout + + while time.monotonic() < deadline: + await asyncio.sleep(delay) + response = await asyncio.to_thread( + client.statement_execution.get_statement, statement_id + ) + if response.status and response.status.state not in _PENDING_STATES: + return response + delay = min(delay * 1.5, 5.0) + + raise TimeoutError(f"Statement {statement_id} did not complete within {self.timeout}s") + + async def get_arrow_data( + self, client: WorkspaceClient, job_id: str + ) -> dict[str, Any]: + """Fetch Arrow binary data for a completed statement. + + Downloads external link chunks and concatenates into a single buffer. + """ + response = await asyncio.to_thread( + client.statement_execution.get_statement, job_id + ) + + if not response.result: + raise ValueError(f"No result available for job {job_id}") + + # Check for inline attachment first + attachment = getattr(response.result, "attachment", None) + if attachment: + return {"data": base64.b64decode(attachment)} + + # Download from external links + if response.result.external_links: + chunks: list[bytes] = [] + async with httpx.AsyncClient(timeout=30.0) as http: + for link in response.result.external_links: + url = getattr(link, "external_link", None) or getattr(link, "url", None) + if url: + resp = await http.get(url) + resp.raise_for_status() + chunks.append(resp.content) + return {"data": b"".join(chunks)} + + raise ValueError(f"No Arrow data available for job {job_id}") diff --git a/packages/appkit-py/src/appkit_py/context/__init__.py b/packages/appkit-py/src/appkit_py/context/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/src/appkit_py/context/execution_context.py b/packages/appkit-py/src/appkit_py/context/execution_context.py new file mode 100644 index 00000000..bd23e141 --- /dev/null +++ b/packages/appkit-py/src/appkit_py/context/execution_context.py @@ -0,0 +1,50 @@ +"""Execution context using Python contextvars. + +This is the Python equivalent of the TypeScript AsyncLocalStorage-based +context from packages/appkit/src/context/execution-context.ts. +""" + +from __future__ import annotations + +import contextvars +from typing import Any, Awaitable, Callable, TypeVar + +from .user_context import UserContext + +T = TypeVar("T") + +_user_context_var: contextvars.ContextVar[UserContext | None] = contextvars.ContextVar( + "user_context", default=None +) + + +async def run_in_user_context(user_context: UserContext, fn: Callable[[], Awaitable[T]]) -> T: + """Run an async function in a user context.""" + token = _user_context_var.set(user_context) + try: + return await fn() + finally: + _user_context_var.reset(token) + + +def get_user_context() -> UserContext | None: + """Get the current user context, or None if not in a user context.""" + return _user_context_var.get() + + +def get_execution_context() -> UserContext | None: + """Get the current execution context (user or None for service principal).""" + return _user_context_var.get() + + +def get_current_user_id() -> str: + """Get the current user ID, or 'service-principal' if not in user context.""" + ctx = _user_context_var.get() + if ctx is not None: + return ctx.user_id + return "service-principal" + + +def is_in_user_context() -> bool: + """Check if currently running in a user context.""" + return _user_context_var.get() is not None diff --git a/packages/appkit-py/src/appkit_py/context/service_context.py b/packages/appkit-py/src/appkit_py/context/service_context.py new file mode 100644 index 00000000..647a8d74 --- /dev/null +++ b/packages/appkit-py/src/appkit_py/context/service_context.py @@ -0,0 +1,33 @@ +"""Service context singleton for the Databricks workspace client.""" + +from __future__ import annotations + +from .user_context import UserContext + + +class ServiceContext: + """Singleton holding the service principal workspace client.""" + + _instance: ServiceContext | None = None + + def __init__(self) -> None: + self.service_user_id: str = "service-principal" + + @classmethod + def initialize(cls) -> ServiceContext: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def get(cls) -> ServiceContext: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def reset(cls) -> None: + cls._instance = None + + def create_user_context(self, token: str, user_id: str, user_name: str | None = None) -> UserContext: + return UserContext(user_id=user_id, token=token, user_name=user_name) diff --git a/packages/appkit-py/src/appkit_py/context/user_context.py b/packages/appkit-py/src/appkit_py/context/user_context.py new file mode 100644 index 00000000..79d1d0ba --- /dev/null +++ b/packages/appkit-py/src/appkit_py/context/user_context.py @@ -0,0 +1,14 @@ +"""User context dataclass for OBO (On-Behalf-Of) execution.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass +class UserContext: + """Per-request user context created from x-forwarded-* headers.""" + + user_id: str + token: str + user_name: str | None = None diff --git a/packages/appkit-py/src/appkit_py/core/__init__.py b/packages/appkit-py/src/appkit_py/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/src/appkit_py/errors/__init__.py b/packages/appkit-py/src/appkit_py/errors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/src/appkit_py/errors/base.py b/packages/appkit-py/src/appkit_py/errors/base.py new file mode 100644 index 00000000..592f2772 --- /dev/null +++ b/packages/appkit-py/src/appkit_py/errors/base.py @@ -0,0 +1,83 @@ +"""AppKit error hierarchy matching the TypeScript implementation.""" + +from __future__ import annotations + + +class AppKitError(Exception): + code: str = "APPKIT_ERROR" + status_code: int = 500 + is_retryable: bool = False + + def __init__(self, message: str, *, cause: Exception | None = None) -> None: + super().__init__(message) + self.cause = cause + + def to_dict(self) -> dict: + return {"error": str(self), "code": self.code, "statusCode": self.status_code} + + +class AuthenticationError(AppKitError): + code = "AUTHENTICATION_ERROR" + status_code = 401 + + @classmethod + def missing_token(cls, token_type: str = "access token") -> AuthenticationError: + return cls(f"Missing {token_type}") + + +class ValidationError(AppKitError): + code = "VALIDATION_ERROR" + status_code = 400 + + @classmethod + def missing_field(cls, field: str) -> ValidationError: + return cls(f"{field} is required") + + @classmethod + def invalid_value(cls, field: str, value: str, expectation: str) -> ValidationError: + return cls(f"Invalid {field}: {value}. Expected: {expectation}") + + +class ConfigurationError(AppKitError): + code = "CONFIGURATION_ERROR" + status_code = 500 + + @classmethod + def missing_env_var(cls, var_name: str) -> ConfigurationError: + return cls(f"Missing environment variable: {var_name}") + + +class ExecutionError(AppKitError): + code = "EXECUTION_ERROR" + status_code = 500 + + @classmethod + def statement_failed(cls, message: str) -> ExecutionError: + return cls(message) + + +class ConnectionError_(AppKitError): + code = "CONNECTION_ERROR" + status_code = 503 + is_retryable = True + + @classmethod + def api_failure(cls, service: str, cause: Exception | None = None) -> ConnectionError_: + return cls(f"Failed to connect to {service}", cause=cause) + + +class InitializationError(AppKitError): + code = "INITIALIZATION_ERROR" + status_code = 500 + + @classmethod + def not_initialized(cls, component: str, hint: str = "") -> InitializationError: + msg = f"{component} is not initialized" + if hint: + msg += f". {hint}" + return cls(msg) + + +class ServerError(AppKitError): + code = "SERVER_ERROR" + status_code = 500 diff --git a/packages/appkit-py/src/appkit_py/logging/__init__.py b/packages/appkit-py/src/appkit_py/logging/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/src/appkit_py/plugin/__init__.py b/packages/appkit-py/src/appkit_py/plugin/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/src/appkit_py/plugin/interceptors/__init__.py b/packages/appkit-py/src/appkit_py/plugin/interceptors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/src/appkit_py/plugin/interceptors/cache.py b/packages/appkit-py/src/appkit_py/plugin/interceptors/cache.py new file mode 100644 index 00000000..28ee32c8 --- /dev/null +++ b/packages/appkit-py/src/appkit_py/plugin/interceptors/cache.py @@ -0,0 +1,38 @@ +"""CacheInterceptor wrapping CacheManager. + +Mirrors packages/appkit/src/plugin/interceptors/cache.ts +""" + +from __future__ import annotations + +import time +from typing import Any, Awaitable, Callable + + +class CacheInterceptor: + def __init__( + self, + cache_store: dict[str, Any], + cache_key: str | None, + ttl: float = 300, + enabled: bool = True, + ) -> None: + self._store = cache_store + self._key = cache_key + self._ttl = ttl + self._enabled = enabled + + async def intercept(self, fn: Callable[[], Awaitable[Any]]) -> Any: + if not self._enabled or not self._key: + return await fn() + + if self._key in self._store: + value, expires_at = self._store[self._key] + if time.time() < expires_at: + return value + del self._store[self._key] + + result = await fn() + if self._key: + self._store[self._key] = (result, time.time() + self._ttl) + return result diff --git a/packages/appkit-py/src/appkit_py/plugin/interceptors/retry.py b/packages/appkit-py/src/appkit_py/plugin/interceptors/retry.py new file mode 100644 index 00000000..c032bba2 --- /dev/null +++ b/packages/appkit-py/src/appkit_py/plugin/interceptors/retry.py @@ -0,0 +1,38 @@ +"""RetryInterceptor with exponential backoff. + +Mirrors packages/appkit/src/plugin/interceptors/retry.ts +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any, Awaitable, Callable + +logger = logging.getLogger("appkit.interceptor.retry") + + +class RetryInterceptor: + def __init__( + self, + attempts: int = 3, + initial_delay: float = 1.0, + max_delay: float = 30.0, + ) -> None: + self.attempts = attempts + self.initial_delay = initial_delay + self.max_delay = max_delay + + async def intercept(self, fn: Callable[[], Awaitable[Any]]) -> Any: + last_error: Exception | None = None + for attempt in range(1, self.attempts + 1): + try: + return await fn() + except Exception as exc: + last_error = exc + if attempt >= self.attempts: + raise + delay = min(self.initial_delay * (2 ** (attempt - 1)), self.max_delay) + logger.debug("Retry attempt %d/%d after %.1fs: %s", attempt, self.attempts, delay, exc) + await asyncio.sleep(delay) + raise last_error # type: ignore[misc] diff --git a/packages/appkit-py/src/appkit_py/plugin/interceptors/timeout.py b/packages/appkit-py/src/appkit_py/plugin/interceptors/timeout.py new file mode 100644 index 00000000..c64b2998 --- /dev/null +++ b/packages/appkit-py/src/appkit_py/plugin/interceptors/timeout.py @@ -0,0 +1,17 @@ +"""TimeoutInterceptor using asyncio.wait_for. + +Mirrors packages/appkit/src/plugin/interceptors/timeout.ts +""" + +from __future__ import annotations + +import asyncio +from typing import Any, Awaitable, Callable + + +class TimeoutInterceptor: + def __init__(self, timeout_seconds: float) -> None: + self.timeout_seconds = timeout_seconds + + async def intercept(self, fn: Callable[[], Awaitable[Any]]) -> Any: + return await asyncio.wait_for(fn(), timeout=self.timeout_seconds) diff --git a/packages/appkit-py/src/appkit_py/plugin/interceptors/types.py b/packages/appkit-py/src/appkit_py/plugin/interceptors/types.py new file mode 100644 index 00000000..5171cccf --- /dev/null +++ b/packages/appkit-py/src/appkit_py/plugin/interceptors/types.py @@ -0,0 +1,11 @@ +"""Interceptor protocol and context types.""" + +from __future__ import annotations + +from typing import Any, Awaitable, Callable, Protocol, TypeVar + +T = TypeVar("T") + + +class ExecutionInterceptor(Protocol): + async def intercept(self, fn: Callable[[], Awaitable[Any]]) -> Any: ... diff --git a/packages/appkit-py/src/appkit_py/plugin/plugin.py b/packages/appkit-py/src/appkit_py/plugin/plugin.py new file mode 100644 index 00000000..340abccb --- /dev/null +++ b/packages/appkit-py/src/appkit_py/plugin/plugin.py @@ -0,0 +1,94 @@ +"""Abstract Plugin base class. + +Mirrors packages/appkit/src/plugin/plugin.ts +""" + +from __future__ import annotations + +import asyncio +import inspect +from typing import Any + +from appkit_py.context.execution_context import run_in_user_context +from appkit_py.context.user_context import UserContext +from appkit_py.stream.stream_manager import StreamManager + + +# Methods excluded from the as_user proxy +_EXCLUDED_FROM_PROXY = frozenset({ + "setup", "shutdown", "inject_routes", "get_endpoints", + "as_user", "exports", "client_config", "name", +}) + + +class Plugin: + """Abstract base class for all AppKit plugins.""" + + name: str = "plugin" + phase: str = "normal" # "core", "normal", or "deferred" + + def __init__(self, config: dict[str, Any] | None = None) -> None: + self.config = config or {} + self.stream_manager = StreamManager() + self._registered_endpoints: dict[str, str] = {} + + async def setup(self) -> None: + """Async setup hook called after construction.""" + pass + + def inject_routes(self, router: Any) -> None: + """Register HTTP routes on the given router.""" + pass + + def get_endpoints(self) -> dict[str, str]: + return dict(self._registered_endpoints) + + def exports(self) -> dict[str, Any]: + return {} + + def client_config(self) -> dict[str, Any]: + return {} + + def as_user(self, request: Any) -> Plugin: + """Return a proxy that wraps method calls in user context.""" + headers = getattr(request, "headers", {}) + token = headers.get("x-forwarded-access-token", "") + user_id = headers.get("x-forwarded-user", "") + user_ctx = UserContext(user_id=user_id, token=token) + return _UserContextProxy(self, user_ctx) # type: ignore[return-value] + + def resolve_user_id(self, request: Any) -> str: + headers = getattr(request, "headers", {}) + return headers.get("x-forwarded-user", "service-principal") + + async def shutdown(self) -> None: + self.stream_manager.abort_all() + + +class _UserContextProxy(Plugin): + """Proxy that wraps all method calls in a user context. + + Python equivalent of the JS Proxy used by asUser() in TypeScript. + """ + + def __init__(self, plugin: Plugin, user_context: UserContext) -> None: + # Don't call super().__init__ — we delegate everything + object.__setattr__(self, "_plugin", plugin) + object.__setattr__(self, "_user_context", user_context) + + def __getattr__(self, name: str) -> Any: + attr = getattr(self._plugin, name) + if name in _EXCLUDED_FROM_PROXY or not callable(attr): + return attr + + # Only wrap coroutine functions as async; leave sync methods alone + if asyncio.iscoroutinefunction(attr): + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + return await run_in_user_context( + self._user_context, + lambda: attr(*args, **kwargs), + ) + return async_wrapper + + # Sync callable — return as-is (context won't propagate, but won't break) + return attr diff --git a/packages/appkit-py/src/appkit_py/plugins/__init__.py b/packages/appkit-py/src/appkit_py/plugins/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/src/appkit_py/plugins/analytics/__init__.py b/packages/appkit-py/src/appkit_py/plugins/analytics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/src/appkit_py/plugins/analytics/query.py b/packages/appkit-py/src/appkit_py/plugins/analytics/query.py new file mode 100644 index 00000000..4c6a34be --- /dev/null +++ b/packages/appkit-py/src/appkit_py/plugins/analytics/query.py @@ -0,0 +1,68 @@ +"""QueryProcessor for SQL parameter processing. + +Mirrors packages/appkit/src/plugins/analytics/query.ts +""" + +from __future__ import annotations + +import hashlib +import os +import re +from typing import Any + + +class QueryProcessor: + """Process SQL queries: hash, convert named parameters, etc.""" + + def hash_query(self, query: str) -> str: + """SHA256 hash of the query text for cache keying.""" + return hashlib.sha256(query.encode()).hexdigest() + + def convert_to_sql_parameters( + self, + query: str, + parameters: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Convert named :param placeholders to Databricks SQL parameter format. + + Returns dict with 'statement' and 'parameters' keys. + """ + if not parameters: + return {"statement": query, "parameters": []} + + sql_params = [] + for name, value in parameters.items(): + if value is None: + sql_params.append({"name": name, "value": None, "type": "STRING"}) + elif isinstance(value, dict) and "__sql_type" in value: + sql_params.append({ + "name": name, + "value": str(value["value"]), + "type": value["__sql_type"], + }) + else: + sql_params.append({"name": name, "value": str(value), "type": "STRING"}) + + return {"statement": query, "parameters": sql_params} + + async def process_query_params( + self, + query: str, + parameters: dict[str, Any] | None = None, + *, + workspace_id: str | None = None, + ) -> dict[str, Any] | None: + """Process and validate query parameters. + + Auto-injects workspaceId if the query references :workspaceId and + it's not already in the parameters. + """ + params = dict(parameters) if parameters else {} + + # Auto-inject workspaceId if referenced in query but not provided + if ":workspaceId" in query and "workspaceId" not in params: + ws_id = workspace_id or os.environ.get("DATABRICKS_WORKSPACE_ID", "") + if ws_id: + params["workspaceId"] = {"__sql_type": "STRING", "value": ws_id} + + return params if params else None diff --git a/packages/appkit-py/src/appkit_py/plugins/files/__init__.py b/packages/appkit-py/src/appkit_py/plugins/files/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/src/appkit_py/plugins/genie/__init__.py b/packages/appkit-py/src/appkit_py/plugins/genie/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/src/appkit_py/plugins/server/__init__.py b/packages/appkit-py/src/appkit_py/plugins/server/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/src/appkit_py/server.py b/packages/appkit-py/src/appkit_py/server.py new file mode 100644 index 00000000..5db2a729 --- /dev/null +++ b/packages/appkit-py/src/appkit_py/server.py @@ -0,0 +1,801 @@ +"""Main FastAPI application — the Python AppKit backend server. + +This is the full server implementation that provides 100% API compatibility +with the TypeScript AppKit backend. It serves the same endpoints that the +React frontend (appkit-ui) expects. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import uuid +from pathlib import Path +from typing import Any, AsyncGenerator + +from fastapi import FastAPI, Request, Response +from fastapi.responses import JSONResponse, StreamingResponse +from starlette.staticfiles import StaticFiles + +from appkit_py.connectors.files.client import FilesConnector +from appkit_py.connectors.genie.client import GenieConnector +from appkit_py.connectors.sql_warehouse.client import SQLWarehouseConnector +from appkit_py.plugins.analytics.query import QueryProcessor +from appkit_py.stream.sse_writer import SSE_HEADERS, format_error, format_event, format_heartbeat +from appkit_py.stream.stream_manager import StreamManager +from appkit_py.stream.types import SSEErrorCode + +logger = logging.getLogger("appkit.server") + + +def _get_workspace_client() -> Any | None: + """Create a WorkspaceClient if DATABRICKS_HOST is set.""" + host = os.environ.get("DATABRICKS_HOST") + if not host: + return None + try: + from databricks.sdk import WorkspaceClient + return WorkspaceClient() + except Exception as exc: + logger.warning("Failed to create WorkspaceClient: %s", exc) + return None + + +# --------------------------------------------------------------------------- +# App factory +# --------------------------------------------------------------------------- + +def create_server( + *, + query_dir: str | None = None, + static_path: str | None = None, + genie_spaces: dict[str, str] | None = None, + volumes: dict[str, str] | None = None, +) -> FastAPI: + """Create and configure the FastAPI application. + + This mirrors the TypeScript createApp() + server plugin pattern. + """ + app = FastAPI(title="AppKit Python Backend") + stream_manager = StreamManager() + query_processor = QueryProcessor() + + # Discover configuration from environment + _genie_spaces = genie_spaces or _discover_genie_spaces() + _volumes = volumes or _discover_volumes() + _query_dir = query_dir or _find_query_dir() + + # Initialize connectors + _ws_client = _get_workspace_client() # Service principal client + _sql_connector = SQLWarehouseConnector() + _genie_connector = GenieConnector() + _file_connectors: dict[str, FilesConnector] = { + key: FilesConnector(default_volume=path) for key, path in _volumes.items() + } + _warehouse_id = os.environ.get("DATABRICKS_WAREHOUSE_ID") + + def _get_user_client(request: Request) -> Any | None: + """Create a per-request WorkspaceClient using OBO credentials. + + Falls back to the service principal client if no user headers are present. + """ + token = request.headers.get("x-forwarded-access-token") + host = os.environ.get("DATABRICKS_HOST") + if token and host: + try: + from databricks.sdk import WorkspaceClient + return WorkspaceClient(host=host, token=token) + except Exception: + pass + return _ws_client + + # ----------------------------------------------------------------------- + # Health endpoint + # ----------------------------------------------------------------------- + @app.get("/health") + async def health(): + return {"status": "ok"} + + # ----------------------------------------------------------------------- + # Reconnect plugin (test/dev SSE endpoint matching TS dev-playground) + # ----------------------------------------------------------------------- + @app.get("/api/reconnect/stream") + async def reconnect_stream(request: Request): + async def event_generator() -> AsyncGenerator[str, None]: + for i in range(1, 6): + event_id = str(uuid.uuid4()) + yield format_event(event_id, { + "type": "message", + "count": i, + "total": 5, + "message": f"Event {i} of 5", + }) + await asyncio.sleep(0.1) + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={k: v for k, v in SSE_HEADERS.items() if k != "Content-Type"}, + ) + + # ----------------------------------------------------------------------- + # Analytics plugin: POST /api/analytics/query/{query_key} + # ----------------------------------------------------------------------- + @app.post("/api/analytics/query/{query_key}") + async def analytics_query(query_key: str, request: Request): + body = {} + try: + body = await request.json() + except Exception: + pass + + format_ = body.get("format", "ARROW_STREAM") + parameters = body.get("parameters") + + if not query_key: + return JSONResponse({"error": "query_key is required"}, status_code=400) + + # Look up the query file + query_text = _load_query(query_key, _query_dir) + if query_text is None: + return JSONResponse({"error": "Query not found"}, status_code=404) + + is_obo = query_key.endswith(".obo") or _has_obo_file(query_key, _query_dir) + + async def event_generator() -> AsyncGenerator[str, None]: + if not _ws_client or not _warehouse_id: + error_id = str(uuid.uuid4()) + yield format_error( + error_id, + "Databricks connection not configured", + SSEErrorCode.TEMPORARY_UNAVAILABLE, + ) + return + + try: + converted = query_processor.convert_to_sql_parameters(query_text, parameters) + + # Format configs matching TS FORMAT_CONFIGS with fallback order + FORMAT_CONFIGS = { + "ARROW_STREAM": {"disposition": "INLINE", "format": "ARROW_STREAM", "type": "result"}, + "JSON": {"disposition": "INLINE", "format": "JSON_ARRAY", "type": "result"}, + "ARROW": {"disposition": "EXTERNAL_LINKS", "format": "ARROW_STREAM", "type": "arrow"}, + } + + # For default ARROW_STREAM, try fallback: ARROW_STREAM → JSON → ARROW + if format_ == "ARROW_STREAM": + fallback_order = ["ARROW_STREAM", "JSON", "ARROW"] + else: + fallback_order = [format_] + + response = None + result_type = "result" + for i, fmt_name in enumerate(fallback_order): + fmt_config = FORMAT_CONFIGS.get(fmt_name, FORMAT_CONFIGS["JSON"]) + try: + response = await _sql_connector.execute_statement( + _ws_client, + statement=converted["statement"], + warehouse_id=_warehouse_id, + parameters=converted.get("parameters") or None, + disposition=fmt_config["disposition"], + format=fmt_config["format"], + ) + result_type = fmt_config["type"] + if i > 0: + logger.info("Query succeeded with fallback format %s", fmt_name) + break + except Exception as fmt_err: + msg = str(fmt_err) + is_format_error = any(s in msg for s in [ + "ARROW_STREAM", "JSON_ARRAY", "EXTERNAL_LINKS", + "INVALID_PARAMETER_VALUE", "NOT_IMPLEMENTED", + "format field must be", + ]) + if not is_format_error or i == len(fallback_order) - 1: + raise + logger.warning("Format %s rejected, falling back: %s", fmt_name, msg) + + if response is None: + raise RuntimeError("All format fallbacks exhausted") + + # For ARROW format with EXTERNAL_LINKS, emit an arrow event + if result_type == "arrow" and response.statement_id: + event_id = str(uuid.uuid4()) + yield format_event(event_id, { + "type": "arrow", + "statement_id": response.statement_id, + }) + else: + # Transform result: handles Arrow IPC attachment, data_array, etc. + result_data = _sql_connector.transform_result(response) + + event_id = str(uuid.uuid4()) + yield format_event(event_id, { + "type": "result", + "chunk_index": 0, + "row_offset": 0, + "row_count": len(result_data), + "data": result_data, + }) + + except Exception as exc: + error_id = str(uuid.uuid4()) + yield format_error(error_id, str(exc)) + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={k: v for k, v in SSE_HEADERS.items() if k != "Content-Type"}, + ) + + # ----------------------------------------------------------------------- + # Analytics plugin: GET /api/analytics/arrow-result/{job_id} + # ----------------------------------------------------------------------- + @app.get("/api/analytics/arrow-result/{job_id}") + async def analytics_arrow_result(job_id: str): + if not _ws_client: + return JSONResponse( + {"error": "Arrow job not found", "plugin": "analytics"}, + status_code=404, + ) + try: + result = await _sql_connector.get_arrow_data(_ws_client, job_id) + return Response( + content=result["data"], + media_type="application/octet-stream", + headers={ + "Content-Length": str(len(result["data"])), + "Cache-Control": "public, max-age=3600", + }, + ) + except Exception as exc: + return JSONResponse( + {"error": str(exc) or "Arrow job not found", "plugin": "analytics"}, + status_code=404, + ) + + # ----------------------------------------------------------------------- + # Files plugin: GET /api/files/volumes + # ----------------------------------------------------------------------- + @app.get("/api/files/volumes") + async def files_volumes(): + return {"volumes": list(_volumes.keys())} + + # ----------------------------------------------------------------------- + # Files plugin: volume routes + # ----------------------------------------------------------------------- + def _resolve_volume(volume_key: str) -> str | None: + return _volumes.get(volume_key) + + def _validate_path(path: str | None) -> str | True: + if not path: + return "path is required" + if len(path) > 4096: + return f"path exceeds maximum length of 4096 characters (got {len(path)})" + if "\0" in path: + return "path must not contain null bytes" + return True + + async def _run_file_op(volume_key: str, op_name: str, op_coro): + """Helper to run a file operation with error handling.""" + if not _ws_client: + return JSONResponse( + {"error": "Databricks connection not configured", "plugin": "files"}, + status_code=500, + ) + connector = _file_connectors.get(volume_key) + if not connector: + return JSONResponse( + {"error": "Volume connector not found", "plugin": "files"}, + status_code=500, + ) + try: + return await op_coro + except Exception as exc: + status = 500 + if hasattr(exc, "status_code"): + status = exc.status_code + return JSONResponse( + {"error": str(exc), "plugin": "files"}, + status_code=status, + ) + + @app.get("/api/files/{volume_key}/list") + async def files_list(volume_key: str, request: Request, path: str | None = None): + if not _resolve_volume(volume_key): + safe_key = "".join(c for c in volume_key if c.isalnum() or c in "_-") + return JSONResponse( + {"error": f'Unknown volume "{safe_key}"', "plugin": "files"}, + status_code=404, + ) + connector = _file_connectors.get(volume_key) + client = _get_user_client(request) + if not client or not connector: + return JSONResponse( + {"error": "Databricks connection not configured", "plugin": "files"}, + status_code=500, + ) + try: + result = await connector.list(client, path) + return result + except Exception as exc: + return JSONResponse( + {"error": str(exc), "plugin": "files"}, status_code=500 + ) + + @app.get("/api/files/{volume_key}/read") + async def files_read(volume_key: str, path: str | None = None): + if not _resolve_volume(volume_key): + safe_key = "".join(c for c in volume_key if c.isalnum() or c in "_-") + return JSONResponse( + {"error": f'Unknown volume "{safe_key}"', "plugin": "files"}, + status_code=404, + ) + valid = _validate_path(path) + if valid is not True: + return JSONResponse({"error": valid, "plugin": "files"}, status_code=400) + connector = _file_connectors.get(volume_key) + client = _get_user_client(request) + if not client or not connector: + return JSONResponse( + {"error": "Databricks connection not configured", "plugin": "files"}, + status_code=500, + ) + try: + text = await connector.read(client, path) + return Response(content=text, media_type="text/plain") + except Exception as exc: + return JSONResponse({"error": str(exc), "plugin": "files"}, status_code=500) + + def _get_client_for_request(request: Request) -> Any: + """Get the appropriate WorkspaceClient for a request. + + OBO routes use per-request client with user's token. + Falls back to service principal client. + """ + return _get_user_client(request) + + def _file_handler_preamble(volume_key: str, request: Request, path: str | None = None, require_path: bool = True): + """Common preamble for file endpoints: resolve volume, validate path, get client. + + Returns (error_response, None, None) on failure, or (None, connector, client) on success. + """ + if not _resolve_volume(volume_key): + safe_key = "".join(c for c in volume_key if c.isalnum() or c in "_-") + return (JSONResponse( + {"error": f'Unknown volume "{safe_key}"', "plugin": "files"}, + status_code=404, + ), None, None) + if require_path: + valid = _validate_path(path) + if valid is not True: + return (JSONResponse({"error": valid, "plugin": "files"}, status_code=400), None, None) + connector = _file_connectors.get(volume_key) + client = _get_user_client(request) + if not client or not connector: + return (JSONResponse( + {"error": "Databricks connection not configured", "plugin": "files"}, + status_code=500, + ), None, None) + return (None, connector, client) # All checks passed + + @app.get("/api/files/{volume_key}/download") + async def files_download(volume_key: str, request: Request, path: str | None = None): + err, connector, client = _file_handler_preamble(volume_key, request, path) + if err: + return err + try: + result = await connector.download(client, path) + import mimetypes + content_type = result.get("content_type") or mimetypes.guess_type(path)[0] or "application/octet-stream" + raw_name = path.split("/")[-1] if path else "download" + # Sanitize filename: strip chars that could enable header injection + filename = "".join(c for c in raw_name if c.isalnum() or c in "._- ")[:255] or "download" + headers = { + "Content-Disposition": f'attachment; filename="{filename}"', + "X-Content-Type-Options": "nosniff", + } + content = result.get("contents") + if hasattr(content, "read"): + body = content.read() + else: + body = content or b"" + return Response(content=body, media_type=content_type, headers=headers) + except Exception as exc: + return JSONResponse({"error": str(exc), "plugin": "files"}, status_code=500) + + @app.get("/api/files/{volume_key}/raw") + async def files_raw(volume_key: str, request: Request, path: str | None = None): + err, connector, client = _file_handler_preamble(volume_key, request, path) + if err: + return err + try: + result = await connector.download(client, path) + import mimetypes + content_type = result.get("content_type") or mimetypes.guess_type(path)[0] or "application/octet-stream" + headers = { + "Content-Security-Policy": "sandbox", + "X-Content-Type-Options": "nosniff", + } + content = result.get("contents") + if hasattr(content, "read"): + body = content.read() + else: + body = content or b"" + return Response(content=body, media_type=content_type, headers=headers) + except Exception as exc: + return JSONResponse({"error": str(exc), "plugin": "files"}, status_code=500) + + @app.get("/api/files/{volume_key}/exists") + async def files_exists(volume_key: str, request: Request, path: str | None = None): + err, connector, client = _file_handler_preamble(volume_key, request, path) + if err: + return err + try: + exists = await connector.exists(client, path) + return {"exists": exists} + except Exception as exc: + return JSONResponse({"error": str(exc), "plugin": "files"}, status_code=500) + + @app.get("/api/files/{volume_key}/metadata") + async def files_metadata(volume_key: str, request: Request, path: str | None = None): + err, connector, client = _file_handler_preamble(volume_key, request, path) + if err: + return err + try: + meta = await connector.metadata(client, path) + return meta + except Exception as exc: + return JSONResponse({"error": str(exc), "plugin": "files"}, status_code=500) + + @app.get("/api/files/{volume_key}/preview") + async def files_preview(volume_key: str, request: Request, path: str | None = None): + err, connector, client = _file_handler_preamble(volume_key, request, path) + if err: + return err + try: + preview = await connector.preview(client, path) + return preview + except Exception as exc: + return JSONResponse({"error": str(exc), "plugin": "files"}, status_code=500) + + @app.post("/api/files/{volume_key}/upload") + async def files_upload(volume_key: str, request: Request, path: str | None = None): + if not _resolve_volume(volume_key): + safe_key = "".join(c for c in volume_key if c.isalnum() or c in "_-") + return JSONResponse( + {"error": f'Unknown volume "{safe_key}"', "plugin": "files"}, + status_code=404, + ) + valid = _validate_path(path) + if valid is not True: + return JSONResponse({"error": valid, "plugin": "files"}, status_code=400) + + max_size = 5 * 1024 * 1024 * 1024 # 5GB + content_length = request.headers.get("content-length") + if content_length: + try: + size = int(content_length) + if size > max_size: + return JSONResponse( + { + "error": f"File size ({size} bytes) exceeds maximum allowed size ({max_size} bytes).", + "plugin": "files", + }, + status_code=413, + ) + except ValueError: + pass + + connector = _file_connectors.get(volume_key) + client = _get_user_client(request) + if not client or not connector: + return JSONResponse( + {"error": "Databricks connection not configured", "plugin": "files"}, + status_code=500, + ) + try: + # Stream the body with a running size counter to prevent OOM + chunks: list[bytes] = [] + bytes_received = 0 + async for chunk in request.stream(): + bytes_received += len(chunk) + if bytes_received > max_size: + return JSONResponse( + { + "error": f"Upload stream exceeds maximum allowed size ({max_size} bytes).", + "plugin": "files", + }, + status_code=413, + ) + chunks.append(chunk) + body = b"".join(chunks) + await connector.upload(client, path, body) + return {"success": True} + except Exception as exc: + if "exceeds maximum allowed size" in str(exc): + return JSONResponse({"error": str(exc), "plugin": "files"}, status_code=413) + return JSONResponse({"error": str(exc), "plugin": "files"}, status_code=500) + + @app.post("/api/files/{volume_key}/mkdir") + async def files_mkdir(volume_key: str, request: Request): + if not _resolve_volume(volume_key): + safe_key = "".join(c for c in volume_key if c.isalnum() or c in "_-") + return JSONResponse( + {"error": f'Unknown volume "{safe_key}"', "plugin": "files"}, + status_code=404, + ) + body = {} + try: + body = await request.json() + except Exception: + pass + dir_path = body.get("path") if isinstance(body, dict) else None + valid = _validate_path(dir_path) + if valid is not True: + return JSONResponse({"error": valid, "plugin": "files"}, status_code=400) + connector = _file_connectors.get(volume_key) + client = _get_user_client(request) + if not client or not connector: + return JSONResponse( + {"error": "Databricks connection not configured", "plugin": "files"}, + status_code=500, + ) + try: + await connector.create_directory(client, dir_path) + return {"success": True} + except Exception as exc: + return JSONResponse({"error": str(exc), "plugin": "files"}, status_code=500) + + @app.delete("/api/files/{volume_key}") + async def files_delete(volume_key: str, request: Request, path: str | None = None): + if not _resolve_volume(volume_key): + safe_key = "".join(c for c in volume_key if c.isalnum() or c in "_-") + return JSONResponse( + {"error": f'Unknown volume "{safe_key}"', "plugin": "files"}, + status_code=404, + ) + valid = _validate_path(path) + if valid is not True: + return JSONResponse({"error": valid, "plugin": "files"}, status_code=400) + connector = _file_connectors.get(volume_key) + client = _get_user_client(request) + if not client or not connector: + return JSONResponse( + {"error": "Databricks connection not configured", "plugin": "files"}, + status_code=500, + ) + try: + await connector.delete(client, path) + return {"success": True} + except Exception as exc: + return JSONResponse({"error": str(exc), "plugin": "files"}, status_code=500) + + # ----------------------------------------------------------------------- + # Genie plugin + # ----------------------------------------------------------------------- + def _sse_from_genie(gen_coro, client: Any) -> StreamingResponse: + """Create an SSE StreamingResponse from a genie async generator.""" + async def event_generator() -> AsyncGenerator[str, None]: + if not client: + error_id = str(uuid.uuid4()) + yield format_error(error_id, "Databricks Genie connection not configured", SSEErrorCode.TEMPORARY_UNAVAILABLE) + return + try: + async for event in gen_coro: + event_id = str(uuid.uuid4()) + yield format_event(event_id, event) + except Exception as exc: + error_id = str(uuid.uuid4()) + yield format_error(error_id, str(exc)) + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={k: v for k, v in SSE_HEADERS.items() if k != "Content-Type"}, + ) + + @app.post("/api/genie/{alias}/messages") + async def genie_send_message(alias: str, request: Request): + space_id = _genie_spaces.get(alias) + if not space_id: + return JSONResponse({"error": f"Unknown space alias: {alias}"}, status_code=404) + + body = {} + try: + body = await request.json() + except Exception: + pass + content = body.get("content") if isinstance(body, dict) else None + if not content: + return JSONResponse({"error": "content is required"}, status_code=400) + + conversation_id = body.get("conversationId") if isinstance(body, dict) else None + client = _get_user_client(request) + return _sse_from_genie( + _genie_connector.stream_send_message(client, space_id, content, conversation_id), + client, + ) + + @app.get("/api/genie/{alias}/conversations/{conversation_id}") + async def genie_get_conversation(alias: str, conversation_id: str, request: Request): + space_id = _genie_spaces.get(alias) + if not space_id: + return JSONResponse({"error": f"Unknown space alias: {alias}"}, status_code=404) + + include_query_results = request.query_params.get("includeQueryResults", "true") != "false" + page_token = request.query_params.get("pageToken") + client = _get_user_client(request) + return _sse_from_genie( + _genie_connector.stream_conversation( + client, space_id, conversation_id, + include_query_results=include_query_results, page_token=page_token, + ), + client, + ) + + @app.get("/api/genie/{alias}/conversations/{conversation_id}/messages/{message_id}") + async def genie_get_message(alias: str, conversation_id: str, message_id: str, request: Request): + space_id = _genie_spaces.get(alias) + if not space_id: + return JSONResponse({"error": f"Unknown space alias: {alias}"}, status_code=404) + + client = _get_user_client(request) + return _sse_from_genie( + _genie_connector.stream_get_message(client, space_id, conversation_id, message_id), + client, + ) + + # ----------------------------------------------------------------------- + # Static file serving with client config injection + # ----------------------------------------------------------------------- + resolved_static = static_path or _find_static_dir() + if resolved_static and Path(resolved_static).is_dir(): + _static_dir = Path(resolved_static) + _index_html = _static_dir / "index.html" + + # Build client config (injected into index.html like TS StaticServer) + _client_config = json.dumps({ + "appName": os.environ.get("DATABRICKS_APP_NAME", "appkit-py"), + "queries": {}, + "endpoints": { + "analytics": {"query": "/api/analytics/query", "arrow": "/api/analytics/arrow-result"}, + "files": { + "volumes": "/api/files/volumes", "list": "/api/files/:volumeKey/list", + "read": "/api/files/:volumeKey/read", "download": "/api/files/:volumeKey/download", + "raw": "/api/files/:volumeKey/raw", "exists": "/api/files/:volumeKey/exists", + "metadata": "/api/files/:volumeKey/metadata", "preview": "/api/files/:volumeKey/preview", + "upload": "/api/files/:volumeKey/upload", "mkdir": "/api/files/:volumeKey/mkdir", + "delete": "/api/files/:volumeKey", + }, + "genie": { + "sendMessage": "/api/genie/:alias/messages", + "getConversation": "/api/genie/:alias/conversations/:conversationId", + "getMessage": "/api/genie/:alias/conversations/:conversationId/messages/:messageId", + }, + }, + "plugins": { + "files": {"volumes": list(_volumes.keys())}, + "genie": {"spaces": list(_genie_spaces.keys())}, + }, + }) + # Escape for safe HTML embedding + _safe_config = _client_config.replace("<", "\\u003c").replace(">", "\\u003e").replace("&", "\\u0026") + + @app.get("/{full_path:path}") + async def serve_spa(full_path: str): + """Serve static files or index.html with injected config (SPA catch-all).""" + import mimetypes + # Resolve and verify the path stays within the static directory + file_path = (_static_dir / full_path).resolve() + static_root = _static_dir.resolve() + if ( + file_path.is_file() + and str(file_path).startswith(str(static_root) + os.sep) + ): + ct = mimetypes.guess_type(str(file_path))[0] or "application/octet-stream" + return Response(content=file_path.read_bytes(), media_type=ct) + + # Fall back to index.html with injected config + if _index_html.is_file(): + html = _index_html.read_text() + config_script = ( + f'\n' + '' + ) + # Inject before or at end of
+ if "" in html: + html = html.replace("", f"{config_script}\n") + else: + html = config_script + "\n" + html + return Response(content=html, media_type="text/html") + + return JSONResponse({"error": "Not found"}, status_code=404) + + return app + + +# --------------------------------------------------------------------------- +# Configuration discovery helpers +# --------------------------------------------------------------------------- + +def _discover_genie_spaces() -> dict[str, str]: + space_id = os.environ.get("DATABRICKS_GENIE_SPACE_ID") + if space_id: + return {"default": space_id} + return {} + + +def _discover_volumes() -> dict[str, str]: + prefix = "DATABRICKS_VOLUME_" + volumes: dict[str, str] = {} + for key, value in os.environ.items(): + if key.startswith(prefix) and value: + suffix = key[len(prefix):] + if suffix: + volumes[suffix.lower()] = value + return volumes + + +def _find_static_dir() -> str | None: + """Auto-detect the frontend static directory (matching TS StaticServer logic).""" + candidates = [ + "client/dist", "dist", "build", "public", "out", + "../client/dist", "../dist", + ] + for candidate in candidates: + if Path(candidate).is_dir(): + return candidate + return None + + +def _find_query_dir() -> str | None: + """Find the config/queries directory relative to CWD.""" + candidates = ["config/queries", "../config/queries", "../../config/queries"] + for candidate in candidates: + path = Path(candidate) + if path.is_dir(): + return str(path) + return None + + +def _load_query(query_key: str, query_dir: str | None) -> str | None: + """Load a SQL query file by key from the query directory.""" + if not query_dir: + return None + + # Sanitize query_key: reject path separators and traversal sequences + if "/" in query_key or "\\" in query_key or ".." in query_key: + return None + + base = query_key.removesuffix(".obo") + dir_path = Path(query_dir).resolve() + + # Try .obo.sql first, then .sql + for suffix in [".obo.sql", ".sql"]: + file_path = (dir_path / f"{base}{suffix}").resolve() + # Verify the resolved path stays within the query directory + if not str(file_path).startswith(str(dir_path) + os.sep): + return None + if file_path.is_file(): + return file_path.read_text() + + return None + + +def _has_obo_file(query_key: str, query_dir: str | None) -> bool: + """Check if a .obo.sql variant exists for this query key.""" + if not query_dir: + return False + base = query_key.removesuffix(".obo") + return (Path(query_dir) / f"{base}.obo.sql").is_file() + + +# --------------------------------------------------------------------------- +# App instance for uvicorn +# --------------------------------------------------------------------------- + +app = create_server() diff --git a/packages/appkit-py/src/appkit_py/stream/__init__.py b/packages/appkit-py/src/appkit_py/stream/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/src/appkit_py/stream/buffers.py b/packages/appkit-py/src/appkit_py/stream/buffers.py new file mode 100644 index 00000000..9bcd1cad --- /dev/null +++ b/packages/appkit-py/src/appkit_py/stream/buffers.py @@ -0,0 +1,84 @@ +"""Ring buffer implementations for SSE event replay on reconnection. + +Ports the TypeScript RingBuffer and EventRingBuffer from +packages/appkit/src/stream/buffers.ts +""" + +from __future__ import annotations + +from collections import OrderedDict +from typing import Generic, TypeVar + +from .types import BufferedEvent + +T = TypeVar("T") + + +class RingBuffer(Generic[T]): + """Generic FIFO ring buffer with LRU eviction and O(1) key lookup.""" + + def __init__(self, capacity: int) -> None: + self._capacity = capacity + self._store: OrderedDict[str, T] = OrderedDict() + + def add(self, key: str, value: T) -> None: + if key in self._store: + del self._store[key] + elif len(self._store) >= self._capacity: + self._store.popitem(last=False) # Evict oldest + self._store[key] = value + + def get(self, key: str) -> T | None: + return self._store.get(key) + + def has(self, key: str) -> bool: + return key in self._store + + def __len__(self) -> int: + return len(self._store) + + def keys(self) -> list[str]: + return list(self._store.keys()) + + def values(self) -> list[T]: + return list(self._store.values()) + + +class EventRingBuffer: + """Specialized ring buffer for SSE events with get_events_since() for replay.""" + + def __init__(self, capacity: int) -> None: + self._buffer: RingBuffer[BufferedEvent] = RingBuffer(capacity) + self._order: list[str] = [] # Maintain insertion order for replay + self._capacity = capacity + + def add_event(self, event: BufferedEvent) -> None: + self._buffer.add(event.id, event) + self._order.append(event.id) + # Trim order list to capacity + if len(self._order) > self._capacity: + self._order = self._order[-self._capacity :] + + def has_event(self, event_id: str) -> bool: + return self._buffer.has(event_id) + + def get_events_since(self, event_id: str) -> list[BufferedEvent] | None: + """Get all events after the given event ID. + + Returns None if the event_id is not in the buffer (buffer overflow). + Returns an empty list if event_id is the last event. + """ + if not self._buffer.has(event_id): + return None + + try: + idx = self._order.index(event_id) + except ValueError: + return None + + result: list[BufferedEvent] = [] + for eid in self._order[idx + 1 :]: + event = self._buffer.get(eid) + if event is not None: + result.append(event) + return result diff --git a/packages/appkit-py/src/appkit_py/stream/defaults.py b/packages/appkit-py/src/appkit_py/stream/defaults.py new file mode 100644 index 00000000..3e4547ad --- /dev/null +++ b/packages/appkit-py/src/appkit_py/stream/defaults.py @@ -0,0 +1,11 @@ +"""Stream default configuration values matching the TypeScript implementation.""" + +STREAM_DEFAULTS = { + "buffer_size": 100, + "max_event_size": 1024 * 1024, # 1MB + "buffer_ttl": 10 * 60, # 10 minutes (seconds) + "cleanup_interval": 5 * 60, # 5 minutes (seconds) + "max_persistent_buffers": 10000, + "heartbeat_interval": 10, # 10 seconds + "max_active_streams": 1000, +} diff --git a/packages/appkit-py/src/appkit_py/stream/sse_writer.py b/packages/appkit-py/src/appkit_py/stream/sse_writer.py new file mode 100644 index 00000000..6e97ac3a --- /dev/null +++ b/packages/appkit-py/src/appkit_py/stream/sse_writer.py @@ -0,0 +1,67 @@ +"""SSE wire format writer matching the TypeScript SSEWriter. + +Produces the exact format expected by the AppKit frontend: + id: {uuid} + event: {type} + data: {json} + (empty line) + +Plus heartbeat comments: `: heartbeat\\n\\n` +""" + +from __future__ import annotations + +import json +import re +from typing import Any, Callable, Coroutine + +from .types import BufferedEvent, SSEErrorCode, SSEWarningCode + + +def sanitize_event_type(event_type: str) -> str: + """Sanitize SSE event type: remove newlines, cap at 100 chars.""" + sanitized = re.sub(r"[\r\n]", "", event_type) + return sanitized[:100] + + +def format_event(event_id: str, event: dict[str, Any]) -> str: + """Format a single SSE event as a string.""" + event_type = sanitize_event_type(str(event.get("type", "message"))) + event_data = json.dumps(event, separators=(",", ":")) + return f"id: {event_id}\nevent: {event_type}\ndata: {event_data}\n\n" + + +def format_error(event_id: str, error: str, code: SSEErrorCode = SSEErrorCode.INTERNAL_ERROR) -> str: + """Format an SSE error event.""" + data = json.dumps({"error": error, "code": code.value}, separators=(",", ":")) + return f"id: {event_id}\nevent: error\ndata: {data}\n\n" + + +def format_buffered_event(event: BufferedEvent) -> str: + """Format a buffered event for replay.""" + event_type = sanitize_event_type(event.type) + return f"id: {event.id}\nevent: {event_type}\ndata: {event.data}\n\n" + + +def format_heartbeat() -> str: + """Format an SSE heartbeat comment.""" + return ": heartbeat\n\n" + + +def format_buffer_overflow_warning(last_event_id: str) -> str: + """Format a buffer overflow warning.""" + data = json.dumps({ + "warning": "Buffer overflow detected - some events were lost", + "code": SSEWarningCode.BUFFER_OVERFLOW_RESTART.value, + "lastEventId": last_event_id, + }, separators=(",", ":")) + return f"event: warning\ndata: {data}\n\n" + + +SSE_HEADERS = { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Encoding": "none", + "X-Accel-Buffering": "no", +} diff --git a/packages/appkit-py/src/appkit_py/stream/stream_manager.py b/packages/appkit-py/src/appkit_py/stream/stream_manager.py new file mode 100644 index 00000000..23436542 --- /dev/null +++ b/packages/appkit-py/src/appkit_py/stream/stream_manager.py @@ -0,0 +1,143 @@ +"""StreamManager — core SSE streaming orchestration. + +Ports the TypeScript StreamManager from packages/appkit/src/stream/stream-manager.ts. +Handles async generator-based event streams with: +- UUID event IDs +- Ring buffer for reconnection replay (persisted per stream_id) +- Heartbeat keep-alive +- Error event emission +- Graceful abort via tracked disconnect events +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import time +import uuid +from typing import Any, AsyncGenerator, Callable, Coroutine + +from .buffers import BufferedEvent, EventRingBuffer +from .defaults import STREAM_DEFAULTS +from .sse_writer import ( + format_buffered_event, + format_buffer_overflow_warning, + format_error, + format_event, + format_heartbeat, +) +from .types import SSEErrorCode + +logger = logging.getLogger("appkit.stream") + +SendFunc = Callable[[str], Coroutine[Any, Any, None]] + + +class StreamManager: + """Manages SSE event streaming with reconnection support.""" + + def __init__( + self, + buffer_size: int = STREAM_DEFAULTS["buffer_size"], + heartbeat_interval: float = STREAM_DEFAULTS["heartbeat_interval"], + ) -> None: + self._buffer_size = buffer_size + self._heartbeat_interval = heartbeat_interval + # Persist buffers per stream_id for reconnection replay + self._stream_buffers: dict[str, EventRingBuffer] = {} + # Track active disconnect events for abort_all() + self._active_disconnects: set[asyncio.Event] = set() + + async def stream( + self, + send: SendFunc, + handler: Callable[..., AsyncGenerator[dict[str, Any], None]], + *, + on_disconnect: asyncio.Event | None = None, + last_event_id: str | None = None, + stream_id: str | None = None, + ) -> None: + """Stream events from an async generator to the client.""" + sid = stream_id or str(uuid.uuid4()) + # Get or create a persistent buffer for this stream + if sid not in self._stream_buffers: + self._stream_buffers[sid] = EventRingBuffer(capacity=self._buffer_size) + event_buffer = self._stream_buffers[sid] + + disconnect = on_disconnect or asyncio.Event() + self._active_disconnects.add(disconnect) + heartbeat_task: asyncio.Task | None = None + + try: + heartbeat_task = asyncio.create_task( + self._heartbeat_loop(send, disconnect) + ) + + # Replay buffered events if reconnecting + if last_event_id: + if event_buffer.has_event(last_event_id): + missed = event_buffer.get_events_since(last_event_id) + if missed: + for event in missed: + await send(format_buffered_event(event)) + else: + # Buffer overflow — event was evicted + await send(format_buffer_overflow_warning(last_event_id)) + + # Stream events from handler + async for event in handler(signal=disconnect): + if disconnect.is_set(): + break + + event_id = str(uuid.uuid4()) + event_type = str(event.get("type", "message")) + event_data = json.dumps(event, separators=(",", ":")) + + event_buffer.add_event( + BufferedEvent( + id=event_id, + type=event_type, + data=event_data, + timestamp=time.time(), + ) + ) + + await send(format_event(event_id, event)) + + except Exception as exc: + error_id = str(uuid.uuid4()) + error_msg = str(exc) if str(exc) else type(exc).__name__ + try: + await send(format_error(error_id, error_msg)) + except Exception: + pass + logger.error("Stream error: %s", exc) + finally: + self._active_disconnects.discard(disconnect) + if heartbeat_task and not heartbeat_task.done(): + heartbeat_task.cancel() + try: + await heartbeat_task + except asyncio.CancelledError: + pass + + async def _heartbeat_loop(self, send: SendFunc, disconnect: asyncio.Event) -> None: + """Send periodic heartbeat comments to keep the connection alive.""" + try: + while not disconnect.is_set(): + await asyncio.sleep(self._heartbeat_interval) + if not disconnect.is_set(): + try: + await send(format_heartbeat()) + except Exception: + break + except asyncio.CancelledError: + pass + + def abort_all(self) -> None: + """Abort all active streams by setting their disconnect events.""" + for evt in list(self._active_disconnects): + evt.set() + self._active_disconnects.clear() + self._stream_buffers.clear() diff --git a/packages/appkit-py/src/appkit_py/stream/types.py b/packages/appkit-py/src/appkit_py/stream/types.py new file mode 100644 index 00000000..c0709e5a --- /dev/null +++ b/packages/appkit-py/src/appkit_py/stream/types.py @@ -0,0 +1,27 @@ +"""SSE stream types mirroring the TypeScript implementation.""" + +from __future__ import annotations + +import enum +from dataclasses import dataclass, field + + +class SSEErrorCode(str, enum.Enum): + TEMPORARY_UNAVAILABLE = "TEMPORARY_UNAVAILABLE" + TIMEOUT = "TIMEOUT" + INTERNAL_ERROR = "INTERNAL_ERROR" + INVALID_REQUEST = "INVALID_REQUEST" + STREAM_ABORTED = "STREAM_ABORTED" + STREAM_EVICTED = "STREAM_EVICTED" + + +class SSEWarningCode(str, enum.Enum): + BUFFER_OVERFLOW_RESTART = "BUFFER_OVERFLOW_RESTART" + + +@dataclass +class BufferedEvent: + id: str + type: str + data: str + timestamp: float diff --git a/packages/appkit-py/tests/__init__.py b/packages/appkit-py/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/tests/conftest.py b/packages/appkit-py/tests/conftest.py new file mode 100644 index 00000000..0329f409 --- /dev/null +++ b/packages/appkit-py/tests/conftest.py @@ -0,0 +1,62 @@ +"""Shared test fixtures for appkit-py tests. + +Integration tests are language-agnostic HTTP tests that can run against either +the TypeScript or Python backend. Set APPKIT_TEST_URL to point at the target server. +""" + +from __future__ import annotations + +import os +from collections.abc import AsyncGenerator + +import httpx +import pytest +import pytest_asyncio + + +@pytest.fixture(scope="session") +def base_url() -> str: + """Base URL for the backend server under test. + + Set APPKIT_TEST_URL env var to point at TS or Python backend. + Default: http://localhost:8000 + """ + return os.environ.get("APPKIT_TEST_URL", "http://localhost:8000") + + +@pytest.fixture(scope="session") +def auth_headers() -> dict[str, str]: + """Default auth headers simulating Databricks Apps proxy.""" + return { + "x-forwarded-user": "test-user@databricks.com", + "x-forwarded-access-token": "fake-obo-token-for-testing", + } + + +@pytest.fixture(scope="session") +def no_auth_headers() -> dict[str, str]: + """Empty headers for testing unauthenticated requests.""" + return {} + + +@pytest_asyncio.fixture +async def http_client( + base_url: str, auth_headers: dict[str, str] +) -> AsyncGenerator[httpx.AsyncClient]: + """Async HTTP client pre-configured with base URL and auth headers.""" + async with httpx.AsyncClient( + base_url=base_url, + headers=auth_headers, + timeout=httpx.Timeout(30.0, connect=10.0), + ) as client: + yield client + + +@pytest_asyncio.fixture +async def unauthed_client(base_url: str) -> AsyncGenerator[httpx.AsyncClient]: + """Async HTTP client with no auth headers.""" + async with httpx.AsyncClient( + base_url=base_url, + timeout=httpx.Timeout(30.0, connect=10.0), + ) as client: + yield client diff --git a/packages/appkit-py/tests/helpers/__init__.py b/packages/appkit-py/tests/helpers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/tests/helpers/sse_parser.py b/packages/appkit-py/tests/helpers/sse_parser.py new file mode 100644 index 00000000..7456f2d0 --- /dev/null +++ b/packages/appkit-py/tests/helpers/sse_parser.py @@ -0,0 +1,192 @@ +"""SSE (Server-Sent Events) parser for integration tests. + +Parses the exact wire format used by AppKit: + id: {uuid} + event: {type} + data: {json} + +Plus heartbeat comments: `: heartbeat\\n\\n` +""" + +from __future__ import annotations + +import json +import re +from dataclasses import dataclass, field + +import httpx + +UUID_PATTERN = re.compile( + r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.IGNORECASE +) + + +@dataclass +class SSEEvent: + """A single parsed SSE event.""" + + id: str | None = None + event: str | None = None + data: str | None = None + is_heartbeat: bool = False + raw_lines: list[str] = field(default_factory=list) + + @property + def is_error(self) -> bool: + return self.event == "error" + + @property + def parsed_data(self) -> dict | list | None: + """Parse the data field as JSON. Returns None if no data or parse failure.""" + if self.data is None: + return None + try: + return json.loads(self.data) + except (json.JSONDecodeError, TypeError): + return None + + @property + def has_valid_uuid_id(self) -> bool: + """Check if the event ID is a valid UUID v4 format.""" + if self.id is None: + return False + return bool(UUID_PATTERN.match(self.id)) + + +def parse_sse_text(text: str) -> list[SSEEvent]: + """Parse raw SSE text into a list of SSEEvent objects. + + Handles the standard SSE format: + - Lines starting with ':' are comments (heartbeats) + - Lines with 'field: value' format set event fields + - Empty lines delimit events + """ + events: list[SSEEvent] = [] + current_lines: list[str] = [] + current_id: str | None = None + current_event: str | None = None + current_data: str | None = None + + for raw_line in text.split("\n"): + line = raw_line + + # Empty line = event boundary + if line == "": + if current_data is not None or current_event is not None or current_id is not None: + events.append( + SSEEvent( + id=current_id, + event=current_event, + data=current_data, + is_heartbeat=False, + raw_lines=current_lines, + ) + ) + current_lines = [] + current_id = None + current_event = None + current_data = None + elif current_lines and all(l.startswith(":") for l in current_lines if l): + # Comment-only block (heartbeat) + events.append( + SSEEvent( + is_heartbeat=True, + raw_lines=current_lines, + ) + ) + current_lines = [] + continue + + current_lines.append(line) + + # Comment line (heartbeat) + if line.startswith(":"): + continue + + # Field: value parsing + if ":" in line: + field_name, _, value = line.partition(":") + value = value.lstrip(" ") # Strip single leading space per SSE spec + + if field_name == "id": + current_id = value + elif field_name == "event": + current_event = value + elif field_name == "data": + if current_data is None: + current_data = value + else: + current_data += "\n" + value + + # Handle trailing event without final newline + if current_data is not None or current_event is not None or current_id is not None: + events.append( + SSEEvent( + id=current_id, + event=current_event, + data=current_data, + is_heartbeat=False, + raw_lines=current_lines, + ) + ) + + return events + + +async def parse_sse_response(response: httpx.Response) -> list[SSEEvent]: + """Parse an httpx response as SSE events.""" + return parse_sse_text(response.text) + + +async def collect_sse_stream( + client: httpx.AsyncClient, + method: str, + url: str, + *, + json_body: dict | None = None, + headers: dict | None = None, + timeout: float = 30.0, + max_events: int = 100, +) -> list[SSEEvent]: + """Make a streaming request and collect SSE events. + + Uses httpx streaming to handle long-lived SSE connections with a timeout. + """ + events: list[SSEEvent] = [] + buffer = "" + + request_kwargs: dict = { + "method": method, + "url": url, + "timeout": timeout, + "headers": {**(headers or {}), "Accept": "text/event-stream"}, + } + if json_body is not None: + request_kwargs["json"] = json_body + + async with client.stream(**request_kwargs) as response: + async for chunk in response.aiter_text(): + buffer += chunk + # Parse complete events from buffer + while "\n\n" in buffer: + event_text, buffer = buffer.split("\n\n", 1) + parsed = parse_sse_text(event_text + "\n\n") + events.extend(parsed) + if len(events) >= max_events: + return events + + # Parse any remaining buffer + if buffer.strip(): + events.extend(parse_sse_text(buffer)) + + return events + + +def events_only(events: list[SSEEvent]) -> list[SSEEvent]: + """Filter out heartbeat events, returning only real events.""" + return [e for e in events if not e.is_heartbeat] + + +def heartbeats_only(events: list[SSEEvent]) -> list[SSEEvent]: + """Filter to only heartbeat events.""" + return [e for e in events if e.is_heartbeat] diff --git a/packages/appkit-py/tests/integration/__init__.py b/packages/appkit-py/tests/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/tests/integration/test_analytics.py b/packages/appkit-py/tests/integration/test_analytics.py new file mode 100644 index 00000000..04bc963d --- /dev/null +++ b/packages/appkit-py/tests/integration/test_analytics.py @@ -0,0 +1,130 @@ +"""Integration tests for the Analytics plugin API. + +Endpoints: + POST /api/analytics/query/:query_key → SSE stream + GET /api/analytics/arrow-result/:jobId → binary Arrow data +""" + +from __future__ import annotations + +import httpx +import pytest + +from tests.helpers.sse_parser import collect_sse_stream, events_only + +pytestmark = pytest.mark.integration + + +class TestAnalyticsQueryEndpoint: + """Tests for POST /api/analytics/query/:query_key.""" + + async def test_query_returns_sse_content_type(self, http_client: httpx.AsyncClient): + """Query endpoint must return SSE content type.""" + try: + async with http_client.stream( + "POST", + "/api/analytics/query/spend_data", + json={"format": "JSON"}, + timeout=20.0, + ) as resp: + if resp.status_code == 404: + pytest.skip("Query 'spend_data' not found — no query files configured") + content_type = resp.headers.get("content-type", "") + # Successful queries return SSE, errors return JSON + assert ( + "text/event-stream" in content_type + or "application/json" in content_type + ) + except (httpx.HTTPError, httpx.StreamError): + pytest.skip("Analytics endpoint not available") + + async def test_query_missing_key_returns_error(self, http_client: httpx.AsyncClient): + """Query with nonexistent key should return 404.""" + response = await http_client.post( + "/api/analytics/query/nonexistent_query_that_does_not_exist", + json={"format": "JSON"}, + ) + assert response.status_code == 404 + body = response.json() + assert "error" in body + + async def test_query_result_events_have_correct_format( + self, http_client: httpx.AsyncClient + ): + """Result events from analytics should have type field in their data.""" + try: + events = await collect_sse_stream( + http_client, + "POST", + "/api/analytics/query/spend_data", + json_body={"format": "JSON"}, + timeout=20.0, + max_events=5, + ) + except (httpx.HTTPError, httpx.StreamError): + pytest.skip("Analytics endpoint not available") + + real = events_only(events) + if not real: + pytest.skip("No analytics events received") + + for event in real: + if event.is_error: + # Error events are allowed — Databricks may not be configured + data = event.parsed_data + assert "error" in data + continue + data = event.parsed_data + assert data is not None, "Event data should be valid JSON" + assert "type" in data, f"Result event missing 'type': {data}" + + async def test_query_default_format_is_arrow_stream( + self, http_client: httpx.AsyncClient + ): + """When no format is specified, default should be ARROW_STREAM.""" + try: + events = await collect_sse_stream( + http_client, + "POST", + "/api/analytics/query/spend_data", + json_body={}, # No format specified + timeout=20.0, + max_events=5, + ) + except (httpx.HTTPError, httpx.StreamError): + pytest.skip("Analytics endpoint not available") + + real = events_only(events) + if not real: + pytest.skip("No analytics events received") + + # First non-error event should exist + for event in real: + if not event.is_error: + data = event.parsed_data + assert data is not None + break + + +class TestAnalyticsArrowEndpoint: + """Tests for GET /api/analytics/arrow-result/:jobId.""" + + async def test_arrow_result_not_found_returns_404(self, http_client: httpx.AsyncClient): + """Requesting a nonexistent job ID should return 404.""" + response = await http_client.get( + "/api/analytics/arrow-result/nonexistent-job-id-12345" + ) + assert response.status_code == 404 + body = response.json() + assert "error" in body + + async def test_arrow_result_has_correct_headers(self, http_client: httpx.AsyncClient): + """If an arrow result exists, it should have correct binary headers. + + Since we can't easily create a real job, this test just validates + the error response format for missing jobs. + """ + response = await http_client.get("/api/analytics/arrow-result/fake-job") + # Should be 404 with JSON error + assert response.status_code == 404 + assert "application/json" in response.headers.get("content-type", "") diff --git a/packages/appkit-py/tests/integration/test_auth_context.py b/packages/appkit-py/tests/integration/test_auth_context.py new file mode 100644 index 00000000..0dcee852 --- /dev/null +++ b/packages/appkit-py/tests/integration/test_auth_context.py @@ -0,0 +1,86 @@ +"""Integration tests for authentication and user context propagation. + +The AppKit backend uses two auth modes: +1. Service principal — configured via DATABRICKS_HOST/DATABRICKS_TOKEN env vars +2. User context (OBO) — forwarded via x-forwarded-user and x-forwarded-access-token headers + +The Databricks Apps proxy sets these headers automatically in production. +""" + +from __future__ import annotations + +import httpx +import pytest + +pytestmark = pytest.mark.integration + + +class TestAuthHeaders: + """Tests for auth header handling.""" + + async def test_health_works_without_auth(self, unauthed_client: httpx.AsyncClient): + """Health endpoint should not require auth.""" + response = await unauthed_client.get("/health") + assert response.status_code == 200 + + async def test_volumes_endpoint_works_without_auth( + self, unauthed_client: httpx.AsyncClient + ): + """The volumes list endpoint doesn't require user context.""" + response = await unauthed_client.get("/api/files/volumes") + # Should work — volumes list doesn't require OBO + assert response.status_code == 200 + + async def test_file_operations_require_user_context( + self, unauthed_client: httpx.AsyncClient + ): + """File operations (except volumes list) should require auth headers in OBO mode.""" + # First get a volume key + vol_resp = await unauthed_client.get("/api/files/volumes") + if vol_resp.status_code != 200: + pytest.skip("Files plugin not available") + volumes = vol_resp.json().get("volumes", []) + if not volumes: + pytest.skip("No volumes configured") + + volume = volumes[0] + response = await unauthed_client.get( + f"/api/files/{volume}/list" + ) + # Should either fail with auth error or succeed if service principal mode + # The key assertion: it should NOT crash — it should return a structured error + assert response.status_code in (200, 401, 403, 500) + if response.status_code >= 400: + body = response.json() + assert "error" in body + + async def test_authenticated_request_accepted(self, http_client: httpx.AsyncClient): + """Requests with proper auth headers should be accepted.""" + response = await http_client.get("/health") + assert response.status_code == 200 + + async def test_auth_headers_forwarded_format(self, http_client: httpx.AsyncClient): + """Auth headers should follow the x-forwarded-* format.""" + # The http_client fixture already includes these headers. + # This test validates that the server accepts them without error. + response = await http_client.get("/api/files/volumes") + assert response.status_code == 200 + + +class TestErrorResponseFormat: + """Tests for consistent error response formatting.""" + + async def test_404_returns_json_error(self, http_client: httpx.AsyncClient): + """404 errors should return JSON with an 'error' field.""" + response = await http_client.get("/api/files/nonexistent_volume/list") + assert response.status_code == 404 + body = response.json() + assert "error" in body + + async def test_error_includes_plugin_name(self, http_client: httpx.AsyncClient): + """Error responses from plugins should include the plugin name.""" + response = await http_client.get("/api/files/nonexistent_volume/list") + assert response.status_code == 404 + body = response.json() + assert "plugin" in body + assert body["plugin"] == "files" diff --git a/packages/appkit-py/tests/integration/test_files.py b/packages/appkit-py/tests/integration/test_files.py new file mode 100644 index 00000000..8a71da9b --- /dev/null +++ b/packages/appkit-py/tests/integration/test_files.py @@ -0,0 +1,198 @@ +"""Integration tests for the Files plugin API. + +Endpoints: + GET /api/files/volumes → { volumes: [...] } + GET /api/files/:volumeKey/list?path= → DirectoryEntry[] + GET /api/files/:volumeKey/read?path= → text/plain + GET /api/files/:volumeKey/download?path= → binary + Content-Disposition + GET /api/files/:volumeKey/raw?path= → binary + CSP sandbox + GET /api/files/:volumeKey/exists?path= → { exists: bool } + GET /api/files/:volumeKey/metadata?path= → FileMetadata + GET /api/files/:volumeKey/preview?path= → FilePreview + POST /api/files/:volumeKey/upload?path= → { success: true } + POST /api/files/:volumeKey/mkdir → { success: true } + DELETE /api/files/:volumeKey?path= → { success: true } +""" + +from __future__ import annotations + +import httpx +import pytest + +pytestmark = pytest.mark.integration + + +class TestFilesVolumes: + """Tests for GET /api/files/volumes.""" + + async def test_volumes_returns_200(self, http_client: httpx.AsyncClient): + response = await http_client.get("/api/files/volumes") + assert response.status_code == 200 + + async def test_volumes_returns_volume_list(self, http_client: httpx.AsyncClient): + response = await http_client.get("/api/files/volumes") + body = response.json() + assert "volumes" in body + assert isinstance(body["volumes"], list) + + async def test_volumes_returns_json(self, http_client: httpx.AsyncClient): + response = await http_client.get("/api/files/volumes") + assert "application/json" in response.headers.get("content-type", "") + + +class TestFilesUnknownVolume: + """Tests for unknown volume key.""" + + async def test_unknown_volume_returns_404(self, http_client: httpx.AsyncClient): + response = await http_client.get("/api/files/nonexistent_volume_xyz/list") + assert response.status_code == 404 + + async def test_unknown_volume_error_format(self, http_client: httpx.AsyncClient): + response = await http_client.get("/api/files/nonexistent_volume_xyz/list") + body = response.json() + assert "error" in body + assert "plugin" in body + assert body["plugin"] == "files" + + +class TestFilesPathValidation: + """Tests for path validation across all file endpoints.""" + + @pytest.fixture + def volume_key(self, http_client: httpx.AsyncClient) -> str: + """Get the first available volume key, or skip if none.""" + return "test" # Will 404 if not configured, which is fine for validation tests + + async def test_missing_path_returns_400(self, http_client: httpx.AsyncClient): + """Endpoints requiring path should return 400 when path is missing.""" + # read endpoint requires path + response = await http_client.get("/api/files/test/read") + # Either 400 (path validation) or 404 (unknown volume) is acceptable + assert response.status_code in (400, 404) + + async def test_null_bytes_in_path_rejected(self, http_client: httpx.AsyncClient): + """Paths containing null bytes must be rejected.""" + response = await http_client.get("/api/files/test/read", params={"path": "file\x00.txt"}) + # Either 400 (null byte rejection) or 404 (unknown volume) + assert response.status_code in (400, 404) + + async def test_long_path_rejected(self, http_client: httpx.AsyncClient): + """Paths exceeding 4096 characters must be rejected.""" + long_path = "a" * 4097 + response = await http_client.get("/api/files/test/read", params={"path": long_path}) + assert response.status_code in (400, 404) + + +class TestFilesListEndpoint: + """Tests for GET /api/files/:volumeKey/list.""" + + async def _get_first_volume(self, client: httpx.AsyncClient) -> str | None: + resp = await client.get("/api/files/volumes") + if resp.status_code != 200: + return None + volumes = resp.json().get("volumes", []) + return volumes[0] if volumes else None + + async def test_list_returns_array(self, http_client: httpx.AsyncClient): + volume = await self._get_first_volume(http_client) + if not volume: + pytest.skip("No volumes configured") + + response = await http_client.get(f"/api/files/{volume}/list") + assert response.status_code == 200 + body = response.json() + assert isinstance(body, list) + + async def test_list_with_path_param(self, http_client: httpx.AsyncClient): + volume = await self._get_first_volume(http_client) + if not volume: + pytest.skip("No volumes configured") + + response = await http_client.get(f"/api/files/{volume}/list", params={"path": "/"}) + # Should succeed or return API error (not crash) + assert response.status_code in (200, 401, 403, 404, 500) + + +class TestFilesExistsEndpoint: + """Tests for GET /api/files/:volumeKey/exists.""" + + async def _get_first_volume(self, client: httpx.AsyncClient) -> str | None: + resp = await client.get("/api/files/volumes") + if resp.status_code != 200: + return None + volumes = resp.json().get("volumes", []) + return volumes[0] if volumes else None + + async def test_exists_returns_boolean(self, http_client: httpx.AsyncClient): + volume = await self._get_first_volume(http_client) + if not volume: + pytest.skip("No volumes configured") + + response = await http_client.get( + f"/api/files/{volume}/exists", params={"path": "/nonexistent-file.txt"} + ) + if response.status_code == 200: + body = response.json() + assert "exists" in body + assert isinstance(body["exists"], bool) + else: + # API error (auth, etc.) — still valid + assert response.status_code in (401, 403, 500) + + +class TestFilesDownloadEndpoint: + """Tests for GET /api/files/:volumeKey/download.""" + + async def test_download_missing_path_returns_400(self, http_client: httpx.AsyncClient): + response = await http_client.get("/api/files/test/download") + assert response.status_code in (400, 404) + + +class TestFilesUploadEndpoint: + """Tests for POST /api/files/:volumeKey/upload.""" + + async def test_upload_missing_path_returns_400(self, http_client: httpx.AsyncClient): + response = await http_client.post( + "/api/files/test/upload", + content=b"file content", + headers={"content-type": "application/octet-stream"}, + ) + assert response.status_code in (400, 404) + + async def test_upload_oversized_returns_413(self, http_client: httpx.AsyncClient): + """Uploads exceeding max size should be rejected with 413.""" + # We can't fake Content-Length with httpx (protocol-level mismatch), + # so test by sending a large body to a known volume. + # First get a volume + vol_resp = await http_client.get("/api/files/volumes") + volumes = vol_resp.json().get("volumes", []) + if not volumes: + pytest.skip("No volumes configured — cannot test 413") + + volume = volumes[0] + # The actual check is server-side on Content-Length header. + # We verify the endpoint exists and handles the path correctly. + response = await http_client.post( + f"/api/files/{volume}/upload", + params={"path": "/test.txt"}, + content=b"small content", + headers={"content-type": "application/octet-stream"}, + ) + # Should not crash — returns success or server error (no Databricks) + assert response.status_code in (200, 401, 403, 413, 500) + + +class TestFilesMkdirEndpoint: + """Tests for POST /api/files/:volumeKey/mkdir.""" + + async def test_mkdir_missing_path_returns_400(self, http_client: httpx.AsyncClient): + response = await http_client.post("/api/files/test/mkdir", json={}) + assert response.status_code in (400, 404) + + +class TestFilesDeleteEndpoint: + """Tests for DELETE /api/files/:volumeKey.""" + + async def test_delete_missing_path_returns_400(self, http_client: httpx.AsyncClient): + response = await http_client.delete("/api/files/test") + assert response.status_code in (400, 404) diff --git a/packages/appkit-py/tests/integration/test_genie.py b/packages/appkit-py/tests/integration/test_genie.py new file mode 100644 index 00000000..c3cb7ca3 --- /dev/null +++ b/packages/appkit-py/tests/integration/test_genie.py @@ -0,0 +1,158 @@ +"""Integration tests for the Genie plugin API. + +Endpoints: + POST /api/genie/:alias/messages → SSE stream + GET /api/genie/:alias/conversations/:conversationId → SSE stream + GET /api/genie/:alias/conversations/:conversationId/messages/:mid → SSE stream +""" + +from __future__ import annotations + +import httpx +import pytest + +from tests.helpers.sse_parser import collect_sse_stream, events_only + +pytestmark = pytest.mark.integration + + +class TestGenieSendMessage: + """Tests for POST /api/genie/:alias/messages.""" + + async def test_unknown_alias_returns_404(self, http_client: httpx.AsyncClient): + """Sending a message to an unknown space alias should return 404.""" + response = await http_client.post( + "/api/genie/nonexistent_alias_xyz/messages", + json={"content": "Hello"}, + ) + assert response.status_code == 404 + body = response.json() + assert "error" in body + + async def test_missing_content_returns_400(self, http_client: httpx.AsyncClient): + """Sending a message without content should return 400.""" + response = await http_client.post( + "/api/genie/demo/messages", + json={}, # No content field + ) + # 400 (missing content) or 404 (unknown alias) are both valid + assert response.status_code in (400, 404) + + async def test_send_message_returns_sse(self, http_client: httpx.AsyncClient): + """If demo space is configured, sending a message should return SSE.""" + try: + async with http_client.stream( + "POST", + "/api/genie/demo/messages", + json={"content": "What are the top products?"}, + timeout=30.0, + ) as resp: + if resp.status_code == 404: + pytest.skip("Genie 'demo' space not configured") + content_type = resp.headers.get("content-type", "") + assert "text/event-stream" in content_type + except (httpx.HTTPError, httpx.StreamError): + pytest.skip("Genie endpoint not available") + + async def test_send_message_events_include_message_start( + self, http_client: httpx.AsyncClient + ): + """Genie stream should start with a message_start event.""" + try: + events = await collect_sse_stream( + http_client, + "POST", + "/api/genie/demo/messages", + json_body={"content": "Hello"}, + timeout=30.0, + max_events=10, + ) + except (httpx.HTTPError, httpx.StreamError): + pytest.skip("Genie endpoint not available") + + real = events_only(events) + if not real: + pytest.skip("No genie events received") + + # First non-error event should be message_start + first_event = real[0] + if first_event.is_error: + pytest.skip("Got error instead of message_start — Genie may not be configured") + + data = first_event.parsed_data + assert data is not None + assert data.get("type") == "message_start" + assert "conversationId" in data + assert "messageId" in data + assert "spaceId" in data + + async def test_send_message_with_request_id(self, http_client: httpx.AsyncClient): + """Messages with a custom requestId query param should work.""" + response = await http_client.post( + "/api/genie/demo/messages", + params={"requestId": "custom-request-id-123"}, + json={"content": "Hello"}, + ) + # Either SSE stream or 404 (alias not found) + assert response.status_code in (200, 404) + + +class TestGenieGetConversation: + """Tests for GET /api/genie/:alias/conversations/:conversationId.""" + + async def test_unknown_alias_returns_404(self, http_client: httpx.AsyncClient): + response = await http_client.get( + "/api/genie/nonexistent_alias/conversations/conv-123" + ) + assert response.status_code == 404 + + async def test_get_conversation_returns_sse_or_error( + self, http_client: httpx.AsyncClient + ): + """Getting a conversation should return SSE or a structured error.""" + try: + async with http_client.stream( + "GET", + "/api/genie/demo/conversations/fake-conv-id", + timeout=15.0, + ) as resp: + if resp.status_code == 404: + pytest.skip("Genie 'demo' space not configured") + content_type = resp.headers.get("content-type", "") + # Should be SSE or JSON error + assert ( + "text/event-stream" in content_type + or "application/json" in content_type + ) + except (httpx.HTTPError, httpx.StreamError): + pytest.skip("Genie endpoint not available") + + +class TestGenieGetMessage: + """Tests for GET /api/genie/:alias/conversations/:convId/messages/:msgId.""" + + async def test_unknown_alias_returns_404(self, http_client: httpx.AsyncClient): + response = await http_client.get( + "/api/genie/nonexistent_alias/conversations/conv-1/messages/msg-1" + ) + assert response.status_code == 404 + + async def test_get_message_returns_sse_or_error( + self, http_client: httpx.AsyncClient + ): + """Getting a message should return SSE or a structured error.""" + try: + async with http_client.stream( + "GET", + "/api/genie/demo/conversations/fake-conv/messages/fake-msg", + timeout=15.0, + ) as resp: + if resp.status_code == 404: + pytest.skip("Genie 'demo' space not configured") + content_type = resp.headers.get("content-type", "") + assert ( + "text/event-stream" in content_type + or "application/json" in content_type + ) + except (httpx.HTTPError, httpx.StreamError): + pytest.skip("Genie endpoint not available") diff --git a/packages/appkit-py/tests/integration/test_health.py b/packages/appkit-py/tests/integration/test_health.py new file mode 100644 index 00000000..b6e8478d --- /dev/null +++ b/packages/appkit-py/tests/integration/test_health.py @@ -0,0 +1,34 @@ +"""Integration tests for the /health endpoint. + +These tests validate the health check contract that must be identical +between TypeScript and Python backends. +""" + +from __future__ import annotations + +import httpx +import pytest + +pytestmark = pytest.mark.integration + + +class TestHealthEndpoint: + async def test_health_returns_200(self, http_client: httpx.AsyncClient): + response = await http_client.get("/health") + assert response.status_code == 200 + + async def test_health_returns_status_ok(self, http_client: httpx.AsyncClient): + response = await http_client.get("/health") + body = response.json() + assert body == {"status": "ok"} + + async def test_health_content_type_is_json(self, http_client: httpx.AsyncClient): + response = await http_client.get("/health") + content_type = response.headers.get("content-type", "") + assert "application/json" in content_type + + async def test_health_works_without_auth(self, unauthed_client: httpx.AsyncClient): + """Health endpoint should work without auth headers.""" + response = await unauthed_client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} diff --git a/packages/appkit-py/tests/integration/test_sse_protocol.py b/packages/appkit-py/tests/integration/test_sse_protocol.py new file mode 100644 index 00000000..37111f08 --- /dev/null +++ b/packages/appkit-py/tests/integration/test_sse_protocol.py @@ -0,0 +1,230 @@ +"""Integration tests for the SSE (Server-Sent Events) protocol. + +These tests validate the SSE wire format is correct and compatible with +the AppKit frontend's SSE client (connectSSE). They can run against any +SSE-producing endpoint — we use the reconnect plugin if available, or +analytics/genie endpoints. + +The exact SSE format required by the frontend: + id: {uuid} + event: {event_type} + data: {json_string} + (empty line) + +Plus heartbeat comments: `: heartbeat\\n\\n` +""" + +from __future__ import annotations + +import json + +import httpx +import pytest + +from tests.helpers.sse_parser import ( + SSEEvent, + collect_sse_stream, + events_only, + parse_sse_text, +) + +pytestmark = pytest.mark.integration + + +class TestSSEParser: + """Verify our SSE parser correctly handles the wire format.""" + + def test_parse_basic_event(self): + text = "id: abc-123\nevent: result\ndata: {\"type\":\"result\"}\n\n" + events = parse_sse_text(text) + real = events_only(events) + assert len(real) == 1 + assert real[0].id == "abc-123" + assert real[0].event == "result" + assert real[0].data == '{"type":"result"}' + + def test_parse_heartbeat(self): + text = ": heartbeat\n\n" + events = parse_sse_text(text) + assert len(events) == 1 + assert events[0].is_heartbeat is True + + def test_parse_multiple_events(self): + text = ( + "id: 1\nevent: a\ndata: {}\n\n" + ": heartbeat\n\n" + "id: 2\nevent: b\ndata: {}\n\n" + ) + events = parse_sse_text(text) + assert len(events) == 3 + real = events_only(events) + assert len(real) == 2 + + def test_parse_error_event(self): + text = 'id: err-1\nevent: error\ndata: {"error":"fail","code":"INTERNAL_ERROR"}\n\n' + events = events_only(parse_sse_text(text)) + assert len(events) == 1 + assert events[0].is_error is True + data = events[0].parsed_data + assert data["error"] == "fail" + assert data["code"] == "INTERNAL_ERROR" + + def test_uuid_validation(self): + event = SSEEvent(id="550e8400-e29b-41d4-a716-446655440000") + assert event.has_valid_uuid_id is True + + event = SSEEvent(id="not-a-uuid") + assert event.has_valid_uuid_id is False + + event = SSEEvent(id=None) + assert event.has_valid_uuid_id is False + + +class TestSSEProtocolCompliance: + """Tests that validate SSE protocol compliance against a running server. + + These require the reconnect plugin or any streaming endpoint to be available. + If no streaming endpoint is available, tests are skipped. + """ + + @pytest.fixture + async def sse_events(self, http_client: httpx.AsyncClient) -> list[SSEEvent] | None: + """Try to get SSE events from a known streaming endpoint. + + Tries the reconnect plugin first, then analytics with a dummy query. + Returns None if no streaming endpoint is available. + """ + # Try reconnect plugin (dev-playground specific) + try: + events = await collect_sse_stream( + http_client, "GET", "/api/reconnect/stream", timeout=15.0, max_events=3 + ) + if events: + return events + except (httpx.HTTPError, httpx.StreamError): + pass + + return None + + async def _find_sse_endpoint(self, client: httpx.AsyncClient) -> tuple[str, str, dict | None]: + """Find a working SSE endpoint. Returns (method, url, json_body).""" + # Try reconnect plugin first (TS dev-playground only) + try: + async with client.stream("GET", "/api/reconnect/stream", timeout=3.0) as resp: + if "text/event-stream" in resp.headers.get("content-type", ""): + return ("GET", "/api/reconnect/stream", None) + except (httpx.HTTPError, httpx.StreamError): + pass + + # Try genie with a known alias (requires genie space configured) + try: + async with client.stream( + "POST", "/api/genie/demo/messages", + json={"content": "test"}, timeout=3.0 + ) as resp: + if "text/event-stream" in resp.headers.get("content-type", ""): + return ("POST", "/api/genie/demo/messages", {"content": "test"}) + except (httpx.HTTPError, httpx.StreamError): + pass + + # Try analytics with any query + try: + async with client.stream( + "POST", "/api/analytics/query/test", + json={"format": "JSON"}, timeout=3.0 + ) as resp: + if "text/event-stream" in resp.headers.get("content-type", ""): + return ("POST", "/api/analytics/query/test", {"format": "JSON"}) + except (httpx.HTTPError, httpx.StreamError): + pass + + raise RuntimeError("No SSE endpoint available") + + async def test_sse_content_type(self, http_client: httpx.AsyncClient): + """SSE endpoints must return Content-Type: text/event-stream.""" + try: + method, url, body = await self._find_sse_endpoint(http_client) + kwargs: dict = {"timeout": 5.0} + if body: + kwargs["json"] = body + async with http_client.stream(method, url, **kwargs) as resp: + content_type = resp.headers.get("content-type", "") + assert "text/event-stream" in content_type + except RuntimeError: + pytest.skip("No streaming endpoint available") + + async def test_sse_cache_control(self, http_client: httpx.AsyncClient): + """SSE endpoints must set Cache-Control: no-cache.""" + try: + method, url, body = await self._find_sse_endpoint(http_client) + kwargs: dict = {"timeout": 5.0} + if body: + kwargs["json"] = body + async with http_client.stream(method, url, **kwargs) as resp: + cache_control = resp.headers.get("cache-control", "") + assert "no-cache" in cache_control + except RuntimeError: + pytest.skip("No streaming endpoint available") + + async def test_sse_event_has_id_event_data(self, sse_events: list[SSEEvent] | None): + """Each SSE event must have id, event, and data fields.""" + if sse_events is None: + pytest.skip("No streaming endpoint available") + + real = events_only(sse_events) + if not real: + pytest.skip("No real events received") + + for event in real: + assert event.id is not None, f"Event missing id: {event.raw_lines}" + assert event.event is not None, f"Event missing event type: {event.raw_lines}" + assert event.data is not None, f"Event missing data: {event.raw_lines}" + + async def test_sse_event_ids_are_uuids(self, sse_events: list[SSEEvent] | None): + """Event IDs should be UUID v4 format.""" + if sse_events is None: + pytest.skip("No streaming endpoint available") + + real = events_only(sse_events) + if not real: + pytest.skip("No real events received") + + for event in real: + assert event.has_valid_uuid_id, f"Event ID is not UUID: {event.id}" + + async def test_sse_data_is_valid_json(self, sse_events: list[SSEEvent] | None): + """Event data fields must be valid JSON.""" + if sse_events is None: + pytest.skip("No streaming endpoint available") + + real = events_only(sse_events) + if not real: + pytest.skip("No real events received") + + for event in real: + assert event.data is not None + try: + json.loads(event.data) + except json.JSONDecodeError: + pytest.fail(f"Event data is not valid JSON: {event.data[:100]}") + + async def test_sse_error_event_format(self): + """Error events must have the format: {error: string, code: SSEErrorCode}.""" + error_text = ( + 'id: e1\nevent: error\n' + 'data: {"error":"Something failed","code":"INTERNAL_ERROR"}\n\n' + ) + events = events_only(parse_sse_text(error_text)) + assert len(events) == 1 + data = events[0].parsed_data + assert "error" in data + assert "code" in data + valid_codes = { + "TEMPORARY_UNAVAILABLE", + "TIMEOUT", + "INTERNAL_ERROR", + "INVALID_REQUEST", + "STREAM_ABORTED", + "STREAM_EVICTED", + } + assert data["code"] in valid_codes diff --git a/packages/appkit-py/tests/unit/__init__.py b/packages/appkit-py/tests/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/appkit-py/tests/unit/test_cache_manager.py b/packages/appkit-py/tests/unit/test_cache_manager.py new file mode 100644 index 00000000..a170f5f7 --- /dev/null +++ b/packages/appkit-py/tests/unit/test_cache_manager.py @@ -0,0 +1,106 @@ +"""Unit tests for CacheManager.""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.unit + + +class TestCacheManager: + def test_import(self): + from appkit_py.cache.cache_manager import CacheManager + + mgr = CacheManager() + assert mgr is not None + + async def test_get_or_execute_miss(self): + from appkit_py.cache.cache_manager import CacheManager + + mgr = CacheManager() + call_count = 0 + + async def compute(): + nonlocal call_count + call_count += 1 + return {"result": 42} + + result = await mgr.get_or_execute( + key_parts=["test", "query1"], + fn=compute, + user_key="user-1", + ttl=60, + ) + assert result == {"result": 42} + assert call_count == 1 + + async def test_get_or_execute_hit(self): + from appkit_py.cache.cache_manager import CacheManager + + mgr = CacheManager() + call_count = 0 + + async def compute(): + nonlocal call_count + call_count += 1 + return {"result": 42} + + # First call — miss + await mgr.get_or_execute(["test", "q"], compute, "user-1", ttl=60) + # Second call — should be cached + result = await mgr.get_or_execute(["test", "q"], compute, "user-1", ttl=60) + assert result == {"result": 42} + assert call_count == 1 # Only called once + + async def test_different_users_separate_cache(self): + from appkit_py.cache.cache_manager import CacheManager + + mgr = CacheManager() + calls: list[str] = [] + + async def compute_for(user: str): + calls.append(user) + return f"result-{user}" + + r1 = await mgr.get_or_execute(["q"], lambda: compute_for("a"), "user-a", ttl=60) + r2 = await mgr.get_or_execute(["q"], lambda: compute_for("b"), "user-b", ttl=60) + assert r1 == "result-a" + assert r2 == "result-b" + assert len(calls) == 2 # Both users computed separately + + async def test_generate_key_deterministic(self): + from appkit_py.cache.cache_manager import CacheManager + + mgr = CacheManager() + k1 = mgr.generate_key(["a", "b", 1], "user") + k2 = mgr.generate_key(["a", "b", 1], "user") + assert k1 == k2 + + async def test_generate_key_different_for_different_inputs(self): + from appkit_py.cache.cache_manager import CacheManager + + mgr = CacheManager() + k1 = mgr.generate_key(["a"], "user-1") + k2 = mgr.generate_key(["b"], "user-1") + k3 = mgr.generate_key(["a"], "user-2") + assert k1 != k2 + assert k1 != k3 + + async def test_delete(self): + from appkit_py.cache.cache_manager import CacheManager + + mgr = CacheManager() + call_count = 0 + + async def compute(): + nonlocal call_count + call_count += 1 + return "value" + + await mgr.get_or_execute(["k"], compute, "u", ttl=60) + key = mgr.generate_key(["k"], "u") + mgr.delete(key) + + # Should recompute after deletion + await mgr.get_or_execute(["k"], compute, "u", ttl=60) + assert call_count == 2 diff --git a/packages/appkit-py/tests/unit/test_context.py b/packages/appkit-py/tests/unit/test_context.py new file mode 100644 index 00000000..db7dc8d8 --- /dev/null +++ b/packages/appkit-py/tests/unit/test_context.py @@ -0,0 +1,79 @@ +"""Unit tests for execution context (contextvars-based user context propagation).""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.unit + + +class TestExecutionContext: + def test_import(self): + from appkit_py.context.execution_context import ( + get_execution_context, + is_in_user_context, + run_in_user_context, + ) + + async def test_default_is_not_user_context(self): + from appkit_py.context.execution_context import is_in_user_context + + assert is_in_user_context() is False + + async def test_run_in_user_context(self): + from appkit_py.context.execution_context import ( + get_current_user_id, + is_in_user_context, + run_in_user_context, + ) + from appkit_py.context.user_context import UserContext + + ctx = UserContext( + user_id="test-user-123", + token="fake-token", + ) + + async def inner(): + assert is_in_user_context() is True + assert get_current_user_id() == "test-user-123" + return "done" + + result = await run_in_user_context(ctx, inner) + assert result == "done" + + async def test_context_does_not_leak(self): + from appkit_py.context.execution_context import ( + is_in_user_context, + run_in_user_context, + ) + from appkit_py.context.user_context import UserContext + + ctx = UserContext(user_id="u1", token="t1") + + async def inner(): + assert is_in_user_context() is True + + await run_in_user_context(ctx, inner) + # After exiting, should no longer be in user context + assert is_in_user_context() is False + + async def test_nested_user_contexts(self): + from appkit_py.context.execution_context import ( + get_current_user_id, + run_in_user_context, + ) + from appkit_py.context.user_context import UserContext + + ctx_outer = UserContext(user_id="outer", token="t1") + ctx_inner = UserContext(user_id="inner", token="t2") + + async def inner_fn(): + assert get_current_user_id() == "inner" + + async def outer_fn(): + assert get_current_user_id() == "outer" + await run_in_user_context(ctx_inner, inner_fn) + # After inner returns, should restore outer context + assert get_current_user_id() == "outer" + + await run_in_user_context(ctx_outer, outer_fn) diff --git a/packages/appkit-py/tests/unit/test_interceptors.py b/packages/appkit-py/tests/unit/test_interceptors.py new file mode 100644 index 00000000..c8689956 --- /dev/null +++ b/packages/appkit-py/tests/unit/test_interceptors.py @@ -0,0 +1,160 @@ +"""Unit tests for the execution interceptor chain. + +Interceptor order (outermost to innermost): + Telemetry → Timeout → Retry → Cache +""" + +from __future__ import annotations + +import asyncio + +import pytest + +pytestmark = pytest.mark.unit + + +class TestRetryInterceptor: + """Tests for RetryInterceptor with exponential backoff.""" + + async def test_success_on_first_attempt(self): + from appkit_py.plugin.interceptors.retry import RetryInterceptor + + interceptor = RetryInterceptor(attempts=3, initial_delay=0.01, max_delay=0.1) + call_count = 0 + + async def fn(): + nonlocal call_count + call_count += 1 + return "ok" + + result = await interceptor.intercept(fn) + assert result == "ok" + assert call_count == 1 + + async def test_retry_on_failure(self): + from appkit_py.plugin.interceptors.retry import RetryInterceptor + + interceptor = RetryInterceptor(attempts=3, initial_delay=0.01, max_delay=0.1) + call_count = 0 + + async def fn(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise RuntimeError("temporary failure") + return "ok" + + result = await interceptor.intercept(fn) + assert result == "ok" + assert call_count == 3 + + async def test_exhausted_retries_raises(self): + from appkit_py.plugin.interceptors.retry import RetryInterceptor + + interceptor = RetryInterceptor(attempts=2, initial_delay=0.01, max_delay=0.1) + + async def fn(): + raise RuntimeError("permanent failure") + + with pytest.raises(RuntimeError, match="permanent failure"): + await interceptor.intercept(fn) + + async def test_no_retry_when_attempts_is_one(self): + from appkit_py.plugin.interceptors.retry import RetryInterceptor + + interceptor = RetryInterceptor(attempts=1, initial_delay=0.01, max_delay=0.1) + call_count = 0 + + async def fn(): + nonlocal call_count + call_count += 1 + raise RuntimeError("fail") + + with pytest.raises(RuntimeError): + await interceptor.intercept(fn) + assert call_count == 1 + + +class TestTimeoutInterceptor: + """Tests for TimeoutInterceptor.""" + + async def test_completes_within_timeout(self): + from appkit_py.plugin.interceptors.timeout import TimeoutInterceptor + + interceptor = TimeoutInterceptor(timeout_seconds=5.0) + + async def fn(): + return "fast" + + result = await interceptor.intercept(fn) + assert result == "fast" + + async def test_timeout_raises(self): + from appkit_py.plugin.interceptors.timeout import TimeoutInterceptor + + interceptor = TimeoutInterceptor(timeout_seconds=0.05) + + async def fn(): + await asyncio.sleep(10) + return "slow" + + with pytest.raises((asyncio.TimeoutError, TimeoutError)): + await interceptor.intercept(fn) + + +class TestCacheInterceptor: + """Tests for CacheInterceptor.""" + + async def test_cache_miss_executes_function(self): + from appkit_py.plugin.interceptors.cache import CacheInterceptor + + cache_store: dict[str, object] = {} + interceptor = CacheInterceptor( + cache_store=cache_store, cache_key="test-key", ttl=60 + ) + call_count = 0 + + async def fn(): + nonlocal call_count + call_count += 1 + return {"data": "result"} + + result = await interceptor.intercept(fn) + assert result == {"data": "result"} + assert call_count == 1 + + async def test_cache_hit_skips_function(self): + import time + from appkit_py.plugin.interceptors.cache import CacheInterceptor + + cache_store: dict[str, object] = {"test-key": ({"data": "cached"}, time.time() + 60)} + interceptor = CacheInterceptor( + cache_store=cache_store, cache_key="test-key", ttl=60 + ) + call_count = 0 + + async def fn(): + nonlocal call_count + call_count += 1 + return {"data": "fresh"} + + result = await interceptor.intercept(fn) + assert result == {"data": "cached"} + assert call_count == 0 + + async def test_cache_disabled_always_executes(self): + from appkit_py.plugin.interceptors.cache import CacheInterceptor + + interceptor = CacheInterceptor( + cache_store={}, cache_key=None, ttl=60, enabled=False + ) + call_count = 0 + + async def fn(): + nonlocal call_count + call_count += 1 + return "result" + + await interceptor.intercept(fn) + await interceptor.intercept(fn) + assert call_count == 2 diff --git a/packages/appkit-py/tests/unit/test_plugin.py b/packages/appkit-py/tests/unit/test_plugin.py new file mode 100644 index 00000000..fb90dca3 --- /dev/null +++ b/packages/appkit-py/tests/unit/test_plugin.py @@ -0,0 +1,77 @@ +"""Unit tests for the Plugin base class.""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.unit + + +class TestPluginBase: + def test_import(self): + from appkit_py.plugin.plugin import Plugin + + async def test_default_setup_is_noop(self): + from appkit_py.plugin.plugin import Plugin + + class TestPlugin(Plugin): + name = "test" + + plugin = TestPlugin(config={}) + await plugin.setup() # Should not raise + + def test_default_exports_empty(self): + from appkit_py.plugin.plugin import Plugin + + class TestPlugin(Plugin): + name = "test" + + plugin = TestPlugin(config={}) + assert plugin.exports() == {} + + def test_default_client_config_empty(self): + from appkit_py.plugin.plugin import Plugin + + class TestPlugin(Plugin): + name = "test" + + plugin = TestPlugin(config={}) + assert plugin.client_config() == {} + + def test_default_inject_routes_is_noop(self): + from appkit_py.plugin.plugin import Plugin + + class TestPlugin(Plugin): + name = "test" + + plugin = TestPlugin(config={}) + # Should not raise with a mock router + plugin.inject_routes(None) + + +class TestPluginAsUser: + """Tests for the as_user() proxy pattern.""" + + async def test_as_user_returns_proxy(self): + from appkit_py.plugin.plugin import Plugin + + class TestPlugin(Plugin): + name = "test" + + async def get_data(self): + return "data" + + plugin = TestPlugin(config={}) + # Create a mock request with auth headers + mock_request = type( + "MockRequest", + (), + { + "headers": { + "x-forwarded-user": "test-user", + "x-forwarded-access-token": "test-token", + } + }, + )() + proxy = plugin.as_user(mock_request) + assert proxy is not plugin # Should be a different object diff --git a/packages/appkit-py/tests/unit/test_query_processor.py b/packages/appkit-py/tests/unit/test_query_processor.py new file mode 100644 index 00000000..79a69fbf --- /dev/null +++ b/packages/appkit-py/tests/unit/test_query_processor.py @@ -0,0 +1,52 @@ +"""Unit tests for QueryProcessor (SQL parameter processing).""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.unit + + +class TestQueryProcessor: + def test_import(self): + from appkit_py.plugins.analytics.query import QueryProcessor + + qp = QueryProcessor() + assert qp is not None + + def test_hash_query_deterministic(self): + from appkit_py.plugins.analytics.query import QueryProcessor + + qp = QueryProcessor() + h1 = qp.hash_query("SELECT * FROM table") + h2 = qp.hash_query("SELECT * FROM table") + assert h1 == h2 + + def test_hash_query_different_for_different_queries(self): + from appkit_py.plugins.analytics.query import QueryProcessor + + qp = QueryProcessor() + h1 = qp.hash_query("SELECT * FROM table1") + h2 = qp.hash_query("SELECT * FROM table2") + assert h1 != h2 + + def test_convert_to_sql_parameters_no_params(self): + from appkit_py.plugins.analytics.query import QueryProcessor + + qp = QueryProcessor() + result = qp.convert_to_sql_parameters("SELECT 1", None) + assert result["statement"] == "SELECT 1" + + def test_convert_to_sql_parameters_with_named_params(self): + from appkit_py.plugins.analytics.query import QueryProcessor + + qp = QueryProcessor() + result = qp.convert_to_sql_parameters( + "SELECT * FROM t WHERE id = :id AND name = :name", + { + "id": {"__sql_type": "NUMERIC", "value": "42"}, + "name": {"__sql_type": "STRING", "value": "test"}, + }, + ) + assert "parameters" in result + assert isinstance(result["parameters"], list) diff --git a/packages/appkit-py/tests/unit/test_ring_buffer.py b/packages/appkit-py/tests/unit/test_ring_buffer.py new file mode 100644 index 00000000..6aa6b3ba --- /dev/null +++ b/packages/appkit-py/tests/unit/test_ring_buffer.py @@ -0,0 +1,130 @@ +"""Unit tests for RingBuffer and EventRingBuffer. + +These test the SSE event buffering used for stream reconnection. +""" + +from __future__ import annotations + +import time + +import pytest + +pytestmark = pytest.mark.unit + + +class TestRingBuffer: + """Tests for the generic RingBuffer.""" + + def test_import(self): + """RingBuffer should be importable from appkit_py.stream.buffers.""" + from appkit_py.stream.buffers import RingBuffer + + buf = RingBuffer(capacity=5) + assert buf is not None + + def test_add_and_retrieve(self): + from appkit_py.stream.buffers import RingBuffer + + buf: RingBuffer[str] = RingBuffer(capacity=5) + buf.add("key1", "value1") + assert buf.get("key1") == "value1" + + def test_capacity_eviction(self): + from appkit_py.stream.buffers import RingBuffer + + buf: RingBuffer[str] = RingBuffer(capacity=3) + buf.add("a", "1") + buf.add("b", "2") + buf.add("c", "3") + buf.add("d", "4") # Should evict "a" + assert buf.get("a") is None + assert buf.get("d") == "4" + + def test_lru_eviction_order(self): + from appkit_py.stream.buffers import RingBuffer + + buf: RingBuffer[str] = RingBuffer(capacity=3) + buf.add("a", "1") + buf.add("b", "2") + buf.add("c", "3") + # Oldest (a) should be evicted first + buf.add("d", "4") + assert buf.get("a") is None + assert buf.get("b") == "2" + + def test_size_tracking(self): + from appkit_py.stream.buffers import RingBuffer + + buf: RingBuffer[str] = RingBuffer(capacity=5) + assert len(buf) == 0 + buf.add("a", "1") + assert len(buf) == 1 + buf.add("b", "2") + assert len(buf) == 2 + + +class TestEventRingBuffer: + """Tests for the SSE-specific EventRingBuffer.""" + + def test_import(self): + from appkit_py.stream.buffers import EventRingBuffer + + buf = EventRingBuffer(capacity=10) + assert buf is not None + + def test_add_event(self): + from appkit_py.stream.buffers import BufferedEvent, EventRingBuffer + + buf = EventRingBuffer(capacity=10) + event = BufferedEvent( + id="evt-1", type="message", data='{"text":"hello"}', timestamp=time.time() + ) + buf.add_event(event) + assert buf.has_event("evt-1") + + def test_get_events_since(self): + from appkit_py.stream.buffers import BufferedEvent, EventRingBuffer + + buf = EventRingBuffer(capacity=10) + now = time.time() + for i in range(5): + buf.add_event( + BufferedEvent( + id=f"evt-{i}", type="msg", data=f'{{"i":{i}}}', timestamp=now + i + ) + ) + + # Get events after evt-2 (should return evt-3, evt-4) + since = buf.get_events_since("evt-2") + assert since is not None + assert len(since) == 2 + assert since[0].id == "evt-3" + assert since[1].id == "evt-4" + + def test_get_events_since_missing_id(self): + from appkit_py.stream.buffers import BufferedEvent, EventRingBuffer + + buf = EventRingBuffer(capacity=10) + buf.add_event( + BufferedEvent(id="evt-1", type="msg", data="{}", timestamp=time.time()) + ) + # Non-existent ID means buffer overflow — return None + result = buf.get_events_since("nonexistent") + assert result is None + + def test_buffer_overflow_eviction(self): + from appkit_py.stream.buffers import BufferedEvent, EventRingBuffer + + buf = EventRingBuffer(capacity=3) + now = time.time() + for i in range(5): + buf.add_event( + BufferedEvent(id=f"evt-{i}", type="msg", data="{}", timestamp=now + i) + ) + + # First two should be evicted + assert not buf.has_event("evt-0") + assert not buf.has_event("evt-1") + assert buf.has_event("evt-2") + assert buf.has_event("evt-3") + assert buf.has_event("evt-4") diff --git a/packages/appkit-py/tests/unit/test_stream_manager.py b/packages/appkit-py/tests/unit/test_stream_manager.py new file mode 100644 index 00000000..36189d3f --- /dev/null +++ b/packages/appkit-py/tests/unit/test_stream_manager.py @@ -0,0 +1,145 @@ +"""Unit tests for StreamManager. + +Tests the core SSE streaming orchestration including: +- Basic event streaming +- Heartbeat generation +- Stream reconnection via Last-Event-ID +- Error handling +- Multi-client broadcast +""" + +from __future__ import annotations + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest + +pytestmark = pytest.mark.unit + + +class TestStreamManager: + """Tests for StreamManager streaming behavior.""" + + def test_import(self): + from appkit_py.stream.stream_manager import StreamManager + + mgr = StreamManager() + assert mgr is not None + + async def test_basic_streaming(self): + """StreamManager should yield events from an async generator.""" + from appkit_py.stream.stream_manager import StreamManager + + mgr = StreamManager() + events_sent: list[str] = [] + + async def handler(signal=None): + for i in range(3): + yield {"type": "message", "count": i} + + async def mock_send(data: str): + events_sent.append(data) + + await mgr.stream(mock_send, handler, on_disconnect=asyncio.Event()) + + # Should have 3 events (plus possible heartbeats) + data_events = [e for e in events_sent if "event:" in e and "heartbeat" not in e] + assert len(data_events) >= 3 + + async def test_error_in_handler_sends_error_event(self): + """If the handler raises, an error SSE event should be sent.""" + from appkit_py.stream.stream_manager import StreamManager + + mgr = StreamManager() + events_sent: list[str] = [] + + async def failing_handler(signal=None): + yield {"type": "message", "data": "ok"} + raise RuntimeError("Something broke") + + async def mock_send(data: str): + events_sent.append(data) + + await mgr.stream(mock_send, failing_handler, on_disconnect=asyncio.Event()) + + # Should contain an error event + all_text = "".join(events_sent) + assert "event: error" in all_text + + async def test_abort_signal_stops_streaming(self): + """Setting abort should stop the stream.""" + from appkit_py.stream.stream_manager import StreamManager + + mgr = StreamManager() + events_sent: list[str] = [] + disconnect = asyncio.Event() + + async def slow_handler(signal=None): + for i in range(100): + if signal and signal.is_set(): + return + yield {"type": "message", "count": i} + await asyncio.sleep(0.01) + + async def mock_send(data: str): + events_sent.append(data) + + # Abort after a short delay + async def abort_soon(): + await asyncio.sleep(0.05) + disconnect.set() + + asyncio.create_task(abort_soon()) + await mgr.stream(mock_send, slow_handler, on_disconnect=disconnect) + + # Should have stopped early (not all 100 events) + data_events = [e for e in events_sent if "event:" in e and "heartbeat" not in e] + assert len(data_events) < 100 + + +class TestStreamManagerSSEFormat: + """Tests that StreamManager produces correct SSE wire format.""" + + async def test_event_has_id_event_data_fields(self): + from appkit_py.stream.stream_manager import StreamManager + + mgr = StreamManager() + events_sent: list[str] = [] + + async def handler(signal=None): + yield {"type": "test_event", "value": 42} + + async def mock_send(data: str): + events_sent.append(data) + + await mgr.stream(mock_send, handler, on_disconnect=asyncio.Event()) + + # Find the event in output + all_text = "".join(events_sent) + assert "id:" in all_text + assert "event:" in all_text + assert "data:" in all_text + + async def test_event_data_is_valid_json(self): + from appkit_py.stream.stream_manager import StreamManager + + mgr = StreamManager() + events_sent: list[str] = [] + + async def handler(signal=None): + yield {"type": "result", "payload": {"key": "value"}} + + async def mock_send(data: str): + events_sent.append(data) + + await mgr.stream(mock_send, handler, on_disconnect=asyncio.Event()) + + # Extract data lines and verify JSON + for chunk in events_sent: + for line in chunk.split("\n"): + if line.startswith("data:"): + data_str = line[len("data:"):].strip() + parsed = json.loads(data_str) + assert isinstance(parsed, dict) diff --git a/packages/appkit-ui/src/react/charts/types.ts b/packages/appkit-ui/src/react/charts/types.ts index 65804a74..fdcc55f1 100644 --- a/packages/appkit-ui/src/react/charts/types.ts +++ b/packages/appkit-ui/src/react/charts/types.ts @@ -5,7 +5,7 @@ import type { Table } from "apache-arrow"; // ============================================================================ /** Supported data formats for analytics queries */ -export type DataFormat = "json" | "arrow" | "auto"; +export type DataFormat = "json" | "arrow" | "arrow_stream" | "auto"; /** Chart orientation */ export type Orientation = "vertical" | "horizontal"; diff --git a/packages/appkit-ui/src/react/hooks/__tests__/use-chart-data.test.ts b/packages/appkit-ui/src/react/hooks/__tests__/use-chart-data.test.ts index 3d5e96f1..32ce52cb 100644 --- a/packages/appkit-ui/src/react/hooks/__tests__/use-chart-data.test.ts +++ b/packages/appkit-ui/src/react/hooks/__tests__/use-chart-data.test.ts @@ -205,7 +205,7 @@ describe("useChartData", () => { ); }); - test("auto-selects JSON by default when no heuristics match", () => { + test("auto-selects ARROW_STREAM by default when no heuristics match", () => { mockUseAnalyticsQuery.mockReturnValue({ data: [], loading: false, @@ -223,11 +223,11 @@ describe("useChartData", () => { expect(mockUseAnalyticsQuery).toHaveBeenCalledWith( "test", { limit: 100 }, - expect.objectContaining({ format: "JSON" }), + expect.objectContaining({ format: "ARROW_STREAM" }), ); }); - test("defaults to auto format (JSON) when format is not specified", () => { + test("defaults to auto format (ARROW_STREAM) when format is not specified", () => { mockUseAnalyticsQuery.mockReturnValue({ data: [], loading: false, @@ -243,7 +243,7 @@ describe("useChartData", () => { expect(mockUseAnalyticsQuery).toHaveBeenCalledWith( "test", undefined, - expect.objectContaining({ format: "JSON" }), + expect.objectContaining({ format: "ARROW_STREAM" }), ); }); }); diff --git a/packages/appkit-ui/src/react/hooks/types.ts b/packages/appkit-ui/src/react/hooks/types.ts index 5db725fc..f25f17dd 100644 --- a/packages/appkit-ui/src/react/hooks/types.ts +++ b/packages/appkit-ui/src/react/hooks/types.ts @@ -5,7 +5,7 @@ import type { Table } from "apache-arrow"; // ============================================================================ /** Supported response formats for analytics queries */ -export type AnalyticsFormat = "JSON" | "ARROW"; +export type AnalyticsFormat = "JSON" | "ARROW" | "ARROW_STREAM"; /** * Typed Arrow Table - preserves row type information for type inference. @@ -32,8 +32,10 @@ export interface TypedArrowTable< // ============================================================================ /** Options for configuring an analytics SSE query */ -export interface UseAnalyticsQueryOptions