diff --git a/cli/serve/app.py b/cli/serve/app.py index 583b28c01..31b0f2da3 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -7,6 +7,7 @@ import sys import time import uuid +from typing import Any try: import typer @@ -14,6 +15,7 @@ from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse + from pydantic import BaseModel except ImportError as e: raise ImportError( "The 'm serve' command requires extra dependencies. " @@ -31,7 +33,9 @@ OpenAIError, OpenAIErrorResponse, ) +from .schema_converter import json_schema_to_pydantic from .streaming import stream_chat_completion_chunks +from .utils import extract_finish_reason app = FastAPI( title="M serve OpenAI API Compatible Server", @@ -108,7 +112,7 @@ def _build_model_options(request: ChatCompletionRequest) -> dict: "presence_penalty", # Presence penalty - not yet implemented "frequency_penalty", # Frequency penalty - not yet implemented "logit_bias", # Logit bias - not yet implemented - "response_format", # Response format (json_object) - not yet implemented + "response_format", # Response format - handled separately "functions", # Legacy function calling - not yet implemented "function_call", # Legacy function calling - not yet implemented "tools", # Tool calling - not yet implemented @@ -137,6 +141,10 @@ def _build_model_options(request: ChatCompletionRequest) -> dict: def make_chat_endpoint(module): """Makes a chat endpoint using a custom module.""" + # Inspect serve function once at endpoint creation time + serve_sig = inspect.signature(module.serve) + accepts_format = "format" in serve_sig.parameters + is_async = inspect.iscoroutinefunction(module.serve) async def endpoint(request: ChatCompletionRequest): try: @@ -154,22 +162,49 @@ async def endpoint(request: ChatCompletionRequest): model_options = _build_model_options(request) + # Handle response_format + format_model: type[BaseModel] | None = None + if request.response_format is not None: + if request.response_format.type == "json_schema": + if request.response_format.json_schema is None: + return create_openai_error_response( + status_code=400, + message="json_schema field is required when response_format.type is 'json_schema'", + error_type="invalid_request_error", + param="response_format.json_schema", + ) + try: + format_model = json_schema_to_pydantic( + request.response_format.json_schema.schema_, + request.response_format.json_schema.name, + ) + except ValueError as e: + return create_openai_error_response( + status_code=400, + message=f"Invalid JSON schema: {e!s}", + error_type="invalid_request_error", + param="response_format.json_schema.schema", + ) + # For "json_object" and "text", format_model remains None + # Note: "json_object" mode is not yet implemented - the backend + # receives no signal to produce JSON output (same as "text" mode) + + # Build kwargs for serve call + serve_kwargs: dict[str, Any] = { + "input": request.messages, + "requirements": request.requirements, + "model_options": model_options, + } + if accepts_format: + serve_kwargs["format"] = format_model + # Detect if serve is async or sync and handle accordingly - if inspect.iscoroutinefunction(module.serve): + if is_async: # It's async, await it directly - output = await module.serve( - input=request.messages, - requirements=request.requirements, - model_options=model_options, - ) + output = await module.serve(**serve_kwargs) else: # It's sync, run in thread pool to avoid blocking event loop - output = await asyncio.to_thread( - module.serve, - input=request.messages, - requirements=request.requirements, - model_options=model_options, - ) + output = await asyncio.to_thread(module.serve, **serve_kwargs) # system_fingerprint represents backend config hash, not model name # The model name is already in response.model (line 73) @@ -200,7 +235,7 @@ async def endpoint(request: ChatCompletionRequest): message=ChatCompletionMessage( content=output.value, role="assistant" ), - finish_reason="stop", + finish_reason=extract_finish_reason(output), ) ], object="chat.completion", # type: ignore diff --git a/cli/serve/models.py b/cli/serve/models.py index 7e738730e..64ee1e11f 100644 --- a/cli/serve/models.py +++ b/cli/serve/models.py @@ -29,8 +29,26 @@ class ToolFunction(BaseModel): function: FunctionDefinition +class JsonSchemaFormat(BaseModel): + """JSON Schema definition for structured output.""" + + name: str + """Name of the schema.""" + + schema_: dict[str, Any] = Field(alias="schema") + """JSON Schema definition.""" + + strict: bool | None = None + """Accepted for OpenAI compatibility; currently ignored by ``m serve``.""" + + model_config = {"populate_by_name": True} + + class ResponseFormat(BaseModel): - type: Literal["text", "json_object"] + type: Literal["text", "json_object", "json_schema"] + + json_schema: JsonSchemaFormat | None = None + """JSON Schema definition when type is 'json_schema'.""" class StreamOptions(BaseModel): diff --git a/cli/serve/schema_converter.py b/cli/serve/schema_converter.py new file mode 100644 index 000000000..866e9341d --- /dev/null +++ b/cli/serve/schema_converter.py @@ -0,0 +1,407 @@ +"""Helpers for converting OpenAI-style JSON Schema response formats.""" + +from enum import Enum +from typing import Annotated, Any, Literal, cast + +from pydantic import BaseModel, ConfigDict, Strict, create_model + + +def json_schema_to_pydantic( + schema: dict[str, Any], model_name: str = "DynamicModel" +) -> type[BaseModel]: + """Convert a practical subset of JSON Schema to a Pydantic model dynamically. + + This converter targets the OpenAI-style structured output schemas used by + ``m serve``. It intentionally maps JSON Schema features into Python typing + and Pydantic model semantics rather than attempting to preserve every JSON + Schema validation rule exactly. + + Supported features: + - top-level and nested ``object`` schemas with ``properties`` and ``required`` + - primitive types: ``string``, ``integer``, ``number``, ``boolean`` + - arrays via ``type: "array"`` with supported ``items`` + - string or primitive enums via ``enum`` + - nullable fields via ``type: ["", "null"]`` + - local ``$ref`` into ``$defs`` / ``definitions`` + - simple ``allOf`` merging for object-like schemas + - simple ``anyOf`` / ``oneOf`` unions when each branch is representable + - boolean and schema-valued ``additionalProperties`` + + Behavior notes: + - ``additionalProperties: false`` maps to ``extra="forbid"`` + - ``additionalProperties: true`` maps to ``extra="ignore"`` + - schema-valued ``additionalProperties`` maps to ``dict[str, ValueType]`` + only for open-ended object maps. It cannot be combined with named + ``properties`` because that is not representable as a single standard + Pydantic field shape without custom validators. + - sibling keywords next to ``$ref`` are merged over the resolved target, + matching common JSON Schema practice for OpenAI-compatible schemas + + Still unsupported and will raise ``ValueError``: + - non-local refs + - tuple-style array schemas + - object schemas without ``properties`` unless they are pure + ``additionalProperties`` maps + - schema constraints beyond representable typing/extra handling + + Args: + schema: JSON Schema definition (must have top-level ``type: "object"``). + model_name: Name for the generated Pydantic model. + + Returns: + A dynamically created Pydantic model class. + + Raises: + ValueError: If the schema is invalid or unsupported. + """ + defs = schema.get("$defs") + if defs is None: + defs = schema.get("definitions", {}) + if defs is None: + defs = {} + if not isinstance(defs, dict): + raise ValueError("Schema '$defs' must be an object") + + ref_cache: dict[str, Any] = {} + model_cache: dict[str, type[BaseModel]] = {} + + def _sanitize_model_name(name: str) -> str: + sanitized = "".join(ch if ch.isalnum() else "_" for ch in name).strip("_") + return sanitized or "DynamicModel" + + def _format_path(path: str) -> str: + return path or "" + + def _resolve_ref(ref: str) -> dict[str, Any]: + if ref in ref_cache: + resolved = ref_cache[ref] + if not isinstance(resolved, dict): + raise ValueError(f"Resolved ref is invalid: {ref}") + return resolved + + prefixes = ("#/$defs/", "#/definitions/") + for prefix in prefixes: + if ref.startswith(prefix): + key = ref[len(prefix) :] + if key not in defs: + raise ValueError(f"Unresolved local ref: {ref}") + target = defs[key] + if not isinstance(target, dict): + raise ValueError(f"Ref target must be an object: {ref}") + ref_cache[ref] = target + return target + + raise ValueError( + f"Only local $ref values into $defs/definitions are supported: {ref}" + ) + + def _merge_nullable(annotation: Any, is_nullable: bool) -> Any: + """Wrap an annotation in ``None`` when the source schema is nullable.""" + if is_nullable: + return annotation | None + return annotation + + def _enum_annotation(enum_values: list[Any], path: str) -> Any: + """Convert JSON Schema enum values into a Python typing annotation.""" + if not enum_values: + raise ValueError(f"{_format_path(path)} enum must not be empty") + + value_types = {type(value) for value in enum_values} + if len(value_types) != 1: + raise ValueError( + f"{_format_path(path)} enum values must all have the same primitive type" + ) + + value_type = value_types.pop() + allowed_types = {str, int, float, bool} + if value_type not in allowed_types: + raise ValueError( + f"{_format_path(path)} enum values must be string, integer, number, or boolean" + ) + + if value_type is str: + enum_name = _sanitize_model_name( + path.replace(".", "_").replace("[", "_").replace("]", "") + ) + members = { + ( + value.upper() if value and value[0].isalpha() else f"VALUE_{index}" + ): value + for index, value in enumerate(enum_values) + } + return Enum(enum_name or "GeneratedEnum", members) + + return Literal[tuple(enum_values)] + + def _merge_object_schemas( + schemas: list[dict[str, Any]], path: str + ) -> dict[str, Any]: + """Merge simple object schemas for ``allOf``. + + This supports the common OpenAI-compatible case where ``allOf`` is used + to compose object fragments. Conflicting keywords are rejected rather + than silently guessed. + """ + merged: dict[str, Any] = {"type": "object", "properties": {}, "required": []} + merged_required: set[str] = set() + merged_additional_properties: bool | dict[str, Any] = True + + for index, branch in enumerate(schemas): + resolved_branch = _normalize_schema(branch, f"{path}.allOf[{index}]") + branch_type = resolved_branch.get("type", "object") + if branch_type != "object": + raise ValueError( + f"{_format_path(path)} allOf only supports object branches" + ) + + branch_properties = resolved_branch.get("properties", {}) + if not isinstance(branch_properties, dict): + raise ValueError( + f"{_format_path(path)} allOf branch properties must be an object" + ) + + for property_name, property_schema in branch_properties.items(): + if property_name in merged["properties"]: + raise ValueError( + f"{_format_path(path)} allOf has conflicting property " + f"definitions for '{property_name}'" + ) + cast(dict[str, Any], merged["properties"])[property_name] = ( + property_schema + ) + + branch_required = resolved_branch.get("required", []) + if not isinstance(branch_required, list): + raise ValueError( + f"{_format_path(path)} allOf branch 'required' must be an array" + ) + merged_required.update( + field_name + for field_name in branch_required + if isinstance(field_name, str) + ) + + branch_additional_properties = resolved_branch.get( + "additionalProperties", True + ) + if branch_additional_properties is False: + merged_additional_properties = False + elif isinstance(branch_additional_properties, dict): + if merged_additional_properties is True: + merged_additional_properties = branch_additional_properties + elif merged_additional_properties is False: + continue + elif merged_additional_properties != branch_additional_properties: + raise ValueError( + f"{_format_path(path)} allOf has conflicting " + "additionalProperties schemas" + ) + + merged["required"] = sorted(merged_required) + merged["additionalProperties"] = merged_additional_properties + return merged + + def _union_annotation( + keyword: str, union_schemas: list[dict[str, Any]], path: str + ) -> Any: + """Convert ``anyOf``/``oneOf`` branches into a Python union annotation.""" + if not union_schemas: + raise ValueError(f"{_format_path(path)} {keyword} must not be empty") + + annotations: list[Any] = [] + for index, branch in enumerate(union_schemas): + annotations.append( + _schema_to_type(branch, f"{path}.{keyword}[{index}]", in_union=True) + ) + + annotation = annotations[0] + for branch_annotation in annotations[1:]: + annotation = annotation | branch_annotation + return annotation + + def _normalize_schema(field_schema: dict[str, Any], path: str) -> dict[str, Any]: + """Resolve refs and simple combinators into a normalized schema object.""" + if not isinstance(field_schema, dict): + raise ValueError(f"{_format_path(path)} schema must be an object") + + normalized = dict(field_schema) + + if "$ref" in normalized: + ref = normalized["$ref"] + if not isinstance(ref, str): + raise ValueError(f"{_format_path(path)} $ref must be a string") + resolved = _resolve_ref(ref) + sibling_keywords = {k: v for k, v in normalized.items() if k != "$ref"} + if sibling_keywords: + merged = dict(resolved) + merged.update(sibling_keywords) + normalized = merged + else: + normalized = dict(resolved) + + if "allOf" in normalized: + all_of = normalized.pop("allOf") + if not isinstance(all_of, list): + raise ValueError(f"{_format_path(path)} allOf must be an array") + merged = _merge_object_schemas(all_of, path) + merged.update(normalized) + normalized = merged + + return normalized + + def _schema_to_type( + field_schema: dict[str, Any], path: str, in_union: bool = False + ) -> Any: + """Convert a JSON Schema node into a Python typing annotation.""" + normalized_schema = _normalize_schema(field_schema, path) + + for keyword in ("anyOf", "oneOf"): + if keyword in normalized_schema: + union_schemas = normalized_schema[keyword] + if not isinstance(union_schemas, list): + raise ValueError(f"{_format_path(path)} {keyword} must be an array") + sibling_keywords = { + key: value + for key, value in normalized_schema.items() + if key != keyword + } + branch_schemas: list[dict[str, Any]] = [] + for branch in union_schemas: + if not isinstance(branch, dict): + raise ValueError( + f"{_format_path(path)} {keyword} branches must be objects" + ) + merged_branch = dict(branch) + for sibling_key, sibling_value in sibling_keywords.items(): + merged_branch.setdefault(sibling_key, sibling_value) + branch_schemas.append(merged_branch) + return _union_annotation(keyword, branch_schemas, path) + + if "enum" in normalized_schema: + enum_values = normalized_schema["enum"] + if not isinstance(enum_values, list): + raise ValueError(f"{_format_path(path)} enum must be an array") + return _enum_annotation(enum_values, path) + + field_type = normalized_schema.get("type", "string") + is_nullable = False + if isinstance(field_type, list): + non_null_types = [item for item in field_type if item != "null"] + null_count = len(field_type) - len(non_null_types) + if null_count > 1 or len(non_null_types) != 1: + raise ValueError( + f"{_format_path(path)} uses unsupported multi-type schema: {field_type}" + ) + if null_count == 1: + is_nullable = True + field_type = non_null_types[0] + + primitive_type_mapping = { + "string": str, + "integer": int, + "number": float, + "boolean": bool, + } + + if field_type in primitive_type_mapping: + base_type = primitive_type_mapping[field_type] + if in_union: + annotated_type = Annotated[base_type, Strict()] # type: ignore[valid-type] + return _merge_nullable(annotated_type, is_nullable) + return _merge_nullable(base_type, is_nullable) + + if field_type == "object": + properties = normalized_schema.get("properties") + additional_properties = normalized_schema.get("additionalProperties", True) + + if properties is None and isinstance(additional_properties, dict): + value_annotation = _schema_to_type(additional_properties, f"{path}.*") + dict_type = dict[str, value_annotation] # type: ignore[valid-type] + return _merge_nullable(dict_type, is_nullable) + + nested_name = _sanitize_model_name(f"{model_name}_{path.replace('.', '_')}") + nested_model = _object_schema_to_model(normalized_schema, nested_name, path) + return _merge_nullable(nested_model, is_nullable) + + if field_type == "array": + items_schema = normalized_schema.get("items") + item_annotation: Any + if items_schema is None: + item_annotation = Any + elif isinstance(items_schema, list): + raise ValueError( + f"{_format_path(path)} uses unsupported tuple-style array schema" + ) + elif isinstance(items_schema, dict): + item_annotation = _schema_to_type(items_schema, f"{path}[]") + else: + raise ValueError(f"{_format_path(path)} items must be an object") + # Construct list type at runtime to avoid mypy subscript error. + list_type = list[item_annotation] # type: ignore[valid-type] + return _merge_nullable(list_type, is_nullable) + + raise ValueError( + f"{_format_path(path)} uses unsupported JSON schema type: {field_type}" + ) + + def _object_schema_to_model( + object_schema: dict[str, Any], current_model_name: str, path: str + ) -> type[BaseModel]: + normalized_schema = _normalize_schema(object_schema, path) + if normalized_schema.get("type") != "object": + raise ValueError(f"{_format_path(path)} must be an object schema") + + cache_key = f"{current_model_name}:{id(object_schema)}" + cached = model_cache.get(cache_key) + if cached is not None: + return cached + + properties = normalized_schema.get("properties", {}) + required = normalized_schema.get("required", []) + additional_properties = normalized_schema.get("additionalProperties", True) + + if not isinstance(required, list): + raise ValueError(f"{_format_path(path)} 'required' must be an array") + + if not isinstance(properties, dict): + raise ValueError(f"{_format_path(path)} 'properties' must be an object") + + if not properties: + if isinstance(additional_properties, dict): + raise ValueError( + f"{_format_path(path)} is a pure additionalProperties map and should " + "be used as a field type, not as a model root" + ) + raise ValueError( + f"{_format_path(path)} must have a non-empty 'properties' object" + ) + + field_definitions: dict[str, Any] = {} + for field_name, field_schema in properties.items(): + child_path = f"{path}.{field_name}" if path else field_name + annotation = _schema_to_type(field_schema, child_path) + if field_name in required: + field_definitions[field_name] = (annotation, ...) + else: + field_definitions[field_name] = (annotation | None, None) + + if additional_properties not in (True, False): + raise ValueError( + f"{_format_path(path)} only supports boolean additionalProperties " + "when combined with named properties" + ) + + model_config = ConfigDict( + extra="forbid" if additional_properties is False else "ignore", + use_enum_values=True, + ) + dynamic_model = create_model( + current_model_name, __config__=model_config, **field_definitions + ) + model_cache[cache_key] = dynamic_model + return dynamic_model + + if not isinstance(schema, dict): + raise ValueError("Schema must be a dictionary") + + return _object_schema_to_model(schema, _sanitize_model_name(model_name), "") diff --git a/cli/serve/streaming.py b/cli/serve/streaming.py index 51ff33c3c..6b75f9c4b 100644 --- a/cli/serve/streaming.py +++ b/cli/serve/streaming.py @@ -14,6 +14,7 @@ OpenAIErrorResponse, StreamOptions, ) +from .utils import extract_finish_reason async def stream_chat_completion_chunks( @@ -26,6 +27,11 @@ async def stream_chat_completion_chunks( ) -> AsyncGenerator[str, None]: """Generate OpenAI-compatible SSE chat completion chunks from a model output. + This function acts as a pass-through streaming layer, forwarding chunks directly + from the backend to the client without buffering or validation. Format validation + for structured outputs happens at the module level (in the serve function) and + client side, not in this streaming layer. + Args: output: The model output object to stream. completion_id: Unique identifier for this completion. @@ -112,7 +118,7 @@ async def stream_chat_completion_chunks( ChatCompletionChunkChoice( index=0, delta=ChatCompletionChunkDelta(content=None), - finish_reason="stop", + finish_reason=extract_finish_reason(output), ) ], object="chat.completion.chunk", diff --git a/cli/serve/utils.py b/cli/serve/utils.py new file mode 100644 index 000000000..81fd39731 --- /dev/null +++ b/cli/serve/utils.py @@ -0,0 +1,49 @@ +from typing import Any, Literal + +FinishReason = Literal[ + "stop", "length", "content_filter", "tool_calls", "function_call" +] + + +def extract_finish_reason(output: Any) -> FinishReason: + """Extract finish_reason from ModelOutputThunk metadata. + + Args: + output: The model output thunk containing response metadata. + + Returns: + The finish_reason from the backend response, defaulting to "stop" if unavailable. + Possible values: "stop", "length", "content_filter", "tool_calls", "function_call". + """ + # Valid finish_reason values per OpenAI spec + valid_reasons: set[FinishReason] = { + "stop", + "length", + "content_filter", + "tool_calls", + "function_call", + } + + # Try to get finish_reason from the response metadata + # Different backends store this in different places + if hasattr(output, "_meta") and output._meta: + # Ollama backend stores response in chat_response with done_reason field + # (ollama.ChatResponse object with done_reason attribute) + chat_response = output._meta.get("chat_response") + if chat_response and hasattr(chat_response, "done_reason"): + done_reason = chat_response.done_reason + if done_reason in valid_reasons: + return done_reason + + # OpenAI backend stores full response dict in oai_chat_response + # (from chunk.model_dump() which includes choices array) + oai_response = output._meta.get("oai_chat_response") + if oai_response and isinstance(oai_response, dict): + choices = oai_response.get("choices", []) + if choices and len(choices) > 0: + finish_reason = choices[0].get("finish_reason") + if finish_reason in valid_reasons: + return finish_reason + + # Default to "stop" per OpenAI spec + return "stop" diff --git a/docs/examples/m_serve/README.md b/docs/examples/m_serve/README.md index 70fcb5f5e..c65ba8819 100644 --- a/docs/examples/m_serve/README.md +++ b/docs/examples/m_serve/README.md @@ -19,6 +19,14 @@ A dedicated streaming example for `m serve` that supports both modes: - `stream=True` returns an uncomputed thunk so the server can emit incremental Server-Sent Events (SSE) chunks +### m_serve_example_response_format.py +Example demonstrating structured output with the `response_format` parameter. + +**Key Features:** +- Supporting the `format` parameter in serve functions +- Structured output validation with JSON schemas +- Three format types: `text`, `json_object`, `json_schema` + ### pii_serve.py Example of serving a PII (Personally Identifiable Information) detection service. @@ -29,6 +37,9 @@ Client code for testing the served API endpoints with non-streaming requests. Client code demonstrating streaming responses using Server-Sent Events (SSE) against `m_serve_example_streaming.py`. +### client_response_format.py +Client code demonstrating all three `response_format` types with examples. + ## Concepts Demonstrated - **API Deployment**: Exposing Mellea programs as REST APIs @@ -37,6 +48,7 @@ against `m_serve_example_streaming.py`. - **Validation in Production**: Using requirements in deployed services - **Model Options**: Passing model configuration through API - **Streaming Responses**: Real-time token streaming via Server-Sent Events (SSE) +- **Structured Output**: Using `response_format` for JSON schema validation ## Basic Pattern @@ -84,6 +96,85 @@ m serve docs/examples/m_serve/m_serve_example_streaming.py python docs/examples/m_serve/client_streaming.py ``` +### Response Format + +```bash +# Start the response_format example server +m serve docs/examples/m_serve/m_serve_example_response_format.py + +# In another terminal, test with the response_format client +python docs/examples/m_serve/client_response_format.py +``` + +## Response Format Support + +The server supports structured output via the `response_format` parameter, which allows you to control the format of the model's response. This is compatible with OpenAI's response format API. + +**Three Format Types:** + +1. **`text`** (default): Plain text output +2. **`json_object`**: Unstructured JSON output (model decides the schema) +3. **`json_schema`**: Structured output validated against a JSON schema + +**Key Features:** +- Automatic JSON schema to Pydantic model conversion +- Schema validation for structured outputs +- OpenAI-compatible API +- Works with the `format` parameter in serve functions + +**Example - JSON Schema:** +```python +import openai + +client = openai.OpenAI(api_key="na", base_url="http://0.0.0.0:8080/v1") + +# Define a schema for structured output +person_schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "email": {"type": "string"}, + }, + "required": ["name", "age", "email"], +} + +response = client.chat.completions.create( + messages=[{"role": "user", "content": "Generate a person named Alice"}], + model="granite4:micro-h", + response_format={ + "type": "json_schema", + "json_schema": { + "name": "Person", + "schema": person_schema, + "strict": True, + }, + }, +) + +# Response will be valid JSON matching the schema +print(response.choices[0].message.content) +``` + +**Server Implementation:** +Your serve function must accept a `format` parameter to support `json_schema`: + +```python +def serve( + input: list[ChatMessage], + requirements: list[str] | None = None, + model_options: dict | None = None, + format: type | None = None, # Add this parameter +) -> ModelOutputThunk: + result = session.instruct( + description=input[-1].content, + requirements=requirements, + model_options=model_options, + format=format, # Pass to instruct() + ) + return result +``` + ## Streaming Support The server supports streaming responses via Server-Sent Events (SSE) when the diff --git a/docs/examples/m_serve/client_response_format.py b/docs/examples/m_serve/client_response_format.py new file mode 100644 index 000000000..a51f371b1 --- /dev/null +++ b/docs/examples/m_serve/client_response_format.py @@ -0,0 +1,254 @@ +# pytest: skip_always +"""Client demonstrating response_format parameter with m serve. + +This example shows how to use the three response_format types: +1. text - Plain text output (default) +2. json_object - Unstructured JSON output +3. json_schema - Structured output validated against a JSON schema + +Prerequisites: + Start the server first: + m serve docs/examples/m_serve/m_serve_example_response_format.py + + Then run this client: + python docs/examples/m_serve/client_response_format.py +""" + +import json + +import openai + +PORT = 8080 +BASE_URL = f"http://0.0.0.0:{PORT}/v1" + +# Create OpenAI client pointing to our m serve endpoint +client = openai.OpenAI(api_key="not-needed", base_url=BASE_URL) + + +def example_text_format(): + """Example 1: Plain text output (default behavior).""" + print("\n" + "=" * 60) + print("Example 1: Text Format (default)") + print("=" * 60) + + response = client.chat.completions.create( + model="granite4:micro-h", + messages=[{"role": "user", "content": "Write a haiku about programming."}], + response_format={"type": "text"}, + ) + + print(f"Response: {response.choices[0].message.content}") + + +def example_json_object(): + """Example 2: Unstructured JSON output. + + Note: json_object format requests JSON but doesn't enforce it strictly. + The model may wrap JSON in markdown or add explanatory text. + For strict JSON validation, use json_schema instead. + """ + print("\n" + "=" * 60) + print("Example 2: JSON Object Format") + print("=" * 60) + + response = client.chat.completions.create( + model="granite4:micro-h", + messages=[ + { + "role": "user", + "content": "Generate a JSON object with information about a fictional person. Include name, age, and city. Return ONLY the JSON, no markdown formatting.", + } + ], + response_format={"type": "json_object"}, + ) + + content = response.choices[0].message.content or "" + print(f"Response: {content}") + + # First, try to parse as-is (valid JSON) + try: + data = json.loads(content) + print("\n✓ Valid JSON received") + print(f"\nParsed JSON:\n{json.dumps(data, indent=2)}") + return + except json.JSONDecodeError: + # Not valid JSON, try to extract from markdown + print("\n⚠ Response is not valid JSON, attempting to extract from markdown...") + + # Fallback: Try to extract JSON from markdown code blocks + json_content = content + if "```json" in content: + # Extract JSON from markdown code block + start = content.find("```json") + 7 + end = content.find("```", start) + if end > start: + json_content = content[start:end].strip() + print("Extracted from ```json block") + elif "```" in content: + # Generic code block + start = content.find("```") + 3 + end = content.find("```", start) + if end > start: + json_content = content[start:end].strip() + print("Extracted from ``` block") + + # Try parsing the extracted content + try: + data = json.loads(json_content) + print( + f"\n✓ Successfully extracted and parsed JSON:\n{json.dumps(data, indent=2)}" + ) + except json.JSONDecodeError as e: + print("\n✗ Failed to parse JSON even after extraction") + print("Note: json_object format doesn't enforce strict JSON.") + print("For guaranteed JSON output, use json_schema format instead.") + print(f"Parse error: {e}") + + +def example_json_schema_person(): + """Example 3: Structured output with JSON schema validation.""" + print("\n" + "=" * 60) + print("Example 3: JSON Schema Format - Person") + print("=" * 60) + + # Define a JSON schema for a person + person_schema = { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The person's full name"}, + "age": {"type": "integer", "description": "The person's age in years"}, + "email": {"type": "string", "description": "The person's email address"}, + "city": { + "type": "string", + "description": "The city where the person lives", + }, + }, + "required": ["name", "age", "email"], + "additionalProperties": False, + } + + response = client.chat.completions.create( + model="granite4:micro-h", + messages=[ + { + "role": "user", + "content": "Generate information about a software engineer named Alice.", + } + ], + response_format={ + "type": "json_schema", + "json_schema": {"name": "Person", "schema": person_schema, "strict": True}, + }, + ) + + content = response.choices[0].message.content + print(f"Response: {content}") + + # Parse and validate the structured output + try: + data = json.loads(content or "{}") + print(f"\nParsed structured output:\n{json.dumps(data, indent=2)}") + + # Verify required fields + assert "name" in data, "Missing required field: name" + assert "age" in data, "Missing required field: age" + assert "email" in data, "Missing required field: email" + print("\n✓ All required fields present") + + except json.JSONDecodeError as e: + print(f"Failed to parse JSON: {e}") + except AssertionError as e: + print(f"Validation error: {e}") + + +def example_json_schema_product(): + """Example 4: Structured output for a product catalog.""" + print("\n" + "=" * 60) + print("Example 4: JSON Schema Format - Product") + print("=" * 60) + + # Define a JSON schema for a product + product_schema = { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Product name"}, + "price": {"type": "number", "description": "Price in USD"}, + "category": { + "type": "string", + "enum": ["electronics", "clothing", "food", "books"], + "description": "Product category", + }, + "in_stock": { + "type": "boolean", + "description": "Whether the product is in stock", + }, + "description": {"type": "string", "description": "Product description"}, + }, + "required": ["name", "price", "category", "in_stock"], + "additionalProperties": False, + } + + response = client.chat.completions.create( + model="granite4:micro-h", + messages=[ + { + "role": "user", + "content": "Generate a product listing for a laptop computer.", + } + ], + response_format={ + "type": "json_schema", + "json_schema": { + "name": "Product", + "schema": product_schema, + "strict": True, + }, + }, + ) + + content = response.choices[0].message.content + print(f"Response: {content}") + + # Parse and display the structured output + try: + data = json.loads(content or "{}") + print(f"\nParsed product data:\n{json.dumps(data, indent=2)}") + + # Verify the category is valid + valid_categories = ["electronics", "clothing", "food", "books"] + if data.get("category") in valid_categories: + print(f"\n✓ Valid category: {data['category']}") + + except json.JSONDecodeError as e: + print(f"Failed to parse JSON: {e}") + + +def main(): + """Run all examples.""" + print("\n" + "=" * 60) + print("RESPONSE_FORMAT EXAMPLES") + print("=" * 60) + print(f"Connecting to: {BASE_URL}") + print("=" * 60) + + try: + # Run all examples + example_text_format() + example_json_object() + example_json_schema_person() + example_json_schema_product() + + print("\n" + "=" * 60) + print("ALL EXAMPLES COMPLETED") + print("=" * 60) + + except Exception as e: + print(f"\nError: {e}") + print("\nMake sure the server is running:") + print( + f" m serve docs/examples/m_serve/m_serve_example_response_format.py --port {PORT}" + ) + + +if __name__ == "__main__": + main() diff --git a/docs/examples/m_serve/m_serve_example_response_format.py b/docs/examples/m_serve/m_serve_example_response_format.py new file mode 100644 index 000000000..4d2bc6b5c --- /dev/null +++ b/docs/examples/m_serve/m_serve_example_response_format.py @@ -0,0 +1,56 @@ +# pytest: ollama, e2e + +"""Example demonstrating response_format with m serve. + +This example shows how to use the response_format parameter to get structured +output from the model. The server supports three format types: +- text: Plain text output (default) +- json_object: Unstructured JSON output +- json_schema: Structured output validated against a JSON schema + +Run the server: + m serve docs/examples/m_serve/m_serve_example_response_format.py + +Test with the client: + python docs/examples/m_serve/client_response_format.py +""" + +from typing import Any + +import mellea +from cli.serve.models import ChatMessage +from mellea.core import ModelOutputThunk +from mellea.stdlib.context import ChatContext + +session = mellea.start_session(ctx=ChatContext()) + + +def serve( + input: list[ChatMessage], + requirements: list[str] | None = None, + model_options: dict[str, Any] | None = None, + format: type | None = None, +) -> ModelOutputThunk: + """Serve function that supports response_format parameter. + + Args: + input: List of chat messages from the client + requirements: Optional list of requirement strings + model_options: Optional model configuration parameters + format: Optional Pydantic model for structured output (from response_format) + + Returns: + ModelOutputThunk with the generated response + """ + message = input[-1].content or "No message provided" + + # When format is provided (from json_schema response_format), + # pass it to instruct() to get structured output + result = session.instruct( + description=message, + requirements=requirements, # type: ignore + model_options=model_options, + format=format, # This enables structured output validation + ) + + return result diff --git a/test/cli/test_schema_converter.py b/test/cli/test_schema_converter.py new file mode 100644 index 000000000..78004292a --- /dev/null +++ b/test/cli/test_schema_converter.py @@ -0,0 +1,207 @@ +"""Unit tests for JSON Schema to Pydantic conversion.""" + +import pytest + +from cli.serve.schema_converter import json_schema_to_pydantic + + +def test_json_schema_supports_enum_field(): + """Test that enum constraints are converted to a narrower Pydantic type.""" + model = json_schema_to_pydantic( + { + "type": "object", + "properties": {"status": {"type": "string", "enum": ["open", "closed"]}}, + "required": ["status"], + }, + "EnumExample", + ) + + parsed = model.model_validate({"status": "open"}) + assert parsed.model_dump()["status"] == "open" + + with pytest.raises(Exception): + model.model_validate({"status": "pending"}) + + +def test_json_schema_supports_nested_object_field(): + """Test that nested object schemas are converted recursively.""" + model = json_schema_to_pydantic( + { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name"], + "additionalProperties": False, + } + }, + "required": ["user"], + }, + "NestedObjectExample", + ) + + parsed = model.model_validate({"user": {"name": "Alice", "age": 30}}) + parsed_user = parsed.model_dump()["user"] + assert parsed_user["name"] == "Alice" + assert parsed_user["age"] == 30 + + with pytest.raises(Exception): + model.model_validate({"user": {"name": "Alice", "extra": True}}) + + +def test_json_schema_supports_array_items_schema(): + """Test that arrays validate their item schemas.""" + model = json_schema_to_pydantic( + { + "type": "object", + "properties": {"tags": {"type": "array", "items": {"type": "string"}}}, + "required": ["tags"], + }, + "ArrayExample", + ) + + parsed = model.model_validate({"tags": ["a", "b"]}) + assert parsed.model_dump()["tags"] == ["a", "b"] + + with pytest.raises(Exception): + model.model_validate({"tags": ["a", 1]}) + + +def test_json_schema_supports_top_level_ref(): + """Test that local refs are resolved from $defs.""" + model = json_schema_to_pydantic( + { + "type": "object", + "$defs": { + "User": { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + } + }, + "properties": {"user": {"$ref": "#/$defs/User"}}, + "required": ["user"], + }, + "RefExample", + ) + + parsed = model.model_validate({"user": {"name": "Alice"}}) + assert parsed.model_dump()["user"]["name"] == "Alice" + + with pytest.raises(Exception): + model.model_validate({"user": {}}) + + +def test_json_schema_supports_anyof_field(): + """Test that representable anyOf branches are converted to unions.""" + model = json_schema_to_pydantic( + { + "type": "object", + "properties": { + "value": {"anyOf": [{"type": "string"}, {"type": "integer"}]} + }, + "required": ["value"], + }, + "AnyOfExample", + ) + + parsed_string = model.model_validate({"value": "hello"}) + assert parsed_string.model_dump()["value"] == "hello" + + parsed_integer = model.model_validate({"value": 7}) + assert parsed_integer.model_dump()["value"] == 7 + + with pytest.raises(Exception): + model.model_validate({"value": True}) + + +def test_json_schema_supports_allof_object_merge(): + """Test that allOf merges object fragments into one model.""" + model = json_schema_to_pydantic( + { + "type": "object", + "properties": { + "user": { + "allOf": [ + { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + }, + { + "type": "object", + "properties": {"age": {"type": "integer"}}, + "required": ["age"], + "additionalProperties": False, + }, + ] + } + }, + "required": ["user"], + }, + "AllOfExample", + ) + + parsed = model.model_validate({"user": {"name": "Alice", "age": 30}}) + parsed_user = parsed.model_dump()["user"] + assert parsed_user["name"] == "Alice" + assert parsed_user["age"] == 30 + + with pytest.raises(Exception): + model.model_validate({"user": {"name": "Alice"}}) + + with pytest.raises(Exception): + model.model_validate({"user": {"name": "Alice", "age": 30, "extra": True}}) + + +def test_json_schema_supports_additional_properties_schema_map(): + """Test schema-valued additionalProperties as a typed dict field.""" + model = json_schema_to_pydantic( + { + "type": "object", + "properties": { + "metadata": { + "type": "object", + "additionalProperties": {"type": "integer"}, + } + }, + "required": ["metadata"], + }, + "AdditionalPropertiesMapExample", + ) + + parsed = model.model_validate({"metadata": {"a": 1, "b": 2}}) + assert parsed.model_dump()["metadata"] == {"a": 1, "b": 2} + + with pytest.raises(Exception): + model.model_validate({"metadata": {"a": "bad"}}) + + +def test_json_schema_supports_nested_ref_in_array_items(): + """Test local refs nested under array items.""" + model = json_schema_to_pydantic( + { + "type": "object", + "$defs": { + "Tag": { + "type": "object", + "properties": {"label": {"type": "string"}}, + "required": ["label"], + "additionalProperties": False, + } + }, + "properties": {"tags": {"type": "array", "items": {"$ref": "#/$defs/Tag"}}}, + "required": ["tags"], + }, + "NestedRefArrayExample", + ) + + parsed = model.model_validate({"tags": [{"label": "alpha"}]}) + assert parsed.model_dump()["tags"][0]["label"] == "alpha" + + with pytest.raises(Exception): + model.model_validate({"tags": [{"label": "alpha", "extra": True}]}) diff --git a/test/cli/test_serve.py b/test/cli/test_serve.py index 515cc82f2..dbe0fe751 100644 --- a/test/cli/test_serve.py +++ b/test/cli/test_serve.py @@ -1,19 +1,28 @@ """Tests for the m serve OpenAI-compatible API server.""" +import json from unittest.mock import Mock import pytest from fastapi import FastAPI from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse, StreamingResponse from fastapi.testclient import TestClient +from pydantic import BaseModel, ValidationError -from cli.serve.app import make_chat_endpoint +from cli.serve.app import make_chat_endpoint, validation_exception_handler from cli.serve.models import ( ChatCompletion, ChatCompletionRequest, ChatMessage, CompletionUsage, + FunctionDefinition, + FunctionParameters, + JsonSchemaFormat, + ResponseFormat, + ToolFunction, ) +from mellea.backends.model_options import ModelOption from mellea.core.base import ModelOutputThunk @@ -125,8 +134,6 @@ async def test_system_fingerprint_always_none(self, mock_module, sample_request) @pytest.mark.asyncio async def test_model_options_passed_correctly(self, mock_module, sample_request): """Test that model options are passed to serve function correctly.""" - from mellea.backends.model_options import ModelOption - mock_output = ModelOutputThunk("Test response") mock_module.serve.return_value = mock_output @@ -233,9 +240,6 @@ async def test_all_fields_together(self, mock_module, sample_request): @pytest.mark.asyncio async def test_n_greater_than_1_rejected(self, mock_module): """Test that requests with n > 1 are rejected with appropriate error.""" - import json - - from fastapi.responses import JSONResponse request = ChatCompletionRequest( model="test-model", @@ -293,7 +297,6 @@ async def test_n_less_than_1_rejected_by_pydantic(self, mock_module): so n=0 or negative values will be caught by the framework, not our code. This test documents that behavior. """ - from pydantic import ValidationError # Pydantic validation happens before the endpoint is called with pytest.raises(ValidationError) as exc_info: @@ -320,8 +323,6 @@ def test_n_zero_rejected_at_http_level(self, mock_module): and converted to OpenAI-compatible 400 errors (not FastAPI's default 422). """ # Setup a test app with the exception handler - from cli.serve.app import validation_exception_handler - app = FastAPI() app.add_exception_handler(RequestValidationError, validation_exception_handler) app.add_api_route( @@ -439,8 +440,6 @@ async def test_unsupported_params_excluded_from_model_options(self, mock_module) model_options = call_args.kwargs["model_options"] # Supported parameters should be present - from mellea.backends.model_options import ModelOption - assert ModelOption.TEMPERATURE in model_options assert model_options[ModelOption.TEMPERATURE] == 0.7 assert ModelOption.MAX_NEW_TOKENS in model_options @@ -457,12 +456,6 @@ async def test_unsupported_params_excluded_from_model_options(self, mock_module) @pytest.mark.asyncio async def test_tool_params_excluded_from_model_options(self, mock_module): """Test that tool-related parameters are excluded from model_options.""" - from cli.serve.models import ( - FunctionDefinition, - FunctionParameters, - ToolFunction, - ) - request = ChatCompletionRequest( model="test-model", messages=[ChatMessage(role="user", content="Hello")], @@ -511,7 +504,6 @@ async def test_tool_params_excluded_from_model_options(self, mock_module): @pytest.mark.asyncio async def test_response_format_excluded_from_model_options(self, mock_module): """Test that response_format parameter is excluded from model_options.""" - from cli.serve.models import ResponseFormat request = ChatCompletionRequest( model="test-model", @@ -535,3 +527,426 @@ async def test_response_format_excluded_from_model_options(self, mock_module): # response_format should NOT be in model_options assert "response_format" not in model_options + + +class TestResponseFormat: + """Tests for response_format parameter handling.""" + + @pytest.mark.asyncio + async def test_json_schema_format_passed_to_serve(self): + """Test that json_schema response_format is converted to Pydantic model and passed to serve.""" + + # Create a mock module with serve that accepts format parameter + mock_module = Mock() + mock_module.__name__ = "test_module" + + # Track calls manually + captured_format = None + + def mock_serve(input, requirements=None, model_options=None, format=None): + nonlocal captured_format + captured_format = format + return ModelOutputThunk('{"name": "Alice", "age": 30}') + + # Assign the real function so signature inspection works + mock_module.serve = mock_serve + + # Create a request with json_schema response_format + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate a person")], + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="Person", + schema={ + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Verify format was passed + assert captured_format is not None + assert issubclass(captured_format, BaseModel) + assert "name" in captured_format.model_fields + assert "age" in captured_format.model_fields + + # Verify response is successful + assert isinstance(response, ChatCompletion) + assert response.choices[0].message.content == '{"name": "Alice", "age": 30}' + + @pytest.mark.asyncio + async def test_json_object_format_no_schema(self, mock_module): + """Test that json_object response_format doesn't pass a format model.""" + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate JSON")], + response_format=ResponseFormat(type="json_object"), + ) + + mock_output = ModelOutputThunk('{"result": "success"}') + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Verify serve was called + call_args = mock_module.serve.call_args + assert call_args is not None + + # For json_object, format should be None (no specific schema) + if "format" in call_args.kwargs: + assert call_args.kwargs["format"] is None + + # Verify response is successful + assert isinstance(response, ChatCompletion) + + @pytest.mark.asyncio + async def test_text_format_no_schema(self, mock_module): + """Test that text response_format doesn't pass a format model.""" + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello")], + response_format=ResponseFormat(type="text"), + ) + + mock_output = ModelOutputThunk("Hello there!") + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Verify serve was called + call_args = mock_module.serve.call_args + assert call_args is not None + + # For text, format should be None + if "format" in call_args.kwargs: + assert call_args.kwargs["format"] is None + + # Verify response is successful + assert isinstance(response, ChatCompletion) + + @pytest.mark.asyncio + async def test_json_schema_missing_schema_field(self, mock_module): + """Test that json_schema without schema field returns error.""" + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate")], + response_format=ResponseFormat( + type="json_schema", + json_schema=None, # Missing schema + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Should return error + assert isinstance(response, JSONResponse) + assert response.status_code == 400 + + body_bytes = response.body + if isinstance(body_bytes, memoryview): + body_bytes = bytes(body_bytes) + error_data = json.loads(body_bytes.decode("utf-8")) + assert "error" in error_data + assert error_data["error"]["type"] == "invalid_request_error" + assert "json_schema" in error_data["error"]["message"].lower() + + @pytest.mark.asyncio + async def test_json_schema_invalid_schema(self, mock_module): + """Test that invalid JSON schema returns error.""" + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate")], + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="Invalid", + schema={ + "type": "array", # Not supported (only object) + "items": {"type": "string"}, + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Should return error + assert isinstance(response, JSONResponse) + assert response.status_code == 400 + + body_bytes = response.body + if isinstance(body_bytes, memoryview): + body_bytes = bytes(body_bytes) + error_data = json.loads(body_bytes.decode("utf-8")) + assert "error" in error_data + assert error_data["error"]["type"] == "invalid_request_error" + assert "schema" in error_data["error"]["message"].lower() + + @pytest.mark.asyncio + async def test_serve_without_format_parameter(self, mock_module): + """Test that serve functions without format parameter still work.""" + + # Create a serve function that doesn't accept format + def serve_no_format(input, requirements=None, model_options=None): + return ModelOutputThunk("Response without format") + + mock_module.serve = serve_no_format + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello")], + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="Test", + schema={ + "type": "object", + "properties": {"result": {"type": "string"}}, + "required": ["result"], + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Should succeed even though serve doesn't accept format + assert isinstance(response, ChatCompletion) + assert response.choices[0].message.content == "Response without format" + + @pytest.mark.asyncio + async def test_json_schema_with_optional_fields(self): + """Test that JSON schema with optional fields is handled correctly.""" + + # Create a mock module with serve that accepts format parameter + mock_module = Mock() + mock_module.__name__ = "test_module" + + # Track calls manually + captured_format = None + + def mock_serve(input, requirements=None, model_options=None, format=None): + nonlocal captured_format + captured_format = format + return ModelOutputThunk('{"name": "Widget", "price": 9.99}') + + # Assign the real function so signature inspection works + mock_module.serve = mock_serve + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate")], + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="Product", + schema={ + "type": "object", + "properties": { + "name": {"type": "string"}, + "price": {"type": "number"}, + "description": {"type": "string"}, + }, + "required": ["name", "price"], # description is optional + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Verify format model was created correctly + assert captured_format is not None + assert issubclass(captured_format, BaseModel) + assert "name" in captured_format.model_fields + assert "price" in captured_format.model_fields + assert "description" in captured_format.model_fields + + # Verify response is successful + assert isinstance(response, ChatCompletion) + + @pytest.mark.asyncio + async def test_json_schema_rejects_non_local_ref(self, mock_module): + """Test that non-local refs still return a request error.""" + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate")], + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="RemoteRefExample", + schema={ + "type": "object", + "properties": { + "value": {"$ref": "https://example.com/schemas/value.json"} + }, + "required": ["value"], + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + assert isinstance(response, JSONResponse) + assert response.status_code == 400 + + body_bytes = response.body + if isinstance(body_bytes, memoryview): + body_bytes = bytes(body_bytes) + error_data = json.loads(body_bytes.decode("utf-8")) + assert error_data["error"]["type"] == "invalid_request_error" + assert "local" in error_data["error"]["message"].lower() + assert "$ref" in error_data["error"]["message"].lower() + + +class TestResponseFormatStreaming: + """Tests for response_format parameter with streaming enabled.""" + + @pytest.mark.asyncio + async def test_json_schema_format_with_streaming(self): + """Test that json_schema response_format works with stream=True.""" + + # Create a mock module with serve that accepts format parameter + mock_module = Mock() + mock_module.__name__ = "test_module" + + # Create a mock output that supports streaming + mock_output = ModelOutputThunk('{"name": "Alice", "age": 30}') + mock_output._computed = True # Mark as pre-computed + + def mock_serve(input, requirements=None, model_options=None, format=None): + return mock_output + + mock_module.serve = mock_serve + + # Create a request with json_schema response_format and streaming + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate a person")], + stream=True, + response_format=ResponseFormat( + type="json_schema", + json_schema=JsonSchemaFormat( + name="Person", + schema={ + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + }, + ), + ), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + # Verify it's a streaming response + assert isinstance(response, StreamingResponse) + + # Consume the stream and verify chunks + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + + # Should have multiple chunks including initial, content, final, and [DONE] + assert len(chunks) > 0 + + # Verify no error chunks (all should start with "data: ") + for chunk in chunks: + chunk_str = chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk + assert chunk_str.startswith("data: ") + + @pytest.mark.asyncio + async def test_json_object_format_with_streaming(self): + """Test that json_object response_format works with stream=True.""" + + mock_module = Mock() + mock_module.__name__ = "test_module" + + # Valid JSON output + mock_output = ModelOutputThunk('{"result": "success"}') + mock_output._computed = True + mock_module.serve.return_value = mock_output + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Generate JSON")], + stream=True, + response_format=ResponseFormat(type="json_object"), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + assert isinstance(response, StreamingResponse) + + # Consume the stream + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + + # Should complete successfully without errors + assert len(chunks) > 0 + # Verify no error chunks + for chunk in chunks: + chunk_str = chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk + assert "error" not in chunk_str.lower() or chunk_str.startswith( + "data: [DONE]" + ) + + @pytest.mark.asyncio + async def test_text_format_with_streaming(self): + """Test that text response_format works with stream=True.""" + + mock_module = Mock() + mock_module.__name__ = "test_module" + + mock_output = ModelOutputThunk("Plain text response") + mock_output._computed = True + mock_module.serve.return_value = mock_output + + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello")], + stream=True, + response_format=ResponseFormat(type="text"), + ) + + endpoint = make_chat_endpoint(mock_module) + response = await endpoint(request) + + assert isinstance(response, StreamingResponse) + + # Consume the stream + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + + # Should complete successfully + assert len(chunks) > 0 diff --git a/test/cli/test_serve_utils.py b/test/cli/test_serve_utils.py new file mode 100644 index 000000000..e0c44f89f --- /dev/null +++ b/test/cli/test_serve_utils.py @@ -0,0 +1,185 @@ +"""Unit tests for cli/serve/utils.py — finish_reason extraction.""" + +from unittest.mock import Mock + +import pytest + +from cli.serve.utils import extract_finish_reason +from mellea.core.base import ModelOutputThunk + + +class TestExtractFinishReason: + """Tests for extract_finish_reason function.""" + + def test_default_finish_reason_when_no_meta(self): + """Test that 'stop' is returned when output has no _meta attribute.""" + output = ModelOutputThunk("test response") + # Don't set _meta attribute + assert extract_finish_reason(output) == "stop" + + def test_default_finish_reason_when_meta_is_none(self): + """Test that 'stop' is returned when _meta is None.""" + output = ModelOutputThunk("test response") + output._meta = None + assert extract_finish_reason(output) == "stop" + + def test_default_finish_reason_when_meta_is_empty(self): + """Test that 'stop' is returned when _meta is empty dict.""" + output = ModelOutputThunk("test response") + output._meta = {} + assert extract_finish_reason(output) == "stop" + + def test_ollama_done_reason_stop(self): + """Test extraction of 'stop' from Ollama chat_response.done_reason.""" + output = ModelOutputThunk("test response") + chat_response = Mock() + chat_response.done_reason = "stop" + output._meta = {"chat_response": chat_response} + assert extract_finish_reason(output) == "stop" + + def test_ollama_done_reason_length(self): + """Test extraction of 'length' from Ollama chat_response.done_reason.""" + output = ModelOutputThunk("test response") + chat_response = Mock() + chat_response.done_reason = "length" + output._meta = {"chat_response": chat_response} + assert extract_finish_reason(output) == "length" + + def test_ollama_done_reason_none(self): + """Test that default 'stop' is returned when done_reason is None.""" + output = ModelOutputThunk("test response") + chat_response = Mock() + chat_response.done_reason = None + output._meta = {"chat_response": chat_response} + assert extract_finish_reason(output) == "stop" + + def test_ollama_chat_response_without_done_reason(self): + """Test that default 'stop' is returned when chat_response lacks done_reason.""" + output = ModelOutputThunk("test response") + chat_response = Mock(spec=[]) # Mock without done_reason attribute + output._meta = {"chat_response": chat_response} + assert extract_finish_reason(output) == "stop" + + def test_openai_finish_reason_stop(self): + """Test extraction of 'stop' from OpenAI oai_chat_response.""" + output = ModelOutputThunk("test response") + output._meta = { + "oai_chat_response": {"choices": [{"finish_reason": "stop", "index": 0}]} + } + assert extract_finish_reason(output) == "stop" + + def test_openai_finish_reason_length(self): + """Test extraction of 'length' from OpenAI oai_chat_response.""" + output = ModelOutputThunk("test response") + output._meta = { + "oai_chat_response": {"choices": [{"finish_reason": "length", "index": 0}]} + } + assert extract_finish_reason(output) == "length" + + def test_openai_finish_reason_content_filter(self): + """Test extraction of 'content_filter' from OpenAI oai_chat_response.""" + output = ModelOutputThunk("test response") + output._meta = { + "oai_chat_response": { + "choices": [{"finish_reason": "content_filter", "index": 0}] + } + } + assert extract_finish_reason(output) == "content_filter" + + def test_openai_finish_reason_tool_calls(self): + """Test extraction of 'tool_calls' from OpenAI oai_chat_response.""" + output = ModelOutputThunk("test response") + output._meta = { + "oai_chat_response": { + "choices": [{"finish_reason": "tool_calls", "index": 0}] + } + } + assert extract_finish_reason(output) == "tool_calls" + + def test_openai_finish_reason_function_call(self): + """Test extraction of 'function_call' from OpenAI oai_chat_response.""" + output = ModelOutputThunk("test response") + output._meta = { + "oai_chat_response": { + "choices": [{"finish_reason": "function_call", "index": 0}] + } + } + assert extract_finish_reason(output) == "function_call" + + def test_openai_empty_choices_array(self): + """Test that default 'stop' is returned when choices array is empty.""" + output = ModelOutputThunk("test response") + output._meta = {"oai_chat_response": {"choices": []}} + assert extract_finish_reason(output) == "stop" + + def test_openai_missing_choices_key(self): + """Test that default 'stop' is returned when choices key is missing.""" + output = ModelOutputThunk("test response") + output._meta = {"oai_chat_response": {}} + assert extract_finish_reason(output) == "stop" + + def test_openai_finish_reason_none(self): + """Test that default 'stop' is returned when finish_reason is None.""" + output = ModelOutputThunk("test response") + output._meta = { + "oai_chat_response": {"choices": [{"finish_reason": None, "index": 0}]} + } + assert extract_finish_reason(output) == "stop" + + def test_openai_non_dict_response(self): + """Test that default 'stop' is returned when oai_chat_response is not a dict.""" + output = ModelOutputThunk("test response") + output._meta = {"oai_chat_response": "not a dict"} + assert extract_finish_reason(output) == "stop" + + def test_ollama_takes_precedence_over_openai(self): + """Test that Ollama done_reason is checked before OpenAI finish_reason.""" + output = ModelOutputThunk("test response") + chat_response = Mock() + chat_response.done_reason = "length" + output._meta = { + "chat_response": chat_response, + "oai_chat_response": {"choices": [{"finish_reason": "stop", "index": 0}]}, + } + # Should return Ollama's done_reason, not OpenAI's finish_reason + assert extract_finish_reason(output) == "length" + + def test_openai_used_when_ollama_missing(self): + """Test that OpenAI finish_reason is used when Ollama data is missing.""" + output = ModelOutputThunk("test response") + output._meta = { + "oai_chat_response": { + "choices": [{"finish_reason": "content_filter", "index": 0}] + } + } + assert extract_finish_reason(output) == "content_filter" + + def test_multiple_choices_uses_first(self): + """Test that first choice is used when multiple choices exist.""" + output = ModelOutputThunk("test response") + output._meta = { + "oai_chat_response": { + "choices": [ + {"finish_reason": "stop", "index": 0}, + {"finish_reason": "length", "index": 1}, + ] + } + } + assert extract_finish_reason(output) == "stop" + + def test_other_meta_keys_ignored(self): + """Test that unrelated _meta keys don't interfere.""" + output = ModelOutputThunk("test response") + output._meta = { + "model": "gpt-4", + "provider": "openai", + "usage": {"total_tokens": 100}, + "random_key": "random_value", + } + assert extract_finish_reason(output) == "stop" + + def test_output_without_meta_attribute(self): + """Test handling of output objects that don't have _meta attribute at all.""" + # Create a simple object without _meta + output = Mock(spec=[]) + assert extract_finish_reason(output) == "stop"