读取数据集GPU.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import os
  2. import gym
  3. import ray
  4. from gym.spaces import Discrete, Box
  5. from ray import tune
  6. class SimpleCorridor(gym.Env):
  7. def __init__(self, config):
  8. self.end_pos = config['corridor_length']
  9. self.cur_pos = 0
  10. self.action_space = Discrete(2)
  11. self.observation_space = Box(0.0, self.end_pos, shape=(1,))
  12. def reset(self):
  13. self.cur_pos = 0
  14. return [self.cur_pos]
  15. def step(self, action):
  16. if action == 0 and self.cur_pos > 0:
  17. self.cur_pos -= 1
  18. elif action == 1:
  19. self.cur_pos += 1
  20. done = self.cur_pos >= self.end_pos
  21. return [self.cur_pos], 1 if done else 0, done, {}
  22. if __name__ == '__main__':
  23. from datetime import datetime
  24. start_time = datetime.utcnow()
  25. print('Python start time: {} UTC'.format(start_time))
  26. if 'CLOUD_PROVIDER' in os.environ and os.environ['CLOUD_PROVIDER'] == 'Agit':
  27. from agit import ray_init
  28. ray_init()
  29. from agit import open
  30. dataset_path = 'agit://'
  31. else:
  32. ray.init()
  33. dataset_path = './'
  34. print('Ray Cluster Resources: {}'.format(ray.cluster_resources()))
  35. import tensorflow as tf
  36. print('TensorFlow CUDA is available: {}'.format(tf.config.list_physical_devices('GPU')))
  37. import torch
  38. print('pyTorch CUDA is available: {}'.format(torch.cuda.is_available()))
  39. with open(dataset_path + 'expert_data.csv', 'rb') as file:
  40. raw_data = file.read()
  41. print(raw_data)
  42. tune.run(
  43. 'PPO',
  44. queue_trials=True, # Don't use this parameter unless you know what you do.
  45. stop={'training_iteration': 10},
  46. config={
  47. 'env': SimpleCorridor,
  48. 'env_config': {'corridor_length': 5},
  49. 'num_gpus': 1
  50. }
  51. )
  52. complete_time = datetime.utcnow()
  53. print('Python complete time: {} UTC'.format(complete_time))
  54. print('Python resource time: {} UTC'.format(complete_time - start_time))