aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorOxbian <oxbian@mailbox.org>2025-03-04 21:15:42 -0500
committerOxbian <oxbian@mailbox.org>2025-03-04 21:15:42 -0500
commit43f26405e818aec791b28c50373843851fe1320e (patch)
tree81fe0cb2180afaebedc0edf65bd1c077ab267893 /src
parentb9061a3e652cb7594397c38cd0078a47ddab960a (diff)
downloadNAI-43f26405e818aec791b28c50373843851fe1320e.tar.gz
NAI-43f26405e818aec791b28c50373843851fe1320e.zip
feat: routing request
Diffstat (limited to 'src')
-rw-r--r--src/app/init.rs18
-rw-r--r--src/app/llm.rs35
2 files changed, 48 insertions, 5 deletions
diff --git a/src/app/init.rs b/src/app/init.rs
index 201e79e..f1f917a 100644
--- a/src/app/init.rs
+++ b/src/app/init.rs
@@ -30,10 +30,24 @@ impl App {
self.messages.push(message);
}
+ fn categorize_ask(&mut self) {
+ let runtime = Builder::new_current_thread().enable_all().build().unwrap();
+
+ let result = runtime.block_on(async {
+ // Ask the LLM to categorise the request between (chat, code, wikipedia)
+ self.chat_llm.ask_format(&self.messages).await
+ });
+
+ let categorie = result.unwrap()[0]["function"]["arguments"]["category"].clone();
+
+ self.ask(categorie.to_string().as_str());
+ }
+
fn ask(&mut self, mode: &str) {
let runtime = Builder::new_current_thread()
.enable_all()
.build().unwrap();
+
let result = runtime.block_on(async {
if mode == "resume" {
self.resume_llm.ask(&self.messages).await
@@ -43,14 +57,14 @@ impl App {
});
match result {
- Ok(msg) => self.append_message(msg, MessageType::ASSISTANT),
+ Ok(msg) => self.append_message(msg.to_string(), MessageType::ASSISTANT),
Err(e) => self.append_message(e.to_string(), MessageType::ASSISTANT),
}
}
pub fn send_message(&mut self, content: String) {
self.append_message(content, MessageType::USER);
- self.ask("chat");
+ self.categorize_ask();
}
pub fn resume_conv(&mut self) {
diff --git a/src/app/llm.rs b/src/app/llm.rs
index 9fc1b3a..a172855 100644
--- a/src/app/llm.rs
+++ b/src/app/llm.rs
@@ -11,6 +11,7 @@ pub struct LLM {
url: String,
model: String,
pub system_prompt: String,
+ pub tools: serde_json::Value,
}
impl LLM {
@@ -27,9 +28,9 @@ impl LLM {
.post(&self.url)
.header(CONTENT_TYPE, "application/json")
.json(&serde_json::json!({
- "model": self.model,
- "messages": messages,
- "stream": true}))
+ "model": self.model,
+ "messages": messages,
+ "stream": true}))
.send()
.await?;
@@ -57,6 +58,34 @@ impl LLM {
warn(full_message.clone());
Ok(full_message)
}
+
+ // Use tools functionnality of Ollama, only some models supports it:
+ // https://ollama.com/search?c=tools
+ pub async fn ask_format(&self, messages: &Vec<Message>) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
+ let client = Client::new();
+ let response = client
+ .post(&self.url)
+ .header(CONTENT_TYPE, "application/json")
+ .json(&serde_json::json!({
+ "model": self.model,
+ "messages": messages,
+ "stream": false,
+ "tools": self.tools}))
+ .send()
+ .await?.json::<Value>().await?;
+
+ warn(response.to_string());
+
+ if let Some(tool_calls) = response
+ .get("message")
+ .and_then(|msg| msg.get("tool_calls"))
+ .cloned()
+ {
+ Ok(tool_calls)
+ } else {
+ Err("tool_calls not found".into())
+ }
+ }
}
#[derive(Debug, Serialize, Clone)]
ArKa projects. All rights to me, and your next child right arm.