aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/app.rs (renamed from src/app/mod.rs)1
-rw-r--r--src/app/init.rs26
-rw-r--r--src/app/llm.rs19
-rw-r--r--src/app/modules.rs1
-rw-r--r--src/app/modules/wikipedia.rs108
-rw-r--r--src/helper.rs (renamed from src/helper/mod.rs)0
-rw-r--r--src/helper/init.rs2
-rw-r--r--src/lib.rs3
-rw-r--r--src/ui.rs (renamed from src/ui/mod.rs)0
-rw-r--r--src/ui/init.rs19
10 files changed, 154 insertions, 25 deletions
diff --git a/src/app/mod.rs b/src/app.rs
index 3cff678..8376f8c 100644
--- a/src/app/mod.rs
+++ b/src/app.rs
@@ -1,2 +1,3 @@
pub mod init;
pub mod llm;
+pub mod modules;
diff --git a/src/app/init.rs b/src/app/init.rs
index dd8d3b7..0930319 100644
--- a/src/app/init.rs
+++ b/src/app/init.rs
@@ -1,27 +1,30 @@
use crate::app::llm::{Message, MessageType, LLM};
+use crate::app::modules::wikipedia::ask_wiki;
use crate::helper::init::warn;
use uuid::Uuid;
use tokio::runtime::Builder;
pub struct App {
pub messages: Vec<Message>, // History of recorded message
- conv_id: Uuid,
- chat_llm: LLM,
- resume_llm: LLM,
+ pub conv_id: Uuid, // ID for retrieving and saving the history of messag
+ categorize_llm: LLM,
+ chat_llm: LLM, // Configuration for the LLM that chat with you
+ resume_llm: LLM, // Configuration for the LLM that resume conversation
}
impl App {
pub fn new() -> App {
- let chat_llm: LLM = LLM::new("config/chat-LLM.json".to_string());
+ let categorize_llm = LLM::new("config/categorize-LLM.json");
App {
messages: vec![Message::new(
MessageType::SYSTEM,
- chat_llm.system_prompt.clone(),
+ categorize_llm.system_prompt.clone(),
)],
conv_id: Uuid::new_v4(),
- chat_llm,
- resume_llm: LLM::new("config/resume-LLM.json".to_string()),
+ categorize_llm,
+ chat_llm: LLM::new("config/chat-LLM.json"),
+ resume_llm: LLM::new("config/resume-LLM.json"),
}
}
@@ -29,7 +32,7 @@ impl App {
let message = Message::new(role, msg);
let err = message.save_message(self.conv_id.to_string());
- warn(err.is_err().to_string());
+ //warn(err.is_err().to_string());
self.messages.push(message);
}
@@ -39,19 +42,20 @@ impl App {
let result = runtime.block_on(async {
// Ask the LLM to categorise the request between (chat, code, wikipedia)
- self.chat_llm.ask_format(&self.messages).await
+ self.categorize_llm.ask_tools(&self.messages).await
});
match result {
Ok(msg) => {
let categorie = msg[0]["function"]["arguments"]["category"].clone();
- self.ask(categorie.to_string().as_str());
+ self.ask(&categorie.to_string().replace("\"", ""));
},
Err(e) => self.append_message(e.to_string(), MessageType::ASSISTANT),
}
}
fn ask(&mut self, mode: &str) {
+ warn(format!("Categorie: {}", mode));
let runtime = Builder::new_current_thread()
.enable_all()
.build().unwrap();
@@ -59,6 +63,8 @@ impl App {
let result = runtime.block_on(async {
if mode == "resume" {
self.resume_llm.ask(&self.messages).await
+ } else if mode == "wikipedia" {
+ ask_wiki(&self.messages).await
} else {
self.chat_llm.ask(&self.messages).await
}
diff --git a/src/app/llm.rs b/src/app/llm.rs
index 9c6d222..979b790 100644
--- a/src/app/llm.rs
+++ b/src/app/llm.rs
@@ -15,7 +15,7 @@ pub struct LLM {
}
impl LLM {
- pub fn new(config_file: String) -> LLM {
+ pub fn new(config_file: &str) -> LLM {
let contents = fs::read_to_string(config_file).unwrap();
serde_json::from_str(&contents).unwrap()
}
@@ -40,7 +40,7 @@ impl LLM {
while let Some(chunk) = res.chunk().await? {
let answer: Value = serde_json::from_slice(chunk.as_ref())?;
- warn(answer.to_string());
+ //warn(answer.to_string());
if answer["done"].as_bool().unwrap_or(false) {
break;
}
@@ -59,7 +59,7 @@ impl LLM {
// Use tools functionnality of Ollama, only some models supports it:
// https://ollama.com/search?c=tools
- pub async fn ask_format(&self, messages: &Vec<Message>) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
+ pub async fn ask_tools(&self, messages: &Vec<Message>) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
let client = Client::new();
let response = client
.post(&self.url)
@@ -72,7 +72,7 @@ impl LLM {
.send()
.await?.json::<Value>().await?;
- warn(response.to_string());
+ //warn(response.to_string());
if let Some(tool_calls) = response
.get("message")
@@ -105,8 +105,8 @@ impl fmt::Display for MessageType {
#[derive(Debug, Serialize, Clone)]
pub struct Message {
- role: MessageType,
- content: String,
+ pub role: MessageType,
+ pub content: String,
}
impl Message {
@@ -114,6 +114,13 @@ impl Message {
Message { role, content }
}
+ pub fn default() -> Message {
+ Message {
+ role: MessageType::USER,
+ content: "".to_string(),
+ }
+ }
+
pub fn save_message(&self, conv_id: String) -> Result<(), Box<dyn std::error::Error>> {
// Create conv directory if doesn't exist
create_dir_all("conv")?;
diff --git a/src/app/modules.rs b/src/app/modules.rs
new file mode 100644
index 0000000..622d63c
--- /dev/null
+++ b/src/app/modules.rs
@@ -0,0 +1 @@
+pub mod wikipedia;
diff --git a/src/app/modules/wikipedia.rs b/src/app/modules/wikipedia.rs
new file mode 100644
index 0000000..5864df4
--- /dev/null
+++ b/src/app/modules/wikipedia.rs
@@ -0,0 +1,108 @@
+use crate::app::llm::{Message, MessageType, LLM};
+use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
+use crate::helper::init::warn;
+use std::fs;
+use select::document::Document;
+use select::predicate::{Name, Class};
+use regex::Regex;
+
+pub async fn ask_wiki(messages: &Vec<Message>) -> Result<String, Box<dyn std::error::Error>> {
+ let wiki_search = LLM::new("config/wiki/wiki-search.json");
+ let wiki_best = LLM::new("config/wiki/wiki-best.json");
+ let wiki_resume = LLM::new("config/wiki/wiki-resume.json");
+
+ let settings: serde_json::Value = serde_json::from_str(&fs::read_to_string("config/wiki/wiki.json").unwrap()).unwrap();
+ let wiki_url: String = settings.get("wiki_url").unwrap().to_string().replace("\"", "");
+ let zim_name: String = settings.get("zim_name").unwrap().to_string().replace("\"", "");
+
+ // Search articles corresponding to user query
+ let user_query: Message = messages.last().unwrap().clone();
+ let articles: Vec<String> = search_articles(user_query.clone(), wiki_search, &wiki_url, &zim_name).await?;
+
+ // Find best article to respond user query
+ let best_article_content = find_get_best_article(articles, &user_query.content, wiki_best, &wiki_url, &zim_name).await?;
+
+ // Resume article and create the response
+ let messages = vec![
+ Message::new(MessageType::SYSTEM, wiki_resume.system_prompt.clone()),
+ Message::new(MessageType::USER, format!("The users query is: {}", user_query.content)),
+ Message::new(MessageType::USER, format!("The search results are: {}", best_article_content)),
+ ];
+ let query_response: String = wiki_resume.ask(&messages).await.unwrap();
+
+ Ok(query_response)
+}
+
+async fn search_articles(user_query: Message, search_llm: LLM, wiki_url: &String, zim_name: &String) -> Result<Vec<String>, Box<dyn std::error::Error>> {
+ // Use LLM to create 4 queries and fetch articles with those 4 queries
+ let messages = vec![
+ Message::new(MessageType::SYSTEM, search_llm.system_prompt.clone()),
+ user_query,
+ ];
+ let result = search_llm.ask_tools(&messages).await?;
+
+ let queries: Vec<String> = result[0]["function"]["arguments"]["queries"].as_array().unwrap().iter().map(|x| x.as_str().unwrap().to_string()).collect();
+
+ // Search articles on wikipedia API
+ let mut articles: Vec<String> = Vec::new();
+ for query in queries.iter() {
+ warn(query.clone());
+
+ // Request kiwix API for articles matching query
+ let encoded_query = utf8_percent_encode(&query, NON_ALPHANUMERIC).to_string();
+ let client = reqwest::Client::new();
+ let url = format!("{}/search?books.name={}&pattern={}", wiki_url, zim_name, encoded_query);
+ let body = client.get(url).send().await?.text().await?;
+
+ // Select every article corresponding to the query
+ let document = Document::from(body.as_str());
+
+ // Select articles title from the query
+ let results_div = document.find(Class("results")).next().unwrap();
+ for node in results_div.find(Name("a")) {
+ let article = node.text();
+ articles.push(article.clone());
+ }
+ }
+ Ok(articles)
+}
+
+async fn find_get_best_article(articles: Vec<String>, user_query: &String, best_llm: LLM, wiki_url: &String, zim_name: &String) -> Result<String, Box<dyn std::error::Error>> {
+ // Create a string with all the articles title
+ let mut articles_headings: String = String::new();
+ for article in articles {
+ articles_headings = format!("{}, {}", &articles_headings, article);
+ }
+
+ let messages = vec![
+ Message::new(MessageType::SYSTEM, best_llm.system_prompt.clone()),
+ Message::new(MessageType::USER, format!("The user's query is: {}. Here are the headings:\n{}\n\nPlease select the most relevant heading. Output the heading **only** and nothing else.", user_query, articles_headings))];
+ let best_article = best_llm.ask(&messages).await?;
+
+ // wiki query get article content & parse
+ let client = reqwest::Client::new();
+ let url: String = format!("{}/content/{}/A/{}", wiki_url, zim_name, best_article.replace("*","").replace(" ", "_"));
+ let body = client.get(url).send().await?.text().await?;
+ let content = extract_text_from_tags(&body);
+
+ Ok(content)
+}
+
+fn extract_text_from_tags(html: &str) -> String {
+ // Créer une expression régulière pour trouver le contenu dans les balises <p>, <h1>, <h2>, <h3>
+ let re = Regex::new(r#"<p[^>]*>(.*?)</p>|<h1[^>]*>(.*?)</h1>|<h2[^>]*>(.*?)</h2>|<h3[^>]*>(.*?)</h3>"#).unwrap();
+
+ // Utiliser l'expression régulière pour capturer le contenu des balises <p>, <h1>, <h2>, <h3>
+ let text = re.captures_iter(html)
+ .flat_map(|cap| {
+ // Trouver le premier groupe capturé non vide (parmi cap[1] à cap[4])
+ (1..=4)
+ .filter_map(|i| cap.get(i))
+ .map(|m| m.as_str()) // &str
+ .flat_map(|s| s.split_whitespace())
+ .collect::<Vec<_>>() // Vec<&str>
+ })
+ .collect::<Vec<_>>() // collect words
+ .join(" "); // join with spaces
+ text
+}
diff --git a/src/helper/mod.rs b/src/helper.rs
index 43763f1..43763f1 100644
--- a/src/helper/mod.rs
+++ b/src/helper.rs
diff --git a/src/helper/init.rs b/src/helper/init.rs
index 2e0537d..d004b34 100644
--- a/src/helper/init.rs
+++ b/src/helper/init.rs
@@ -10,5 +10,5 @@ pub fn warn(content: String) {
.open("log.txt")
.unwrap();
let utc: DateTime<Local> = Local::now();
- writeln!(file, "[{}] {}", utc, content);
+ writeln!(file, "[{}] {}", utc, content).unwrap();
}
diff --git a/src/lib.rs b/src/lib.rs
deleted file mode 100644
index 52c86c2..0000000
--- a/src/lib.rs
+++ /dev/null
@@ -1,3 +0,0 @@
-pub mod app;
-pub mod helper;
-pub mod ui;
diff --git a/src/ui/mod.rs b/src/ui.rs
index 424376c..424376c 100644
--- a/src/ui/mod.rs
+++ b/src/ui.rs
diff --git a/src/ui/init.rs b/src/ui/init.rs
index ea0882c..afd686a 100644
--- a/src/ui/init.rs
+++ b/src/ui/init.rs
@@ -95,6 +95,10 @@ impl Ui {
]);
let [help_area, messages_area, input_area] = vertical.areas(frame.area());
+ let help_horizontal =
+ Layout::horizontal([Constraint::Percentage(75), Constraint::Percentage(25)]);
+ let [help_text_area, conv_id_area] = help_horizontal.areas(help_area);
+
let (msg, style) = match self.input_field.input_mode {
InputMode::Normal => (
vec![
@@ -103,8 +107,8 @@ impl Ui {
" to exit, ".into(),
"e".bold(),
" to start editing, ".into(),
- "r".bold(),
- " to resume the conversation.".into(),
+ "s".bold(),
+ " to save a resume of the conversation.".into(),
],
Style::default(),
),
@@ -119,9 +123,13 @@ impl Ui {
Style::default(),
),
};
- let text = Text::from(Line::from(msg)).patch_style(style);
- let help_message = Paragraph::new(text);
- frame.render_widget(help_message, help_area);
+ let help_text = Text::from(Line::from(msg)).patch_style(style);
+ let help_message = Paragraph::new(help_text);
+ frame.render_widget(help_message, help_text_area);
+
+ let conv_id = self.app.conv_id.to_string().clone();
+ let conv_id_text = Paragraph::new(format!("Conv id: {conv_id}"));
+ frame.render_widget(conv_id_text, conv_id_area);
// Rendering inputfield
let input = Paragraph::new(self.input_field.input.as_str())
@@ -187,6 +195,7 @@ impl Ui {
max_char_per_line = size;
}
}
+
let messages = Paragraph::new(Text::from(messages))
.block(Block::bordered().title("Chat with Néo AI"))
.wrap(Wrap { trim: false })
ArKa projects. All rights to me, and your next child right arm.