diff --git a/.gitignore b/.gitignore index 813bc48..dd3225a 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ _cgo_gotypes.go _cgo_export.* _obj/ _test/ +.gocache/ # ─── Node / TypeScript (sdk, cli) ───────────────────────────────────────────── **/node_modules/ diff --git a/gateway/internal/inject/inject.go b/gateway/internal/inject/inject.go index d88624d..2003dc0 100644 --- a/gateway/internal/inject/inject.go +++ b/gateway/internal/inject/inject.go @@ -65,6 +65,7 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { rec := &responseRecorder{ ResponseWriter: w, body: &bytes.Buffer{}, + header: make(http.Header), statusCode: http.StatusOK, } @@ -72,9 +73,10 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { m.next.ServeHTTP(rec, r) // Check if the response is HTML. - contentType := rec.Header().Get("Content-Type") + contentType := rec.header.Get("Content-Type") if !isHTML(contentType) { // Not HTML — write the response as-is. + copyHeaders(w.Header(), rec.header) w.WriteHeader(rec.statusCode) if _, err := w.Write(rec.body.Bytes()); err != nil { m.logger.Debug("rep.inject.write_error", "path", r.URL.Path, "error", err) @@ -84,7 +86,7 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Decompress the body if the upstream ignored our Accept-Encoding removal. body := rec.body.Bytes() - encoding := rec.Header().Get("Content-Encoding") + encoding := rec.header.Get("Content-Encoding") if encoding != "" { decompressed, err := decompressBody(body, encoding) if err != nil { @@ -93,6 +95,7 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { "path", r.URL.Path, "reason", "unsupported Content-Encoding: "+encoding, ) + copyHeaders(w.Header(), rec.header) w.WriteHeader(rec.statusCode) if _, err := w.Write(body); err != nil { m.logger.Debug("rep.inject.write_error", "path", r.URL.Path, "error", err) @@ -111,6 +114,8 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Inject the REP script tag into the HTML. injected := injectIntoHTML(body, tag) + copyHeaders(w.Header(), rec.header) + // Update Content-Length to reflect the injected content. w.Header().Set("Content-Length", strconv.Itoa(len(injected))) @@ -260,11 +265,16 @@ func isHTML(contentType string) bool { // responseRecorder captures the upstream response for inspection. type responseRecorder struct { http.ResponseWriter + header http.Header body *bytes.Buffer statusCode int wroteHeader bool } +func (r *responseRecorder) Header() http.Header { + return r.header +} + func (r *responseRecorder) WriteHeader(code int) { r.statusCode = code r.wroteHeader = true @@ -277,12 +287,23 @@ func (r *responseRecorder) Write(b []byte) (int, error) { // Flush implements http.Flusher for streaming support. func (r *responseRecorder) Flush() { - if f, ok := r.ResponseWriter.(http.Flusher); ok { - f.Flush() - } + // Intentionally do nothing. The middleware buffers the full upstream response + // before deciding whether to inject, so flushing here would prematurely commit + // headers/body to the client. } // ReadFrom implements io.ReaderFrom for efficient copies. func (r *responseRecorder) ReadFrom(src io.Reader) (int64, error) { return r.body.ReadFrom(src) } + +func copyHeaders(dst, src http.Header) { + for k := range dst { + dst.Del(k) + } + for k, values := range src { + for _, value := range values { + dst.Add(k, value) + } + } +} diff --git a/gateway/internal/inject/inject_test.go b/gateway/internal/inject/inject_test.go index ad6ae8c..45d390a 100644 --- a/gateway/internal/inject/inject_test.go +++ b/gateway/internal/inject/inject_test.go @@ -191,6 +191,57 @@ func TestMiddleware_ContentLengthUpdated(t *testing.T) { _ = expectedLen // The header value is set by the middleware. } +func TestMiddleware_UpstreamFlushDoesNotCommitEarly(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusCreated) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + _, _ = w.Write([]byte(`flushed`)) + }) + + m := New(upstream, testScriptTag, slog.Default()) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + m.ServeHTTP(rec, req) + + if rec.Code != http.StatusCreated { + t.Fatalf("expected status %d, got %d", http.StatusCreated, rec.Code) + } + body := rec.Body.String() + if !strings.Contains(body, testScriptTag) { + t.Fatal("expected injected script tag in flushed HTML response") + } + if !strings.Contains(body, "flushed") { + t.Fatal("expected original body content to be preserved after flush") + } +} + +func TestMiddleware_BuffersHeadersUntilWriteback(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.Header().Set("X-REP-Test", "buffered") + _, _ = w.Write([]byte(``)) + }) + + m := New(upstream, testScriptTag, slog.Default()) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + m.ServeHTTP(rec, req) + + if got := rec.Header().Get("X-REP-Test"); got != "buffered" { + t.Fatalf("expected buffered header to be copied back, got %q", got) + } + if got := rec.Header().Get("Content-Type"); got != "text/html" { + t.Fatalf("expected content type to survive buffering, got %q", got) + } +} + func TestIsHTML(t *testing.T) { tests := []struct { ct string