diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go
index 58253d0..7d05686 100644
--- a/internal/http3/roundtrip.go
+++ b/internal/http3/roundtrip.go
@@ -26,7 +26,7 @@
reqBodyWriter bodyWriter
// Response.Body, provided to the caller.
- respBody bodyReader
+ respBody io.ReadCloser
errOnce sync.Once
err error
@@ -126,6 +126,11 @@
encr.HasBody = false
go copyRequestBody(rt)
}
+ } else {
+ // If we have no body to send, close the write direction of the stream
+ // as soon as we have sent our HEADERS. That way, servers will know
+ // that there are no DATA frames incoming.
+ rt.st.stream.CloseWrite()
}
// Read the response headers.
@@ -164,8 +169,14 @@
if err != nil {
return nil, err
}
- rt.respBody.st = st
- rt.respBody.remain = contentLength
+ if contentLength != 0 && req.Method != http.MethodHead {
+ rt.respBody = &bodyReader{
+ st: st,
+ remain: contentLength,
+ }
+ } else {
+ rt.respBody = http.NoBody
+ }
resp := &http.Response{
Proto: "HTTP/3.0",
ProtoMajor: 3,
diff --git a/internal/http3/roundtrip_test.go b/internal/http3/roundtrip_test.go
index efbe105..b6137c2 100644
--- a/internal/http3/roundtrip_test.go
+++ b/internal/http3/roundtrip_test.go
@@ -419,3 +419,63 @@
rt.wantBody(serverBody)
})
}
+
+func TestRoundTripNoBodyClosesStream(t *testing.T) {
+ synctest.Test(t, func(t *testing.T) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("PUT", "https://example.tld/", nil)
+ tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+
+ st.wantHeaders(nil)
+ st.wantClosed("no DATA frames to send")
+ })
+}
+
+func TestRoundTripReadRespWithNoBody(t *testing.T) {
+ synctest.Test(t, func(t *testing.T) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ // Case 1: we know response body is empty because the server closes the
+ // write direction of the stream.
+ req, _ := http.NewRequest("GET", "https://example.tld/", nil)
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+ st.wantHeaders(nil)
+ st.writeHeaders(http.Header{
+ ":status": {"200"},
+ })
+ st.stream.stream.CloseWrite()
+ rt.wantStatus(200)
+ rt.wantBody(make([]byte, 0))
+
+ // Case 2: we know response body is empty because the server indicates
+ // a Content-Length of 0.
+ req, _ = http.NewRequest("GET", "https://example.tld/", nil)
+ rt = tc.roundTrip(req)
+ st = tc.wantStream(streamTypeRequest)
+ st.wantHeaders(nil)
+ st.writeHeaders(http.Header{
+ ":status": {"200"},
+ "Content-Length": {"0"},
+ })
+ rt.wantStatus(200)
+ rt.wantBody(make([]byte, 0))
+
+ // Case 3: we know response body is empty because we sent a HEAD
+ // request.
+ req, _ = http.NewRequest("HEAD", "https://example.tld/", nil)
+ rt = tc.roundTrip(req)
+ st = tc.wantStream(streamTypeRequest)
+ st.wantHeaders(nil)
+ st.writeHeaders(http.Header{
+ ":status": {"200"},
+ "Content-Length": {"1000"},
+ })
+ rt.wantStatus(200)
+ rt.wantBody(make([]byte, 0))
+ })
+}
diff --git a/internal/http3/server.go b/internal/http3/server.go
index 5285d4c..7fe1245 100644
--- a/internal/http3/server.go
+++ b/internal/http3/server.go
@@ -6,6 +6,7 @@
import (
"context"
+ "io"
"net/http"
"strconv"
"sync"
@@ -217,6 +218,21 @@
message: reqInfo.InvalidReason,
}
}
+
+ var body io.ReadCloser
+ contentLength := int64(-1)
+ if n, err := strconv.Atoi(header.Get("Content-Length")); err == nil {
+ contentLength = int64(n)
+ }
+ if contentLength != 0 {
+ body = &bodyReader{
+ st: st,
+ remain: contentLength,
+ }
+ } else {
+ body = http.NoBody
+ }
+
req := &http.Request{
Proto: "HTTP/3.0",
Method: pHeader.method,
@@ -226,11 +242,8 @@
Trailer: reqInfo.Trailer,
ProtoMajor: 3,
RemoteAddr: sc.qconn.RemoteAddr().String(),
- Body: &bodyReader{
- st: st,
- remain: -1,
- },
- Header: header,
+ Body: body,
+ Header: header,
}
defer req.Body.Close()
diff --git a/internal/http3/server_test.go b/internal/http3/server_test.go
index 2b6de06..805a8b3 100644
--- a/internal/http3/server_test.go
+++ b/internal/http3/server_test.go
@@ -373,6 +373,41 @@
})
}
+func TestServerHandlerReadReqWithNoBody(t *testing.T) {
+ synctest.Test(t, func(t *testing.T) {
+ serverBody := []byte("hello from server!")
+ ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if _, err := io.ReadAll(r.Body); err != nil {
+ t.Errorf("got %v err when reading from an empty request body, want nil", err)
+ }
+ w.Write(serverBody)
+ }))
+ tc := ts.connect()
+ tc.greet()
+
+ // Case 1: we know that there is no body / DATA frame because the
+ // client closes the write direction of the stream.
+ reqStream := tc.newStream(streamTypeRequest)
+ reqStream.writeHeaders(requestHeader(nil))
+ reqStream.stream.stream.CloseWrite()
+ synctest.Wait()
+ reqStream.wantHeaders(http.Header{":status": {"200"}})
+ reqStream.wantData(serverBody)
+ reqStream.wantClosed("request is complete")
+
+ // Case 2: we know that there is no body / DATA frame because the
+ // client indicates a Content-Length of 0.
+ reqStream = tc.newStream(streamTypeRequest)
+ reqStream.writeHeaders(requestHeader(http.Header{
+ "Content-Length": {"0"},
+ }))
+ synctest.Wait()
+ reqStream.wantHeaders(http.Header{":status": {"200"}})
+ reqStream.wantData(serverBody)
+ reqStream.wantClosed("request is complete")
+ })
+}
+
type testServer struct {
t testing.TB
s *Server