diff options
Diffstat (limited to 'src/app')
-rw-r--r-- | src/app/init.rs | 101 | ||||
-rw-r--r-- | src/app/llm.rs | 98 | ||||
-rw-r--r-- | src/app/mod.rs | 1 |
3 files changed, 142 insertions, 58 deletions
diff --git a/src/app/init.rs b/src/app/init.rs index eabf204..aa74ae5 100644 --- a/src/app/init.rs +++ b/src/app/init.rs @@ -1,76 +1,61 @@ +use crate::app::llm::{Message, MessageType, LLM}; use crate::helper::init::print_in_file; -use color_eyre::Result; -use reqwest; -use serde_json::Value; -use std::{collections::HashMap, fmt}; +use tokio; -#[derive(Debug)] pub struct App { - pub messages: Vec<Message>, // History of recorded messages -} - -#[derive(Debug)] -pub struct Message { - content: String, - msg_type: MessageType, -} - -impl fmt::Display for Message { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self.msg_type { - MessageType::Human => return write!(f, "You: {}", self.content), - MessageType::LLM => return write!(f, "Néo AI: {}", self.content), - } - } -} - -#[derive(Debug)] -pub enum MessageType { - Human, - LLM, + pub messages: Vec<Message>, // History of recorded message + chat_llm: LLM, + resume_llm: LLM, } impl App { pub fn new() -> App { + let chat_llm: LLM = LLM::new("config/chat-LLM.json".to_string()).unwrap(); App { - messages: Vec::new(), + messages: vec![Message::new( + MessageType::SYSTEM, + chat_llm.system_prompt.clone(), + )], + chat_llm, + resume_llm: LLM::new("config/resume-LLM.json".to_string()).unwrap(), } } - pub fn send_message(&mut self, content: String) -> Result<()> { - // POST: http://localhost:8080/completion {"prompt": "lorem ipsum"} - self.messages.push(Message { - content: content.clone(), - msg_type: MessageType::Human, - }); + fn append_message(&mut self, msg: String, role: MessageType) { + let message = Message::new(role, msg); + self.messages.push(message); + } - let client = reqwest::blocking::Client::new(); - let response = client - .post("http://localhost:8080/completion") - .json(&serde_json::json!({ - "prompt": &content, - "n_predict": 400, - })) - .send()?; + pub fn send_message(&mut self, content: String) { + self.append_message(content, MessageType::USER); - if response.status().is_success() { - // Désérialiser la réponse JSON - let json_response: Value = response.json()?; + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build().unwrap(); + let result = runtime.block_on(async { + self.chat_llm.ask(&self.messages).await + }); - //print_in_file(json_response.to_string().clone()); - // Accéder à la partie spécifique du JSON - if let Some(msg) = json_response["content"].as_str() { - self.messages.push(Message { - content: msg.to_string().clone(), - msg_type: MessageType::LLM, - }); - } else { - println!("Le champ 'data.id' est absent ou mal formaté."); - } - } else { - eprintln!("La requête a échoué avec le statut : {}", response.status()); + match result { + Ok(msg) => self.append_message(msg, MessageType::ASSISTANT), + Err(e) => self.append_message(e.to_string(), MessageType::ASSISTANT), } + } - Ok(()) + pub fn resume_conv(&mut self) { + self.append_message(self.resume_llm.system_prompt.to_string(), MessageType::USER); + + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build().unwrap(); + + let result = runtime.block_on(async { + self.resume_llm.ask(&self.messages).await + }); + + match result { + Ok(msg) => self.append_message(msg, MessageType::ASSISTANT), + Err(e) => self.append_message(e.to_string(), MessageType::ASSISTANT), + } } } diff --git a/src/app/llm.rs b/src/app/llm.rs new file mode 100644 index 0000000..8603395 --- /dev/null +++ b/src/app/llm.rs @@ -0,0 +1,98 @@ +use crate::helper::init::print_in_file; +use reqwest::{header::CONTENT_TYPE, Client}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::fmt; +use std::fs; + +#[derive(Deserialize, Debug)] +pub struct LLM { + url: String, + model: String, + pub system_prompt: String, +} + +impl LLM { + pub fn new(config_file: String) -> Result<LLM, Box<dyn std::error::Error>> { + let contents = fs::read_to_string(config_file)?; + let llm: LLM = serde_json::from_str(&contents)?; + + Ok(llm) + } + + pub async fn ask(&self, messages: &Vec<Message>) -> Result<String, Box<dyn std::error::Error>> { + let client = Client::new(); + let response = client + .post(&self.url) + .header(CONTENT_TYPE, "application/json") + .json(&serde_json::json!({ + "model": self.model, + "messages": messages, + "stream": true})) + .send() + .await?; + + let mut full_message = String::new(); + + // Reading the stream and saving the response + match response.error_for_status() { + Ok(mut res) => { + while let Some(chunk) = res.chunk().await? { + let answer: Value = serde_json::from_slice(chunk.as_ref())?; + + print_in_file(answer.to_string()); + if answer["done"].as_bool().unwrap_or(false) { + break; + } + + let msg = answer["message"]["content"].as_str().unwrap_or("\n"); + + full_message.push_str(msg); + } + } + Err(e) => return Err(Box::new(e)), + } + + print_in_file(full_message.clone()); + Ok(full_message) + } +} + +#[derive(Debug, Serialize, Clone)] +pub enum MessageType { + ASSISTANT, + SYSTEM, + USER, +} + +impl fmt::Display for MessageType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + MessageType::ASSISTANT => write!(f, "assistant"), + MessageType::SYSTEM => write!(f, "system"), + MessageType::USER => write!(f, "user"), + } + } +} + +#[derive(Debug, Serialize, Clone)] +pub struct Message { + role: MessageType, + content: String, +} + +impl Message { + pub fn new(role: MessageType, content: String) -> Message { + Message { role, content } + } +} + +impl fmt::Display for Message { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.role { + MessageType::USER => return write!(f, "You: {}", self.content), + MessageType::SYSTEM => return write!(f, "System: {}", self.content), + MessageType::ASSISTANT => return write!(f, "Néo AI: {}", self.content), + } + } +} diff --git a/src/app/mod.rs b/src/app/mod.rs index 43763f1..3cff678 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -1 +1,2 @@ pub mod init; +pub mod llm; |