main.rs 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. use std::env;
  2. use reqwest::Client;
  3. use serde::{Deserialize, Serialize};
  4. use std::io::{self, BufRead};
  5. #[derive(Serialize, Deserialize)]
  6. struct ChatMessage {
  7. role: String,
  8. content: String,
  9. }
  10. #[derive(Serialize)]
  11. struct ChatRequest {
  12. model: String,
  13. messages: Vec<ChatMessage>,
  14. }
  15. #[derive(Deserialize)]
  16. struct ChatResponseChoice {
  17. message: ChatMessage,
  18. }
  19. #[derive(Deserialize)]
  20. struct ChatResponse {
  21. choices: Vec<ChatResponseChoice>,
  22. }
  23. struct ChatBot {
  24. openai_key: String,
  25. client: Client,
  26. }
  27. impl ChatBot {
  28. fn new(openai_key: String) -> ChatBot {
  29. ChatBot {
  30. openai_key,
  31. client: Client::new(),
  32. }
  33. }
  34. async fn chat(&self, messages: Vec<String>, model: &str) -> Result<String, Box<dyn std::error::Error>> {
  35. let mut chat_messages: Vec<ChatMessage> = Vec::new();
  36. for message in messages {
  37. chat_messages.push(ChatMessage {
  38. role: String::from("user"),
  39. content: message,
  40. });
  41. }
  42. if chat_messages.is_empty() {
  43. let stdin = io::stdin();
  44. println!("Enter your message (press Ctrl + D to send):");
  45. for line in stdin.lock().lines() {
  46. match line {
  47. Ok(input) => {
  48. chat_messages.push(ChatMessage {
  49. role: String::from("user"),
  50. content: input,
  51. });
  52. }
  53. Err(_) => break, // exit loop on error
  54. }
  55. }
  56. }
  57. let chat_request = ChatRequest {
  58. model: model.to_string(),
  59. messages: chat_messages,
  60. };
  61. let response = self.client
  62. .post("https://api.openai.com/v1/chat/completions")
  63. .header("Content-Type", "application/json")
  64. .header("Authorization", format!("Bearer {}", self.openai_key))
  65. .json(&chat_request)
  66. .send()
  67. .await?;
  68. if response.status().is_success() {
  69. let response_body: ChatResponse = response.json().await?;
  70. let message = &response_body.choices[0].message;
  71. Ok(message.content.to_string())
  72. } else {
  73. let status = response.status();
  74. let error_message = response.text().await?;
  75. Err(format!("HTTP Error {}: {}", status, error_message).into())
  76. }
  77. }
  78. }
  79. #[tokio::main]
  80. async fn main() -> Result<(), Box<dyn std::error::Error>> {
  81. // Retrieve openai_key from the environment
  82. let openai_key = match env::var("OPENAI_KEY") {
  83. Ok(val) => val,
  84. Err(_) => {
  85. println!("openai_key must be set in the environment!");
  86. return Ok(());
  87. }
  88. };
  89. // Retrieve model from the environment or use the default
  90. let model = env::var("OPENAI_MODEL")
  91. .unwrap_or_else(|_| "gpt-3.5-turbo-0613".to_string());
  92. let bot = ChatBot::new(openai_key);
  93. let args: Vec<String> = env::args().skip(1).collect();
  94. let response = bot.chat(args, &model).await?;
  95. println!("{}", response);
  96. Ok(())
  97. }