diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/app.rs (renamed from src/app/mod.rs) | 1 | ||||
-rw-r--r-- | src/app/init.rs | 26 | ||||
-rw-r--r-- | src/app/llm.rs | 19 | ||||
-rw-r--r-- | src/app/modules.rs | 1 | ||||
-rw-r--r-- | src/app/modules/wikipedia.rs | 108 | ||||
-rw-r--r-- | src/helper.rs (renamed from src/helper/mod.rs) | 0 | ||||
-rw-r--r-- | src/helper/init.rs | 2 | ||||
-rw-r--r-- | src/lib.rs | 3 | ||||
-rw-r--r-- | src/ui.rs (renamed from src/ui/mod.rs) | 0 | ||||
-rw-r--r-- | src/ui/init.rs | 19 |
10 files changed, 154 insertions, 25 deletions
diff --git a/src/app/mod.rs b/src/app.rs index 3cff678..8376f8c 100644 --- a/src/app/mod.rs +++ b/src/app.rs @@ -1,2 +1,3 @@ pub mod init; pub mod llm; +pub mod modules; diff --git a/src/app/init.rs b/src/app/init.rs index dd8d3b7..0930319 100644 --- a/src/app/init.rs +++ b/src/app/init.rs @@ -1,27 +1,30 @@ use crate::app::llm::{Message, MessageType, LLM}; +use crate::app::modules::wikipedia::ask_wiki; use crate::helper::init::warn; use uuid::Uuid; use tokio::runtime::Builder; pub struct App { pub messages: Vec<Message>, // History of recorded message - conv_id: Uuid, - chat_llm: LLM, - resume_llm: LLM, + pub conv_id: Uuid, // ID for retrieving and saving the history of messag + categorize_llm: LLM, + chat_llm: LLM, // Configuration for the LLM that chat with you + resume_llm: LLM, // Configuration for the LLM that resume conversation } impl App { pub fn new() -> App { - let chat_llm: LLM = LLM::new("config/chat-LLM.json".to_string()); + let categorize_llm = LLM::new("config/categorize-LLM.json"); App { messages: vec![Message::new( MessageType::SYSTEM, - chat_llm.system_prompt.clone(), + categorize_llm.system_prompt.clone(), )], conv_id: Uuid::new_v4(), - chat_llm, - resume_llm: LLM::new("config/resume-LLM.json".to_string()), + categorize_llm, + chat_llm: LLM::new("config/chat-LLM.json"), + resume_llm: LLM::new("config/resume-LLM.json"), } } @@ -29,7 +32,7 @@ impl App { let message = Message::new(role, msg); let err = message.save_message(self.conv_id.to_string()); - warn(err.is_err().to_string()); + //warn(err.is_err().to_string()); self.messages.push(message); } @@ -39,19 +42,20 @@ impl App { let result = runtime.block_on(async { // Ask the LLM to categorise the request between (chat, code, wikipedia) - self.chat_llm.ask_format(&self.messages).await + self.categorize_llm.ask_tools(&self.messages).await }); match result { Ok(msg) => { let categorie = msg[0]["function"]["arguments"]["category"].clone(); - self.ask(categorie.to_string().as_str()); + self.ask(&categorie.to_string().replace("\"", "")); }, Err(e) => self.append_message(e.to_string(), MessageType::ASSISTANT), } } fn ask(&mut self, mode: &str) { + warn(format!("Categorie: {}", mode)); let runtime = Builder::new_current_thread() .enable_all() .build().unwrap(); @@ -59,6 +63,8 @@ impl App { let result = runtime.block_on(async { if mode == "resume" { self.resume_llm.ask(&self.messages).await + } else if mode == "wikipedia" { + ask_wiki(&self.messages).await } else { self.chat_llm.ask(&self.messages).await } diff --git a/src/app/llm.rs b/src/app/llm.rs index 9c6d222..979b790 100644 --- a/src/app/llm.rs +++ b/src/app/llm.rs @@ -15,7 +15,7 @@ pub struct LLM { } impl LLM { - pub fn new(config_file: String) -> LLM { + pub fn new(config_file: &str) -> LLM { let contents = fs::read_to_string(config_file).unwrap(); serde_json::from_str(&contents).unwrap() } @@ -40,7 +40,7 @@ impl LLM { while let Some(chunk) = res.chunk().await? { let answer: Value = serde_json::from_slice(chunk.as_ref())?; - warn(answer.to_string()); + //warn(answer.to_string()); if answer["done"].as_bool().unwrap_or(false) { break; } @@ -59,7 +59,7 @@ impl LLM { // Use tools functionnality of Ollama, only some models supports it: // https://ollama.com/search?c=tools - pub async fn ask_format(&self, messages: &Vec<Message>) -> Result<serde_json::Value, Box<dyn std::error::Error>> { + pub async fn ask_tools(&self, messages: &Vec<Message>) -> Result<serde_json::Value, Box<dyn std::error::Error>> { let client = Client::new(); let response = client .post(&self.url) @@ -72,7 +72,7 @@ impl LLM { .send() .await?.json::<Value>().await?; - warn(response.to_string()); + //warn(response.to_string()); if let Some(tool_calls) = response .get("message") @@ -105,8 +105,8 @@ impl fmt::Display for MessageType { #[derive(Debug, Serialize, Clone)] pub struct Message { - role: MessageType, - content: String, + pub role: MessageType, + pub content: String, } impl Message { @@ -114,6 +114,13 @@ impl Message { Message { role, content } } + pub fn default() -> Message { + Message { + role: MessageType::USER, + content: "".to_string(), + } + } + pub fn save_message(&self, conv_id: String) -> Result<(), Box<dyn std::error::Error>> { // Create conv directory if doesn't exist create_dir_all("conv")?; diff --git a/src/app/modules.rs b/src/app/modules.rs new file mode 100644 index 0000000..622d63c --- /dev/null +++ b/src/app/modules.rs @@ -0,0 +1 @@ +pub mod wikipedia; diff --git a/src/app/modules/wikipedia.rs b/src/app/modules/wikipedia.rs new file mode 100644 index 0000000..5864df4 --- /dev/null +++ b/src/app/modules/wikipedia.rs @@ -0,0 +1,108 @@ +use crate::app::llm::{Message, MessageType, LLM}; +use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC}; +use crate::helper::init::warn; +use std::fs; +use select::document::Document; +use select::predicate::{Name, Class}; +use regex::Regex; + +pub async fn ask_wiki(messages: &Vec<Message>) -> Result<String, Box<dyn std::error::Error>> { + let wiki_search = LLM::new("config/wiki/wiki-search.json"); + let wiki_best = LLM::new("config/wiki/wiki-best.json"); + let wiki_resume = LLM::new("config/wiki/wiki-resume.json"); + + let settings: serde_json::Value = serde_json::from_str(&fs::read_to_string("config/wiki/wiki.json").unwrap()).unwrap(); + let wiki_url: String = settings.get("wiki_url").unwrap().to_string().replace("\"", ""); + let zim_name: String = settings.get("zim_name").unwrap().to_string().replace("\"", ""); + + // Search articles corresponding to user query + let user_query: Message = messages.last().unwrap().clone(); + let articles: Vec<String> = search_articles(user_query.clone(), wiki_search, &wiki_url, &zim_name).await?; + + // Find best article to respond user query + let best_article_content = find_get_best_article(articles, &user_query.content, wiki_best, &wiki_url, &zim_name).await?; + + // Resume article and create the response + let messages = vec![ + Message::new(MessageType::SYSTEM, wiki_resume.system_prompt.clone()), + Message::new(MessageType::USER, format!("The users query is: {}", user_query.content)), + Message::new(MessageType::USER, format!("The search results are: {}", best_article_content)), + ]; + let query_response: String = wiki_resume.ask(&messages).await.unwrap(); + + Ok(query_response) +} + +async fn search_articles(user_query: Message, search_llm: LLM, wiki_url: &String, zim_name: &String) -> Result<Vec<String>, Box<dyn std::error::Error>> { + // Use LLM to create 4 queries and fetch articles with those 4 queries + let messages = vec![ + Message::new(MessageType::SYSTEM, search_llm.system_prompt.clone()), + user_query, + ]; + let result = search_llm.ask_tools(&messages).await?; + + let queries: Vec<String> = result[0]["function"]["arguments"]["queries"].as_array().unwrap().iter().map(|x| x.as_str().unwrap().to_string()).collect(); + + // Search articles on wikipedia API + let mut articles: Vec<String> = Vec::new(); + for query in queries.iter() { + warn(query.clone()); + + // Request kiwix API for articles matching query + let encoded_query = utf8_percent_encode(&query, NON_ALPHANUMERIC).to_string(); + let client = reqwest::Client::new(); + let url = format!("{}/search?books.name={}&pattern={}", wiki_url, zim_name, encoded_query); + let body = client.get(url).send().await?.text().await?; + + // Select every article corresponding to the query + let document = Document::from(body.as_str()); + + // Select articles title from the query + let results_div = document.find(Class("results")).next().unwrap(); + for node in results_div.find(Name("a")) { + let article = node.text(); + articles.push(article.clone()); + } + } + Ok(articles) +} + +async fn find_get_best_article(articles: Vec<String>, user_query: &String, best_llm: LLM, wiki_url: &String, zim_name: &String) -> Result<String, Box<dyn std::error::Error>> { + // Create a string with all the articles title + let mut articles_headings: String = String::new(); + for article in articles { + articles_headings = format!("{}, {}", &articles_headings, article); + } + + let messages = vec![ + Message::new(MessageType::SYSTEM, best_llm.system_prompt.clone()), + Message::new(MessageType::USER, format!("The user's query is: {}. Here are the headings:\n{}\n\nPlease select the most relevant heading. Output the heading **only** and nothing else.", user_query, articles_headings))]; + let best_article = best_llm.ask(&messages).await?; + + // wiki query get article content & parse + let client = reqwest::Client::new(); + let url: String = format!("{}/content/{}/A/{}", wiki_url, zim_name, best_article.replace("*","").replace(" ", "_")); + let body = client.get(url).send().await?.text().await?; + let content = extract_text_from_tags(&body); + + Ok(content) +} + +fn extract_text_from_tags(html: &str) -> String { + // Créer une expression régulière pour trouver le contenu dans les balises <p>, <h1>, <h2>, <h3> + let re = Regex::new(r#"<p[^>]*>(.*?)</p>|<h1[^>]*>(.*?)</h1>|<h2[^>]*>(.*?)</h2>|<h3[^>]*>(.*?)</h3>"#).unwrap(); + + // Utiliser l'expression régulière pour capturer le contenu des balises <p>, <h1>, <h2>, <h3> + let text = re.captures_iter(html) + .flat_map(|cap| { + // Trouver le premier groupe capturé non vide (parmi cap[1] à cap[4]) + (1..=4) + .filter_map(|i| cap.get(i)) + .map(|m| m.as_str()) // &str + .flat_map(|s| s.split_whitespace()) + .collect::<Vec<_>>() // Vec<&str> + }) + .collect::<Vec<_>>() // collect words + .join(" "); // join with spaces + text +} diff --git a/src/helper/mod.rs b/src/helper.rs index 43763f1..43763f1 100644 --- a/src/helper/mod.rs +++ b/src/helper.rs diff --git a/src/helper/init.rs b/src/helper/init.rs index 2e0537d..d004b34 100644 --- a/src/helper/init.rs +++ b/src/helper/init.rs @@ -10,5 +10,5 @@ pub fn warn(content: String) { .open("log.txt") .unwrap(); let utc: DateTime<Local> = Local::now(); - writeln!(file, "[{}] {}", utc, content); + writeln!(file, "[{}] {}", utc, content).unwrap(); } diff --git a/src/lib.rs b/src/lib.rs deleted file mode 100644 index 52c86c2..0000000 --- a/src/lib.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod app; -pub mod helper; -pub mod ui; diff --git a/src/ui/mod.rs b/src/ui.rs index 424376c..424376c 100644 --- a/src/ui/mod.rs +++ b/src/ui.rs diff --git a/src/ui/init.rs b/src/ui/init.rs index ea0882c..afd686a 100644 --- a/src/ui/init.rs +++ b/src/ui/init.rs @@ -95,6 +95,10 @@ impl Ui { ]); let [help_area, messages_area, input_area] = vertical.areas(frame.area()); + let help_horizontal = + Layout::horizontal([Constraint::Percentage(75), Constraint::Percentage(25)]); + let [help_text_area, conv_id_area] = help_horizontal.areas(help_area); + let (msg, style) = match self.input_field.input_mode { InputMode::Normal => ( vec![ @@ -103,8 +107,8 @@ impl Ui { " to exit, ".into(), "e".bold(), " to start editing, ".into(), - "r".bold(), - " to resume the conversation.".into(), + "s".bold(), + " to save a resume of the conversation.".into(), ], Style::default(), ), @@ -119,9 +123,13 @@ impl Ui { Style::default(), ), }; - let text = Text::from(Line::from(msg)).patch_style(style); - let help_message = Paragraph::new(text); - frame.render_widget(help_message, help_area); + let help_text = Text::from(Line::from(msg)).patch_style(style); + let help_message = Paragraph::new(help_text); + frame.render_widget(help_message, help_text_area); + + let conv_id = self.app.conv_id.to_string().clone(); + let conv_id_text = Paragraph::new(format!("Conv id: {conv_id}")); + frame.render_widget(conv_id_text, conv_id_area); // Rendering inputfield let input = Paragraph::new(self.input_field.input.as_str()) @@ -187,6 +195,7 @@ impl Ui { max_char_per_line = size; } } + let messages = Paragraph::new(Text::from(messages)) .block(Block::bordered().title("Chat with Néo AI")) .wrap(Wrap { trim: false }) |