Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,14 +626,14 @@ async def sse_writer(): # pragma: lax no cover
# Then send the message to be processed by the server
session_message = self._create_session_message(message, request, request_id, protocol_version)
await writer.send(session_message)
except Exception: # pragma: no cover
except Exception: # pragma: lax no cover
logger.exception("SSE response error")
await sse_stream_writer.aclose()
await self._clean_up_memory_streams(request_id)
finally:
await sse_stream_reader.aclose()

except Exception as err: # pragma: no cover
except Exception as err: # pragma: lax no cover
logger.exception("Error handling POST request")
response = self._create_error_response(
f"Error handling POST request: {err}",
Expand Down Expand Up @@ -816,7 +816,7 @@ async def _validate_request_headers(self, request: Request, send: Send) -> bool:

async def _validate_session(self, request: Request, send: Send) -> bool:
"""Validate the session ID in the request."""
if not self.mcp_session_id: # pragma: no cover
if not self.mcp_session_id: # pragma: lax no cover
# If we're not using session IDs, return True
return True

Expand Down Expand Up @@ -849,7 +849,7 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)

# If no protocol version provided, assume default version
if protocol_version is None: # pragma: no cover
if protocol_version is None: # pragma: lax no cover
protocol_version = DEFAULT_NEGOTIATED_VERSION

# Check if the protocol version is supported
Expand Down Expand Up @@ -1026,7 +1026,7 @@ async def message_router():
)
except anyio.ClosedResourceError:
if self._terminated:
logger.debug("Read stream closed by client")
logger.debug("Read stream closed by client") # pragma: lax no cover
else:
logger.exception("Unexpected closure of read stream in message router")
except Exception: # pragma: lax no cover
Expand Down
59 changes: 48 additions & 11 deletions src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def __init__(
self._session_creation_lock = anyio.Lock()
self._server_instances: dict[str, StreamableHTTPServerTransport] = {}

# Track in-flight stateless transports for graceful shutdown
self._stateless_transports: set[StreamableHTTPServerTransport] = set()

# The task group will be set during lifespan
self._task_group = None
# Thread-safe tracking of run() calls
Expand Down Expand Up @@ -130,11 +133,28 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
yield # Let the application run
finally:
logger.info("StreamableHTTP session manager shutting down")

# Terminate all active transports before cancelling the task
# group. This closes their in-memory streams, which lets
# EventSourceResponse send a final ``more_body=False`` chunk
# — a clean HTTP close instead of a connection reset.
for transport in list(self._server_instances.values()):
try:
await transport.terminate()
except Exception: # pragma: no cover
logger.debug("Error terminating transport during shutdown", exc_info=True)
for transport in list(self._stateless_transports):
try:
await transport.terminate()
except Exception: # pragma: no cover
logger.debug("Error terminating stateless transport during shutdown", exc_info=True)

# Cancel task group to stop all spawned tasks
tg.cancel_scope.cancel()
self._task_group = None
# Clear any remaining server instances
self._server_instances.clear()
self._stateless_transports.clear()

async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Process ASGI request with proper session handling and transport setup.
Expand All @@ -151,7 +171,12 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No
await self._handle_stateful_request(scope, receive, send)

async def _handle_stateless_request(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Process request in stateless mode - creating a new transport for each request."""
"""Process request in stateless mode - creating a new transport for each request.

Uses a request-scoped task group so the server task is automatically
cancelled when the request completes, preventing task accumulation in
the manager's global task group.
"""
logger.debug("Stateless mode: Creating new transport for this request")
# No session ID needed in stateless mode
http_transport = StreamableHTTPServerTransport(
Expand All @@ -161,6 +186,9 @@ async def _handle_stateless_request(self, scope: Scope, receive: Receive, send:
security_settings=self.security_settings,
)

# Track for graceful shutdown
self._stateless_transports.add(http_transport)

# Start server in a new task
async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED):
async with http_transport.connect() as streams:
Expand All @@ -173,18 +201,27 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA
self.app.create_initialization_options(),
stateless=True,
)
except Exception: # pragma: no cover
except Exception: # pragma: lax no cover
logger.exception("Stateless session crashed")

# Assert task group is not None for type checking
assert self._task_group is not None
# Start the server task
await self._task_group.start(run_stateless_server)

# Handle the HTTP request and return the response
await http_transport.handle_request(scope, receive, send)

# Terminate the transport after the request is handled
# Use a request-scoped task group instead of the global one.
# This ensures the server task is cancelled when the request
# finishes, preventing zombie tasks from accumulating.
# See: https://github.com/modelcontextprotocol/python-sdk/issues/1764
try:
async with anyio.create_task_group() as request_tg:
await request_tg.start(run_stateless_server)
# Handle the HTTP request directly in the caller's context
# (not as a child task) so execution flows back naturally.
await http_transport.handle_request(scope, receive, send)
# Cancel the request-scoped task group to stop the server task.
request_tg.cancel_scope.cancel()
finally:
self._stateless_transports.discard(http_transport)

# Terminate after the task group exits — the server task is already
# cancelled at this point, so this is just cleanup (sets _terminated
# flag and closes any remaining streams).
await http_transport.terminate()

async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: Send) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ async def _receive_loop(self) -> None:
try:
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
await stream.aclose()
except Exception: # pragma: no cover
except Exception: # pragma: lax no cover
# Stream might already be closed
pass
self._response_streams.clear()
Expand Down
Loading
Loading