1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
//! # Database access, migration and functions
//!
//! OpenVet uses SQLite as a database. It does this because the amount of data that it needs to
//! handle is small, SQLite is easily extensible and easy to back up.
//!
//! This module handles:
//!
//! - Augmenting SQLite with user-defined functions for ensuring the integrity of the database
//! - Backing up the database. Database backups are performed periodically and uploaded to the
//!   object storage.
//! - Migrating the database.
//! - Providing an abstraction for using the database.
//!
//! It uses a pool of database connections for concurrent access.

use anyhow::Result;
pub use deadpool_sqlite::Pool;
use deadpool_sqlite::{Config, Hook, HookError, Object, Runtime};
use rusqlite::{
    hooks::{AuthAction, AuthContext, Authorization},
    Connection, Error, TransactionBehavior,
};
use std::{
    path::Path,
    sync::{
        atomic::{AtomicU64, Ordering},
        Arc,
    },
};
use tokio::sync::Mutex;

pub mod functions;
pub mod migrations;
mod operations;

#[cfg(test)]
mod tests;

pub use self::operations::*;

type Writer = Arc<Mutex<Connection>>;

#[derive(Clone, Debug)]
pub struct Database {
    writer: Writer,
    readers: Pool,
}

fn authorize_read(context: AuthContext<'_>) -> Authorization {
    // TODO: log denials
    match &context.action {
        AuthAction::Read { .. } => Authorization::Allow,
        AuthAction::Select { .. } => Authorization::Allow,
        AuthAction::Function { .. } => Authorization::Allow,
        _ => Authorization::Deny,
    }
}

fn authorize_write(context: AuthContext<'_>) -> Authorization {
    // TODO: log denials
    match &context.action {
        AuthAction::Read { .. } => Authorization::Allow,
        AuthAction::Select { .. } => Authorization::Allow,
        AuthAction::Function { .. } => Authorization::Allow,
        AuthAction::Delete { .. } => Authorization::Allow,
        AuthAction::Insert { .. } => Authorization::Allow,
        AuthAction::Update { .. } => Authorization::Allow,
        _ => Authorization::Deny,
    }
}

fn init_writer(conn: &mut Connection) -> Result<()> {
    functions::register_all(conn)?;
    migrations::migrate(conn)?;
    conn.authorizer(Some(authorize_write));
    conn.set_transaction_behavior(TransactionBehavior::Immediate);
    Ok(())
}

fn init_reader(conn: &Connection) -> Result<()> {
    functions::register_all(conn)?;
    conn.authorizer(Some(authorize_read));
    Ok(())
}

static MEMORY_INDEX: AtomicU64 = AtomicU64::new(0);

impl Database {
    pub async fn open(path: &Path) -> Result<Self> {
        // create pool
        let pool = pool(path)?;

        let writer = Arc::new(Mutex::new(Connection::open(path)?));
        let mut handle = writer.clone().lock_owned().await;
        tokio::task::spawn_blocking(move || {
            init_writer(&mut handle)?;
            Ok(()) as Result<_>
        })
        .await??;

        Ok(Self {
            writer,
            readers: pool,
        })
    }

    pub async fn memory() -> Result<Self> {
        let index = MEMORY_INDEX.fetch_add(1, Ordering::SeqCst);
        let name = format!("file:memdb{index}?mode=memory&cache=shared");
        Self::open(Path::new(&name)).await
    }

    /// Get a handle to a reader.
    ///
    /// The readers are managed by the deadpool crate. While the underlying connections can be
    /// used for writing, they should only be used to receive data.
    pub async fn reader(&self) -> Result<Object> {
        Ok(self.readers.get().await?)
    }

    pub async fn write<T: 'static + Send, F: FnOnce(&mut Connection) -> T + 'static + Send>(
        &self,
        func: F,
    ) -> Result<T> {
        let mut handle = self.writer.clone().lock_owned().await;
        let result = tokio::task::spawn_blocking(move || func(&mut handle)).await?;
        Ok(result)
    }
}

pub fn init(conn: &Connection) -> Result<()> {
    functions::register_all(conn)?;
    migrations::migrate(conn)?;
    Ok(())
}

pub fn pool(path: &Path) -> Result<Pool> {
    let pool = Config::new(path)
        .builder(Runtime::Tokio1)?
        .post_create(Hook::async_fn(|wrapper, metrics| {
            Box::pin(async move {
                wrapper
                    .interact(|conn| {
                        init_reader(conn).map_err(|e| Error::UserFunctionError(e.into()))?;
                        Ok(()) as Result<(), Error>
                    })
                    .await
                    .map_err(|e| HookError::Message(e.to_string().into()))?
                    .map_err(HookError::Backend)?;
                Ok(()) as Result<(), HookError>
            })
        }))
        .max_size(16)
        .build()?;
    Ok(pool)
}