markov.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import random
  2. import re
  3. MRK_START = '__start'
  4. MRK_END = '__end'
  5. DEFAULT_CONFIG = {
  6. 'read': False,
  7. 'reply_on_mention': False,
  8. 'remove_mentions': True,
  9. }
  10. class MarkovGen:
  11. def __init__(self, states: dict[str, list[str]] = {}, config = DEFAULT_CONFIG) -> None:
  12. self.stateTable: dict[str, set[str]] = {k: set(v) for k,v in states.items()}
  13. self.config = DEFAULT_CONFIG | config
  14. def addMessage(self, inpMsg: str) -> None:
  15. samples = [v.lower() for v in inpMsg.split() if v != MRK_START and v != MRK_END]
  16. if len(samples) == 0: return
  17. samples.insert(0, MRK_START)
  18. samples.append(MRK_END)
  19. for i, val in enumerate(samples):
  20. if i + 1 > len(samples) - 1: break
  21. if val not in self.stateTable:
  22. self.stateTable[val] = set()
  23. self.stateTable[val].add(samples[i + 1])
  24. def dumpState(self) -> dict[str, list[str]]:
  25. return {k: list(v) for k,v in self.stateTable.items()}
  26. def generate(self, startMsg: str='') -> str:
  27. if len(self.stateTable) == 0: raise ValueError('No messages recorded!')
  28. out = [ MRK_START ]
  29. out.extend(startMsg.split())
  30. while out[-1] != MRK_END:
  31. try:
  32. nextVals = self.stateTable[out[-1]]
  33. except KeyError as exc:
  34. raise ValueError(f'Невозможно сгенерировать предложение, начинающеяся на {out[-1]}!') from exc
  35. out.append(random.choice(list(nextVals)))
  36. outString = ''.join([str(token) + ' ' for token in out if token != MRK_START and token != MRK_END])
  37. if not outString.startswith('http'):
  38. if self.config['remove_mentions']:
  39. print(outString.capitalize())
  40. sub = re.sub('(<@&*[0-9]*>)', '', outString.capitalize())
  41. print(sub)
  42. if sub == '':
  43. return '*(Пустое сообщение)*'
  44. return sub
  45. return outString.capitalize()
  46. if self.config['remove_mentions']:
  47. sub = re.sub('(<@&*[0-9]*>)', '', outString)
  48. return sub
  49. return outString
  50. if __name__ == '__main__':
  51. generator = MarkovGen()
  52. generator.addMessage('Дарова мусороиды прекрасные')
  53. generator.addMessage('Прекрасные сегодня цветы')
  54. generator.addMessage('Сегодня собаки прекрасные топчут цветы мои хорошие')
  55. for _ in range(10):
  56. print(generator.generate())