_reloader.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. import fnmatch
  2. import os
  3. import subprocess
  4. import sys
  5. import threading
  6. import time
  7. import typing as t
  8. from itertools import chain
  9. from pathlib import PurePath
  10. from ._internal import _log
  11. # The various system prefixes where imports are found. Base values are
  12. # different when running in a virtualenv. All reloaders will ignore the
  13. # base paths (usually the system installation). The stat reloader won't
  14. # scan the virtualenv paths, it will only include modules that are
  15. # already imported.
  16. _ignore_always = tuple({sys.base_prefix, sys.base_exec_prefix})
  17. prefix = {*_ignore_always, sys.prefix, sys.exec_prefix}
  18. if hasattr(sys, "real_prefix"):
  19. # virtualenv < 20
  20. prefix.add(sys.real_prefix) # type: ignore[attr-defined]
  21. _stat_ignore_scan = tuple(prefix)
  22. del prefix
  23. _ignore_common_dirs = {
  24. "__pycache__",
  25. ".git",
  26. ".hg",
  27. ".tox",
  28. ".nox",
  29. ".pytest_cache",
  30. ".mypy_cache",
  31. }
  32. def _iter_module_paths() -> t.Iterator[str]:
  33. """Find the filesystem paths associated with imported modules."""
  34. # List is in case the value is modified by the app while updating.
  35. for module in list(sys.modules.values()):
  36. name = getattr(module, "__file__", None)
  37. if name is None or name.startswith(_ignore_always):
  38. continue
  39. while not os.path.isfile(name):
  40. # Zip file, find the base file without the module path.
  41. old = name
  42. name = os.path.dirname(name)
  43. if name == old: # skip if it was all directories somehow
  44. break
  45. else:
  46. yield name
  47. def _remove_by_pattern(paths: t.Set[str], exclude_patterns: t.Set[str]) -> None:
  48. for pattern in exclude_patterns:
  49. paths.difference_update(fnmatch.filter(paths, pattern))
  50. def _find_stat_paths(
  51. extra_files: t.Set[str], exclude_patterns: t.Set[str]
  52. ) -> t.Iterable[str]:
  53. """Find paths for the stat reloader to watch. Returns imported
  54. module files, Python files under non-system paths. Extra files and
  55. Python files under extra directories can also be scanned.
  56. System paths have to be excluded for efficiency. Non-system paths,
  57. such as a project root or ``sys.path.insert``, should be the paths
  58. of interest to the user anyway.
  59. """
  60. paths = set()
  61. for path in chain(list(sys.path), extra_files):
  62. path = os.path.abspath(path)
  63. if os.path.isfile(path):
  64. # zip file on sys.path, or extra file
  65. paths.add(path)
  66. continue
  67. parent_has_py = {os.path.dirname(path): True}
  68. for root, dirs, files in os.walk(path):
  69. # Optimizations: ignore system prefixes, __pycache__ will
  70. # have a py or pyc module at the import path, ignore some
  71. # common known dirs such as version control and tool caches.
  72. if (
  73. root.startswith(_stat_ignore_scan)
  74. or os.path.basename(root) in _ignore_common_dirs
  75. ):
  76. dirs.clear()
  77. continue
  78. has_py = False
  79. for name in files:
  80. if name.endswith((".py", ".pyc")):
  81. has_py = True
  82. paths.add(os.path.join(root, name))
  83. # Optimization: stop scanning a directory if neither it nor
  84. # its parent contained Python files.
  85. if not (has_py or parent_has_py[os.path.dirname(root)]):
  86. dirs.clear()
  87. continue
  88. parent_has_py[root] = has_py
  89. paths.update(_iter_module_paths())
  90. _remove_by_pattern(paths, exclude_patterns)
  91. return paths
  92. def _find_watchdog_paths(
  93. extra_files: t.Set[str], exclude_patterns: t.Set[str]
  94. ) -> t.Iterable[str]:
  95. """Find paths for the stat reloader to watch. Looks at the same
  96. sources as the stat reloader, but watches everything under
  97. directories instead of individual files.
  98. """
  99. dirs = set()
  100. for name in chain(list(sys.path), extra_files):
  101. name = os.path.abspath(name)
  102. if os.path.isfile(name):
  103. name = os.path.dirname(name)
  104. dirs.add(name)
  105. for name in _iter_module_paths():
  106. dirs.add(os.path.dirname(name))
  107. _remove_by_pattern(dirs, exclude_patterns)
  108. return _find_common_roots(dirs)
  109. def _find_common_roots(paths: t.Iterable[str]) -> t.Iterable[str]:
  110. root: t.Dict[str, dict] = {}
  111. for chunks in sorted((PurePath(x).parts for x in paths), key=len, reverse=True):
  112. node = root
  113. for chunk in chunks:
  114. node = node.setdefault(chunk, {})
  115. node.clear()
  116. rv = set()
  117. def _walk(node: t.Mapping[str, dict], path: t.Tuple[str, ...]) -> None:
  118. for prefix, child in node.items():
  119. _walk(child, path + (prefix,))
  120. if not node:
  121. rv.add(os.path.join(*path))
  122. _walk(root, ())
  123. return rv
  124. def _get_args_for_reloading() -> t.List[str]:
  125. """Determine how the script was executed, and return the args needed
  126. to execute it again in a new process.
  127. """
  128. rv = [sys.executable]
  129. py_script = sys.argv[0]
  130. args = sys.argv[1:]
  131. # Need to look at main module to determine how it was executed.
  132. __main__ = sys.modules["__main__"]
  133. # The value of __package__ indicates how Python was called. It may
  134. # not exist if a setuptools script is installed as an egg. It may be
  135. # set incorrectly for entry points created with pip on Windows.
  136. if getattr(__main__, "__package__", None) is None or (
  137. os.name == "nt"
  138. and __main__.__package__ == ""
  139. and not os.path.exists(py_script)
  140. and os.path.exists(f"{py_script}.exe")
  141. ):
  142. # Executed a file, like "python app.py".
  143. py_script = os.path.abspath(py_script)
  144. if os.name == "nt":
  145. # Windows entry points have ".exe" extension and should be
  146. # called directly.
  147. if not os.path.exists(py_script) and os.path.exists(f"{py_script}.exe"):
  148. py_script += ".exe"
  149. if (
  150. os.path.splitext(sys.executable)[1] == ".exe"
  151. and os.path.splitext(py_script)[1] == ".exe"
  152. ):
  153. rv.pop(0)
  154. rv.append(py_script)
  155. else:
  156. # Executed a module, like "python -m werkzeug.serving".
  157. if os.path.isfile(py_script):
  158. # Rewritten by Python from "-m script" to "/path/to/script.py".
  159. py_module = t.cast(str, __main__.__package__)
  160. name = os.path.splitext(os.path.basename(py_script))[0]
  161. if name != "__main__":
  162. py_module += f".{name}"
  163. else:
  164. # Incorrectly rewritten by pydevd debugger from "-m script" to "script".
  165. py_module = py_script
  166. rv.extend(("-m", py_module.lstrip(".")))
  167. rv.extend(args)
  168. return rv
  169. class ReloaderLoop:
  170. name = ""
  171. def __init__(
  172. self,
  173. extra_files: t.Optional[t.Iterable[str]] = None,
  174. exclude_patterns: t.Optional[t.Iterable[str]] = None,
  175. interval: t.Union[int, float] = 1,
  176. ) -> None:
  177. self.extra_files: t.Set[str] = {os.path.abspath(x) for x in extra_files or ()}
  178. self.exclude_patterns: t.Set[str] = set(exclude_patterns or ())
  179. self.interval = interval
  180. def __enter__(self) -> "ReloaderLoop":
  181. """Do any setup, then run one step of the watch to populate the
  182. initial filesystem state.
  183. """
  184. self.run_step()
  185. return self
  186. def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore
  187. """Clean up any resources associated with the reloader."""
  188. pass
  189. def run(self) -> None:
  190. """Continually run the watch step, sleeping for the configured
  191. interval after each step.
  192. """
  193. while True:
  194. self.run_step()
  195. time.sleep(self.interval)
  196. def run_step(self) -> None:
  197. """Run one step for watching the filesystem. Called once to set
  198. up initial state, then repeatedly to update it.
  199. """
  200. pass
  201. def restart_with_reloader(self) -> int:
  202. """Spawn a new Python interpreter with the same arguments as the
  203. current one, but running the reloader thread.
  204. """
  205. while True:
  206. _log("info", f" * Restarting with {self.name}")
  207. args = _get_args_for_reloading()
  208. new_environ = os.environ.copy()
  209. new_environ["WERKZEUG_RUN_MAIN"] = "true"
  210. exit_code = subprocess.call(args, env=new_environ, close_fds=False)
  211. if exit_code != 3:
  212. return exit_code
  213. def trigger_reload(self, filename: str) -> None:
  214. self.log_reload(filename)
  215. sys.exit(3)
  216. def log_reload(self, filename: str) -> None:
  217. filename = os.path.abspath(filename)
  218. _log("info", f" * Detected change in {filename!r}, reloading")
  219. class StatReloaderLoop(ReloaderLoop):
  220. name = "stat"
  221. def __enter__(self) -> ReloaderLoop:
  222. self.mtimes: t.Dict[str, float] = {}
  223. return super().__enter__()
  224. def run_step(self) -> None:
  225. for name in _find_stat_paths(self.extra_files, self.exclude_patterns):
  226. try:
  227. mtime = os.stat(name).st_mtime
  228. except OSError:
  229. continue
  230. old_time = self.mtimes.get(name)
  231. if old_time is None:
  232. self.mtimes[name] = mtime
  233. continue
  234. if mtime > old_time:
  235. self.trigger_reload(name)
  236. class WatchdogReloaderLoop(ReloaderLoop):
  237. def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
  238. from watchdog.observers import Observer
  239. from watchdog.events import PatternMatchingEventHandler
  240. super().__init__(*args, **kwargs)
  241. trigger_reload = self.trigger_reload
  242. class EventHandler(PatternMatchingEventHandler): # type: ignore
  243. def on_any_event(self, event): # type: ignore
  244. trigger_reload(event.src_path)
  245. reloader_name = Observer.__name__.lower()
  246. if reloader_name.endswith("observer"):
  247. reloader_name = reloader_name[:-8]
  248. self.name = f"watchdog ({reloader_name})"
  249. self.observer = Observer()
  250. # Extra patterns can be non-Python files, match them in addition
  251. # to all Python files in default and extra directories. Ignore
  252. # __pycache__ since a change there will always have a change to
  253. # the source file (or initial pyc file) as well. Ignore Git and
  254. # Mercurial internal changes.
  255. extra_patterns = [p for p in self.extra_files if not os.path.isdir(p)]
  256. self.event_handler = EventHandler(
  257. patterns=["*.py", "*.pyc", "*.zip", *extra_patterns],
  258. ignore_patterns=[
  259. *[f"*/{d}/*" for d in _ignore_common_dirs],
  260. *self.exclude_patterns,
  261. ],
  262. )
  263. self.should_reload = False
  264. def trigger_reload(self, filename: str) -> None:
  265. # This is called inside an event handler, which means throwing
  266. # SystemExit has no effect.
  267. # https://github.com/gorakhargosh/watchdog/issues/294
  268. self.should_reload = True
  269. self.log_reload(filename)
  270. def __enter__(self) -> ReloaderLoop:
  271. self.watches: t.Dict[str, t.Any] = {}
  272. self.observer.start()
  273. return super().__enter__()
  274. def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore
  275. self.observer.stop()
  276. self.observer.join()
  277. def run(self) -> None:
  278. while not self.should_reload:
  279. self.run_step()
  280. time.sleep(self.interval)
  281. sys.exit(3)
  282. def run_step(self) -> None:
  283. to_delete = set(self.watches)
  284. for path in _find_watchdog_paths(self.extra_files, self.exclude_patterns):
  285. if path not in self.watches:
  286. try:
  287. self.watches[path] = self.observer.schedule(
  288. self.event_handler, path, recursive=True
  289. )
  290. except OSError:
  291. # Clear this path from list of watches We don't want
  292. # the same error message showing again in the next
  293. # iteration.
  294. self.watches[path] = None
  295. to_delete.discard(path)
  296. for path in to_delete:
  297. watch = self.watches.pop(path, None)
  298. if watch is not None:
  299. self.observer.unschedule(watch)
  300. reloader_loops: t.Dict[str, t.Type[ReloaderLoop]] = {
  301. "stat": StatReloaderLoop,
  302. "watchdog": WatchdogReloaderLoop,
  303. }
  304. try:
  305. __import__("watchdog.observers")
  306. except ImportError:
  307. reloader_loops["auto"] = reloader_loops["stat"]
  308. else:
  309. reloader_loops["auto"] = reloader_loops["watchdog"]
  310. def ensure_echo_on() -> None:
  311. """Ensure that echo mode is enabled. Some tools such as PDB disable
  312. it which causes usability issues after a reload."""
  313. # tcgetattr will fail if stdin isn't a tty
  314. if sys.stdin is None or not sys.stdin.isatty():
  315. return
  316. try:
  317. import termios
  318. except ImportError:
  319. return
  320. attributes = termios.tcgetattr(sys.stdin)
  321. if not attributes[3] & termios.ECHO:
  322. attributes[3] |= termios.ECHO
  323. termios.tcsetattr(sys.stdin, termios.TCSANOW, attributes)
  324. def run_with_reloader(
  325. main_func: t.Callable[[], None],
  326. extra_files: t.Optional[t.Iterable[str]] = None,
  327. exclude_patterns: t.Optional[t.Iterable[str]] = None,
  328. interval: t.Union[int, float] = 1,
  329. reloader_type: str = "auto",
  330. ) -> None:
  331. """Run the given function in an independent Python interpreter."""
  332. import signal
  333. signal.signal(signal.SIGTERM, lambda *args: sys.exit(0))
  334. reloader = reloader_loops[reloader_type](
  335. extra_files=extra_files, exclude_patterns=exclude_patterns, interval=interval
  336. )
  337. try:
  338. if os.environ.get("WERKZEUG_RUN_MAIN") == "true":
  339. ensure_echo_on()
  340. t = threading.Thread(target=main_func, args=())
  341. t.daemon = True
  342. # Enter the reloader to set up initial state, then start
  343. # the app thread and reloader update loop.
  344. with reloader:
  345. t.start()
  346. reloader.run()
  347. else:
  348. sys.exit(reloader.restart_with_reloader())
  349. except KeyboardInterrupt:
  350. pass