testing.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. import contextlib
  2. import io
  3. import os
  4. import shlex
  5. import shutil
  6. import sys
  7. import tempfile
  8. import typing as t
  9. from types import TracebackType
  10. from . import formatting
  11. from . import termui
  12. from . import utils
  13. from ._compat import _find_binary_reader
  14. if t.TYPE_CHECKING:
  15. from .core import BaseCommand
  16. class EchoingStdin:
  17. def __init__(self, input: t.BinaryIO, output: t.BinaryIO) -> None:
  18. self._input = input
  19. self._output = output
  20. self._paused = False
  21. def __getattr__(self, x: str) -> t.Any:
  22. return getattr(self._input, x)
  23. def _echo(self, rv: bytes) -> bytes:
  24. if not self._paused:
  25. self._output.write(rv)
  26. return rv
  27. def read(self, n: int = -1) -> bytes:
  28. return self._echo(self._input.read(n))
  29. def read1(self, n: int = -1) -> bytes:
  30. return self._echo(self._input.read1(n)) # type: ignore
  31. def readline(self, n: int = -1) -> bytes:
  32. return self._echo(self._input.readline(n))
  33. def readlines(self) -> t.List[bytes]:
  34. return [self._echo(x) for x in self._input.readlines()]
  35. def __iter__(self) -> t.Iterator[bytes]:
  36. return iter(self._echo(x) for x in self._input)
  37. def __repr__(self) -> str:
  38. return repr(self._input)
  39. @contextlib.contextmanager
  40. def _pause_echo(stream: t.Optional[EchoingStdin]) -> t.Iterator[None]:
  41. if stream is None:
  42. yield
  43. else:
  44. stream._paused = True
  45. yield
  46. stream._paused = False
  47. class _NamedTextIOWrapper(io.TextIOWrapper):
  48. def __init__(
  49. self, buffer: t.BinaryIO, name: str, mode: str, **kwargs: t.Any
  50. ) -> None:
  51. super().__init__(buffer, **kwargs)
  52. self._name = name
  53. self._mode = mode
  54. @property
  55. def name(self) -> str:
  56. return self._name
  57. @property
  58. def mode(self) -> str:
  59. return self._mode
  60. def make_input_stream(
  61. input: t.Optional[t.Union[str, bytes, t.IO]], charset: str
  62. ) -> t.BinaryIO:
  63. # Is already an input stream.
  64. if hasattr(input, "read"):
  65. rv = _find_binary_reader(t.cast(t.IO, input))
  66. if rv is not None:
  67. return rv
  68. raise TypeError("Could not find binary reader for input stream.")
  69. if input is None:
  70. input = b""
  71. elif isinstance(input, str):
  72. input = input.encode(charset)
  73. return io.BytesIO(t.cast(bytes, input))
  74. class Result:
  75. """Holds the captured result of an invoked CLI script."""
  76. def __init__(
  77. self,
  78. runner: "CliRunner",
  79. stdout_bytes: bytes,
  80. stderr_bytes: t.Optional[bytes],
  81. return_value: t.Any,
  82. exit_code: int,
  83. exception: t.Optional[BaseException],
  84. exc_info: t.Optional[
  85. t.Tuple[t.Type[BaseException], BaseException, TracebackType]
  86. ] = None,
  87. ):
  88. #: The runner that created the result
  89. self.runner = runner
  90. #: The standard output as bytes.
  91. self.stdout_bytes = stdout_bytes
  92. #: The standard error as bytes, or None if not available
  93. self.stderr_bytes = stderr_bytes
  94. #: The value returned from the invoked command.
  95. #:
  96. #: .. versionadded:: 8.0
  97. self.return_value = return_value
  98. #: The exit code as integer.
  99. self.exit_code = exit_code
  100. #: The exception that happened if one did.
  101. self.exception = exception
  102. #: The traceback
  103. self.exc_info = exc_info
  104. @property
  105. def output(self) -> str:
  106. """The (standard) output as unicode string."""
  107. return self.stdout
  108. @property
  109. def stdout(self) -> str:
  110. """The standard output as unicode string."""
  111. return self.stdout_bytes.decode(self.runner.charset, "replace").replace(
  112. "\r\n", "\n"
  113. )
  114. @property
  115. def stderr(self) -> str:
  116. """The standard error as unicode string."""
  117. if self.stderr_bytes is None:
  118. raise ValueError("stderr not separately captured")
  119. return self.stderr_bytes.decode(self.runner.charset, "replace").replace(
  120. "\r\n", "\n"
  121. )
  122. def __repr__(self) -> str:
  123. exc_str = repr(self.exception) if self.exception else "okay"
  124. return f"<{type(self).__name__} {exc_str}>"
  125. class CliRunner:
  126. """The CLI runner provides functionality to invoke a Click command line
  127. script for unittesting purposes in a isolated environment. This only
  128. works in single-threaded systems without any concurrency as it changes the
  129. global interpreter state.
  130. :param charset: the character set for the input and output data.
  131. :param env: a dictionary with environment variables for overriding.
  132. :param echo_stdin: if this is set to `True`, then reading from stdin writes
  133. to stdout. This is useful for showing examples in
  134. some circumstances. Note that regular prompts
  135. will automatically echo the input.
  136. :param mix_stderr: if this is set to `False`, then stdout and stderr are
  137. preserved as independent streams. This is useful for
  138. Unix-philosophy apps that have predictable stdout and
  139. noisy stderr, such that each may be measured
  140. independently
  141. """
  142. def __init__(
  143. self,
  144. charset: str = "utf-8",
  145. env: t.Optional[t.Mapping[str, t.Optional[str]]] = None,
  146. echo_stdin: bool = False,
  147. mix_stderr: bool = True,
  148. ) -> None:
  149. self.charset = charset
  150. self.env = env or {}
  151. self.echo_stdin = echo_stdin
  152. self.mix_stderr = mix_stderr
  153. def get_default_prog_name(self, cli: "BaseCommand") -> str:
  154. """Given a command object it will return the default program name
  155. for it. The default is the `name` attribute or ``"root"`` if not
  156. set.
  157. """
  158. return cli.name or "root"
  159. def make_env(
  160. self, overrides: t.Optional[t.Mapping[str, t.Optional[str]]] = None
  161. ) -> t.Mapping[str, t.Optional[str]]:
  162. """Returns the environment overrides for invoking a script."""
  163. rv = dict(self.env)
  164. if overrides:
  165. rv.update(overrides)
  166. return rv
  167. @contextlib.contextmanager
  168. def isolation(
  169. self,
  170. input: t.Optional[t.Union[str, bytes, t.IO]] = None,
  171. env: t.Optional[t.Mapping[str, t.Optional[str]]] = None,
  172. color: bool = False,
  173. ) -> t.Iterator[t.Tuple[io.BytesIO, t.Optional[io.BytesIO]]]:
  174. """A context manager that sets up the isolation for invoking of a
  175. command line tool. This sets up stdin with the given input data
  176. and `os.environ` with the overrides from the given dictionary.
  177. This also rebinds some internals in Click to be mocked (like the
  178. prompt functionality).
  179. This is automatically done in the :meth:`invoke` method.
  180. :param input: the input stream to put into sys.stdin.
  181. :param env: the environment overrides as dictionary.
  182. :param color: whether the output should contain color codes. The
  183. application can still override this explicitly.
  184. .. versionchanged:: 8.0
  185. ``stderr`` is opened with ``errors="backslashreplace"``
  186. instead of the default ``"strict"``.
  187. .. versionchanged:: 4.0
  188. Added the ``color`` parameter.
  189. """
  190. bytes_input = make_input_stream(input, self.charset)
  191. echo_input = None
  192. old_stdin = sys.stdin
  193. old_stdout = sys.stdout
  194. old_stderr = sys.stderr
  195. old_forced_width = formatting.FORCED_WIDTH
  196. formatting.FORCED_WIDTH = 80
  197. env = self.make_env(env)
  198. bytes_output = io.BytesIO()
  199. if self.echo_stdin:
  200. bytes_input = echo_input = t.cast(
  201. t.BinaryIO, EchoingStdin(bytes_input, bytes_output)
  202. )
  203. sys.stdin = text_input = _NamedTextIOWrapper(
  204. bytes_input, encoding=self.charset, name="<stdin>", mode="r"
  205. )
  206. if self.echo_stdin:
  207. # Force unbuffered reads, otherwise TextIOWrapper reads a
  208. # large chunk which is echoed early.
  209. text_input._CHUNK_SIZE = 1 # type: ignore
  210. sys.stdout = _NamedTextIOWrapper(
  211. bytes_output, encoding=self.charset, name="<stdout>", mode="w"
  212. )
  213. bytes_error = None
  214. if self.mix_stderr:
  215. sys.stderr = sys.stdout
  216. else:
  217. bytes_error = io.BytesIO()
  218. sys.stderr = _NamedTextIOWrapper(
  219. bytes_error,
  220. encoding=self.charset,
  221. name="<stderr>",
  222. mode="w",
  223. errors="backslashreplace",
  224. )
  225. @_pause_echo(echo_input) # type: ignore
  226. def visible_input(prompt: t.Optional[str] = None) -> str:
  227. sys.stdout.write(prompt or "")
  228. val = text_input.readline().rstrip("\r\n")
  229. sys.stdout.write(f"{val}\n")
  230. sys.stdout.flush()
  231. return val
  232. @_pause_echo(echo_input) # type: ignore
  233. def hidden_input(prompt: t.Optional[str] = None) -> str:
  234. sys.stdout.write(f"{prompt or ''}\n")
  235. sys.stdout.flush()
  236. return text_input.readline().rstrip("\r\n")
  237. @_pause_echo(echo_input) # type: ignore
  238. def _getchar(echo: bool) -> str:
  239. char = sys.stdin.read(1)
  240. if echo:
  241. sys.stdout.write(char)
  242. sys.stdout.flush()
  243. return char
  244. default_color = color
  245. def should_strip_ansi(
  246. stream: t.Optional[t.IO] = None, color: t.Optional[bool] = None
  247. ) -> bool:
  248. if color is None:
  249. return not default_color
  250. return not color
  251. old_visible_prompt_func = termui.visible_prompt_func
  252. old_hidden_prompt_func = termui.hidden_prompt_func
  253. old__getchar_func = termui._getchar
  254. old_should_strip_ansi = utils.should_strip_ansi # type: ignore
  255. termui.visible_prompt_func = visible_input
  256. termui.hidden_prompt_func = hidden_input
  257. termui._getchar = _getchar
  258. utils.should_strip_ansi = should_strip_ansi # type: ignore
  259. old_env = {}
  260. try:
  261. for key, value in env.items():
  262. old_env[key] = os.environ.get(key)
  263. if value is None:
  264. try:
  265. del os.environ[key]
  266. except Exception:
  267. pass
  268. else:
  269. os.environ[key] = value
  270. yield (bytes_output, bytes_error)
  271. finally:
  272. for key, value in old_env.items():
  273. if value is None:
  274. try:
  275. del os.environ[key]
  276. except Exception:
  277. pass
  278. else:
  279. os.environ[key] = value
  280. sys.stdout = old_stdout
  281. sys.stderr = old_stderr
  282. sys.stdin = old_stdin
  283. termui.visible_prompt_func = old_visible_prompt_func
  284. termui.hidden_prompt_func = old_hidden_prompt_func
  285. termui._getchar = old__getchar_func
  286. utils.should_strip_ansi = old_should_strip_ansi # type: ignore
  287. formatting.FORCED_WIDTH = old_forced_width
  288. def invoke(
  289. self,
  290. cli: "BaseCommand",
  291. args: t.Optional[t.Union[str, t.Sequence[str]]] = None,
  292. input: t.Optional[t.Union[str, bytes, t.IO]] = None,
  293. env: t.Optional[t.Mapping[str, t.Optional[str]]] = None,
  294. catch_exceptions: bool = True,
  295. color: bool = False,
  296. **extra: t.Any,
  297. ) -> Result:
  298. """Invokes a command in an isolated environment. The arguments are
  299. forwarded directly to the command line script, the `extra` keyword
  300. arguments are passed to the :meth:`~clickpkg.Command.main` function of
  301. the command.
  302. This returns a :class:`Result` object.
  303. :param cli: the command to invoke
  304. :param args: the arguments to invoke. It may be given as an iterable
  305. or a string. When given as string it will be interpreted
  306. as a Unix shell command. More details at
  307. :func:`shlex.split`.
  308. :param input: the input data for `sys.stdin`.
  309. :param env: the environment overrides.
  310. :param catch_exceptions: Whether to catch any other exceptions than
  311. ``SystemExit``.
  312. :param extra: the keyword arguments to pass to :meth:`main`.
  313. :param color: whether the output should contain color codes. The
  314. application can still override this explicitly.
  315. .. versionchanged:: 8.0
  316. The result object has the ``return_value`` attribute with
  317. the value returned from the invoked command.
  318. .. versionchanged:: 4.0
  319. Added the ``color`` parameter.
  320. .. versionchanged:: 3.0
  321. Added the ``catch_exceptions`` parameter.
  322. .. versionchanged:: 3.0
  323. The result object has the ``exc_info`` attribute with the
  324. traceback if available.
  325. """
  326. exc_info = None
  327. with self.isolation(input=input, env=env, color=color) as outstreams:
  328. return_value = None
  329. exception: t.Optional[BaseException] = None
  330. exit_code = 0
  331. if isinstance(args, str):
  332. args = shlex.split(args)
  333. try:
  334. prog_name = extra.pop("prog_name")
  335. except KeyError:
  336. prog_name = self.get_default_prog_name(cli)
  337. try:
  338. return_value = cli.main(args=args or (), prog_name=prog_name, **extra)
  339. except SystemExit as e:
  340. exc_info = sys.exc_info()
  341. e_code = t.cast(t.Optional[t.Union[int, t.Any]], e.code)
  342. if e_code is None:
  343. e_code = 0
  344. if e_code != 0:
  345. exception = e
  346. if not isinstance(e_code, int):
  347. sys.stdout.write(str(e_code))
  348. sys.stdout.write("\n")
  349. e_code = 1
  350. exit_code = e_code
  351. except Exception as e:
  352. if not catch_exceptions:
  353. raise
  354. exception = e
  355. exit_code = 1
  356. exc_info = sys.exc_info()
  357. finally:
  358. sys.stdout.flush()
  359. stdout = outstreams[0].getvalue()
  360. if self.mix_stderr:
  361. stderr = None
  362. else:
  363. stderr = outstreams[1].getvalue() # type: ignore
  364. return Result(
  365. runner=self,
  366. stdout_bytes=stdout,
  367. stderr_bytes=stderr,
  368. return_value=return_value,
  369. exit_code=exit_code,
  370. exception=exception,
  371. exc_info=exc_info, # type: ignore
  372. )
  373. @contextlib.contextmanager
  374. def isolated_filesystem(
  375. self, temp_dir: t.Optional[t.Union[str, os.PathLike]] = None
  376. ) -> t.Iterator[str]:
  377. """A context manager that creates a temporary directory and
  378. changes the current working directory to it. This isolates tests
  379. that affect the contents of the CWD to prevent them from
  380. interfering with each other.
  381. :param temp_dir: Create the temporary directory under this
  382. directory. If given, the created directory is not removed
  383. when exiting.
  384. .. versionchanged:: 8.0
  385. Added the ``temp_dir`` parameter.
  386. """
  387. cwd = os.getcwd()
  388. dt = tempfile.mkdtemp(dir=temp_dir) # type: ignore[type-var]
  389. os.chdir(dt)
  390. try:
  391. yield t.cast(str, dt)
  392. finally:
  393. os.chdir(cwd)
  394. if temp_dir is None:
  395. try:
  396. shutil.rmtree(dt)
  397. except OSError: # noqa: B014
  398. pass