requirements_jax.txt 158 B

123
  1. tensorboard
  2. jax[cpu] # pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
  3. # conda install cuda -c nvidia