utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. #!/usr/bin/env python
  2. # License: GPLv3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
  3. import os
  4. import subprocess
  5. import traceback
  6. from collections.abc import Iterator, Sequence
  7. from contextlib import suppress
  8. from typing import Any
  9. from kitty.types import run_once
  10. from kitty.utils import SSHConnectionData
  11. @run_once
  12. def ssh_options() -> dict[str, str]:
  13. try:
  14. p = subprocess.run(['ssh'], stderr=subprocess.PIPE, encoding='utf-8')
  15. raw = p.stderr or ''
  16. except FileNotFoundError:
  17. return {
  18. '4': '', '6': '', 'A': '', 'a': '', 'C': '', 'f': '', 'G': '', 'g': '', 'K': '', 'k': '',
  19. 'M': '', 'N': '', 'n': '', 'q': '', 's': '', 'T': '', 't': '', 'V': '', 'v': '', 'X': '',
  20. 'x': '', 'Y': '', 'y': '', 'B': 'bind_interface', 'b': 'bind_address', 'c': 'cipher_spec',
  21. 'D': '[bind_address:]port', 'E': 'log_file', 'e': 'escape_char', 'F': 'configfile', 'I': 'pkcs11',
  22. 'i': 'identity_file', 'J': '[user@]host[:port]', 'L': 'address', 'l': 'login_name', 'm': 'mac_spec',
  23. 'O': 'ctl_cmd', 'o': 'option', 'p': 'port', 'Q': 'query_option', 'R': 'address',
  24. 'S': 'ctl_path', 'W': 'host:port', 'w': 'local_tun[:remote_tun]'
  25. }
  26. ans: dict[str, str] = {}
  27. pos = 0
  28. while True:
  29. pos = raw.find('[', pos)
  30. if pos < 0:
  31. break
  32. num = 1
  33. epos = pos
  34. while num > 0:
  35. epos += 1
  36. if raw[epos] not in '[]':
  37. continue
  38. num += 1 if raw[epos] == '[' else -1
  39. q = raw[pos+1:epos]
  40. pos = epos
  41. if len(q) < 2 or q[0] != '-':
  42. continue
  43. if ' ' in q:
  44. opt, desc = q.split(' ', 1)
  45. ans[opt[1:]] = desc
  46. else:
  47. ans.update(dict.fromkeys(q[1:], ''))
  48. return ans
  49. def is_kitten_cmdline(q: Sequence[str]) -> bool:
  50. if not q:
  51. return False
  52. exe_name = os.path.basename(q[0]).lower()
  53. if exe_name == 'kitten' and q[1:2] == ['ssh']:
  54. return True
  55. if len(q) < 4:
  56. return False
  57. if exe_name != 'kitty':
  58. return False
  59. if q[1:3] == ['+kitten', 'ssh'] or q[1:4] == ['+', 'kitten', 'ssh']:
  60. return True
  61. return q[1:3] == ['+runpy', 'from kittens.runner import main; main()'] and len(q) >= 6 and q[5] == 'ssh'
  62. def patch_cmdline(key: str, val: str, argv: list[str]) -> None:
  63. for i, arg in enumerate(tuple(argv)):
  64. if arg.startswith(f'--kitten={key}='):
  65. argv[i] = f'--kitten={key}={val}'
  66. return
  67. elif i > 0 and argv[i-1] == '--kitten' and (arg.startswith(f'{key}=') or arg.startswith(f'{key} ')):
  68. argv[i] = f'{key}={val}'
  69. return
  70. idx = argv.index('ssh')
  71. argv.insert(idx + 1, f'--kitten={key}={val}')
  72. def set_cwd_in_cmdline(cwd: str, argv: list[str]) -> None:
  73. patch_cmdline('cwd', cwd, argv)
  74. def create_shared_memory(data: Any, prefix: str) -> str:
  75. import atexit
  76. import json
  77. from kitty.fast_data_types import get_boss
  78. from kitty.shm import SharedMemory
  79. db = json.dumps(data).encode('utf-8')
  80. with SharedMemory(size=len(db) + SharedMemory.num_bytes_for_size, prefix=prefix) as shm:
  81. shm.write_data_with_size(db)
  82. shm.flush()
  83. atexit.register(shm.close) # keeps shm alive till exit
  84. get_boss().atexit.shm_unlink(shm.name)
  85. return shm.name
  86. def read_data_from_shared_memory(shm_name: str) -> Any:
  87. import json
  88. import stat
  89. from kitty.shm import SharedMemory
  90. with SharedMemory(shm_name, readonly=True) as shm:
  91. shm.unlink()
  92. if shm.stats.st_uid != os.geteuid() or shm.stats.st_gid != os.getegid():
  93. raise ValueError(f'Incorrect owner on pwfile: uid={shm.stats.st_uid} gid={shm.stats.st_gid}')
  94. mode = stat.S_IMODE(shm.stats.st_mode)
  95. if mode != stat.S_IREAD | stat.S_IWRITE:
  96. raise ValueError(f'Incorrect permissions on pwfile: 0o{mode:03o}')
  97. return json.loads(shm.read_data_with_size())
  98. def get_ssh_data(msgb: memoryview, request_id: str) -> Iterator[bytes]:
  99. from base64 import standard_b64decode
  100. yield b'\nKITTY_DATA_START\n' # to discard leading data
  101. try:
  102. msg = standard_b64decode(msgb).decode('utf-8')
  103. md = dict(x.split('=', 1) for x in msg.split(':'))
  104. pw = md['pw']
  105. pwfilename = md['pwfile']
  106. rq_id = md['id']
  107. except Exception:
  108. traceback.print_exc()
  109. yield b'invalid ssh data request message\n'
  110. else:
  111. try:
  112. env_data = read_data_from_shared_memory(pwfilename)
  113. if pw != env_data['pw']:
  114. raise ValueError('Incorrect password')
  115. if rq_id != request_id:
  116. raise ValueError(f'Incorrect request id: {rq_id!r} expecting the KITTY_PID-KITTY_WINDOW_ID for the current kitty window')
  117. except Exception as e:
  118. traceback.print_exc()
  119. yield f'{e}\n'.encode()
  120. else:
  121. yield b'OK\n'
  122. encoded_data = memoryview(env_data['tarfile'].encode('ascii'))
  123. # macOS has a 255 byte limit on its input queue as per man stty.
  124. # Not clear if that applies to canonical mode input as well, but
  125. # better to be safe.
  126. line_sz = 254
  127. while encoded_data:
  128. yield encoded_data[:line_sz]
  129. yield b'\n'
  130. encoded_data = encoded_data[line_sz:]
  131. yield b'KITTY_DATA_END\n'
  132. def set_env_in_cmdline(env: dict[str, str], argv: list[str], clone: bool = True) -> None:
  133. from kitty.options.utils import DELETE_ENV_VAR
  134. if clone:
  135. patch_cmdline('clone_env', create_shared_memory(env, 'ksse-'), argv)
  136. return
  137. idx = argv.index('ssh')
  138. for i in range(idx, len(argv)):
  139. if argv[i] == '--kitten':
  140. idx = i + 1
  141. elif argv[i].startswith('--kitten='):
  142. idx = i
  143. env_dirs = []
  144. for k, v in env.items():
  145. if v is DELETE_ENV_VAR:
  146. x = f'--kitten=env={k}'
  147. else:
  148. x = f'--kitten=env={k}={v}'
  149. env_dirs.append(x)
  150. argv[idx+1:idx+1] = env_dirs
  151. def get_ssh_cli() -> tuple[set[str], set[str]]:
  152. other_ssh_args: set[str] = set()
  153. boolean_ssh_args: set[str] = set()
  154. for k, v in ssh_options().items():
  155. k = f'-{k}'
  156. if v:
  157. other_ssh_args.add(k)
  158. else:
  159. boolean_ssh_args.add(k)
  160. return boolean_ssh_args, other_ssh_args
  161. def is_extra_arg(arg: str, extra_args: tuple[str, ...]) -> str:
  162. for x in extra_args:
  163. if arg == x or arg.startswith(f'{x}='):
  164. return x
  165. return ''
  166. passthrough_args = {f'-{x}' for x in 'NnfGT'}
  167. def set_server_args_in_cmdline(
  168. server_args: list[str], argv: list[str],
  169. extra_args: tuple[str, ...] = ('--kitten',),
  170. allocate_tty: bool = False
  171. ) -> None:
  172. boolean_ssh_args, other_ssh_args = get_ssh_cli()
  173. ssh_args = []
  174. expecting_option_val = False
  175. found_extra_args: list[str] = []
  176. expecting_extra_val = ''
  177. ans = list(argv)
  178. found_ssh = False
  179. for i, argument in enumerate(argv):
  180. if not found_ssh:
  181. found_ssh = argument == 'ssh'
  182. continue
  183. if argument.startswith('-') and not expecting_option_val:
  184. if argument == '--':
  185. del ans[i+2:]
  186. if allocate_tty and ans[i-1] != '-t':
  187. ans.insert(i, '-t')
  188. break
  189. if extra_args:
  190. matching_ex = is_extra_arg(argument, extra_args)
  191. if matching_ex:
  192. if '=' in argument:
  193. exval = argument.partition('=')[-1]
  194. found_extra_args.extend((matching_ex, exval))
  195. else:
  196. expecting_extra_val = matching_ex
  197. expecting_option_val = True
  198. continue
  199. # could be a multi-character option
  200. all_args = argument[1:]
  201. for i, arg in enumerate(all_args):
  202. arg = f'-{arg}'
  203. if arg in boolean_ssh_args:
  204. ssh_args.append(arg)
  205. continue
  206. if arg in other_ssh_args:
  207. ssh_args.append(arg)
  208. rest = all_args[i+1:]
  209. if rest:
  210. ssh_args.append(rest)
  211. else:
  212. expecting_option_val = True
  213. break
  214. raise KeyError(f'unknown option -- {arg[1:]}')
  215. continue
  216. if expecting_option_val:
  217. if expecting_extra_val:
  218. found_extra_args.extend((expecting_extra_val, argument))
  219. expecting_extra_val = ''
  220. else:
  221. ssh_args.append(argument)
  222. expecting_option_val = False
  223. continue
  224. del ans[i+1:]
  225. if allocate_tty and ans[i] != '-t':
  226. ans.insert(i, '-t')
  227. break
  228. argv[:] = ans + server_args
  229. def get_connection_data(args: list[str], cwd: str = '', extra_args: tuple[str, ...] = ()) -> SSHConnectionData | None:
  230. boolean_ssh_args, other_ssh_args = get_ssh_cli()
  231. port: int | None = None
  232. expecting_port = expecting_identity = False
  233. expecting_option_val = False
  234. expecting_hostname = False
  235. expecting_extra_val = ''
  236. host_name = identity_file = found_ssh = ''
  237. found_extra_args: list[tuple[str, str]] = []
  238. for i, arg in enumerate(args):
  239. if not found_ssh:
  240. if os.path.basename(arg).lower() in ('ssh', 'ssh.exe'):
  241. found_ssh = arg
  242. continue
  243. if expecting_hostname:
  244. host_name = arg
  245. continue
  246. if arg.startswith('-') and not expecting_option_val:
  247. if arg in boolean_ssh_args:
  248. continue
  249. if arg == '--':
  250. expecting_hostname = True
  251. if arg.startswith('-p'):
  252. if arg[2:].isdigit():
  253. with suppress(Exception):
  254. port = int(arg[2:])
  255. continue
  256. elif arg == '-p':
  257. expecting_port = True
  258. elif arg.startswith('-i'):
  259. if arg == '-i':
  260. expecting_identity = True
  261. else:
  262. identity_file = arg[2:]
  263. continue
  264. if arg.startswith('--') and extra_args:
  265. matching_ex = is_extra_arg(arg, extra_args)
  266. if matching_ex:
  267. if '=' in arg:
  268. exval = arg.partition('=')[-1]
  269. found_extra_args.append((matching_ex, exval))
  270. continue
  271. expecting_extra_val = matching_ex
  272. expecting_option_val = True
  273. continue
  274. if expecting_option_val:
  275. if expecting_port:
  276. with suppress(Exception):
  277. port = int(arg)
  278. expecting_port = False
  279. elif expecting_identity:
  280. identity_file = arg
  281. elif expecting_extra_val:
  282. found_extra_args.append((expecting_extra_val, arg))
  283. expecting_extra_val = ''
  284. expecting_option_val = False
  285. continue
  286. if not host_name:
  287. host_name = arg
  288. if not host_name:
  289. return None
  290. if host_name.startswith('ssh://'):
  291. from urllib.parse import urlparse
  292. purl = urlparse(host_name)
  293. if purl.hostname:
  294. host_name = purl.hostname
  295. if purl.username:
  296. host_name = f'{purl.username}@{host_name}'
  297. if port is None and purl.port:
  298. port = purl.port
  299. if identity_file:
  300. if not os.path.isabs(identity_file):
  301. identity_file = os.path.expanduser(identity_file)
  302. if not os.path.isabs(identity_file):
  303. identity_file = os.path.normpath(os.path.join(cwd or os.getcwd(), identity_file))
  304. return SSHConnectionData(found_ssh, host_name, port, identity_file, tuple(found_extra_args))