test_plugins.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. import importlib
  2. import os
  3. import shutil
  4. import sys
  5. import unittest
  6. from pathlib import Path
  7. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  8. TEST_DATA_DIR = Path(os.path.dirname(os.path.abspath(__file__)), 'testdata')
  9. sys.path.append(str(TEST_DATA_DIR))
  10. importlib.invalidate_caches()
  11. from yt_dlp.plugins import (
  12. PACKAGE_NAME,
  13. PluginSpec,
  14. directories,
  15. load_plugins,
  16. load_all_plugins,
  17. register_plugin_spec,
  18. )
  19. from yt_dlp.globals import (
  20. extractors,
  21. postprocessors,
  22. plugin_dirs,
  23. plugin_ies,
  24. plugin_pps,
  25. all_plugins_loaded,
  26. plugin_specs,
  27. )
  28. EXTRACTOR_PLUGIN_SPEC = PluginSpec(
  29. module_name='extractor',
  30. suffix='IE',
  31. destination=extractors,
  32. plugin_destination=plugin_ies,
  33. )
  34. POSTPROCESSOR_PLUGIN_SPEC = PluginSpec(
  35. module_name='postprocessor',
  36. suffix='PP',
  37. destination=postprocessors,
  38. plugin_destination=plugin_pps,
  39. )
  40. def reset_plugins():
  41. plugin_ies.value = {}
  42. plugin_pps.value = {}
  43. plugin_dirs.value = ['default']
  44. plugin_specs.value = {}
  45. all_plugins_loaded.value = False
  46. # Clearing override plugins is probably difficult
  47. for module_name in tuple(sys.modules):
  48. for plugin_type in ('extractor', 'postprocessor'):
  49. if module_name.startswith(f'{PACKAGE_NAME}.{plugin_type}.'):
  50. del sys.modules[module_name]
  51. importlib.invalidate_caches()
  52. class TestPlugins(unittest.TestCase):
  53. TEST_PLUGIN_DIR = TEST_DATA_DIR / PACKAGE_NAME
  54. def setUp(self):
  55. reset_plugins()
  56. def tearDown(self):
  57. reset_plugins()
  58. def test_directories_containing_plugins(self):
  59. self.assertIn(self.TEST_PLUGIN_DIR, map(Path, directories()))
  60. def test_extractor_classes(self):
  61. plugins_ie = load_plugins(EXTRACTOR_PLUGIN_SPEC)
  62. self.assertIn(f'{PACKAGE_NAME}.extractor.normal', sys.modules.keys())
  63. self.assertIn('NormalPluginIE', plugins_ie.keys())
  64. # don't load modules with underscore prefix
  65. self.assertFalse(
  66. f'{PACKAGE_NAME}.extractor._ignore' in sys.modules,
  67. 'loaded module beginning with underscore')
  68. self.assertNotIn('IgnorePluginIE', plugins_ie.keys())
  69. self.assertNotIn('IgnorePluginIE', plugin_ies.value)
  70. # Don't load extractors with underscore prefix
  71. self.assertNotIn('_IgnoreUnderscorePluginIE', plugins_ie.keys())
  72. self.assertNotIn('_IgnoreUnderscorePluginIE', plugin_ies.value)
  73. # Don't load extractors not specified in __all__ (if supplied)
  74. self.assertNotIn('IgnoreNotInAllPluginIE', plugins_ie.keys())
  75. self.assertNotIn('IgnoreNotInAllPluginIE', plugin_ies.value)
  76. self.assertIn('InAllPluginIE', plugins_ie.keys())
  77. self.assertIn('InAllPluginIE', plugin_ies.value)
  78. # Don't load override extractors
  79. self.assertNotIn('OverrideGenericIE', plugins_ie.keys())
  80. self.assertNotIn('OverrideGenericIE', plugin_ies.value)
  81. self.assertNotIn('_UnderscoreOverrideGenericIE', plugins_ie.keys())
  82. self.assertNotIn('_UnderscoreOverrideGenericIE', plugin_ies.value)
  83. def test_postprocessor_classes(self):
  84. plugins_pp = load_plugins(POSTPROCESSOR_PLUGIN_SPEC)
  85. self.assertIn('NormalPluginPP', plugins_pp.keys())
  86. self.assertIn(f'{PACKAGE_NAME}.postprocessor.normal', sys.modules.keys())
  87. self.assertIn('NormalPluginPP', plugin_pps.value)
  88. def test_importing_zipped_module(self):
  89. zip_path = TEST_DATA_DIR / 'zipped_plugins.zip'
  90. shutil.make_archive(str(zip_path)[:-4], 'zip', str(zip_path)[:-4])
  91. sys.path.append(str(zip_path)) # add zip to search paths
  92. importlib.invalidate_caches() # reset the import caches
  93. try:
  94. for plugin_type in ('extractor', 'postprocessor'):
  95. package = importlib.import_module(f'{PACKAGE_NAME}.{plugin_type}')
  96. self.assertIn(zip_path / PACKAGE_NAME / plugin_type, map(Path, package.__path__))
  97. plugins_ie = load_plugins(EXTRACTOR_PLUGIN_SPEC)
  98. self.assertIn('ZippedPluginIE', plugins_ie.keys())
  99. plugins_pp = load_plugins(POSTPROCESSOR_PLUGIN_SPEC)
  100. self.assertIn('ZippedPluginPP', plugins_pp.keys())
  101. finally:
  102. sys.path.remove(str(zip_path))
  103. os.remove(zip_path)
  104. importlib.invalidate_caches() # reset the import caches
  105. def test_reloading_plugins(self):
  106. reload_plugins_path = TEST_DATA_DIR / 'reload_plugins'
  107. load_plugins(EXTRACTOR_PLUGIN_SPEC)
  108. load_plugins(POSTPROCESSOR_PLUGIN_SPEC)
  109. # Remove default folder and add reload_plugin path
  110. sys.path.remove(str(TEST_DATA_DIR))
  111. sys.path.append(str(reload_plugins_path))
  112. importlib.invalidate_caches()
  113. try:
  114. for plugin_type in ('extractor', 'postprocessor'):
  115. package = importlib.import_module(f'{PACKAGE_NAME}.{plugin_type}')
  116. self.assertIn(reload_plugins_path / PACKAGE_NAME / plugin_type, map(Path, package.__path__))
  117. plugins_ie = load_plugins(EXTRACTOR_PLUGIN_SPEC)
  118. self.assertIn('NormalPluginIE', plugins_ie.keys())
  119. self.assertTrue(
  120. plugins_ie['NormalPluginIE'].REPLACED,
  121. msg='Reloading has not replaced original extractor plugin')
  122. self.assertTrue(
  123. extractors.value['NormalPluginIE'].REPLACED,
  124. msg='Reloading has not replaced original extractor plugin globally')
  125. plugins_pp = load_plugins(POSTPROCESSOR_PLUGIN_SPEC)
  126. self.assertIn('NormalPluginPP', plugins_pp.keys())
  127. self.assertTrue(plugins_pp['NormalPluginPP'].REPLACED,
  128. msg='Reloading has not replaced original postprocessor plugin')
  129. self.assertTrue(
  130. postprocessors.value['NormalPluginPP'].REPLACED,
  131. msg='Reloading has not replaced original postprocessor plugin globally')
  132. finally:
  133. sys.path.remove(str(reload_plugins_path))
  134. sys.path.append(str(TEST_DATA_DIR))
  135. importlib.invalidate_caches()
  136. def test_extractor_override_plugin(self):
  137. load_plugins(EXTRACTOR_PLUGIN_SPEC)
  138. from yt_dlp.extractor.generic import GenericIE
  139. self.assertEqual(GenericIE.TEST_FIELD, 'override')
  140. self.assertEqual(GenericIE.SECONDARY_TEST_FIELD, 'underscore-override')
  141. self.assertEqual(GenericIE.IE_NAME, 'generic+override+underscore-override')
  142. importlib.invalidate_caches()
  143. # test that loading a second time doesn't wrap a second time
  144. load_plugins(EXTRACTOR_PLUGIN_SPEC)
  145. from yt_dlp.extractor.generic import GenericIE
  146. self.assertEqual(GenericIE.IE_NAME, 'generic+override+underscore-override')
  147. def test_load_all_plugin_types(self):
  148. # no plugin specs registered
  149. load_all_plugins()
  150. self.assertNotIn(f'{PACKAGE_NAME}.extractor.normal', sys.modules.keys())
  151. self.assertNotIn(f'{PACKAGE_NAME}.postprocessor.normal', sys.modules.keys())
  152. register_plugin_spec(EXTRACTOR_PLUGIN_SPEC)
  153. register_plugin_spec(POSTPROCESSOR_PLUGIN_SPEC)
  154. load_all_plugins()
  155. self.assertTrue(all_plugins_loaded.value)
  156. self.assertIn(f'{PACKAGE_NAME}.extractor.normal', sys.modules.keys())
  157. self.assertIn(f'{PACKAGE_NAME}.postprocessor.normal', sys.modules.keys())
  158. def test_no_plugin_dirs(self):
  159. register_plugin_spec(EXTRACTOR_PLUGIN_SPEC)
  160. register_plugin_spec(POSTPROCESSOR_PLUGIN_SPEC)
  161. plugin_dirs.value = []
  162. load_all_plugins()
  163. self.assertNotIn(f'{PACKAGE_NAME}.extractor.normal', sys.modules.keys())
  164. self.assertNotIn(f'{PACKAGE_NAME}.postprocessor.normal', sys.modules.keys())
  165. def test_set_plugin_dirs(self):
  166. custom_plugin_dir = str(TEST_DATA_DIR / 'plugin_packages')
  167. plugin_dirs.value = [custom_plugin_dir]
  168. load_plugins(EXTRACTOR_PLUGIN_SPEC)
  169. self.assertIn(f'{PACKAGE_NAME}.extractor.package', sys.modules.keys())
  170. self.assertIn('PackagePluginIE', plugin_ies.value)
  171. def test_invalid_plugin_dir(self):
  172. plugin_dirs.value = ['invalid_dir']
  173. with self.assertRaises(ValueError):
  174. load_plugins(EXTRACTOR_PLUGIN_SPEC)
  175. def test_append_plugin_dirs(self):
  176. custom_plugin_dir = str(TEST_DATA_DIR / 'plugin_packages')
  177. self.assertEqual(plugin_dirs.value, ['default'])
  178. plugin_dirs.value.append(custom_plugin_dir)
  179. self.assertEqual(plugin_dirs.value, ['default', custom_plugin_dir])
  180. load_plugins(EXTRACTOR_PLUGIN_SPEC)
  181. self.assertIn(f'{PACKAGE_NAME}.extractor.package', sys.modules.keys())
  182. self.assertIn('PackagePluginIE', plugin_ies.value)
  183. def test_get_plugin_spec(self):
  184. register_plugin_spec(EXTRACTOR_PLUGIN_SPEC)
  185. register_plugin_spec(POSTPROCESSOR_PLUGIN_SPEC)
  186. self.assertEqual(plugin_specs.value.get('extractor'), EXTRACTOR_PLUGIN_SPEC)
  187. self.assertEqual(plugin_specs.value.get('postprocessor'), POSTPROCESSOR_PLUGIN_SPEC)
  188. self.assertIsNone(plugin_specs.value.get('invalid'))
  189. if __name__ == '__main__':
  190. unittest.main()