Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ _cgo_gotypes.go
_cgo_export.*
_obj/
_test/
.gocache/

# ─── Node / TypeScript (sdk, cli) ─────────────────────────────────────────────
**/node_modules/
Expand Down
31 changes: 26 additions & 5 deletions gateway/internal/inject/inject.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,18 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
rec := &responseRecorder{
ResponseWriter: w,
body: &bytes.Buffer{},
header: make(http.Header),
statusCode: http.StatusOK,
}

// Serve the request to the upstream handler.
m.next.ServeHTTP(rec, r)

// Check if the response is HTML.
contentType := rec.Header().Get("Content-Type")
contentType := rec.header.Get("Content-Type")
if !isHTML(contentType) {
// Not HTML — write the response as-is.
copyHeaders(w.Header(), rec.header)
w.WriteHeader(rec.statusCode)
if _, err := w.Write(rec.body.Bytes()); err != nil {
m.logger.Debug("rep.inject.write_error", "path", r.URL.Path, "error", err)
Expand All @@ -84,7 +86,7 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {

// Decompress the body if the upstream ignored our Accept-Encoding removal.
body := rec.body.Bytes()
encoding := rec.Header().Get("Content-Encoding")
encoding := rec.header.Get("Content-Encoding")
if encoding != "" {
decompressed, err := decompressBody(body, encoding)
if err != nil {
Expand All @@ -93,6 +95,7 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
"path", r.URL.Path,
"reason", "unsupported Content-Encoding: "+encoding,
)
copyHeaders(w.Header(), rec.header)
w.WriteHeader(rec.statusCode)
if _, err := w.Write(body); err != nil {
m.logger.Debug("rep.inject.write_error", "path", r.URL.Path, "error", err)
Expand All @@ -111,6 +114,8 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Inject the REP script tag into the HTML.
injected := injectIntoHTML(body, tag)

copyHeaders(w.Header(), rec.header)

// Update Content-Length to reflect the injected content.
w.Header().Set("Content-Length", strconv.Itoa(len(injected)))

Expand Down Expand Up @@ -260,11 +265,16 @@ func isHTML(contentType string) bool {
// responseRecorder captures the upstream response for inspection.
type responseRecorder struct {
http.ResponseWriter
header http.Header
body *bytes.Buffer
statusCode int
wroteHeader bool
}

func (r *responseRecorder) Header() http.Header {
return r.header
}

func (r *responseRecorder) WriteHeader(code int) {
r.statusCode = code
r.wroteHeader = true
Expand All @@ -277,12 +287,23 @@ func (r *responseRecorder) Write(b []byte) (int, error) {

// Flush implements http.Flusher for streaming support.
func (r *responseRecorder) Flush() {
if f, ok := r.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
// Intentionally do nothing. The middleware buffers the full upstream response
// before deciding whether to inject, so flushing here would prematurely commit
// headers/body to the client.
}

// ReadFrom implements io.ReaderFrom for efficient copies.
func (r *responseRecorder) ReadFrom(src io.Reader) (int64, error) {
return r.body.ReadFrom(src)
}

func copyHeaders(dst, src http.Header) {
for k := range dst {
dst.Del(k)
}
for k, values := range src {
for _, value := range values {
dst.Add(k, value)
}
}
}
51 changes: 51 additions & 0 deletions gateway/internal/inject/inject_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,57 @@ func TestMiddleware_ContentLengthUpdated(t *testing.T) {
_ = expectedLen // The header value is set by the middleware.
}

func TestMiddleware_UpstreamFlushDoesNotCommitEarly(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusCreated)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
_, _ = w.Write([]byte(`<html><head></head><body>flushed</body></html>`))
})

m := New(upstream, testScriptTag, slog.Default())

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

m.ServeHTTP(rec, req)

if rec.Code != http.StatusCreated {
t.Fatalf("expected status %d, got %d", http.StatusCreated, rec.Code)
}
body := rec.Body.String()
if !strings.Contains(body, testScriptTag) {
t.Fatal("expected injected script tag in flushed HTML response")
}
if !strings.Contains(body, "flushed") {
t.Fatal("expected original body content to be preserved after flush")
}
}

func TestMiddleware_BuffersHeadersUntilWriteback(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.Header().Set("X-REP-Test", "buffered")
_, _ = w.Write([]byte(`<html><head></head><body></body></html>`))
})

m := New(upstream, testScriptTag, slog.Default())

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

m.ServeHTTP(rec, req)

if got := rec.Header().Get("X-REP-Test"); got != "buffered" {
t.Fatalf("expected buffered header to be copied back, got %q", got)
}
if got := rec.Header().Get("Content-Type"); got != "text/html" {
t.Fatalf("expected content type to survive buffering, got %q", got)
}
}

func TestIsHTML(t *testing.T) {
tests := []struct {
ct string
Expand Down
Loading