diff --git a/python/PACKAGE_STATUS.md b/python/PACKAGE_STATUS.md index 1f336f1cd8..7e9270392e 100644 --- a/python/PACKAGE_STATUS.md +++ b/python/PACKAGE_STATUS.md @@ -34,6 +34,7 @@ Status is grouped into these buckets: | `agent-framework-foundry-local` | `python/packages/foundry_local` | `beta` | | `agent-framework-gemini` | `python/packages/gemini` | `alpha` | | `agent-framework-github-copilot` | `python/packages/github_copilot` | `beta` | +| `agent-framework-hosting-discord` | `python/packages/hosting-discord` | `alpha` | | `agent-framework-hyperlight` | `python/packages/hyperlight` | `beta` | | `agent-framework-lab` | `python/packages/lab` | `beta` | | `agent-framework-mem0` | `python/packages/mem0` | `beta` | diff --git a/python/packages/hosting-discord/LICENSE b/python/packages/hosting-discord/LICENSE new file mode 100644 index 0000000000..331750f625 --- /dev/null +++ b/python/packages/hosting-discord/LICENSE @@ -0,0 +1,22 @@ + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE + diff --git a/python/packages/hosting-discord/README.md b/python/packages/hosting-discord/README.md new file mode 100644 index 0000000000..d0e1e82223 --- /dev/null +++ b/python/packages/hosting-discord/README.md @@ -0,0 +1,78 @@ +# agent-framework-hosting-discord + +Discord HTTP Interactions channel for [agent-framework-hosting](../hosting). +The channel exposes a signed Starlette route for Discord slash commands, maps a +configurable slash command to the hosted agent, maps `ChannelCommand` instances +to native Discord commands, and supports push to Discord channel ids. + +## Usage + +```python +from agent_framework_hosting import AgentFrameworkHost +from agent_framework_hosting_discord import DiscordChannel + +host = AgentFrameworkHost( + target=my_agent, + channels=[ + DiscordChannel( + application_id="", + public_key="", + bot_token="", + guild_id="", + ) + ], +) +host.serve() +``` + +Configure the Discord Developer Portal interaction endpoint as: + +```text +https:///discord/interactions +``` + +The channel verifies Discord's `X-Signature-Ed25519` header against the raw +request body before parsing JSON. `skip_signature_verification=True` exists only +for local tests and should not be used on a public endpoint. + +## Slash commands + +By default, `/ask prompt:` invokes the hosted agent. Additional +`ChannelCommand` instances are registered as Discord slash commands with an +optional `input` string option: + +```python +from agent_framework_hosting import ChannelCommand + +async def reset(ctx): + await ctx.reply("Reset acknowledged") + +DiscordChannel( + application_id="...", + public_key="...", + bot_token="...", + commands=[ChannelCommand("reset", "Reset the conversation", reset)], +) +``` + +When `guild_id` is set, commands are registered only for that guild and usually +appear quickly. Global command registration can take much longer to propagate. +If `register_commands=True` but `bot_token` is omitted, the channel logs a +warning and assumes commands were registered outside the host. + +## Identity, sessions, and push + +The default isolation key is `discord:::`, +which keeps each user private inside a Discord channel or thread. Pass +`isolation_key_factory=` to use a different scope. + +`ChannelIdentity.native_id` is the Discord user id. Push requires +`identity.attributes["channel_id"]`; the first slice intentionally does not +create DM channels as a fallback. + +## Streaming + +Set `streaming=True` to consume the host stream and edit the original Discord +interaction response as text accumulates. Edits are debounced with +`edit_interval` to avoid excessive Discord REST calls. + diff --git a/python/packages/hosting-discord/agent_framework_hosting_discord/__init__.py b/python/packages/hosting-discord/agent_framework_hosting_discord/__init__.py new file mode 100644 index 0000000000..ed048b20f3 --- /dev/null +++ b/python/packages/hosting-discord/agent_framework_hosting_discord/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Discord channel for ``agent-framework-hosting``.""" + +import importlib.metadata + +from ._channel import DiscordChannel, DiscordIsolationKeyFactory, discord_isolation_key + +try: + __version__ = importlib.metadata.version(__name__) +except importlib.metadata.PackageNotFoundError: + __version__ = "0.0.0" + +__all__ = [ + "DiscordChannel", + "DiscordIsolationKeyFactory", + "__version__", + "discord_isolation_key", +] diff --git a/python/packages/hosting-discord/agent_framework_hosting_discord/_channel.py b/python/packages/hosting-discord/agent_framework_hosting_discord/_channel.py new file mode 100644 index 0000000000..c7125ff91f --- /dev/null +++ b/python/packages/hosting-discord/agent_framework_hosting_discord/_channel.py @@ -0,0 +1,657 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Discord HTTP Interactions channel.""" + +from __future__ import annotations + +import asyncio +import json +import logging +import re +import time +from collections.abc import Awaitable, Callable, Coroutine, Mapping, Sequence +from typing import Any, cast + +import httpx +from agent_framework import AgentResponse, AgentResponseUpdate, Content, Message, ResponseStream +from agent_framework_hosting import ( + ChannelCommand, + ChannelCommandContext, + ChannelContext, + ChannelContribution, + ChannelIdentity, + ChannelRequest, + ChannelResponseHook, + ChannelRunHook, + ChannelSession, + ChannelStreamTransformHook, + HostedRunResult, + apply_channel_response_hook, + apply_run_hook, +) +from nacl.exceptions import BadSignatureError +from nacl.signing import VerifyKey +from starlette.requests import Request +from starlette.responses import JSONResponse, Response +from starlette.routing import Route + +logger = logging.getLogger("agent_framework.hosting.discord") + +DiscordInteraction = Mapping[str, Any] +DiscordIsolationKeyFactory = Callable[[DiscordInteraction], str] + +_DISCORD_API_BASE = "https://discord.com/api/v10" +_DISCORD_MAX_BODY_BYTES = 1024 * 1024 +_DISCORD_MAX_CONTENT_LEN = 2000 +_INTERACTION_PING = 1 +_INTERACTION_APPLICATION_COMMAND = 2 +_RESPONSE_PONG = 1 +_RESPONSE_DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE = 5 +_OPTION_STRING = 3 +_APPLICATION_COMMAND_CHAT_INPUT = 1 +_COMMAND_NAME_RE = re.compile(r"^[a-z0-9_-]{1,32}$") + + +def discord_isolation_key(guild_id: str | None, channel_id: str, user_id: str) -> str: + """Build the default Discord isolation key. + + Args: + guild_id: Discord guild id, or ``None`` for a DM interaction. + channel_id: Discord channel or thread id. + user_id: Discord user id. + + Returns: + A stable host isolation key scoped to guild/channel/user. + """ + scope = guild_id or "dm" + return f"discord:{scope}:{channel_id}:{user_id}" + + +def _default_isolation_key(interaction: DiscordInteraction) -> str: + user = _user_from_interaction(interaction) + user_id = _require_string(user.get("id"), "interaction user id") + channel_id = _require_string(interaction.get("channel_id"), "interaction channel_id") + guild_id = _string_or_none(interaction.get("guild_id")) + return discord_isolation_key(guild_id, channel_id, user_id) + + +def _text_result(text: str) -> HostedRunResult[AgentResponse]: + """Build a host delivery payload from text accumulated by this channel.""" + return HostedRunResult(AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text(text=text)])])) + + +class DiscordChannel: + """Discord channel backed by signed HTTP Interactions.""" + + name = "discord" + + def __init__( + self, + *, + application_id: str, + public_key: str, + bot_token: str | None = None, + guild_id: str | None = None, + path: str = "/discord", + agent_command: str = "ask", + agent_command_description: str = "Ask the agent", + agent_command_option: str = "prompt", + register_commands: bool = True, + commands: Sequence[ChannelCommand] | None = None, + run_hook: ChannelRunHook | None = None, + response_hook: ChannelResponseHook | None = None, + stream_transform_hook: ChannelStreamTransformHook | None = None, + streaming: bool = False, + isolation_key_factory: DiscordIsolationKeyFactory | None = None, + skip_signature_verification: bool = False, + edit_interval: float = 1.0, + max_body_bytes: int = _DISCORD_MAX_BODY_BYTES, + api_base_url: str = _DISCORD_API_BASE, + ) -> None: + """Configure the Discord channel. + + Keyword Args: + application_id: Discord application id. + public_key: Discord application public key as lowercase or + uppercase hex. Used to verify interaction signatures. + bot_token: Bot token used to register slash commands and push + messages to Discord channel ids. Interaction webhook replies + do not require this token. + guild_id: Optional guild id for guild-scoped slash command + registration. Recommended for development because global + command registration can take a long time to propagate. + path: Host mount path. The interaction route is contributed as + ``/interactions`` below this path. + agent_command: Slash command name that invokes the hosted agent. + agent_command_description: Description for the agent slash command. + agent_command_option: String option name that carries the prompt. + register_commands: Whether startup should register slash commands + through Discord REST when ``bot_token`` is configured. + commands: Additional host ``ChannelCommand`` instances to expose + as Discord slash commands. + run_hook: Optional hook that can rewrite the channel request before + it reaches the host. + response_hook: Optional hook that can rewrite the hosted result + before the originating Discord response is serialized. + stream_transform_hook: Optional per-update transform hook applied + while streaming. + streaming: Whether the agent command should call ``run_stream`` + and edit the original interaction response as deltas arrive. + isolation_key_factory: Optional callable that receives the raw + Discord interaction and returns a host isolation key. + skip_signature_verification: Disable Ed25519 verification. Use + only for local tests; never expose publicly with this enabled. + edit_interval: Minimum seconds between streaming edits to the + original Discord interaction response. + max_body_bytes: Maximum raw interaction request body size. + api_base_url: Discord API base URL. Primarily useful for tests. + + Raises: + ValueError: If public key hex or command names are invalid, or if + command names collide. + """ + self.application_id = application_id + self.public_key = public_key + self.bot_token = bot_token + self.guild_id = guild_id + self.path = path + self.agent_command = agent_command + self.agent_command_description = agent_command_description + self.agent_command_option = agent_command_option + self.register_commands = register_commands + self._commands: set[ChannelCommand] = set(commands) or {} # type: ignore + self._command_by_name = {command.name: command for command in self._commands} + self._run_hook = run_hook + self.response_hook = response_hook + self._stream_transform_hook = stream_transform_hook + self._streaming = streaming + self._isolation_key_factory = isolation_key_factory or _default_isolation_key + self._skip_signature_verification = skip_signature_verification + self._edit_interval = edit_interval + self._max_body_bytes = max_body_bytes + self._api_base_url = api_base_url.rstrip("/") + self._ctx: ChannelContext | None = None + self._http: httpx.AsyncClient | None = None + self._tasks: set[asyncio.Task[None]] = set() + + self._validate_configuration() + try: + self._verify_key = VerifyKey(bytes.fromhex(public_key)) + except ValueError as exc: + raise ValueError("DiscordChannel public_key must be a valid Ed25519 public key hex string") from exc + + def contribute(self, context: ChannelContext) -> ChannelContribution: + """Register the Discord interaction route and lifecycle hooks.""" + self._ctx = context + return ChannelContribution( + routes=[Route("/interactions", self._handle, methods=["POST"])], + commands=self._commands, + on_startup=[self._on_startup], + on_shutdown=[self._on_shutdown], + ) + + async def push(self, identity: ChannelIdentity, payload: HostedRunResult[Any]) -> None: + """Push a hosted result to a Discord channel. + + Args: + identity: Destination identity. ``identity.attributes`` must carry + ``channel_id``. + payload: Hosted run result to render as Discord message text. + + Raises: + RuntimeError: If the channel has no bot token for Discord REST. + ValueError: If ``channel_id`` is missing from the identity. + """ + channel_id = _string_or_none(identity.attributes.get("channel_id")) + if channel_id is None: + raise ValueError("Discord push requires identity.attributes['channel_id']") + if self.bot_token is None: + raise RuntimeError("DiscordChannel.push requires bot_token to send channel messages") + await self._send_channel_messages(channel_id, _payload_text(payload)) + + async def _on_startup(self) -> None: + """Open the Discord REST client and optionally register slash commands.""" + self._ensure_http() + if self._skip_signature_verification: + logger.warning( + "DiscordChannel running with skip_signature_verification=True. " + "Use only for local tests; public Discord endpoints must verify signatures." + ) + if not self.register_commands: + return + if self.bot_token is None: + logger.warning( + "DiscordChannel register_commands=True but bot_token is not configured; " + "slash commands must be registered outside the host." + ) + return + if self.guild_id is None: + logger.warning( + "DiscordChannel registering global slash commands; Discord can take a long time " + "to propagate global command changes. Set guild_id for faster development updates." + ) + try: + await self._register_commands() + except (RuntimeError, httpx.HTTPError): + logger.exception("DiscordChannel slash command registration failed; continuing startup") + + async def _on_shutdown(self) -> None: + """Drain in-flight interaction tasks and close the Discord REST client.""" + if self._tasks: + await asyncio.gather(*self._tasks, return_exceptions=True) + if self._http is not None: + await self._http.aclose() + self._http = None + + async def _handle(self, request: Request) -> Response: + """Handle one Discord interaction webhook request.""" + raw_body = await request.body() + if len(raw_body) > self._max_body_bytes: + return JSONResponse({"error": "request body too large"}, status_code=413) + if not self._skip_signature_verification and not self._verify_signature(request, raw_body): + return JSONResponse({"error": "invalid signature"}, status_code=401) + try: + body = json.loads(raw_body.decode("utf-8")) + except json.JSONDecodeError: + return JSONResponse({"error": "invalid JSON"}, status_code=400) + if not isinstance(body, Mapping): + return JSONResponse({"error": "interaction body must be a JSON object"}, status_code=400) + interaction = cast("DiscordInteraction", body) + + interaction_type = interaction.get("type") + if interaction_type == _INTERACTION_PING: + return JSONResponse({"type": _RESPONSE_PONG}) + if interaction_type != _INTERACTION_APPLICATION_COMMAND: + return JSONResponse({"error": f"unsupported interaction type: {interaction_type!r}"}, status_code=400) + + self._schedule(self._dispatch_application_command(interaction)) + return JSONResponse({"type": _RESPONSE_DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE}) + + async def _dispatch_application_command(self, interaction: DiscordInteraction) -> None: + token = _require_string(interaction.get("token"), "interaction token") + try: + name = _application_command_name(interaction) + if name == self.agent_command: + await self._run_agent_command(interaction, token) + return + command = self._command_by_name.get(name) + if command is None: + await self._edit_original(token, f"Unknown Discord command: {name}") + return + await self._run_channel_command(command, interaction, token) + except Exception: + logger.exception("DiscordChannel interaction handling failed") + await self._try_edit_original(token, "Sorry, something went wrong while handling that Discord command.") + raise + + async def _run_agent_command(self, interaction: DiscordInteraction, token: str) -> None: + if self._ctx is None: + raise RuntimeError("DiscordChannel was not contributed to a host.") + prompt = _string_option(interaction, self.agent_command_option) + if prompt is None: + await self._edit_original(token, f"Missing required `{self.agent_command_option}` option.") + return + request = self._build_request( + interaction, + operation="message.create", + input_value=prompt, + stream=self._streaming, + ) + if self._run_hook is not None: + request = await apply_run_hook( + self._run_hook, + request, + target=self._ctx.target, + protocol_request=interaction, + ) + if request.stream: + await self._run_streaming(request, token) + return + result = await self._ctx.run(request) + include_originating = await self._ctx.deliver_response(request, result) + if include_originating: + result = await apply_channel_response_hook(self, result, request=request, originating=True) + await self._edit_original_with_result(token, result) + else: + await self._edit_original(token, "Sent.") + + async def _run_channel_command( + self, + command: ChannelCommand, + interaction: DiscordInteraction, + token: str, + ) -> None: + command_input = _string_option(interaction, "input") + request = self._build_request( + interaction, + operation="command.invoke", + input_value=f"/{command.name}" if command_input is None else f"/{command.name} {command_input}", + stream=False, + ) + reply = _DiscordInteractionReply(self, token) + await command.handle(ChannelCommandContext(request=request, reply=reply)) + if not reply.sent: + await self._edit_original(token, "Done.") + + async def _run_streaming(self, request: ChannelRequest, token: str) -> None: + if self._ctx is None: + raise RuntimeError("DiscordChannel was not contributed to a host.") + stream: ResponseStream[AgentResponseUpdate, AgentResponse] = self._ctx.run_stream(request) + accumulated: list[str] = [] + last_edit = 0.0 + async for update in stream: + transformed: AgentResponseUpdate | None = update + if self._stream_transform_hook is not None: + maybe = self._stream_transform_hook(update) + if isinstance(maybe, Awaitable): + transformed = await cast("Awaitable[AgentResponseUpdate | None]", maybe) + else: + transformed = maybe + if transformed is None: + continue + chunk = _update_text(transformed) + if not chunk: + continue + accumulated.append(chunk) + now = time.monotonic() + if self._edit_interval <= 0 or now - last_edit >= self._edit_interval: + await self._edit_original(token, _stream_preview_content("".join(accumulated))) + last_edit = now + + final = _text_result("".join(accumulated)) + include_originating = await self._ctx.deliver_response(request, final) + if include_originating: + final = await apply_channel_response_hook(self, final, request=request, originating=True) + await self._edit_original_with_result(token, final) + else: + await self._edit_original(token, "Sent.") + + def _build_request( + self, + interaction: DiscordInteraction, + *, + operation: str, + input_value: Any, + stream: bool, + ) -> ChannelRequest: + identity = self._identity_from_interaction(interaction) + command_name = _application_command_name(interaction) + metadata = { + "interaction_id": _string_or_none(interaction.get("id")), + "application_id": self.application_id, + "guild_id": _string_or_none(interaction.get("guild_id")), + "channel_id": _string_or_none(interaction.get("channel_id")), + "user_id": identity.native_id, + "command": command_name, + } + clean_metadata = {key: value for key, value in metadata.items() if value is not None} + return ChannelRequest( + channel=self.name, + operation=operation, + input=input_value, + session=ChannelSession(isolation_key=self._isolation_key_factory(interaction)), + metadata=clean_metadata, + attributes=clean_metadata, + stream=stream, + identity=identity, + ) + + def _identity_from_interaction(self, interaction: DiscordInteraction) -> ChannelIdentity: + user = _user_from_interaction(interaction) + user_id = _require_string(user.get("id"), "interaction user id") + attributes = { + "username": _string_or_none(user.get("username")), + "global_name": _string_or_none(user.get("global_name")), + "guild_id": _string_or_none(interaction.get("guild_id")), + "channel_id": _string_or_none(interaction.get("channel_id")), + "application_id": self.application_id, + } + return ChannelIdentity( + channel=self.name, + native_id=user_id, + attributes={key: value for key, value in attributes.items() if value is not None}, + ) + + def _verify_signature(self, request: Request, raw_body: bytes) -> bool: + signature = request.headers.get("x-signature-ed25519") + timestamp = request.headers.get("x-signature-timestamp") + if not signature or not timestamp: + return False + try: + self._verify_key.verify(timestamp.encode("utf-8") + raw_body, bytes.fromhex(signature)) + except (BadSignatureError, ValueError): + return False + return True + + def _schedule(self, coro: Coroutine[Any, Any, None]) -> None: + task = asyncio.create_task(coro) + self._tasks.add(task) + task.add_done_callback(self._on_task_done) + + def _on_task_done(self, task: asyncio.Task[None]) -> None: + self._tasks.discard(task) + try: + task.result() + except asyncio.CancelledError: + return + except Exception: + logger.exception("DiscordChannel background task failed") + + def _ensure_http(self) -> httpx.AsyncClient: + if self._http is None: + self._http = httpx.AsyncClient(base_url=self._api_base_url, timeout=30.0) + return self._http + + async def _register_commands(self) -> None: + http = self._ensure_http() + path = f"/applications/{self.application_id}/commands" + if self.guild_id is not None: + path = f"/applications/{self.application_id}/guilds/{self.guild_id}/commands" + response = await http.put(path, headers=self._bot_headers(), json=self._command_payloads()) + _raise_for_discord_error(response, "register slash commands") + + async def _edit_original_with_result(self, token: str, payload: HostedRunResult[Any]) -> None: + chunks = _split_content(_payload_text(payload)) + await self._edit_original(token, chunks[0]) + for chunk in chunks[1:]: + await self._send_followup(token, chunk) + + async def _edit_original(self, token: str, content: str) -> None: + http = self._ensure_http() + response = await http.patch( + f"/webhooks/{self.application_id}/{token}/messages/@original", + json={"content": _normalize_content(content)}, + ) + _raise_for_discord_error(response, "edit interaction response") + + async def _try_edit_original(self, token: str, content: str) -> None: + try: + await self._edit_original(token, content) + except (RuntimeError, httpx.HTTPError): + logger.exception("DiscordChannel failed to edit interaction error response") + + async def _send_followup(self, token: str, content: str) -> None: + http = self._ensure_http() + response = await http.post( + f"/webhooks/{self.application_id}/{token}", + json={"content": _normalize_content(content)}, + ) + _raise_for_discord_error(response, "send interaction follow-up") + + async def _send_channel_messages(self, channel_id: str, content: str) -> None: + http = self._ensure_http() + for chunk in _split_content(content): + response = await http.post( + f"/channels/{channel_id}/messages", + headers=self._bot_headers(), + json={"content": chunk}, + ) + _raise_for_discord_error(response, "send channel message") + + def _bot_headers(self) -> dict[str, str]: + if self.bot_token is None: + raise RuntimeError("Discord bot token is required for this operation") + return {"Authorization": f"Bot {self.bot_token}"} + + def _command_payloads(self) -> list[dict[str, Any]]: + payloads = [ + { + "type": _APPLICATION_COMMAND_CHAT_INPUT, + "name": self.agent_command, + "description": self.agent_command_description, + "options": [ + { + "type": _OPTION_STRING, + "name": self.agent_command_option, + "description": "Prompt for the agent.", + "required": True, + } + ], + } + ] + for command in self._commands: + payloads.append({ + "type": _APPLICATION_COMMAND_CHAT_INPUT, + "name": command.name, + "description": command.description, + "options": [ + { + "type": _OPTION_STRING, + "name": "input", + "description": "Optional command input.", + "required": False, + } + ], + }) + return payloads + + def _validate_configuration(self) -> None: + names = [self.agent_command, *(command.name for command in self._commands)] + for name in names: + if not _COMMAND_NAME_RE.fullmatch(name): + raise ValueError( + "Discord command names must be lowercase ASCII letters, numbers, hyphen, " + f"or underscore, and 1-32 characters long: {name!r}" + ) + if not _COMMAND_NAME_RE.fullmatch(self.agent_command_option): + raise ValueError( + "Discord agent_command_option must be lowercase ASCII letters, numbers, hyphen, " + f"or underscore, and 1-32 characters long: {self.agent_command_option!r}" + ) + if len(set(names)) != len(names): + raise ValueError("Discord command names must be unique; agent_command cannot collide with commands") + if self._edit_interval < 0: + raise ValueError("edit_interval must be >= 0") + if self._max_body_bytes <= 0: + raise ValueError("max_body_bytes must be > 0") + + +class _DiscordInteractionReply: + """Reply helper that edits the deferred response first, then sends follow-ups.""" + + def __init__(self, channel: DiscordChannel, token: str) -> None: + self._channel = channel + self._token = token + self.sent = False + + async def __call__(self, body: str) -> None: + chunks = _split_content(body) + if not self.sent: + await self._channel._edit_original(self._token, chunks[0]) # pyright: ignore[reportPrivateUsage] + self.sent = True + for chunk in chunks[1:]: + await self._channel._send_followup(self._token, chunk) # pyright: ignore[reportPrivateUsage] + return + for chunk in chunks: + await self._channel._send_followup(self._token, chunk) # pyright: ignore[reportPrivateUsage] + + +def _user_from_interaction(interaction: DiscordInteraction) -> Mapping[str, Any]: + member = interaction.get("member") + if isinstance(member, Mapping): + member_user = member.get("user") + if isinstance(member_user, Mapping): + return member_user + user = interaction.get("user") + if isinstance(user, Mapping): + return user + raise ValueError("Discord interaction is missing user information") + + +def _application_command_name(interaction: DiscordInteraction) -> str: + data = interaction.get("data") + if not isinstance(data, Mapping): + raise ValueError("Discord application command interaction is missing data") + return _require_string(data.get("name"), "application command name") + + +def _string_option(interaction: DiscordInteraction, name: str) -> str | None: + data = interaction.get("data") + if not isinstance(data, Mapping): + return None + options = data.get("options") + if not isinstance(options, Sequence) or isinstance(options, (str, bytes)): + return None + for option in options: + if not isinstance(option, Mapping): + continue + if option.get("name") != name: + continue + value = option.get("value") + if value is None: + return None + return str(value) + return None + + +def _payload_text(payload: HostedRunResult[Any]) -> str: + text = getattr(payload.result, "text", None) + if isinstance(text, str) and text: + return text + messages = getattr(payload.result, "messages", None) + if isinstance(messages, Sequence): + for message in reversed(messages): + message_text = getattr(message, "text", None) + if isinstance(message_text, str) and message_text: + return message_text + return "(no response)" + + +def _update_text(update: AgentResponseUpdate) -> str: + parts: list[str] = [] + for content in update.contents: + text = getattr(content, "text", None) + if isinstance(text, str) and text: + parts.append(text) + return "".join(parts) + + +def _split_content(content: str) -> list[str]: + normalized = _normalize_content(content) + return [normalized[i : i + _DISCORD_MAX_CONTENT_LEN] for i in range(0, len(normalized), _DISCORD_MAX_CONTENT_LEN)] + + +def _stream_preview_content(content: str) -> str: + return _split_content(content)[0] + + +def _normalize_content(content: str) -> str: + return content if content else "(no response)" + + +def _string_or_none(value: Any) -> str | None: + return value if isinstance(value, str) and value else None + + +def _require_string(value: Any, field_name: str) -> str: + if isinstance(value, str) and value: + return value + raise ValueError(f"Discord {field_name} must be a non-empty string") + + +def _raise_for_discord_error(response: httpx.Response, action: str) -> None: + try: + response.raise_for_status() + except httpx.HTTPStatusError as exc: + body = response.text[:500] + raise RuntimeError(f"Discord {action} failed with HTTP {response.status_code}: {body}") from exc diff --git a/python/packages/hosting-discord/agent_framework_hosting_discord/py.typed b/python/packages/hosting-discord/agent_framework_hosting_discord/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/packages/hosting-discord/pyproject.toml b/python/packages/hosting-discord/pyproject.toml new file mode 100644 index 0000000000..23948b660f --- /dev/null +++ b/python/packages/hosting-discord/pyproject.toml @@ -0,0 +1,107 @@ +[project] +name = "agent-framework-hosting-discord" +description = "Discord channel for agent-framework-hosting." +authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] +readme = "README.md" +requires-python = ">=3.10" +version = "1.0.0a260526" +license-files = ["LICENSE"] +urls.homepage = "https://aka.ms/agent-framework" +urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" +urls.release_notes = "https://github.com/microsoft/agent-framework/releases?q=tag%3Apython-1&expanded=true" +urls.issues = "https://github.com/microsoft/agent-framework/issues" +classifiers = [ + "License :: OSI Approved :: MIT License", + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Typing :: Typed", +] +dependencies = [ + "agent-framework-core>=1.2.0,<2", + "agent-framework-hosting>=1.0.0a260424,<2", + "httpx>=0.27,<1", + "PyNaCl>=1.2.0,<2", +] + +[tool.uv] +prerelease = "if-necessary-or-explicit" +environments = [ + "sys_platform == 'darwin'", + "sys_platform == 'linux'", + "sys_platform == 'win32'" +] + +[tool.uv-dynamic-versioning] +fallback-version = "0.0.0" + +[tool.pytest.ini_options] +testpaths = 'tests' +addopts = "-ra -q -r fEX" +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" +filterwarnings = [] +timeout = 120 +markers = [ + "integration: marks tests as integration tests that require external services", +] + +[tool.ruff] +extend = "../../pyproject.toml" + +[tool.coverage.run] +omit = [ + "**/__init__.py" +] + +[tool.pyright] +extends = "../../pyproject.toml" +include = ["agent_framework_hosting_discord"] +exclude = ['tests'] +# Discord interactions arrive as loosely-typed JSON maps. Runtime guards narrow +# payloads where needed; strict Unknown reporting on every `.get()` is noisy. +reportUnknownArgumentType = "none" +reportUnknownMemberType = "none" +reportUnknownVariableType = "none" +reportUnknownLambdaType = "none" +reportOptionalMemberAccess = "none" + +[tool.mypy] +plugins = ['pydantic.mypy'] +strict = true +python_version = "3.10" +ignore_missing_imports = true +disallow_untyped_defs = true +no_implicit_optional = true +check_untyped_defs = true +warn_return_any = true +show_error_codes = true +warn_unused_ignores = false +disallow_incomplete_defs = true +disallow_untyped_decorators = true + +[tool.bandit] +targets = ["agent_framework_hosting_discord"] +exclude_dirs = ["tests"] + +[tool.poe] +executor.type = "uv" +include = "../../shared_tasks.toml" + +[tool.poe.tasks.mypy] +help = "Run MyPy for this package." +cmd = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_hosting_discord" + +[tool.poe.tasks.test] +help = "Run the default unit test suite for this package." +cmd = 'pytest -m "not integration" --cov=agent_framework_hosting_discord --cov-report=term-missing:skip-covered tests' + +[build-system] +requires = ["flit-core >= 3.11,<4.0"] +build-backend = "flit_core.buildapi" + diff --git a/python/packages/hosting-discord/tests/discord/test_channel.py b/python/packages/hosting-discord/tests/discord/test_channel.py new file mode 100644 index 0000000000..868f905a44 --- /dev/null +++ b/python/packages/hosting-discord/tests/discord/test_channel.py @@ -0,0 +1,680 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import json +from collections.abc import AsyncIterator +from typing import Any + +import httpx +import pytest +from agent_framework import AgentResponse, AgentResponseUpdate, Content, Message +from agent_framework_hosting import ( + ChannelCommand, + ChannelCommandContext, + ChannelRequest, + ChannelResponseContext, + HostedRunResult, +) +from nacl.signing import SigningKey +from starlette.applications import Starlette +from starlette.testclient import TestClient + +from agent_framework_hosting_discord import DiscordChannel, discord_isolation_key + + +def _run_result(text: str) -> HostedRunResult[AgentResponse]: + return HostedRunResult(AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text(text=text)])])) + + +def _interaction(command: str = "ask", *, prompt: str = "hello", token: str = "token") -> dict[str, Any]: + return { + "id": "interaction-1", + "type": 2, + "application_id": "app-1", + "token": token, + "guild_id": "guild-1", + "channel_id": "channel-1", + "member": { + "user": { + "id": "user-1", + "username": "ada", + "global_name": "Ada", + } + }, + "data": { + "name": command, + "options": [{"name": "prompt", "type": 3, "value": prompt}], + }, + } + + +def _headers(signing_key: SigningKey, body: bytes) -> dict[str, str]: + timestamp = "1234567890" + signature = signing_key.sign(timestamp.encode("utf-8") + body).signature.hex() + return { + "x-signature-ed25519": signature, + "x-signature-timestamp": timestamp, + "content-type": "application/json", + } + + +class _FakeContext: + def __init__(self, *, text: str = "agent reply", include_originating: bool = True) -> None: + self.target = object() + self.text = text + self.include_originating = include_originating + self.requests: list[ChannelRequest] = [] + self.delivered: list[tuple[ChannelRequest, HostedRunResult[Any]]] = [] + self.stream: _FakeStream | None = None + + async def run(self, request: ChannelRequest) -> HostedRunResult[AgentResponse]: + self.requests.append(request) + return _run_result(self.text) + + def run_stream(self, request: ChannelRequest) -> _FakeStream: + self.requests.append(request) + if self.stream is None: + self.stream = _FakeStream(["a", "b"]) + return self.stream + + async def deliver_response(self, request: ChannelRequest, payload: HostedRunResult[Any]) -> bool: + self.delivered.append((request, payload)) + return self.include_originating + + +class _FakeStream: + def __init__(self, chunks: list[str]) -> None: + self._chunks = chunks + + def __aiter__(self) -> AsyncIterator[AgentResponseUpdate]: + return self._iter() + + async def _iter(self) -> AsyncIterator[AgentResponseUpdate]: + for chunk in self._chunks: + yield AgentResponseUpdate(contents=[Content.from_text(text=chunk)], role="assistant") + + +class _DiscordRecorder: + def __init__(self) -> None: + self.requests: list[httpx.Request] = [] + self.json_payloads: list[Any] = [] + + def transport(self) -> httpx.MockTransport: + def handler(request: httpx.Request) -> httpx.Response: + self.requests.append(request) + if request.content: + self.json_payloads.append(json.loads(request.content.decode("utf-8"))) + return httpx.Response(200, json={"ok": True}) + + return httpx.MockTransport(handler) + + +def test_discord_isolation_key_scopes_to_guild_channel_user() -> None: + assert discord_isolation_key("guild", "channel", "user") == "discord:guild:channel:user" + assert discord_isolation_key(None, "dm-channel", "user") == "discord:dm:dm-channel:user" + + +def test_ping_requires_valid_signature_and_returns_pong() -> None: + signing_key = SigningKey.generate() + channel = DiscordChannel( + application_id="app-1", + public_key=signing_key.verify_key.encode().hex(), + register_commands=False, + ) + app = Starlette(routes=list(channel.contribute(_FakeContext()).routes)) # type: ignore[arg-type] + body = json.dumps({"type": 1}).encode("utf-8") + + with TestClient(app) as client: + ok = client.post("/interactions", content=body, headers=_headers(signing_key, body)) + bad = client.post( + "/interactions", + content=body, + headers={ + **_headers(signing_key, body), + "x-signature-ed25519": "00" * 64, + }, + ) + + assert ok.status_code == 200 + assert ok.json() == {"type": 1} + assert bad.status_code == 401 + + +def test_request_validation_errors() -> None: + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + register_commands=False, + skip_signature_verification=True, + max_body_bytes=2, + ) + app = Starlette(routes=list(channel.contribute(_FakeContext()).routes)) # type: ignore[arg-type] + unsupported_channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + register_commands=False, + skip_signature_verification=True, + ) + unsupported_app = Starlette(routes=list(unsupported_channel.contribute(_FakeContext()).routes)) # type: ignore[arg-type] + + with TestClient(app) as client: + too_large = client.post("/interactions", content=b"{}x") + invalid_json = client.post("/interactions", content=b"{") + with TestClient(unsupported_app) as client: + non_object = client.post("/interactions", json=[]) + unsupported = client.post("/interactions", json={"type": 99}) + + assert too_large.status_code == 413 + assert invalid_json.status_code == 400 + assert non_object.status_code == 400 + assert unsupported.status_code == 400 + + +def test_constructor_validates_discord_configuration() -> None: + public_key = SigningKey.generate().verify_key.encode().hex() + + with pytest.raises(ValueError, match="public_key"): + DiscordChannel(application_id="app-1", public_key="not-hex") + with pytest.raises(ValueError, match="command names"): + DiscordChannel(application_id="app-1", public_key=public_key, agent_command="Ask") + with pytest.raises(ValueError, match="unique"): + DiscordChannel( + application_id="app-1", + public_key=public_key, + commands=[ChannelCommand(name="ask", description="Ask again", handle=lambda _ctx: _noop())], + ) + with pytest.raises(ValueError, match="edit_interval"): + DiscordChannel(application_id="app-1", public_key=public_key, edit_interval=-1) + with pytest.raises(ValueError, match="max_body_bytes"): + DiscordChannel(application_id="app-1", public_key=public_key, max_body_bytes=0) + + +async def test_agent_command_runs_host_and_edits_original_response() -> None: + recorder = _DiscordRecorder() + context = _FakeContext(text="agent says hi") + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + register_commands=False, + skip_signature_verification=True, + api_base_url="https://discord.test", + ) + channel.contribute(context) # type: ignore[arg-type] + channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport()) + + await channel._run_agent_command(_interaction(prompt="what now?"), "token") + + assert context.requests[0].operation == "message.create" + assert context.requests[0].input == "what now?" + assert context.requests[0].session is not None + assert context.requests[0].session.isolation_key == "discord:guild-1:channel-1:user-1" + assert context.requests[0].identity is not None + assert context.requests[0].identity.native_id == "user-1" + assert context.requests[0].identity.attributes["channel_id"] == "channel-1" + assert len(context.delivered) == 1 + assert recorder.requests[0].method == "PATCH" + assert recorder.requests[0].url.path == "/webhooks/app-1/token/messages/@original" + assert recorder.json_payloads[0] == {"content": "agent says hi"} + + +async def test_run_hook_can_rewrite_agent_request() -> None: + recorder = _DiscordRecorder() + context = _FakeContext(text="agent says hi") + + async def hook(request: ChannelRequest, **_: Any) -> ChannelRequest: + return ChannelRequest( + channel=request.channel, + operation=request.operation, + input="rewritten", + session=request.session, + metadata=request.metadata, + attributes=request.attributes, + stream=request.stream, + identity=request.identity, + response_target=request.response_target, + ) + + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + register_commands=False, + run_hook=hook, + api_base_url="https://discord.test", + ) + channel.contribute(context) # type: ignore[arg-type] + channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport()) + + await channel._run_agent_command(_interaction(prompt="original"), "token") + + assert context.requests[0].input == "rewritten" + + +async def test_response_hook_rewrites_originating_reply() -> None: + recorder = _DiscordRecorder() + context = _FakeContext(text="original") + + async def hook(result: HostedRunResult[Any], *, context: ChannelResponseContext) -> HostedRunResult[Any]: + assert context.originating is True + assert result.result.text == "original" + return _run_result("rewritten") + + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + register_commands=False, + response_hook=hook, + api_base_url="https://discord.test", + ) + channel.contribute(context) # type: ignore[arg-type] + channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport()) + + await channel._run_agent_command(_interaction(), "token") + + assert recorder.json_payloads[-1] == {"content": "rewritten"} + + +async def test_deliver_response_false_acknowledges_without_originating_payload() -> None: + recorder = _DiscordRecorder() + context = _FakeContext(text="fanout only", include_originating=False) + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + register_commands=False, + api_base_url="https://discord.test", + ) + channel.contribute(context) # type: ignore[arg-type] + channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport()) + + await channel._run_agent_command(_interaction(), "token") + + assert recorder.json_payloads[-1] == {"content": "Sent."} + + +async def test_missing_prompt_edits_original_without_calling_host() -> None: + recorder = _DiscordRecorder() + context = _FakeContext(text="should not run") + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + register_commands=False, + api_base_url="https://discord.test", + ) + channel.contribute(context) # type: ignore[arg-type] + channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport()) + interaction = _interaction() + interaction["data"]["options"] = [] + + await channel._run_agent_command(interaction, "token") + + assert context.requests == [] + assert recorder.json_payloads[-1] == {"content": "Missing required `prompt` option."} + + +async def test_dispatch_application_command_routes_agent_command() -> None: + recorder = _DiscordRecorder() + context = _FakeContext(text="dispatched") + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + register_commands=False, + api_base_url="https://discord.test", + ) + channel.contribute(context) # type: ignore[arg-type] + channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport()) + + await channel._dispatch_application_command(_interaction(command="ask")) + + assert context.requests[0].operation == "message.create" + assert recorder.json_payloads[-1] == {"content": "dispatched"} + + +async def test_channel_command_handler_receives_context_and_replies() -> None: + recorder = _DiscordRecorder() + captured: list[ChannelCommandContext] = [] + + async def handler(ctx: ChannelCommandContext) -> None: + captured.append(ctx) + await ctx.reply("reset done") + + command = ChannelCommand(name="reset", description="Reset", handle=handler) + context = _FakeContext() + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + register_commands=False, + commands=[command], + api_base_url="https://discord.test", + ) + channel.contribute(context) # type: ignore[arg-type] + channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport()) + interaction = _interaction(command="reset") + interaction["data"]["options"] = [{"name": "input", "type": 3, "value": "please"}] + + await channel._run_channel_command(command, interaction, "token") + + assert captured + assert captured[0].request.operation == "command.invoke" + assert captured[0].request.input == "/reset please" + assert recorder.json_payloads == [{"content": "reset done"}] + + +async def test_channel_command_reply_sends_followups_after_first_edit() -> None: + recorder = _DiscordRecorder() + + async def handler(ctx: ChannelCommandContext) -> None: + await ctx.reply("first") + await ctx.reply("second") + + command = ChannelCommand(name="reset", description="Reset", handle=handler) + context = _FakeContext() + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + register_commands=False, + commands=[command], + api_base_url="https://discord.test", + ) + channel.contribute(context) # type: ignore[arg-type] + channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport()) + + await channel._run_channel_command(command, _interaction(command="reset"), "token") + + assert [request.method for request in recorder.requests] == ["PATCH", "POST"] + assert recorder.json_payloads == [{"content": "first"}, {"content": "second"}] + + +async def test_channel_command_reply_chunks_long_content() -> None: + recorder = _DiscordRecorder() + + async def handler(ctx: ChannelCommandContext) -> None: + await ctx.reply("a" * 2001) + + command = ChannelCommand(name="reset", description="Reset", handle=handler) + context = _FakeContext() + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + register_commands=False, + commands=[command], + api_base_url="https://discord.test", + ) + channel.contribute(context) # type: ignore[arg-type] + channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport()) + + await channel._run_channel_command(command, _interaction(command="reset"), "token") + + assert [request.method for request in recorder.requests] == ["PATCH", "POST"] + assert [len(payload["content"]) for payload in recorder.json_payloads] == [2000, 1] + + +async def test_channel_command_edits_done_when_handler_does_not_reply() -> None: + recorder = _DiscordRecorder() + + async def handler(_ctx: ChannelCommandContext) -> None: + return None + + command = ChannelCommand(name="reset", description="Reset", handle=handler) + context = _FakeContext() + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + register_commands=False, + commands=[command], + api_base_url="https://discord.test", + ) + channel.contribute(context) # type: ignore[arg-type] + channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport()) + + await channel._run_channel_command(command, _interaction(command="reset"), "token") + + assert recorder.json_payloads == [{"content": "Done."}] + + +async def test_unknown_command_edits_error_response() -> None: + recorder = _DiscordRecorder() + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + register_commands=False, + api_base_url="https://discord.test", + ) + channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport()) + + await channel._dispatch_application_command(_interaction(command="missing")) + + assert recorder.json_payloads == [{"content": "Unknown Discord command: missing"}] + + +async def test_startup_bulk_registers_guild_commands() -> None: + recorder = _DiscordRecorder() + command = ChannelCommand(name="reset", description="Reset", handle=lambda _ctx: _noop()) + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + bot_token="bot-token", + guild_id="guild-1", + commands=[command], + api_base_url="https://discord.test", + ) + channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport()) + + await channel._on_startup() + + assert recorder.requests[0].method == "PUT" + assert recorder.requests[0].url.path == "/applications/app-1/guilds/guild-1/commands" + assert recorder.requests[0].headers["authorization"] == "Bot bot-token" + assert [payload["name"] for payload in recorder.json_payloads[0]] == ["ask", "reset"] + + +async def test_global_startup_registration_warns_about_propagation(caplog: pytest.LogCaptureFixture) -> None: + recorder = _DiscordRecorder() + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + bot_token="bot-token", + api_base_url="https://discord.test", + ) + channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport()) + + await channel._on_startup() + + assert recorder.requests[0].url.path == "/applications/app-1/commands" + assert "global slash commands" in caplog.text + + +async def test_startup_warns_when_registration_has_no_bot_token(caplog: pytest.LogCaptureFixture) -> None: + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + ) + + await channel._on_startup() + await channel._on_shutdown() + + assert "slash commands must be registered outside the host" in caplog.text + + +async def test_originating_reply_sends_followup_chunks() -> None: + recorder = _DiscordRecorder() + context = _FakeContext(text="a" * 2001) + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + register_commands=False, + api_base_url="https://discord.test", + ) + channel.contribute(context) # type: ignore[arg-type] + channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport()) + + await channel._run_agent_command(_interaction(), "token") + + assert [request.method for request in recorder.requests] == ["PATCH", "POST"] + assert [len(payload["content"]) for payload in recorder.json_payloads] == [2000, 1] + + +async def test_push_requires_channel_id_and_sends_chunked_messages() -> None: + recorder = _DiscordRecorder() + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + bot_token="bot-token", + register_commands=False, + api_base_url="https://discord.test", + ) + channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport()) + + await channel.push( + identity=channel._identity_from_interaction(_interaction()), # pyright: ignore[reportPrivateUsage] + payload=_run_result("a" * 2001), + ) + + assert [request.url.path for request in recorder.requests] == [ + "/channels/channel-1/messages", + "/channels/channel-1/messages", + ] + assert [len(payload["content"]) for payload in recorder.json_payloads] == [2000, 1] + + +async def test_push_renders_no_response_for_unknown_payload_shape() -> None: + recorder = _DiscordRecorder() + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + bot_token="bot-token", + register_commands=False, + api_base_url="https://discord.test", + ) + channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport()) + + await channel.push( + identity=channel._identity_from_interaction(_interaction()), # pyright: ignore[reportPrivateUsage] + payload=HostedRunResult(object()), + ) + + assert recorder.json_payloads == [{"content": "(no response)"}] + + +async def test_push_requires_bot_token_and_channel_id() -> None: + identity = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + register_commands=False, + )._identity_from_interaction(_interaction()) # pyright: ignore[reportPrivateUsage] + no_bot_token = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + register_commands=False, + ) + no_channel_id = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + bot_token="bot-token", + register_commands=False, + ) + + with pytest.raises(RuntimeError, match="bot_token"): + await no_bot_token.push(identity=identity, payload=_run_result("hello")) + with pytest.raises(ValueError, match="channel_id"): + await no_channel_id.push( + identity=type(identity)(channel=identity.channel, native_id=identity.native_id, attributes={}), + payload=_run_result("hello"), + ) + + +async def test_streaming_edits_original_and_delivers_final_response() -> None: + recorder = _DiscordRecorder() + context = _FakeContext() + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + register_commands=False, + streaming=True, + edit_interval=0, + api_base_url="https://discord.test", + ) + channel.contribute(context) # type: ignore[arg-type] + channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport()) + + await channel._run_agent_command(_interaction(), "token") + + assert [payload["content"] for payload in recorder.json_payloads] == ["a", "ab", "ab"] + assert len(context.delivered) == 1 + assert context.delivered[0][1].result.text == "ab" + + +async def test_streaming_preview_is_limited_and_final_reply_is_chunked() -> None: + recorder = _DiscordRecorder() + context = _FakeContext() + context.stream = _FakeStream(["a" * 2001]) + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + register_commands=False, + streaming=True, + edit_interval=0, + api_base_url="https://discord.test", + ) + channel.contribute(context) # type: ignore[arg-type] + channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport()) + + await channel._run_agent_command(_interaction(), "token") + + assert [request.method for request in recorder.requests] == ["PATCH", "PATCH", "POST"] + assert [len(payload["content"]) for payload in recorder.json_payloads] == [2000, 2000, 1] + assert len(context.delivered[0][1].result.text) == 2001 + + +async def test_stream_transform_hook_can_drop_updates_and_disable_originating_reply() -> None: + recorder = _DiscordRecorder() + context = _FakeContext(include_originating=False) + + async def hook(update: AgentResponseUpdate) -> AgentResponseUpdate | None: + if update.text == "a": + return None + return update + + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + register_commands=False, + streaming=True, + stream_transform_hook=hook, + edit_interval=0, + api_base_url="https://discord.test", + ) + channel.contribute(context) # type: ignore[arg-type] + channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport()) + + await channel._run_agent_command(_interaction(), "token") + + assert [payload["content"] for payload in recorder.json_payloads] == ["b", "Sent."] + assert context.delivered[0][1].result.text == "b" + + +async def test_stream_transform_hook_can_synchronously_rewrite_updates() -> None: + recorder = _DiscordRecorder() + context = _FakeContext() + + def hook(_update: AgentResponseUpdate) -> AgentResponseUpdate: + return AgentResponseUpdate(contents=[Content.from_text(text="x")], role="assistant") + + channel = DiscordChannel( + application_id="app-1", + public_key=SigningKey.generate().verify_key.encode().hex(), + register_commands=False, + streaming=True, + stream_transform_hook=hook, + edit_interval=0, + api_base_url="https://discord.test", + ) + channel.contribute(context) # type: ignore[arg-type] + channel._http = httpx.AsyncClient(base_url="https://discord.test", transport=recorder.transport()) + + await channel._run_agent_command(_interaction(), "token") + + assert [payload["content"] for payload in recorder.json_payloads] == ["x", "xx", "xx"] + + +async def _noop() -> None: + return None diff --git a/python/packages/hosting/agent_framework_hosting/__init__.py b/python/packages/hosting/agent_framework_hosting/__init__.py index 72553d1aef..aa91f654b0 100644 --- a/python/packages/hosting/agent_framework_hosting/__init__.py +++ b/python/packages/hosting/agent_framework_hosting/__init__.py @@ -71,6 +71,7 @@ RetryPolicy, TaskHandle, TaskStatus, + apply_channel_response_hook, apply_response_hook, apply_run_hook, ) @@ -134,6 +135,7 @@ "TaskHandle", "TaskStatus", "__version__", + "apply_channel_response_hook", "apply_response_hook", "apply_run_hook", "get_current_isolation_keys", diff --git a/python/packages/hosting/agent_framework_hosting/_host.py b/python/packages/hosting/agent_framework_hosting/_host.py index a64d3071dd..05ea0fabb9 100644 --- a/python/packages/hosting/agent_framework_hosting/_host.py +++ b/python/packages/hosting/agent_framework_hosting/_host.py @@ -81,15 +81,13 @@ ChannelPush, ChannelPushCodec, ChannelRequest, - ChannelResponseContext, - ChannelResponseHook, DurableTaskPayloadMode, DurableTaskRunner, HostedRunResult, HostStatePaths, PushPayloadNotSerializable, ResponseTargetKind, - apply_response_hook, + apply_channel_response_hook, ) if TYPE_CHECKING: @@ -1896,17 +1894,15 @@ async def _deliver_payload_to_channel( contract; richer surfaces stay attribute-level so adding hook support to a new channel does not require updating the Protocol. """ - shaped: HostedRunResult[Any] = payload.replace() - hook = cast(ChannelResponseHook | None, getattr(channel, "response_hook", None)) - if callable(hook): - ctx = ChannelResponseContext( - request=request, - channel_name=channel.name, - destination_identity=identity, - originating=False, - is_echo=is_echo, - ) - shaped = await apply_response_hook(hook, shaped, context=ctx) + shaped = await apply_channel_response_hook( + channel, + payload, + request=request, + destination_identity=identity, + originating=False, + is_echo=is_echo, + clone=True, + ) await channel.push(identity, shaped) return shaped diff --git a/python/packages/hosting/agent_framework_hosting/_types.py b/python/packages/hosting/agent_framework_hosting/_types.py index 30ad79cdd5..64ac9258f8 100644 --- a/python/packages/hosting/agent_framework_hosting/_types.py +++ b/python/packages/hosting/agent_framework_hosting/_types.py @@ -25,7 +25,7 @@ from collections.abc import Awaitable, Callable, Mapping, Sequence from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypedDict, TypeVar, runtime_checkable +from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypedDict, TypeVar, cast, runtime_checkable from agent_framework import ( AgentResponse, @@ -760,6 +760,52 @@ class ChannelPush(Protocol): async def push(self, identity: ChannelIdentity, payload: HostedRunResult[Any]) -> None: ... +async def apply_channel_response_hook( + channel: Channel | ChannelPush, + result: HostedRunResult[Any], + *, + request: ChannelRequest, + originating: bool, + destination_identity: ChannelIdentity | None = None, + is_echo: bool = False, + clone: bool = False, +) -> HostedRunResult[Any]: + """Apply a channel's optional response hook with the standard context. + + Channels and the host call this helper when they need to shape a + :class:`HostedRunResult` for one destination. The helper centralizes the + response-hook convention: hooks are discovered from a duck-typed + ``response_hook`` attribute, called through :func:`apply_response_hook`, + and receive a :class:`ChannelResponseContext` that identifies the channel, + destination identity, originating-vs-push phase, and echo phase. + + Args: + channel: Channel whose ``response_hook`` attribute may shape the payload. + result: Hosted run result to pass to the hook. + request: Originating channel request. + originating: Whether this is the originating channel's synchronous reply. + destination_identity: Destination identity for non-originating pushes, or + ``None`` for originating replies. + is_echo: Whether the payload is an echo of the user input. + clone: Whether to shallow-clone ``result`` before applying the hook. + + Returns: + The original, cloned, or hook-shaped hosted run result. + """ + shaped = result.replace() if clone else result + hook = cast(ChannelResponseHook | None, getattr(channel, "response_hook", None)) + if not callable(hook): + return shaped + context = ChannelResponseContext( + request=request, + channel_name=channel.name, + destination_identity=destination_identity, + originating=originating, + is_echo=is_echo, + ) + return await apply_response_hook(hook, shaped, context=context) + + # --------------------------------------------------------------------------- # # Durable task runner — pluggable seam for non-originating push fan-out and # (in v1 fast-follow) background runs. See spec §"Durable task runner". @@ -910,6 +956,7 @@ async def get(self, handle: TaskHandle) -> TaskStatus | None: "RetryPolicy", "TaskHandle", "TaskStatus", + "apply_channel_response_hook", "apply_response_hook", "apply_run_hook", ] diff --git a/python/packages/hosting/tests/test_types.py b/python/packages/hosting/tests/test_types.py index e502e16dca..3253c50905 100644 --- a/python/packages/hosting/tests/test_types.py +++ b/python/packages/hosting/tests/test_types.py @@ -7,12 +7,16 @@ from typing import Any from agent_framework_hosting import ( + ChannelContribution, ChannelIdentity, ChannelRequest, + ChannelResponseContext, ChannelSession, DurableTaskPayloadMode, + HostedRunResult, ResponseTarget, ResponseTargetKind, + apply_channel_response_hook, apply_run_hook, ) @@ -117,6 +121,88 @@ class _DummyTarget: """ +class _DummyChannel: + name = "dummy" + path = "/dummy" + + def contribute(self, _context: Any) -> ChannelContribution: + return ChannelContribution() + + +class TestApplyChannelResponseHook: + async def test_originating_hook_receives_standard_context(self) -> None: + request = ChannelRequest(channel="discord", operation="message.create", input="hi") + payload = HostedRunResult("original") + captured: list[ChannelResponseContext] = [] + + async def hook( + result: HostedRunResult[Any], + *, + context: ChannelResponseContext, + ) -> HostedRunResult[Any]: + captured.append(context) + return result.replace(result="hooked") + + channel = _DummyChannel() + channel.response_hook = hook # type: ignore[attr-defined] + + shaped = await apply_channel_response_hook(channel, payload, request=request, originating=True) + + assert shaped.result == "hooked" + assert captured[0].request is request + assert captured[0].channel_name == "dummy" + assert captured[0].destination_identity is None + assert captured[0].originating is True + assert captured[0].is_echo is False + + async def test_non_originating_hook_can_clone_before_shaping(self) -> None: + request = ChannelRequest(channel="responses", operation="message.create", input="hi") + identity = ChannelIdentity(channel="dummy", native_id="user-1") + payload = HostedRunResult("original") + seen_payloads: list[HostedRunResult[Any]] = [] + seen_contexts: list[ChannelResponseContext] = [] + + def hook( + result: HostedRunResult[Any], + *, + context: ChannelResponseContext, + ) -> HostedRunResult[Any]: + seen_payloads.append(result) + seen_contexts.append(context) + return result.replace(result="hooked") + + channel = _DummyChannel() + channel.response_hook = hook # type: ignore[attr-defined] + + shaped = await apply_channel_response_hook( + channel, + payload, + request=request, + destination_identity=identity, + originating=False, + is_echo=True, + clone=True, + ) + + assert seen_payloads[0] is not payload + assert shaped.result == "hooked" + assert seen_contexts[0].destination_identity is identity + assert seen_contexts[0].originating is False + assert seen_contexts[0].is_echo is True + + async def test_missing_hook_returns_payload_or_clone(self) -> None: + request = ChannelRequest(channel="responses", operation="message.create", input="hi") + payload = HostedRunResult("original") + channel = _DummyChannel() + + same = await apply_channel_response_hook(channel, payload, request=request, originating=True) + cloned = await apply_channel_response_hook(channel, payload, request=request, originating=True, clone=True) + + assert same is payload + assert cloned is not payload + assert cloned.result == payload.result + + class TestApplyRunHook: """`apply_run_hook` is the channel-side helper that invokes a `ChannelRunHook` with the standard kwargs (`request` positional, diff --git a/python/pyproject.toml b/python/pyproject.toml index 4beeb9d5c6..1c695ec8a9 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -90,6 +90,7 @@ agent-framework-hosting-invocations = { workspace = true } agent-framework-hosting-telegram = { workspace = true } agent-framework-hosting-activity-protocol = { workspace = true } agent-framework-hosting-entra = { workspace = true } +agent-framework-hosting-discord = { workspace = true } agent-framework-hyperlight = { workspace = true } agent-framework-lab = { workspace = true } agent-framework-mem0 = { workspace = true } diff --git a/python/uv.lock b/python/uv.lock index 15df562b43..f3044b9a68 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -48,11 +48,12 @@ members = [ "agent-framework-gemini", "agent-framework-github-copilot", "agent-framework-hosting", - "agent-framework-hosting-responses", - "agent-framework-hosting-invocations", - "agent-framework-hosting-telegram", "agent-framework-hosting-activity-protocol", + "agent-framework-hosting-discord", "agent-framework-hosting-entra", + "agent-framework-hosting-invocations", + "agent-framework-hosting-responses", + "agent-framework-hosting-telegram", "agent-framework-hyperlight", "agent-framework-lab", "agent-framework-mem0", @@ -648,13 +649,41 @@ provides-extras = ["serve", "disk"] dev = [{ name = "httpx", specifier = ">=0.28.1" }] [[package]] -name = "agent-framework-hosting-telegram" +name = "agent-framework-hosting-activity-protocol" version = "1.0.0a260424" -source = { editable = "packages/hosting-telegram" } +source = { editable = "packages/hosting-activity-protocol" } +dependencies = [ + { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "agent-framework-hosting", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "azure-identity", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, +] + +[package.metadata] +requires-dist = [ + { name = "agent-framework-core", editable = "packages/core" }, + { name = "agent-framework-hosting", editable = "packages/hosting" }, + { name = "azure-identity", specifier = ">=1.20,<2" }, + { name = "httpx", specifier = ">=0.27,<1" }, +] + +[[package]] +name = "agent-framework-hosting-discord" +version = "1.0.0a260526" +source = { editable = "packages/hosting-discord" } dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "agent-framework-hosting", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "pynacl", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, +] + +[package.metadata] +requires-dist = [ + { name = "agent-framework-core", editable = "packages/core" }, + { name = "agent-framework-hosting", editable = "packages/hosting" }, + { name = "httpx", specifier = ">=0.27,<1" }, + { name = "pynacl", specifier = ">=1.2.0,<2" }, ] [[package]] @@ -669,46 +698,62 @@ dependencies = [ { name = "msal", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] +[package.metadata] +requires-dist = [ + { name = "agent-framework-core", editable = "packages/core" }, + { name = "agent-framework-hosting", editable = "packages/hosting" }, + { name = "cryptography", specifier = ">=42" }, + { name = "httpx", specifier = ">=0.27,<1" }, + { name = "msal", specifier = ">=1.28,<2" }, +] + [[package]] -name = "agent-framework-hosting-responses" +name = "agent-framework-hosting-invocations" version = "1.0.0a260424" -source = { editable = "packages/hosting-responses" } +source = { editable = "packages/hosting-invocations" } dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "agent-framework-hosting", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, - +] [package.metadata] requires-dist = [ { name = "agent-framework-core", editable = "packages/core" }, { name = "agent-framework-hosting", editable = "packages/hosting" }, - { name = "openai", specifier = ">=1.99.0,<3" }, ] [[package]] -name = "agent-framework-hosting-activity-protocol" +name = "agent-framework-hosting-responses" version = "1.0.0a260424" -source = { editable = "packages/hosting-activity-protocol" } +source = { editable = "packages/hosting-responses" } dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "agent-framework-hosting", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "azure-identity", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, +] + +[package.metadata] +requires-dist = [ + { name = "agent-framework-core", editable = "packages/core" }, + { name = "agent-framework-hosting", editable = "packages/hosting" }, + { name = "openai", specifier = ">=1.99.0,<3" }, ] [[package]] -name = "agent-framework-hosting-invocations" +name = "agent-framework-hosting-telegram" version = "1.0.0a260424" -source = { editable = "packages/hosting-invocations" } +source = { editable = "packages/hosting-telegram" } dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "agent-framework-hosting", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "azure-identity", specifier = ">=1.20,<2" }, - { name = "httpx", specifier = ">=0.27,<1" }, - { name = "cryptography", specifier = ">=42" }, + { name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, +] + +[package.metadata] +requires-dist = [ + { name = "agent-framework-core", editable = "packages/core" }, + { name = "agent-framework-hosting", editable = "packages/hosting" }, { name = "httpx", specifier = ">=0.27,<1" }, - { name = "msal", specifier = ">=1.28,<2" }, ] [[package]]