Skip to content

Commit 5416204

Browse files
committed
net/http: make Transport return Writable Response.Body on protocol switch
Updates #26937 Updates #17227 Change-Id: I79865938b05c219e1947822e60e4f52bb2604b70 Reviewed-on: https://go-review.googlesource.com/131279 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
1 parent 30b080e commit 5416204

File tree

3 files changed

+126
-2
lines changed

3 files changed

+126
-2
lines changed

src/net/http/response.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"crypto/tls"
1313
"errors"
1414
"fmt"
15+
"golang_org/x/net/http/httpguts"
1516
"io"
1617
"net/textproto"
1718
"net/url"
@@ -63,6 +64,10 @@ type Response struct {
6364
//
6465
// The Body is automatically dechunked if the server replied
6566
// with a "chunked" Transfer-Encoding.
67+
//
68+
// As of Go 1.12, the Body will be also implement io.Writer
69+
// on a successful "101 Switching Protocols" responses,
70+
// as used by WebSockets and HTTP/2's "h2c" mode.
6671
Body io.ReadCloser
6772

6873
// ContentLength records the length of the associated content. The
@@ -333,3 +338,23 @@ func (r *Response) closeBody() {
333338
r.Body.Close()
334339
}
335340
}
341+
342+
// bodyIsWritable reports whether the Body supports writing. The
343+
// Transport returns Writable bodies for 101 Switching Protocols
344+
// responses.
345+
// The Transport uses this method to determine whether a persistent
346+
// connection is done being managed from its perspective. Once we
347+
// return a writable response body to a user, the net/http package is
348+
// done managing that connection.
349+
func (r *Response) bodyIsWritable() bool {
350+
_, ok := r.Body.(io.Writer)
351+
return ok
352+
}
353+
354+
// isProtocolSwitch reports whether r is a response to a successful
355+
// protocol upgrade.
356+
func (r *Response) isProtocolSwitch() bool {
357+
return r.StatusCode == StatusSwitchingProtocols &&
358+
r.Header.Get("Upgrade") != "" &&
359+
httpguts.HeaderValuesContainsToken(r.Header["Connection"], "Upgrade")
360+
}

src/net/http/transport.go

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,6 +1607,11 @@ func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritte
16071607
return err
16081608
}
16091609

1610+
// errCallerOwnsConn is an internal sentinel error used when we hand
1611+
// off a writable response.Body to the caller. We use this to prevent
1612+
// closing a net.Conn that is now owned by the caller.
1613+
var errCallerOwnsConn = errors.New("read loop ending; caller owns writable underlying conn")
1614+
16101615
func (pc *persistConn) readLoop() {
16111616
closeErr := errReadLoopExiting // default value, if not changed below
16121617
defer func() {
@@ -1681,9 +1686,10 @@ func (pc *persistConn) readLoop() {
16811686
pc.numExpectedResponses--
16821687
pc.mu.Unlock()
16831688

1689+
bodyWritable := resp.bodyIsWritable()
16841690
hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0
16851691

1686-
if resp.Close || rc.req.Close || resp.StatusCode <= 199 {
1692+
if resp.Close || rc.req.Close || resp.StatusCode <= 199 || bodyWritable {
16871693
// Don't do keep-alive on error if either party requested a close
16881694
// or we get an unexpected informational (1xx) response.
16891695
// StatusCode 100 is already handled above.
@@ -1704,6 +1710,10 @@ func (pc *persistConn) readLoop() {
17041710
pc.wroteRequest() &&
17051711
tryPutIdleConn(trace)
17061712

1713+
if bodyWritable {
1714+
closeErr = errCallerOwnsConn
1715+
}
1716+
17071717
select {
17081718
case rc.ch <- responseAndError{res: resp}:
17091719
case <-rc.callerGone:
@@ -1848,6 +1858,10 @@ func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTr
18481858
}
18491859
break
18501860
}
1861+
if resp.isProtocolSwitch() {
1862+
resp.Body = newReadWriteCloserBody(pc.br, pc.conn)
1863+
}
1864+
18511865
resp.TLS = pc.tlsState
18521866
return
18531867
}
@@ -1874,6 +1888,38 @@ func (pc *persistConn) waitForContinue(continueCh <-chan struct{}) func() bool {
18741888
}
18751889
}
18761890

1891+
func newReadWriteCloserBody(br *bufio.Reader, rwc io.ReadWriteCloser) io.ReadWriteCloser {
1892+
body := &readWriteCloserBody{ReadWriteCloser: rwc}
1893+
if br.Buffered() != 0 {
1894+
body.br = br
1895+
}
1896+
return body
1897+
}
1898+
1899+
// readWriteCloserBody is the Response.Body type used when we want to
1900+
// give users write access to the Body through the underlying
1901+
// connection (TCP, unless using custom dialers). This is then
1902+
// the concrete type for a Response.Body on the 101 Switching
1903+
// Protocols response, as used by WebSockets, h2c, etc.
1904+
type readWriteCloserBody struct {
1905+
br *bufio.Reader // used until empty
1906+
io.ReadWriteCloser
1907+
}
1908+
1909+
func (b *readWriteCloserBody) Read(p []byte) (n int, err error) {
1910+
if b.br != nil {
1911+
if n := b.br.Buffered(); len(p) > n {
1912+
p = p[:n]
1913+
}
1914+
n, err = b.br.Read(p)
1915+
if b.br.Buffered() == 0 {
1916+
b.br = nil
1917+
}
1918+
return n, err
1919+
}
1920+
return b.ReadWriteCloser.Read(p)
1921+
}
1922+
18771923
// nothingWrittenError wraps a write errors which ended up writing zero bytes.
18781924
type nothingWrittenError struct {
18791925
error
@@ -2193,7 +2239,9 @@ func (pc *persistConn) closeLocked(err error) {
21932239
// freelist for http2. That's done by the
21942240
// alternate protocol's RoundTripper.
21952241
} else {
2196-
pc.conn.Close()
2242+
if err != errCallerOwnsConn {
2243+
pc.conn.Close()
2244+
}
21972245
close(pc.closech)
21982246
}
21992247
}

src/net/http/transport_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4836,3 +4836,54 @@ func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
48364836
t.Fatal("timeout")
48374837
}
48384838
}
4839+
4840+
func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
4841+
setParallel(t)
4842+
defer afterTest(t)
4843+
done := make(chan struct{})
4844+
defer close(done)
4845+
cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4846+
conn, _, err := w.(Hijacker).Hijack()
4847+
if err != nil {
4848+
t.Error(err)
4849+
return
4850+
}
4851+
defer conn.Close()
4852+
io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
4853+
bs := bufio.NewScanner(conn)
4854+
bs.Scan()
4855+
fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
4856+
<-done
4857+
}))
4858+
defer cst.close()
4859+
4860+
req, _ := NewRequest("GET", cst.ts.URL, nil)
4861+
req.Header.Set("Upgrade", "foo")
4862+
req.Header.Set("Connection", "upgrade")
4863+
res, err := cst.c.Do(req)
4864+
if err != nil {
4865+
t.Fatal(err)
4866+
}
4867+
if res.StatusCode != 101 {
4868+
t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
4869+
}
4870+
rwc, ok := res.Body.(io.ReadWriteCloser)
4871+
if !ok {
4872+
t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
4873+
}
4874+
defer rwc.Close()
4875+
bs := bufio.NewScanner(rwc)
4876+
if !bs.Scan() {
4877+
t.Fatalf("expected readable input")
4878+
}
4879+
if got, want := bs.Text(), "Some buffered data"; got != want {
4880+
t.Errorf("read %q; want %q", got, want)
4881+
}
4882+
io.WriteString(rwc, "echo\n")
4883+
if !bs.Scan() {
4884+
t.Fatalf("expected another line")
4885+
}
4886+
if got, want := bs.Text(), "ECHO"; got != want {
4887+
t.Errorf("read %q; want %q", got, want)
4888+
}
4889+
}

0 commit comments

Comments
 (0)