diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index 6160edcc02..cfe6fb9a11 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -84,6 +84,8 @@ class OAuth2Auth(BaseModelWithConfig): expires_at: Optional[int] = None expires_in: Optional[int] = None audience: Optional[str] = None + code_challenge_method: Optional[str] = None + code_verifier: Optional[str] = None token_endpoint_auth_method: Optional[ Literal[ "client_secret_basic", diff --git a/src/google/adk/auth/auth_handler.py b/src/google/adk/auth/auth_handler.py index ec7c75716c..6cce82644a 100644 --- a/src/google/adk/auth/auth_handler.py +++ b/src/google/adk/auth/auth_handler.py @@ -14,8 +14,10 @@ from __future__ import annotations +import secrets from typing import TYPE_CHECKING +from authlib.oauth2.rfc7636 import create_s256_code_challenge from fastapi.openapi.models import SecurityBase from .auth_credential import AuthCredential @@ -69,7 +71,7 @@ async def parse_and_store_auth_response(self, state: State) -> None: state[credential_key] = await self.exchange_auth_token() def _validate(self) -> None: - if not self.auth_scheme: + if not self.auth_config.auth_scheme: raise ValueError("auth_scheme is empty.") def get_auth_response(self, state: State) -> AuthCredential: @@ -158,6 +160,9 @@ def generate_auth_uri( auth_scheme = self.auth_config.auth_scheme auth_credential = self.auth_config.raw_auth_credential + if not auth_credential or not auth_credential.oauth2: + raise ValueError("OAuth2 auth_credential with oauth2 config is required.") + oauth2_credential = auth_credential.oauth2 if isinstance(auth_scheme, OpenIdConnectWithConfig): authorization_endpoint = auth_scheme.authorization_endpoint @@ -186,23 +191,36 @@ def generate_auth_uri( scopes = list(scopes.keys()) client = OAuth2Session( - auth_credential.oauth2.client_id, - auth_credential.oauth2.client_secret, + oauth2_credential.client_id, + oauth2_credential.client_secret, scope=" ".join(scopes), - redirect_uri=auth_credential.oauth2.redirect_uri, + redirect_uri=oauth2_credential.redirect_uri, ) params = { "access_type": "offline", "prompt": "consent", } - if auth_credential.oauth2.audience: - params["audience"] = auth_credential.oauth2.audience + if oauth2_credential.audience: + params["audience"] = oauth2_credential.audience + code_challenge_method = oauth2_credential.code_challenge_method + if code_challenge_method: + if not oauth2_credential.code_verifier: + oauth2_credential.code_verifier = secrets.token_urlsafe(64) + params["code_challenge_method"] = code_challenge_method + if code_challenge_method == "S256": + params["code_challenge"] = create_s256_code_challenge( + oauth2_credential.code_verifier + ) + else: + params["code_challenge"] = oauth2_credential.code_verifier + uri, state = client.create_authorization_url( url=authorization_endpoint, **params ) exchanged_auth_credential = auth_credential.model_copy(deep=True) - exchanged_auth_credential.oauth2.auth_uri = uri - exchanged_auth_credential.oauth2.state = state + if exchanged_auth_credential.oauth2: + exchanged_auth_credential.oauth2.auth_uri = uri + exchanged_auth_credential.oauth2.state = state return exchanged_auth_credential diff --git a/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py b/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py index 02365f3026..3b2ea3982a 100644 --- a/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py +++ b/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py @@ -193,15 +193,20 @@ async def _exchange_authorization_code( return ExchangeResult(auth_credential, False) try: - tokens = client.fetch_token( - token_endpoint, - authorization_response=self._normalize_auth_uri( + fetch_token_kwargs = { + "authorization_response": self._normalize_auth_uri( auth_credential.oauth2.auth_response_uri ), - code=auth_credential.oauth2.auth_code, - grant_type=OAuthGrantType.AUTHORIZATION_CODE, - client_id=auth_credential.oauth2.client_id, - ) + "code": auth_credential.oauth2.auth_code, + "grant_type": OAuthGrantType.AUTHORIZATION_CODE, + "client_id": auth_credential.oauth2.client_id, + } + if auth_credential.oauth2.code_verifier: + fetch_token_kwargs["code_verifier"] = ( + auth_credential.oauth2.code_verifier + ) + + tokens = client.fetch_token(token_endpoint, **fetch_token_kwargs) update_credential_with_tokens(auth_credential, tokens) logger.debug("Successfully exchanged authorization code for access token") except Exception as e: diff --git a/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py b/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py index 574a0922b8..dec92d6d36 100644 --- a/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py +++ b/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py @@ -105,6 +105,43 @@ async def test_exchange_success(self, mock_oauth2_session): assert exchange_result.was_exchanged mock_client.fetch_token.assert_called_once() + @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") + async def test_exchange_authorization_code_with_code_verifier( + self, mock_oauth2_session + ): + """Test authorization code exchange passes PKCE verifier.""" + mock_client = Mock() + mock_oauth2_session.return_value = mock_client + mock_client.fetch_token.return_value = OAuth2Token({ + "access_token": "new_access_token", + }) + + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + auth_response_uri="https://example.com/callback?code=auth_code", + auth_code="auth_code", + code_verifier="test-code-verifier", + ), + ) + + exchanger = OAuth2CredentialExchanger() + await exchanger.exchange(credential, scheme) + + fetch_kwargs = mock_client.fetch_token.call_args.kwargs + assert fetch_kwargs["code_verifier"] == "test-code-verifier" + async def test_exchange_missing_auth_scheme(self): """Test exchange with missing auth_scheme raises ValueError.""" credential = AuthCredential( diff --git a/tests/unittests/auth/test_auth_handler.py b/tests/unittests/auth/test_auth_handler.py index 2faeeb158e..ed8faa5877 100644 --- a/tests/unittests/auth/test_auth_handler.py +++ b/tests/unittests/auth/test_auth_handler.py @@ -64,6 +64,12 @@ def create_authorization_url(self, url, **kwargs): params = f"client_id={self.client_id}&scope={self.scope}" if kwargs.get("audience"): params += f"&audience={kwargs.get('audience')}" + if kwargs.get("code_challenge_method"): + params += f"&code_challenge_method={kwargs.get('code_challenge_method')}" + if kwargs.get("code_challenge"): + params += f"&code_challenge={kwargs.get('code_challenge')}" + if kwargs.get("code_verifier"): + params += f"&code_verifier={kwargs.get('code_verifier')}" return f"{url}?{params}", "mock_state" def fetch_token( @@ -249,6 +255,19 @@ def test_generate_auth_uri_with_audience_and_prompt( assert "audience=test_audience" in result.oauth2.auth_uri + @patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session) + def test_generate_auth_uri_with_pkce(self, auth_config): + """Test generating an auth URI with PKCE enabled.""" + auth_config.raw_auth_credential.oauth2.code_challenge_method = "S256" + handler = AuthHandler(auth_config) + + result = handler.generate_auth_uri() + + assert "code_challenge_method=S256" in result.oauth2.auth_uri + assert "code_challenge=" in result.oauth2.auth_uri + assert "code_verifier=" not in result.oauth2.auth_uri + assert result.oauth2.code_verifier + @patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session) def test_generate_auth_uri_openid( self, openid_auth_scheme, oauth2_credentials