diff --git a/auth/client/iam/client.go b/auth/client/iam/client.go index 3b8e8c17b..348adaf35 100644 --- a/auth/client/iam/client.go +++ b/auth/client/iam/client.go @@ -24,16 +24,17 @@ import ( "encoding/json" "errors" "fmt" - "github.com/lestrrat-go/jwx/v2/jws" - "github.com/lestrrat-go/jwx/v2/jwt" - "github.com/nuts-foundation/nuts-node/crypto" - "github.com/nuts-foundation/nuts-node/vdr/resolver" "io" "net/http" "net/url" "strings" "time" + "github.com/lestrrat-go/jwx/v2/jws" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/nuts-foundation/nuts-node/crypto" + "github.com/nuts-foundation/nuts-node/vdr/resolver" + "github.com/nuts-foundation/go-did/vc" "github.com/nuts-foundation/nuts-node/auth/log" "github.com/nuts-foundation/nuts-node/auth/oauth" @@ -190,19 +191,17 @@ func (hb HTTPClient) AccessToken(ctx context.Context, tokenEndpoint string, data return token, fmt.Errorf("failed to call endpoint: %w", err) } if err = core.TestResponseCode(http.StatusOK, response); err != nil { - // check for oauth error - if innerErr := core.TestResponseCode(http.StatusBadRequest, response); innerErr != nil { - // a non oauth error, the response body could contain a lot of stuff. We'll log and return the entire error - log.Logger().Debugf("authorization server token endpoint returned non oauth error (statusCode=%d)", response.StatusCode) - return token, err + httpErr, ok := errors.AsType[core.HttpError](err) + if ok && strings.Contains(response.Header.Get("Content-Type"), "application/json") { + // If the response is JSON and looks like an OAuth error, return it as such (regardless of status code). + oauthError := oauth.OAuth2Error{} + if jsonErr := json.Unmarshal(httpErr.ResponseBody, &oauthError); jsonErr == nil && oauthError.Code != "" { + return token, oauth.RemoteOAuthError{Cause: oauthError} + } } - httpErr := err.(core.HttpError) - oauthError := oauth.OAuth2Error{} - if err := json.Unmarshal(httpErr.ResponseBody, &oauthError); err != nil { - return token, fmt.Errorf("unable to unmarshal OAuth error response: %w", err) - } - - return token, oauth.RemoteOAuthError{Cause: oauthError} + // Not an OAuth error, the response body could contain a lot of stuff. We'll log and return the entire error + log.Logger().Debugf("authorization server token endpoint returned non oauth error (statusCode=%d)", response.StatusCode) + return token, err } var responseData []byte diff --git a/auth/client/iam/client_test.go b/auth/client/iam/client_test.go index c51a6cb3c..560715c82 100644 --- a/auth/client/iam/client_test.go +++ b/auth/client/iam/client_test.go @@ -188,6 +188,44 @@ func TestHTTPClient_AccessToken(t *testing.T) { require.True(t, ok) assert.Equal(t, "offline", string(httpError.ResponseBody)) }) + t.Run("error - oauth error with non-400 status", func(t *testing.T) { + // Some authorization servers return non-400 status codes for OAuth errors (e.g. 401, 500). + // The client should still recognize a JSON body with an "error" field as an OAuth error. + handler := http2.Handler{StatusCode: http.StatusInternalServerError, ResponseData: oauth.OAuth2Error{Code: oauth.InvalidRequest}} + tlsServer, client := testServerAndClient(t, &handler) + + _, err := client.AccessToken(ctx, tlsServer.URL, data, dpopHeader) + + require.Error(t, err) + var oauthError oauth.OAuth2Error + require.ErrorAs(t, err, &oauthError) + assert.Equal(t, oauth.InvalidRequest, oauthError.Code) + require.ErrorAs(t, err, new(oauth.RemoteOAuthError)) + }) + t.Run("error - non-JSON response with non-OK status", func(t *testing.T) { + // Not JSON, so must not be treated as an OAuth error. + handler := http2.Handler{StatusCode: http.StatusBadRequest, ResponseData: "not json"} + tlsServer, client := testServerAndClient(t, &handler) + + _, err := client.AccessToken(ctx, tlsServer.URL, data, dpopHeader) + + require.Error(t, err) + httpError, ok := err.(core.HttpError) + require.True(t, ok) + assert.Equal(t, "not json", string(httpError.ResponseBody)) + }) + t.Run("error - JSON response without OAuth error code", func(t *testing.T) { + // JSON, but without an "error" field — must not be treated as an OAuth error. + handler := http2.Handler{StatusCode: http.StatusBadRequest, ResponseData: map[string]string{"message": "something went wrong"}} + tlsServer, client := testServerAndClient(t, &handler) + + _, err := client.AccessToken(ctx, tlsServer.URL, data, dpopHeader) + + require.Error(t, err) + _, ok := err.(core.HttpError) + require.True(t, ok) + require.NotErrorAs(t, err, new(oauth.RemoteOAuthError)) + }) t.Run("error - invalid response", func(t *testing.T) { handler := http2.Handler{StatusCode: http.StatusOK, ResponseData: "}"} tlsServer, client := testServerAndClient(t, &handler) diff --git a/core/http_client.go b/core/http_client.go index f5acec4ed..d2ec64874 100644 --- a/core/http_client.go +++ b/core/http_client.go @@ -37,6 +37,11 @@ var TracingHTTPTransport func(http.RoundTripper) http.RoundTripper // If the response body is longer than this, it will be truncated. const HttpResponseBodyLogClipAt = 200 +// HttpResponseBodyMaxSize is the maximum number of bytes read from an unexpected HTTP response body. +// It prevents DoS attacks where a malicious server returns a very large response body. +// Only applied to error responses, so 1 MB is more than enough. +const HttpResponseBodyMaxSize = 1024 * 1024 + // HttpError describes an error returned when invoking a remote server. type HttpError struct { error @@ -54,7 +59,14 @@ func TestResponseCode(expectedStatusCode int, response *http.Response) error { // It logs using the given logger, unless nil is passed. func TestResponseCodeWithLog(expectedStatusCode int, response *http.Response, log *logrus.Entry) error { if response.StatusCode != expectedStatusCode { - responseData, _ := io.ReadAll(response.Body) + // Read at most HttpResponseBodyMaxSize bytes to prevent DoS. Read one extra byte to detect truncation. + responseData, _ := io.ReadAll(io.LimitReader(response.Body, HttpResponseBodyMaxSize+1)) + truncated := len(responseData) > HttpResponseBodyMaxSize + if truncated { + responseData = responseData[:HttpResponseBodyMaxSize] + logrus.WithField("http_request_path", response.Request.URL.Path). + Warnf("HTTP response body exceeds %d bytes, truncating", HttpResponseBodyMaxSize) + } if log != nil { // Cut off the response body to 100 characters max to prevent logging of large responses responseBodyString := string(responseData) diff --git a/core/http_client_test.go b/core/http_client_test.go index 1bc9e3183..adb759867 100644 --- a/core/http_client_test.go +++ b/core/http_client_test.go @@ -19,6 +19,7 @@ package core import ( + "bytes" "context" "errors" "github.com/sirupsen/logrus" @@ -144,6 +145,17 @@ func TestTestResponseCodeWithLog(t *testing.T) { assert.Equal(t, "Unexpected HTTP response (len=201): "+strings.Repeat("a", HttpResponseBodyLogClipAt)+"...(clipped)", hook.LastEntry().Message) }) + t.Run("response body exceeding max size is truncated", func(t *testing.T) { + data := bytes.Repeat([]byte("a"), HttpResponseBodyMaxSize+1024) + status := stdHttp.StatusUnauthorized + requestURL, _ := url.Parse("/foo") + request := &stdHttp.Request{URL: requestURL} + + err := TestResponseCodeWithLog(stdHttp.StatusOK, &stdHttp.Response{StatusCode: status, Body: io.NopCloser(bytes.NewReader(data)), Request: request}, nil) + + require.ErrorAs(t, err, new(HttpError)) + assert.Len(t, err.(HttpError).ResponseBody, HttpResponseBodyMaxSize) + }) } type readCloser []byte