use anyhow::Result;
pub use deadpool_sqlite::Pool;
use deadpool_sqlite::{Config, Hook, HookError, Object, Runtime};
use rusqlite::{
hooks::{AuthAction, AuthContext, Authorization},
Connection, Error, TransactionBehavior,
};
use std::{
path::Path,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};
use tokio::sync::Mutex;
pub mod functions;
pub mod migrations;
mod operations;
#[cfg(test)]
mod tests;
pub use self::operations::*;
type Writer = Arc<Mutex<Connection>>;
#[derive(Clone, Debug)]
pub struct Database {
writer: Writer,
readers: Pool,
}
fn authorize_read(context: AuthContext<'_>) -> Authorization {
match &context.action {
AuthAction::Read { .. } => Authorization::Allow,
AuthAction::Select { .. } => Authorization::Allow,
AuthAction::Function { .. } => Authorization::Allow,
_ => Authorization::Deny,
}
}
fn authorize_write(context: AuthContext<'_>) -> Authorization {
match &context.action {
AuthAction::Read { .. } => Authorization::Allow,
AuthAction::Select { .. } => Authorization::Allow,
AuthAction::Function { .. } => Authorization::Allow,
AuthAction::Delete { .. } => Authorization::Allow,
AuthAction::Insert { .. } => Authorization::Allow,
AuthAction::Update { .. } => Authorization::Allow,
_ => Authorization::Deny,
}
}
fn init_writer(conn: &mut Connection) -> Result<()> {
functions::register_all(conn)?;
migrations::migrate(conn)?;
conn.authorizer(Some(authorize_write));
conn.set_transaction_behavior(TransactionBehavior::Immediate);
Ok(())
}
fn init_reader(conn: &Connection) -> Result<()> {
functions::register_all(conn)?;
conn.authorizer(Some(authorize_read));
Ok(())
}
static MEMORY_INDEX: AtomicU64 = AtomicU64::new(0);
impl Database {
pub async fn open(path: &Path) -> Result<Self> {
let pool = pool(path)?;
let writer = Arc::new(Mutex::new(Connection::open(path)?));
let mut handle = writer.clone().lock_owned().await;
tokio::task::spawn_blocking(move || {
init_writer(&mut handle)?;
Ok(()) as Result<_>
})
.await??;
Ok(Self {
writer,
readers: pool,
})
}
pub async fn memory() -> Result<Self> {
let index = MEMORY_INDEX.fetch_add(1, Ordering::SeqCst);
let name = format!("file:memdb{index}?mode=memory&cache=shared");
Self::open(Path::new(&name)).await
}
pub async fn reader(&self) -> Result<Object> {
Ok(self.readers.get().await?)
}
pub async fn write<T: 'static + Send, F: FnOnce(&mut Connection) -> T + 'static + Send>(
&self,
func: F,
) -> Result<T> {
let mut handle = self.writer.clone().lock_owned().await;
let result = tokio::task::spawn_blocking(move || func(&mut handle)).await?;
Ok(result)
}
}
pub fn init(conn: &Connection) -> Result<()> {
functions::register_all(conn)?;
migrations::migrate(conn)?;
Ok(())
}
pub fn pool(path: &Path) -> Result<Pool> {
let pool = Config::new(path)
.builder(Runtime::Tokio1)?
.post_create(Hook::async_fn(|wrapper, metrics| {
Box::pin(async move {
wrapper
.interact(|conn| {
init_reader(conn).map_err(|e| Error::UserFunctionError(e.into()))?;
Ok(()) as Result<(), Error>
})
.await
.map_err(|e| HookError::Message(e.to_string().into()))?
.map_err(HookError::Backend)?;
Ok(()) as Result<(), HookError>
})
}))
.max_size(16)
.build()?;
Ok(pool)
}