cuda_common.go 489 B

12345678910111213141516171819202122
  1. //go:build linux || windows
  2. package gpu
  3. import (
  4. "log/slog"
  5. "strings"
  6. )
  7. func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
  8. ids := []string{}
  9. for _, info := range gpuInfo {
  10. if info.Library != "cuda" {
  11. // TODO shouldn't happen if things are wired correctly...
  12. slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
  13. continue
  14. }
  15. ids = append(ids, info.ID)
  16. }
  17. return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
  18. }