server.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import time
  2. import mindspore as ms
  3. import numpy as np
  4. from mindformers.models.glm import GLMConfig, GLMChatModel
  5. from mindformers.models.glm.chatglm_6b_tokenizer import ChatGLMTokenizer
  6. from mindformers.models.glm.glm_processor import process_response
  7. import uvicorn
  8. from fastapi import FastAPI
  9. from fastapi.middleware.cors import CORSMiddleware
  10. from pydantic import BaseModel
  11. # 配置ChatGLM
  12. config = GLMConfig(
  13. position_encoding_2d=True,
  14. use_past=True,
  15. is_sample_acceleration=True,
  16. )
  17. #对模型进行初始化
  18. ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend", device_id=1)
  19. model = GLMChatModel(config)
  20. ms.load_checkpoint("./checkpoint_download/glm/glm_6b_chat.ckpt", model)
  21. tokenizer = ChatGLMTokenizer('./checkpoint_download/glm/ice_text.model')
  22. #初始化FastAPI应用
  23. app = FastAPI()
  24. app.add_middleware(CORSMiddleware,
  25. allow_origins=['*'],
  26. allow_credentials=True,
  27. allow_methods=['*'],
  28. allow_headers=['*'])
  29. #定义ChatInfo消息类
  30. class ChatInfo(BaseModel):
  31. owner: str
  32. msg: str
  33. unique_id: str
  34. #让模型产生回复
  35. def generate_response(query):
  36. input_ids = tokenizer(query)['input_ids']
  37. start_time = time.time()
  38. outputs = model.generate(input_ids,
  39. max_length=config.max_decode_length,
  40. do_sample=False)
  41. end_time = time.time()
  42. print(f'generate speed: {outputs[0].shape[0]/(end_time-start_time):.2f} tokens/s')
  43. response = tokenizer.decode(outputs)
  44. response = process_response(response[0])
  45. return response
  46. prompts = ["我很焦虑,我应该怎么办", "其他人是怎么应对焦虑的呢?", "你有过焦虑的时候吗?"]
  47. #这里需要补充对话
  48. history = []
  49. #提交对话信息
  50. @app.post('/chat')
  51. async def chat(ChatInfo: ChatInfo):
  52. unique_id = ChatInfo.unique_id
  53. existing_files = os.listdir('./dialogues')
  54. # print(existing_files)
  55. target_file = f'{unique_id}.json'
  56. if target_file in existing_files:
  57. with open(f'./dialogues/{unique_id}.json', 'r', encoding='utf-8') as f:
  58. data: list = ujson.load(f)
  59. else:
  60. data = []
  61. data.append({
  62. 'owner': ChatInfo.owner,
  63. 'msg': ChatInfo.msg,
  64. 'unique_id': ChatInfo.unique_id
  65. })
  66. input_str = ''
  67. for item in data:
  68. if item['owner'] == 'seeker':
  69. input_str += '求助者:' + item['msg']
  70. else:
  71. input_str += '支持者:' + item['msg']
  72. input_str += '支持者:'
  73. while len(input_str) > 2000:
  74. if input_str.index('求助者:') > input_str.index('支持者:'):
  75. start_idx = input_str.index('求助者:')
  76. else:
  77. start_idx = input_str.index('支持者:')
  78. input_str = input_str[start_idx:]
  79. wrapped_data = input_str
  80. response = generate_response(data=wrapped_data)
  81. supporter_msg = {
  82. 'owner': 'supporter',
  83. 'msg': response,
  84. 'unique_id': unique_id
  85. }
  86. data.append(supporter_msg)
  87. with open(f'./dialogues/{unique_id}.json', 'w', encoding='utf-8') as f:
  88. ujson.dump(data, f, ensure_ascii=False, indent=2)
  89. return {'item': supporter_msg, 'responseCode': 200}
  90. if __name__ == '__main__':
  91. uvicorn.run(app, host='127.0.0.1', port=8080)