aboutsummaryrefslogtreecommitdiff
path: root/src/app/llm.rs
blob: 979b7908f17d0d770f7fd6b31d482523cd241b09 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
use crate::helper::init::warn;
use reqwest::{header::CONTENT_TYPE, Client};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::fmt;
use std::fs::{self, OpenOptions, create_dir_all};
use std::io::Write;

#[derive(Deserialize, Debug)]
pub struct LLM {
    url: String,
    model: String,
    pub system_prompt: String,
    pub tools: serde_json::Value,
}

impl LLM {
    pub fn new(config_file: &str) -> LLM {
        let contents = fs::read_to_string(config_file).unwrap();
        serde_json::from_str(&contents).unwrap()
    }

    pub async fn ask(&self, messages: &Vec<Message>) -> Result<String, Box<dyn std::error::Error>> {
        let client = Client::new();
        let response = client
            .post(&self.url)
            .header(CONTENT_TYPE, "application/json")
            .json(&serde_json::json!({
                "model": self.model,
                "messages": messages,
                "stream": true}))
            .send()
            .await?;

        let mut full_message = String::new();

        // Reading the stream and saving the response
        match response.error_for_status() {
            Ok(mut res) => {
                while let Some(chunk) = res.chunk().await? {
                    let answer: Value = serde_json::from_slice(chunk.as_ref())?;

                    //warn(answer.to_string());
                    if answer["done"].as_bool().unwrap_or(false) {
                        break;
                    }

                    let msg = answer["message"]["content"].as_str().unwrap_or("\n");

                    full_message.push_str(msg);
                }
            }
            Err(e) => return Err(Box::new(e)),
        }
    
        warn(full_message.clone());
        Ok(full_message)
    }

    // Use tools functionnality of Ollama, only some models supports it:
    // https://ollama.com/search?c=tools
    pub async fn ask_tools(&self, messages: &Vec<Message>) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
        let client = Client::new();
        let response = client
            .post(&self.url)
            .header(CONTENT_TYPE, "application/json")
            .json(&serde_json::json!({
                "model": self.model,
                "messages": messages,
                "stream": false,
                "tools": self.tools}))
            .send()
            .await?.json::<Value>().await?;

        //warn(response.to_string());

        if let Some(tool_calls) = response
            .get("message")
                .and_then(|msg| msg.get("tool_calls"))
                .cloned()
        {
            Ok(tool_calls)
        } else {
            Err("tool_calls not found".into())
        }
    }
}

#[derive(Debug, Serialize, Clone)]
pub enum MessageType {
    ASSISTANT,
    SYSTEM,
    USER,
}

impl fmt::Display for MessageType {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            MessageType::ASSISTANT => write!(f, "assistant"),
            MessageType::SYSTEM => write!(f, "system"),
            MessageType::USER => write!(f, "user"),
        }
    }
}

#[derive(Debug, Serialize, Clone)]
pub struct Message {
    pub role: MessageType,
    pub content: String,
}

impl Message {
    pub fn new(role: MessageType, content: String) -> Message {
        Message { role, content }
    }

    pub fn default() -> Message {
        Message {
            role: MessageType::USER,
            content: "".to_string(),
        }
    }

    pub fn save_message(&self, conv_id: String) -> Result<(), Box<dyn std::error::Error>> {
        // Create conv directory if doesn't exist
        create_dir_all("conv")?;

        // Save message
        let mut file = OpenOptions::new()
            .write(true)
            .append(true)
            .create(true)
            .open("conv/".to_string() + &conv_id)
            .unwrap();

        writeln!(file, "{}", serde_json::to_string(self)?)?;

        Ok(())
    }
}

impl fmt::Display for Message {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self.role {
            MessageType::USER => return write!(f, "You: {}", self.content),
            MessageType::SYSTEM => return write!(f, "System: {}", self.content),
            MessageType::ASSISTANT => return write!(f, "Néo AI: {}", self.content),
        }
    }
}
ArKa projects. All rights to me, and your next child right arm.