models.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from dataclasses import dataclass
  2. from enum import Enum
  3. import torch
  4. from omegaconf import OmegaConf
  5. from IPython.display import Audio
  6. @dataclass
  7. class ModelInfo():
  8. model_id: str
  9. languages: list[str]
  10. language: str
  11. device: str
  12. sample_rate: int
  13. speakers: list
  14. class Device(Enum):
  15. CPU = 'cpu'
  16. GPU = 'gpu'
  17. class SampleRate(Enum):
  18. HI = 48000
  19. MID = 24000
  20. LOW = 8000
  21. class SpeakerModel:
  22. """
  23. Ititialize speaker's model object
  24. """
  25. def __init__(self):
  26. self.model_id = 'v3_1_ru'
  27. self.lang = 'ru'
  28. self.sample_rate = SampleRate.HI.value
  29. self.device = Device.CPU
  30. self.put_accent=True
  31. self.put_yo=True
  32. self.languages = []
  33. self.models = self.__get_models()
  34. def __get_models(self):
  35. """Get models from github"""
  36. try:
  37. torch.hub.download_url_to_file(
  38. 'https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',
  39. 'latest_silero_models.yml',
  40. progress=False)
  41. except Exception:
  42. print("ERROR downloading models. You need internet connection")
  43. exit()
  44. self.models = OmegaConf.load('latest_silero_models.yml')
  45. self.languages = list(self.models.tts_models.keys())
  46. def specify_model(
  47. self,
  48. lang: str = None,
  49. device: Device = None,
  50. sample_rate: SampleRate = None,
  51. model_id: str = None):
  52. """
  53. Accept lenguage, device and
  54. """
  55. if lang: self.lang = lang
  56. if sample_rate: self.semple_rate = sample_rate.value
  57. if device: self.device = device
  58. if model_id: self.model_id = model_id
  59. set_device = torch.device(self.device.value)
  60. if self.lang not in self.languages:
  61. raise Exception(f"Language is not valid. You cat shoose from {self.languages}")
  62. self.model, self.example_text = torch.hub.load(repo_or_dir='snakers4/silero-models',
  63. model='silero_tts',
  64. language=self.lang,
  65. speaker=self.model_id)
  66. self.model.to(set_device)
  67. self.__get_speakers()
  68. def __get_speakers(self):
  69. """
  70. Gets the list of aviabble languages
  71. """
  72. try:
  73. audio = self.model.apply_tts(text='test',
  74. speaker="test",
  75. sample_rate=self.sample_rate,
  76. put_accent=self.put_accent,
  77. put_yo=self.put_yo)
  78. except Exception as err:
  79. self.speakers = str(err).split('`speaker` should be in ')[1].split(', ')
  80. self.speaker = self.speakers[0]
  81. def get_model_info(self):
  82. return ModelInfo(
  83. model_id=self.model_id,
  84. languages=self.languages,
  85. language=self.lang,
  86. device=self.device.value,
  87. sample_rate=self.sample_rate,
  88. speakers=self.speakers)
  89. def text2speech(self, text: str, filename: str, speaker: str = None) -> None:
  90. audio = self.model.apply_tts(text=text,
  91. speaker=speaker or self.speaker,
  92. sample_rate=self.sample_rate,
  93. put_accent=self.put_accent,
  94. put_yo=self.put_yo)
  95. a = Audio(audio, rate=self.sample_rate)
  96. with open(filename, 'wb') as f:
  97. f.write(a.data)