http_proxy.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. """
  2. Basic HTTP Proxy
  3. ================
  4. .. autoclass:: ProxyMiddleware
  5. :copyright: 2007 Pallets
  6. :license: BSD-3-Clause
  7. """
  8. import typing as t
  9. from http import client
  10. from ..datastructures import EnvironHeaders
  11. from ..http import is_hop_by_hop_header
  12. from ..urls import url_parse
  13. from ..urls import url_quote
  14. from ..wsgi import get_input_stream
  15. if t.TYPE_CHECKING:
  16. from _typeshed.wsgi import StartResponse
  17. from _typeshed.wsgi import WSGIApplication
  18. from _typeshed.wsgi import WSGIEnvironment
  19. class ProxyMiddleware:
  20. """Proxy requests under a path to an external server, routing other
  21. requests to the app.
  22. This middleware can only proxy HTTP requests, as HTTP is the only
  23. protocol handled by the WSGI server. Other protocols, such as
  24. WebSocket requests, cannot be proxied at this layer. This should
  25. only be used for development, in production a real proxy server
  26. should be used.
  27. The middleware takes a dict mapping a path prefix to a dict
  28. describing the host to be proxied to::
  29. app = ProxyMiddleware(app, {
  30. "/static/": {
  31. "target": "http://127.0.0.1:5001/",
  32. }
  33. })
  34. Each host has the following options:
  35. ``target``:
  36. The target URL to dispatch to. This is required.
  37. ``remove_prefix``:
  38. Whether to remove the prefix from the URL before dispatching it
  39. to the target. The default is ``False``.
  40. ``host``:
  41. ``"<auto>"`` (default):
  42. The host header is automatically rewritten to the URL of the
  43. target.
  44. ``None``:
  45. The host header is unmodified from the client request.
  46. Any other value:
  47. The host header is overwritten with the value.
  48. ``headers``:
  49. A dictionary of headers to be sent with the request to the
  50. target. The default is ``{}``.
  51. ``ssl_context``:
  52. A :class:`ssl.SSLContext` defining how to verify requests if the
  53. target is HTTPS. The default is ``None``.
  54. In the example above, everything under ``"/static/"`` is proxied to
  55. the server on port 5001. The host header is rewritten to the target,
  56. and the ``"/static/"`` prefix is removed from the URLs.
  57. :param app: The WSGI application to wrap.
  58. :param targets: Proxy target configurations. See description above.
  59. :param chunk_size: Size of chunks to read from input stream and
  60. write to target.
  61. :param timeout: Seconds before an operation to a target fails.
  62. .. versionadded:: 0.14
  63. """
  64. def __init__(
  65. self,
  66. app: "WSGIApplication",
  67. targets: t.Mapping[str, t.Dict[str, t.Any]],
  68. chunk_size: int = 2 << 13,
  69. timeout: int = 10,
  70. ) -> None:
  71. def _set_defaults(opts: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
  72. opts.setdefault("remove_prefix", False)
  73. opts.setdefault("host", "<auto>")
  74. opts.setdefault("headers", {})
  75. opts.setdefault("ssl_context", None)
  76. return opts
  77. self.app = app
  78. self.targets = {
  79. f"/{k.strip('/')}/": _set_defaults(v) for k, v in targets.items()
  80. }
  81. self.chunk_size = chunk_size
  82. self.timeout = timeout
  83. def proxy_to(
  84. self, opts: t.Dict[str, t.Any], path: str, prefix: str
  85. ) -> "WSGIApplication":
  86. target = url_parse(opts["target"])
  87. host = t.cast(str, target.ascii_host)
  88. def application(
  89. environ: "WSGIEnvironment", start_response: "StartResponse"
  90. ) -> t.Iterable[bytes]:
  91. headers = list(EnvironHeaders(environ).items())
  92. headers[:] = [
  93. (k, v)
  94. for k, v in headers
  95. if not is_hop_by_hop_header(k)
  96. and k.lower() not in ("content-length", "host")
  97. ]
  98. headers.append(("Connection", "close"))
  99. if opts["host"] == "<auto>":
  100. headers.append(("Host", host))
  101. elif opts["host"] is None:
  102. headers.append(("Host", environ["HTTP_HOST"]))
  103. else:
  104. headers.append(("Host", opts["host"]))
  105. headers.extend(opts["headers"].items())
  106. remote_path = path
  107. if opts["remove_prefix"]:
  108. remote_path = remote_path[len(prefix) :].lstrip("/")
  109. remote_path = f"{target.path.rstrip('/')}/{remote_path}"
  110. content_length = environ.get("CONTENT_LENGTH")
  111. chunked = False
  112. if content_length not in ("", None):
  113. headers.append(("Content-Length", content_length)) # type: ignore
  114. elif content_length is not None:
  115. headers.append(("Transfer-Encoding", "chunked"))
  116. chunked = True
  117. try:
  118. if target.scheme == "http":
  119. con = client.HTTPConnection(
  120. host, target.port or 80, timeout=self.timeout
  121. )
  122. elif target.scheme == "https":
  123. con = client.HTTPSConnection(
  124. host,
  125. target.port or 443,
  126. timeout=self.timeout,
  127. context=opts["ssl_context"],
  128. )
  129. else:
  130. raise RuntimeError(
  131. "Target scheme must be 'http' or 'https', got"
  132. f" {target.scheme!r}."
  133. )
  134. con.connect()
  135. remote_url = url_quote(remote_path)
  136. querystring = environ["QUERY_STRING"]
  137. if querystring:
  138. remote_url = f"{remote_url}?{querystring}"
  139. con.putrequest(environ["REQUEST_METHOD"], remote_url, skip_host=True)
  140. for k, v in headers:
  141. if k.lower() == "connection":
  142. v = "close"
  143. con.putheader(k, v)
  144. con.endheaders()
  145. stream = get_input_stream(environ)
  146. while True:
  147. data = stream.read(self.chunk_size)
  148. if not data:
  149. break
  150. if chunked:
  151. con.send(b"%x\r\n%s\r\n" % (len(data), data))
  152. else:
  153. con.send(data)
  154. resp = con.getresponse()
  155. except OSError:
  156. from ..exceptions import BadGateway
  157. return BadGateway()(environ, start_response)
  158. start_response(
  159. f"{resp.status} {resp.reason}",
  160. [
  161. (k.title(), v)
  162. for k, v in resp.getheaders()
  163. if not is_hop_by_hop_header(k)
  164. ],
  165. )
  166. def read() -> t.Iterator[bytes]:
  167. while True:
  168. try:
  169. data = resp.read(self.chunk_size)
  170. except OSError:
  171. break
  172. if not data:
  173. break
  174. yield data
  175. return read()
  176. return application
  177. def __call__(
  178. self, environ: "WSGIEnvironment", start_response: "StartResponse"
  179. ) -> t.Iterable[bytes]:
  180. path = environ["PATH_INFO"]
  181. app = self.app
  182. for prefix, opts in self.targets.items():
  183. if path.startswith(prefix):
  184. app = self.proxy_to(opts, path, prefix)
  185. break
  186. return app(environ, start_response)