Skip to content

Commit f57edfa

Browse files
ShivamShivam
authored andcommitted
feat: add subject and claims fields to AccessToken
Add two optional fields to AccessToken: - subject: str | None — stores the JWT sub claim (user ID) - claims: dict[str, Any] | None — stores arbitrary custom JWT claims Also add Context.subject property so tool handlers can read the authenticated user's subject via ctx.subject without importing get_access_token directly. Both fields default to None, preserving full backward compatibility. Closes #1038
1 parent 528abfa commit f57edfa

File tree

4 files changed

+127
-1
lines changed

4 files changed

+127
-1
lines changed

src/mcp/server/auth/provider.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Generic, Literal, Protocol, TypeVar
2+
from typing import Any, Generic, Literal, Protocol, TypeVar
33
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
44

55
from pydantic import AnyUrl, BaseModel
@@ -40,6 +40,8 @@ class AccessToken(BaseModel):
4040
scopes: list[str]
4141
expires_at: int | None = None
4242
resource: str | None = None # RFC 8707 resource indicator
43+
subject: str | None = None # JWT sub claim (user ID)
44+
claims: dict[str, Any] | None = None # Additional JWT claims beyond reserved fields
4345

4446

4547
RegistrationErrorCode = Literal[

src/mcp/server/mcpserver/context.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from pydantic import AnyUrl, BaseModel
77

8+
from mcp.server.auth.middleware.auth_context import get_access_token
89
from mcp.server.context import LifespanContextT, RequestT, ServerRequestContext
910
from mcp.server.elicitation import (
1011
ElicitationResult,
@@ -218,6 +219,26 @@ def client_id(self) -> str | None:
218219
"""Get the client ID if available."""
219220
return self.request_context.meta.get("client_id") if self.request_context.meta else None # pragma: no cover
220221

222+
@property
223+
def subject(self) -> str | None:
224+
"""Get the authenticated user's subject (JWT sub claim), if available.
225+
226+
Returns the ``subject`` field from the current request's access token.
227+
This is typically the user ID set by the token verifier when the token
228+
is validated. Returns ``None`` when the request is unauthenticated or
229+
the token verifier did not populate the field.
230+
231+
Example::
232+
233+
@server.tool()
234+
async def my_tool(ctx: Context) -> str:
235+
if ctx.subject is None:
236+
return "unauthenticated"
237+
return f"Hello, {ctx.subject}"
238+
"""
239+
token = get_access_token()
240+
return token.subject if token else None
241+
221242
@property
222243
def request_id(self) -> str:
223244
"""Get the unique ID for this request."""

tests/server/auth/middleware/test_bearer_auth.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,64 @@ def no_expiry_access_token() -> AccessToken:
102102
)
103103

104104

105+
class TestAccessTokenFields:
106+
"""Tests for AccessToken model fields including subject and claims."""
107+
108+
def test_backward_compat_without_subject_and_claims(self):
109+
"""Existing code that omits subject/claims should still work."""
110+
token = AccessToken(
111+
token="tok",
112+
client_id="client",
113+
scopes=["read"],
114+
)
115+
assert token.subject is None
116+
assert token.claims is None
117+
118+
def test_subject_field(self):
119+
"""subject stores the JWT sub claim."""
120+
token = AccessToken(
121+
token="tok",
122+
client_id="client",
123+
scopes=["read"],
124+
subject="user-123",
125+
)
126+
assert token.subject == "user-123"
127+
128+
def test_claims_field(self):
129+
"""claims stores arbitrary additional JWT claims."""
130+
custom_claims = {"org": "acme", "role": "admin", "tier": 2}
131+
token = AccessToken(
132+
token="tok",
133+
client_id="client",
134+
scopes=["read"],
135+
claims=custom_claims,
136+
)
137+
assert token.claims == custom_claims
138+
139+
def test_subject_and_claims_together(self):
140+
"""subject and claims can both be set simultaneously."""
141+
token = AccessToken(
142+
token="tok",
143+
client_id="client",
144+
scopes=["read"],
145+
subject="user-456",
146+
claims={"org": "acme"},
147+
)
148+
assert token.subject == "user-456"
149+
assert token.claims == {"org": "acme"}
150+
151+
def test_subject_flows_through_authenticated_user(self):
152+
"""AuthenticatedUser carries the subject via its access_token attribute."""
153+
token = AccessToken(
154+
token="tok",
155+
client_id="client",
156+
scopes=["read"],
157+
subject="user-789",
158+
)
159+
user = AuthenticatedUser(token)
160+
assert user.access_token.subject == "user-789"
161+
162+
105163
@pytest.mark.anyio
106164
class TestBearerAuthBackend:
107165
"""Tests for the BearerAuthBackend class."""
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""Tests for the mcpserver Context class."""
2+
3+
from mcp.server.auth.middleware.auth_context import auth_context_var
4+
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
5+
from mcp.server.auth.provider import AccessToken
6+
from mcp.server.mcpserver import Context
7+
8+
9+
class TestContextSubject:
10+
"""Tests for Context.subject property."""
11+
12+
def test_subject_returns_none_when_unauthenticated(self):
13+
ctx = Context()
14+
assert ctx.subject is None
15+
16+
def test_subject_returns_none_when_token_has_no_subject(self):
17+
user = AuthenticatedUser(AccessToken(token="tok", client_id="client", scopes=["read"]))
18+
token = auth_context_var.set(user)
19+
try:
20+
ctx = Context()
21+
assert ctx.subject is None
22+
finally:
23+
auth_context_var.reset(token)
24+
25+
def test_subject_returns_value_from_access_token(self):
26+
user = AuthenticatedUser(AccessToken(token="tok", client_id="client", scopes=["read"], subject="user-123"))
27+
token = auth_context_var.set(user)
28+
try:
29+
ctx = Context()
30+
assert ctx.subject == "user-123"
31+
finally:
32+
auth_context_var.reset(token)
33+
34+
def test_subject_reflects_current_context(self):
35+
ctx = Context()
36+
assert ctx.subject is None
37+
38+
user = AuthenticatedUser(AccessToken(token="a", client_id="c", scopes=[], subject="alice"))
39+
cv_token = auth_context_var.set(user)
40+
try:
41+
assert ctx.subject == "alice"
42+
finally:
43+
auth_context_var.reset(cv_token)
44+
45+
assert ctx.subject is None

0 commit comments

Comments
 (0)