oku_fs/fs/net/
embeddings.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
use super::core::{home_replica_filters, EmbeddingModality};
use crate::fs::OkuFs;
use iroh_docs::DocTicket;
use miette::IntoDiagnostic;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use std::{collections::HashMap, path::PathBuf};

impl OkuFs {
    /// The embedding vector database for text media.
    pub fn text_database(&self) -> zebra::database::default::text::DefaultTextDatabase {
        zebra::database::default::text::DefaultTextDatabase::open_or_create(&"text.zebra".into())
    }

    /// The embedding vector database for image media.
    pub fn image_database(&self) -> zebra::database::default::image::DefaultImageDatabase {
        zebra::database::default::image::DefaultImageDatabase::open_or_create(&"image.zebra".into())
    }

    /// The embedding vector database for audio media.
    pub fn audio_database(&self) -> zebra::database::default::audio::DefaultAudioDatabase {
        zebra::database::default::audio::DefaultAudioDatabase::open_or_create(&"audio.zebra".into())
    }

    /// Fetch an embedding file associated with a post.
    ///
    /// # Arguments
    ///
    /// * `ticket` - A ticket for the replica containing the file to retrieve.
    ///
    /// * `path` - The path to the file to retrieve.
    ///
    /// * `uri` - The URI associated with the OkuNet post.
    pub(crate) async fn fetch_post_embeddings(
        &self,
        ticket: &DocTicket,
        path: &PathBuf,
        uri: &str,
    ) -> miette::Result<()> {
        if let Ok(bytes) = self
            .fetch_file_with_ticket(ticket, path, &Some(home_replica_filters()))
            .await
        {
            let embeddings = toml::from_str::<HashMap<EmbeddingModality, Vec<f32>>>(
                String::from_utf8_lossy(&bytes).as_ref(),
            )
            .into_diagnostic()?;
            let text_db = self.text_database();
            let image_db = self.image_database();
            let audio_db = self.audio_database();
            embeddings
                .into_par_iter()
                .map(|(modality, embedding)| -> miette::Result<()> {
                    match modality {
                        EmbeddingModality::Text => {
                            text_db
                                .insert_records(
                                    &vec![embedding.try_into().unwrap_or_default()],
                                    &vec![uri.to_owned().into()],
                                )
                                .map_err(|e| miette::miette!("{e}"))?;
                        }
                        EmbeddingModality::Image => {
                            image_db
                                .insert_records(
                                    &vec![embedding.try_into().unwrap_or_default()],
                                    &vec![uri.to_owned().into()],
                                )
                                .map_err(|e| miette::miette!("{e}"))?;
                        }
                        EmbeddingModality::Audio => {
                            audio_db
                                .insert_records(
                                    &vec![embedding.try_into().unwrap_or_default()],
                                    &vec![uri.to_owned().into()],
                                )
                                .map_err(|e| miette::miette!("{e}"))?;
                        }
                    }
                    Ok(())
                })
                .collect::<miette::Result<Vec<_>>>()?;
            text_db
                .index
                .deduplicate()
                .map_err(|e| miette::miette!("{e}"))?;
            image_db
                .index
                .deduplicate()
                .map_err(|e| miette::miette!("{e}"))?;
            audio_db
                .index
                .deduplicate()
                .map_err(|e| miette::miette!("{e}"))?;
        }
        Ok(())
    }
}