diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index aa99e7c88..4661c14ad 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -626,14 +626,14 @@ async def sse_writer(): # pragma: lax no cover # Then send the message to be processed by the server session_message = self._create_session_message(message, request, request_id, protocol_version) await writer.send(session_message) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("SSE response error") await sse_stream_writer.aclose() await self._clean_up_memory_streams(request_id) finally: await sse_stream_reader.aclose() - except Exception as err: # pragma: no cover + except Exception as err: # pragma: lax no cover logger.exception("Error handling POST request") response = self._create_error_response( f"Error handling POST request: {err}", @@ -816,7 +816,7 @@ async def _validate_request_headers(self, request: Request, send: Send) -> bool: async def _validate_session(self, request: Request, send: Send) -> bool: """Validate the session ID in the request.""" - if not self.mcp_session_id: # pragma: no cover + if not self.mcp_session_id: # pragma: lax no cover # If we're not using session IDs, return True return True @@ -849,7 +849,7 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) # If no protocol version provided, assume default version - if protocol_version is None: # pragma: no cover + if protocol_version is None: # pragma: lax no cover protocol_version = DEFAULT_NEGOTIATED_VERSION # Check if the protocol version is supported @@ -1026,7 +1026,7 @@ async def message_router(): ) except anyio.ClosedResourceError: if self._terminated: - logger.debug("Read stream closed by client") + logger.debug("Read stream closed by client") # pragma: lax no cover else: logger.exception("Unexpected closure of read stream in message router") except Exception: # pragma: lax no cover diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index c25314eab..56f982e52 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -90,6 +90,9 @@ def __init__( self._session_creation_lock = anyio.Lock() self._server_instances: dict[str, StreamableHTTPServerTransport] = {} + # Track in-flight stateless transports for graceful shutdown + self._stateless_transports: set[StreamableHTTPServerTransport] = set() + # The task group will be set during lifespan self._task_group = None # Thread-safe tracking of run() calls @@ -130,11 +133,28 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: yield # Let the application run finally: logger.info("StreamableHTTP session manager shutting down") + + # Terminate all active transports before cancelling the task + # group. This closes their in-memory streams, which lets + # EventSourceResponse send a final ``more_body=False`` chunk + # — a clean HTTP close instead of a connection reset. + for transport in list(self._server_instances.values()): + try: + await transport.terminate() + except Exception: # pragma: no cover + logger.debug("Error terminating transport during shutdown", exc_info=True) + for transport in list(self._stateless_transports): + try: + await transport.terminate() + except Exception: # pragma: no cover + logger.debug("Error terminating stateless transport during shutdown", exc_info=True) + # Cancel task group to stop all spawned tasks tg.cancel_scope.cancel() self._task_group = None # Clear any remaining server instances self._server_instances.clear() + self._stateless_transports.clear() async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Process ASGI request with proper session handling and transport setup. @@ -151,7 +171,12 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await self._handle_stateful_request(scope, receive, send) async def _handle_stateless_request(self, scope: Scope, receive: Receive, send: Send) -> None: - """Process request in stateless mode - creating a new transport for each request.""" + """Process request in stateless mode - creating a new transport for each request. + + Uses a request-scoped task group so the server task is automatically + cancelled when the request completes, preventing task accumulation in + the manager's global task group. + """ logger.debug("Stateless mode: Creating new transport for this request") # No session ID needed in stateless mode http_transport = StreamableHTTPServerTransport( @@ -161,6 +186,9 @@ async def _handle_stateless_request(self, scope: Scope, receive: Receive, send: security_settings=self.security_settings, ) + # Track for graceful shutdown + self._stateless_transports.add(http_transport) + # Start server in a new task async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED): async with http_transport.connect() as streams: @@ -173,18 +201,27 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA self.app.create_initialization_options(), stateless=True, ) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("Stateless session crashed") - # Assert task group is not None for type checking - assert self._task_group is not None - # Start the server task - await self._task_group.start(run_stateless_server) - - # Handle the HTTP request and return the response - await http_transport.handle_request(scope, receive, send) - - # Terminate the transport after the request is handled + # Use a request-scoped task group instead of the global one. + # This ensures the server task is cancelled when the request + # finishes, preventing zombie tasks from accumulating. + # See: https://github.com/modelcontextprotocol/python-sdk/issues/1764 + try: + async with anyio.create_task_group() as request_tg: + await request_tg.start(run_stateless_server) + # Handle the HTTP request directly in the caller's context + # (not as a child task) so execution flows back naturally. + await http_transport.handle_request(scope, receive, send) + # Cancel the request-scoped task group to stop the server task. + request_tg.cancel_scope.cancel() + finally: + self._stateless_transports.discard(http_transport) + + # Terminate after the task group exits — the server task is already + # cancelled at this point, so this is just cleanup (sets _terminated + # flag and closes any remaining streams). await http_transport.terminate() async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: Send) -> None: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 9364abb73..42ff12ba8 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -423,7 +423,7 @@ async def _receive_loop(self) -> None: try: await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) await stream.aclose() - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover # Stream might already be closed pass self._response_streams.clear() diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 47cfbf14a..60b07f6a7 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -8,13 +8,13 @@ import anyio import httpx import pytest -from starlette.types import Message +from starlette.types import ASGIApp, Message, Receive, Scope, Send -from mcp import Client +from mcp import Client, types from mcp.client.streamable_http import streamable_http_client from mcp.server import Server, ServerRequestContext, streamable_http_manager from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.types import INVALID_REQUEST, ListToolsResult, PaginatedRequestParams @@ -269,6 +269,83 @@ async def mock_receive(): assert len(transport._request_streams) == 0, "Transport should have no active request streams" +@pytest.mark.anyio +async def test_stateless_requests_task_leak_on_client_disconnect(): + """Test that stateless tasks don't leak when clients disconnect mid-request. + + Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1764 + + Reproduces the production memory leak: a client sends a tool call, the tool + handler takes some time, and the client disconnects before the response is + delivered. The SSE response pipeline detects the disconnect but app.run() + continues in the background. After the tool finishes, the response has + nowhere to go, and app.run() blocks on ``async for message in + session.incoming_messages`` forever — leaking the task in the global + task group. + + The test uses real Server.run() with a real tool handler, real SSE streaming + via httpx.ASGITransport, and simulates client disconnect by cancelling the + request task. + """ + from mcp.types import CallToolResult, TextContent + + tool_started = anyio.Event() + tool_gate = anyio.Event() + + async def handle_call_tool(ctx: ServerRequestContext, params: Any) -> CallToolResult: + tool_started.set() + await tool_gate.wait() + return CallToolResult(content=[TextContent(type="text", text="done")]) # pragma: no cover + + app = Server( + "test-stateless-leak", + on_call_tool=handle_call_tool, + ) + + host = "testserver" + mcp_app = app.streamable_http_app(host=host, stateless_http=True) + + async with ( + mcp_app.router.lifespan_context(mcp_app), + httpx.ASGITransport(mcp_app) as transport, + ): + session_manager = app._session_manager + assert session_manager is not None + + async def make_and_abandon_tool_call(): + async with httpx.AsyncClient(transport=transport, base_url=f"http://{host}", timeout=30.0) as http_client: + async with Client(streamable_http_client(f"http://{host}/mcp", http_client=http_client)) as client: + # Start tool call — this will block until tool completes + # We'll cancel it from outside to simulate disconnect + await client.call_tool("slow_tool", {}) + + num_requests = 3 + for _ in range(num_requests): + async with anyio.create_task_group() as tg: + tg.start_soon(make_and_abandon_tool_call) + # Wait for the tool handler to actually start + await tool_started.wait() + tool_started = anyio.Event() # Reset for next iteration + # Simulate client disconnect by cancelling the request + tg.cancel_scope.cancel() + + # Let the tool finish now (response has nowhere to go) + tool_gate.set() + tool_gate = anyio.Event() # Reset for next iteration + + # Give tasks a chance to settle + await anyio.sleep(0.1) + + # Check for leaked tasks in the session manager's global task group + await anyio.sleep(0.1) + assert session_manager._task_group is not None + leaked = len(session_manager._task_group._tasks) # type: ignore[attr-defined] + + assert leaked == 0, ( + f"Expected 0 lingering tasks but found {leaked}. Stateless request tasks are leaking after client disconnect." + ) + + @pytest.mark.anyio async def test_unknown_session_id_returns_404(caplog: pytest.LogCaptureFixture): """Test that requests with unknown session IDs return HTTP 404 per MCP spec.""" @@ -413,3 +490,165 @@ def test_session_idle_timeout_rejects_non_positive(): def test_session_idle_timeout_rejects_stateless(): with pytest.raises(RuntimeError, match="not supported in stateless"): StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True) + + +MCP_HEADERS = { + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", +} + +_INITIALIZE_REQUEST = { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "0.1"}, + }, +} + +_INITIALIZED_NOTIFICATION = { + "jsonrpc": "2.0", + "method": "notifications/initialized", +} + +_TOOL_CALL_REQUEST = { + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": {"name": "slow_tool", "arguments": {"message": "hello"}}, +} + + +def _make_slow_tool_server() -> tuple[Server, anyio.Event]: + """Create an MCP server with a tool that blocks forever, returning + the server and an event that fires when the tool starts executing.""" + tool_started = anyio.Event() + + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + tool_started.set() + await anyio.sleep_forever() + return types.CallToolResult( # pragma: no cover + content=[types.TextContent(type="text", text="never reached")] + ) + + async def handle_list_tools( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListToolsResult: # pragma: no cover + return ListToolsResult( + tools=[ + types.Tool( + name="slow_tool", + description="A tool that blocks forever", + input_schema={"type": "object", "properties": {"message": {"type": "string"}}}, + ) + ] + ) + + app = Server("test-graceful-shutdown", on_call_tool=handle_call_tool, on_list_tools=handle_list_tools) + return app, tool_started + + +class SSECloseTracker: + """ASGI middleware that tracks whether SSE responses close cleanly. + + In HTTP, a clean close means sending a final empty chunk (``0\\r\\n\\r\\n``). + At the ASGI protocol level this corresponds to a + ``{"type": "http.response.body", "more_body": False}`` message. + + Without graceful drain, the server task is cancelled but nothing closes + the stateless transport's streams — the SSE response hangs indefinitely + and never sends the final body. A reverse proxy (e.g. nginx) would log + "upstream prematurely closed connection while reading upstream". + """ + + def __init__(self, app: ASGIApp) -> None: + self.app = app + self.sse_streams_opened = 0 + self.sse_streams_closed_cleanly = 0 + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + is_sse = False + + async def tracking_send(message: Message) -> None: + nonlocal is_sse + if message["type"] == "http.response.start": + for name, value in message.get("headers", []): + if name == b"content-type" and b"text/event-stream" in value: + is_sse = True + self.sse_streams_opened += 1 + break + elif message["type"] == "http.response.body" and is_sse: + if not message.get("more_body", False): + self.sse_streams_closed_cleanly += 1 + await send(message) + + await self.app(scope, receive, tracking_send) + + +@pytest.mark.anyio +async def test_graceful_shutdown_closes_sse_streams_cleanly(): + """Verify that shutting down the session manager closes in-flight SSE + streams with a proper ``more_body=False`` ASGI message. + + This is the ASGI equivalent of sending the final HTTP chunk — the signal + that reverse proxies like nginx use to distinguish a clean close from a + connection reset ("upstream prematurely closed connection"). + + Without the graceful-drain fix, stateless transports are not tracked by + the session manager. On shutdown nothing calls ``terminate()`` on them, + so SSE responses hang indefinitely and never send the final body. With + the fix, ``run()``'s finally block iterates ``_stateless_transports`` and + terminates each one, closing the underlying memory streams and letting + ``EventSourceResponse`` complete normally. + """ + app, tool_started = _make_slow_tool_server() + manager = StreamableHTTPSessionManager(app=app, stateless=True) + + tracker = SSECloseTracker(StreamableHTTPASGIApp(manager)) + + manager_ready = anyio.Event() + + with anyio.fail_after(10): + async with anyio.create_task_group() as tg: + + async def run_lifespan_and_shutdown() -> None: + async with manager.run(): + manager_ready.set() + with anyio.fail_after(5): + await tool_started.wait() + # manager.run() exits — graceful shutdown runs here + + async def make_requests() -> None: + with anyio.fail_after(5): + await manager_ready.wait() + async with ( + httpx.ASGITransport(tracker, raise_app_exceptions=False) as transport, + httpx.AsyncClient(transport=transport, base_url="http://testserver") as client, + ): + # Initialize + resp = await client.post("/mcp/", json=_INITIALIZE_REQUEST, headers=MCP_HEADERS) + resp.raise_for_status() + + # Send initialized notification + resp = await client.post("/mcp/", json=_INITIALIZED_NOTIFICATION, headers=MCP_HEADERS) + assert resp.status_code == 202 + + # Send slow tool call — returns an SSE stream that blocks + # until shutdown terminates it + await client.post( + "/mcp/", + json=_TOOL_CALL_REQUEST, + headers=MCP_HEADERS, + timeout=httpx.Timeout(10, connect=5), + ) + + tg.start_soon(run_lifespan_and_shutdown) + tg.start_soon(make_requests) + + assert tracker.sse_streams_opened > 0, "Test should have opened at least one SSE stream" + assert tracker.sse_streams_closed_cleanly == tracker.sse_streams_opened, ( + f"All {tracker.sse_streams_opened} SSE stream(s) should have closed with " + f"more_body=False, but only {tracker.sse_streams_closed_cleanly} did" + )