conversation.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from pysentimiento import create_analyzer
  2. import collections
  3. import re
  4. # msg constructor and formatter
  5. class character_msg_constructor:
  6. def __init__(self, name, char_persona):
  7. self.name = name
  8. self.persona = char_persona
  9. self.conversation_history = ''
  10. self.emotion_analyzer = create_analyzer(task="emotion", lang="en")
  11. self.split_counter = 0
  12. self.history_loop_cache = ''
  13. def construct_msg(self, text:str, conversation_history=None) -> str:
  14. if conversation_history != None:
  15. self.conversation_history = f'{self.conversation_history}\n{conversation_history}' # add conversation history
  16. if len(self.conversation_history.split('\n')) > 40: # limit conversation history to prevent memory leak
  17. self.conversation_history = self.conversation_history.split('\n')[-6:] # replace with last 4 lines
  18. self.split_counter = 2
  19. conversation_template = f"""{self.name}'s Persona: {self.persona}
  20. {self.conversation_history.strip()}
  21. You: {text}
  22. """
  23. return '\n'.join([x.strip() for x in conversation_template.split('\n')])
  24. # conversation formatter
  25. def get_current_converse(self, conversation_text:str) -> list:
  26. splited = [x.strip() for x in conversation_text.split('\n') if x != '']
  27. conversation_list = []
  28. conversation_line_count = 0
  29. for idx, thisline in enumerate(splited):
  30. holder = conversation_line_count
  31. if thisline.startswith(f'{self.name}:') or thisline.startswith('You:'): # if found talking line
  32. holder += 1
  33. if holder > conversation_line_count: # append talking line at each found
  34. conversation_list.append(thisline)
  35. conversation_line_count = holder
  36. elif conversation_line_count > 0: # concat conversation into the line before if no new converse line found
  37. conversation_list[-1] = f'{conversation_list[-1].strip()} {thisline.strip()}'
  38. return conversation_list
  39. def emotion_analyze(self, text:str) -> list:
  40. emotions_text = text
  41. if '*' in text:
  42. emotions_text = re.findall(r'\*(.*?)\*', emotions_text) # get emotion *action* as input if exist
  43. emotions_text = ' '.join(emotions_text) # create input
  44. emotions = self.emotion_analyzer.predict(emotions_text).probas
  45. ordered = dict(sorted(emotions.items(), key=lambda x: x[1]))
  46. ordered = [k for k, v in ordered.items()] # top two emotion
  47. ordered.reverse()
  48. return ordered[:2]
  49. def clean_emotion_action_text_for_speech(self, text):
  50. clean_text = re.sub(r'\*.*?\*', '', text) # remove *action* from text
  51. clean_text = clean_text.replace(f'{self.name}:', '') # replace -> name: "dialog"
  52. return clean_text