diff options
Diffstat (limited to 'src/app')
-rw-r--r-- | src/app/init.rs | 18 | ||||
-rw-r--r-- | src/app/llm.rs | 35 |
2 files changed, 48 insertions, 5 deletions
diff --git a/src/app/init.rs b/src/app/init.rs index 201e79e..f1f917a 100644 --- a/src/app/init.rs +++ b/src/app/init.rs @@ -30,10 +30,24 @@ impl App { self.messages.push(message); } + fn categorize_ask(&mut self) { + let runtime = Builder::new_current_thread().enable_all().build().unwrap(); + + let result = runtime.block_on(async { + // Ask the LLM to categorise the request between (chat, code, wikipedia) + self.chat_llm.ask_format(&self.messages).await + }); + + let categorie = result.unwrap()[0]["function"]["arguments"]["category"].clone(); + + self.ask(categorie.to_string().as_str()); + } + fn ask(&mut self, mode: &str) { let runtime = Builder::new_current_thread() .enable_all() .build().unwrap(); + let result = runtime.block_on(async { if mode == "resume" { self.resume_llm.ask(&self.messages).await @@ -43,14 +57,14 @@ impl App { }); match result { - Ok(msg) => self.append_message(msg, MessageType::ASSISTANT), + Ok(msg) => self.append_message(msg.to_string(), MessageType::ASSISTANT), Err(e) => self.append_message(e.to_string(), MessageType::ASSISTANT), } } pub fn send_message(&mut self, content: String) { self.append_message(content, MessageType::USER); - self.ask("chat"); + self.categorize_ask(); } pub fn resume_conv(&mut self) { diff --git a/src/app/llm.rs b/src/app/llm.rs index 9fc1b3a..a172855 100644 --- a/src/app/llm.rs +++ b/src/app/llm.rs @@ -11,6 +11,7 @@ pub struct LLM { url: String, model: String, pub system_prompt: String, + pub tools: serde_json::Value, } impl LLM { @@ -27,9 +28,9 @@ impl LLM { .post(&self.url) .header(CONTENT_TYPE, "application/json") .json(&serde_json::json!({ - "model": self.model, - "messages": messages, - "stream": true})) + "model": self.model, + "messages": messages, + "stream": true})) .send() .await?; @@ -57,6 +58,34 @@ impl LLM { warn(full_message.clone()); Ok(full_message) } + + // Use tools functionnality of Ollama, only some models supports it: + // https://ollama.com/search?c=tools + pub async fn ask_format(&self, messages: &Vec<Message>) -> Result<serde_json::Value, Box<dyn std::error::Error>> { + let client = Client::new(); + let response = client + .post(&self.url) + .header(CONTENT_TYPE, "application/json") + .json(&serde_json::json!({ + "model": self.model, + "messages": messages, + "stream": false, + "tools": self.tools})) + .send() + .await?.json::<Value>().await?; + + warn(response.to_string()); + + if let Some(tool_calls) = response + .get("message") + .and_then(|msg| msg.get("tool_calls")) + .cloned() + { + Ok(tool_calls) + } else { + Err("tool_calls not found".into()) + } + } } #[derive(Debug, Serialize, Clone)] |