amd_hip_windows.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. package gpu
  2. import (
  3. "fmt"
  4. "log/slog"
  5. "syscall"
  6. "unsafe"
  7. "golang.org/x/sys/windows"
  8. )
  9. const (
  10. hipSuccess = 0
  11. hipErrorNoDevice = 100
  12. )
  13. type hipDevicePropMinimal struct {
  14. Name [256]byte
  15. unused1 [140]byte
  16. GcnArchName [256]byte // gfx####
  17. iGPU int // Doesn't seem to actually report correctly
  18. unused2 [128]byte
  19. }
  20. // Wrap the amdhip64.dll library for GPU discovery
  21. type HipLib struct {
  22. dll windows.Handle
  23. hipGetDeviceCount uintptr
  24. hipGetDeviceProperties uintptr
  25. hipMemGetInfo uintptr
  26. hipSetDevice uintptr
  27. hipDriverGetVersion uintptr
  28. }
  29. func NewHipLib() (*HipLib, error) {
  30. h, err := windows.LoadLibrary("amdhip64.dll")
  31. if err != nil {
  32. return nil, fmt.Errorf("unable to load amdhip64.dll: %w", err)
  33. }
  34. hl := &HipLib{}
  35. hl.dll = h
  36. hl.hipGetDeviceCount, err = windows.GetProcAddress(hl.dll, "hipGetDeviceCount")
  37. if err != nil {
  38. return nil, err
  39. }
  40. hl.hipGetDeviceProperties, err = windows.GetProcAddress(hl.dll, "hipGetDeviceProperties")
  41. if err != nil {
  42. return nil, err
  43. }
  44. hl.hipMemGetInfo, err = windows.GetProcAddress(hl.dll, "hipMemGetInfo")
  45. if err != nil {
  46. return nil, err
  47. }
  48. hl.hipSetDevice, err = windows.GetProcAddress(hl.dll, "hipSetDevice")
  49. if err != nil {
  50. return nil, err
  51. }
  52. hl.hipDriverGetVersion, err = windows.GetProcAddress(hl.dll, "hipDriverGetVersion")
  53. if err != nil {
  54. return nil, err
  55. }
  56. return hl, nil
  57. }
  58. // The hip library only evaluates the HIP_VISIBLE_DEVICES variable at startup
  59. // so we have to unload/reset the library after we do our initial discovery
  60. // to make sure our updates to that variable are processed by llama.cpp
  61. func (hl *HipLib) Release() {
  62. err := windows.FreeLibrary(hl.dll)
  63. if err != nil {
  64. slog.Warn("failed to unload amdhip64.dll", "error", err)
  65. }
  66. hl.dll = 0
  67. }
  68. func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
  69. if hl.dll == 0 {
  70. return 0, 0, fmt.Errorf("dll has been unloaded")
  71. }
  72. var version int
  73. status, _, err := syscall.SyscallN(hl.hipDriverGetVersion, uintptr(unsafe.Pointer(&version)))
  74. if status != hipSuccess {
  75. return 0, 0, fmt.Errorf("failed call to hipDriverGetVersion: %d %s", status, err)
  76. }
  77. slog.Debug("hipDriverGetVersion", "version", version)
  78. driverMajor = version / 10000000
  79. driverMinor = (version - (driverMajor * 10000000)) / 100000
  80. return driverMajor, driverMinor, nil
  81. }
  82. func (hl *HipLib) HipGetDeviceCount() int {
  83. if hl.dll == 0 {
  84. slog.Error("dll has been unloaded")
  85. return 0
  86. }
  87. var count int
  88. status, _, err := syscall.SyscallN(hl.hipGetDeviceCount, uintptr(unsafe.Pointer(&count)))
  89. if status == hipErrorNoDevice {
  90. slog.Info("AMD ROCm reports no devices found")
  91. return 0
  92. }
  93. if status != hipSuccess {
  94. slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err)
  95. }
  96. return count
  97. }
  98. func (hl *HipLib) HipSetDevice(device int) error {
  99. if hl.dll == 0 {
  100. return fmt.Errorf("dll has been unloaded")
  101. }
  102. status, _, err := syscall.SyscallN(hl.hipSetDevice, uintptr(device))
  103. if status != hipSuccess {
  104. return fmt.Errorf("failed call to hipSetDevice: %d %s", status, err)
  105. }
  106. return nil
  107. }
  108. func (hl *HipLib) HipGetDeviceProperties(device int) (*hipDevicePropMinimal, error) {
  109. if hl.dll == 0 {
  110. return nil, fmt.Errorf("dll has been unloaded")
  111. }
  112. var props hipDevicePropMinimal
  113. status, _, err := syscall.SyscallN(hl.hipGetDeviceProperties, uintptr(unsafe.Pointer(&props)), uintptr(device))
  114. if status != hipSuccess {
  115. return nil, fmt.Errorf("failed call to hipGetDeviceProperties: %d %s", status, err)
  116. }
  117. return &props, nil
  118. }
  119. // free, total, err
  120. func (hl *HipLib) HipMemGetInfo() (uint64, uint64, error) {
  121. if hl.dll == 0 {
  122. return 0, 0, fmt.Errorf("dll has been unloaded")
  123. }
  124. var totalMemory uint64
  125. var freeMemory uint64
  126. status, _, err := syscall.SyscallN(hl.hipMemGetInfo, uintptr(unsafe.Pointer(&freeMemory)), uintptr(unsafe.Pointer(&totalMemory)))
  127. if status != hipSuccess {
  128. return 0, 0, fmt.Errorf("failed call to hipMemGetInfo: %d %s", status, err)
  129. }
  130. return freeMemory, totalMemory, nil
  131. }