_winconsole.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # This module is based on the excellent work by Adam Bartoš who
  2. # provided a lot of what went into the implementation here in
  3. # the discussion to issue1602 in the Python bug tracker.
  4. #
  5. # There are some general differences in regards to how this works
  6. # compared to the original patches as we do not need to patch
  7. # the entire interpreter but just work in our little world of
  8. # echo and prompt.
  9. import io
  10. import sys
  11. import time
  12. import typing as t
  13. from ctypes import byref
  14. from ctypes import c_char
  15. from ctypes import c_char_p
  16. from ctypes import c_int
  17. from ctypes import c_ssize_t
  18. from ctypes import c_ulong
  19. from ctypes import c_void_p
  20. from ctypes import POINTER
  21. from ctypes import py_object
  22. from ctypes import Structure
  23. from ctypes.wintypes import DWORD
  24. from ctypes.wintypes import HANDLE
  25. from ctypes.wintypes import LPCWSTR
  26. from ctypes.wintypes import LPWSTR
  27. from ._compat import _NonClosingTextIOWrapper
  28. assert sys.platform == "win32"
  29. import msvcrt # noqa: E402
  30. from ctypes import windll # noqa: E402
  31. from ctypes import WINFUNCTYPE # noqa: E402
  32. c_ssize_p = POINTER(c_ssize_t)
  33. kernel32 = windll.kernel32
  34. GetStdHandle = kernel32.GetStdHandle
  35. ReadConsoleW = kernel32.ReadConsoleW
  36. WriteConsoleW = kernel32.WriteConsoleW
  37. GetConsoleMode = kernel32.GetConsoleMode
  38. GetLastError = kernel32.GetLastError
  39. GetCommandLineW = WINFUNCTYPE(LPWSTR)(("GetCommandLineW", windll.kernel32))
  40. CommandLineToArgvW = WINFUNCTYPE(POINTER(LPWSTR), LPCWSTR, POINTER(c_int))(
  41. ("CommandLineToArgvW", windll.shell32)
  42. )
  43. LocalFree = WINFUNCTYPE(c_void_p, c_void_p)(("LocalFree", windll.kernel32))
  44. STDIN_HANDLE = GetStdHandle(-10)
  45. STDOUT_HANDLE = GetStdHandle(-11)
  46. STDERR_HANDLE = GetStdHandle(-12)
  47. PyBUF_SIMPLE = 0
  48. PyBUF_WRITABLE = 1
  49. ERROR_SUCCESS = 0
  50. ERROR_NOT_ENOUGH_MEMORY = 8
  51. ERROR_OPERATION_ABORTED = 995
  52. STDIN_FILENO = 0
  53. STDOUT_FILENO = 1
  54. STDERR_FILENO = 2
  55. EOF = b"\x1a"
  56. MAX_BYTES_WRITTEN = 32767
  57. try:
  58. from ctypes import pythonapi
  59. except ImportError:
  60. # On PyPy we cannot get buffers so our ability to operate here is
  61. # severely limited.
  62. get_buffer = None
  63. else:
  64. class Py_buffer(Structure):
  65. _fields_ = [
  66. ("buf", c_void_p),
  67. ("obj", py_object),
  68. ("len", c_ssize_t),
  69. ("itemsize", c_ssize_t),
  70. ("readonly", c_int),
  71. ("ndim", c_int),
  72. ("format", c_char_p),
  73. ("shape", c_ssize_p),
  74. ("strides", c_ssize_p),
  75. ("suboffsets", c_ssize_p),
  76. ("internal", c_void_p),
  77. ]
  78. PyObject_GetBuffer = pythonapi.PyObject_GetBuffer
  79. PyBuffer_Release = pythonapi.PyBuffer_Release
  80. def get_buffer(obj, writable=False):
  81. buf = Py_buffer()
  82. flags = PyBUF_WRITABLE if writable else PyBUF_SIMPLE
  83. PyObject_GetBuffer(py_object(obj), byref(buf), flags)
  84. try:
  85. buffer_type = c_char * buf.len
  86. return buffer_type.from_address(buf.buf)
  87. finally:
  88. PyBuffer_Release(byref(buf))
  89. class _WindowsConsoleRawIOBase(io.RawIOBase):
  90. def __init__(self, handle):
  91. self.handle = handle
  92. def isatty(self):
  93. super().isatty()
  94. return True
  95. class _WindowsConsoleReader(_WindowsConsoleRawIOBase):
  96. def readable(self):
  97. return True
  98. def readinto(self, b):
  99. bytes_to_be_read = len(b)
  100. if not bytes_to_be_read:
  101. return 0
  102. elif bytes_to_be_read % 2:
  103. raise ValueError(
  104. "cannot read odd number of bytes from UTF-16-LE encoded console"
  105. )
  106. buffer = get_buffer(b, writable=True)
  107. code_units_to_be_read = bytes_to_be_read // 2
  108. code_units_read = c_ulong()
  109. rv = ReadConsoleW(
  110. HANDLE(self.handle),
  111. buffer,
  112. code_units_to_be_read,
  113. byref(code_units_read),
  114. None,
  115. )
  116. if GetLastError() == ERROR_OPERATION_ABORTED:
  117. # wait for KeyboardInterrupt
  118. time.sleep(0.1)
  119. if not rv:
  120. raise OSError(f"Windows error: {GetLastError()}")
  121. if buffer[0] == EOF:
  122. return 0
  123. return 2 * code_units_read.value
  124. class _WindowsConsoleWriter(_WindowsConsoleRawIOBase):
  125. def writable(self):
  126. return True
  127. @staticmethod
  128. def _get_error_message(errno):
  129. if errno == ERROR_SUCCESS:
  130. return "ERROR_SUCCESS"
  131. elif errno == ERROR_NOT_ENOUGH_MEMORY:
  132. return "ERROR_NOT_ENOUGH_MEMORY"
  133. return f"Windows error {errno}"
  134. def write(self, b):
  135. bytes_to_be_written = len(b)
  136. buf = get_buffer(b)
  137. code_units_to_be_written = min(bytes_to_be_written, MAX_BYTES_WRITTEN) // 2
  138. code_units_written = c_ulong()
  139. WriteConsoleW(
  140. HANDLE(self.handle),
  141. buf,
  142. code_units_to_be_written,
  143. byref(code_units_written),
  144. None,
  145. )
  146. bytes_written = 2 * code_units_written.value
  147. if bytes_written == 0 and bytes_to_be_written > 0:
  148. raise OSError(self._get_error_message(GetLastError()))
  149. return bytes_written
  150. class ConsoleStream:
  151. def __init__(self, text_stream: t.TextIO, byte_stream: t.BinaryIO) -> None:
  152. self._text_stream = text_stream
  153. self.buffer = byte_stream
  154. @property
  155. def name(self) -> str:
  156. return self.buffer.name
  157. def write(self, x: t.AnyStr) -> int:
  158. if isinstance(x, str):
  159. return self._text_stream.write(x)
  160. try:
  161. self.flush()
  162. except Exception:
  163. pass
  164. return self.buffer.write(x)
  165. def writelines(self, lines: t.Iterable[t.AnyStr]) -> None:
  166. for line in lines:
  167. self.write(line)
  168. def __getattr__(self, name: str) -> t.Any:
  169. return getattr(self._text_stream, name)
  170. def isatty(self) -> bool:
  171. return self.buffer.isatty()
  172. def __repr__(self):
  173. return f"<ConsoleStream name={self.name!r} encoding={self.encoding!r}>"
  174. def _get_text_stdin(buffer_stream: t.BinaryIO) -> t.TextIO:
  175. text_stream = _NonClosingTextIOWrapper(
  176. io.BufferedReader(_WindowsConsoleReader(STDIN_HANDLE)),
  177. "utf-16-le",
  178. "strict",
  179. line_buffering=True,
  180. )
  181. return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
  182. def _get_text_stdout(buffer_stream: t.BinaryIO) -> t.TextIO:
  183. text_stream = _NonClosingTextIOWrapper(
  184. io.BufferedWriter(_WindowsConsoleWriter(STDOUT_HANDLE)),
  185. "utf-16-le",
  186. "strict",
  187. line_buffering=True,
  188. )
  189. return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
  190. def _get_text_stderr(buffer_stream: t.BinaryIO) -> t.TextIO:
  191. text_stream = _NonClosingTextIOWrapper(
  192. io.BufferedWriter(_WindowsConsoleWriter(STDERR_HANDLE)),
  193. "utf-16-le",
  194. "strict",
  195. line_buffering=True,
  196. )
  197. return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
  198. _stream_factories: t.Mapping[int, t.Callable[[t.BinaryIO], t.TextIO]] = {
  199. 0: _get_text_stdin,
  200. 1: _get_text_stdout,
  201. 2: _get_text_stderr,
  202. }
  203. def _is_console(f: t.TextIO) -> bool:
  204. if not hasattr(f, "fileno"):
  205. return False
  206. try:
  207. fileno = f.fileno()
  208. except (OSError, io.UnsupportedOperation):
  209. return False
  210. handle = msvcrt.get_osfhandle(fileno)
  211. return bool(GetConsoleMode(handle, byref(DWORD())))
  212. def _get_windows_console_stream(
  213. f: t.TextIO, encoding: t.Optional[str], errors: t.Optional[str]
  214. ) -> t.Optional[t.TextIO]:
  215. if (
  216. get_buffer is not None
  217. and encoding in {"utf-16-le", None}
  218. and errors in {"strict", None}
  219. and _is_console(f)
  220. ):
  221. func = _stream_factories.get(f.fileno())
  222. if func is not None:
  223. b = getattr(f, "buffer", None)
  224. if b is None:
  225. return None
  226. return func(b)