reverseproxy_test.go 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // Reverse proxy tests.
  5. package httputil
  6. import (
  7. "bufio"
  8. "bytes"
  9. "context"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "log"
  14. "net/http"
  15. "net/http/httptest"
  16. "net/http/internal/ascii"
  17. "net/url"
  18. "os"
  19. "reflect"
  20. "sort"
  21. "strconv"
  22. "strings"
  23. "sync"
  24. "testing"
  25. "time"
  26. )
  27. const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
  28. func init() {
  29. inOurTests = true
  30. hopHeaders = append(hopHeaders, fakeHopHeader)
  31. }
  32. func TestReverseProxy(t *testing.T) {
  33. const backendResponse = "I am the backend"
  34. const backendStatus = 404
  35. backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  36. if r.Method == "GET" && r.FormValue("mode") == "hangup" {
  37. c, _, _ := w.(http.Hijacker).Hijack()
  38. c.Close()
  39. return
  40. }
  41. if len(r.TransferEncoding) > 0 {
  42. t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
  43. }
  44. if r.Header.Get("X-Forwarded-For") == "" {
  45. t.Errorf("didn't get X-Forwarded-For header")
  46. }
  47. if c := r.Header.Get("Connection"); c != "" {
  48. t.Errorf("handler got Connection header value %q", c)
  49. }
  50. if c := r.Header.Get("Te"); c != "trailers" {
  51. t.Errorf("handler got Te header value %q; want 'trailers'", c)
  52. }
  53. if c := r.Header.Get("Upgrade"); c != "" {
  54. t.Errorf("handler got Upgrade header value %q", c)
  55. }
  56. if c := r.Header.Get("Proxy-Connection"); c != "" {
  57. t.Errorf("handler got Proxy-Connection header value %q", c)
  58. }
  59. if g, e := r.Host, "some-name"; g != e {
  60. t.Errorf("backend got Host header %q, want %q", g, e)
  61. }
  62. w.Header().Set("Trailers", "not a special header field name")
  63. w.Header().Set("Trailer", "X-Trailer")
  64. w.Header().Set("X-Foo", "bar")
  65. w.Header().Set("Upgrade", "foo")
  66. w.Header().Set(fakeHopHeader, "foo")
  67. w.Header().Add("X-Multi-Value", "foo")
  68. w.Header().Add("X-Multi-Value", "bar")
  69. http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
  70. w.WriteHeader(backendStatus)
  71. w.Write([]byte(backendResponse))
  72. w.Header().Set("X-Trailer", "trailer_value")
  73. w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
  74. }))
  75. defer backend.Close()
  76. backendURL, err := url.Parse(backend.URL)
  77. if err != nil {
  78. t.Fatal(err)
  79. }
  80. proxyHandler := NewSingleHostReverseProxy(backendURL)
  81. proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  82. frontend := httptest.NewServer(proxyHandler)
  83. defer frontend.Close()
  84. frontendClient := frontend.Client()
  85. getReq, _ := http.NewRequest("GET", frontend.URL, nil)
  86. getReq.Host = "some-name"
  87. getReq.Header.Set("Connection", "close, TE")
  88. getReq.Header.Add("Te", "foo")
  89. getReq.Header.Add("Te", "bar, trailers")
  90. getReq.Header.Set("Proxy-Connection", "should be deleted")
  91. getReq.Header.Set("Upgrade", "foo")
  92. getReq.Close = true
  93. res, err := frontendClient.Do(getReq)
  94. if err != nil {
  95. t.Fatalf("Get: %v", err)
  96. }
  97. if g, e := res.StatusCode, backendStatus; g != e {
  98. t.Errorf("got res.StatusCode %d; expected %d", g, e)
  99. }
  100. if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
  101. t.Errorf("got X-Foo %q; expected %q", g, e)
  102. }
  103. if c := res.Header.Get(fakeHopHeader); c != "" {
  104. t.Errorf("got %s header value %q", fakeHopHeader, c)
  105. }
  106. if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
  107. t.Errorf("header Trailers = %q; want %q", g, e)
  108. }
  109. if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
  110. t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
  111. }
  112. if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
  113. t.Fatalf("got %d SetCookies, want %d", g, e)
  114. }
  115. if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
  116. t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
  117. }
  118. if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
  119. t.Errorf("unexpected cookie %q", cookie.Name)
  120. }
  121. bodyBytes, _ := io.ReadAll(res.Body)
  122. if g, e := string(bodyBytes), backendResponse; g != e {
  123. t.Errorf("got body %q; expected %q", g, e)
  124. }
  125. if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
  126. t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
  127. }
  128. if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e {
  129. t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e)
  130. }
  131. // Test that a backend failing to be reached or one which doesn't return
  132. // a response results in a StatusBadGateway.
  133. getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
  134. getReq.Close = true
  135. res, err = frontendClient.Do(getReq)
  136. if err != nil {
  137. t.Fatal(err)
  138. }
  139. res.Body.Close()
  140. if res.StatusCode != http.StatusBadGateway {
  141. t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status)
  142. }
  143. }
  144. // Issue 16875: remove any proxied headers mentioned in the "Connection"
  145. // header value.
  146. func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
  147. const fakeConnectionToken = "X-Fake-Connection-Token"
  148. const backendResponse = "I am the backend"
  149. // someConnHeader is some arbitrary header to be declared as a hop-by-hop header
  150. // in the Request's Connection header.
  151. const someConnHeader = "X-Some-Conn-Header"
  152. backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  153. if c := r.Header.Get("Connection"); c != "" {
  154. t.Errorf("handler got header %q = %q; want empty", "Connection", c)
  155. }
  156. if c := r.Header.Get(fakeConnectionToken); c != "" {
  157. t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
  158. }
  159. if c := r.Header.Get(someConnHeader); c != "" {
  160. t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
  161. }
  162. w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken)
  163. w.Header().Add("Connection", someConnHeader)
  164. w.Header().Set(someConnHeader, "should be deleted")
  165. w.Header().Set(fakeConnectionToken, "should be deleted")
  166. io.WriteString(w, backendResponse)
  167. }))
  168. defer backend.Close()
  169. backendURL, err := url.Parse(backend.URL)
  170. if err != nil {
  171. t.Fatal(err)
  172. }
  173. proxyHandler := NewSingleHostReverseProxy(backendURL)
  174. frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  175. proxyHandler.ServeHTTP(w, r)
  176. if c := r.Header.Get(someConnHeader); c != "should be deleted" {
  177. t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
  178. }
  179. if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" {
  180. t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted")
  181. }
  182. c := r.Header["Connection"]
  183. var cf []string
  184. for _, f := range c {
  185. for _, sf := range strings.Split(f, ",") {
  186. if sf = strings.TrimSpace(sf); sf != "" {
  187. cf = append(cf, sf)
  188. }
  189. }
  190. }
  191. sort.Strings(cf)
  192. expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken}
  193. sort.Strings(expectedValues)
  194. if !reflect.DeepEqual(cf, expectedValues) {
  195. t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues)
  196. }
  197. }))
  198. defer frontend.Close()
  199. getReq, _ := http.NewRequest("GET", frontend.URL, nil)
  200. getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken)
  201. getReq.Header.Add("Connection", someConnHeader)
  202. getReq.Header.Set(someConnHeader, "should be deleted")
  203. getReq.Header.Set(fakeConnectionToken, "should be deleted")
  204. res, err := frontend.Client().Do(getReq)
  205. if err != nil {
  206. t.Fatalf("Get: %v", err)
  207. }
  208. defer res.Body.Close()
  209. bodyBytes, err := io.ReadAll(res.Body)
  210. if err != nil {
  211. t.Fatalf("reading body: %v", err)
  212. }
  213. if got, want := string(bodyBytes), backendResponse; got != want {
  214. t.Errorf("got body %q; want %q", got, want)
  215. }
  216. if c := res.Header.Get("Connection"); c != "" {
  217. t.Errorf("handler got header %q = %q; want empty", "Connection", c)
  218. }
  219. if c := res.Header.Get(someConnHeader); c != "" {
  220. t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
  221. }
  222. if c := res.Header.Get(fakeConnectionToken); c != "" {
  223. t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
  224. }
  225. }
  226. func TestReverseProxyStripEmptyConnection(t *testing.T) {
  227. // See Issue 46313.
  228. const backendResponse = "I am the backend"
  229. // someConnHeader is some arbitrary header to be declared as a hop-by-hop header
  230. // in the Request's Connection header.
  231. const someConnHeader = "X-Some-Conn-Header"
  232. backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  233. if c := r.Header.Values("Connection"); len(c) != 0 {
  234. t.Errorf("handler got header %q = %v; want empty", "Connection", c)
  235. }
  236. if c := r.Header.Get(someConnHeader); c != "" {
  237. t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
  238. }
  239. w.Header().Add("Connection", "")
  240. w.Header().Add("Connection", someConnHeader)
  241. w.Header().Set(someConnHeader, "should be deleted")
  242. io.WriteString(w, backendResponse)
  243. }))
  244. defer backend.Close()
  245. backendURL, err := url.Parse(backend.URL)
  246. if err != nil {
  247. t.Fatal(err)
  248. }
  249. proxyHandler := NewSingleHostReverseProxy(backendURL)
  250. frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  251. proxyHandler.ServeHTTP(w, r)
  252. if c := r.Header.Get(someConnHeader); c != "should be deleted" {
  253. t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
  254. }
  255. }))
  256. defer frontend.Close()
  257. getReq, _ := http.NewRequest("GET", frontend.URL, nil)
  258. getReq.Header.Add("Connection", "")
  259. getReq.Header.Add("Connection", someConnHeader)
  260. getReq.Header.Set(someConnHeader, "should be deleted")
  261. res, err := frontend.Client().Do(getReq)
  262. if err != nil {
  263. t.Fatalf("Get: %v", err)
  264. }
  265. defer res.Body.Close()
  266. bodyBytes, err := io.ReadAll(res.Body)
  267. if err != nil {
  268. t.Fatalf("reading body: %v", err)
  269. }
  270. if got, want := string(bodyBytes), backendResponse; got != want {
  271. t.Errorf("got body %q; want %q", got, want)
  272. }
  273. if c := res.Header.Get("Connection"); c != "" {
  274. t.Errorf("handler got header %q = %q; want empty", "Connection", c)
  275. }
  276. if c := res.Header.Get(someConnHeader); c != "" {
  277. t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
  278. }
  279. }
  280. func TestXForwardedFor(t *testing.T) {
  281. const prevForwardedFor = "client ip"
  282. const backendResponse = "I am the backend"
  283. const backendStatus = 404
  284. backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  285. if r.Header.Get("X-Forwarded-For") == "" {
  286. t.Errorf("didn't get X-Forwarded-For header")
  287. }
  288. if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
  289. t.Errorf("X-Forwarded-For didn't contain prior data")
  290. }
  291. w.WriteHeader(backendStatus)
  292. w.Write([]byte(backendResponse))
  293. }))
  294. defer backend.Close()
  295. backendURL, err := url.Parse(backend.URL)
  296. if err != nil {
  297. t.Fatal(err)
  298. }
  299. proxyHandler := NewSingleHostReverseProxy(backendURL)
  300. frontend := httptest.NewServer(proxyHandler)
  301. defer frontend.Close()
  302. getReq, _ := http.NewRequest("GET", frontend.URL, nil)
  303. getReq.Host = "some-name"
  304. getReq.Header.Set("Connection", "close")
  305. getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
  306. getReq.Close = true
  307. res, err := frontend.Client().Do(getReq)
  308. if err != nil {
  309. t.Fatalf("Get: %v", err)
  310. }
  311. if g, e := res.StatusCode, backendStatus; g != e {
  312. t.Errorf("got res.StatusCode %d; expected %d", g, e)
  313. }
  314. bodyBytes, _ := io.ReadAll(res.Body)
  315. if g, e := string(bodyBytes), backendResponse; g != e {
  316. t.Errorf("got body %q; expected %q", g, e)
  317. }
  318. }
  319. // Issue 38079: don't append to X-Forwarded-For if it's present but nil
  320. func TestXForwardedFor_Omit(t *testing.T) {
  321. backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  322. if v := r.Header.Get("X-Forwarded-For"); v != "" {
  323. t.Errorf("got X-Forwarded-For header: %q", v)
  324. }
  325. w.Write([]byte("hi"))
  326. }))
  327. defer backend.Close()
  328. backendURL, err := url.Parse(backend.URL)
  329. if err != nil {
  330. t.Fatal(err)
  331. }
  332. proxyHandler := NewSingleHostReverseProxy(backendURL)
  333. frontend := httptest.NewServer(proxyHandler)
  334. defer frontend.Close()
  335. oldDirector := proxyHandler.Director
  336. proxyHandler.Director = func(r *http.Request) {
  337. r.Header["X-Forwarded-For"] = nil
  338. oldDirector(r)
  339. }
  340. getReq, _ := http.NewRequest("GET", frontend.URL, nil)
  341. getReq.Host = "some-name"
  342. getReq.Close = true
  343. res, err := frontend.Client().Do(getReq)
  344. if err != nil {
  345. t.Fatalf("Get: %v", err)
  346. }
  347. res.Body.Close()
  348. }
  349. var proxyQueryTests = []struct {
  350. baseSuffix string // suffix to add to backend URL
  351. reqSuffix string // suffix to add to frontend's request URL
  352. want string // what backend should see for final request URL (without ?)
  353. }{
  354. {"", "", ""},
  355. {"?sta=tic", "?us=er", "sta=tic&us=er"},
  356. {"", "?us=er", "us=er"},
  357. {"?sta=tic", "", "sta=tic"},
  358. }
  359. func TestReverseProxyQuery(t *testing.T) {
  360. backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  361. w.Header().Set("X-Got-Query", r.URL.RawQuery)
  362. w.Write([]byte("hi"))
  363. }))
  364. defer backend.Close()
  365. for i, tt := range proxyQueryTests {
  366. backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
  367. if err != nil {
  368. t.Fatal(err)
  369. }
  370. frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
  371. req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
  372. req.Close = true
  373. res, err := frontend.Client().Do(req)
  374. if err != nil {
  375. t.Fatalf("%d. Get: %v", i, err)
  376. }
  377. if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
  378. t.Errorf("%d. got query %q; expected %q", i, g, e)
  379. }
  380. res.Body.Close()
  381. frontend.Close()
  382. }
  383. }
  384. func TestReverseProxyFlushInterval(t *testing.T) {
  385. const expected = "hi"
  386. backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  387. w.Write([]byte(expected))
  388. }))
  389. defer backend.Close()
  390. backendURL, err := url.Parse(backend.URL)
  391. if err != nil {
  392. t.Fatal(err)
  393. }
  394. proxyHandler := NewSingleHostReverseProxy(backendURL)
  395. proxyHandler.FlushInterval = time.Microsecond
  396. frontend := httptest.NewServer(proxyHandler)
  397. defer frontend.Close()
  398. req, _ := http.NewRequest("GET", frontend.URL, nil)
  399. req.Close = true
  400. res, err := frontend.Client().Do(req)
  401. if err != nil {
  402. t.Fatalf("Get: %v", err)
  403. }
  404. defer res.Body.Close()
  405. if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
  406. t.Errorf("got body %q; expected %q", bodyBytes, expected)
  407. }
  408. }
  409. func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
  410. const expected = "hi"
  411. stopCh := make(chan struct{})
  412. backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  413. w.Header().Add("MyHeader", expected)
  414. w.WriteHeader(200)
  415. w.(http.Flusher).Flush()
  416. <-stopCh
  417. }))
  418. defer backend.Close()
  419. defer close(stopCh)
  420. backendURL, err := url.Parse(backend.URL)
  421. if err != nil {
  422. t.Fatal(err)
  423. }
  424. proxyHandler := NewSingleHostReverseProxy(backendURL)
  425. proxyHandler.FlushInterval = time.Microsecond
  426. frontend := httptest.NewServer(proxyHandler)
  427. defer frontend.Close()
  428. req, _ := http.NewRequest("GET", frontend.URL, nil)
  429. req.Close = true
  430. ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
  431. defer cancel()
  432. req = req.WithContext(ctx)
  433. res, err := frontend.Client().Do(req)
  434. if err != nil {
  435. t.Fatalf("Get: %v", err)
  436. }
  437. defer res.Body.Close()
  438. if res.Header.Get("MyHeader") != expected {
  439. t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected)
  440. }
  441. }
  442. func TestReverseProxyCancellation(t *testing.T) {
  443. const backendResponse = "I am the backend"
  444. reqInFlight := make(chan struct{})
  445. backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  446. close(reqInFlight) // cause the client to cancel its request
  447. select {
  448. case <-time.After(10 * time.Second):
  449. // Note: this should only happen in broken implementations, and the
  450. // closenotify case should be instantaneous.
  451. t.Error("Handler never saw CloseNotify")
  452. return
  453. case <-w.(http.CloseNotifier).CloseNotify():
  454. }
  455. w.WriteHeader(http.StatusOK)
  456. w.Write([]byte(backendResponse))
  457. }))
  458. defer backend.Close()
  459. backend.Config.ErrorLog = log.New(io.Discard, "", 0)
  460. backendURL, err := url.Parse(backend.URL)
  461. if err != nil {
  462. t.Fatal(err)
  463. }
  464. proxyHandler := NewSingleHostReverseProxy(backendURL)
  465. // Discards errors of the form:
  466. // http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection
  467. proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
  468. frontend := httptest.NewServer(proxyHandler)
  469. defer frontend.Close()
  470. frontendClient := frontend.Client()
  471. getReq, _ := http.NewRequest("GET", frontend.URL, nil)
  472. go func() {
  473. <-reqInFlight
  474. frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
  475. }()
  476. res, err := frontendClient.Do(getReq)
  477. if res != nil {
  478. t.Errorf("got response %v; want nil", res.Status)
  479. }
  480. if err == nil {
  481. // This should be an error like:
  482. // Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079:
  483. // use of closed network connection
  484. t.Error("Server.Client().Do() returned nil error; want non-nil error")
  485. }
  486. }
  487. func req(t *testing.T, v string) *http.Request {
  488. req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
  489. if err != nil {
  490. t.Fatal(err)
  491. }
  492. return req
  493. }
  494. // Issue 12344
  495. func TestNilBody(t *testing.T) {
  496. backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  497. w.Write([]byte("hi"))
  498. }))
  499. defer backend.Close()
  500. frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
  501. backURL, _ := url.Parse(backend.URL)
  502. rp := NewSingleHostReverseProxy(backURL)
  503. r := req(t, "GET / HTTP/1.0\r\n\r\n")
  504. r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working
  505. rp.ServeHTTP(w, r)
  506. }))
  507. defer frontend.Close()
  508. res, err := http.Get(frontend.URL)
  509. if err != nil {
  510. t.Fatal(err)
  511. }
  512. defer res.Body.Close()
  513. slurp, err := io.ReadAll(res.Body)
  514. if err != nil {
  515. t.Fatal(err)
  516. }
  517. if string(slurp) != "hi" {
  518. t.Errorf("Got %q; want %q", slurp, "hi")
  519. }
  520. }
  521. // Issue 15524
  522. func TestUserAgentHeader(t *testing.T) {
  523. const explicitUA = "explicit UA"
  524. backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  525. if r.URL.Path == "/noua" {
  526. if c := r.Header.Get("User-Agent"); c != "" {
  527. t.Errorf("handler got non-empty User-Agent header %q", c)
  528. }
  529. return
  530. }
  531. if c := r.Header.Get("User-Agent"); c != explicitUA {
  532. t.Errorf("handler got unexpected User-Agent header %q", c)
  533. }
  534. }))
  535. defer backend.Close()
  536. backendURL, err := url.Parse(backend.URL)
  537. if err != nil {
  538. t.Fatal(err)
  539. }
  540. proxyHandler := NewSingleHostReverseProxy(backendURL)
  541. proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  542. frontend := httptest.NewServer(proxyHandler)
  543. defer frontend.Close()
  544. frontendClient := frontend.Client()
  545. getReq, _ := http.NewRequest("GET", frontend.URL, nil)
  546. getReq.Header.Set("User-Agent", explicitUA)
  547. getReq.Close = true
  548. res, err := frontendClient.Do(getReq)
  549. if err != nil {
  550. t.Fatalf("Get: %v", err)
  551. }
  552. res.Body.Close()
  553. getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil)
  554. getReq.Header.Set("User-Agent", "")
  555. getReq.Close = true
  556. res, err = frontendClient.Do(getReq)
  557. if err != nil {
  558. t.Fatalf("Get: %v", err)
  559. }
  560. res.Body.Close()
  561. }
  562. type bufferPool struct {
  563. get func() []byte
  564. put func([]byte)
  565. }
  566. func (bp bufferPool) Get() []byte { return bp.get() }
  567. func (bp bufferPool) Put(v []byte) { bp.put(v) }
  568. func TestReverseProxyGetPutBuffer(t *testing.T) {
  569. const msg = "hi"
  570. backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  571. io.WriteString(w, msg)
  572. }))
  573. defer backend.Close()
  574. backendURL, err := url.Parse(backend.URL)
  575. if err != nil {
  576. t.Fatal(err)
  577. }
  578. var (
  579. mu sync.Mutex
  580. log []string
  581. )
  582. addLog := func(event string) {
  583. mu.Lock()
  584. defer mu.Unlock()
  585. log = append(log, event)
  586. }
  587. rp := NewSingleHostReverseProxy(backendURL)
  588. const size = 1234
  589. rp.BufferPool = bufferPool{
  590. get: func() []byte {
  591. addLog("getBuf")
  592. return make([]byte, size)
  593. },
  594. put: func(p []byte) {
  595. addLog("putBuf-" + strconv.Itoa(len(p)))
  596. },
  597. }
  598. frontend := httptest.NewServer(rp)
  599. defer frontend.Close()
  600. req, _ := http.NewRequest("GET", frontend.URL, nil)
  601. req.Close = true
  602. res, err := frontend.Client().Do(req)
  603. if err != nil {
  604. t.Fatalf("Get: %v", err)
  605. }
  606. slurp, err := io.ReadAll(res.Body)
  607. res.Body.Close()
  608. if err != nil {
  609. t.Fatalf("reading body: %v", err)
  610. }
  611. if string(slurp) != msg {
  612. t.Errorf("msg = %q; want %q", slurp, msg)
  613. }
  614. wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
  615. mu.Lock()
  616. defer mu.Unlock()
  617. if !reflect.DeepEqual(log, wantLog) {
  618. t.Errorf("Log events = %q; want %q", log, wantLog)
  619. }
  620. }
  621. func TestReverseProxy_Post(t *testing.T) {
  622. const backendResponse = "I am the backend"
  623. const backendStatus = 200
  624. var requestBody = bytes.Repeat([]byte("a"), 1<<20)
  625. backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  626. slurp, err := io.ReadAll(r.Body)
  627. if err != nil {
  628. t.Errorf("Backend body read = %v", err)
  629. }
  630. if len(slurp) != len(requestBody) {
  631. t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
  632. }
  633. if !bytes.Equal(slurp, requestBody) {
  634. t.Error("Backend read wrong request body.") // 1MB; omitting details
  635. }
  636. w.Write([]byte(backendResponse))
  637. }))
  638. defer backend.Close()
  639. backendURL, err := url.Parse(backend.URL)
  640. if err != nil {
  641. t.Fatal(err)
  642. }
  643. proxyHandler := NewSingleHostReverseProxy(backendURL)
  644. frontend := httptest.NewServer(proxyHandler)
  645. defer frontend.Close()
  646. postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
  647. res, err := frontend.Client().Do(postReq)
  648. if err != nil {
  649. t.Fatalf("Do: %v", err)
  650. }
  651. if g, e := res.StatusCode, backendStatus; g != e {
  652. t.Errorf("got res.StatusCode %d; expected %d", g, e)
  653. }
  654. bodyBytes, _ := io.ReadAll(res.Body)
  655. if g, e := string(bodyBytes), backendResponse; g != e {
  656. t.Errorf("got body %q; expected %q", g, e)
  657. }
  658. }
  659. type RoundTripperFunc func(*http.Request) (*http.Response, error)
  660. func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
  661. return fn(req)
  662. }
  663. // Issue 16036: send a Request with a nil Body when possible
  664. func TestReverseProxy_NilBody(t *testing.T) {
  665. backendURL, _ := url.Parse("http://fake.tld/")
  666. proxyHandler := NewSingleHostReverseProxy(backendURL)
  667. proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  668. proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
  669. if req.Body != nil {
  670. t.Error("Body != nil; want a nil Body")
  671. }
  672. return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
  673. })
  674. frontend := httptest.NewServer(proxyHandler)
  675. defer frontend.Close()
  676. res, err := frontend.Client().Get(frontend.URL)
  677. if err != nil {
  678. t.Fatal(err)
  679. }
  680. defer res.Body.Close()
  681. if res.StatusCode != 502 {
  682. t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
  683. }
  684. }
  685. // Issue 33142: always allocate the request headers
  686. func TestReverseProxy_AllocatedHeader(t *testing.T) {
  687. proxyHandler := new(ReverseProxy)
  688. proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  689. proxyHandler.Director = func(*http.Request) {} // noop
  690. proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
  691. if req.Header == nil {
  692. t.Error("Header == nil; want a non-nil Header")
  693. }
  694. return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
  695. })
  696. proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{
  697. Method: "GET",
  698. URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"},
  699. Proto: "HTTP/1.0",
  700. ProtoMajor: 1,
  701. })
  702. }
  703. // Issue 14237. Test ModifyResponse and that an error from it
  704. // causes the proxy to return StatusBadGateway, or StatusOK otherwise.
  705. func TestReverseProxyModifyResponse(t *testing.T) {
  706. backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  707. w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
  708. }))
  709. defer backendServer.Close()
  710. rpURL, _ := url.Parse(backendServer.URL)
  711. rproxy := NewSingleHostReverseProxy(rpURL)
  712. rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  713. rproxy.ModifyResponse = func(resp *http.Response) error {
  714. if resp.Header.Get("X-Hit-Mod") != "true" {
  715. return fmt.Errorf("tried to by-pass proxy")
  716. }
  717. return nil
  718. }
  719. frontendProxy := httptest.NewServer(rproxy)
  720. defer frontendProxy.Close()
  721. tests := []struct {
  722. url string
  723. wantCode int
  724. }{
  725. {frontendProxy.URL + "/mod", http.StatusOK},
  726. {frontendProxy.URL + "/schedule", http.StatusBadGateway},
  727. }
  728. for i, tt := range tests {
  729. resp, err := http.Get(tt.url)
  730. if err != nil {
  731. t.Fatalf("failed to reach proxy: %v", err)
  732. }
  733. if g, e := resp.StatusCode, tt.wantCode; g != e {
  734. t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
  735. }
  736. resp.Body.Close()
  737. }
  738. }
  739. type failingRoundTripper struct{}
  740. func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
  741. return nil, errors.New("some error")
  742. }
  743. type staticResponseRoundTripper struct{ res *http.Response }
  744. func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
  745. return rt.res, nil
  746. }
  747. func TestReverseProxyErrorHandler(t *testing.T) {
  748. tests := []struct {
  749. name string
  750. wantCode int
  751. errorHandler func(http.ResponseWriter, *http.Request, error)
  752. transport http.RoundTripper // defaults to failingRoundTripper
  753. modifyResponse func(*http.Response) error
  754. }{
  755. {
  756. name: "default",
  757. wantCode: http.StatusBadGateway,
  758. },
  759. {
  760. name: "errorhandler",
  761. wantCode: http.StatusTeapot,
  762. errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
  763. },
  764. {
  765. name: "modifyresponse_noerr",
  766. transport: staticResponseRoundTripper{
  767. &http.Response{StatusCode: 345, Body: http.NoBody},
  768. },
  769. modifyResponse: func(res *http.Response) error {
  770. res.StatusCode++
  771. return nil
  772. },
  773. errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
  774. wantCode: 346,
  775. },
  776. {
  777. name: "modifyresponse_err",
  778. transport: staticResponseRoundTripper{
  779. &http.Response{StatusCode: 345, Body: http.NoBody},
  780. },
  781. modifyResponse: func(res *http.Response) error {
  782. res.StatusCode++
  783. return errors.New("some error to trigger errorHandler")
  784. },
  785. errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
  786. wantCode: http.StatusTeapot,
  787. },
  788. }
  789. for _, tt := range tests {
  790. t.Run(tt.name, func(t *testing.T) {
  791. target := &url.URL{
  792. Scheme: "http",
  793. Host: "dummy.tld",
  794. Path: "/",
  795. }
  796. rproxy := NewSingleHostReverseProxy(target)
  797. rproxy.Transport = tt.transport
  798. rproxy.ModifyResponse = tt.modifyResponse
  799. if rproxy.Transport == nil {
  800. rproxy.Transport = failingRoundTripper{}
  801. }
  802. rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  803. if tt.errorHandler != nil {
  804. rproxy.ErrorHandler = tt.errorHandler
  805. }
  806. frontendProxy := httptest.NewServer(rproxy)
  807. defer frontendProxy.Close()
  808. resp, err := http.Get(frontendProxy.URL + "/test")
  809. if err != nil {
  810. t.Fatalf("failed to reach proxy: %v", err)
  811. }
  812. if g, e := resp.StatusCode, tt.wantCode; g != e {
  813. t.Errorf("got res.StatusCode %d; expected %d", g, e)
  814. }
  815. resp.Body.Close()
  816. })
  817. }
  818. }
  819. // Issue 16659: log errors from short read
  820. func TestReverseProxy_CopyBuffer(t *testing.T) {
  821. backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  822. out := "this call was relayed by the reverse proxy"
  823. // Coerce a wrong content length to induce io.UnexpectedEOF
  824. w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
  825. fmt.Fprintln(w, out)
  826. }))
  827. defer backendServer.Close()
  828. rpURL, err := url.Parse(backendServer.URL)
  829. if err != nil {
  830. t.Fatal(err)
  831. }
  832. var proxyLog bytes.Buffer
  833. rproxy := NewSingleHostReverseProxy(rpURL)
  834. rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
  835. donec := make(chan bool, 1)
  836. frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  837. defer func() { donec <- true }()
  838. rproxy.ServeHTTP(w, r)
  839. }))
  840. defer frontendProxy.Close()
  841. if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
  842. t.Fatalf("want non-nil error")
  843. }
  844. // The race detector complains about the proxyLog usage in logf in copyBuffer
  845. // and our usage below with proxyLog.Bytes() so we're explicitly using a
  846. // channel to ensure that the ReverseProxy's ServeHTTP is done before we
  847. // continue after Get.
  848. <-donec
  849. expected := []string{
  850. "EOF",
  851. "read",
  852. }
  853. for _, phrase := range expected {
  854. if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
  855. t.Errorf("expected log to contain phrase %q", phrase)
  856. }
  857. }
  858. }
  859. type staticTransport struct {
  860. res *http.Response
  861. }
  862. func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
  863. return t.res, nil
  864. }
  865. func BenchmarkServeHTTP(b *testing.B) {
  866. res := &http.Response{
  867. StatusCode: 200,
  868. Body: io.NopCloser(strings.NewReader("")),
  869. }
  870. proxy := &ReverseProxy{
  871. Director: func(*http.Request) {},
  872. Transport: &staticTransport{res},
  873. }
  874. w := httptest.NewRecorder()
  875. r := httptest.NewRequest("GET", "/", nil)
  876. b.ReportAllocs()
  877. for i := 0; i < b.N; i++ {
  878. proxy.ServeHTTP(w, r)
  879. }
  880. }
  881. func TestServeHTTPDeepCopy(t *testing.T) {
  882. backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  883. w.Write([]byte("Hello Gopher!"))
  884. }))
  885. defer backend.Close()
  886. backendURL, err := url.Parse(backend.URL)
  887. if err != nil {
  888. t.Fatal(err)
  889. }
  890. type result struct {
  891. before, after string
  892. }
  893. resultChan := make(chan result, 1)
  894. proxyHandler := NewSingleHostReverseProxy(backendURL)
  895. frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  896. before := r.URL.String()
  897. proxyHandler.ServeHTTP(w, r)
  898. after := r.URL.String()
  899. resultChan <- result{before: before, after: after}
  900. }))
  901. defer frontend.Close()
  902. want := result{before: "/", after: "/"}
  903. res, err := frontend.Client().Get(frontend.URL)
  904. if err != nil {
  905. t.Fatalf("Do: %v", err)
  906. }
  907. res.Body.Close()
  908. got := <-resultChan
  909. if got != want {
  910. t.Errorf("got = %+v; want = %+v", got, want)
  911. }
  912. }
  913. // Issue 18327: verify we always do a deep copy of the Request.Header map
  914. // before any mutations.
  915. func TestClonesRequestHeaders(t *testing.T) {
  916. log.SetOutput(io.Discard)
  917. defer log.SetOutput(os.Stderr)
  918. req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  919. req.RemoteAddr = "1.2.3.4:56789"
  920. rp := &ReverseProxy{
  921. Director: func(req *http.Request) {
  922. req.Header.Set("From-Director", "1")
  923. },
  924. Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
  925. if v := req.Header.Get("From-Director"); v != "1" {
  926. t.Errorf("From-Directory value = %q; want 1", v)
  927. }
  928. return nil, io.EOF
  929. }),
  930. }
  931. rp.ServeHTTP(httptest.NewRecorder(), req)
  932. if req.Header.Get("From-Director") == "1" {
  933. t.Error("Director header mutation modified caller's request")
  934. }
  935. if req.Header.Get("X-Forwarded-For") != "" {
  936. t.Error("X-Forward-For header mutation modified caller's request")
  937. }
  938. }
  939. type roundTripperFunc func(req *http.Request) (*http.Response, error)
  940. func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
  941. return fn(req)
  942. }
  943. func TestModifyResponseClosesBody(t *testing.T) {
  944. req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  945. req.RemoteAddr = "1.2.3.4:56789"
  946. closeCheck := new(checkCloser)
  947. logBuf := new(bytes.Buffer)
  948. outErr := errors.New("ModifyResponse error")
  949. rp := &ReverseProxy{
  950. Director: func(req *http.Request) {},
  951. Transport: &staticTransport{&http.Response{
  952. StatusCode: 200,
  953. Body: closeCheck,
  954. }},
  955. ErrorLog: log.New(logBuf, "", 0),
  956. ModifyResponse: func(*http.Response) error {
  957. return outErr
  958. },
  959. }
  960. rec := httptest.NewRecorder()
  961. rp.ServeHTTP(rec, req)
  962. res := rec.Result()
  963. if g, e := res.StatusCode, http.StatusBadGateway; g != e {
  964. t.Errorf("got res.StatusCode %d; expected %d", g, e)
  965. }
  966. if !closeCheck.closed {
  967. t.Errorf("body should have been closed")
  968. }
  969. if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
  970. t.Errorf("ErrorLog %q does not contain %q", g, e)
  971. }
  972. }
  973. type checkCloser struct {
  974. closed bool
  975. }
  976. func (cc *checkCloser) Close() error {
  977. cc.closed = true
  978. return nil
  979. }
  980. func (cc *checkCloser) Read(b []byte) (int, error) {
  981. return len(b), nil
  982. }
  983. // Issue 23643: panic on body copy error
  984. func TestReverseProxy_PanicBodyError(t *testing.T) {
  985. log.SetOutput(io.Discard)
  986. defer log.SetOutput(os.Stderr)
  987. backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  988. out := "this call was relayed by the reverse proxy"
  989. // Coerce a wrong content length to induce io.ErrUnexpectedEOF
  990. w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
  991. fmt.Fprintln(w, out)
  992. }))
  993. defer backendServer.Close()
  994. rpURL, err := url.Parse(backendServer.URL)
  995. if err != nil {
  996. t.Fatal(err)
  997. }
  998. rproxy := NewSingleHostReverseProxy(rpURL)
  999. // Ensure that the handler panics when the body read encounters an
  1000. // io.ErrUnexpectedEOF
  1001. defer func() {
  1002. err := recover()
  1003. if err == nil {
  1004. t.Fatal("handler should have panicked")
  1005. }
  1006. if err != http.ErrAbortHandler {
  1007. t.Fatal("expected ErrAbortHandler, got", err)
  1008. }
  1009. }()
  1010. req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1011. rproxy.ServeHTTP(httptest.NewRecorder(), req)
  1012. }
  1013. // Issue #46866: panic without closing incoming request body causes a panic
  1014. func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) {
  1015. backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1016. out := "this call was relayed by the reverse proxy"
  1017. // Coerce a wrong content length to induce io.ErrUnexpectedEOF
  1018. w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
  1019. fmt.Fprintln(w, out)
  1020. }))
  1021. defer backend.Close()
  1022. backendURL, err := url.Parse(backend.URL)
  1023. if err != nil {
  1024. t.Fatal(err)
  1025. }
  1026. proxyHandler := NewSingleHostReverseProxy(backendURL)
  1027. proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1028. frontend := httptest.NewServer(proxyHandler)
  1029. defer frontend.Close()
  1030. frontendClient := frontend.Client()
  1031. var wg sync.WaitGroup
  1032. for i := 0; i < 2; i++ {
  1033. wg.Add(1)
  1034. go func() {
  1035. defer wg.Done()
  1036. for j := 0; j < 10; j++ {
  1037. const reqLen = 6 * 1024 * 1024
  1038. req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
  1039. req.ContentLength = reqLen
  1040. resp, _ := frontendClient.Transport.RoundTrip(req)
  1041. if resp != nil {
  1042. io.Copy(io.Discard, resp.Body)
  1043. resp.Body.Close()
  1044. }
  1045. }
  1046. }()
  1047. }
  1048. wg.Wait()
  1049. }
  1050. func TestSelectFlushInterval(t *testing.T) {
  1051. tests := []struct {
  1052. name string
  1053. p *ReverseProxy
  1054. res *http.Response
  1055. want time.Duration
  1056. }{
  1057. {
  1058. name: "default",
  1059. res: &http.Response{},
  1060. p: &ReverseProxy{FlushInterval: 123},
  1061. want: 123,
  1062. },
  1063. {
  1064. name: "server-sent events overrides non-zero",
  1065. res: &http.Response{
  1066. Header: http.Header{
  1067. "Content-Type": {"text/event-stream"},
  1068. },
  1069. },
  1070. p: &ReverseProxy{FlushInterval: 123},
  1071. want: -1,
  1072. },
  1073. {
  1074. name: "server-sent events overrides zero",
  1075. res: &http.Response{
  1076. Header: http.Header{
  1077. "Content-Type": {"text/event-stream"},
  1078. },
  1079. },
  1080. p: &ReverseProxy{FlushInterval: 0},
  1081. want: -1,
  1082. },
  1083. {
  1084. name: "server-sent events with media-type parameters overrides non-zero",
  1085. res: &http.Response{
  1086. Header: http.Header{
  1087. "Content-Type": {"text/event-stream;charset=utf-8"},
  1088. },
  1089. },
  1090. p: &ReverseProxy{FlushInterval: 123},
  1091. want: -1,
  1092. },
  1093. {
  1094. name: "server-sent events with media-type parameters overrides zero",
  1095. res: &http.Response{
  1096. Header: http.Header{
  1097. "Content-Type": {"text/event-stream;charset=utf-8"},
  1098. },
  1099. },
  1100. p: &ReverseProxy{FlushInterval: 0},
  1101. want: -1,
  1102. },
  1103. {
  1104. name: "Content-Length: -1, overrides non-zero",
  1105. res: &http.Response{
  1106. ContentLength: -1,
  1107. },
  1108. p: &ReverseProxy{FlushInterval: 123},
  1109. want: -1,
  1110. },
  1111. {
  1112. name: "Content-Length: -1, overrides zero",
  1113. res: &http.Response{
  1114. ContentLength: -1,
  1115. },
  1116. p: &ReverseProxy{FlushInterval: 0},
  1117. want: -1,
  1118. },
  1119. }
  1120. for _, tt := range tests {
  1121. t.Run(tt.name, func(t *testing.T) {
  1122. got := tt.p.flushInterval(tt.res)
  1123. if got != tt.want {
  1124. t.Errorf("flushLatency = %v; want %v", got, tt.want)
  1125. }
  1126. })
  1127. }
  1128. }
  1129. func TestReverseProxyWebSocket(t *testing.T) {
  1130. backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1131. if upgradeType(r.Header) != "websocket" {
  1132. t.Error("unexpected backend request")
  1133. http.Error(w, "unexpected request", 400)
  1134. return
  1135. }
  1136. c, _, err := w.(http.Hijacker).Hijack()
  1137. if err != nil {
  1138. t.Error(err)
  1139. return
  1140. }
  1141. defer c.Close()
  1142. io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
  1143. bs := bufio.NewScanner(c)
  1144. if !bs.Scan() {
  1145. t.Errorf("backend failed to read line from client: %v", bs.Err())
  1146. return
  1147. }
  1148. fmt.Fprintf(c, "backend got %q\n", bs.Text())
  1149. }))
  1150. defer backendServer.Close()
  1151. backURL, _ := url.Parse(backendServer.URL)
  1152. rproxy := NewSingleHostReverseProxy(backURL)
  1153. rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1154. rproxy.ModifyResponse = func(res *http.Response) error {
  1155. res.Header.Add("X-Modified", "true")
  1156. return nil
  1157. }
  1158. handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1159. rw.Header().Set("X-Header", "X-Value")
  1160. rproxy.ServeHTTP(rw, req)
  1161. if got, want := rw.Header().Get("X-Modified"), "true"; got != want {
  1162. t.Errorf("response writer X-Modified header = %q; want %q", got, want)
  1163. }
  1164. })
  1165. frontendProxy := httptest.NewServer(handler)
  1166. defer frontendProxy.Close()
  1167. req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1168. req.Header.Set("Connection", "Upgrade")
  1169. req.Header.Set("Upgrade", "websocket")
  1170. c := frontendProxy.Client()
  1171. res, err := c.Do(req)
  1172. if err != nil {
  1173. t.Fatal(err)
  1174. }
  1175. if res.StatusCode != 101 {
  1176. t.Fatalf("status = %v; want 101", res.Status)
  1177. }
  1178. got := res.Header.Get("X-Header")
  1179. want := "X-Value"
  1180. if got != want {
  1181. t.Errorf("Header(XHeader) = %q; want %q", got, want)
  1182. }
  1183. if !ascii.EqualFold(upgradeType(res.Header), "websocket") {
  1184. t.Fatalf("not websocket upgrade; got %#v", res.Header)
  1185. }
  1186. rwc, ok := res.Body.(io.ReadWriteCloser)
  1187. if !ok {
  1188. t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
  1189. }
  1190. defer rwc.Close()
  1191. if got, want := res.Header.Get("X-Modified"), "true"; got != want {
  1192. t.Errorf("response X-Modified header = %q; want %q", got, want)
  1193. }
  1194. io.WriteString(rwc, "Hello\n")
  1195. bs := bufio.NewScanner(rwc)
  1196. if !bs.Scan() {
  1197. t.Fatalf("Scan: %v", bs.Err())
  1198. }
  1199. got = bs.Text()
  1200. want = `backend got "Hello"`
  1201. if got != want {
  1202. t.Errorf("got %#q, want %#q", got, want)
  1203. }
  1204. }
  1205. func TestReverseProxyWebSocketCancellation(t *testing.T) {
  1206. n := 5
  1207. triggerCancelCh := make(chan bool, n)
  1208. nthResponse := func(i int) string {
  1209. return fmt.Sprintf("backend response #%d\n", i)
  1210. }
  1211. terminalMsg := "final message"
  1212. cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1213. if g, ws := upgradeType(r.Header), "websocket"; g != ws {
  1214. t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
  1215. http.Error(w, "Unexpected request", 400)
  1216. return
  1217. }
  1218. conn, bufrw, err := w.(http.Hijacker).Hijack()
  1219. if err != nil {
  1220. t.Error(err)
  1221. return
  1222. }
  1223. defer conn.Close()
  1224. upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
  1225. if _, err := io.WriteString(conn, upgradeMsg); err != nil {
  1226. t.Error(err)
  1227. return
  1228. }
  1229. if _, _, err := bufrw.ReadLine(); err != nil {
  1230. t.Errorf("Failed to read line from client: %v", err)
  1231. return
  1232. }
  1233. for i := 0; i < n; i++ {
  1234. if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
  1235. select {
  1236. case <-triggerCancelCh:
  1237. default:
  1238. t.Errorf("Writing response #%d failed: %v", i, err)
  1239. }
  1240. return
  1241. }
  1242. bufrw.Flush()
  1243. time.Sleep(time.Second)
  1244. }
  1245. if _, err := bufrw.WriteString(terminalMsg); err != nil {
  1246. select {
  1247. case <-triggerCancelCh:
  1248. default:
  1249. t.Errorf("Failed to write terminal message: %v", err)
  1250. }
  1251. }
  1252. bufrw.Flush()
  1253. }))
  1254. defer cst.Close()
  1255. backendURL, _ := url.Parse(cst.URL)
  1256. rproxy := NewSingleHostReverseProxy(backendURL)
  1257. rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1258. rproxy.ModifyResponse = func(res *http.Response) error {
  1259. res.Header.Add("X-Modified", "true")
  1260. return nil
  1261. }
  1262. handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1263. rw.Header().Set("X-Header", "X-Value")
  1264. ctx, cancel := context.WithCancel(req.Context())
  1265. go func() {
  1266. <-triggerCancelCh
  1267. cancel()
  1268. }()
  1269. rproxy.ServeHTTP(rw, req.WithContext(ctx))
  1270. })
  1271. frontendProxy := httptest.NewServer(handler)
  1272. defer frontendProxy.Close()
  1273. req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1274. req.Header.Set("Connection", "Upgrade")
  1275. req.Header.Set("Upgrade", "websocket")
  1276. res, err := frontendProxy.Client().Do(req)
  1277. if err != nil {
  1278. t.Fatalf("Dialing to frontend proxy: %v", err)
  1279. }
  1280. defer res.Body.Close()
  1281. if g, w := res.StatusCode, 101; g != w {
  1282. t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
  1283. }
  1284. if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
  1285. t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w)
  1286. }
  1287. if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) {
  1288. t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w)
  1289. }
  1290. rwc, ok := res.Body.(io.ReadWriteCloser)
  1291. if !ok {
  1292. t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
  1293. }
  1294. if got, want := res.Header.Get("X-Modified"), "true"; got != want {
  1295. t.Errorf("response X-Modified header = %q; want %q", got, want)
  1296. }
  1297. if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
  1298. t.Fatalf("Failed to write first message: %v", err)
  1299. }
  1300. // Read loop.
  1301. br := bufio.NewReader(rwc)
  1302. for {
  1303. line, err := br.ReadString('\n')
  1304. switch {
  1305. case line == terminalMsg: // this case before "err == io.EOF"
  1306. t.Fatalf("The websocket request was not canceled, unfortunately!")
  1307. case err == io.EOF:
  1308. return
  1309. case err != nil:
  1310. t.Fatalf("Unexpected error: %v", err)
  1311. case line == nthResponse(0): // We've gotten the first response back
  1312. // Let's trigger a cancel.
  1313. close(triggerCancelCh)
  1314. }
  1315. }
  1316. }
  1317. func TestUnannouncedTrailer(t *testing.T) {
  1318. backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1319. w.WriteHeader(http.StatusOK)
  1320. w.(http.Flusher).Flush()
  1321. w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
  1322. }))
  1323. defer backend.Close()
  1324. backendURL, err := url.Parse(backend.URL)
  1325. if err != nil {
  1326. t.Fatal(err)
  1327. }
  1328. proxyHandler := NewSingleHostReverseProxy(backendURL)
  1329. proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1330. frontend := httptest.NewServer(proxyHandler)
  1331. defer frontend.Close()
  1332. frontendClient := frontend.Client()
  1333. res, err := frontendClient.Get(frontend.URL)
  1334. if err != nil {
  1335. t.Fatalf("Get: %v", err)
  1336. }
  1337. io.ReadAll(res.Body)
  1338. if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w {
  1339. t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w)
  1340. }
  1341. }
  1342. func TestSingleJoinSlash(t *testing.T) {
  1343. tests := []struct {
  1344. slasha string
  1345. slashb string
  1346. expected string
  1347. }{
  1348. {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1349. {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1350. {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"},
  1351. {"https://www.google.com", "", "https://www.google.com/"},
  1352. {"", "favicon.ico", "/favicon.ico"},
  1353. }
  1354. for _, tt := range tests {
  1355. if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected {
  1356. t.Errorf("singleJoiningSlash(%q,%q) want %q got %q",
  1357. tt.slasha,
  1358. tt.slashb,
  1359. tt.expected,
  1360. got)
  1361. }
  1362. }
  1363. }
  1364. func TestJoinURLPath(t *testing.T) {
  1365. tests := []struct {
  1366. a *url.URL
  1367. b *url.URL
  1368. wantPath string
  1369. wantRaw string
  1370. }{
  1371. {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""},
  1372. {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"},
  1373. {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
  1374. {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
  1375. {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"},
  1376. {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"},
  1377. }
  1378. for _, tt := range tests {
  1379. p, rp := joinURLPath(tt.a, tt.b)
  1380. if p != tt.wantPath || rp != tt.wantRaw {
  1381. t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)",
  1382. tt.a.Path, tt.a.RawPath,
  1383. tt.b.Path, tt.b.RawPath,
  1384. tt.wantPath, tt.wantRaw,
  1385. p, rp)
  1386. }
  1387. }
  1388. }