diff options
Diffstat (limited to 'src/app/init.rs')
-rw-r--r-- | src/app/init.rs | 101 |
1 files changed, 43 insertions, 58 deletions
diff --git a/src/app/init.rs b/src/app/init.rs index eabf204..aa74ae5 100644 --- a/src/app/init.rs +++ b/src/app/init.rs @@ -1,76 +1,61 @@ +use crate::app::llm::{Message, MessageType, LLM}; use crate::helper::init::print_in_file; -use color_eyre::Result; -use reqwest; -use serde_json::Value; -use std::{collections::HashMap, fmt}; +use tokio; -#[derive(Debug)] pub struct App { - pub messages: Vec<Message>, // History of recorded messages -} - -#[derive(Debug)] -pub struct Message { - content: String, - msg_type: MessageType, -} - -impl fmt::Display for Message { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self.msg_type { - MessageType::Human => return write!(f, "You: {}", self.content), - MessageType::LLM => return write!(f, "Néo AI: {}", self.content), - } - } -} - -#[derive(Debug)] -pub enum MessageType { - Human, - LLM, + pub messages: Vec<Message>, // History of recorded message + chat_llm: LLM, + resume_llm: LLM, } impl App { pub fn new() -> App { + let chat_llm: LLM = LLM::new("config/chat-LLM.json".to_string()).unwrap(); App { - messages: Vec::new(), + messages: vec![Message::new( + MessageType::SYSTEM, + chat_llm.system_prompt.clone(), + )], + chat_llm, + resume_llm: LLM::new("config/resume-LLM.json".to_string()).unwrap(), } } - pub fn send_message(&mut self, content: String) -> Result<()> { - // POST: http://localhost:8080/completion {"prompt": "lorem ipsum"} - self.messages.push(Message { - content: content.clone(), - msg_type: MessageType::Human, - }); + fn append_message(&mut self, msg: String, role: MessageType) { + let message = Message::new(role, msg); + self.messages.push(message); + } - let client = reqwest::blocking::Client::new(); - let response = client - .post("http://localhost:8080/completion") - .json(&serde_json::json!({ - "prompt": &content, - "n_predict": 400, - })) - .send()?; + pub fn send_message(&mut self, content: String) { + self.append_message(content, MessageType::USER); - if response.status().is_success() { - // Désérialiser la réponse JSON - let json_response: Value = response.json()?; + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build().unwrap(); + let result = runtime.block_on(async { + self.chat_llm.ask(&self.messages).await + }); - //print_in_file(json_response.to_string().clone()); - // Accéder à la partie spécifique du JSON - if let Some(msg) = json_response["content"].as_str() { - self.messages.push(Message { - content: msg.to_string().clone(), - msg_type: MessageType::LLM, - }); - } else { - println!("Le champ 'data.id' est absent ou mal formaté."); - } - } else { - eprintln!("La requête a échoué avec le statut : {}", response.status()); + match result { + Ok(msg) => self.append_message(msg, MessageType::ASSISTANT), + Err(e) => self.append_message(e.to_string(), MessageType::ASSISTANT), } + } - Ok(()) + pub fn resume_conv(&mut self) { + self.append_message(self.resume_llm.system_prompt.to_string(), MessageType::USER); + + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build().unwrap(); + + let result = runtime.block_on(async { + self.resume_llm.ask(&self.messages).await + }); + + match result { + Ok(msg) => self.append_message(msg, MessageType::ASSISTANT), + Err(e) => self.append_message(e.to_string(), MessageType::ASSISTANT), + } } } |