oku_fs/fs/net/
embeddings.rs

1use super::core::{home_replica_filters, EmbeddingModality};
2use crate::{database::posts::core::OkuNote, fs::OkuFs};
3use bytes::Bytes;
4use iroh_blobs::Hash;
5use iroh_docs::AuthorId;
6use iroh_docs::DocTicket;
7use log::error;
8use miette::IntoDiagnostic;
9use rayon::iter::IntoParallelIterator;
10use rayon::iter::ParallelIterator;
11use std::path::PathBuf;
12use url::Url;
13use zebra::{
14    database::default::{
15        audio::DefaultAudioModel, image::DefaultImageModel, text::DefaultTextModel,
16    },
17    model::core::{DatabaseEmbeddingModel, DIM_BGESMALL_EN_1_5, DIM_VIT_BASE_PATCH16_224},
18    Embedding,
19};
20
21impl OkuFs {
22    /// The embedding vector database for text media.
23    pub fn text_database(
24        &self,
25    ) -> miette::Result<zebra::database::default::text::DefaultTextDatabase> {
26        zebra::database::default::text::DefaultTextDatabase::open_or_create(
27            &"text.zebra".into(),
28            &Default::default(),
29        )
30        .map_err(|e| miette::miette!("{e}"))
31    }
32
33    /// The embedding vector database for image media.
34    pub fn image_database(
35        &self,
36    ) -> miette::Result<zebra::database::default::image::DefaultImageDatabase> {
37        zebra::database::default::image::DefaultImageDatabase::open_or_create(
38            &"image.zebra".into(),
39            &Default::default(),
40        )
41        .map_err(|e| miette::miette!("{e}"))
42    }
43
44    /// The embedding vector database for audio media.
45    pub fn audio_database(
46        &self,
47    ) -> miette::Result<zebra::database::default::audio::DefaultAudioDatabase> {
48        zebra::database::default::audio::DefaultAudioDatabase::open_or_create(
49            &"audio.zebra".into(),
50            &Default::default(),
51        )
52        .map_err(|e| miette::miette!("{e}"))
53    }
54
55    /// Determine the modality of some data.
56    ///
57    /// # Arguments
58    ///
59    /// * `bytes` - The given data.
60    ///
61    /// # Returns
62    ///
63    /// The modality of the data, if embeddable.
64    pub fn bytes_to_embedding_modality(&self, bytes: &Bytes) -> miette::Result<EmbeddingModality> {
65        let mime_type = tree_magic_mini::from_u8(bytes);
66        let type_ = mime_type.split("/").nth(0).unwrap_or_default();
67        match type_ {
68            "audio" => Ok(EmbeddingModality::Audio),
69            "image" => Ok(EmbeddingModality::Image),
70            "text" => Ok(EmbeddingModality::Text),
71            _ => Err(miette::miette!(
72                "Unexpected MIME type ({mime_type:?}); embedding modality cannot be determined … "
73            )),
74        }
75    }
76
77    /// Create an embedding file in the user's home replica for a document.
78    ///
79    /// # Arguments
80    ///
81    /// * `path` - An optional path to the embedding file; if none is specified, a suggested path will be used.
82    ///
83    /// * `url` - The URL of the document.
84    ///
85    /// * `bytes` - The document's contents.
86    ///
87    /// # Returns
88    ///
89    /// The hash of the file.
90    pub async fn create_post_embedding(&self, url: &Url, bytes: &Bytes) -> miette::Result<Hash> {
91        let home_replica_id = self
92            .home_replica()
93            .await
94            .ok_or(miette::miette!("No home replica set … "))?;
95        let url_string = url.to_string();
96        let embed_path: &PathBuf = &OkuNote::embedding_path_from_url(&url_string).into();
97        let archive_path: &PathBuf = &OkuNote::archive_path_from_url(&url_string).into();
98        if let Err(e) = self
99            .create_or_modify_file(&home_replica_id, archive_path, bytes.clone())
100            .await
101        {
102            error!("{e}");
103        }
104        match self.bytes_to_embedding_modality(bytes)? {
105            EmbeddingModality::Audio => {
106                let model = DefaultAudioModel::default();
107                let embedding = model
108                    .embed(bytes.clone())
109                    .map_err(|e| miette::miette!("{e}"))?;
110                let embedding_json = serde_json::to_string(&embedding).into_diagnostic()?;
111                self.create_or_modify_file(&home_replica_id, embed_path, embedding_json)
112                    .await
113            }
114            EmbeddingModality::Image => {
115                let model = DefaultImageModel::default();
116                let embedding = model
117                    .embed(bytes.clone())
118                    .map_err(|e| miette::miette!("{e}"))?;
119                let embedding_json = serde_json::to_string(&embedding).into_diagnostic()?;
120                self.create_or_modify_file(&home_replica_id, embed_path, embedding_json)
121                    .await
122            }
123            EmbeddingModality::Text => {
124                let model = DefaultTextModel::default();
125                let embedding = model
126                    .embed(bytes.clone())
127                    .map_err(|e| miette::miette!("{e}"))?;
128                let embedding_json = serde_json::to_string(&embedding).into_diagnostic()?;
129                self.create_or_modify_file(&home_replica_id, embed_path, embedding_json)
130                    .await
131            }
132        }
133    }
134
135    /// Find the archival records of the most similar documents.
136    ///
137    /// # Arguments
138    ///
139    /// * `bytes` - A document.
140    ///
141    /// * `number_of_results` - The maximum number of archives to return.
142    ///
143    /// # Returns
144    ///
145    /// The URIs of the documents approximately most similar to the given one, paired with their archivist's authorship ID.
146    pub fn nearest_archives(
147        &self,
148        bytes: &Bytes,
149        number_of_results: usize,
150    ) -> miette::Result<Vec<(AuthorId, String)>> {
151        match self.bytes_to_embedding_modality(bytes)? {
152            EmbeddingModality::Audio => {
153                let db = self.audio_database()?;
154                let results = db
155                    .query_documents(&[bytes.clone()], number_of_results)
156                    .map_err(|e| miette::miette!("{e}"))?;
157                Ok(results
158                    .into_read_only()
159                    .into_par_iter()
160                    .flat_map(|(_x, y)| {
161                        y.into_read_only()
162                            .into_par_iter()
163                            .filter_map(|(_a, b)| serde_json::from_slice(&b).ok())
164                            .collect::<Vec<_>>()
165                    })
166                    .collect())
167            }
168            EmbeddingModality::Image => {
169                let db = self.image_database()?;
170                let results = db
171                    .query_documents(&[bytes.clone()], number_of_results)
172                    .map_err(|e| miette::miette!("{e}"))?;
173                Ok(results
174                    .into_read_only()
175                    .into_par_iter()
176                    .flat_map(|(_x, y)| {
177                        y.into_read_only()
178                            .into_par_iter()
179                            .filter_map(|(_a, b)| serde_json::from_slice(&b).ok())
180                            .collect::<Vec<_>>()
181                    })
182                    .collect())
183            }
184            EmbeddingModality::Text => {
185                let db = self.text_database()?;
186                let results = db
187                    .query_documents(&[bytes.clone()], number_of_results)
188                    .map_err(|e| miette::miette!("{e}"))?;
189                Ok(results
190                    .into_read_only()
191                    .into_par_iter()
192                    .flat_map(|(_x, y)| {
193                        y.into_read_only()
194                            .into_par_iter()
195                            .filter_map(|(_a, b)| serde_json::from_slice(&b).ok())
196                            .collect::<Vec<_>>()
197                    })
198                    .collect())
199            }
200        }
201    }
202
203    /// Fetch an embedding file associated with a post.
204    ///
205    /// # Arguments
206    ///
207    /// * `ticket` - A ticket for the replica containing the file to retrieve.
208    ///
209    /// * `path` - The path to the file to retrieve.
210    ///
211    /// * `uri` - The URI associated with the OkuNet post.
212    pub async fn fetch_post_embeddings(
213        &self,
214        ticket: &DocTicket,
215        author_id: &AuthorId,
216        uri: &str,
217    ) -> miette::Result<()> {
218        let path: &PathBuf = &OkuNote::embedding_path_from_url(&uri.to_string()).into();
219        let archive_path: &PathBuf = &OkuNote::archive_path_from_url(&uri.to_string()).into();
220        if let Ok(embedding_bytes) = self
221            .fetch_file_with_ticket(ticket, path, &Some(home_replica_filters()))
222            .await
223        {
224            if let Ok(bytes) = self
225                .fetch_file_with_ticket(ticket, archive_path, &Some(home_replica_filters()))
226                .await
227            {
228                match self.bytes_to_embedding_modality(&bytes)? {
229                    EmbeddingModality::Audio => {
230                        let embedding =
231                            serde_json::from_str::<Embedding<DIM_VIT_BASE_PATCH16_224>>(
232                                String::from_utf8_lossy(&embedding_bytes).as_ref(),
233                            )
234                            .into_diagnostic()?;
235                        let db = self.audio_database()?;
236                        db.insert_records(
237                            &vec![embedding],
238                            &vec![serde_json::to_string(&(author_id, uri))
239                                .into_diagnostic()?
240                                .into()],
241                        )
242                        .map_err(|e| miette::miette!("{e}"))?;
243                        db.deduplicate().map_err(|e| miette::miette!("{e}"))?;
244                    }
245                    EmbeddingModality::Image => {
246                        let embedding =
247                            serde_json::from_str::<Embedding<DIM_VIT_BASE_PATCH16_224>>(
248                                String::from_utf8_lossy(&embedding_bytes).as_ref(),
249                            )
250                            .into_diagnostic()?;
251                        let db = self.image_database()?;
252                        db.insert_records(
253                            &vec![embedding],
254                            &vec![serde_json::to_string(&(author_id, uri))
255                                .into_diagnostic()?
256                                .into()],
257                        )
258                        .map_err(|e| miette::miette!("{e}"))?;
259                        db.deduplicate().map_err(|e| miette::miette!("{e}"))?;
260                    }
261                    EmbeddingModality::Text => {
262                        let embedding = serde_json::from_str::<Embedding<DIM_BGESMALL_EN_1_5>>(
263                            String::from_utf8_lossy(&embedding_bytes).as_ref(),
264                        )
265                        .into_diagnostic()?;
266                        let db = self.text_database()?;
267                        db.insert_records(
268                            &vec![embedding],
269                            &vec![serde_json::to_string(&(author_id, uri))
270                                .into_diagnostic()?
271                                .into()],
272                        )
273                        .map_err(|e| miette::miette!("{e}"))?;
274                        db.deduplicate().map_err(|e| miette::miette!("{e}"))?;
275                    }
276                }
277            }
278        }
279        Ok(())
280    }
281
282    /// Fetch an archived copy of a document.
283    ///
284    /// # Arguments
285    ///
286    /// * `author_id` - The authorship ID of the OkuNet user who archived the document.
287    ///
288    /// * `uri` - The URI of the document.
289    ///
290    /// # Returns
291    ///
292    /// The archived copy of the document.
293    pub async fn fetch_archive(&self, author_id: &AuthorId, uri: &str) -> anyhow::Result<Bytes> {
294        let path: &PathBuf = &OkuNote::archive_path_from_url(&uri.to_string()).into();
295        let ticket = self.resolve_author_id(author_id).await?;
296        self.fetch_file_with_ticket(&ticket, path, &Some(home_replica_filters()))
297            .await
298    }
299}