Source code for draugr.torch_utilities.system.device

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from enum import Enum
from typing import Union

import torch

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 15/11/2019
           """

from sorcery import assigned_names

GLOBAL_DEVICE: torch.device = None

__all__ = [
    "global_torch_device",
    "select_cuda_device",
    "get_gpu_usage_mb",
    "auto_select_available_cuda_device",
    "set_global_torch_device",
    "torch_clean_up",
    "TorchDeviceEnum",
]


[docs]class TorchDeviceEnum(Enum): (cpu, cuda, vulkan) = assigned_names()
[docs]def global_torch_device( device_preference: Union[bool, str, TorchDeviceEnum] = None, override: torch.device = None, verbose: bool = False, ) -> torch.device: """ first time call stores to device for global reference, later call must explicitly manually override! :param verbose: :type verbose: :param device_preference: :type device_preference: :param override: :type override: :return: :rtype:""" global GLOBAL_DEVICE if override is not None: GLOBAL_DEVICE = override if verbose: print(f"Overriding global torch device to {override}") elif device_preference is not None: if isinstance(device_preference, bool): if torch.is_vulkan_available() and device_preference: d_ = TorchDeviceEnum.vulkan elif torch.cuda.is_available() and device_preference: d_ = TorchDeviceEnum.cuda else: d_ = TorchDeviceEnum.cpu d = torch.device(d_.value) elif isinstance(device_preference, TorchDeviceEnum): d = torch.device(device_preference.value) elif isinstance(device_preference, str): d = torch.device(device_preference) else: raise TypeError("not bool or str") if GLOBAL_DEVICE is None: GLOBAL_DEVICE = d return d elif GLOBAL_DEVICE is None: if torch.is_vulkan_available(): d_ = TorchDeviceEnum.vulkan elif torch.cuda.is_available(): d_ = TorchDeviceEnum.cuda else: d_ = TorchDeviceEnum.cpu GLOBAL_DEVICE = torch.device(d_.value) return GLOBAL_DEVICE
[docs]def set_global_torch_device(device: torch.device) -> None: """ :param device: :return:""" global GLOBAL_DEVICE GLOBAL_DEVICE = device
[docs]def select_cuda_device(cuda_device_idx: int) -> torch.device: """ :param cuda_device_idx: :type cuda_device_idx: :return: :rtype:""" num_cuda_device = torch.cuda.device_count() assert num_cuda_device > 0 assert cuda_device_idx < num_cuda_device if 0 <= cuda_device_idx < num_cuda_device: return torch.device(f"cuda:{cuda_device_idx}")
[docs]def get_gpu_usage_mb(): """ :return: :rtype:""" import subprocess """Get the current gpu usage. Returns ------- usage: dict Keys are device ids as integers. Values are memory usage as integers in MB. """ result = subprocess.check_output( ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"] ).decode("utf-8") # Convert lines into a dictionary gpu_memory = [int(x) for x in result.strip().split("\n")] gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory)) return gpu_memory_map
[docs]def torch_clean_up() -> None: r"""**Destroy cuda state by emptying cache and collecting IPC.** Consecutively calls `torch.cuda.empty_cache()` and `torch.cuda.ipc_collect()`.""" torch.cuda.empty_cache() torch.cuda.ipc_collect()
[docs]def auto_select_available_cuda_device( expected_memory_usage_mb: int = 1024, ) -> torch.device: r""" Auto selects the device with highest compute capability and with the requested memory available :param expected_memory_usage_mb: :type expected_memory_usage_mb: :return: :rtype:""" num_cuda_device = torch.cuda.device_count() # TODO: torch.vulkan.device_count() variant in the future assert num_cuda_device > 0 """ print(torch.cuda.cudart()) print(torch.cuda.memory_snapshot()) torch.cuda.memory_cached(dev_idx), torch.cuda.memory_allocated(dev_idx), torch.cuda.max_memory_allocated(dev_idx), torch.cuda.max_memory_cached(dev_idx), torch.cuda.get_device_name(dev_idx), torch.cuda.get_device_properties(dev_idx), torch.cuda.memory_stats(dev_idx) """ preferred_idx = None highest_capability = 0 for dev_idx, usage in enumerate(get_gpu_usage_mb().values()): cuda_capability = float( ".".join([str(x) for x in torch.cuda.get_device_capability(dev_idx)]) ) if expected_memory_usage_mb: total_mem = ( torch.cuda.get_device_properties(dev_idx).total_memory // 1000 // 1000 ) if expected_memory_usage_mb < total_mem - usage: if cuda_capability > highest_capability: highest_capability = cuda_capability preferred_idx = dev_idx else: if cuda_capability > highest_capability: highest_capability = cuda_capability preferred_idx = dev_idx if preferred_idx is None: raise FileNotFoundError( f"No device with {expected_memory_usage_mb} mb memory found" ) return select_cuda_device(preferred_idx)
if __name__ == "__main__": def stest_override() -> None: """ :rtype: None """ print(global_torch_device(verbose=True)) print( global_torch_device( override=global_torch_device(device_preference=False, verbose=True), verbose=True, ) ) print(global_torch_device(verbose=True)) print(global_torch_device(device_preference=True)) print(global_torch_device()) print( global_torch_device( override=global_torch_device(device_preference=True, verbose=True) ) ) print(global_torch_device()) def a() -> None: """ :rtype: None """ print(global_torch_device()) print(auto_select_available_cuda_device()) def b() -> None: """ :rtype: None """ print(global_torch_device(TorchDeviceEnum.vulkan)) # stest_override() print(global_torch_device(False).type) b()