Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions auth/client/iam/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/nuts-foundation/nuts-node/crypto"
"github.com/nuts-foundation/nuts-node/vdr/resolver"
"io"
"net/http"
"net/url"
"strings"
"time"

"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/nuts-foundation/nuts-node/crypto"
"github.com/nuts-foundation/nuts-node/vdr/resolver"

"github.com/nuts-foundation/go-did/vc"
"github.com/nuts-foundation/nuts-node/auth/log"
"github.com/nuts-foundation/nuts-node/auth/oauth"
Expand Down Expand Up @@ -190,19 +191,17 @@ func (hb HTTPClient) AccessToken(ctx context.Context, tokenEndpoint string, data
return token, fmt.Errorf("failed to call endpoint: %w", err)
}
if err = core.TestResponseCode(http.StatusOK, response); err != nil {
// check for oauth error
if innerErr := core.TestResponseCode(http.StatusBadRequest, response); innerErr != nil {
// a non oauth error, the response body could contain a lot of stuff. We'll log and return the entire error
log.Logger().Debugf("authorization server token endpoint returned non oauth error (statusCode=%d)", response.StatusCode)
return token, err
httpErr, ok := errors.AsType[core.HttpError](err)
if ok && strings.Contains(response.Header.Get("Content-Type"), "application/json") {
// If the response is JSON and looks like an OAuth error, return it as such (regardless of status code).
oauthError := oauth.OAuth2Error{}
if jsonErr := json.Unmarshal(httpErr.ResponseBody, &oauthError); jsonErr == nil && oauthError.Code != "" {
return token, oauth.RemoteOAuthError{Cause: oauthError}
}
}
httpErr := err.(core.HttpError)
oauthError := oauth.OAuth2Error{}
if err := json.Unmarshal(httpErr.ResponseBody, &oauthError); err != nil {
return token, fmt.Errorf("unable to unmarshal OAuth error response: %w", err)
}

return token, oauth.RemoteOAuthError{Cause: oauthError}
// Not an OAuth error, the response body could contain a lot of stuff. We'll log and return the entire error
log.Logger().Debugf("authorization server token endpoint returned non oauth error (statusCode=%d)", response.StatusCode)
return token, err
}

var responseData []byte
Expand Down
38 changes: 38 additions & 0 deletions auth/client/iam/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,44 @@ func TestHTTPClient_AccessToken(t *testing.T) {
require.True(t, ok)
assert.Equal(t, "offline", string(httpError.ResponseBody))
})
t.Run("error - oauth error with non-400 status", func(t *testing.T) {
// Some authorization servers return non-400 status codes for OAuth errors (e.g. 401, 500).
// The client should still recognize a JSON body with an "error" field as an OAuth error.
handler := http2.Handler{StatusCode: http.StatusInternalServerError, ResponseData: oauth.OAuth2Error{Code: oauth.InvalidRequest}}
tlsServer, client := testServerAndClient(t, &handler)

_, err := client.AccessToken(ctx, tlsServer.URL, data, dpopHeader)

require.Error(t, err)
var oauthError oauth.OAuth2Error
require.ErrorAs(t, err, &oauthError)
assert.Equal(t, oauth.InvalidRequest, oauthError.Code)
require.ErrorAs(t, err, new(oauth.RemoteOAuthError))
})
t.Run("error - non-JSON response with non-OK status", func(t *testing.T) {
// Not JSON, so must not be treated as an OAuth error.
handler := http2.Handler{StatusCode: http.StatusBadRequest, ResponseData: "not json"}
tlsServer, client := testServerAndClient(t, &handler)

_, err := client.AccessToken(ctx, tlsServer.URL, data, dpopHeader)

require.Error(t, err)
httpError, ok := err.(core.HttpError)
require.True(t, ok)
assert.Equal(t, "not json", string(httpError.ResponseBody))
})
t.Run("error - JSON response without OAuth error code", func(t *testing.T) {
// JSON, but without an "error" field — must not be treated as an OAuth error.
handler := http2.Handler{StatusCode: http.StatusBadRequest, ResponseData: map[string]string{"message": "something went wrong"}}
tlsServer, client := testServerAndClient(t, &handler)

_, err := client.AccessToken(ctx, tlsServer.URL, data, dpopHeader)

require.Error(t, err)
_, ok := err.(core.HttpError)
require.True(t, ok)
require.NotErrorAs(t, err, new(oauth.RemoteOAuthError))
})
t.Run("error - invalid response", func(t *testing.T) {
handler := http2.Handler{StatusCode: http.StatusOK, ResponseData: "}"}
tlsServer, client := testServerAndClient(t, &handler)
Expand Down
14 changes: 13 additions & 1 deletion core/http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ var TracingHTTPTransport func(http.RoundTripper) http.RoundTripper
// If the response body is longer than this, it will be truncated.
const HttpResponseBodyLogClipAt = 200

// HttpResponseBodyMaxSize is the maximum number of bytes read from an unexpected HTTP response body.
// It prevents DoS attacks where a malicious server returns a very large response body.
// Only applied to error responses, so 1 MB is more than enough.
const HttpResponseBodyMaxSize = 1024 * 1024

// HttpError describes an error returned when invoking a remote server.
type HttpError struct {
error
Expand All @@ -54,7 +59,14 @@ func TestResponseCode(expectedStatusCode int, response *http.Response) error {
// It logs using the given logger, unless nil is passed.
func TestResponseCodeWithLog(expectedStatusCode int, response *http.Response, log *logrus.Entry) error {
if response.StatusCode != expectedStatusCode {
responseData, _ := io.ReadAll(response.Body)
// Read at most HttpResponseBodyMaxSize bytes to prevent DoS. Read one extra byte to detect truncation.
responseData, _ := io.ReadAll(io.LimitReader(response.Body, HttpResponseBodyMaxSize+1))
truncated := len(responseData) > HttpResponseBodyMaxSize
if truncated {
responseData = responseData[:HttpResponseBodyMaxSize]
logrus.WithField("http_request_path", response.Request.URL.Path).
Warnf("HTTP response body exceeds %d bytes, truncating", HttpResponseBodyMaxSize)
}
if log != nil {
// Cut off the response body to 100 characters max to prevent logging of large responses
responseBodyString := string(responseData)
Expand Down
12 changes: 12 additions & 0 deletions core/http_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package core

import (
"bytes"
"context"
"errors"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -144,6 +145,17 @@ func TestTestResponseCodeWithLog(t *testing.T) {

assert.Equal(t, "Unexpected HTTP response (len=201): "+strings.Repeat("a", HttpResponseBodyLogClipAt)+"...(clipped)", hook.LastEntry().Message)
})
t.Run("response body exceeding max size is truncated", func(t *testing.T) {
data := bytes.Repeat([]byte("a"), HttpResponseBodyMaxSize+1024)
status := stdHttp.StatusUnauthorized
requestURL, _ := url.Parse("/foo")
request := &stdHttp.Request{URL: requestURL}

err := TestResponseCodeWithLog(stdHttp.StatusOK, &stdHttp.Response{StatusCode: status, Body: io.NopCloser(bytes.NewReader(data)), Request: request}, nil)

require.ErrorAs(t, err, new(HttpError))
assert.Len(t, err.(HttpError).ResponseBody, HttpResponseBodyMaxSize)
})
}

type readCloser []byte
Expand Down
Loading