123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- use std::env;
- use reqwest::Client;
- use serde::{Deserialize, Serialize};
- use std::io::{self, BufRead};
- #[derive(Serialize, Deserialize)]
- struct ChatMessage {
- role: String,
- content: String,
- }
- #[derive(Serialize)]
- struct ChatRequest {
- model: String,
- messages: Vec<ChatMessage>,
- }
- #[derive(Deserialize)]
- struct ChatResponseChoice {
- message: ChatMessage,
- }
- #[derive(Deserialize)]
- struct ChatResponse {
- choices: Vec<ChatResponseChoice>,
- }
- struct ChatBot {
- openai_key: String,
- client: Client,
- }
- impl ChatBot {
- fn new(openai_key: String) -> ChatBot {
- ChatBot {
- openai_key,
- client: Client::new(),
- }
- }
- async fn chat(&self, messages: Vec<String>, model: &str) -> Result<String, Box<dyn std::error::Error>> {
- let mut chat_messages: Vec<ChatMessage> = Vec::new();
- for message in messages {
- chat_messages.push(ChatMessage {
- role: String::from("user"),
- content: message,
- });
- }
- if chat_messages.is_empty() {
- let stdin = io::stdin();
- println!("Enter your message (press Ctrl + D to send):");
- for line in stdin.lock().lines() {
- match line {
- Ok(input) => {
- chat_messages.push(ChatMessage {
- role: String::from("user"),
- content: input,
- });
- }
- Err(_) => break, // exit loop on error
- }
- }
- }
- let chat_request = ChatRequest {
- model: model.to_string(),
- messages: chat_messages,
- };
- let response = self.client
- .post("https://api.openai.com/v1/chat/completions")
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", self.openai_key))
- .json(&chat_request)
- .send()
- .await?;
- if response.status().is_success() {
- let response_body: ChatResponse = response.json().await?;
- let message = &response_body.choices[0].message;
- Ok(message.content.to_string())
- } else {
- let status = response.status();
- let error_message = response.text().await?;
- Err(format!("HTTP Error {}: {}", status, error_message).into())
- }
- }
- }
- #[tokio::main]
- async fn main() -> Result<(), Box<dyn std::error::Error>> {
- // Retrieve openai_key from the environment
- let openai_key = match env::var("OPENAI_KEY") {
- Ok(val) => val,
- Err(_) => {
- println!("openai_key must be set in the environment!");
- return Ok(());
- }
- };
- // Retrieve model from the environment or use the default
- let model = env::var("OPENAI_MODEL")
- .unwrap_or_else(|_| "gpt-3.5-turbo-0613".to_string());
- let bot = ChatBot::new(openai_key);
- let args: Vec<String> = env::args().skip(1).collect();
- let response = bot.chat(args, &model).await?;
- println!("{}", response);
- Ok(())
- }
|