oku_fs/fs/net/
embeddings.rs1use super::core::{home_replica_filters, EmbeddingModality};
2use crate::{database::posts::core::OkuNote, fs::OkuFs};
3use bytes::Bytes;
4use iroh_blobs::Hash;
5use iroh_docs::AuthorId;
6use iroh_docs::DocTicket;
7use log::error;
8use miette::IntoDiagnostic;
9use rayon::iter::IntoParallelIterator;
10use rayon::iter::ParallelIterator;
11use std::path::PathBuf;
12use url::Url;
13use zebra::{
14 database::default::{
15 audio::DefaultAudioModel, image::DefaultImageModel, text::DefaultTextModel,
16 },
17 model::core::{DatabaseEmbeddingModel, DIM_BGESMALL_EN_1_5, DIM_VIT_BASE_PATCH16_224},
18 Embedding,
19};
20
21impl OkuFs {
22 pub fn text_database(
24 &self,
25 ) -> miette::Result<zebra::database::default::text::DefaultTextDatabase> {
26 zebra::database::default::text::DefaultTextDatabase::open_or_create(
27 &"text.zebra".into(),
28 &Default::default(),
29 )
30 .map_err(|e| miette::miette!("{e}"))
31 }
32
33 pub fn image_database(
35 &self,
36 ) -> miette::Result<zebra::database::default::image::DefaultImageDatabase> {
37 zebra::database::default::image::DefaultImageDatabase::open_or_create(
38 &"image.zebra".into(),
39 &Default::default(),
40 )
41 .map_err(|e| miette::miette!("{e}"))
42 }
43
44 pub fn audio_database(
46 &self,
47 ) -> miette::Result<zebra::database::default::audio::DefaultAudioDatabase> {
48 zebra::database::default::audio::DefaultAudioDatabase::open_or_create(
49 &"audio.zebra".into(),
50 &Default::default(),
51 )
52 .map_err(|e| miette::miette!("{e}"))
53 }
54
55 pub fn bytes_to_embedding_modality(&self, bytes: &Bytes) -> miette::Result<EmbeddingModality> {
65 let mime_type = tree_magic_mini::from_u8(bytes);
66 let type_ = mime_type.split("/").nth(0).unwrap_or_default();
67 match type_ {
68 "audio" => Ok(EmbeddingModality::Audio),
69 "image" => Ok(EmbeddingModality::Image),
70 "text" => Ok(EmbeddingModality::Text),
71 _ => Err(miette::miette!(
72 "Unexpected MIME type ({mime_type:?}); embedding modality cannot be determined … "
73 )),
74 }
75 }
76
77 pub async fn create_post_embedding(&self, url: &Url, bytes: &Bytes) -> miette::Result<Hash> {
91 let home_replica_id = self
92 .home_replica()
93 .await
94 .ok_or(miette::miette!("No home replica set … "))?;
95 let url_string = url.to_string();
96 let embed_path: &PathBuf = &OkuNote::embedding_path_from_url(&url_string).into();
97 let archive_path: &PathBuf = &OkuNote::archive_path_from_url(&url_string).into();
98 if let Err(e) = self
99 .create_or_modify_file(&home_replica_id, archive_path, bytes.clone())
100 .await
101 {
102 error!("{e}");
103 }
104 match self.bytes_to_embedding_modality(bytes)? {
105 EmbeddingModality::Audio => {
106 let model = DefaultAudioModel::default();
107 let embedding = model
108 .embed(bytes.clone())
109 .map_err(|e| miette::miette!("{e}"))?;
110 let embedding_json = serde_json::to_string(&embedding).into_diagnostic()?;
111 self.create_or_modify_file(&home_replica_id, embed_path, embedding_json)
112 .await
113 }
114 EmbeddingModality::Image => {
115 let model = DefaultImageModel::default();
116 let embedding = model
117 .embed(bytes.clone())
118 .map_err(|e| miette::miette!("{e}"))?;
119 let embedding_json = serde_json::to_string(&embedding).into_diagnostic()?;
120 self.create_or_modify_file(&home_replica_id, embed_path, embedding_json)
121 .await
122 }
123 EmbeddingModality::Text => {
124 let model = DefaultTextModel::default();
125 let embedding = model
126 .embed(bytes.clone())
127 .map_err(|e| miette::miette!("{e}"))?;
128 let embedding_json = serde_json::to_string(&embedding).into_diagnostic()?;
129 self.create_or_modify_file(&home_replica_id, embed_path, embedding_json)
130 .await
131 }
132 }
133 }
134
135 pub fn nearest_archives(
147 &self,
148 bytes: &Bytes,
149 number_of_results: usize,
150 ) -> miette::Result<Vec<(AuthorId, String)>> {
151 match self.bytes_to_embedding_modality(bytes)? {
152 EmbeddingModality::Audio => {
153 let db = self.audio_database()?;
154 let results = db
155 .query_documents(&[bytes.clone()], number_of_results)
156 .map_err(|e| miette::miette!("{e}"))?;
157 Ok(results
158 .into_read_only()
159 .into_par_iter()
160 .flat_map(|(_x, y)| {
161 y.into_read_only()
162 .into_par_iter()
163 .filter_map(|(_a, b)| serde_json::from_slice(&b).ok())
164 .collect::<Vec<_>>()
165 })
166 .collect())
167 }
168 EmbeddingModality::Image => {
169 let db = self.image_database()?;
170 let results = db
171 .query_documents(&[bytes.clone()], number_of_results)
172 .map_err(|e| miette::miette!("{e}"))?;
173 Ok(results
174 .into_read_only()
175 .into_par_iter()
176 .flat_map(|(_x, y)| {
177 y.into_read_only()
178 .into_par_iter()
179 .filter_map(|(_a, b)| serde_json::from_slice(&b).ok())
180 .collect::<Vec<_>>()
181 })
182 .collect())
183 }
184 EmbeddingModality::Text => {
185 let db = self.text_database()?;
186 let results = db
187 .query_documents(&[bytes.clone()], number_of_results)
188 .map_err(|e| miette::miette!("{e}"))?;
189 Ok(results
190 .into_read_only()
191 .into_par_iter()
192 .flat_map(|(_x, y)| {
193 y.into_read_only()
194 .into_par_iter()
195 .filter_map(|(_a, b)| serde_json::from_slice(&b).ok())
196 .collect::<Vec<_>>()
197 })
198 .collect())
199 }
200 }
201 }
202
203 pub async fn fetch_post_embeddings(
213 &self,
214 ticket: &DocTicket,
215 author_id: &AuthorId,
216 uri: &str,
217 ) -> miette::Result<()> {
218 let path: &PathBuf = &OkuNote::embedding_path_from_url(&uri.to_string()).into();
219 let archive_path: &PathBuf = &OkuNote::archive_path_from_url(&uri.to_string()).into();
220 if let Ok(embedding_bytes) = self
221 .fetch_file_with_ticket(ticket, path, &Some(home_replica_filters()))
222 .await
223 {
224 if let Ok(bytes) = self
225 .fetch_file_with_ticket(ticket, archive_path, &Some(home_replica_filters()))
226 .await
227 {
228 match self.bytes_to_embedding_modality(&bytes)? {
229 EmbeddingModality::Audio => {
230 let embedding =
231 serde_json::from_str::<Embedding<DIM_VIT_BASE_PATCH16_224>>(
232 String::from_utf8_lossy(&embedding_bytes).as_ref(),
233 )
234 .into_diagnostic()?;
235 let db = self.audio_database()?;
236 db.insert_records(
237 &vec![embedding],
238 &vec![serde_json::to_string(&(author_id, uri))
239 .into_diagnostic()?
240 .into()],
241 )
242 .map_err(|e| miette::miette!("{e}"))?;
243 db.deduplicate().map_err(|e| miette::miette!("{e}"))?;
244 }
245 EmbeddingModality::Image => {
246 let embedding =
247 serde_json::from_str::<Embedding<DIM_VIT_BASE_PATCH16_224>>(
248 String::from_utf8_lossy(&embedding_bytes).as_ref(),
249 )
250 .into_diagnostic()?;
251 let db = self.image_database()?;
252 db.insert_records(
253 &vec![embedding],
254 &vec![serde_json::to_string(&(author_id, uri))
255 .into_diagnostic()?
256 .into()],
257 )
258 .map_err(|e| miette::miette!("{e}"))?;
259 db.deduplicate().map_err(|e| miette::miette!("{e}"))?;
260 }
261 EmbeddingModality::Text => {
262 let embedding = serde_json::from_str::<Embedding<DIM_BGESMALL_EN_1_5>>(
263 String::from_utf8_lossy(&embedding_bytes).as_ref(),
264 )
265 .into_diagnostic()?;
266 let db = self.text_database()?;
267 db.insert_records(
268 &vec![embedding],
269 &vec![serde_json::to_string(&(author_id, uri))
270 .into_diagnostic()?
271 .into()],
272 )
273 .map_err(|e| miette::miette!("{e}"))?;
274 db.deduplicate().map_err(|e| miette::miette!("{e}"))?;
275 }
276 }
277 }
278 }
279 Ok(())
280 }
281
282 pub async fn fetch_archive(&self, author_id: &AuthorId, uri: &str) -> anyhow::Result<Bytes> {
294 let path: &PathBuf = &OkuNote::archive_path_from_url(&uri.to_string()).into();
295 let ticket = self.resolve_author_id(author_id).await?;
296 self.fetch_file_with_ticket(&ticket, path, &Some(home_replica_filters()))
297 .await
298 }
299}