shader_reload_test.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. #
  2. # Copyright (c) Contributors to the Open 3D Engine Project.
  3. # For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. #
  5. # SPDX-License-Identifier: Apache-2.0 OR MIT
  6. #
  7. #
  8. import os
  9. import numpy as np
  10. import tempfile
  11. from PIL import Image
  12. import azlmbr.bus as azbus
  13. import azlmbr.legacy.general as azgeneral
  14. import azlmbr.editor as azeditor
  15. import azlmbr.atom as azatom
  16. import azlmbr.entity as azentity
  17. import azlmbr.camera as azcamera
  18. # See README.md for details.
  19. # Some global constants
  20. EXPECTED_LEVEL_NAME = "ShaderReloadTest"
  21. SUPPORTED_IMAGE_FILE_EXTENSIONS = {".ppm", ".bmp", ".tiff", ".tif", ".png"}
  22. def OverwriteFile(azslFilePath: str, begin: list, middle: list, end: list):
  23. fileObj = open(azslFilePath, "w")
  24. fileObj.writelines(begin)
  25. fileObj.writelines(middle)
  26. fileObj.writelines(end)
  27. fileObj.close()
  28. def FlipTestColorLine(line: str) -> tuple[tuple[int, int, int], str]:
  29. """
  30. Flips/toggles "BLUE_COLOR" for "GREEN_COLOR" in @line
  31. and viceversa. Also retuns the expected pixel color in addition
  32. to the modified line.
  33. """
  34. if "BLUE_COLOR" in line:
  35. expectedColor = "GREEN_COLOR"
  36. newLine = line.replace("BLUE_COLOR", expectedColor)
  37. return (0, 152, 15), newLine
  38. elif "GREEN_COLOR" in line:
  39. expectedColor = "BLUE_COLOR"
  40. newLine = line.replace("GREEN_COLOR", expectedColor)
  41. return (0, 1, 155), newLine
  42. raise Exception("Can't find color")
  43. def FlipShaderColor(azslFilePath: str) -> tuple[int, int, int]:
  44. """
  45. Modifies the Shader code in @azslFilePath by flipping the line
  46. that outputs the expected color.
  47. If the line is found as:
  48. const float3 TEST_COLOR = BLUE_COLOR;
  49. It gets flipped to:
  50. const float3 TEST_COLOR = GREEN_COLOR;
  51. and viceversa.
  52. """
  53. begin = []
  54. middle = []
  55. end = []
  56. fileObj = open(azslFilePath, 'rt')
  57. sectionStart = False
  58. sectionEnd = False
  59. expectedColor = ""
  60. for line in fileObj:
  61. if ("ShaderReloadTest" in line):
  62. if ("START" in line):
  63. middle.append(line)
  64. sectionStart = True
  65. continue
  66. elif ("END" in line):
  67. middle.append(line)
  68. sectionEnd = True
  69. continue
  70. if (not sectionStart) and (not sectionEnd):
  71. begin.append(line)
  72. continue
  73. if sectionEnd:
  74. end.append(line)
  75. continue
  76. if "TEST_COLOR" in line:
  77. expectedColor, newLine = FlipTestColorLine(line)
  78. middle.append(newLine)
  79. fileObj.close()
  80. OverwriteFile(azslFilePath, begin, middle, end)
  81. return expectedColor
  82. def UpdateShaderAndTestPixelResult(azslFilePath: str, screenUpdateWaitTime: float, captureCountOnFailure: int, screenshotImagePath: str, quickShaderOverwriteCount: int, quickShaderOverwriteWait: float) -> bool:
  83. """
  84. This function represents the work done for a single iteration of the shader modification test.
  85. Modifies the content of the Shader (@azslFilePath), waits @screenUpdateWaitTime, and captures
  86. the pixels of the Editor Viewport.
  87. Some leniency was added with the variable @captureCountOnFailure because sometimes @@screenUpdateWaitTime is not
  88. enought time for the screen to update. This function retries @captureCountOnFailure times before
  89. considering it a failure.
  90. """
  91. expectedPixelColor = FlipShaderColor(azslFilePath)
  92. print(f"Expecting color {expectedPixelColor}")
  93. for overwriteCount in range(quickShaderOverwriteCount):
  94. azgeneral.idle_wait(quickShaderOverwriteWait)
  95. expectedPixelColor = FlipShaderColor(azslFilePath)
  96. print(f"Shader quickly overwritten {overwriteCount + 1} of {quickShaderOverwriteCount}")
  97. print(f"Expecting color {expectedPixelColor}")
  98. azgeneral.idle_wait(screenUpdateWaitTime)
  99. # Capture the screenshot
  100. captureCount = -1
  101. success = False
  102. while captureCount < captureCountOnFailure:
  103. captureCount += 1
  104. outcome = azatom.FrameCaptureRequestBus(
  105. azbus.Broadcast, "CaptureScreenshot", screenshotImagePath
  106. )
  107. if not outcome.IsSuccess():
  108. frameCaptureError = outcome.GetError()
  109. errorMsg = frameCaptureError.error_message
  110. print(f"Failed to capture screenshot at outputImagePath='{screenshotImagePath}'\nError:\n{errorMsg}")
  111. return False
  112. azgeneral.idle_wait(screenUpdateWaitTime)
  113. img = Image.open(screenshotImagePath)
  114. width, height = img.size
  115. r = int(height/2)
  116. c = int(width/2)
  117. image_array = np.array(img)
  118. color = image_array[r][c]
  119. print(f"captureCount {captureCount}: Center Pixel[{r},{c}] Color={color}, type:{type(color)}, r={color[0]}, g={color[1]}, b={color[2]}")
  120. success = (color[0] == expectedPixelColor[0]) and (color[1] == expectedPixelColor[1]) and (color[2] == expectedPixelColor[2])
  121. if success:
  122. return success
  123. return success
  124. def ShaderReloadTest(iterationCountMax: int, screenUpdateWaitTime: float, captureCountOnFailure: int, screenshotImagePath: str, quickShaderOverwriteCount: int, quickShaderOverwriteWait: float) -> tuple[bool, int]:
  125. """
  126. This function is the main loop. Runs @iterationCountMax iterations and all iterations must PASS
  127. to consider the test a success.
  128. A single iteration modifies the Shader file, waits @screenUpdateWaitTime, captures the pixel content
  129. of the Editor Viewport, and reads the center pixel of the image for an expected color.
  130. """
  131. levelPath = azeditor.EditorToolsApplicationRequestBus(azbus.Broadcast, "GetCurrentLevelPath")
  132. levelName = os.path.basename(levelPath)
  133. if levelName != EXPECTED_LEVEL_NAME:
  134. print(f"ERROR: This test suite expects a level named '{EXPECTED_LEVEL_NAME}', instead got '{levelName}'")
  135. return False, 0
  136. azslFilePath = os.path.join(levelPath, "SimpleMesh.azsl")
  137. iterationCount = 0
  138. success = False
  139. while iterationCount < iterationCountMax:
  140. iterationCount += 1
  141. print(f"Starting Retry {iterationCount} of {iterationCountMax}...")
  142. success = UpdateShaderAndTestPixelResult(azslFilePath, screenUpdateWaitTime, captureCountOnFailure, screenshotImagePath, quickShaderOverwriteCount, quickShaderOverwriteWait)
  143. if not success:
  144. break
  145. return success, iterationCount
  146. def ValidateImageExtension(screenshotImagePath: str) -> bool:
  147. _, file_extension = os.path.splitext(screenshotImagePath)
  148. if file_extension in SUPPORTED_IMAGE_FILE_EXTENSIONS:
  149. return True
  150. print(f"ERROR: Image path '{screenshotImagePath}' has an unsupported file extension.\nSupported extensions: {SUPPORTED_IMAGE_FILE_EXTENSIONS}")
  151. return False
  152. def AdjustEditorCameraPosition(cameraEntityName: str) -> azentity.EntityId:
  153. """
  154. Searches for an entity named @cameraEntityName, assumes the entity has a Camera Component,
  155. and forces the Editor Viewport to make it the Active Camera. This helps center the `Billboard`
  156. entity because this test Samples the middle the of the screen for the correct Pixel color.
  157. """
  158. if not cameraEntityName:
  159. return None
  160. # Find the first entity with such name.
  161. searchFilter = azentity.SearchFilter()
  162. searchFilter.names = [cameraEntityName,]
  163. entityList = azentity.SearchBus(azbus.Broadcast, "SearchEntities", searchFilter)
  164. print(f"Found {len(entityList)} entities named {cameraEntityName}. Will use the first.")
  165. if len(entityList) < 1:
  166. print(f"No camera entity with name {cameraEntityName} was found. Viewport camera won't be adjusted.")
  167. return None
  168. cameraEntityId = entityList[0]
  169. isActiveCamera = azcamera.EditorCameraViewRequestBus(azbus.Event, "IsActiveCamera", cameraEntityId)
  170. if isActiveCamera:
  171. print(f"Entity '{cameraEntityName}' is already the active camera")
  172. return cameraEntityId
  173. azcamera.EditorCameraViewRequestBus(azbus.Event, "ToggleCameraAsActiveView", cameraEntityId)
  174. print(f"Entity '{cameraEntityName}' is now the active camera. Will wait 2 seconds for the screen to settle.")
  175. print(f"REMARK: It is expected that the camera is located at [0, -1, 2] with all euler angles at 0.")
  176. azgeneral.idle_wait(2.0)
  177. return cameraEntityId
  178. def ClearViewportOfHelpers():
  179. """
  180. Makes sure all helpers and artifacts that add unwanted pixels
  181. are hidden.
  182. """
  183. # Make sure no entity is selected when the test runs because entity selection adds unwanted colored pixels
  184. azeditor.ToolsApplicationRequestBus(azbus.Broadcast, "SetSelectedEntities", [])
  185. # Hide helpers
  186. if azgeneral.is_helpers_shown():
  187. azgeneral.toggle_helpers()
  188. # Hide icons
  189. if azgeneral.is_icons_shown():
  190. azgeneral.toggle_icons()
  191. #Hide FPS, etc
  192. azgeneral.set_cvar_integer("r_displayInfo", 0)
  193. # Wait a little for the screen to update.
  194. azgeneral.idle_wait(0.25)
  195. # Quick Example on how to run this test from the Editor Console (See README.md for more details):
  196. # Runs 10 iterations:
  197. # pyRunFile C:\GIT\o3de\AutomatedTesting\Levels\ShaderReloadTest\shader_reload_test.py -i 10
  198. def MainFunc():
  199. import argparse
  200. parser = argparse.ArgumentParser(
  201. description="Records several frames of pass attachments as image files."
  202. )
  203. parser.add_argument(
  204. "-i",
  205. "--iterations",
  206. type=int,
  207. default=1,
  208. help="How many times the Shader should be modified and the screen pixel validated.",
  209. )
  210. parser.add_argument(
  211. "--screen_update_wait_time",
  212. type=float,
  213. default=3.0,
  214. help="Minimum time to wait after modifying the shader and taking the screen snapshot to validate color output.",
  215. )
  216. parser.add_argument(
  217. "--capture_count_on_failure",
  218. type=int,
  219. default=2,
  220. help="How many times the screen should be recaptured if the pixel output failes.",
  221. )
  222. parser.add_argument(
  223. "-p",
  224. "--screenshot_image_path",
  225. default="",
  226. help="Absolute path of the file where the screenshot will be written to. Must include image extensions 'ppm', 'png', 'bmp', 'tif'. By default a temporary png path will be created",
  227. )
  228. parser.add_argument(
  229. "-q",
  230. "--quick_shader_overwrite_count",
  231. type=int,
  232. default=0,
  233. help="How many times the shader should be overwritten before capturing the screenshot. This simulates real life cases where a shader file is updated and saved to the file system several times consecutively",
  234. )
  235. parser.add_argument(
  236. "-w",
  237. "--quick_shader_overwrite_wait",
  238. type=float,
  239. default=1.0,
  240. help="Minimum time to wait in between quick shader overwrites.",
  241. )
  242. parser.add_argument(
  243. "-c",
  244. "--camera_entity_name",
  245. default="Camera",
  246. help="Name of the entity that contains a Camera Component. If found, the Editor camera will be set to it before starting the test.",
  247. )
  248. args = parser.parse_args()
  249. iterationCountMax = args.iterations
  250. screenUpdateWaitTime = args.screen_update_wait_time
  251. captureCountOnFailure = args.capture_count_on_failure
  252. screenshotImagePath = args.screenshot_image_path
  253. quickShaderOverwriteCount = args.quick_shader_overwrite_count
  254. quickShaderOverwriteWait = args.quick_shader_overwrite_wait
  255. cameraEntityName = args.camera_entity_name
  256. tmpDir = None
  257. if not screenshotImagePath:
  258. tmpDir = tempfile.TemporaryDirectory()
  259. screenshotImagePath = os.path.join(tmpDir.name, "shader_reload.png")
  260. print(f"The temporary file '{screenshotImagePath}' will be used to capture screenshots")
  261. else:
  262. if not ValidateImageExtension(screenshotImagePath):
  263. return # Exit test suite.
  264. cameraEntityId = AdjustEditorCameraPosition(cameraEntityName)
  265. ClearViewportOfHelpers()
  266. result, iterationCount = ShaderReloadTest(iterationCountMax, screenUpdateWaitTime, captureCountOnFailure, screenshotImagePath, quickShaderOverwriteCount, quickShaderOverwriteWait)
  267. if result:
  268. print(f"ShaderReloadTest PASSED after retrying {iterationCount}/{iterationCountMax} times.")
  269. else:
  270. print(f"ShaderReloadTest FAILED after retrying {iterationCount}/{iterationCountMax} times.")
  271. if cameraEntityId is not None:
  272. azcamera.EditorCameraViewRequestBus(azbus.Event, "ToggleCameraAsActiveView", cameraEntityId)
  273. if tmpDir:
  274. tmpDir.cleanup()
  275. if __name__ == "__main__":
  276. MainFunc()