diff --git a/common/httpx/httpx.go b/common/httpx/httpx.go index 27f4d63b..c6d46c4d 100644 --- a/common/httpx/httpx.go +++ b/common/httpx/httpx.go @@ -241,13 +241,19 @@ get_response: resp.Input = req.Host resp.Headers = httpresp.Header.Clone() + // body shouldn't be read with the following status codes + // 101 - Switching Protocols => websockets don't have a readable body + // 304 - Not Modified => no body the response terminates with latest header newline + shouldSkipBodyRead := generic.EqualsAny(httpresp.StatusCode, http.StatusSwitchingProtocols, http.StatusNotModified) if h.Options.MaxResponseBodySizeToRead > 0 { httpresp.Body = io.NopCloser(io.LimitReader(httpresp.Body, h.Options.MaxResponseBodySizeToRead)) - defer func() { - _, _ = io.Copy(io.Discard, httpresp.Body) - _ = httpresp.Body.Close() - }() + if !shouldSkipBodyRead { + defer func() { + _, _ = io.Copy(io.Discard, httpresp.Body) + _ = httpresp.Body.Close() + }() + } } // httputil.DumpResponse does not handle websockets @@ -272,10 +278,7 @@ get_response: resp.Raw = string(rawResp) resp.RawHeaders = string(headers) var respbody []byte - // body shouldn't be read with the following status codes - // 101 - Switching Protocols => websockets don't have a readable body - // 304 - Not Modified => no body the response terminates with latest header newline - if !generic.EqualsAny(httpresp.StatusCode, http.StatusSwitchingProtocols, http.StatusNotModified) { + if !shouldSkipBodyRead { var err error respbody, err = io.ReadAll(io.LimitReader(httpresp.Body, h.Options.MaxResponseBodySizeToRead)) if err != nil && !shouldIgnoreBodyErrors { diff --git a/common/httpx/httpx_test.go b/common/httpx/httpx_test.go index 0dd9cbbc..a58d6101 100644 --- a/common/httpx/httpx_test.go +++ b/common/httpx/httpx_test.go @@ -3,6 +3,7 @@ package httpx import ( "net/http" "testing" + "time" "github.com/projectdiscovery/retryablehttp-go" "github.com/stretchr/testify/require" @@ -120,3 +121,60 @@ func TestDefaultProtocolKeepsRetryableHTTP2FallbackClient(t *testing.T) { require.NotNil(t, ht.client) require.NotSame(t, ht.client.HTTPClient, ht.client.HTTPClient2) } + +type blockingReadCloser struct{} + +func (*blockingReadCloser) Read([]byte) (int, error) { + select {} +} + +func (*blockingReadCloser) Close() error { + return nil +} + +type switchingProtocolsRoundTripper struct{} + +func (switchingProtocolsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return &http.Response{ + Status: "101 Switching Protocols", + StatusCode: http.StatusSwitchingProtocols, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{ + "Upgrade": {"websocket"}, + "Connection": {"Upgrade"}, + }, + Body: &blockingReadCloser{}, + Request: req, + }, nil +} + +func TestDoSwitchingProtocolsDoesNotHang(t *testing.T) { + options := DefaultOptions + options.CdnCheck = "false" + options.Timeout = 2 * time.Second + options.RetryMax = 0 + + ht, err := New(&options) + require.NoError(t, err) + + rt := switchingProtocolsRoundTripper{} + ht.client.HTTPClient.Transport = rt + ht.client.HTTPClient2.Transport = rt + + req, err := retryablehttp.NewRequest(http.MethodGet, "http://example.com", nil) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + _, _ = ht.Do(req, UnsafeOptions{}) + close(done) + }() + + select { + case <-done: + case <-time.After(4 * time.Second): + t.Fatal("Do hung on 101 Switching Protocols response") + } +}