diff --git a/internal/http3/server.go b/internal/http3/server.go
index 2501b19..b9d053a 100644
--- a/internal/http3/server.go
+++ b/internal/http3/server.go
@@ -327,6 +327,7 @@
statusCode int // Status of the response that will be sent in HEADERS frame.
statusCodeSet bool // Status of the response has been set via a call to WriteHeader.
cannotHaveBody bool // Response should not have a body (e.g. response to a HEAD request).
+ bodyLenLeft int // How much of the content body is left to be sent, set via "Content-Length" header. -1 if unknown.
}
func (rw *responseWriter) Header() http.Header {
@@ -400,15 +401,41 @@
}
rw.statusCodeSet = true
rw.statusCode = statusCode
+
+ if n, err := strconv.Atoi(rw.Header().Get("Content-Length")); err == nil {
+ rw.bodyLenLeft = n
+ } else {
+ rw.bodyLenLeft = -1 // Unknown.
+ }
}
-func (rw *responseWriter) Write(b []byte) (int, error) {
+// trimWriteLocked trims a byte slice, b, such that the length of b will not
+// exceed rw.bodyLenLeft. This method will update rw.bodyLenLeft when trimming
+// b, and will also return whether b was trimmed or not.
+// Caller must hold rw.mu.
+func (rw *responseWriter) trimWriteLocked(b []byte) ([]byte, bool) {
+ if rw.bodyLenLeft < 0 {
+ return b, false
+ }
+ n := min(len(b), rw.bodyLenLeft)
+ rw.bodyLenLeft -= n
+ return b[:n], n != len(b)
+}
+
+func (rw *responseWriter) Write(b []byte) (n int, err error) {
// Calling Write implicitly calls WriteHeader(200) if WriteHeader has not
// been called before.
rw.WriteHeader(http.StatusOK)
rw.mu.Lock()
defer rw.mu.Unlock()
+ b, trimmed := rw.trimWriteLocked(b)
+ if trimmed {
+ defer func() {
+ err = http.ErrContentLength
+ }()
+ }
+
// If b fits entirely in our body buffer, save it to the buffer and return
// early so we can coalesce small writes.
// As a special case, we always want to save b to the buffer even when b is
diff --git a/internal/http3/server_test.go b/internal/http3/server_test.go
index a56581b..0318ab4 100644
--- a/internal/http3/server_test.go
+++ b/internal/http3/server_test.go
@@ -304,6 +304,75 @@
})
}
+func TestServerHandlerTrimsContentBody(t *testing.T) {
+ tests := []struct {
+ name string
+ declaredContentLen int
+ declaredInvalidContentLen bool
+ actualContentLen int
+ wantTrimmed bool
+ }{
+ {
+ name: "declared accurate content length",
+ declaredContentLen: 100,
+ actualContentLen: 100,
+ },
+ {
+ name: "declared larger content length",
+ declaredContentLen: 100,
+ actualContentLen: 10,
+ },
+ {
+ name: "declared smaller content length",
+ declaredContentLen: 10,
+ actualContentLen: 100,
+ wantTrimmed: true,
+ },
+ {
+ name: "declared invalid content length",
+ declaredInvalidContentLen: true,
+ actualContentLen: 100,
+ },
+ }
+
+ for _, tt := range tests {
+ wantWrittenLen := min(tt.actualContentLen, tt.declaredContentLen)
+ if tt.declaredInvalidContentLen {
+ wantWrittenLen = tt.actualContentLen
+ }
+ synctestSubtest(t, tt.name, func(t *testing.T) {
+ ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Length", strconv.Itoa(tt.declaredContentLen))
+ if tt.declaredInvalidContentLen {
+ w.Header().Set("Content-Length", "not a number, should be ignored")
+ }
+ var written int
+ var lastErr error
+ for range tt.actualContentLen {
+ n, err := w.Write([]byte("a"))
+ written += n
+ lastErr = err
+ }
+ if tt.wantTrimmed != (lastErr != nil) {
+ t.Errorf("got %v error when writing response body, even though wantTrimmed is %v", lastErr, tt.wantTrimmed)
+ }
+ if written != wantWrittenLen {
+ t.Errorf("got %v bytes written by the server, want %v bytes", written, wantWrittenLen)
+ }
+ }))
+ tc := ts.connect()
+ tc.greet()
+
+ reqStream := tc.newStream(streamTypeRequest)
+ reqStream.writeHeaders(requestHeader(nil))
+ synctest.Wait()
+ reqStream.wantHeaders(nil)
+ reqStream.wantData(slices.Repeat([]byte("a"), wantWrittenLen))
+ reqStream.wantClosed("request is complete")
+ })
+ }
+}
+
func TestServerExpect100Continue(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
streamIdle := make(chan bool)