replay_video.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. __author__ = "Christian Heider Nielsen"
  4. __doc__ = r"""
  5. Created on 28-03-2021
  6. """
  7. import time
  8. import numpy
  9. from trolls.render_mode import RenderModeEnum
  10. from draugr import PROJECT_APP_PATH
  11. from draugr.torch_utilities import (
  12. TensorBoardPytorchWriter,
  13. to_tensor,
  14. )
  15. from draugr.torch_utilities.tensors.dimension_order import (
  16. nhwc_to_nchw_tensor,
  17. )
  18. from draugr.torch_utilities.writers.tensorboard.tensorboard_pytorch_writer import (
  19. VideoInputDimsEnum,
  20. )
  21. if __name__ == "__main__":
  22. def main() -> None:
  23. """
  24. :rtype: None
  25. """
  26. import gym
  27. env = gym.make("Pendulum-v1")
  28. state = env.reset()
  29. with TensorBoardPytorchWriter(
  30. PROJECT_APP_PATH.user_log / "Tests" / "Writers"
  31. ) as writer:
  32. frames = []
  33. done = False
  34. start = time.time()
  35. while not done:
  36. frames.append(env.render(mode=RenderModeEnum.rgb_array.value))
  37. state, reward, done, info = env.step(env.action_space.sample())
  38. fps = len(frames) / (time.time() - start)
  39. env.close()
  40. video_array = numpy.array(frames)
  41. print(video_array.shape)
  42. writer.video(
  43. "replay05",
  44. nhwc_to_nchw_tensor(to_tensor(video_array)).unsqueeze(0),
  45. frame_rate=fps,
  46. )
  47. writer.video(
  48. "replay06",
  49. video_array,
  50. 0,
  51. input_dims=VideoInputDimsEnum.thwc,
  52. frame_rate=fps,
  53. )
  54. writer.video(
  55. "replay08",
  56. numpy.stack([video_array, video_array]),
  57. 0,
  58. input_dims=VideoInputDimsEnum.nthwc,
  59. frame_rate=fps,
  60. )
  61. main()