shell_helpers.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. #!/usr/bin/env python3
  2. import base64
  3. import distutils.file_util
  4. import io
  5. import itertools
  6. import os
  7. import shlex
  8. import shutil
  9. import signal
  10. import stat
  11. import subprocess
  12. import sys
  13. import threading
  14. from typing import List, Union
  15. import urllib.request
  16. class LF:
  17. '''
  18. LineFeed (AKA newline).
  19. Singleton class. Can be used in print_cmd to print out nicer command lines
  20. with --key on the same line as "--key value".
  21. '''
  22. pass
  23. class ShellHelpers:
  24. '''
  25. Helpers to do things which are easy from the shell,
  26. usually filesystem, process or pipe operations.
  27. Attempt to print shell equivalents of all commands to make things
  28. easy to debug and understand what is going on.
  29. '''
  30. _print_lock = threading.Lock()
  31. def __init__(self, dry_run=False, quiet=False):
  32. '''
  33. :param dry_run: don't run the commands, just potentially print them. Debug aid.
  34. :type dry_run: Bool
  35. :param quiet: don't print the commands
  36. :type dry_run: Bool
  37. '''
  38. self.dry_run = dry_run
  39. self.quiet = quiet
  40. @classmethod
  41. def _print_thread_safe(cls, string):
  42. '''
  43. Python sucks: a naive print adds a bunch of random spaces to stdout,
  44. and then copy pasting the command fails.
  45. https://stackoverflow.com/questions/3029816/how-do-i-get-a-thread-safe-print-in-python-2-6
  46. The initial use case was test-gdb which must create a thread for GDB to run the program in parallel.
  47. '''
  48. with cls._print_lock:
  49. try:
  50. print(string, flush=True)
  51. except BrokenPipeError:
  52. # https://stackoverflow.com/questions/26692284/how-to-prevent-brokenpipeerror-when-doing-a-flush-in-python
  53. # https://stackoverflow.com/questions/16314321/suppressing-printout-of-exception-ignored-message-in-python-3
  54. pass
  55. def add_newlines(self, cmd):
  56. out = []
  57. for arg in cmd:
  58. out.extend([arg, LF])
  59. return out
  60. def base64_encode(self, string):
  61. '''
  62. TODO deal with redirection and print nicely.
  63. '''
  64. return base64.b64encode(string.encode()).decode()
  65. def base64_decode(self, string):
  66. return base64.b64decode(string.encode()).decode()
  67. def check_output(self, *args, **kwargs):
  68. '''
  69. Analogous to subprocess.check_output: get the stdout / stderr
  70. of a program back as a byte array.
  71. '''
  72. out_str = []
  73. actual_kwargs = {
  74. 'show_stdout': False,
  75. 'show_cmd': False
  76. }
  77. actual_kwargs.update(kwargs)
  78. self.run_cmd(
  79. *args,
  80. out_str=out_str,
  81. **actual_kwargs
  82. )
  83. return out_str[0]
  84. def chmod(self, path, add_rm_abs='+', mode_delta=stat.S_IXUSR):
  85. '''
  86. TODO extend further, shell print equivalent.
  87. '''
  88. old_mode = os.stat(path).st_mode
  89. if add_rm_abs == '+':
  90. new_mode = old_mode | mode_delta
  91. elif add_rm_abs == '':
  92. new_mode = mode_delta
  93. elif add_rm_abs == '-':
  94. new_mode = old_mode & ~mode_delta
  95. os.chmod(path, new_mode)
  96. @staticmethod
  97. def cmd_to_string(
  98. cmd: List[Union[str, LF]],
  99. cwd=None,
  100. extra_env=None,
  101. extra_paths=None,
  102. force_oneline: bool =False,
  103. ):
  104. '''
  105. Format a command given as a list of strings so that it can
  106. be viewed nicely and executed by bash directly and print it to stdout.
  107. If cmd contains:
  108. * no LF, then newlines are added after every word
  109. * exactly one LF at the end, then no newlines are added
  110. * otherwise: newlines are added exactly at each LF
  111. '''
  112. last_newline = ' \\\n'
  113. newline_separator = last_newline + ' '
  114. out = []
  115. if extra_env is None:
  116. extra_env = {}
  117. if cwd is not None:
  118. out.append('cd {} &&'.format(shlex.quote(cwd)))
  119. if extra_paths is not None:
  120. out.append('PATH="{}:${{PATH}}"'.format(':'.join(extra_paths)))
  121. for key in extra_env:
  122. out.append('{}={}'.format(shlex.quote(key), shlex.quote(extra_env[key])))
  123. cmd_quote = []
  124. newline_count = 0
  125. for arg in cmd:
  126. if arg == LF:
  127. if not force_oneline:
  128. cmd_quote.append(arg)
  129. newline_count += 1
  130. else:
  131. cmd_quote.append(shlex.quote(arg))
  132. if force_oneline or newline_count > 0:
  133. cmd_quote = [
  134. ' '.join(list(y))
  135. for x, y in itertools.groupby(
  136. cmd_quote,
  137. lambda z: z == LF
  138. )
  139. if not x
  140. ]
  141. out.extend(cmd_quote)
  142. if force_oneline or newline_count == 1 and cmd[-1] == LF:
  143. ending = ''
  144. else:
  145. ending = last_newline + ';'
  146. return newline_separator.join(out) + ending
  147. def copy_file_if_update(self, src, destfile):
  148. if os.path.isdir(destfile):
  149. destfile = os.path.join(destfile, os.path.basename(src))
  150. self.mkdir_p(os.path.dirname(destfile))
  151. if (
  152. not os.path.exists(destfile) or \
  153. os.path.getmtime(src) > os.path.getmtime(destfile)
  154. ):
  155. self.cp(src, destfile)
  156. def copy_dir_if_update_non_recursive(
  157. self,
  158. srcdir,
  159. destdir,
  160. filter_ext=None
  161. ):
  162. # TODO print rsync equivalent.
  163. os.makedirs(destdir, exist_ok=True)
  164. if not os.path.exists(srcdir) and self.dry_run:
  165. basenames = []
  166. else:
  167. basenames = os.listdir(srcdir)
  168. for basename in sorted(basenames):
  169. src = os.path.join(srcdir, basename)
  170. if os.path.isfile(src) or os.path.islink(src):
  171. noext, ext = os.path.splitext(basename)
  172. if (filter_ext is None or ext == filter_ext):
  173. dest = os.path.join(destdir, basename)
  174. self.copy_file_if_update(src, dest)
  175. def copy_dir_if_update(
  176. self,
  177. srcdir,
  178. destdir,
  179. filter_ext=None
  180. ):
  181. self.copy_dir_if_update_non_recursive(srcdir, destdir, filter_ext)
  182. srcdir_abs = os.path.abspath(srcdir)
  183. srcdir_abs_len = len(srcdir_abs)
  184. for path, dirnames, filenames in self.walk(srcdir_abs):
  185. for dirname in dirnames:
  186. dirpath = os.path.join(path, dirname)
  187. dirpath_relative_root = dirpath[srcdir_abs_len + 1:]
  188. self.copy_dir_if_update_non_recursive(
  189. dirpath,
  190. os.path.join(destdir, dirpath_relative_root),
  191. filter_ext
  192. )
  193. def cp(self, src, dest, **kwargs):
  194. self.print_cmd(['cp', src, dest])
  195. if not self.dry_run:
  196. if os.path.islink(src):
  197. if os.path.lexists(dest):
  198. os.unlink(dest)
  199. linkto = os.readlink(src)
  200. os.symlink(linkto, dest)
  201. else:
  202. shutil.copy2(src, dest)
  203. def mkdir_p(self, d):
  204. if not os.path.exists(d):
  205. self.print_cmd(['mkdir', d, LF])
  206. if not self.dry_run:
  207. os.makedirs(d)
  208. def mv(self, src, dest, **kwargs):
  209. self.print_cmd(['mv', src, dest])
  210. if not self.dry_run:
  211. shutil.move(src, dest)
  212. def print_cmd(
  213. self,
  214. cmd,
  215. cwd=None,
  216. cmd_file=None,
  217. extra_env=None,
  218. extra_paths=None,
  219. force_oneline=False,
  220. ):
  221. '''
  222. Print cmd_to_string to stdout.
  223. Optionally save the command to cmd_file file, and add extra_env
  224. environment variables to the command generated.
  225. '''
  226. if type(cmd) is str:
  227. cmd_string = cmd
  228. else:
  229. cmd_string = self.cmd_to_string(
  230. cmd,
  231. cwd=cwd,
  232. extra_env=extra_env,
  233. extra_paths=extra_paths,
  234. force_oneline=force_oneline,
  235. )
  236. if not self.quiet:
  237. self._print_thread_safe('+ ' + cmd_string)
  238. if cmd_file is not None:
  239. os.makedirs(os.path.dirname(cmd_file), exist_ok=True)
  240. with open(cmd_file, 'w') as f:
  241. f.write('#!/usr/bin/env bash\n')
  242. f.write(cmd_string)
  243. self.chmod(cmd_file)
  244. def rmrf(self, path):
  245. self.print_cmd(['rm', '-r', '-f', path, LF])
  246. if not self.dry_run and os.path.exists(path):
  247. if os.path.isdir(path):
  248. shutil.rmtree(path)
  249. else:
  250. os.unlink(path)
  251. def run_cmd(
  252. self,
  253. cmd,
  254. cmd_file=None,
  255. out_file=None,
  256. show_stdout=True,
  257. show_cmd=True,
  258. extra_env=None,
  259. extra_paths=None,
  260. delete_env=None,
  261. raise_on_failure=True,
  262. *,
  263. out_str=None,
  264. **kwargs
  265. ):
  266. '''
  267. Run a command. Write the command to stdout before running it.
  268. Wait until the command finishes execution.
  269. :param cmd: command to run. LF entries are magic get skipped.
  270. :type cmd: List[str]
  271. :param cmd_file: if not None, write the command to be run to that file
  272. :type cmd_file: str
  273. :param out_file: if not None, write the stdout and stderr of the command the file
  274. :type out_file: str
  275. :param out_str: if not None, append the stdout and stderr string to this list
  276. :type out_str: Union(List,None)
  277. :param show_stdout: wether to show stdout and stderr on the terminal or not
  278. :type show_stdout: bool
  279. :param extra_env: extra environment variables to add when running the command
  280. :type extra_env: Dict[str,str]
  281. :return: exit status of the command
  282. :rtype: int
  283. '''
  284. if out_file is None and out_str is None:
  285. if show_stdout:
  286. stdout = None
  287. stderr = None
  288. else:
  289. stdout = subprocess.DEVNULL
  290. stderr = subprocess.DEVNULL
  291. else:
  292. stdout = subprocess.PIPE
  293. stderr = subprocess.STDOUT
  294. if extra_env is None:
  295. extra_env = {}
  296. if delete_env is None:
  297. delete_env = []
  298. if 'cwd' in kwargs:
  299. cwd = kwargs['cwd']
  300. else:
  301. cwd = None
  302. env = os.environ.copy()
  303. env.update(extra_env)
  304. if extra_paths is not None:
  305. path = ':'.join(extra_paths)
  306. if 'PATH' in os.environ:
  307. path += ':' + os.environ['PATH']
  308. env['PATH'] = path
  309. for key in delete_env:
  310. if key in env:
  311. del env[key]
  312. if show_cmd:
  313. self.print_cmd(
  314. cmd,
  315. cwd=cwd,
  316. cmd_file=cmd_file,
  317. extra_env=extra_env,
  318. extra_paths=extra_paths
  319. )
  320. # Otherwise, if called from a non-main thread:
  321. # ValueError: signal only works in main thread
  322. if threading.current_thread() == threading.main_thread():
  323. # Otherwise Ctrl + C gives:
  324. # - ugly Python stack trace for gem5 (QEMU takes over terminal and is fine).
  325. # - kills Python, and that then kills GDB:
  326. # https://stackoverflow.com/questions/19807134/does-python-always-raise-an-exception-if-you-do-ctrlc-when-a-subprocess-is-exec
  327. sigint_old = signal.getsignal(signal.SIGINT)
  328. signal.signal(signal.SIGINT, signal.SIG_IGN)
  329. # Otherwise BrokenPipeError when piping through | grep
  330. # But if I do this_module, my terminal gets broken at the end. Why, why, why.
  331. # https://stackoverflow.com/questions/14207708/ioerror-errno-32-broken-pipe-python
  332. # Ignoring the exception is not enough as it prints a warning anyways.
  333. #sigpipe_old = signal.getsignal(signal.SIGPIPE)
  334. #signal.signal(signal.SIGPIPE, signal.SIG_DFL)
  335. cmd = self.strip_newlines(cmd)
  336. if not self.dry_run:
  337. # https://stackoverflow.com/questions/15535240/python-popen-write-to-stdout-and-log-file-simultaneously/52090802#52090802
  338. with subprocess.Popen(
  339. cmd,
  340. stdout=stdout,
  341. stderr=stderr,
  342. env=env,
  343. **kwargs
  344. ) as proc:
  345. if out_file is not None or out_str is not None:
  346. if out_file is not None:
  347. os.makedirs(os.path.split(os.path.abspath(out_file))[0], exist_ok=True)
  348. if out_file is not None:
  349. logfile = open(out_file, 'bw')
  350. logfile_str = []
  351. while True:
  352. byte = proc.stdout.read(1)
  353. if byte:
  354. if show_stdout:
  355. sys.stdout.buffer.write(byte)
  356. try:
  357. sys.stdout.flush()
  358. except BlockingIOError:
  359. # TODO understand. Why, Python, why.
  360. pass
  361. if out_file is not None:
  362. logfile.write(byte)
  363. if out_str is not None:
  364. logfile_str.append(byte)
  365. else:
  366. break
  367. if out_file is not None:
  368. logfile.close()
  369. if out_str is not None:
  370. out_str.append((b''.join(logfile_str)))
  371. if threading.current_thread() == threading.main_thread():
  372. signal.signal(signal.SIGINT, sigint_old)
  373. #signal.signal(signal.SIGPIPE, sigpipe_old)
  374. returncode = proc.returncode
  375. if returncode != 0 and raise_on_failure:
  376. e = Exception('Command exited with status: {}'.format(returncode))
  377. e.returncode = returncode
  378. raise e
  379. return returncode
  380. else:
  381. if not out_str is None:
  382. out_str.append(b'')
  383. return 0
  384. def shlex_split(self, string):
  385. '''
  386. shlex_split, but also add Newline after every word.
  387. Not perfect since it does not group arguments, but I don't see a solution.
  388. '''
  389. return self.add_newlines(shlex.split(string))
  390. def strip_newlines(self, cmd):
  391. if type(cmd) is str:
  392. return cmd
  393. else:
  394. return [x for x in cmd if x != LF]
  395. def walk(self, root):
  396. '''
  397. Extended walk that can take files or directories.
  398. '''
  399. if not os.path.exists(root):
  400. raise Exception('Path does not exist: ' + root)
  401. if os.path.isfile(root):
  402. dirname, basename = os.path.split(root)
  403. yield dirname, [], [basename]
  404. else:
  405. for path, dirnames, filenames in os.walk(root):
  406. dirnames.sort()
  407. filenames.sort()
  408. yield path, dirnames, filenames
  409. def wget(self, url, download_path):
  410. '''
  411. Append extra KEY=val configs into the given config file.
  412. I wissh we could have a progress indicator, but impossible:
  413. https://stackoverflow.com/questions/51212/how-to-write-a-download-progress-indicator-in-python
  414. '''
  415. self.print_cmd([
  416. 'wget', LF,
  417. '-O', download_path, LF,
  418. url, LF,
  419. ])
  420. urllib.request.urlretrieve(url, download_path)
  421. def write_configs(self, config_path, configs, config_fragments=None, mode='a'):
  422. '''
  423. Append extra KEY=val configs into the given config file.
  424. '''
  425. if config_fragments is None:
  426. config_fragments = []
  427. for config_fragment in config_fragments:
  428. self.print_cmd(['cat', config_fragment, '>>', config_path])
  429. if not self.dry_run:
  430. with open(config_path, 'a') as config_file:
  431. for config_fragment in config_fragments:
  432. with open(config_fragment, 'r') as config_fragment_file:
  433. for line in config_fragment_file:
  434. config_file.write(line)
  435. self.write_string_to_file(config_path, '\n'.join(configs), mode=mode)
  436. def write_string_to_file(self, path, string, mode='w'):
  437. if mode == 'a':
  438. redirect = '>>'
  439. else:
  440. redirect = '>'
  441. self.print_cmd("cat << 'EOF' {} {}\n{}\nEOF".format(redirect, path, string))
  442. if not self.dry_run:
  443. with open(path, mode) as f:
  444. f.write(string)
  445. if __name__ == '__main__':
  446. shell_helpers = ShellHelpers()
  447. if 'cmd_to_string':
  448. # Default.
  449. assert shell_helpers.cmd_to_string(['cmd']) == 'cmd \\\n;'
  450. assert shell_helpers.cmd_to_string(['cmd', 'arg1']) == 'cmd \\\n arg1 \\\n;'
  451. assert shell_helpers.cmd_to_string(['cmd', 'arg1', 'arg2']) == 'cmd \\\n arg1 \\\n arg2 \\\n;'
  452. # Argument with a space gets escaped.
  453. assert shell_helpers.cmd_to_string(['cmd', 'arg1 arg2']) == "cmd \\\n 'arg1 arg2' \\\n;"
  454. # Ending in LF with no other LFs get separated only by spaces.
  455. assert shell_helpers.cmd_to_string(['cmd', LF]) == 'cmd'
  456. assert shell_helpers.cmd_to_string(['cmd', 'arg1', LF]) == 'cmd arg1'
  457. assert shell_helpers.cmd_to_string(['cmd', 'arg1', 'arg2', LF]) == 'cmd arg1 arg2'
  458. # More than one LF adds newline separators at each LF.
  459. assert shell_helpers.cmd_to_string(['cmd', LF, 'arg1', LF]) == 'cmd \\\n arg1 \\\n;'
  460. assert shell_helpers.cmd_to_string(['cmd', LF, 'arg1', LF, 'arg2', LF]) == 'cmd \\\n arg1 \\\n arg2 \\\n;'
  461. assert shell_helpers.cmd_to_string(['cmd', LF, 'arg1', 'arg2', LF]) == 'cmd \\\n arg1 arg2 \\\n;'
  462. # force_oneline separates everything simply by spaces.
  463. assert \
  464. shell_helpers.cmd_to_string(['cmd', LF, 'arg1', LF, 'arg2', LF], force_oneline=True) \
  465. == 'cmd arg1 arg2'