oku_fs/fs/net/
embeddings.rs

1use super::core::{home_replica_filters, EmbeddingModality};
2use crate::{database::posts::core::OkuNote, fs::OkuFs};
3use bytes::Bytes;
4use iroh_blobs::Hash;
5use iroh_docs::DocTicket;
6use log::error;
7use miette::IntoDiagnostic;
8use rayon::iter::IntoParallelIterator;
9use rayon::iter::IntoParallelRefIterator;
10use rayon::iter::ParallelIterator;
11use std::path::PathBuf;
12use url::Url;
13use zebra::{
14    database::default::{
15        audio::DefaultAudioModel, image::DefaultImageModel, text::DefaultTextModel,
16    },
17    model::core::{DatabaseEmbeddingModel, DIM_BGESMALL_EN_1_5, DIM_VIT_BASE_PATCH16_224},
18    Embedding,
19};
20
21impl OkuFs {
22    /// The embedding vector database for text media.
23    pub fn text_database(&self) -> zebra::database::default::text::DefaultTextDatabase {
24        zebra::database::default::text::DefaultTextDatabase::open_or_create(&"text.zebra".into())
25    }
26
27    /// The embedding vector database for image media.
28    pub fn image_database(&self) -> zebra::database::default::image::DefaultImageDatabase {
29        zebra::database::default::image::DefaultImageDatabase::open_or_create(&"image.zebra".into())
30    }
31
32    /// The embedding vector database for audio media.
33    pub fn audio_database(&self) -> zebra::database::default::audio::DefaultAudioDatabase {
34        zebra::database::default::audio::DefaultAudioDatabase::open_or_create(&"audio.zebra".into())
35    }
36
37    /// Determine the modality of some data.
38    ///
39    /// # Arguments
40    ///
41    /// * `bytes` - The given data.
42    ///
43    /// # Returns
44    ///
45    /// The modality of the data, if embeddable.
46    pub fn bytes_to_embedding_modality(&self, bytes: &Bytes) -> miette::Result<EmbeddingModality> {
47        let mime_type = tree_magic_mini::from_u8(bytes);
48        let type_ = mime_type.split("/").nth(0).unwrap_or_default();
49        match type_ {
50            "audio" => Ok(EmbeddingModality::Audio),
51            "image" => Ok(EmbeddingModality::Image),
52            "text" => Ok(EmbeddingModality::Text),
53            _ => Err(miette::miette!(
54                "Unexpected MIME type ({mime_type:?}); embedding modality cannot be determined … "
55            )),
56        }
57    }
58
59    /// Create an embedding file in the user's home replica for a document.
60    ///
61    /// # Arguments
62    ///
63    /// * `path` - An optional path to the embedding file; if none is specified, a suggested path will be used.
64    ///
65    /// * `url` - The URL of the document.
66    ///
67    /// * `bytes` - The document's contents.
68    ///
69    /// # Returns
70    ///
71    /// The hash of the file.
72    pub async fn create_post_embedding(
73        &self,
74        path: &Option<PathBuf>,
75        url: &Url,
76        bytes: &Bytes,
77    ) -> miette::Result<Hash> {
78        let home_replica_id = self
79            .home_replica()
80            .await
81            .ok_or(miette::miette!("No home replica set … "))?;
82        let embed_path = match path {
83            Some(given_path) => given_path,
84            None => &{
85                let mut path: PathBuf =
86                    OkuNote::suggested_post_path_from_url(&url.to_string()).into();
87                path.set_extension("okuembed");
88                path
89            },
90        };
91        let mut archive_path = embed_path.clone();
92        archive_path.set_extension("okuarchive");
93        if let Err(e) = self
94            .create_or_modify_file(&home_replica_id, &archive_path, bytes.clone())
95            .await
96        {
97            error!("{e}");
98        }
99        match self.bytes_to_embedding_modality(bytes)? {
100            EmbeddingModality::Audio => {
101                let model = DefaultAudioModel::default();
102                let embedding = model
103                    .embed(bytes.clone())
104                    .map_err(|e| miette::miette!("{e}"))?;
105                let embedding_json = serde_json::to_string(&embedding).into_diagnostic()?;
106                self.create_or_modify_file(&home_replica_id, embed_path, embedding_json)
107                    .await
108            }
109            EmbeddingModality::Image => {
110                let model = DefaultImageModel::default();
111                let embedding = model
112                    .embed(bytes.clone())
113                    .map_err(|e| miette::miette!("{e}"))?;
114                let embedding_json = serde_json::to_string(&embedding).into_diagnostic()?;
115                self.create_or_modify_file(&home_replica_id, embed_path, embedding_json)
116                    .await
117            }
118            EmbeddingModality::Text => {
119                let model = DefaultTextModel::default();
120                let embedding = model
121                    .embed(bytes.clone())
122                    .map_err(|e| miette::miette!("{e}"))?;
123                let embedding_json = serde_json::to_string(&embedding).into_diagnostic()?;
124                self.create_or_modify_file(&home_replica_id, embed_path, embedding_json)
125                    .await
126            }
127        }
128    }
129
130    /// Find the URLs of the most similar documents.
131    ///
132    /// # Arguments
133    ///
134    /// * `bytes` - A document.
135    ///
136    /// * `number_of_results` - The maximum number of URLs to return.
137    ///
138    /// # Returns
139    ///
140    /// The URLs of the documents most similar to the given one, approximately.
141    pub fn nearest_urls(
142        &self,
143        bytes: &Bytes,
144        number_of_results: usize,
145    ) -> miette::Result<Vec<Url>> {
146        match self.bytes_to_embedding_modality(bytes)? {
147            EmbeddingModality::Audio => {
148                let db = self.audio_database();
149                let results = db
150                    .query_documents(&[bytes.clone()], number_of_results)
151                    .map_err(|e| miette::miette!("{e}"))?;
152                let result_strings: Vec<String> = results
153                    .into_read_only()
154                    .into_par_iter()
155                    .flat_map(|(_x, y)| {
156                        y.into_read_only()
157                            .into_par_iter()
158                            .map(|(_a, b)| String::from_utf8_lossy(&b).to_string())
159                            .collect::<Vec<_>>()
160                    })
161                    .collect();
162                Ok(result_strings
163                    .par_iter()
164                    .filter_map(|x| Url::parse(x).ok())
165                    .collect())
166            }
167            EmbeddingModality::Image => {
168                let db = self.image_database();
169                let results = db
170                    .query_documents(&[bytes.clone()], number_of_results)
171                    .map_err(|e| miette::miette!("{e}"))?;
172                let result_strings: Vec<String> = results
173                    .into_read_only()
174                    .into_par_iter()
175                    .flat_map(|(_x, y)| {
176                        y.into_read_only()
177                            .into_par_iter()
178                            .map(|(_a, b)| String::from_utf8_lossy(&b).to_string())
179                            .collect::<Vec<_>>()
180                    })
181                    .collect();
182                Ok(result_strings
183                    .par_iter()
184                    .filter_map(|x| Url::parse(x).ok())
185                    .collect())
186            }
187            EmbeddingModality::Text => {
188                let db = self.text_database();
189                let results = db
190                    .query_documents(&[bytes.clone()], number_of_results)
191                    .map_err(|e| miette::miette!("{e}"))?;
192                let result_strings: Vec<String> = results
193                    .into_read_only()
194                    .into_par_iter()
195                    .flat_map(|(_x, y)| {
196                        y.into_read_only()
197                            .into_par_iter()
198                            .map(|(_a, b)| String::from_utf8_lossy(&b).to_string())
199                            .collect::<Vec<_>>()
200                    })
201                    .collect();
202                Ok(result_strings
203                    .par_iter()
204                    .filter_map(|x| Url::parse(x).ok())
205                    .collect())
206            }
207        }
208    }
209
210    /// Fetch an embedding file associated with a post.
211    ///
212    /// # Arguments
213    ///
214    /// * `ticket` - A ticket for the replica containing the file to retrieve.
215    ///
216    /// * `path` - The path to the file to retrieve.
217    ///
218    /// * `uri` - The URI associated with the OkuNet post.
219    pub async fn fetch_post_embeddings(
220        &self,
221        ticket: &DocTicket,
222        path: &PathBuf,
223        uri: &str,
224    ) -> miette::Result<()> {
225        let mut archive_path = path.clone();
226        archive_path.set_extension("okuarchive");
227        if let Ok(embedding_bytes) = self
228            .fetch_file_with_ticket(ticket, path, &Some(home_replica_filters()))
229            .await
230        {
231            if let Ok(bytes) = self
232                .fetch_file_with_ticket(ticket, &archive_path, &Some(home_replica_filters()))
233                .await
234            {
235                match self.bytes_to_embedding_modality(&bytes)? {
236                    EmbeddingModality::Audio => {
237                        let embedding =
238                            serde_json::from_str::<Embedding<DIM_VIT_BASE_PATCH16_224>>(
239                                String::from_utf8_lossy(&embedding_bytes).as_ref(),
240                            )
241                            .into_diagnostic()?;
242                        let db = self.audio_database();
243                        db.insert_records(&vec![embedding], &vec![uri.to_owned().into()])
244                            .map_err(|e| miette::miette!("{e}"))?;
245                        db.deduplicate().map_err(|e| miette::miette!("{e}"))?;
246                    }
247                    EmbeddingModality::Image => {
248                        let embedding =
249                            serde_json::from_str::<Embedding<DIM_VIT_BASE_PATCH16_224>>(
250                                String::from_utf8_lossy(&embedding_bytes).as_ref(),
251                            )
252                            .into_diagnostic()?;
253                        let db = self.image_database();
254                        db.insert_records(&vec![embedding], &vec![uri.to_owned().into()])
255                            .map_err(|e| miette::miette!("{e}"))?;
256                        db.deduplicate().map_err(|e| miette::miette!("{e}"))?;
257                    }
258                    EmbeddingModality::Text => {
259                        let embedding = serde_json::from_str::<Embedding<DIM_BGESMALL_EN_1_5>>(
260                            String::from_utf8_lossy(&embedding_bytes).as_ref(),
261                        )
262                        .into_diagnostic()?;
263                        let db = self.text_database();
264                        db.insert_records(&vec![embedding], &vec![uri.to_owned().into()])
265                            .map_err(|e| miette::miette!("{e}"))?;
266                        db.deduplicate().map_err(|e| miette::miette!("{e}"))?;
267                    }
268                }
269            }
270        }
271        Ok(())
272    }
273}