script_ray_cpu.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import os
  2. import gym
  3. from agit import Agent#之前的是eternatus
  4. from gym.spaces import Discrete, Box
  5. from ray import tune
  6. class SimpleCorridor(gym.Env):
  7. def __init__(self, config):
  8. self.end_pos = config['corridor_length']
  9. self.cur_pos = 0
  10. self.action_space = Discrete(2)
  11. self.observation_space = Box(0.0, self.end_pos, shape=(1,))
  12. def reset(self):
  13. self.cur_pos = 0
  14. return [self.cur_pos]
  15. def step(self, action):
  16. if action == 0 and self.cur_pos > 0:
  17. self.cur_pos -= 1
  18. elif action == 1:
  19. self.cur_pos += 1
  20. done = self.cur_pos >= self.end_pos
  21. return [self.cur_pos], 1 if done else 0, done, {}
  22. def main():
  23. from datetime import datetime
  24. start_time = datetime.utcnow()
  25. print('Python start time: {} UTC'.format(start_time))
  26. import tensorflow as tf
  27. print('TensorFlow CUDA is available: {}'.format(tf.config.list_physical_devices('GPU')))
  28. import torch
  29. print('pyTorch CUDA is available: {}'.format(torch.cuda.is_available()))
  30. if 'CLOUD_PROVIDER' in os.environ and os.environ['CLOUD_PROVIDER'] == 'Agit':
  31. provider = 'Agit'
  32. log_dir = '/root/.agit'
  33. results_dir = '/root/.agit'
  34. else:
  35. provider = 'local'
  36. log_dir = '../temp'
  37. results_dir = '../temp'
  38. # Initialize Ray Cluster
  39. #ray_init()
  40. tune.run(
  41. 'PPO',
  42. queue_trials=True, # Don't use this parameter unless you know what you do.
  43. stop={'training_iteration': 10},
  44. config={
  45. 'env': SimpleCorridor,
  46. 'env_config': {'corridor_length': 5}
  47. }
  48. )
  49. with open(os.path.join(results_dir, 'model.pkl'), 'wb') as file:
  50. file.write(b'model data')
  51. complete_time = datetime.utcnow()
  52. print('Python complete time: {} UTC'.format(complete_time))
  53. print('Python resource time: {} UTC'.format(complete_time - start_time))
  54. if __name__ == '__main__':
  55. main()