oku_fs/fs/net/
embeddings.rs1use super::core::{home_replica_filters, EmbeddingModality};
2use crate::{database::posts::core::OkuNote, fs::OkuFs};
3use bytes::Bytes;
4use iroh_blobs::Hash;
5use iroh_docs::DocTicket;
6use log::error;
7use miette::IntoDiagnostic;
8use rayon::iter::IntoParallelIterator;
9use rayon::iter::IntoParallelRefIterator;
10use rayon::iter::ParallelIterator;
11use std::path::PathBuf;
12use url::Url;
13use zebra::{
14 database::default::{
15 audio::DefaultAudioModel, image::DefaultImageModel, text::DefaultTextModel,
16 },
17 model::core::{DatabaseEmbeddingModel, DIM_BGESMALL_EN_1_5, DIM_VIT_BASE_PATCH16_224},
18 Embedding,
19};
20
21impl OkuFs {
22 pub fn text_database(&self) -> zebra::database::default::text::DefaultTextDatabase {
24 zebra::database::default::text::DefaultTextDatabase::open_or_create(&"text.zebra".into())
25 }
26
27 pub fn image_database(&self) -> zebra::database::default::image::DefaultImageDatabase {
29 zebra::database::default::image::DefaultImageDatabase::open_or_create(&"image.zebra".into())
30 }
31
32 pub fn audio_database(&self) -> zebra::database::default::audio::DefaultAudioDatabase {
34 zebra::database::default::audio::DefaultAudioDatabase::open_or_create(&"audio.zebra".into())
35 }
36
37 pub fn bytes_to_embedding_modality(&self, bytes: &Bytes) -> miette::Result<EmbeddingModality> {
47 let mime_type = tree_magic_mini::from_u8(bytes);
48 let type_ = mime_type.split("/").nth(0).unwrap_or_default();
49 match type_ {
50 "audio" => Ok(EmbeddingModality::Audio),
51 "image" => Ok(EmbeddingModality::Image),
52 "text" => Ok(EmbeddingModality::Text),
53 _ => Err(miette::miette!(
54 "Unexpected MIME type ({mime_type:?}); embedding modality cannot be determined … "
55 )),
56 }
57 }
58
59 pub async fn create_post_embedding(
73 &self,
74 path: &Option<PathBuf>,
75 url: &Url,
76 bytes: &Bytes,
77 ) -> miette::Result<Hash> {
78 let home_replica_id = self
79 .home_replica()
80 .await
81 .ok_or(miette::miette!("No home replica set … "))?;
82 let embed_path = match path {
83 Some(given_path) => given_path,
84 None => &{
85 let mut path: PathBuf =
86 OkuNote::suggested_post_path_from_url(&url.to_string()).into();
87 path.set_extension("okuembed");
88 path
89 },
90 };
91 let mut archive_path = embed_path.clone();
92 archive_path.set_extension("okuarchive");
93 if let Err(e) = self
94 .create_or_modify_file(&home_replica_id, &archive_path, bytes.clone())
95 .await
96 {
97 error!("{e}");
98 }
99 match self.bytes_to_embedding_modality(bytes)? {
100 EmbeddingModality::Audio => {
101 let model = DefaultAudioModel::default();
102 let embedding = model
103 .embed(bytes.clone())
104 .map_err(|e| miette::miette!("{e}"))?;
105 let embedding_json = serde_json::to_string(&embedding).into_diagnostic()?;
106 self.create_or_modify_file(&home_replica_id, embed_path, embedding_json)
107 .await
108 }
109 EmbeddingModality::Image => {
110 let model = DefaultImageModel::default();
111 let embedding = model
112 .embed(bytes.clone())
113 .map_err(|e| miette::miette!("{e}"))?;
114 let embedding_json = serde_json::to_string(&embedding).into_diagnostic()?;
115 self.create_or_modify_file(&home_replica_id, embed_path, embedding_json)
116 .await
117 }
118 EmbeddingModality::Text => {
119 let model = DefaultTextModel::default();
120 let embedding = model
121 .embed(bytes.clone())
122 .map_err(|e| miette::miette!("{e}"))?;
123 let embedding_json = serde_json::to_string(&embedding).into_diagnostic()?;
124 self.create_or_modify_file(&home_replica_id, embed_path, embedding_json)
125 .await
126 }
127 }
128 }
129
130 pub fn nearest_urls(
142 &self,
143 bytes: &Bytes,
144 number_of_results: usize,
145 ) -> miette::Result<Vec<Url>> {
146 match self.bytes_to_embedding_modality(bytes)? {
147 EmbeddingModality::Audio => {
148 let db = self.audio_database();
149 let results = db
150 .query_documents(&[bytes.clone()], number_of_results)
151 .map_err(|e| miette::miette!("{e}"))?;
152 let result_strings: Vec<String> = results
153 .into_read_only()
154 .into_par_iter()
155 .flat_map(|(_x, y)| {
156 y.into_read_only()
157 .into_par_iter()
158 .map(|(_a, b)| String::from_utf8_lossy(&b).to_string())
159 .collect::<Vec<_>>()
160 })
161 .collect();
162 Ok(result_strings
163 .par_iter()
164 .filter_map(|x| Url::parse(x).ok())
165 .collect())
166 }
167 EmbeddingModality::Image => {
168 let db = self.image_database();
169 let results = db
170 .query_documents(&[bytes.clone()], number_of_results)
171 .map_err(|e| miette::miette!("{e}"))?;
172 let result_strings: Vec<String> = results
173 .into_read_only()
174 .into_par_iter()
175 .flat_map(|(_x, y)| {
176 y.into_read_only()
177 .into_par_iter()
178 .map(|(_a, b)| String::from_utf8_lossy(&b).to_string())
179 .collect::<Vec<_>>()
180 })
181 .collect();
182 Ok(result_strings
183 .par_iter()
184 .filter_map(|x| Url::parse(x).ok())
185 .collect())
186 }
187 EmbeddingModality::Text => {
188 let db = self.text_database();
189 let results = db
190 .query_documents(&[bytes.clone()], number_of_results)
191 .map_err(|e| miette::miette!("{e}"))?;
192 let result_strings: Vec<String> = results
193 .into_read_only()
194 .into_par_iter()
195 .flat_map(|(_x, y)| {
196 y.into_read_only()
197 .into_par_iter()
198 .map(|(_a, b)| String::from_utf8_lossy(&b).to_string())
199 .collect::<Vec<_>>()
200 })
201 .collect();
202 Ok(result_strings
203 .par_iter()
204 .filter_map(|x| Url::parse(x).ok())
205 .collect())
206 }
207 }
208 }
209
210 pub async fn fetch_post_embeddings(
220 &self,
221 ticket: &DocTicket,
222 path: &PathBuf,
223 uri: &str,
224 ) -> miette::Result<()> {
225 let mut archive_path = path.clone();
226 archive_path.set_extension("okuarchive");
227 if let Ok(embedding_bytes) = self
228 .fetch_file_with_ticket(ticket, path, &Some(home_replica_filters()))
229 .await
230 {
231 if let Ok(bytes) = self
232 .fetch_file_with_ticket(ticket, &archive_path, &Some(home_replica_filters()))
233 .await
234 {
235 match self.bytes_to_embedding_modality(&bytes)? {
236 EmbeddingModality::Audio => {
237 let embedding =
238 serde_json::from_str::<Embedding<DIM_VIT_BASE_PATCH16_224>>(
239 String::from_utf8_lossy(&embedding_bytes).as_ref(),
240 )
241 .into_diagnostic()?;
242 let db = self.audio_database();
243 db.insert_records(&vec![embedding], &vec![uri.to_owned().into()])
244 .map_err(|e| miette::miette!("{e}"))?;
245 db.deduplicate().map_err(|e| miette::miette!("{e}"))?;
246 }
247 EmbeddingModality::Image => {
248 let embedding =
249 serde_json::from_str::<Embedding<DIM_VIT_BASE_PATCH16_224>>(
250 String::from_utf8_lossy(&embedding_bytes).as_ref(),
251 )
252 .into_diagnostic()?;
253 let db = self.image_database();
254 db.insert_records(&vec![embedding], &vec![uri.to_owned().into()])
255 .map_err(|e| miette::miette!("{e}"))?;
256 db.deduplicate().map_err(|e| miette::miette!("{e}"))?;
257 }
258 EmbeddingModality::Text => {
259 let embedding = serde_json::from_str::<Embedding<DIM_BGESMALL_EN_1_5>>(
260 String::from_utf8_lossy(&embedding_bytes).as_ref(),
261 )
262 .into_diagnostic()?;
263 let db = self.text_database();
264 db.insert_records(&vec![embedding], &vec![uri.to_owned().into()])
265 .map_err(|e| miette::miette!("{e}"))?;
266 db.deduplicate().map_err(|e| miette::miette!("{e}"))?;
267 }
268 }
269 }
270 }
271 Ok(())
272 }
273}