From e25ea8f285515872def6df28b0c7d34e091e8de6 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 6 Jul 2023 15:19:36 +0200 Subject: [PATCH 1/8] setup logging --- Cargo.lock | 154 +++++++++++++++++++++++++++++++------ Cargo.toml | 1 + libsqlx-server/Cargo.toml | 14 ++++ libsqlx-server/src/main.rs | 29 +++++++ 4 files changed, 176 insertions(+), 22 deletions(-) create mode 100644 libsqlx-server/Cargo.toml create mode 100644 libsqlx-server/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 6d0349e7..fc45ee48 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,7 +8,16 @@ version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9ecd88a8c8378ca913a680cd98f0f13ac67383d35993f86c90a70e3f137816b" dependencies = [ - "gimli", + "gimli 0.26.2", +] + +[[package]] +name = "addr2line" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4fa78e18c64fce05e902adecd7a5eed15a5e0a3439f7b0e169f0252214865e3" +dependencies = [ + "gimli 0.27.3", ] [[package]] @@ -643,6 +652,21 @@ dependencies = [ "tower-service", ] +[[package]] +name = "backtrace" +version = "0.3.68" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4319208da049c43661739c5fade2ba182f09d1dc2299b32298d3a31692b17e12" +dependencies = [ + "addr2line 0.20.0", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object 0.31.1", + "rustc-demangle", +] + [[package]] name = "base64" version = "0.13.1" @@ -970,9 +994,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.3.0" +version = "4.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93aae7a4192245f70fe75dd9157fc7b4a5bf53e88d30bd4396f7d8f9284d5acc" +checksum = "1640e5cc7fb47dbb8338fd471b105e7ed6c3cb2aeb00c2e067127ffd3764a05d" dependencies = [ "clap_builder", "clap_derive", @@ -981,22 +1005,21 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.3.0" +version = "4.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f423e341edefb78c9caba2d9c7f7687d0e72e89df3ce3394554754393ac3990" +checksum = "98c59138d527eeaf9b53f35a77fcc1fad9d883116070c63d5de1c7dc7b00c72b" dependencies = [ "anstream", "anstyle", - "bitflags 1.3.2", "clap_lex", "strsim", ] [[package]] name = "clap_derive" -version = "4.3.0" +version = "4.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "191d9573962933b4027f932c600cd252ce27a8ad5979418fe78e43c07996f27b" +checksum = "b8cd2b2a819ad6eec39e8f1d6b53001af1e5469f8c177579cdaeb313115b825f" dependencies = [ "heck", "proc-macro2", @@ -1010,6 +1033,33 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b" +[[package]] +name = "color-eyre" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a667583cca8c4f8436db8de46ea8233c42a7d9ae424a82d338f2e4675229204" +dependencies = [ + "backtrace", + "color-spantrace", + "eyre", + "indenter", + "once_cell", + "owo-colors", + "tracing-error", +] + +[[package]] +name = "color-spantrace" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ba75b3d9449ecdccb27ecbc479fdc0b87fa2dd43d2f8298f9bf0e59aacc8dce" +dependencies = [ + "once_cell", + "owo-colors", + "tracing-core", + "tracing-error", +] + [[package]] name = "colorchoice" version = "1.0.0" @@ -1145,7 +1195,7 @@ dependencies = [ "cranelift-egraph", "cranelift-entity", "cranelift-isle", - "gimli", + "gimli 0.26.2", "log", "regalloc2", "smallvec", @@ -1514,6 +1564,16 @@ version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" +[[package]] +name = "eyre" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c2b6b5a29c02cdc822728b7d7b8ae1bab3e3b05d44522770ddd49722eeac7eb" +dependencies = [ + "indenter", + "once_cell", +] + [[package]] name = "fallible-iterator" version = "0.2.0" @@ -1784,6 +1844,12 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "gimli" +version = "0.27.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c80984affa11d98d1b88b66ac8853f143217b399d3c74116778ff8fdb4ed2e" + [[package]] name = "glob" version = "0.3.1" @@ -2082,6 +2148,12 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "indenter" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" + [[package]] name = "indexmap" version = "1.9.3" @@ -2292,9 +2364,9 @@ checksum = "884e2677b40cc8c339eaefcb701c32ef1fd2493d71118dc0ca4b6a736c93bd67" [[package]] name = "libc" -version = "0.2.144" +version = "0.2.147" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b00cc1c228a6782d0f076e7b232802e0c5689d41bb5df366f2a6b6621cfdfe1" +checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" [[package]] name = "libloading" @@ -2391,6 +2463,18 @@ dependencies = [ "uuid", ] +[[package]] +name = "libsqlx-server" +version = "0.1.0" +dependencies = [ + "axum", + "clap", + "color-eyre", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "linked-hash-map" version = "0.5.6" @@ -2684,6 +2768,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "object" +version = "0.31.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bda667d9f2b5051b8833f59f3bf748b28ef54f850f4fcb389a252aa383866d1" +dependencies = [ + "memchr", +] + [[package]] name = "octopod" version = "0.1.0" @@ -2775,6 +2868,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "owo-colors" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1b04fb49957986fdce4d6ee7a65027d55d4b6d2265e5848bbb507b58ccfdb6f" + [[package]] name = "parking_lot" version = "0.12.1" @@ -4035,11 +4134,12 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.28.2" +version = "1.29.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94d7b1cfd2aa4011f2de74c2c4c63665e27a71006b0a192dcd2710272e73dfa2" +checksum = "532826ff75199d5833b9d2c5fe410f29235e25704ee5f0ef599fb51c21f4a4da" dependencies = [ "autocfg", + "backtrace", "bytes 1.4.0", "libc", "mio", @@ -4314,6 +4414,16 @@ dependencies = [ "valuable", ] +[[package]] +name = "tracing-error" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d686ec1c0f384b1277f097b2f279a2ecc11afe8c133c1aabf036a27cb4cd206e" +dependencies = [ + "tracing", + "tracing-subscriber", +] + [[package]] name = "tracing-futures" version = "0.2.5" @@ -4684,7 +4794,7 @@ dependencies = [ "indexmap 1.9.3", "libc", "log", - "object", + "object 0.29.0", "once_cell", "paste", "psm", @@ -4743,9 +4853,9 @@ dependencies = [ "cranelift-frontend", "cranelift-native", "cranelift-wasm", - "gimli", + "gimli 0.26.2", "log", - "object", + "object 0.29.0", "target-lexicon", "thiserror", "wasmparser", @@ -4760,10 +4870,10 @@ checksum = "754b97f7441ac780a7fa738db5b9c23c1b70ef4abccd8ad205ada5669d196ba2" dependencies = [ "anyhow", "cranelift-entity", - "gimli", + "gimli 0.26.2", "indexmap 1.9.3", "log", - "object", + "object 0.29.0", "serde", "target-lexicon", "thiserror", @@ -4790,15 +4900,15 @@ version = "3.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32800cb6e29faabab7056593f70a4c00c65c75c365aaf05406933f2169d0c22f" dependencies = [ - "addr2line", + "addr2line 0.17.0", "anyhow", "bincode", "cfg-if", "cpp_demangle", - "gimli", + "gimli 0.26.2", "ittapi", "log", - "object", + "object 0.29.0", "rustc-demangle", "serde", "target-lexicon", @@ -4816,7 +4926,7 @@ version = "3.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe057012a0ba6cee3685af1e923d6e0a6cb9baf15fb3ffa4be3d7f712c7dec42" dependencies = [ - "object", + "object 0.29.0", "once_cell", "rustix 0.35.13", ] diff --git a/Cargo.toml b/Cargo.toml index 238333f0..26e14f1e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "sqld-libsql-bindings", "testing/end-to-end", "libsqlx", + "libsqlx-server", ] [workspace.dependencies] diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml new file mode 100644 index 00000000..66ba92b7 --- /dev/null +++ b/libsqlx-server/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "libsqlx-server" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +axum = "0.6.18" +clap = { version = "4.3.11", features = ["derive"] } +color-eyre = "0.6.2" +tokio = { version = "1.29.1", features = ["full"] } +tracing = "0.1.37" +tracing-subscriber = "0.3.17" diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs new file mode 100644 index 00000000..a4eeb1e1 --- /dev/null +++ b/libsqlx-server/src/main.rs @@ -0,0 +1,29 @@ +use color_eyre::eyre::Result; +use tracing::metadata::LevelFilter; +use tracing_subscriber::prelude::*; + +#[tokio::main] +async fn main() -> Result<()> { + init(); + + Ok(()) +} + +fn init() { + let registry = tracing_subscriber::registry(); + + registry + .with( + tracing_subscriber::fmt::layer() + .with_ansi(false) + .with_filter( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ), + ) + .init(); + + color_eyre::install().unwrap(); +} + From e301b80d2c05d1d8100f53e3d0807b77228d7591 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 6 Jul 2023 15:34:58 +0200 Subject: [PATCH 2/8] admin server boilerplate --- Cargo.lock | 5 +++-- libsqlx-server/Cargo.toml | 1 + libsqlx-server/src/http/admin.rs | 21 +++++++++++++++++++++ libsqlx-server/src/http/mod.rs | 1 + libsqlx-server/src/main.rs | 3 ++- 5 files changed, 28 insertions(+), 3 deletions(-) create mode 100644 libsqlx-server/src/http/admin.rs create mode 100644 libsqlx-server/src/http/mod.rs diff --git a/Cargo.lock b/Cargo.lock index fc45ee48..06c5b482 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2014,9 +2014,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.26" +version = "0.14.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab302d72a6f11a3b910431ff93aae7e773078c769f0a3ef15fb9ec692ed147d4" +checksum = "ffb1cfd654a8219eaef89881fdb3bb3b1cdc5fa75ded05d6933b2b382e395468" dependencies = [ "bytes 1.4.0", "futures-channel", @@ -2470,6 +2470,7 @@ dependencies = [ "axum", "clap", "color-eyre", + "hyper", "tokio", "tracing", "tracing-subscriber", diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index 66ba92b7..7fe6b737 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -9,6 +9,7 @@ edition = "2021" axum = "0.6.18" clap = { version = "4.3.11", features = ["derive"] } color-eyre = "0.6.2" +hyper = { version = "0.14.27", features = ["h2"] } tokio = { version = "1.29.1", features = ["full"] } tracing = "0.1.37" tracing-subscriber = "0.3.17" diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs new file mode 100644 index 00000000..cda7ee55 --- /dev/null +++ b/libsqlx-server/src/http/admin.rs @@ -0,0 +1,21 @@ +use std::sync::Arc; + +use axum::Router; +use color_eyre::eyre::Result; +use hyper::server::accept::Accept; +use tokio::io::{AsyncRead, AsyncWrite}; + +pub struct AdminServerConfig { } + +struct AdminServerState { } + +pub async fn run_admin_server(_config: AdminServerConfig, listener: I) -> Result<()> +where I: Accept, + I::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ + let state = AdminServerState { }; + let app = Router::new().with_state(Arc::new(state)); + axum::Server::builder(listener).serve(app.into_make_service()).await?; + + Ok(()) +} diff --git a/libsqlx-server/src/http/mod.rs b/libsqlx-server/src/http/mod.rs new file mode 100644 index 00000000..92918b09 --- /dev/null +++ b/libsqlx-server/src/http/mod.rs @@ -0,0 +1 @@ +pub mod admin; diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index a4eeb1e1..15e33e24 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -2,6 +2,8 @@ use color_eyre::eyre::Result; use tracing::metadata::LevelFilter; use tracing_subscriber::prelude::*; +mod http; + #[tokio::main] async fn main() -> Result<()> { init(); @@ -26,4 +28,3 @@ fn init() { color_eyre::install().unwrap(); } - From 92c18367b791379f465c732d241f02924ecfe75a Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 7 Jul 2023 10:46:15 +0200 Subject: [PATCH 3/8] introduce allocation type to manages a single database --- libsqlx-server/src/allocation/config.rs | 9 ++ libsqlx-server/src/allocation/mod.rs | 124 ++++++++++++++++++++++++ 2 files changed, 133 insertions(+) create mode 100644 libsqlx-server/src/allocation/config.rs create mode 100644 libsqlx-server/src/allocation/mod.rs diff --git a/libsqlx-server/src/allocation/config.rs b/libsqlx-server/src/allocation/config.rs new file mode 100644 index 00000000..19a6396b --- /dev/null +++ b/libsqlx-server/src/allocation/config.rs @@ -0,0 +1,9 @@ +use serde::{Serialize, Deserialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub enum AllocConfig { + Primary { }, + Replica { + primary_node_id: String, + } +} diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs new file mode 100644 index 00000000..11ae61ae --- /dev/null +++ b/libsqlx-server/src/allocation/mod.rs @@ -0,0 +1,124 @@ +use std::collections::HashMap; + +use libsqlx::Database; +use tokio::{sync::{mpsc, oneshot}, task::{JoinSet, block_in_place}}; + +pub mod config; + +type ExecFn = Box; + +#[derive(Clone)] +struct ConnectionId { + id: u32, + close_sender: mpsc::Sender<()>, +} + +enum AllocationMessage { + /// Execute callback against connection + Exec { + connection_id: ConnectionId, + exec: ExecFn, + }, + /// Create a new connection, execute the callback and return the connection id. + NewConnExec { + exec: ExecFn, + ret: oneshot::Sender, + } +} + +pub struct Allocation { + inbox: mpsc::Receiver, + database: Box, + /// senders to the spawned connections + connections: HashMap>, + /// spawned connection futures, returning their connection id on completion. + connections_futs: JoinSet, + next_conn_id: u32, + max_concurrent_connections: u32, +} + +impl Allocation { + async fn run(mut self) { + loop { + tokio::select! { + Some(msg) = self.inbox.recv() => { + match msg { + AllocationMessage::Exec { connection_id, exec } => { + if let Some(sender) = self.connections.get(&connection_id.id) { + if let Err(_) = sender.send(exec).await { + tracing::debug!("connection {} closed.", connection_id.id); + self.connections.remove_entry(&connection_id.id); + } + } + }, + AllocationMessage::NewConnExec { exec, ret } => { + let id = self.new_conn_exec(exec).await; + let _ = ret.send(id); + }, + } + }, + maybe_id = self.connections_futs.join_next() => { + if let Some(Ok(id)) = maybe_id { + self.connections.remove_entry(&id); + } + }, + else => break, + } + } + } + + async fn new_conn_exec(&mut self, exec: ExecFn) -> ConnectionId { + let id = self.next_conn_id(); + let conn = block_in_place(|| self.database.connect()).unwrap(); + let (close_sender, exit) = mpsc::channel(1); + let (exec_sender, exec_receiver) = mpsc::channel(1); + let conn = Connection { + id, + conn, + exit, + exec: exec_receiver, + }; + + + self.connections_futs.spawn(conn.run()); + // This should never block! + assert!(exec_sender.try_send(exec).is_ok()); + assert!(self.connections.insert(id, exec_sender).is_none()); + + ConnectionId { + id, + close_sender, + } + } + + fn next_conn_id(&mut self) -> u32 { + loop { + self.next_conn_id = self.next_conn_id.wrapping_add(1); + if !self.connections.contains_key(&self.next_conn_id) { + return self.next_conn_id + } + } + } +} + +struct Connection { + id: u32, + conn: Box, + exit: mpsc::Receiver<()>, + exec: mpsc::Receiver, +} + +impl Connection { + async fn run(mut self) -> u32 { + loop { + tokio::select! { + _ = self.exit.recv() => break, + Some(exec) = self.exec.recv() => { + tokio::task::block_in_place(|| exec(&mut *self.conn)); + } + } + } + + self.id + } +} From 110c5fa0404e61186a71d454470d1a302c36cf42 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 7 Jul 2023 10:47:46 +0200 Subject: [PATCH 4/8] sketch meta store --- libsqlx-server/src/meta.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 libsqlx-server/src/meta.rs diff --git a/libsqlx-server/src/meta.rs b/libsqlx-server/src/meta.rs new file mode 100644 index 00000000..7f48a456 --- /dev/null +++ b/libsqlx-server/src/meta.rs @@ -0,0 +1,10 @@ +use uuid::Uuid; + +use crate::allocation::config::AllocConfig; + +pub struct MetaStore {} + +impl MetaStore { + pub async fn allocate(&self, alloc_id: &str, meta: &AllocConfig) {} + pub async fn deallocate(&self, alloc_id: Uuid) {} +} From 09f7f0955b93c0010ef77279a57215906085d1d3 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 7 Jul 2023 10:48:08 +0200 Subject: [PATCH 5/8] changes to libsqlx --- Cargo.lock | 58 +++++----- libsqlx-server/Cargo.toml | 4 + libsqlx-server/src/allocation/mod.rs | 15 ++- libsqlx-server/src/http/admin.rs | 48 ++++++-- libsqlx-server/src/main.rs | 3 + libsqlx/src/connection.rs | 20 +++- libsqlx/src/database/libsql/connection.rs | 12 +- libsqlx/src/database/libsql/injector/hook.rs | 7 +- libsqlx/src/database/libsql/injector/mod.rs | 20 +++- libsqlx/src/database/libsql/mod.rs | 73 +++++++----- libsqlx/src/database/mod.rs | 112 +------------------ libsqlx/src/database/proxy/connection.rs | 30 ++--- libsqlx/src/database/proxy/database.rs | 7 +- libsqlx/src/database/test_utils.rs | 10 +- libsqlx/src/lib.rs | 1 - libsqlx/src/result_builder.rs | 2 +- libsqlx/src/semaphore.rs | 98 ---------------- 17 files changed, 203 insertions(+), 317 deletions(-) delete mode 100644 libsqlx/src/semaphore.rs diff --git a/Cargo.lock b/Cargo.lock index 06c5b482..63b0f30a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -205,7 +205,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -216,7 +216,7 @@ checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -716,7 +716,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -836,7 +836,7 @@ checksum = "fdde5c9cd29ebd706ce1b35600920a33550e402fc998a2e53ad3b42c3c47a192" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -1024,7 +1024,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -1647,7 +1647,7 @@ checksum = "2cd66269887534af4b0c3e3337404591daa8dc8b9b2b3db71f9523beb4bafb41" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -1747,7 +1747,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -1830,7 +1830,7 @@ checksum = "e77ac7b51b8e6313251737fcef4b1c01a2ea102bde68415b62c0ee9268fec357" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -2470,10 +2470,14 @@ dependencies = [ "axum", "clap", "color-eyre", + "futures", "hyper", + "libsqlx", + "serde", "tokio", "tracing", "tracing-subscriber", + "uuid", ] [[package]] @@ -2836,7 +2840,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -3011,7 +3015,7 @@ checksum = "39407670928234ebc5e6e580247dd567ad73a3578460c5990f9503df207e8f07" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -3091,7 +3095,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b69d39aab54d069e7f2fe8cb970493e7834601ca2d8c65fd7bbd183578080d1" dependencies = [ "proc-macro2", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -3106,9 +3110,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.58" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa1fb82fc0c281dd9671101b66b771ebbe1eaf967b96ac8740dcba4b70005ca8" +checksum = "7b368fba921b0dce7e60f5e04ec15e565b3303972b42bcfde1d0713b881959eb" dependencies = [ "unicode-ident", ] @@ -3204,9 +3208,9 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quote" -version = "1.0.27" +version = "1.0.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f4f29d145265ec1c483c7c654450edde0bfe043d3938d6972630663356d9500" +checksum = "573015e8ab27661678357f27dc26460738fd2b6c86e46f386fde94cb5d913105" dependencies = [ "proc-macro2", ] @@ -3621,22 +3625,22 @@ checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" [[package]] name = "serde" -version = "1.0.164" +version = "1.0.166" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e8c8cf938e98f769bc164923b06dce91cea1751522f46f8466461af04c9027d" +checksum = "d01b7404f9d441d3ad40e6a636a7782c377d2abdbe4fa2440e2edcc2f4f10db8" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.164" +version = "1.0.166" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68" +checksum = "5dd83d6dde2b6b2d466e14d9d1acce8816dedee94f735eac6395808b3483c6d6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -3967,9 +3971,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.16" +version = "2.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6f671d4b5ffdb8eadec19c0ae67fe2639df8684bd7bc4b83d986b8db549cf01" +checksum = "59fb7d6d8281a51045d62b8eb3a7d1ce347b76f312af50cd3dc0af39c87c1737" dependencies = [ "proc-macro2", "quote", @@ -4067,7 +4071,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -4172,7 +4176,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -4402,7 +4406,7 @@ checksum = "0f57e3ca2a01450b1a921183a9c9cbfda207fd822cef4ccb00a65402cbba7a74" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -4719,7 +4723,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", "wasm-bindgen-shared", ] @@ -4753,7 +4757,7 @@ checksum = "e128beba882dd1eb6200e1dc92ae6c5dbaa4311aa7bb211ca035779e5efc39f8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index 7fe6b737..90f2ca0b 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -9,7 +9,11 @@ edition = "2021" axum = "0.6.18" clap = { version = "4.3.11", features = ["derive"] } color-eyre = "0.6.2" +futures = "0.3.28" hyper = { version = "0.14.27", features = ["h2"] } +libsqlx = { version = "0.1.0", path = "../libsqlx" } +serde = { version = "1.0.166", features = ["derive"] } tokio = { version = "1.29.1", features = ["full"] } tracing = "0.1.37" tracing-subscriber = "0.3.17" +uuid = { version = "1.4.0", features = ["v4"] } diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 11ae61ae..2fa9a4cd 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -1,6 +1,5 @@ use std::collections::HashMap; -use libsqlx::Database; use tokio::{sync::{mpsc, oneshot}, task::{JoinSet, block_in_place}}; pub mod config; @@ -26,9 +25,17 @@ enum AllocationMessage { } } +enum Database {} + +impl Database { + fn connect(&self) -> Box { + todo!(); + } +} + pub struct Allocation { inbox: mpsc::Receiver, - database: Box, + database: Database, /// senders to the spawned connections connections: HashMap>, /// spawned connection futures, returning their connection id on completion. @@ -69,7 +76,7 @@ impl Allocation { async fn new_conn_exec(&mut self, exec: ExecFn) -> ConnectionId { let id = self.next_conn_id(); - let conn = block_in_place(|| self.database.connect()).unwrap(); + let conn = block_in_place(|| self.database.connect()); let (close_sender, exit) = mpsc::channel(1); let (exec_sender, exec_receiver) = mpsc::channel(1); let conn = Connection { @@ -103,7 +110,7 @@ impl Allocation { struct Connection { id: u32, - conn: Box, + conn: Box, exit: mpsc::Receiver<()>, exec: mpsc::Receiver, } diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index cda7ee55..2d9c8054 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -1,21 +1,53 @@ use std::sync::Arc; -use axum::Router; +use axum::{extract::State, routing::post, Json, Router}; use color_eyre::eyre::Result; use hyper::server::accept::Accept; +use serde::{Deserialize, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; -pub struct AdminServerConfig { } +use crate::{meta::MetaStore, allocation::config::AllocConfig}; -struct AdminServerState { } +pub struct AdminServerConfig {} + +struct AdminServerState { + meta_store: Arc, +} pub async fn run_admin_server(_config: AdminServerConfig, listener: I) -> Result<()> -where I: Accept, - I::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static, +where + I: Accept, + I::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static, { - let state = AdminServerState { }; - let app = Router::new().with_state(Arc::new(state)); - axum::Server::builder(listener).serve(app.into_make_service()).await?; + let state = AdminServerState { + meta_store: todo!(), + }; + let app = Router::new() + .route("/manage/allocation/create", post(allocate)) + .with_state(Arc::new(state)); + axum::Server::builder(listener) + .serve(app.into_make_service()) + .await?; Ok(()) } + +#[derive(Serialize, Debug)] +struct ErrorResponse {} + +#[derive(Serialize, Debug)] +struct AllocateResp { } + +#[derive(Deserialize, Debug)] +struct AllocateReq { + alloc_id: String, + config: AllocConfig, +} + +async fn allocate( + State(state): State>, + Json(req): Json, +) -> Result, Json> { + state.meta_store.allocate(&req.alloc_id, &req.config).await; + todo!(); +} diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index 15e33e24..1ee047bf 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -2,7 +2,10 @@ use color_eyre::eyre::Result; use tracing::metadata::LevelFilter; use tracing_subscriber::prelude::*; +mod allocation; +mod databases; mod http; +mod meta; #[tokio::main] async fn main() -> Result<()> { diff --git a/libsqlx/src/connection.rs b/libsqlx/src/connection.rs index d21ca9a6..cc4776c4 100644 --- a/libsqlx/src/connection.rs +++ b/libsqlx/src/connection.rs @@ -22,12 +22,26 @@ pub struct DescribeCol { pub trait Connection { /// Executes a query program - fn execute_program( + fn execute_program( &mut self, pgm: Program, - result_builder: B, - ) -> crate::Result; + result_builder: &mut dyn ResultBuilder, + ) -> crate::Result<()>; /// Parse the SQL statement and return information about it. fn describe(&self, sql: String) -> crate::Result; } + +impl Connection for Box { + fn execute_program( + &mut self, + pgm: Program, + result_builder: &mut dyn ResultBuilder, + ) -> crate::Result<()> { + self.as_mut().execute_program(pgm, result_builder) + } + + fn describe(&self, sql: String) -> crate::Result { + self.as_ref().describe(sql) + } +} diff --git a/libsqlx/src/database/libsql/connection.rs b/libsqlx/src/database/libsql/connection.rs index 88632501..8a1c8c55 100644 --- a/libsqlx/src/database/libsql/connection.rs +++ b/libsqlx/src/database/libsql/connection.rs @@ -101,14 +101,14 @@ impl LibsqlConnection { &self.conn } - fn run(&mut self, pgm: Program, mut builder: B) -> Result { + fn run(&mut self, pgm: Program, builder: &mut dyn ResultBuilder) -> Result<()> { let mut results = Vec::with_capacity(pgm.steps.len()); builder.init(&self.builder_config)?; let is_autocommit_before = self.conn.is_autocommit(); for step in pgm.steps() { - let res = self.execute_step(step, &results, &mut builder)?; + let res = self.execute_step(step, &results, builder)?; results.push(res); } @@ -119,14 +119,14 @@ impl LibsqlConnection { builder.finish(!self.conn.is_autocommit(), None)?; - Ok(builder) + Ok(()) } fn execute_step( &mut self, step: &Step, results: &[bool], - builder: &mut impl ResultBuilder, + builder: &mut dyn ResultBuilder, ) -> Result { builder.begin_step()?; let mut enabled = match step.cond.as_ref() { @@ -163,7 +163,7 @@ impl LibsqlConnection { fn execute_query( &self, query: &Query, - builder: &mut impl ResultBuilder, + builder: &mut dyn ResultBuilder, ) -> Result<(u64, Option)> { tracing::trace!("executing query: {}", query.stmt.stmt); @@ -237,7 +237,7 @@ fn eval_cond(cond: &Cond, results: &[bool]) -> Result { } impl Connection for LibsqlConnection { - fn execute_program(&mut self, pgm: Program, builder: B) -> crate::Result { + fn execute_program(&mut self, pgm: Program, builder: &mut dyn ResultBuilder) -> crate::Result<()> { self.run(pgm, builder) } diff --git a/libsqlx/src/database/libsql/injector/hook.rs b/libsqlx/src/database/libsql/injector/hook.rs index 0479fb2d..f87172db 100644 --- a/libsqlx/src/database/libsql/injector/hook.rs +++ b/libsqlx/src/database/libsql/injector/hook.rs @@ -27,14 +27,11 @@ pub struct InjectorHookCtx { } impl InjectorHookCtx { - pub fn new( - buffer: FrameBuffer, - injector_commit_handler: impl InjectorCommitHandler + 'static, - ) -> Self { + pub fn new(buffer: FrameBuffer, commit_handler: Box) -> Self { Self { buffer, is_txn: false, - commit_handler: Box::new(injector_commit_handler), + commit_handler, } } diff --git a/libsqlx/src/database/libsql/injector/mod.rs b/libsqlx/src/database/libsql/injector/mod.rs index df01cd34..1682e3b4 100644 --- a/libsqlx/src/database/libsql/injector/mod.rs +++ b/libsqlx/src/database/libsql/injector/mod.rs @@ -42,6 +42,16 @@ pub trait InjectorCommitHandler: 'static { fn post_commit(&mut self, frame_no: FrameNo) -> anyhow::Result<()>; } +impl InjectorCommitHandler for Box { + fn pre_commit(&mut self, frame_no: FrameNo) -> anyhow::Result<()> { + self.as_mut().pre_commit(frame_no) + } + + fn post_commit(&mut self, frame_no: FrameNo) -> anyhow::Result<()> { + self.as_mut().post_commit(frame_no) + } +} + #[cfg(test)] impl InjectorCommitHandler for () { fn pre_commit(&mut self, _frame_no: FrameNo) -> anyhow::Result<()> { @@ -56,11 +66,11 @@ impl InjectorCommitHandler for () { impl Injector { pub fn new( path: &Path, - injector_commit_hanlder: impl InjectorCommitHandler + 'static, + injector_commit_handler: Box, buffer_capacity: usize, ) -> crate::Result { let buffer = FrameBuffer::default(); - let ctx = InjectorHookCtx::new(buffer.clone(), injector_commit_hanlder); + let ctx = InjectorHookCtx::new(buffer.clone(), injector_commit_handler); let mut ctx = Box::new(ctx); let connection = sqld_libsql_bindings::Connection::open( path, @@ -162,7 +172,7 @@ mod test { let log = LogFile::new(file).unwrap(); let temp = tempfile::tempdir().unwrap(); - let mut injector = Injector::new(temp.path(), (), 10).unwrap(); + let mut injector = Injector::new(temp.path(), Box::new(()), 10).unwrap(); for frame in log.frames_iter().unwrap() { let frame = frame.unwrap(); injector.inject_frame(frame).unwrap(); @@ -184,7 +194,7 @@ mod test { let temp = tempfile::tempdir().unwrap(); // inject one frame at a time - let mut injector = Injector::new(temp.path(), (), 1).unwrap(); + let mut injector = Injector::new(temp.path(), Box::new(()), 1).unwrap(); for frame in log.frames_iter().unwrap() { let frame = frame.unwrap(); injector.inject_frame(frame).unwrap(); @@ -206,7 +216,7 @@ mod test { let temp = tempfile::tempdir().unwrap(); // inject one frame at a time - let mut injector = Injector::new(temp.path(), (), 10).unwrap(); + let mut injector = Injector::new(temp.path(), Box::new(()), 10).unwrap(); let mut iter = log.frames_iter().unwrap(); assert!(injector diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index e5f4ad0a..27397663 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -44,8 +44,8 @@ impl LibsqlDbType for PrimaryType { } pub struct ReplicaType { - // frame injector for the database - injector: Injector, + commit_handler: Option>, + injector_buffer_capacity: usize, } impl LibsqlDbType for ReplicaType { @@ -83,13 +83,13 @@ pub struct LibsqlDatabase { } /// Handler trait for gathering row stats when executing queries. -pub trait RowStatsHandler { +pub trait RowStatsHandler: Send + Sync { fn handle_row_stats(&self, stats: RowStats); } impl RowStatsHandler for F where - F: Fn(RowStats), + F: Fn(RowStats) + Send + Sync, { fn handle_row_stats(&self, stats: RowStats) { (self)(stats) @@ -104,7 +104,8 @@ impl LibsqlDatabase { injector_commit_handler: impl InjectorCommitHandler, ) -> crate::Result { let ty = ReplicaType { - injector: Injector::new(&db_path, injector_commit_handler, injector_buffer_capacity)?, + commit_handler: Some(Box::new(injector_commit_handler)), + injector_buffer_capacity, }; Ok(Self::new(db_path, ty)) @@ -154,7 +155,7 @@ impl Database for LibsqlDatabase { type Connection = LibsqlConnection<::Context>; fn connect(&self) -> Result { - LibsqlConnection::<::Context>::new( + Ok(LibsqlConnection::<::Context>::new( &self.db_path, self.extensions.clone(), T::hook(), @@ -163,13 +164,24 @@ impl Database for LibsqlDatabase { QueryBuilderConfig { max_size: Some(self.response_size_limit), }, - ) + )?) } } impl InjectableDatabase for LibsqlDatabase { - fn inject_frame(&mut self, frame: Frame) -> Result<(), InjectError> { - self.ty.injector.inject_frame(frame).unwrap(); + fn injector(&mut self) -> crate::Result> { + let Some(commit_handler) = self.ty.commit_handler.take() else { panic!("there can be only one injector") }; + Ok(Box::new(Injector::new( + &self.db_path, + commit_handler, + self.ty.injector_buffer_capacity, + )?)) + } +} + +impl super::Injector for Injector { + fn inject(&mut self, frame: Frame) -> Result<(), InjectError> { + self.inject_frame(frame).unwrap(); Ok(()) } } @@ -205,32 +217,36 @@ mod test { fn inject_libsql_db() { let temp = tempfile::tempdir().unwrap(); let replica = ReplicaType { - injector: Injector::new(temp.path(), (), 10).unwrap(), + commit_handler: Some(Box::new(())), + injector_buffer_capacity: 10, }; let mut db = LibsqlDatabase::new(temp.path().to_path_buf(), replica); let mut conn = db.connect().unwrap(); - let res = conn + let mut builder = ReadRowBuilder(Vec::new()); + conn .execute_program( Program::seq(&["select count(*) from test"]), - ReadRowBuilder(Vec::new()), + &mut builder ) .unwrap(); - assert!(res.0.is_empty()); + assert!(builder.0.is_empty()); let file = File::open("assets/test/simple_wallog").unwrap(); let log = LogFile::new(file).unwrap(); + let mut injector = db.injector().unwrap(); log.frames_iter() .unwrap() - .for_each(|f| db.inject_frame(f.unwrap()).unwrap()); + .for_each(|f| injector.inject(f.unwrap()).unwrap()); - let res = conn + let mut builder = ReadRowBuilder(Vec::new()); + conn .execute_program( Program::seq(&["select count(*) from test"]), - ReadRowBuilder(Vec::new()), + &mut builder ) .unwrap(); - assert_eq!(res.0[0], Value::Integer(5)); + assert_eq!(builder.0[0], Value::Integer(5)); } #[test] @@ -248,7 +264,8 @@ mod test { let mut replica = LibsqlDatabase::new( temp_replica.path().to_path_buf(), ReplicaType { - injector: Injector::new(temp_replica.path(), (), 10).unwrap(), + commit_handler: Some(Box::new(())), + injector_buffer_capacity: 10, }, ); @@ -256,27 +273,29 @@ mod test { primary_conn .execute_program( Program::seq(&["create table test (x)", "insert into test values (42)"]), - (), + &mut (), ) .unwrap(); let logfile = primary.ty.logger.log_file.read(); + let mut injector = replica.injector().unwrap(); for frame in logfile.frames_iter().unwrap() { let frame = frame.unwrap(); - replica.inject_frame(frame).unwrap(); + injector.inject(frame).unwrap(); } let mut replica_conn = replica.connect().unwrap(); - let result = replica_conn + let mut builder = ReadRowBuilder(Vec::new()); + replica_conn .execute_program( Program::seq(&["select * from test limit 1"]), - ReadRowBuilder(Vec::new()), + &mut builder ) .unwrap(); - assert_eq!(result.0.len(), 1); - assert_eq!(result.0[0], Value::Integer(42)); + assert_eq!(builder.0.len(), 1); + assert_eq!(builder.0[0], Value::Integer(42)); } #[test] @@ -311,7 +330,7 @@ mod test { let mut conn = db.connect().unwrap(); conn.execute_program( Program::seq(&["create table test (x)", "insert into test values (12)"]), - (), + &mut (), ) .unwrap(); assert!(compactor_called.get()); @@ -354,12 +373,12 @@ mod test { "create table test (x)", "insert into test values (12)", ]), - (), + &mut (), ) .unwrap(); conn.inner_connection().cache_flush().unwrap(); assert!(!compactor_called.get()); - conn.execute_program(Program::seq(&["commit"]), ()).unwrap(); + conn.execute_program(Program::seq(&["commit"]), &mut ()).unwrap(); assert!(compactor_called.get()); } } diff --git a/libsqlx/src/database/mod.rs b/libsqlx/src/database/mod.rs index edfb0b8e..fa1ce874 100644 --- a/libsqlx/src/database/mod.rs +++ b/libsqlx/src/database/mod.rs @@ -1,11 +1,7 @@ use std::time::Duration; -use crate::connection::{Connection, DescribeResponse}; +use crate::connection::Connection; use crate::error::Error; -use crate::program::Program; -use crate::result_builder::ResultBuilder; -use crate::semaphore::{Permit, Semaphore}; - use self::frame::Frame; mod frame; @@ -25,111 +21,13 @@ pub trait Database { type Connection: Connection; /// Create a new connection to the database fn connect(&self) -> Result; - - /// returns a database with a limit on the number of conccurent connections - fn throttled(self, limit: usize, timeout: Option) -> ThrottledDatabase - where - Self: Sized, - { - ThrottledDatabase::new(limit, self, timeout) - } } -// Trait implemented by databases that support frame injection pub trait InjectableDatabase { - fn inject_frame(&mut self, frame: Frame) -> Result<(), InjectError>; -} - -/// A Database that limits the number of conccurent connections to the underlying database. -pub struct ThrottledDatabase { - semaphore: Semaphore, - db: T, - timeout: Option, -} - -impl ThrottledDatabase { - fn new(conccurency: usize, db: T, timeout: Option) -> Self { - Self { - semaphore: Semaphore::new(conccurency), - db, - timeout, - } - } -} - -impl Database for ThrottledDatabase { - type Connection = TrackedDb; - - fn connect(&self) -> Result { - let permit = match self.timeout { - Some(t) => self - .semaphore - .acquire_timeout(t) - .ok_or(Error::DbCreateTimeout)?, - None => self.semaphore.acquire(), - }; - - let inner = self.db.connect()?; - Ok(TrackedDb { permit, inner }) - } -} - -pub struct TrackedDb { - inner: DB, - #[allow(dead_code)] // just hold on to it - permit: Permit, -} - -impl Connection for TrackedDb { - #[inline] - fn execute_program(&mut self, pgm: Program, builder: B) -> crate::Result { - self.inner.execute_program(pgm, builder) - } - - #[inline] - fn describe(&self, sql: String) -> crate::Result { - self.inner.describe(sql) - } + fn injector(&mut self) -> crate::Result>; } -#[cfg(test)] -mod test { - use super::*; - - struct DummyConn; - - impl Connection for DummyConn { - fn execute_program(&mut self, _pgm: Program, _builder: B) -> crate::Result - where - B: ResultBuilder, - { - unreachable!() - } - - fn describe(&self, _sql: String) -> crate::Result { - unreachable!() - } - } - - struct DummyDatabase; - - impl Database for DummyDatabase { - type Connection = DummyConn; - - fn connect(&self) -> Result { - Ok(DummyConn) - } - } - - #[test] - fn throttle_db_creation() { - let db = DummyDatabase.throttled(1, Some(Duration::from_millis(100))); - let conn = db.connect().unwrap(); - - assert!(db.connect().is_err()); - - drop(conn); - - assert!(db.connect().is_ok()); - } +// Trait implemented by databases that support frame injection +pub trait Injector { + fn inject(&mut self, frame: Frame) -> Result<(), InjectError>; } diff --git a/libsqlx/src/database/proxy/connection.rs b/libsqlx/src/database/proxy/connection.rs index 6fec7d23..24c10a47 100644 --- a/libsqlx/src/database/proxy/connection.rs +++ b/libsqlx/src/database/proxy/connection.rs @@ -26,9 +26,9 @@ where ReadDb: Connection, WriteDb: Connection, { - fn execute_program(&mut self, pgm: Program, builder: B) -> Result { + fn execute_program(&mut self, pgm: Program, builder: &mut dyn ResultBuilder) -> Result<()> { let mut state = self.state.lock(); - let builder = ExtractFrameNoBuilder::new(builder); + let mut builder = ExtractFrameNoBuilder::new(builder); if !state.is_txn && pgm.is_read_only() { if let Some(frame_no) = state.last_frame_no { (self.wait_frame_no_cb)(frame_no); @@ -36,24 +36,24 @@ where // We know that this program won't perform any writes. We attempt to run it on the // replica. If it leaves an open transaction, then this program is an interactive // transaction, so we rollback the replica, and execute again on the primary. - let builder = self.read_db.execute_program(pgm.clone(), builder)?; + self.read_db.execute_program(pgm.clone(), &mut builder)?; // still in transaction state after running a read-only txn if builder.is_txn { // TODO: rollback // self.read_db.rollback().await?; - let builder = self.write_db.execute_program(pgm, builder)?; + self.write_db.execute_program(pgm, &mut builder)?; state.is_txn = builder.is_txn; state.last_frame_no = builder.frame_no; - Ok(builder.inner) + Ok(()) } else { - Ok(builder.inner) + Ok(()) } } else { - let builder = self.write_db.execute_program(pgm, builder)?; + self.write_db.execute_program(pgm, &mut builder)?; state.is_txn = builder.is_txn; state.last_frame_no = builder.frame_no; - Ok(builder.inner) + Ok(()) } } @@ -65,14 +65,14 @@ where } } -struct ExtractFrameNoBuilder { - inner: B, +struct ExtractFrameNoBuilder<'a> { + inner: &'a mut dyn ResultBuilder, frame_no: Option, is_txn: bool, } -impl ExtractFrameNoBuilder { - fn new(inner: B) -> Self { +impl<'a> ExtractFrameNoBuilder<'a> { + fn new(inner: &'a mut dyn ResultBuilder) -> Self { Self { inner, frame_no: None, @@ -81,7 +81,7 @@ impl ExtractFrameNoBuilder { } } -impl ResultBuilder for ExtractFrameNoBuilder { +impl<'a> ResultBuilder for ExtractFrameNoBuilder<'a> { fn init( &mut self, config: &QueryBuilderConfig, @@ -206,14 +206,14 @@ mod test { ); let mut conn = db.connect().unwrap(); - conn.execute_program(Program::seq(&["insert into test values (12)"]), ()) + conn.execute_program(Program::seq(&["insert into test values (12)"]), &mut ()) .unwrap(); assert!(!wait_called.get()); assert!(!read_called.get()); assert!(write_called.get()); - conn.execute_program(Program::seq(&["select * from test"]), ()) + conn.execute_program(Program::seq(&["select * from test"]), &mut ()) .unwrap(); assert!(read_called.get()); diff --git a/libsqlx/src/database/proxy/database.rs b/libsqlx/src/database/proxy/database.rs index fedd0ef7..129cc5e2 100644 --- a/libsqlx/src/database/proxy/database.rs +++ b/libsqlx/src/database/proxy/database.rs @@ -1,4 +1,3 @@ -use crate::database::frame::Frame; use crate::database::{Database, InjectableDatabase}; use crate::error::Error; @@ -27,7 +26,6 @@ where WDB: Database, { type Connection = WriteProxyConnection; - /// Create a new connection to the database fn connect(&self) -> Result { Ok(WriteProxyConnection { @@ -43,8 +41,7 @@ impl InjectableDatabase for WriteProxyDatabase where RDB: InjectableDatabase, { - fn inject_frame(&mut self, frame: Frame) -> Result<(), crate::database::InjectError> { - // TODO: handle frame index - self.read_db.inject_frame(frame) + fn injector(&mut self) -> crate::Result> { + self.read_db.injector() } } diff --git a/libsqlx/src/database/test_utils.rs b/libsqlx/src/database/test_utils.rs index 86c072ea..a46aa2ac 100644 --- a/libsqlx/src/database/test_utils.rs +++ b/libsqlx/src/database/test_utils.rs @@ -51,13 +51,13 @@ impl Database for MockDatabase { } impl Connection for MockConnection { - fn execute_program( + fn execute_program( &mut self, pgm: crate::program::Program, - mut reponse_builder: B, - ) -> crate::Result { - (self.execute_fn)(pgm, &mut reponse_builder)?; - Ok(reponse_builder) + reponse_builder: &mut dyn ResultBuilder, + ) -> crate::Result<()> { + (self.execute_fn)(pgm, reponse_builder)?; + Ok(()) } fn describe(&self, sql: String) -> crate::Result { diff --git a/libsqlx/src/lib.rs b/libsqlx/src/lib.rs index 986044e2..a6e3c3a2 100644 --- a/libsqlx/src/lib.rs +++ b/libsqlx/src/lib.rs @@ -7,7 +7,6 @@ mod database; mod program; mod result_builder; mod seal; -mod semaphore; pub type Result = std::result::Result; diff --git a/libsqlx/src/result_builder.rs b/libsqlx/src/result_builder.rs index 2784274c..ae299b1e 100644 --- a/libsqlx/src/result_builder.rs +++ b/libsqlx/src/result_builder.rs @@ -80,7 +80,7 @@ pub struct QueryBuilderConfig { pub max_size: Option, } -pub trait ResultBuilder: Send + 'static { +pub trait ResultBuilder { /// (Re)initialize the builder. This method can be called multiple times. fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { Ok(()) diff --git a/libsqlx/src/semaphore.rs b/libsqlx/src/semaphore.rs deleted file mode 100644 index a47a4eb1..00000000 --- a/libsqlx/src/semaphore.rs +++ /dev/null @@ -1,98 +0,0 @@ -use std::sync::Arc; -use std::time::Duration; -use std::time::Instant; - -use parking_lot::Condvar; -use parking_lot::Mutex; - -struct SemaphoreInner { - max_permits: usize, - permits: Mutex, - condvar: Condvar, -} - -#[derive(Clone)] -pub struct Semaphore { - inner: Arc, -} - -pub struct Permit(Semaphore); - -impl Drop for Permit { - fn drop(&mut self) { - *self.0.inner.permits.lock() -= 1; - self.0.inner.condvar.notify_one(); - } -} - -impl Semaphore { - pub fn new(max_permits: usize) -> Self { - Self { - inner: Arc::new(SemaphoreInner { - max_permits, - permits: Mutex::new(0), - condvar: Condvar::new(), - }), - } - } - - pub fn acquire(&self) -> Permit { - let mut permits = self.inner.permits.lock(); - self.inner - .condvar - .wait_while(&mut permits, |permits| *permits >= self.inner.max_permits); - *permits += 1; - assert!(*permits <= self.inner.max_permits); - Permit(self.clone()) - } - - pub fn acquire_timeout(&self, timeout: Duration) -> Option { - let deadline = Instant::now() + timeout; - let mut permits = self.inner.permits.lock(); - if self - .inner - .condvar - .wait_while_until( - &mut permits, - |permits| *permits >= self.inner.max_permits, - deadline, - ) - .timed_out() - { - return None; - } - - *permits += 1; - assert!(*permits <= self.inner.max_permits); - Some(Permit(self.clone())) - } - - #[cfg(test)] - fn try_acquire(&self) -> Option { - let mut permits = self.inner.permits.lock(); - if *permits >= self.inner.max_permits { - None - } else { - *permits += 1; - Some(Permit(self.clone())) - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn semaphore() { - let sem = Semaphore::new(2); - let permit1 = sem.acquire(); - let _permit2 = sem.acquire(); - - assert!(sem.try_acquire().is_none()); - drop(permit1); - let perm = sem.try_acquire(); - assert!(perm.is_some()); - assert!(sem.acquire_timeout(Duration::from_millis(100)).is_none()); - } -} From 19c8f4fecfafbad6b7af7cb13d323ca87bcdaf62 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 7 Jul 2023 16:48:59 +0200 Subject: [PATCH 6/8] bare bone allocation manager/store --- Cargo.lock | 317 +++++++++++++++++- libsqlx-server/Cargo.toml | 7 +- libsqlx-server/src/allocation/config.rs | 24 +- libsqlx-server/src/allocation/mod.rs | 71 ++-- libsqlx-server/src/databases/mod.rs | 5 + libsqlx-server/src/databases/store.rs | 12 + libsqlx-server/src/http/admin.rs | 69 +++- libsqlx-server/src/main.rs | 15 + libsqlx-server/src/manager.rs | 50 +++ libsqlx-server/src/meta.rs | 53 ++- libsqlx/src/connection.rs | 61 +++- libsqlx/src/database/libsql/connection.rs | 6 +- libsqlx/src/database/libsql/mod.rs | 91 ++--- .../database/libsql/replication_log/logger.rs | 22 +- libsqlx/src/database/mod.rs | 2 +- 15 files changed, 712 insertions(+), 93 deletions(-) create mode 100644 libsqlx-server/src/databases/mod.rs create mode 100644 libsqlx-server/src/databases/store.rs create mode 100644 libsqlx-server/src/manager.rs diff --git a/Cargo.lock b/Cargo.lock index 63b0f30a..0864a55e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -177,6 +177,26 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-io" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fc5b45d93ef0529756f812ca52e44c221b35341892d3dcc34132ac02f3dd2af" +dependencies = [ + "async-lock", + "autocfg", + "cfg-if", + "concurrent-queue", + "futures-lite", + "log", + "parking", + "polling", + "rustix 0.37.19", + "slab", + "socket2", + "waker-fn", +] + [[package]] name = "async-lock" version = "2.7.0" @@ -819,6 +839,12 @@ version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" +[[package]] +name = "bytecount" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c676a478f63e9fa2dd5368a42f28bba0d6c560b775f38583c8bbaa7fcd67c9c" + [[package]] name = "bytemuck" version = "1.13.1" @@ -876,6 +902,15 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38fcc2979eff34a4b84e1cf9a1e3da42a7d44b3b690a40cdcb23e3d556cfb2e5" +[[package]] +name = "camino" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c530edf18f37068ac2d977409ed5cd50d53d73bc653c7647b48eb78976ac9ae2" +dependencies = [ + "serde", +] + [[package]] name = "cap-fs-ext" version = "0.26.1" @@ -941,6 +976,28 @@ dependencies = [ "winx", ] +[[package]] +name = "cargo-platform" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbdb825da8a5df079a43676dbe042702f1707b1109f713a01420fbb4cc71fa27" +dependencies = [ + "serde", +] + +[[package]] +name = "cargo_metadata" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4acbb09d9ee8e23699b9634375c72795d095bf268439da88562cf9b501f181fa" +dependencies = [ + "camino", + "cargo-platform", + "semver", + "serde", + "serde_json", +] + [[package]] name = "cc" version = "1.0.79" @@ -1066,6 +1123,15 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" +[[package]] +name = "concurrent-queue" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62ec6771ecfa0762d24683ee5a32ad78487a3d3afdc0fb8cae19d2c5deb50b7c" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "console" version = "0.15.7" @@ -1558,6 +1624,15 @@ dependencies = [ "libc", ] +[[package]] +name = "error-chain" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d2f06b9cac1506ece98fe3231e3cc9c4410ec3d5b1f24ae1c8946f0742cdefc" +dependencies = [ + "version_check", +] + [[package]] name = "event-listener" version = "2.5.3" @@ -1691,6 +1766,16 @@ dependencies = [ "windows-sys 0.36.1", ] +[[package]] +name = "fs2" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "futures" version = "0.3.28" @@ -1739,6 +1824,21 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +[[package]] +name = "futures-lite" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49a9d51ce47660b1e808d3c990b4709f2f415d928835a17dfd16991515c46bce" +dependencies = [ + "fastrand", + "futures-core", + "futures-io", + "memchr", + "parking", + "pin-project-lite", + "waker-fn", +] + [[package]] name = "futures-macro" version = "0.3.28" @@ -2448,7 +2548,7 @@ dependencies = [ "itertools 0.11.0", "nix", "once_cell", - "parking_lot", + "parking_lot 0.12.1", "rand", "regex", "rusqlite", @@ -2468,12 +2568,15 @@ name = "libsqlx-server" version = "0.1.0" dependencies = [ "axum", + "bincode", "clap", "color-eyre", "futures", "hyper", "libsqlx", + "moka", "serde", + "sled", "tokio", "tracing", "tracing-subscriber", @@ -2526,6 +2629,15 @@ dependencies = [ "libc", ] +[[package]] +name = "mach2" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d0d1830bcd151a6fc4aea1369af235b36c1528fe976b8ff678683c9995eade8" +dependencies = [ + "libc", +] + [[package]] name = "maplit" version = "1.0.2" @@ -2656,6 +2768,31 @@ dependencies = [ "windows-sys 0.45.0", ] +[[package]] +name = "moka" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "206bf83f415b0579fd885fe0804eb828e727636657dc1bf73d80d2f1218e14a1" +dependencies = [ + "async-io", + "async-lock", + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "futures-util", + "once_cell", + "parking_lot 0.12.1", + "quanta", + "rustc_version", + "scheduled-thread-pool", + "skeptic", + "smallvec", + "tagptr", + "thiserror", + "triomphe", + "uuid", +] + [[package]] name = "multimap" version = "0.8.3" @@ -2879,6 +3016,23 @@ version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c1b04fb49957986fdce4d6ee7a65027d55d4b6d2265e5848bbb507b58ccfdb6f" +[[package]] +name = "parking" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14f2252c834a40ed9bb5422029649578e63aa341ac401f74e719dd1afda8394e" + +[[package]] +name = "parking_lot" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" +dependencies = [ + "instant", + "lock_api", + "parking_lot_core 0.8.6", +] + [[package]] name = "parking_lot" version = "0.12.1" @@ -2886,7 +3040,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ "lock_api", - "parking_lot_core", + "parking_lot_core 0.9.7", +] + +[[package]] +name = "parking_lot_core" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" +dependencies = [ + "cfg-if", + "instant", + "libc", + "redox_syscall 0.2.16", + "smallvec", + "winapi", ] [[package]] @@ -3072,6 +3240,22 @@ dependencies = [ "serde_json", ] +[[package]] +name = "polling" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b2d323e8ca7996b3e23126511a523f7e62924d93ecd5ae73b333815b0eb3dce" +dependencies = [ + "autocfg", + "bitflags 1.3.2", + "cfg-if", + "concurrent-queue", + "libc", + "log", + "pin-project-lite", + "windows-sys 0.48.0", +] + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -3200,6 +3384,33 @@ dependencies = [ "cc", ] +[[package]] +name = "pulldown-cmark" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a1a2f1f0a7ecff9c31abbe177637be0e97a0aef46cf8738ece09327985d998" +dependencies = [ + "bitflags 1.3.2", + "memchr", + "unicase", +] + +[[package]] +name = "quanta" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a17e662a7a8291a865152364c20c7abc5e60486ab2001e8ec10b24862de0b9ab" +dependencies = [ + "crossbeam-utils", + "libc", + "mach2", + "once_cell", + "raw-cpuid", + "wasi 0.11.0+wasi-snapshot-preview1", + "web-sys", + "winapi", +] + [[package]] name = "quick-error" version = "1.2.3" @@ -3254,6 +3465,15 @@ dependencies = [ "rand_core", ] +[[package]] +name = "raw-cpuid" +version = "10.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "rayon" version = "1.7.0" @@ -3558,6 +3778,15 @@ version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.21" @@ -3567,6 +3796,15 @@ dependencies = [ "windows-sys 0.42.0", ] +[[package]] +name = "scheduled-thread-pool" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cbc66816425a074528352f5789333ecff06ca41b36b0b0efdfbb29edc391a19" +dependencies = [ + "parking_lot 0.12.1", +] + [[package]] name = "scopeguard" version = "1.1.0" @@ -3622,6 +3860,9 @@ name = "semver" version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" +dependencies = [ + "serde", +] [[package]] name = "serde" @@ -3765,6 +4006,21 @@ version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7bd3e3206899af3f8b12af284fafc038cc1dc2b41d1b89dd17297221c5d225de" +[[package]] +name = "skeptic" +version = "0.13.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16d23b015676c90a0f01c197bfdc786c20342c73a0afdda9025adb0bc42940a8" +dependencies = [ + "bytecount", + "cargo_metadata", + "error-chain", + "glob", + "pulldown-cmark", + "tempfile", + "walkdir", +] + [[package]] name = "slab" version = "0.4.8" @@ -3774,6 +4030,22 @@ dependencies = [ "autocfg", ] +[[package]] +name = "sled" +version = "0.34.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f96b4737c2ce5987354855aed3797279def4ebf734436c6aa4552cf8e169935" +dependencies = [ + "crc32fast", + "crossbeam-epoch", + "crossbeam-utils", + "fs2", + "fxhash", + "libc", + "log", + "parking_lot 0.11.2", +] + [[package]] name = "slice-group-by" version = "0.3.1" @@ -3855,7 +4127,7 @@ dependencies = [ "mimalloc", "nix", "once_cell", - "parking_lot", + "parking_lot 0.12.1", "priority-queue", "proptest", "prost", @@ -4002,6 +4274,12 @@ dependencies = [ "winx", ] +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "tar" version = "0.4.38" @@ -4149,7 +4427,7 @@ dependencies = [ "libc", "mio", "num_cpus", - "parking_lot", + "parking_lot 0.12.1", "pin-project-lite", "signal-hook-registry", "socket2", @@ -4468,6 +4746,12 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "triomphe" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee8098afad3fb0c54a9007aab6804558410503ad676d4633f9c2559a00ac0f" + [[package]] name = "try-lock" version = "0.2.4" @@ -4514,6 +4798,15 @@ dependencies = [ "version_check", ] +[[package]] +name = "unicase" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6" +dependencies = [ + "version_check", +] + [[package]] name = "unicode-bidi" version = "0.3.13" @@ -4637,6 +4930,22 @@ dependencies = [ "libc", ] +[[package]] +name = "waker-fn" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d5b2c62b4012a3e1eca5a7e077d13b3bf498c4073e33ccd58626607748ceeca" + +[[package]] +name = "walkdir" +version = "2.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36df944cda56c7d8d8b7496af378e6b16de9284591917d307c9b4d313c44e698" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.0" diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index 90f2ca0b..26eb60ff 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -7,13 +7,16 @@ edition = "2021" [dependencies] axum = "0.6.18" +bincode = "1.3.3" clap = { version = "4.3.11", features = ["derive"] } color-eyre = "0.6.2" futures = "0.3.28" -hyper = { version = "0.14.27", features = ["h2"] } +hyper = { version = "0.14.27", features = ["h2", "server"] } libsqlx = { version = "0.1.0", path = "../libsqlx" } +moka = { version = "0.11.2", features = ["future"] } serde = { version = "1.0.166", features = ["derive"] } +sled = "0.34.7" tokio = { version = "1.29.1", features = ["full"] } tracing = "0.1.37" -tracing-subscriber = "0.3.17" +tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } uuid = { version = "1.4.0", features = ["v4"] } diff --git a/libsqlx-server/src/allocation/config.rs b/libsqlx-server/src/allocation/config.rs index 19a6396b..f5839e9c 100644 --- a/libsqlx-server/src/allocation/config.rs +++ b/libsqlx-server/src/allocation/config.rs @@ -1,9 +1,21 @@ -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; +/// Structural supertype of AllocConfig, used for checking the meta version. Subsequent version of +/// AllocConfig need to conform to this prototype. #[derive(Debug, Serialize, Deserialize)] -pub enum AllocConfig { - Primary { }, - Replica { - primary_node_id: String, - } +struct ConfigVersion { + config_version: u32, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct AllocConfig { + pub max_conccurent_connection: u32, + pub id: String, + pub db_config: DbConfig, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum DbConfig { + Primary {}, + Replica { primary_node_id: String }, } diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 2fa9a4cd..21e4c97c 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -1,18 +1,24 @@ use std::collections::HashMap; +use std::path::PathBuf; -use tokio::{sync::{mpsc, oneshot}, task::{JoinSet, block_in_place}}; +use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile, PrimaryType}; +use libsqlx::Database as _; +use tokio::sync::{mpsc, oneshot}; +use tokio::task::{block_in_place, JoinSet}; + +use self::config::{AllocConfig, DbConfig}; pub mod config; type ExecFn = Box; #[derive(Clone)] -struct ConnectionId { +pub struct ConnectionId { id: u32, close_sender: mpsc::Sender<()>, } -enum AllocationMessage { +pub enum AllocationMessage { /// Execute callback against connection Exec { connection_id: ConnectionId, @@ -22,30 +28,61 @@ enum AllocationMessage { NewConnExec { exec: ExecFn, ret: oneshot::Sender, - } + }, +} + +pub enum Database { + Primary(libsqlx::libsql::LibsqlDatabase), } -enum Database {} +struct Compactor; + +impl LogCompactor for Compactor { + fn should_compact(&self, _log: &LogFile) -> bool { + false + } + + fn compact( + &self, + _log: LogFile, + _path: std::path::PathBuf, + _size_after: u32, + ) -> Result<(), Box> { + todo!() + } +} impl Database { + pub fn from_config(config: &AllocConfig, path: PathBuf) -> Self { + match config.db_config { + DbConfig::Primary {} => { + let db = LibsqlDatabase::new_primary(path, Compactor, false).unwrap(); + Self::Primary(db) + } + DbConfig::Replica { .. } => todo!(), + } + } + fn connect(&self) -> Box { - todo!(); + match self { + Database::Primary(db) => Box::new(db.connect().unwrap()), + } } } pub struct Allocation { - inbox: mpsc::Receiver, - database: Database, + pub inbox: mpsc::Receiver, + pub database: Database, /// senders to the spawned connections - connections: HashMap>, + pub connections: HashMap>, /// spawned connection futures, returning their connection id on completion. - connections_futs: JoinSet, - next_conn_id: u32, - max_concurrent_connections: u32, + pub connections_futs: JoinSet, + pub next_conn_id: u32, + pub max_concurrent_connections: u32, } impl Allocation { - async fn run(mut self) { + pub async fn run(mut self) { loop { tokio::select! { Some(msg) = self.inbox.recv() => { @@ -86,23 +123,19 @@ impl Allocation { exec: exec_receiver, }; - self.connections_futs.spawn(conn.run()); // This should never block! assert!(exec_sender.try_send(exec).is_ok()); assert!(self.connections.insert(id, exec_sender).is_none()); - ConnectionId { - id, - close_sender, - } + ConnectionId { id, close_sender } } fn next_conn_id(&mut self) -> u32 { loop { self.next_conn_id = self.next_conn_id.wrapping_add(1); if !self.connections.contains_key(&self.next_conn_id) { - return self.next_conn_id + return self.next_conn_id; } } } diff --git a/libsqlx-server/src/databases/mod.rs b/libsqlx-server/src/databases/mod.rs new file mode 100644 index 00000000..0494174b --- /dev/null +++ b/libsqlx-server/src/databases/mod.rs @@ -0,0 +1,5 @@ +use uuid::Uuid; + +mod store; + +pub type DatabaseId = Uuid; diff --git a/libsqlx-server/src/databases/store.rs b/libsqlx-server/src/databases/store.rs new file mode 100644 index 00000000..206beb34 --- /dev/null +++ b/libsqlx-server/src/databases/store.rs @@ -0,0 +1,12 @@ +use std::collections::HashMap; + +use super::DatabaseId; + +pub enum Database { + Replica, + Primary, +} + +pub struct DatabaseManager { + databases: HashMap, +} diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index 2d9c8054..51ba1b7f 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{path::PathBuf, sync::Arc}; use axum::{extract::State, routing::post, Json, Router}; use color_eyre::eyre::Result; @@ -6,24 +6,30 @@ use hyper::server::accept::Accept; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; -use crate::{meta::MetaStore, allocation::config::AllocConfig}; +use crate::{ + allocation::config::{AllocConfig, DbConfig}, + meta::Store, +}; -pub struct AdminServerConfig {} +pub struct AdminServerConfig { + pub db_path: PathBuf, +} struct AdminServerState { - meta_store: Arc, + meta_store: Arc, } -pub async fn run_admin_server(_config: AdminServerConfig, listener: I) -> Result<()> +pub async fn run_admin_server(config: AdminServerConfig, listener: I) -> Result<()> where I: Accept, I::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static, { let state = AdminServerState { - meta_store: todo!(), + meta_store: Arc::new(Store::new(&config.db_path)), }; + let app = Router::new() - .route("/manage/allocation/create", post(allocate)) + .route("/manage/allocation", post(allocate).get(list_allocs)) .with_state(Arc::new(state)); axum::Server::builder(listener) .serve(app.into_make_service()) @@ -36,18 +42,59 @@ where struct ErrorResponse {} #[derive(Serialize, Debug)] -struct AllocateResp { } +struct AllocateResp {} #[derive(Deserialize, Debug)] struct AllocateReq { alloc_id: String, - config: AllocConfig, + max_conccurent_connection: Option, + config: DbConfigReq, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum DbConfigReq { + Primary { }, + Replica { primary_node_id: String }, } async fn allocate( State(state): State>, Json(req): Json, ) -> Result, Json> { - state.meta_store.allocate(&req.alloc_id, &req.config).await; - todo!(); + let config = AllocConfig { + max_conccurent_connection: req.max_conccurent_connection.unwrap_or(16), + id: req.alloc_id.clone(), + db_config: match req.config { + DbConfigReq::Primary { } => DbConfig::Primary { }, + DbConfigReq::Replica { primary_node_id } => DbConfig::Replica { primary_node_id }, + }, + }; + state.meta_store.allocate(&req.alloc_id, &config).await; + + Ok(Json(AllocateResp {})) +} + +#[derive(Serialize, Debug)] +struct ListAllocResp { + allocs: Vec, +} + +#[derive(Serialize, Debug)] +struct AllocView { + id: String, +} + +async fn list_allocs( + State(state): State>, +) -> Result, Json> { + let allocs = state + .meta_store + .list_allocs() + .await + .into_iter() + .map(|cfg| AllocView { id: cfg.id }) + .collect(); + + Ok(Json(ListAllocResp { allocs })) } diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index 1ee047bf..fb213397 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -1,16 +1,31 @@ +use std::path::PathBuf; + use color_eyre::eyre::Result; +use http::admin::{run_admin_server, AdminServerConfig}; +use hyper::server::conn::AddrIncoming; use tracing::metadata::LevelFilter; use tracing_subscriber::prelude::*; mod allocation; mod databases; mod http; +mod manager; mod meta; #[tokio::main] async fn main() -> Result<()> { init(); + let admin_api_listener = tokio::net::TcpListener::bind("0.0.0.0:3456").await?; + run_admin_server( + AdminServerConfig { + db_path: PathBuf::from("database"), + }, + AddrIncoming::from_listener(admin_api_listener)?, + ) + .await + .unwrap(); + Ok(()) } diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs new file mode 100644 index 00000000..8d7737a7 --- /dev/null +++ b/libsqlx-server/src/manager.rs @@ -0,0 +1,50 @@ +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; + +use moka::future::Cache; +use tokio::sync::mpsc; +use tokio::task::JoinSet; + +use crate::allocation::{Allocation, AllocationMessage, Database}; +use crate::meta::Store; + +pub struct Manager { + cache: Cache>, + meta_store: Arc, + db_path: PathBuf, +} + +const MAX_ALLOC_MESSAGE_QUEUE_LEN: usize = 32; + +impl Manager { + pub async fn alloc(&self, alloc_id: &str) -> mpsc::Sender { + if let Some(sender) = self.cache.get(alloc_id) { + return sender.clone(); + } + + if let Some(config) = self.meta_store.meta(alloc_id).await { + let path = self.db_path.join("dbs").join(alloc_id); + tokio::fs::create_dir_all(&path).await.unwrap(); + let (alloc_sender, inbox) = mpsc::channel(MAX_ALLOC_MESSAGE_QUEUE_LEN); + let alloc = Allocation { + inbox, + database: Database::from_config(&config, path), + connections: HashMap::new(), + connections_futs: JoinSet::new(), + next_conn_id: 0, + max_concurrent_connections: config.max_conccurent_connection, + }; + + tokio::spawn(alloc.run()); + + self.cache + .insert(alloc_id.to_string(), alloc_sender.clone()) + .await; + + return alloc_sender; + } + + todo!("alloc doesn't exist") + } +} diff --git a/libsqlx-server/src/meta.rs b/libsqlx-server/src/meta.rs index 7f48a456..475a0250 100644 --- a/libsqlx-server/src/meta.rs +++ b/libsqlx-server/src/meta.rs @@ -1,10 +1,57 @@ +use std::path::Path; + +use sled::Tree; use uuid::Uuid; use crate::allocation::config::AllocConfig; -pub struct MetaStore {} +type ExecFn = Box)>; + +pub struct Store { + meta_store: Tree, +} + +impl Store { + pub fn new(path: &Path) -> Self { + std::fs::create_dir_all(&path).unwrap(); + let path = path.join("store"); + let db = sled::open(path).unwrap(); + let meta_store = db.open_tree("meta_store").unwrap(); + + Self { meta_store } + } + + pub async fn allocate(&self, alloc_id: &str, meta: &AllocConfig) { + //TODO: Handle conflict + tokio::task::block_in_place(|| { + let meta_bytes = bincode::serialize(meta).unwrap(); + self.meta_store + .compare_and_swap(alloc_id, None as Option<&[u8]>, Some(meta_bytes)) + .unwrap() + .unwrap(); + }); + } -impl MetaStore { - pub async fn allocate(&self, alloc_id: &str, meta: &AllocConfig) {} pub async fn deallocate(&self, alloc_id: Uuid) {} + + pub async fn meta(&self, alloc_id: &str) -> Option { + tokio::task::block_in_place(|| { + let config = self.meta_store.get(alloc_id).unwrap()?; + let config = bincode::deserialize(config.as_ref()).unwrap(); + Some(config) + }) + } + + pub async fn list_allocs(&self) -> Vec { + tokio::task::block_in_place(|| { + let mut out = Vec::new(); + for kv in self.meta_store.iter() { + let (k, v) = kv.unwrap(); + let alloc = bincode::deserialize(&v).unwrap(); + out.push(alloc); + } + + out + }) + } } diff --git a/libsqlx/src/connection.rs b/libsqlx/src/connection.rs index cc4776c4..a5eb7e60 100644 --- a/libsqlx/src/connection.rs +++ b/libsqlx/src/connection.rs @@ -1,5 +1,9 @@ -use crate::program::Program; +use rusqlite::types::Value; + +use crate::program::{Program, Step}; +use crate::query::Query; use crate::result_builder::ResultBuilder; +use crate::QueryBuilderConfig; #[derive(Debug, Clone)] pub struct DescribeResponse { @@ -30,6 +34,61 @@ pub trait Connection { /// Parse the SQL statement and return information about it. fn describe(&self, sql: String) -> crate::Result; + + /// execute a single query + fn execute(&mut self, query: Query) -> crate::Result>> { + #[derive(Default)] + struct RowsBuilder { + error: Option, + rows: Vec>, + current_row: Vec, + } + + impl ResultBuilder for RowsBuilder { + fn init( + &mut self, + _config: &QueryBuilderConfig, + ) -> std::result::Result<(), crate::QueryResultBuilderError> { + self.error = None; + self.rows.clear(); + self.current_row.clear(); + + Ok(()) + } + + fn add_row_value( + &mut self, + v: rusqlite::types::ValueRef, + ) -> Result<(), crate::QueryResultBuilderError> { + self.current_row.push(v.into()); + Ok(()) + } + + fn finish_row(&mut self) -> Result<(), crate::QueryResultBuilderError> { + let row = std::mem::take(&mut self.current_row); + self.rows.push(row); + + Ok(()) + } + + fn step_error( + &mut self, + error: crate::error::Error, + ) -> Result<(), crate::QueryResultBuilderError> { + self.error.replace(error); + Ok(()) + } + } + + let pgm = Program::new(vec![Step { cond: None, query }]); + let mut builder = RowsBuilder::default(); + self.execute_program(pgm, &mut builder)?; + if let Some(err) = builder.error.take() { + Err(err) + } else { + Ok(builder.rows) + } + } } impl Connection for Box { diff --git a/libsqlx/src/database/libsql/connection.rs b/libsqlx/src/database/libsql/connection.rs index 8a1c8c55..0a2cb6b0 100644 --- a/libsqlx/src/database/libsql/connection.rs +++ b/libsqlx/src/database/libsql/connection.rs @@ -237,7 +237,11 @@ fn eval_cond(cond: &Cond, results: &[bool]) -> Result { } impl Connection for LibsqlConnection { - fn execute_program(&mut self, pgm: Program, builder: &mut dyn ResultBuilder) -> crate::Result<()> { + fn execute_program( + &mut self, + pgm: Program, + builder: &mut dyn ResultBuilder, + ) -> crate::Result<()> { self.run(pgm, builder) } diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index 27397663..41de3569 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -9,15 +9,16 @@ use crate::database::{Database, InjectError, InjectableDatabase}; use crate::error::Error; use crate::result_builder::QueryBuilderConfig; -use connection::{LibsqlConnection, RowStats}; +use connection::RowStats; use injector::Injector; use replication_log::logger::{ ReplicationLogger, ReplicationLoggerHook, ReplicationLoggerHookCtx, REPLICATION_METHODS, }; use self::injector::InjectorCommitHandler; -use self::replication_log::logger::LogCompactor; +pub use connection::LibsqlConnection; +pub use replication_log::logger::{LogCompactor, LogFile}; pub use replication_log::merger::SnapshotMerger; mod connection; @@ -67,6 +68,18 @@ pub trait LibsqlDbType { fn hook_context(&self) -> ::Context; } +pub struct PlainType; + +impl LibsqlDbType for PlainType { + type ConnectionHook = TransparentMethods; + + fn hook() -> &'static WalMethodsHook { + &TRANSPARENT_METHODS + } + + fn hook_context(&self) -> ::Context {} +} + /// A generic wrapper around a libsql database. /// `LibsqlDatabase` can be specialized into either a `ReplicaType` or a `PrimaryType`. /// In `PrimaryType` mode, the LibsqlDatabase maintains a replication log that can be replicated to @@ -112,6 +125,12 @@ impl LibsqlDatabase { } } +impl LibsqlDatabase { + pub fn new_plain(db_path: PathBuf) -> crate::Result { + Ok(Self::new(db_path, PlainType)) + } +} + impl LibsqlDatabase { pub fn new_primary( db_path: PathBuf, @@ -155,16 +174,18 @@ impl Database for LibsqlDatabase { type Connection = LibsqlConnection<::Context>; fn connect(&self) -> Result { - Ok(LibsqlConnection::<::Context>::new( - &self.db_path, - self.extensions.clone(), - T::hook(), - self.ty.hook_context(), - self.row_stats_callback.clone(), - QueryBuilderConfig { - max_size: Some(self.response_size_limit), - }, - )?) + Ok( + LibsqlConnection::<::Context>::new( + &self.db_path, + self.extensions.clone(), + T::hook(), + self.ty.hook_context(), + self.row_stats_callback.clone(), + QueryBuilderConfig { + max_size: Some(self.response_size_limit), + }, + )?, + ) } } @@ -188,9 +209,9 @@ impl super::Injector for Injector { #[cfg(test)] mod test { - use std::cell::Cell; use std::fs::File; - use std::rc::Rc; + use std::sync::atomic::AtomicBool; + use std::sync::atomic::Ordering::Relaxed; use rusqlite::types::Value; @@ -224,11 +245,7 @@ mod test { let mut conn = db.connect().unwrap(); let mut builder = ReadRowBuilder(Vec::new()); - conn - .execute_program( - Program::seq(&["select count(*) from test"]), - &mut builder - ) + conn.execute_program(Program::seq(&["select count(*) from test"]), &mut builder) .unwrap(); assert!(builder.0.is_empty()); @@ -240,11 +257,7 @@ mod test { .for_each(|f| injector.inject(f.unwrap()).unwrap()); let mut builder = ReadRowBuilder(Vec::new()); - conn - .execute_program( - Program::seq(&["select count(*) from test"]), - &mut builder - ) + conn.execute_program(Program::seq(&["select count(*) from test"]), &mut builder) .unwrap(); assert_eq!(builder.0[0], Value::Integer(5)); } @@ -288,10 +301,7 @@ mod test { let mut replica_conn = replica.connect().unwrap(); let mut builder = ReadRowBuilder(Vec::new()); replica_conn - .execute_program( - Program::seq(&["select * from test limit 1"]), - &mut builder - ) + .execute_program(Program::seq(&["select * from test limit 1"]), &mut builder) .unwrap(); assert_eq!(builder.0.len(), 1); @@ -300,7 +310,7 @@ mod test { #[test] fn primary_compact_log() { - struct Compactor(Rc>); + struct Compactor(Arc); impl LogCompactor for Compactor { fn should_compact(&self, log: &LogFile) -> bool { @@ -312,14 +322,14 @@ mod test { _file: LogFile, _path: PathBuf, _size_after: u32, - ) -> anyhow::Result<()> { - self.0.set(true); + ) -> Result<(), Box> { + self.0.store(true, Relaxed); Ok(()) } } let temp = tempfile::tempdir().unwrap(); - let compactor_called = Rc::new(Cell::new(false)); + let compactor_called = Arc::new(AtomicBool::new(false)); let db = LibsqlDatabase::new_primary( temp.path().to_path_buf(), Compactor(compactor_called.clone()), @@ -333,17 +343,17 @@ mod test { &mut (), ) .unwrap(); - assert!(compactor_called.get()); + assert!(compactor_called.load(Relaxed)); } #[test] fn no_compaction_uncommited_frames() { - struct Compactor(Rc>); + struct Compactor(Arc); impl LogCompactor for Compactor { fn should_compact(&self, log: &LogFile) -> bool { assert_eq!(log.uncommitted_frame_count, 0); - self.0.set(true); + self.0.store(true, Relaxed); false } @@ -352,13 +362,13 @@ mod test { _file: LogFile, _path: PathBuf, _size_after: u32, - ) -> anyhow::Result<()> { + ) -> Result<(), Box> { unreachable!() } } let temp = tempfile::tempdir().unwrap(); - let compactor_called = Rc::new(Cell::new(false)); + let compactor_called = Arc::new(AtomicBool::new(false)); let db = LibsqlDatabase::new_primary( temp.path().to_path_buf(), Compactor(compactor_called.clone()), @@ -377,8 +387,9 @@ mod test { ) .unwrap(); conn.inner_connection().cache_flush().unwrap(); - assert!(!compactor_called.get()); - conn.execute_program(Program::seq(&["commit"]), &mut ()).unwrap(); - assert!(compactor_called.get()); + assert!(!compactor_called.load(Relaxed)); + conn.execute_program(Program::seq(&["commit"]), &mut ()) + .unwrap(); + assert!(compactor_called.load(Relaxed)); } } diff --git a/libsqlx/src/database/libsql/replication_log/logger.rs b/libsqlx/src/database/libsql/replication_log/logger.rs index 25546e36..aebff0db 100644 --- a/libsqlx/src/database/libsql/replication_log/logger.rs +++ b/libsqlx/src/database/libsql/replication_log/logger.rs @@ -602,7 +602,9 @@ impl LogFile { // swap old and new snapshot atomic_rename(&temp_log_path, path.join("wallog")).unwrap(); let old_log_file = std::mem::replace(self, new_log_file); - compactor.compact(old_log_file, temp_log_path, size_after)?; + compactor + .compact(old_log_file, temp_log_path, size_after) + .unwrap(); Ok(()) } @@ -717,17 +719,27 @@ impl Generation { } } -pub trait LogCompactor: 'static { +pub trait LogCompactor: Sync + Send + 'static { /// returns whether the passed log file should be compacted. If this method returns true, /// compact should be called next. fn should_compact(&self, log: &LogFile) -> bool; /// Compact the given snapshot - fn compact(&self, log: LogFile, path: PathBuf, size_after: u32) -> anyhow::Result<()>; + fn compact( + &self, + log: LogFile, + path: PathBuf, + size_after: u32, + ) -> Result<(), Box>; } #[cfg(test)] impl LogCompactor for () { - fn compact(&self, _file: LogFile, _path: PathBuf, _size_after: u32) -> anyhow::Result<()> { + fn compact( + &self, + _file: LogFile, + _path: PathBuf, + _size_after: u32, + ) -> Result<(), Box> { Ok(()) } @@ -739,7 +751,7 @@ impl LogCompactor for () { pub struct ReplicationLogger { pub generation: Generation, pub log_file: RwLock, - compactor: Box, + compactor: Box, db_path: PathBuf, /// a notifier channel other tasks can subscribe to, and get notified when new frames become /// available. diff --git a/libsqlx/src/database/mod.rs b/libsqlx/src/database/mod.rs index fa1ce874..62581402 100644 --- a/libsqlx/src/database/mod.rs +++ b/libsqlx/src/database/mod.rs @@ -1,8 +1,8 @@ use std::time::Duration; +use self::frame::Frame; use crate::connection::Connection; use crate::error::Error; -use self::frame::Frame; mod frame; pub mod libsql; From 52c566c4d86dca7de1c2e7f7d1ec350b7d7e366a Mon Sep 17 00:00:00 2001 From: ad hoc Date: Mon, 10 Jul 2023 09:56:41 +0200 Subject: [PATCH 7/8] user API & database extractor --- Cargo.lock | 34 ++++++++--- libsqlx-server/Cargo.toml | 2 + libsqlx-server/src/http/admin.rs | 12 ++-- libsqlx-server/src/http/mod.rs | 1 + libsqlx-server/src/http/user.rs | 101 +++++++++++++++++++++++++++++++ libsqlx-server/src/main.rs | 33 +++++++--- libsqlx-server/src/manager.rs | 17 ++++-- libsqlx-server/src/meta.rs | 4 +- 8 files changed, 174 insertions(+), 30 deletions(-) create mode 100644 libsqlx-server/src/http/user.rs diff --git a/Cargo.lock b/Cargo.lock index 0864a55e..b4364605 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2575,8 +2575,10 @@ dependencies = [ "hyper", "libsqlx", "moka", + "regex", "serde", "sled", + "thiserror", "tokio", "tracing", "tracing-subscriber", @@ -2650,7 +2652,7 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" dependencies = [ - "regex-automata", + "regex-automata 0.1.10", ] [[package]] @@ -3548,13 +3550,14 @@ dependencies = [ [[package]] name = "regex" -version = "1.8.4" +version = "1.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f" +checksum = "b2eae68fc220f7cf2532e4494aded17545fce192d59cd996e0fe7887f4ceb575" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.7.2", + "regex-automata 0.3.2", + "regex-syntax 0.7.3", ] [[package]] @@ -3566,6 +3569,17 @@ dependencies = [ "regex-syntax 0.6.29", ] +[[package]] +name = "regex-automata" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83d3daa6976cffb758ec878f108ba0e062a45b2d6ca3a2cca965338855476caf" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax 0.7.3", +] + [[package]] name = "regex-syntax" version = "0.6.29" @@ -3574,9 +3588,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78" +checksum = "2ab07dc67230e4a4718e70fd5c20055a4334b121f1f9db8fe63ef39ce9b8c846" [[package]] name = "reqwest" @@ -4334,18 +4348,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.40" +version = "1.0.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" +checksum = "a35fc5b8971143ca348fa6df4f024d4d55264f3468c71ad1c2f365b0a4d58c42" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.40" +version = "1.0.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" +checksum = "463fe12d7993d3b327787537ce8dd4dfa058de32fc2b195ef3cde03dc4771e8f" dependencies = [ "proc-macro2", "quote", diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index 26eb60ff..4b4668ba 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -14,8 +14,10 @@ futures = "0.3.28" hyper = { version = "0.14.27", features = ["h2", "server"] } libsqlx = { version = "0.1.0", path = "../libsqlx" } moka = { version = "0.11.2", features = ["future"] } +regex = "1.9.1" serde = { version = "1.0.166", features = ["derive"] } sled = "0.34.7" +thiserror = "1.0.43" tokio = { version = "1.29.1", features = ["full"] } tracing = "0.1.37" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index 51ba1b7f..80e787d6 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -11,21 +11,21 @@ use crate::{ meta::Store, }; -pub struct AdminServerConfig { - pub db_path: PathBuf, +pub struct AdminApiConfig { + pub meta_store: Arc, } struct AdminServerState { meta_store: Arc, } -pub async fn run_admin_server(config: AdminServerConfig, listener: I) -> Result<()> +pub async fn run_admin_api(config: AdminApiConfig, listener: I) -> Result<()> where I: Accept, I::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static, { let state = AdminServerState { - meta_store: Arc::new(Store::new(&config.db_path)), + meta_store: config.meta_store, }; let app = Router::new() @@ -54,7 +54,7 @@ struct AllocateReq { #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum DbConfigReq { - Primary { }, + Primary {}, Replica { primary_node_id: String }, } @@ -66,7 +66,7 @@ async fn allocate( max_conccurent_connection: req.max_conccurent_connection.unwrap_or(16), id: req.alloc_id.clone(), db_config: match req.config { - DbConfigReq::Primary { } => DbConfig::Primary { }, + DbConfigReq::Primary {} => DbConfig::Primary {}, DbConfigReq::Replica { primary_node_id } => DbConfig::Replica { primary_node_id }, }, }; diff --git a/libsqlx-server/src/http/mod.rs b/libsqlx-server/src/http/mod.rs index 92918b09..1e6bf65b 100644 --- a/libsqlx-server/src/http/mod.rs +++ b/libsqlx-server/src/http/mod.rs @@ -1 +1,2 @@ pub mod admin; +pub mod user; diff --git a/libsqlx-server/src/http/user.rs b/libsqlx-server/src/http/user.rs new file mode 100644 index 00000000..040f5a66 --- /dev/null +++ b/libsqlx-server/src/http/user.rs @@ -0,0 +1,101 @@ +use std::sync::Arc; + +use axum::{async_trait, extract::FromRequestParts, response::IntoResponse, routing::get, Router, Json}; +use color_eyre::Result; +use hyper::{http::request::Parts, server::accept::Accept, StatusCode}; +use serde::Serialize; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::mpsc, +}; + +use crate::{allocation::AllocationMessage, manager::Manager}; + +pub struct UserApiConfig { + pub manager: Arc, +} + +struct UserApiState { + manager: Arc, +} + +pub async fn run_user_api(config: UserApiConfig, listener: I) -> Result<()> +where + I: Accept, + I::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ + let state = UserApiState { manager: config.manager }; + + let app = Router::new() + .route("/", get(test_database)) + .with_state(Arc::new(state)); + + axum::Server::builder(listener) + .serve(app.into_make_service()) + .await?; + + Ok(()) +} + +struct Database { + sender: mpsc::Sender, +} + +#[derive(Debug, thiserror::Error)] +enum UserApiError { + #[error("missing host header")] + MissingHost, + #[error("invalid host header format")] + InvalidHost, + #[error("Database `{0}` doesn't exist")] + UnknownDatabase(String), +} + +impl UserApiError { + fn http_status(&self) -> StatusCode { + match self { + UserApiError::MissingHost + | UserApiError::InvalidHost + | UserApiError::UnknownDatabase(_) => StatusCode::BAD_REQUEST, + } + } +} + +#[derive(Debug, Serialize)] +struct ApiError { + error: String, +} + +impl IntoResponse for UserApiError { + fn into_response(self) -> axum::response::Response { + let mut resp = Json(ApiError { + error: self.to_string() + }).into_response(); + *resp.status_mut() = self.http_status(); + + resp + } +} + +#[async_trait] +impl FromRequestParts> for Database { + type Rejection = UserApiError; + + async fn from_request_parts( + parts: &mut Parts, + state: &Arc, + ) -> Result { + let Some(host) = parts.headers.get("host") else { return Err(UserApiError::MissingHost) }; + let Ok(host_str) = std::str::from_utf8(host.as_bytes()) else {return Err(UserApiError::MissingHost)}; + let db_id = parse_host(host_str)?; + let Some(sender) = state.manager.alloc(db_id).await else { return Err(UserApiError::UnknownDatabase(db_id.to_owned())) }; + + Ok(Database { sender }) + } +} + +fn parse_host(host: &str) -> Result<&str, UserApiError> { + let mut split = host.split("."); + let Some(db_id) = split.next() else { return Err(UserApiError::InvalidHost) }; + Ok(db_id) +} diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index fb213397..2e9411cf 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -1,13 +1,20 @@ -use std::path::PathBuf; +use std::{path::PathBuf, sync::Arc}; use color_eyre::eyre::Result; -use http::admin::{run_admin_server, AdminServerConfig}; +use http::{ + admin::{run_admin_api, AdminApiConfig}, + user::{run_user_api, UserApiConfig}, +}; use hyper::server::conn::AddrIncoming; +use manager::Manager; +use meta::Store; +use tokio::task::JoinSet; use tracing::metadata::LevelFilter; use tracing_subscriber::prelude::*; mod allocation; mod databases; +mod hrana; mod http; mod manager; mod meta; @@ -15,16 +22,24 @@ mod meta; #[tokio::main] async fn main() -> Result<()> { init(); + let mut join_set = JoinSet::new(); + let db_path = PathBuf::from("database"); + let store = Arc::new(Store::new(&db_path)); let admin_api_listener = tokio::net::TcpListener::bind("0.0.0.0:3456").await?; - run_admin_server( - AdminServerConfig { - db_path: PathBuf::from("database"), - }, + join_set.spawn(run_admin_api( + AdminApiConfig { meta_store: store.clone() }, AddrIncoming::from_listener(admin_api_listener)?, - ) - .await - .unwrap(); + )); + + let manager = Arc::new(Manager::new(db_path.clone(), store, 100)); + let user_api_listener = tokio::net::TcpListener::bind("0.0.0.0:3457").await?; + join_set.spawn(run_user_api( + UserApiConfig { manager }, + AddrIncoming::from_listener(user_api_listener)?, + )); + + join_set.join_next().await; Ok(()) } diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index 8d7737a7..81ac3b72 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -18,9 +18,18 @@ pub struct Manager { const MAX_ALLOC_MESSAGE_QUEUE_LEN: usize = 32; impl Manager { - pub async fn alloc(&self, alloc_id: &str) -> mpsc::Sender { + pub fn new(db_path: PathBuf, meta_store: Arc, max_conccurent_allocs: u64) -> Self { + Self { + cache: Cache::new(max_conccurent_allocs), + meta_store, + db_path, + } + } + + /// Returns a handle to an allocation, lazily initializing if it isn't already loaded. + pub async fn alloc(&self, alloc_id: &str) -> Option> { if let Some(sender) = self.cache.get(alloc_id) { - return sender.clone(); + return Some(sender.clone()); } if let Some(config) = self.meta_store.meta(alloc_id).await { @@ -42,9 +51,9 @@ impl Manager { .insert(alloc_id.to_string(), alloc_sender.clone()) .await; - return alloc_sender; + return Some(alloc_sender); } - todo!("alloc doesn't exist") + None } } diff --git a/libsqlx-server/src/meta.rs b/libsqlx-server/src/meta.rs index 475a0250..4eade1b0 100644 --- a/libsqlx-server/src/meta.rs +++ b/libsqlx-server/src/meta.rs @@ -32,7 +32,9 @@ impl Store { }); } - pub async fn deallocate(&self, alloc_id: Uuid) {} + pub async fn deallocate(&self, alloc_id: Uuid) { + todo!() + } pub async fn meta(&self, alloc_id: &str) -> Option { tokio::task::block_in_place(|| { From c2fed59890d12c5914a7e7151237f63276872fd0 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Mon, 10 Jul 2023 15:38:08 +0200 Subject: [PATCH 8/8] port hrana for libsqlx server we can now allocate primaries, and query them --- Cargo.lock | 36 +- libsqlx-server/Cargo.toml | 10 +- libsqlx-server/src/allocation/mod.rs | 101 ++++-- libsqlx-server/src/database.rs | 19 + libsqlx-server/src/databases/mod.rs | 5 - libsqlx-server/src/databases/store.rs | 12 - libsqlx-server/src/hrana/batch.rs | 131 +++++++ libsqlx-server/src/hrana/http/mod.rs | 118 ++++++ libsqlx-server/src/hrana/http/proto.rs | 115 ++++++ libsqlx-server/src/hrana/http/request.rs | 115 ++++++ libsqlx-server/src/hrana/http/stream.rs | 404 +++++++++++++++++++++ libsqlx-server/src/hrana/mod.rs | 68 ++++ libsqlx-server/src/hrana/proto.rs | 160 ++++++++ libsqlx-server/src/hrana/result_builder.rs | 320 ++++++++++++++++ libsqlx-server/src/hrana/stmt.rs | 289 +++++++++++++++ libsqlx-server/src/hrana/ws/conn.rs | 301 +++++++++++++++ libsqlx-server/src/hrana/ws/handshake.rs | 140 +++++++ libsqlx-server/src/hrana/ws/mod.rs | 104 ++++++ libsqlx-server/src/hrana/ws/proto.rs | 127 +++++++ libsqlx-server/src/hrana/ws/session.rs | 329 +++++++++++++++++ libsqlx-server/src/http/admin.rs | 2 +- libsqlx-server/src/http/user.rs | 101 ------ libsqlx-server/src/http/user/error.rs | 41 +++ libsqlx-server/src/http/user/extractors.rs | 32 ++ libsqlx-server/src/http/user/mod.rs | 48 +++ libsqlx-server/src/main.rs | 6 +- libsqlx-server/src/manager.rs | 4 +- libsqlx-server/src/meta.rs | 4 +- libsqlx/src/analysis.rs | 11 +- libsqlx/src/connection.rs | 11 +- libsqlx/src/database/libsql/connection.rs | 2 +- libsqlx/src/database/libsql/mod.rs | 1 + libsqlx/src/error.rs | 12 +- libsqlx/src/lib.rs | 11 +- libsqlx/src/program.rs | 19 +- libsqlx/src/result_builder.rs | 8 +- 36 files changed, 3015 insertions(+), 202 deletions(-) create mode 100644 libsqlx-server/src/database.rs delete mode 100644 libsqlx-server/src/databases/mod.rs delete mode 100644 libsqlx-server/src/databases/store.rs create mode 100644 libsqlx-server/src/hrana/batch.rs create mode 100644 libsqlx-server/src/hrana/http/mod.rs create mode 100644 libsqlx-server/src/hrana/http/proto.rs create mode 100644 libsqlx-server/src/hrana/http/request.rs create mode 100644 libsqlx-server/src/hrana/http/stream.rs create mode 100644 libsqlx-server/src/hrana/mod.rs create mode 100644 libsqlx-server/src/hrana/proto.rs create mode 100644 libsqlx-server/src/hrana/result_builder.rs create mode 100644 libsqlx-server/src/hrana/stmt.rs create mode 100644 libsqlx-server/src/hrana/ws/conn.rs create mode 100644 libsqlx-server/src/hrana/ws/handshake.rs create mode 100644 libsqlx-server/src/hrana/ws/mod.rs create mode 100644 libsqlx-server/src/hrana/ws/proto.rs create mode 100644 libsqlx-server/src/hrana/ws/session.rs delete mode 100644 libsqlx-server/src/http/user.rs create mode 100644 libsqlx-server/src/http/user/error.rs create mode 100644 libsqlx-server/src/http/user/extractors.rs create mode 100644 libsqlx-server/src/http/user/mod.rs diff --git a/Cargo.lock b/Cargo.lock index b4364605..9752e54f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -695,9 +695,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] name = "base64" -version = "0.21.1" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f1e31e207a6b8fb791a38ea3105e6cb541f55e4d029902d3039a4ad07cc4105" +checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" [[package]] name = "base64-simd" @@ -2436,7 +2436,7 @@ version = "8.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6971da4d9c3aa03c3d8f3ff0f4155b534aad021292003895a469716b2a230378" dependencies = [ - "base64 0.21.1", + "base64 0.21.2", "pem", "ring", "serde", @@ -2502,7 +2502,7 @@ checksum = "9c7b1c078b4d3d45ba0db91accc23dcb8d2761d67f819efd94293065597b7ac8" dependencies = [ "anyhow", "async-trait", - "base64 0.21.1", + "base64 0.21.2", "num-traits", "reqwest", "serde_json", @@ -2568,15 +2568,23 @@ name = "libsqlx-server" version = "0.1.0" dependencies = [ "axum", + "base64 0.21.2", "bincode", + "bytes 1.4.0", "clap", "color-eyre", "futures", + "hmac", "hyper", "libsqlx", "moka", + "parking_lot 0.12.1", + "priority-queue", + "rand", "regex", "serde", + "serde_json", + "sha2", "sled", "thiserror", "tokio", @@ -3286,9 +3294,9 @@ dependencies = [ [[package]] name = "priority-queue" -version = "1.3.1" +version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ca9c6be70d989d21a136eb86c2d83e4b328447fac4a88dace2143c179c86267" +checksum = "fff39edfcaec0d64e8d0da38564fad195d2d51b680940295fcc307366e101e61" dependencies = [ "autocfg", "indexmap 1.9.3", @@ -3598,7 +3606,7 @@ version = "0.11.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55" dependencies = [ - "base64 0.21.1", + "base64 0.21.2", "bytes 1.4.0", "encoding_rs", "futures-core", @@ -3755,7 +3763,7 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" dependencies = [ - "base64 0.21.1", + "base64 0.21.2", ] [[package]] @@ -3900,9 +3908,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.99" +version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46266871c240a00b8f503b877622fe33430b3c7d963bdc0f2adc511e54a1eae3" +checksum = "0f1e14e89be7aa4c4b78bdbdc9eb5bf8517829a600ae8eaa39a6e1d960b5185c" dependencies = [ "indexmap 2.0.0", "itoa", @@ -3944,9 +3952,9 @@ dependencies = [ [[package]] name = "sha2" -version = "0.10.6" +version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0" +checksum = "479fb9d862239e610720565ca91403019f2f00410f1864c5aa7479b950a76ed8" dependencies = [ "cfg-if", "cpufeatures", @@ -4116,7 +4124,7 @@ dependencies = [ "aws-config", "aws-sdk-s3", "axum", - "base64 0.21.1", + "base64 0.21.2", "bincode", "bottomless", "bytemuck", @@ -4590,7 +4598,7 @@ checksum = "3082666a3a6433f7f511c7192923fa1fe07c69332d3c6a2e6bb040b569199d5a" dependencies = [ "async-trait", "axum", - "base64 0.21.1", + "base64 0.21.2", "bytes 1.4.0", "futures-core", "futures-util", diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index 4b4668ba..5e6d5c15 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -7,15 +7,23 @@ edition = "2021" [dependencies] axum = "0.6.18" +base64 = "0.21.2" bincode = "1.3.3" +bytes = "1.4.0" clap = { version = "4.3.11", features = ["derive"] } color-eyre = "0.6.2" futures = "0.3.28" +hmac = "0.12.1" hyper = { version = "0.14.27", features = ["h2", "server"] } libsqlx = { version = "0.1.0", path = "../libsqlx" } moka = { version = "0.11.2", features = ["future"] } +parking_lot = "0.12.1" +priority-queue = "1.3.2" +rand = "0.8.5" regex = "1.9.1" -serde = { version = "1.0.166", features = ["derive"] } +serde = { version = "1.0.166", features = ["derive", "rc"] } +serde_json = "1.0.100" +sha2 = "0.10.7" sled = "0.34.7" thiserror = "1.0.43" tokio = { version = "1.29.1", features = ["full"] } diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 21e4c97c..a086f479 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -1,11 +1,15 @@ -use std::collections::HashMap; use std::path::PathBuf; +use std::sync::Arc; use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile, PrimaryType}; use libsqlx::Database as _; use tokio::sync::{mpsc, oneshot}; use tokio::task::{block_in_place, JoinSet}; +use crate::hrana; +use crate::hrana::http::handle_pipeline; +use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; + use self::config::{AllocConfig, DbConfig}; pub mod config; @@ -19,16 +23,11 @@ pub struct ConnectionId { } pub enum AllocationMessage { - /// Execute callback against connection - Exec { - connection_id: ConnectionId, - exec: ExecFn, - }, - /// Create a new connection, execute the callback and return the connection id. - NewConnExec { - exec: ExecFn, - ret: oneshot::Sender, - }, + NewConnection(oneshot::Sender), + HranaPipelineReq { + req: PipelineRequestBody, + ret: oneshot::Sender>, + } } pub enum Database { @@ -73,12 +72,34 @@ impl Database { pub struct Allocation { pub inbox: mpsc::Receiver, pub database: Database, - /// senders to the spawned connections - pub connections: HashMap>, /// spawned connection futures, returning their connection id on completion. pub connections_futs: JoinSet, pub next_conn_id: u32, pub max_concurrent_connections: u32, + + pub hrana_server: Arc, +} + +pub struct ConnectionHandle { + exec: mpsc::Sender, + exit: oneshot::Sender<()>, +} + +impl ConnectionHandle { + pub async fn exec(&self, f: F) -> crate::Result + where F: for<'a> FnOnce(&'a mut (dyn libsqlx::Connection + 'a)) -> R + Send + 'static, + R: Send + 'static, + { + let (sender, ret) = oneshot::channel(); + let cb = move |conn: &mut dyn libsqlx::Connection| { + let res = f(conn); + let _ = sender.send(res); + }; + + self.exec.send(Box::new(cb)).await.unwrap(); + + Ok(ret.await?) + } } impl Allocation { @@ -87,23 +108,22 @@ impl Allocation { tokio::select! { Some(msg) = self.inbox.recv() => { match msg { - AllocationMessage::Exec { connection_id, exec } => { - if let Some(sender) = self.connections.get(&connection_id.id) { - if let Err(_) = sender.send(exec).await { - tracing::debug!("connection {} closed.", connection_id.id); - self.connections.remove_entry(&connection_id.id); - } - } - }, - AllocationMessage::NewConnExec { exec, ret } => { - let id = self.new_conn_exec(exec).await; - let _ = ret.send(id); + AllocationMessage::NewConnection(ret) => { + let _ =ret.send(self.new_conn().await); }, + AllocationMessage::HranaPipelineReq { req, ret} => { + let res = handle_pipeline(&self.hrana_server.clone(), req, || async { + let conn= self.new_conn().await; + dbg!(); + Ok(conn) + }).await; + let _ = ret.send(res); + } } }, maybe_id = self.connections_futs.join_next() => { - if let Some(Ok(id)) = maybe_id { - self.connections.remove_entry(&id); + if let Some(Ok(_id)) = maybe_id { + // self.connections.remove_entry(&id); } }, else => break, @@ -111,10 +131,13 @@ impl Allocation { } } - async fn new_conn_exec(&mut self, exec: ExecFn) -> ConnectionId { + async fn new_conn(&mut self) -> ConnectionHandle { + dbg!(); let id = self.next_conn_id(); + dbg!(); let conn = block_in_place(|| self.database.connect()); - let (close_sender, exit) = mpsc::channel(1); + dbg!(); + let (close_sender, exit) = oneshot::channel(); let (exec_sender, exec_receiver) = mpsc::channel(1); let conn = Connection { id, @@ -123,20 +146,24 @@ impl Allocation { exec: exec_receiver, }; + dbg!(); self.connections_futs.spawn(conn.run()); - // This should never block! - assert!(exec_sender.try_send(exec).is_ok()); - assert!(self.connections.insert(id, exec_sender).is_none()); + dbg!(); + + ConnectionHandle { + exec: exec_sender, + exit: close_sender, + } - ConnectionId { id, close_sender } } fn next_conn_id(&mut self) -> u32 { loop { self.next_conn_id = self.next_conn_id.wrapping_add(1); - if !self.connections.contains_key(&self.next_conn_id) { - return self.next_conn_id; - } + return self.next_conn_id; + // if !self.connections.contains_key(&self.next_conn_id) { + // return self.next_conn_id; + // } } } } @@ -144,7 +171,7 @@ impl Allocation { struct Connection { id: u32, conn: Box, - exit: mpsc::Receiver<()>, + exit: oneshot::Receiver<()>, exec: mpsc::Receiver, } @@ -152,7 +179,7 @@ impl Connection { async fn run(mut self) -> u32 { loop { tokio::select! { - _ = self.exit.recv() => break, + _ = &mut self.exit => break, Some(exec) = self.exec.recv() => { tokio::task::block_in_place(|| exec(&mut *self.conn)); } diff --git a/libsqlx-server/src/database.rs b/libsqlx-server/src/database.rs new file mode 100644 index 00000000..d0c979cc --- /dev/null +++ b/libsqlx-server/src/database.rs @@ -0,0 +1,19 @@ +use tokio::sync::{mpsc, oneshot}; + +use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; +use crate::allocation::{AllocationMessage, ConnectionHandle}; + +pub struct Database { + pub sender: mpsc::Sender, +} + +impl Database { + pub async fn hrana_pipeline(&self, req: PipelineRequestBody) -> crate::Result { + dbg!(); + let (sender, ret) = oneshot::channel(); + dbg!(); + self.sender.send(AllocationMessage::HranaPipelineReq { req, ret: sender }).await.unwrap(); + dbg!(); + ret.await.unwrap() + } +} diff --git a/libsqlx-server/src/databases/mod.rs b/libsqlx-server/src/databases/mod.rs deleted file mode 100644 index 0494174b..00000000 --- a/libsqlx-server/src/databases/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -use uuid::Uuid; - -mod store; - -pub type DatabaseId = Uuid; diff --git a/libsqlx-server/src/databases/store.rs b/libsqlx-server/src/databases/store.rs deleted file mode 100644 index 206beb34..00000000 --- a/libsqlx-server/src/databases/store.rs +++ /dev/null @@ -1,12 +0,0 @@ -use std::collections::HashMap; - -use super::DatabaseId; - -pub enum Database { - Replica, - Primary, -} - -pub struct DatabaseManager { - databases: HashMap, -} diff --git a/libsqlx-server/src/hrana/batch.rs b/libsqlx-server/src/hrana/batch.rs new file mode 100644 index 00000000..7d2a1f0c --- /dev/null +++ b/libsqlx-server/src/hrana/batch.rs @@ -0,0 +1,131 @@ +use std::collections::HashMap; + +use crate::allocation::ConnectionHandle; +use crate::hrana::stmt::StmtError; + +use super::result_builder::HranaBatchProtoBuilder; +use super::stmt::{proto_stmt_to_query, stmt_error_from_sqld_error}; +use super::{proto, ProtocolError, Version}; + +use color_eyre::eyre::anyhow; +use libsqlx::analysis::Statement; +use libsqlx::program::{Cond, Program, Step}; +use libsqlx::query::{Query, Params}; +use libsqlx::result_builder::{StepResult, StepResultsBuilder}; + +fn proto_cond_to_cond(cond: &proto::BatchCond, max_step_i: usize) -> color_eyre::Result { + let try_convert_step = |step: i32| -> Result { + let step = usize::try_from(step).map_err(|_| ProtocolError::BatchCondBadStep)?; + if step >= max_step_i { + return Err(ProtocolError::BatchCondBadStep); + } + Ok(step) + }; + let cond = match cond { + proto::BatchCond::Ok { step } => Cond::Ok { + step: try_convert_step(*step)?, + }, + proto::BatchCond::Error { step } => Cond::Err { + step: try_convert_step(*step)?, + }, + proto::BatchCond::Not { cond } => Cond::Not { + cond: proto_cond_to_cond(cond, max_step_i)?.into(), + }, + proto::BatchCond::And { conds } => Cond::And { + conds: conds + .iter() + .map(|cond| proto_cond_to_cond(cond, max_step_i)) + .collect::>()?, + }, + proto::BatchCond::Or { conds } => Cond::Or { + conds: conds + .iter() + .map(|cond| proto_cond_to_cond(cond, max_step_i)) + .collect::>()?, + }, + }; + + Ok(cond) +} + +pub fn proto_batch_to_program( + batch: &proto::Batch, + sqls: &HashMap, + version: Version, +) -> color_eyre::Result { + let mut steps = Vec::with_capacity(batch.steps.len()); + for (step_i, step) in batch.steps.iter().enumerate() { + let query = proto_stmt_to_query(&step.stmt, sqls, version)?; + let cond = step + .condition + .as_ref() + .map(|cond| proto_cond_to_cond(cond, step_i)) + .transpose()?; + let step = Step { query, cond }; + + steps.push(step); + } + + Ok(Program::new(steps)) +} + +pub async fn execute_batch( + db: &ConnectionHandle, + pgm: Program, +) -> color_eyre::Result { + let builder = db.exec(move |conn| -> color_eyre::Result<_> { + let mut builder = HranaBatchProtoBuilder::default(); + conn.execute_program(pgm, &mut builder)?; + Ok(builder) + }).await??; + + Ok(builder.into_ret()) +} + +pub fn proto_sequence_to_program(sql: &str) -> color_eyre::Result { + let stmts = Statement::parse(sql) + .collect::>>() + .map_err(|err| anyhow!(StmtError::SqlParse { source: err.into() }))?; + + let steps = stmts + .into_iter() + .enumerate() + .map(|(step_i, stmt)| { + let cond = match step_i { + 0 => None, + _ => Some(Cond::Ok { step: step_i - 1 }), + }; + let query = Query { + stmt, + params: Params::empty(), + want_rows: false, + }; + Step { cond, query } + }) + .collect(); + + Ok(Program { + steps, + }) +} + +pub async fn execute_sequence(conn: &ConnectionHandle, pgm: Program) -> color_eyre::Result<()> { + let builder = conn.exec(move |conn| -> color_eyre::Result<_> { + let mut builder = StepResultsBuilder::default(); + conn.execute_program(pgm, &mut builder)?; + + Ok(builder) + }).await??; + + builder + .into_ret() + .into_iter() + .try_for_each(|result| match result { + StepResult::Ok => Ok(()), + StepResult::Err(e) => match stmt_error_from_sqld_error(e) { + Ok(stmt_err) => Err(anyhow!(stmt_err)), + Err(sqld_err) => Err(anyhow!(sqld_err)), + }, + StepResult::Skipped => Err(anyhow!("Statement in sequence was not executed")), + }) +} diff --git a/libsqlx-server/src/hrana/http/mod.rs b/libsqlx-server/src/hrana/http/mod.rs new file mode 100644 index 00000000..5e22bedc --- /dev/null +++ b/libsqlx-server/src/hrana/http/mod.rs @@ -0,0 +1,118 @@ +use color_eyre::eyre::Context; +use futures::Future; +use parking_lot::Mutex; +use serde::{de::DeserializeOwned, Serialize}; + +use crate::allocation::ConnectionHandle; + +use self::proto::{PipelineRequestBody, PipelineResponseBody}; + +use super::ProtocolError; + +pub mod proto; +mod request; +mod stream; + +pub struct Server { + self_url: Option, + baton_key: [u8; 32], + stream_state: Mutex, +} + +#[derive(Debug)] +pub enum Route { + GetIndex, + PostPipeline, +} + +impl Server { + pub fn new(self_url: Option) -> Self { + Self { + self_url, + baton_key: rand::random(), + stream_state: Mutex::new(stream::ServerStreamState::new()), + } + } + + pub async fn run_expire(&self) { + stream::run_expire(self).await + } +} + +fn handle_index() -> color_eyre::Result> { + Ok(text_response( + hyper::StatusCode::OK, + "Hello, this is HTTP API v2 (Hrana over HTTP)".into(), + )) +} + +pub async fn handle_pipeline( + server: &Server, + req: PipelineRequestBody, + mk_conn: F +) -> color_eyre::Result +where F: FnOnce() -> Fut, + Fut: Future>, +{ + let mut stream_guard = stream::acquire(server, req.baton.as_deref(), mk_conn).await?; + + let mut results = Vec::with_capacity(req.requests.len()); + for request in req.requests.into_iter() { + let result = request::handle(&mut stream_guard, request) + .await + .context("Could not execute a request in pipeline")?; + results.push(result); + } + + let resp_body = proto::PipelineResponseBody { + baton: stream_guard.release(), + base_url: server.self_url.clone(), + results, + }; + + Ok(resp_body) +} + +async fn read_request_json(req: hyper::Request) -> color_eyre::Result { + let req_body = hyper::body::to_bytes(req.into_body()) + .await + .context("Could not read request body")?; + let req_body = serde_json::from_slice(&req_body) + .map_err(|err| ProtocolError::Deserialize { source: err }) + .context("Could not deserialize JSON request body")?; + Ok(req_body) +} + +fn protocol_error_response(err: ProtocolError) -> hyper::Response { + text_response(hyper::StatusCode::BAD_REQUEST, err.to_string()) +} + +fn stream_error_response(err: stream::StreamError) -> hyper::Response { + json_response( + hyper::StatusCode::INTERNAL_SERVER_ERROR, + &proto::Error { + message: err.to_string(), + code: err.code().into(), + }, + ) +} + +fn json_response( + status: hyper::StatusCode, + resp_body: &T, +) -> hyper::Response { + let resp_body = serde_json::to_vec(resp_body).unwrap(); + hyper::Response::builder() + .status(status) + .header(hyper::http::header::CONTENT_TYPE, "application/json") + .body(hyper::Body::from(resp_body)) + .unwrap() +} + +fn text_response(status: hyper::StatusCode, resp_body: String) -> hyper::Response { + hyper::Response::builder() + .status(status) + .header(hyper::http::header::CONTENT_TYPE, "text/plain") + .body(hyper::Body::from(resp_body)) + .unwrap() +} diff --git a/libsqlx-server/src/hrana/http/proto.rs b/libsqlx-server/src/hrana/http/proto.rs new file mode 100644 index 00000000..ba1285f1 --- /dev/null +++ b/libsqlx-server/src/hrana/http/proto.rs @@ -0,0 +1,115 @@ +//! Structures for Hrana-over-HTTP. + +pub use super::super::proto::*; +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, Debug)] +pub struct PipelineRequestBody { + pub baton: Option, + pub requests: Vec, +} + +#[derive(Serialize, Debug)] +pub struct PipelineResponseBody { + pub baton: Option, + pub base_url: Option, + pub results: Vec, +} + +#[derive(Serialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum StreamResult { + Ok { response: StreamResponse }, + Error { error: Error }, +} + +#[derive(Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum StreamRequest { + Close(CloseStreamReq), + Execute(ExecuteStreamReq), + Batch(BatchStreamReq), + Sequence(SequenceStreamReq), + Describe(DescribeStreamReq), + StoreSql(StoreSqlStreamReq), + CloseSql(CloseSqlStreamReq), +} + +#[derive(Serialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum StreamResponse { + Close(CloseStreamResp), + Execute(ExecuteStreamResp), + Batch(BatchStreamResp), + Sequence(SequenceStreamResp), + Describe(DescribeStreamResp), + StoreSql(StoreSqlStreamResp), + CloseSql(CloseSqlStreamResp), +} + +#[derive(Deserialize, Debug)] +pub struct CloseStreamReq {} + +#[derive(Serialize, Debug)] +pub struct CloseStreamResp {} + +#[derive(Deserialize, Debug)] +pub struct ExecuteStreamReq { + pub stmt: Stmt, +} + +#[derive(Serialize, Debug)] +pub struct ExecuteStreamResp { + pub result: StmtResult, +} + +#[derive(Deserialize, Debug)] +pub struct BatchStreamReq { + pub batch: Batch, +} + +#[derive(Serialize, Debug)] +pub struct BatchStreamResp { + pub result: BatchResult, +} + +#[derive(Deserialize, Debug)] +pub struct SequenceStreamReq { + #[serde(default)] + pub sql: Option, + #[serde(default)] + pub sql_id: Option, +} + +#[derive(Serialize, Debug)] +pub struct SequenceStreamResp {} + +#[derive(Deserialize, Debug)] +pub struct DescribeStreamReq { + #[serde(default)] + pub sql: Option, + #[serde(default)] + pub sql_id: Option, +} + +#[derive(Serialize, Debug)] +pub struct DescribeStreamResp { + pub result: DescribeResult, +} + +#[derive(Deserialize, Debug)] +pub struct StoreSqlStreamReq { + pub sql_id: i32, + pub sql: String, +} + +#[derive(Serialize, Debug)] +pub struct StoreSqlStreamResp {} + +#[derive(Deserialize, Debug)] +pub struct CloseSqlStreamReq { + pub sql_id: i32, +} + +#[derive(Serialize, Debug)] +pub struct CloseSqlStreamResp {} diff --git a/libsqlx-server/src/hrana/http/request.rs b/libsqlx-server/src/hrana/http/request.rs new file mode 100644 index 00000000..ac6d8912 --- /dev/null +++ b/libsqlx-server/src/hrana/http/request.rs @@ -0,0 +1,115 @@ +use color_eyre::eyre::{anyhow, bail}; + +use super::super::{batch, stmt, ProtocolError, Version}; +use super::{proto, stream}; + +/// An error from executing a [`proto::StreamRequest`] +#[derive(thiserror::Error, Debug)] +pub enum StreamResponseError { + #[error("The server already stores {count} SQL texts, it cannot store more")] + SqlTooMany { count: usize }, + #[error(transparent)] + Stmt(stmt::StmtError), +} + +pub async fn handle( + stream_guard: &mut stream::Guard<'_>, + request: proto::StreamRequest, +) -> color_eyre::Result { + let result = match try_handle(stream_guard, request).await { + Ok(response) => proto::StreamResult::Ok { response }, + Err(err) => { + let resp_err = err.downcast::()?; + let error = proto::Error { + message: resp_err.to_string(), + code: resp_err.code().into(), + }; + proto::StreamResult::Error { error } + } + }; + Ok(result) +} + +async fn try_handle( + stream_guard: &mut stream::Guard<'_>, + request: proto::StreamRequest, +) -> color_eyre::Result { + Ok(match request { + proto::StreamRequest::Close(_req) => { + stream_guard.close_db(); + proto::StreamResponse::Close(proto::CloseStreamResp {}) + } + proto::StreamRequest::Execute(req) => { + let db = stream_guard.get_db()?; + let sqls = stream_guard.sqls(); + let query = stmt::proto_stmt_to_query(&req.stmt, sqls, Version::Hrana2) + .map_err(catch_stmt_error)?; + let result = stmt::execute_stmt(db, query) + .await + .map_err(catch_stmt_error)?; + proto::StreamResponse::Execute(proto::ExecuteStreamResp { result }) + } + proto::StreamRequest::Batch(req) => { + let db = stream_guard.get_db()?; + let sqls = stream_guard.sqls(); + let pgm = batch::proto_batch_to_program(&req.batch, sqls, Version::Hrana2)?; + let result = batch::execute_batch(db, pgm).await?; + proto::StreamResponse::Batch(proto::BatchStreamResp { result }) + } + proto::StreamRequest::Sequence(req) => { + let db = stream_guard.get_db()?; + let sqls = stream_guard.sqls(); + let sql = + stmt::proto_sql_to_sql(req.sql.as_deref(), req.sql_id, sqls, Version::Hrana2)?; + let pgm = batch::proto_sequence_to_program(sql).map_err(catch_stmt_error)?; + batch::execute_sequence(db, pgm) + .await + .map_err(catch_stmt_error)?; + proto::StreamResponse::Sequence(proto::SequenceStreamResp {}) + } + proto::StreamRequest::Describe(req) => { + let db = stream_guard.get_db()?; + let sqls = stream_guard.sqls(); + let sql = + stmt::proto_sql_to_sql(req.sql.as_deref(), req.sql_id, sqls, Version::Hrana2)?; + let result = stmt::describe_stmt(db, sql.into()) + .await + .map_err(catch_stmt_error)?; + proto::StreamResponse::Describe(proto::DescribeStreamResp { result }) + } + proto::StreamRequest::StoreSql(req) => { + let sqls = stream_guard.sqls_mut(); + let sql_id = req.sql_id; + if sqls.contains_key(&sql_id) { + bail!(ProtocolError::SqlExists { sql_id }) + } else if sqls.len() >= MAX_SQL_COUNT { + bail!(StreamResponseError::SqlTooMany { count: sqls.len() }) + } + sqls.insert(sql_id, req.sql); + proto::StreamResponse::StoreSql(proto::StoreSqlStreamResp {}) + } + proto::StreamRequest::CloseSql(req) => { + let sqls = stream_guard.sqls_mut(); + sqls.remove(&req.sql_id); + proto::StreamResponse::CloseSql(proto::CloseSqlStreamResp {}) + } + }) +} + +const MAX_SQL_COUNT: usize = 50; + +fn catch_stmt_error(err: color_eyre::eyre::Error) -> color_eyre::eyre::Error { + match err.downcast::() { + Ok(stmt_err) => anyhow!(StreamResponseError::Stmt(stmt_err)), + Err(err) => err, + } +} + +impl StreamResponseError { + pub fn code(&self) -> &'static str { + match self { + Self::SqlTooMany { .. } => "SQL_STORE_TOO_MANY", + Self::Stmt(err) => err.code(), + } + } +} diff --git a/libsqlx-server/src/hrana/http/stream.rs b/libsqlx-server/src/hrana/http/stream.rs new file mode 100644 index 00000000..1261e7c2 --- /dev/null +++ b/libsqlx-server/src/hrana/http/stream.rs @@ -0,0 +1,404 @@ +use std::cmp::Reverse; +use std::collections::{HashMap, VecDeque}; +use std::pin::Pin; +use std::{future, mem, task}; + +use base64::prelude::{Engine as _, BASE64_STANDARD_NO_PAD}; +use color_eyre::eyre::{anyhow, WrapErr}; +use futures::Future; +use hmac::Mac as _; +use priority_queue::PriorityQueue; +use tokio::time::{Duration, Instant}; + +use super::super::ProtocolError; +use super::Server; +use crate::allocation::ConnectionHandle; + +/// Mutable state related to streams, owned by [`Server`] and protected with a mutex. +pub struct ServerStreamState { + /// Map from stream ids to stream handles. The stream ids are random integers. + handles: HashMap, + /// Queue of streams ordered by the instant when they should expire. All these stream ids + /// should refer to handles in the [`Handle::Available`] variant. + expire_queue: PriorityQueue>, + /// Queue of expired streams that are still stored as [`Handle::Expired`], together with the + /// instant when we should remove them completely. + cleanup_queue: VecDeque<(u64, Instant)>, + /// The timer that we use to wait for the next item in `expire_queue`. + expire_sleep: Pin>, + /// A waker to wake up the task that expires streams from the `expire_queue`. + expire_waker: Option, + /// See [`roundup_instant()`]. + expire_round_base: Instant, +} + +/// Handle to a stream, owned by the [`ServerStreamState`]. +enum Handle { + /// A stream that is open and ready to be used by requests. [`Stream::db`] should always be + /// `Some`. + Available(Box), + /// A stream that has been acquired by a request that hasn't finished processing. This will be + /// replaced with `Available` when the request completes and releases the stream. + Acquired, + /// A stream that has been expired. This stream behaves as closed, but we keep this around for + /// some time to provide a nicer error messages (i.e., if the stream is expired, we return a + /// "stream expired" error rather than "invalid baton" error). + Expired, +} + +/// State of a Hrana-over-HTTP stream. +/// +/// The stream is either owned by [`Handle::Available`] (when it's not in use) or by [`Guard`] +/// (when it's being used by a request). +struct Stream { + /// The database connection that corresponds to this stream. This is `None` after the `"close"` + /// request was executed. + conn: Option, + /// The cache of SQL texts stored on the server with `"store_sql"` requests. + sqls: HashMap, + /// Stream id of this stream. The id is generated randomly (it should be unguessable). + stream_id: u64, + /// Sequence number that is expected in the next baton. To make sure that clients issue stream + /// requests sequentially, the baton returned from each HTTP request includes this sequence + /// number, and the following HTTP request must show a baton with the same sequence number. + baton_seq: u64, +} + +/// Guard object that is used to access a stream from the outside. The guard makes sure that the +/// stream's entry in [`ServerStreamState::handles`] is either removed or replaced with +/// [`Handle::Available`] after the guard goes out of scope. +pub struct Guard<'srv> { + server: &'srv Server, + /// The guarded stream. This is only set to `None` in the destructor. + stream: Option>, + /// If set to `true`, the destructor will release the stream for further use (saving it as + /// [`Handle::Available`] in [`ServerStreamState::handles`]. If false, the stream is removed on + /// drop. + release: bool, +} + +/// An unrecoverable error that should close the stream. The difference from [`ProtocolError`] is +/// that a correct client may trigger this error, it does not mean that the protocol has been +/// violated. +#[derive(thiserror::Error, Debug)] +pub enum StreamError { + #[error("The stream has expired due to inactivity")] + StreamExpired, +} + +impl ServerStreamState { + pub fn new() -> Self { + Self { + handles: HashMap::new(), + expire_queue: PriorityQueue::new(), + cleanup_queue: VecDeque::new(), + expire_sleep: Box::pin(tokio::time::sleep(Duration::ZERO)), + expire_waker: None, + expire_round_base: Instant::now(), + } + } +} + +/// Acquire a guard to a new or existing stream. If baton is `Some`, we try to look up the stream, +/// otherwise we create a new stream. +pub async fn acquire<'srv, F, Fut>( + server: &'srv Server, + baton: Option<&str>, + mk_conn: F, +) -> color_eyre::Result> +where F: FnOnce() -> Fut, + Fut: Future>, +{ + let stream = match baton { + Some(baton) => { + let (stream_id, baton_seq) = decode_baton(server, baton)?; + + let mut state = server.stream_state.lock(); + let handle = state.handles.get_mut(&stream_id); + match handle { + None => { + return Err(ProtocolError::BatonInvalid(format!("Stream handle for {stream_id} was not found")).into()) + } + Some(Handle::Acquired) => { + return Err(ProtocolError::BatonReused) + .context(format!("Stream handle for {stream_id} is acquired")); + } + Some(Handle::Expired) => { + return Err(StreamError::StreamExpired) + .context(format!("Stream handle for {stream_id} is expired")); + } + Some(Handle::Available(stream)) => { + if stream.baton_seq != baton_seq { + return Err(ProtocolError::BatonReused).context(format!( + "Expected baton seq {}, received {baton_seq}", + stream.baton_seq + )); + } + } + }; + + let Handle::Available(mut stream) = mem::replace(handle.unwrap(), Handle::Acquired) else { + unreachable!() + }; + + tracing::debug!("Stream {stream_id} was acquired with baton seq {baton_seq}"); + // incrementing the sequence number forces the next HTTP request to use a different + // baton + stream.baton_seq = stream.baton_seq.wrapping_add(1); + unmark_expire(&mut state, stream.stream_id); + stream + } + None => { + let conn = mk_conn().await.context("Could not create a database connection")?; + + let mut state = server.stream_state.lock(); + let stream = Box::new(Stream { + conn: Some(conn), + sqls: HashMap::new(), + stream_id: gen_stream_id(&mut state), + // initializing the sequence number randomly makes it much harder to exploit + // collisions in batons + baton_seq: rand::random(), + }); + state.handles.insert(stream.stream_id, Handle::Acquired); + tracing::debug!( + "Stream {} was created with baton seq {}", + stream.stream_id, + stream.baton_seq + ); + stream + } + }; + Ok(Guard { + server, + stream: Some(stream), + release: false, + }) +} + +impl<'srv> Guard<'srv> { + pub fn get_db(&self) -> Result<&ConnectionHandle, ProtocolError> { + let stream = self.stream.as_ref().unwrap(); + stream.conn.as_ref().ok_or(ProtocolError::BatonStreamClosed) + } + + /// Closes the database connection. The next call to [`Guard::release()`] will then remove the + /// stream. + pub fn close_db(&mut self) { + let stream = self.stream.as_mut().unwrap(); + stream.conn = None; + } + + pub fn sqls(&self) -> &HashMap { + &self.stream.as_ref().unwrap().sqls + } + + pub fn sqls_mut(&mut self) -> &mut HashMap { + &mut self.stream.as_mut().unwrap().sqls + } + + /// Releases the guard and returns the baton that can be used to access this stream in the next + /// HTTP request. Returns `None` if the stream has been closed (and thus cannot be accessed + /// again). + pub fn release(mut self) -> Option { + let stream = self.stream.as_ref().unwrap(); + if stream.conn.is_some() { + self.release = true; // tell destructor to make the stream available again + Some(encode_baton( + self.server, + stream.stream_id, + stream.baton_seq, + )) + } else { + None + } + } +} + +impl<'srv> Drop for Guard<'srv> { + fn drop(&mut self) { + let stream = self.stream.take().unwrap(); + let stream_id = stream.stream_id; + + let mut state = self.server.stream_state.lock(); + let Some(handle) = state.handles.remove(&stream_id) else { + panic!("Dropped a Guard for stream {stream_id}, \ + but Server does not contain a handle to it"); + }; + if !matches!(handle, Handle::Acquired) { + panic!( + "Dropped a Guard for stream {stream_id}, \ + but Server contained handle that is not acquired" + ); + } + + if self.release { + state.handles.insert(stream_id, Handle::Available(stream)); + mark_expire(&mut state, stream_id); + tracing::debug!("Stream {stream_id} was released for further use"); + } else { + tracing::debug!("Stream {stream_id} was closed"); + } + } +} + +fn gen_stream_id(state: &mut ServerStreamState) -> u64 { + for _ in 0..10 { + let stream_id = rand::random(); + if !state.handles.contains_key(&stream_id) { + return stream_id; + } + } + panic!("Failed to generate a free stream id with rejection sampling") +} + +/// Encodes the baton. +/// +/// The baton is base64-encoded byte string that is composed from: +/// +/// - payload (16 bytes): +/// - `stream_id` (8 bytes, big endian) +/// - `baton_seq` (8 bytes, big endian) +/// - MAC (32 bytes): an authentication code generated with HMAC-SHA256 +/// +/// The MAC is used to cryptographically verify that the baton was generated by this server. It is +/// unlikely that we ever issue the same baton twice, because there are 2^128 possible combinations +/// for payload (note that both `stream_id` and the initial `baton_seq` are generated randomly). +fn encode_baton(server: &Server, stream_id: u64, baton_seq: u64) -> String { + let mut payload = [0; 16]; + payload[0..8].copy_from_slice(&stream_id.to_be_bytes()); + payload[8..16].copy_from_slice(&baton_seq.to_be_bytes()); + + let mut hmac = hmac::Hmac::::new_from_slice(&server.baton_key).unwrap(); + hmac.update(&payload); + let mac = hmac.finalize().into_bytes(); + + let mut baton_data = [0; 48]; + baton_data[0..16].copy_from_slice(&payload); + baton_data[16..48].copy_from_slice(&mac); + BASE64_STANDARD_NO_PAD.encode(baton_data) +} + +/// Decodes a baton encoded with `encode_baton()` and returns `(stream_id, baton_seq)`. Always +/// returns a [`ProtocolError::BatonInvalid`] if the baton is invalid, but it attaches an anyhow +/// context that describes the precise cause. +fn decode_baton(server: &Server, baton_str: &str) -> color_eyre::Result<(u64, u64)> { + let baton_data = BASE64_STANDARD_NO_PAD.decode(baton_str).map_err(|err| { + ProtocolError::BatonInvalid(format!("Could not base64-decode baton: {err}")) + })?; + + if baton_data.len() != 48 { + return Err(ProtocolError::BatonInvalid(format!( + "Baton has invalid size of {} bytes", + baton_data.len() + )).into()); + } + + let payload = &baton_data[0..16]; + let received_mac = &baton_data[16..48]; + + let mut hmac = hmac::Hmac::::new_from_slice(&server.baton_key).unwrap(); + hmac.update(payload); + hmac.verify_slice(received_mac) + .map_err(|_| anyhow!(ProtocolError::BatonInvalid("Invalid MAC on baton".to_string())))?; + + let stream_id = u64::from_be_bytes(payload[0..8].try_into().unwrap()); + let baton_seq = u64::from_be_bytes(payload[8..16].try_into().unwrap()); + Ok((stream_id, baton_seq)) +} + +/// How long do we keep a stream in [`Handle::Available`] state before expiration. Note that every +/// HTTP request resets the timer to beginning, so the client can keep a stream alive for a long +/// time, as long as it pings regularly. +const EXPIRATION: Duration = Duration::from_secs(10); + +/// How long do we keep an expired stream in [`Handle::Expired`] state before removing it for good. +const CLEANUP: Duration = Duration::from_secs(300); + +fn mark_expire(state: &mut ServerStreamState, stream_id: u64) { + let expire_at = roundup_instant(state, Instant::now() + EXPIRATION); + if state.expire_sleep.deadline() > expire_at { + if let Some(waker) = state.expire_waker.take() { + waker.wake(); + } + } + state.expire_queue.push(stream_id, Reverse(expire_at)); +} + +fn unmark_expire(state: &mut ServerStreamState, stream_id: u64) { + state.expire_queue.remove(&stream_id); +} + +/// Handles stream expiration (and cleanup). The returned future is never resolved. +pub async fn run_expire(server: &Server) { + future::poll_fn(|cx| { + let mut state = server.stream_state.lock(); + pump_expire(&mut state, cx); + task::Poll::Pending + }) + .await +} + +fn pump_expire(state: &mut ServerStreamState, cx: &mut task::Context) { + let now = Instant::now(); + + // expire all streams in the `expire_queue` that have passed their expiration time + let wakeup_at = loop { + let stream_id = match state.expire_queue.peek() { + Some((&stream_id, &Reverse(expire_at))) => { + if expire_at <= now { + stream_id + } else { + break expire_at; + } + } + None => break now + Duration::from_secs(60), + }; + state.expire_queue.pop(); + + match state.handles.get_mut(&stream_id) { + Some(handle @ Handle::Available(_)) => { + *handle = Handle::Expired; + } + _ => continue, + } + tracing::debug!("Stream {stream_id} was expired"); + + let cleanup_at = roundup_instant(state, now + CLEANUP); + state.cleanup_queue.push_back((stream_id, cleanup_at)); + }; + + // completely remove streams that are due in `cleanup_queue` + loop { + let stream_id = match state.cleanup_queue.front() { + Some(&(stream_id, cleanup_at)) if cleanup_at <= now => stream_id, + _ => break, + }; + state.cleanup_queue.pop_front(); + + let handle = state.handles.remove(&stream_id); + assert!(matches!(handle, Some(Handle::Expired))); + tracing::debug!("Stream {stream_id} was cleaned up after expiration"); + } + + // make sure that this function is called again no later than at time `wakeup_at` + state.expire_sleep.as_mut().reset(wakeup_at); + state.expire_waker = Some(cx.waker().clone()); + let _: task::Poll<()> = state.expire_sleep.as_mut().poll(cx); +} + +/// Rounds the `instant` to the next second. This is used to ensure that streams that expire close +/// together are expired at exactly the same instant, thus reducing the number of times that +/// [`pump_expire()`] is called during periods of high load. +fn roundup_instant(state: &ServerStreamState, instant: Instant) -> Instant { + let duration_s = (instant - state.expire_round_base).as_secs(); + state.expire_round_base + Duration::from_secs(duration_s + 1) +} + +impl StreamError { + pub fn code(&self) -> &'static str { + match self { + Self::StreamExpired => "STREAM_EXPIRED", + } + } +} diff --git a/libsqlx-server/src/hrana/mod.rs b/libsqlx-server/src/hrana/mod.rs new file mode 100644 index 00000000..fc85fcfe --- /dev/null +++ b/libsqlx-server/src/hrana/mod.rs @@ -0,0 +1,68 @@ +use std::fmt; + +pub mod batch; +pub mod http; +pub mod proto; +mod result_builder; +pub mod stmt; +// pub mod ws; + +#[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)] +pub enum Version { + Hrana1, + Hrana2, +} + +impl fmt::Display for Version { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Version::Hrana1 => write!(f, "hrana1"), + Version::Hrana2 => write!(f, "hrana2"), + } + } +} + +/// An unrecoverable protocol error that should close the WebSocket or HTTP stream. A correct +/// client should never trigger any of these errors. +#[derive(thiserror::Error, Debug)] +pub enum ProtocolError { + #[error("Cannot deserialize client message: {source}")] + Deserialize { source: serde_json::Error }, + #[error("Received a binary WebSocket message, which is not supported")] + BinaryWebSocketMessage, + #[error("Received a request before hello message")] + RequestBeforeHello, + + #[error("Stream {stream_id} not found")] + StreamNotFound { stream_id: i32 }, + #[error("Stream {stream_id} already exists")] + StreamExists { stream_id: i32 }, + + #[error("Either `sql` or `sql_id` are required, but not both")] + SqlIdAndSqlGiven, + #[error("Either `sql` or `sql_id` are required")] + SqlIdOrSqlNotGiven, + #[error("SQL text {sql_id} not found")] + SqlNotFound { sql_id: i32 }, + #[error("SQL text {sql_id} already exists")] + SqlExists { sql_id: i32 }, + + #[error("Invalid reference to step in a batch condition")] + BatchCondBadStep, + + #[error("Received an invalid baton: {0}")] + BatonInvalid(String), + #[error("Received a baton that has already been used")] + BatonReused, + #[error("Stream for this baton was closed")] + BatonStreamClosed, + + #[error("{what} is only supported in protocol version {min_version} and higher")] + NotSupported { + what: &'static str, + min_version: Version, + }, + + #[error("{0}")] + ResponseTooLarge(String), +} diff --git a/libsqlx-server/src/hrana/proto.rs b/libsqlx-server/src/hrana/proto.rs new file mode 100644 index 00000000..8d544a07 --- /dev/null +++ b/libsqlx-server/src/hrana/proto.rs @@ -0,0 +1,160 @@ +//! Structures in Hrana that are common for WebSockets and HTTP. + +use bytes::Bytes; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +#[derive(Serialize, Debug)] +pub struct Error { + pub message: String, + pub code: String, +} + +#[derive(Deserialize, Debug)] +pub struct Stmt { + #[serde(default)] + pub sql: Option, + #[serde(default)] + pub sql_id: Option, + #[serde(default)] + pub args: Vec, + #[serde(default)] + pub named_args: Vec, + #[serde(default)] + pub want_rows: Option, +} + +#[derive(Deserialize, Debug)] +pub struct NamedArg { + pub name: String, + pub value: Value, +} + +#[derive(Serialize, Debug)] +pub struct StmtResult { + pub cols: Vec, + pub rows: Vec>, + pub affected_row_count: u64, + #[serde(with = "option_i64_as_str")] + pub last_insert_rowid: Option, +} + +#[derive(Serialize, Debug)] +pub struct Col { + pub name: Option, + pub decltype: Option, +} + +#[derive(Deserialize, Debug)] +pub struct Batch { + pub steps: Vec, +} + +#[derive(Deserialize, Debug)] +pub struct BatchStep { + pub stmt: Stmt, + #[serde(default)] + pub condition: Option, +} + +#[derive(Serialize, Debug)] +pub struct BatchResult { + pub step_results: Vec>, + pub step_errors: Vec>, +} + +#[derive(Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum BatchCond { + Ok { step: i32 }, + Error { step: i32 }, + Not { cond: Box }, + And { conds: Vec }, + Or { conds: Vec }, +} + +#[derive(Serialize, Debug)] +pub struct DescribeResult { + pub params: Vec, + pub cols: Vec, + pub is_explain: bool, + pub is_readonly: bool, +} + +#[derive(Serialize, Debug)] +pub struct DescribeParam { + pub name: Option, +} + +#[derive(Serialize, Debug)] +pub struct DescribeCol { + pub name: String, + pub decltype: Option, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Value { + Null, + Integer { + #[serde(with = "i64_as_str")] + value: i64, + }, + Float { + value: f64, + }, + Text { + value: Arc, + }, + Blob { + #[serde(with = "bytes_as_base64", rename = "base64")] + value: Bytes, + }, +} + +mod i64_as_str { + use serde::{de, ser}; + use serde::{de::Error as _, Serialize as _}; + + pub fn serialize(value: &i64, ser: S) -> Result { + value.to_string().serialize(ser) + } + + pub fn deserialize<'de, D: de::Deserializer<'de>>(de: D) -> Result { + let str_value = <&'de str as de::Deserialize>::deserialize(de)?; + str_value.parse().map_err(|_| { + D::Error::invalid_value( + de::Unexpected::Str(str_value), + &"decimal integer as a string", + ) + }) + } +} + +mod option_i64_as_str { + use serde::{ser, Serialize as _}; + + pub fn serialize(value: &Option, ser: S) -> Result { + value.map(|v| v.to_string()).serialize(ser) + } +} + +mod bytes_as_base64 { + use base64::{engine::general_purpose::STANDARD_NO_PAD, Engine as _}; + use bytes::Bytes; + use serde::{de, ser}; + use serde::{de::Error as _, Serialize as _}; + + pub fn serialize(value: &Bytes, ser: S) -> Result { + STANDARD_NO_PAD.encode(value).serialize(ser) + } + + pub fn deserialize<'de, D: de::Deserializer<'de>>(de: D) -> Result { + let text = <&'de str as de::Deserialize>::deserialize(de)?; + let text = text.trim_end_matches('='); + let bytes = STANDARD_NO_PAD.decode(text).map_err(|_| { + D::Error::invalid_value(de::Unexpected::Str(text), &"binary data encoded as base64") + })?; + Ok(Bytes::from(bytes)) + } +} diff --git a/libsqlx-server/src/hrana/result_builder.rs b/libsqlx-server/src/hrana/result_builder.rs new file mode 100644 index 00000000..94b23775 --- /dev/null +++ b/libsqlx-server/src/hrana/result_builder.rs @@ -0,0 +1,320 @@ +use std::fmt::{self, Write as _}; +use std::io; + +use bytes::Bytes; +use libsqlx::{result_builder::*, FrameNo}; + +use crate::hrana::stmt::{proto_error_from_stmt_error, stmt_error_from_sqld_error}; + +use super::proto; + +#[derive(Debug, Default)] +pub struct SingleStatementBuilder { + has_step: bool, + cols: Vec, + rows: Vec>, + err: Option, + affected_row_count: u64, + last_insert_rowid: Option, + current_size: u64, + max_response_size: u64, +} + +impl SingleStatementBuilder { + pub fn into_ret(self) -> Result { + match self.err { + Some(err) => Err(err), + None => Ok(proto::StmtResult { + cols: self.cols, + rows: self.rows, + affected_row_count: self.affected_row_count, + last_insert_rowid: self.last_insert_rowid, + }), + } + } +} + +struct SizeFormatter(u64); + +impl io::Write for SizeFormatter { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0 += buf.len() as u64; + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl fmt::Write for SizeFormatter { + fn write_str(&mut self, s: &str) -> fmt::Result { + self.0 += s.len() as u64; + Ok(()) + } +} + +fn value_json_size(v: &ValueRef) -> u64 { + let mut f = SizeFormatter(0); + match v { + ValueRef::Null => write!(&mut f, r#"{{"type":"null"}}"#).unwrap(), + ValueRef::Integer(i) => write!(&mut f, r#"{{"type":"integer", "value": "{i}"}}"#).unwrap(), + ValueRef::Real(x) => write!(&mut f, r#"{{"type":"integer","value": {x}"}}"#).unwrap(), + ValueRef::Text(s) => { + // error will be caught later. + if let Ok(s) = std::str::from_utf8(s) { + write!(&mut f, r#"{{"type":"text","value":"{s}"}}"#).unwrap() + } + } + ValueRef::Blob(b) => return b.len() as u64, + } + + f.0 +} + +impl ResultBuilder for SingleStatementBuilder { + fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + *self = Self { + max_response_size: config.max_size.unwrap_or(u64::MAX), + ..Default::default() + }; + + Ok(()) + } + + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + // SingleStatementBuilder only builds a single statement + assert!(!self.has_step); + self.has_step = true; + Ok(()) + } + + fn finish_step( + &mut self, + affected_row_count: u64, + last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + self.last_insert_rowid = last_insert_rowid; + self.affected_row_count = affected_row_count; + + Ok(()) + } + + fn step_error(&mut self, error: libsqlx::error::Error) -> Result<(), QueryResultBuilderError> { + assert!(self.err.is_none()); + let mut f = SizeFormatter(0); + write!(&mut f, "{error}").unwrap(); + self.current_size = f.0; + + self.err = Some(error); + + Ok(()) + } + + fn cols_description<'a>( + &mut self, + cols: &mut dyn Iterator, + ) -> Result<(), QueryResultBuilderError> { + assert!(self.err.is_none()); + assert!(self.cols.is_empty()); + + let mut cols_size = 0; + + self.cols.extend(cols.into_iter().map(Into::into).map(|c| { + cols_size += estimate_cols_json_size(&c); + proto::Col { + name: Some(c.name.to_owned()), + decltype: c.decl_ty.map(ToString::to_string), + } + })); + + self.current_size += cols_size; + if self.current_size > self.max_response_size { + return Err(QueryResultBuilderError::ResponseTooLarge( + self.max_response_size, + )); + } + + Ok(()) + } + + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + assert!(self.err.is_none()); + assert!(self.rows.is_empty()); + Ok(()) + } + + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + assert!(self.err.is_none()); + self.rows.push(Vec::with_capacity(self.cols.len())); + Ok(()) + } + + fn add_row_value(&mut self, v: ValueRef) -> Result<(), QueryResultBuilderError> { + assert!(self.err.is_none()); + let estimate_size = value_json_size(&v); + if self.current_size + estimate_size > self.max_response_size { + return Err(QueryResultBuilderError::ResponseTooLarge( + self.max_response_size, + )); + } + + self.current_size += estimate_size; + + let val = match v { + ValueRef::Null => proto::Value::Null, + ValueRef::Integer(value) => proto::Value::Integer { value }, + ValueRef::Real(value) => proto::Value::Float { value }, + ValueRef::Text(s) => proto::Value::Text { + value: String::from_utf8(s.to_vec()) + .map_err(QueryResultBuilderError::from_any)? + .into(), + }, + ValueRef::Blob(d) => proto::Value::Blob { + value: Bytes::copy_from_slice(d), + }, + }; + + self.rows + .last_mut() + .expect("row must be initialized") + .push(val); + + Ok(()) + } + + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + assert!(self.err.is_none()); + Ok(()) + } + + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + assert!(self.err.is_none()); + Ok(()) + } + + fn finish( + &mut self, + _is_txn: bool, + _frame_no: Option, + ) -> Result<(), QueryResultBuilderError> { + Ok(()) + } +} + +fn estimate_cols_json_size(c: &Column) -> u64 { + let mut f = SizeFormatter(0); + write!( + &mut f, + r#"{{"name":"{}","decltype":"{}"}}"#, + c.name, + c.decl_ty.unwrap_or("null") + ) + .unwrap(); + f.0 +} + +#[derive(Debug, Default)] +pub struct HranaBatchProtoBuilder { + step_results: Vec>, + step_errors: Vec>, + stmt_builder: SingleStatementBuilder, + current_size: u64, + max_response_size: u64, + step_empty: bool, +} + +impl HranaBatchProtoBuilder { + pub fn into_ret(self) -> proto::BatchResult { + proto::BatchResult { + step_results: self.step_results, + step_errors: self.step_errors, + } + } +} + +impl ResultBuilder for HranaBatchProtoBuilder { + fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + *self = Self { + max_response_size: config.max_size.unwrap_or(u64::MAX), + ..Default::default() + }; + self.stmt_builder.init(config)?; + Ok(()) + } + + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + self.step_empty = true; + self.stmt_builder.begin_step() + } + + fn finish_step( + &mut self, + affected_row_count: u64, + last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + self.stmt_builder + .finish_step(affected_row_count, last_insert_rowid)?; + self.current_size += self.stmt_builder.current_size; + + let new_builder = SingleStatementBuilder { + current_size: 0, + max_response_size: self.max_response_size - self.current_size, + ..Default::default() + }; + match std::mem::replace(&mut self.stmt_builder, new_builder).into_ret() { + Ok(res) => { + self.step_results.push((!self.step_empty).then_some(res)); + self.step_errors.push(None); + } + Err(e) => { + self.step_results.push(None); + self.step_errors.push(Some(proto_error_from_stmt_error( + &stmt_error_from_sqld_error(e).map_err(QueryResultBuilderError::from_any)?, + ))); + } + } + + Ok(()) + } + + fn step_error(&mut self, error: libsqlx::error::Error) -> Result<(), QueryResultBuilderError> { + self.stmt_builder.step_error(error) + } + + fn cols_description( + &mut self, + cols: &mut dyn Iterator, + ) -> Result<(), QueryResultBuilderError> { + self.step_empty = false; + self.stmt_builder.cols_description(cols) + } + + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + self.stmt_builder.begin_rows() + } + + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.stmt_builder.begin_row() + } + + fn add_row_value(&mut self, v: ValueRef) -> Result<(), QueryResultBuilderError> { + self.stmt_builder.add_row_value(v) + } + + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.stmt_builder.finish_row() + } + + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + + fn finish( + &mut self, + _is_txn: bool, + _frame_no: Option, + ) -> Result<(), QueryResultBuilderError> { + Ok(()) + } +} diff --git a/libsqlx-server/src/hrana/stmt.rs b/libsqlx-server/src/hrana/stmt.rs new file mode 100644 index 00000000..e74c3d42 --- /dev/null +++ b/libsqlx-server/src/hrana/stmt.rs @@ -0,0 +1,289 @@ +use std::collections::HashMap; + +use color_eyre::eyre::{bail, anyhow}; +use libsqlx::analysis::Statement; +use libsqlx::query::{Query, Params, Value}; + +use super::result_builder::SingleStatementBuilder; +use super::{proto, ProtocolError, Version}; +use crate::allocation::ConnectionHandle; +use crate::hrana; + +/// An error during execution of an SQL statement. +#[derive(thiserror::Error, Debug)] +pub enum StmtError { + #[error("SQL string could not be parsed: {source}")] + SqlParse { source: color_eyre::eyre::Error }, + #[error("SQL string does not contain any statement")] + SqlNoStmt, + #[error("SQL string contains more than one statement")] + SqlManyStmts, + #[error("Arguments do not match SQL parameters: {msg}")] + ArgsInvalid { msg: String }, + #[error("Specifying both positional and named arguments is not supported")] + ArgsBothPositionalAndNamed, + + #[error("Transaction timed out")] + TransactionTimeout, + #[error("Server cannot handle additional transactions")] + TransactionBusy, + #[error("SQLite error: {message}")] + SqliteError { + source: libsqlx::rusqlite::ffi::Error, + message: String, + }, + #[error("SQL input error: {message} (at offset {offset})")] + SqlInputError { + source: color_eyre::eyre::Error, + message: String, + offset: i32, + }, + + #[error("Operation was blocked{}", .reason.as_ref().map(|msg| format!(": {}", msg)).unwrap_or_default())] + Blocked { reason: Option }, +} + +pub async fn execute_stmt( + conn: &ConnectionHandle, + query: Query, +) -> color_eyre::Result { + let builder = conn.exec(move |conn| -> color_eyre::Result<_> { + let mut builder = SingleStatementBuilder::default(); + let pgm = libsqlx::program::Program::from_queries(std::iter::once(query)); + conn.execute_program(pgm, &mut builder)?; + + Ok(builder) + + }).await??; + + builder + .into_ret() + .map_err(|sqld_error| match stmt_error_from_sqld_error(sqld_error) { + Ok(stmt_error) => anyhow!(stmt_error), + Err(sqld_error) => anyhow!(sqld_error), + }) +} + +pub async fn describe_stmt( + _db: &ConnectionHandle, + _sql: String, +) -> color_eyre::Result { + todo!(); + // match db.describe(sql).await? { + // Ok(describe_response) => todo!(), + // // Ok(proto_describe_result_from_describe_response( + // // describe_response, + // // )), + // Err(sqld_error) => match stmt_error_from_sqld_error(sqld_error) { + // Ok(stmt_error) => bail!(stmt_error), + // Err(sqld_error) => bail!(sqld_error), + // }, + // } +} + +pub fn proto_stmt_to_query( + proto_stmt: &proto::Stmt, + sqls: &HashMap, + version: Version, +) -> color_eyre::Result { + let sql = proto_sql_to_sql(proto_stmt.sql.as_deref(), proto_stmt.sql_id, sqls, version)?; + + let mut stmt_iter = Statement::parse(sql); + let stmt = match stmt_iter.next() { + Some(Ok(stmt)) => stmt, + Some(Err(err)) => bail!(StmtError::SqlParse { source: err.into() }), + None => bail!(StmtError::SqlNoStmt), + }; + + if stmt_iter.next().is_some() { + bail!(StmtError::SqlManyStmts) + } + + let params = if proto_stmt.named_args.is_empty() { + let values = proto_stmt.args.iter().map(proto_value_to_value).collect(); + Params::Positional(values) + } else if proto_stmt.args.is_empty() { + let values = proto_stmt + .named_args + .iter() + .map(|arg| (arg.name.clone(), proto_value_to_value(&arg.value))) + .collect(); + Params::Named(values) + } else { + bail!(StmtError::ArgsBothPositionalAndNamed) + }; + + let want_rows = proto_stmt.want_rows.unwrap_or(true); + Ok(Query { + stmt, + params, + want_rows, + }) +} + +pub fn proto_sql_to_sql<'s>( + proto_sql: Option<&'s str>, + proto_sql_id: Option, + sqls: &'s HashMap, + verion: Version, +) -> Result<&'s str, ProtocolError> { + if proto_sql_id.is_some() && verion < Version::Hrana2 { + return Err(ProtocolError::NotSupported { + what: "`sql_id`", + min_version: Version::Hrana2, + }); + } + + match (proto_sql, proto_sql_id) { + (Some(sql), None) => Ok(sql), + (None, Some(sql_id)) => match sqls.get(&sql_id) { + Some(sql) => Ok(sql), + None => Err(ProtocolError::SqlNotFound { sql_id }), + }, + (Some(_), Some(_)) => Err(ProtocolError::SqlIdAndSqlGiven), + (None, None) => Err(ProtocolError::SqlIdOrSqlNotGiven), + } +} + +fn proto_value_to_value(proto_value: &proto::Value) -> Value { + match proto_value { + proto::Value::Null => Value::Null, + proto::Value::Integer { value } => Value::Integer(*value), + proto::Value::Float { value } => Value::Real(*value), + proto::Value::Text { value } => Value::Text(value.as_ref().into()), + proto::Value::Blob { value } => Value::Blob(value.as_ref().into()), + } +} + +fn proto_value_from_value(value: Value) -> proto::Value { + match value { + Value::Null => proto::Value::Null, + Value::Integer(value) => proto::Value::Integer { value }, + Value::Real(value) => proto::Value::Float { value }, + Value::Text(value) => proto::Value::Text { + value: value.into(), + }, + Value::Blob(value) => proto::Value::Blob { + value: value.into(), + }, + } +} + +// fn proto_describe_result_from_describe_response( +// response: DescribeResponse, +// ) -> proto::DescribeResult { +// proto::DescribeResult { +// params: response +// .params +// .into_iter() +// .map(|p| proto::DescribeParam { name: p.name }) +// .collect(), +// cols: response +// .cols +// .into_iter() +// .map(|c| proto::DescribeCol { +// name: c.name, +// decltype: c.decltype, +// }) +// .collect(), +// is_explain: response.is_explain, +// is_readonly: response.is_readonly, +// } +// } + +pub fn stmt_error_from_sqld_error(sqld_error: libsqlx::error::Error) -> Result { + Ok(match sqld_error { + libsqlx::error::Error::LibSqlInvalidQueryParams(msg) => StmtError::ArgsInvalid { msg }, + libsqlx::error::Error::LibSqlTxTimeout => StmtError::TransactionTimeout, + libsqlx::error::Error::LibSqlTxBusy => StmtError::TransactionBusy, + libsqlx::error::Error::Blocked(reason) => StmtError::Blocked { reason }, + libsqlx::error::Error::RusqliteError(rusqlite_error) => match rusqlite_error { + libsqlx::error::RusqliteError::SqliteFailure(sqlite_error, Some(message)) => StmtError::SqliteError { + source: sqlite_error, + message, + }, + libsqlx::error::RusqliteError::SqliteFailure(sqlite_error, None) => StmtError::SqliteError { + message: sqlite_error.to_string(), + source: sqlite_error, + }, + libsqlx::error::RusqliteError::SqlInputError { + error: sqlite_error, + msg: message, + offset, + .. + } => StmtError::SqlInputError { + source: sqlite_error.into(), + message, + offset, + }, + rusqlite_error => return Err(libsqlx::error::Error::RusqliteError(rusqlite_error)), + }, + sqld_error => return Err(sqld_error), + }) +} + +pub fn proto_error_from_stmt_error(error: &StmtError) -> hrana::proto::Error { + hrana::proto::Error { + message: error.to_string(), + code: error.code().into(), + } +} + +impl StmtError { + pub fn code(&self) -> &'static str { + match self { + Self::SqlParse { .. } => "SQL_PARSE_ERROR", + Self::SqlNoStmt => "SQL_NO_STATEMENT", + Self::SqlManyStmts => "SQL_MANY_STATEMENTS", + Self::ArgsInvalid { .. } => "ARGS_INVALID", + Self::ArgsBothPositionalAndNamed => "ARGS_BOTH_POSITIONAL_AND_NAMED", + Self::TransactionTimeout => "TRANSACTION_TIMEOUT", + Self::TransactionBusy => "TRANSACTION_BUSY", + Self::SqliteError { source, .. } => sqlite_error_code(source.code), + Self::SqlInputError { .. } => "SQL_INPUT_ERROR", + Self::Blocked { .. } => "BLOCKED", + } + } +} + +fn sqlite_error_code(code: libsqlx::error::ErrorCode) -> &'static str { + match code { + libsqlx::error::ErrorCode::InternalMalfunction => "SQLITE_INTERNAL", + libsqlx::error::ErrorCode::PermissionDenied => "SQLITE_PERM", + libsqlx::error::ErrorCode::OperationAborted => "SQLITE_ABORT", + libsqlx::error::ErrorCode::DatabaseBusy => "SQLITE_BUSY", + libsqlx::error::ErrorCode::DatabaseLocked => "SQLITE_LOCKED", + libsqlx::error::ErrorCode::OutOfMemory => "SQLITE_NOMEM", + libsqlx::error::ErrorCode::ReadOnly => "SQLITE_READONLY", + libsqlx::error::ErrorCode::OperationInterrupted => "SQLITE_INTERRUPT", + libsqlx::error::ErrorCode::SystemIoFailure => "SQLITE_IOERR", + libsqlx::error::ErrorCode::DatabaseCorrupt => "SQLITE_CORRUPT", + libsqlx::error::ErrorCode::NotFound => "SQLITE_NOTFOUND", + libsqlx::error::ErrorCode::DiskFull => "SQLITE_FULL", + libsqlx::error::ErrorCode::CannotOpen => "SQLITE_CANTOPEN", + libsqlx::error::ErrorCode::FileLockingProtocolFailed => "SQLITE_PROTOCOL", + libsqlx::error::ErrorCode::SchemaChanged => "SQLITE_SCHEMA", + libsqlx::error::ErrorCode::TooBig => "SQLITE_TOOBIG", + libsqlx::error::ErrorCode::ConstraintViolation => "SQLITE_CONSTRAINT", + libsqlx::error::ErrorCode::TypeMismatch => "SQLITE_MISMATCH", + libsqlx::error::ErrorCode::ApiMisuse => "SQLITE_MISUSE", + libsqlx::error::ErrorCode::NoLargeFileSupport => "SQLITE_NOLFS", + libsqlx::error::ErrorCode::AuthorizationForStatementDenied => "SQLITE_AUTH", + libsqlx::error::ErrorCode::ParameterOutOfRange => "SQLITE_RANGE", + libsqlx::error::ErrorCode::NotADatabase => "SQLITE_NOTADB", + libsqlx::error::ErrorCode::Unknown => "SQLITE_UNKNOWN", + _ => "SQLITE_UNKNOWN", + } +} + +impl From<&proto::Value> for Value { + fn from(proto_value: &proto::Value) -> Value { + proto_value_to_value(proto_value) + } +} + +impl From for proto::Value { + fn from(value: Value) -> proto::Value { + proto_value_from_value(value) + } +} diff --git a/libsqlx-server/src/hrana/ws/conn.rs b/libsqlx-server/src/hrana/ws/conn.rs new file mode 100644 index 00000000..44daf98f --- /dev/null +++ b/libsqlx-server/src/hrana/ws/conn.rs @@ -0,0 +1,301 @@ +use std::borrow::Cow; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use anyhow::{bail, Context as _, Result}; +use futures::stream::FuturesUnordered; +use futures::{ready, FutureExt as _, StreamExt as _}; +use tokio::sync::oneshot; +use tokio_tungstenite::tungstenite; +use tungstenite::protocol::frame::coding::CloseCode; + +use crate::database::Database; + +use super::super::{ProtocolError, Version}; +use super::handshake::WebSocket; +use super::{handshake, proto, session, Server, Upgrade}; + +/// State of a Hrana connection. +struct Conn { + conn_id: u64, + server: Arc>, + ws: WebSocket, + ws_closed: bool, + /// The version of the protocol that has been negotiated in the WebSocket handshake. + version: Version, + /// After a successful authentication, this contains the session-level state of the connection. + session: Option>, + /// Join set for all tasks that were spawned to handle the connection. + join_set: tokio::task::JoinSet<()>, + /// Future responses to requests that we have received but are evaluating asynchronously. + responses: FuturesUnordered, +} + +/// A `Future` that stores a handle to a future response to request which is being evaluated +/// asynchronously. +struct ResponseFuture { + /// The request id, which must be included in the response. + request_id: i32, + /// The future that will be resolved with the response. + response_rx: futures::future::Fuse>>, +} + +pub(super) async fn handle_tcp( + server: Arc>, + socket: tokio::net::TcpStream, + conn_id: u64, +) -> Result<()> { + let (ws, version) = handshake::handshake_tcp(socket) + .await + .context("Could not perform the WebSocket handshake on TCP connection")?; + handle_ws(server, ws, version, conn_id).await +} + +pub(super) async fn handle_upgrade( + server: Arc>, + upgrade: Upgrade, + conn_id: u64, +) -> Result<()> { + let (ws, version) = handshake::handshake_upgrade(upgrade) + .await + .context("Could not perform the WebSocket handshake on HTTP connection")?; + handle_ws(server, ws, version, conn_id).await +} + +async fn handle_ws( + server: Arc>, + ws: WebSocket, + version: Version, + conn_id: u64, +) -> Result<()> { + let mut conn = Conn { + conn_id, + server, + ws, + ws_closed: false, + version, + session: None, + join_set: tokio::task::JoinSet::new(), + responses: FuturesUnordered::new(), + }; + + loop { + if let Some(kicker) = conn.server.idle_kicker.as_ref() { + kicker.kick(); + } + + tokio::select! { + Some(client_msg_res) = conn.ws.recv() => { + let client_msg = client_msg_res + .context("Could not receive a WebSocket message")?; + match handle_msg(&mut conn, client_msg).await { + Ok(true) => continue, + Ok(false) => break, + Err(err) => { + match err.downcast::() { + Ok(proto_err) => { + tracing::warn!( + "Connection #{} terminated due to protocol error: {}", + conn.conn_id, + proto_err, + ); + let close_code = protocol_error_to_close_code(&proto_err); + close(&mut conn, close_code, proto_err.to_string()).await; + return Ok(()) + } + Err(err) => { + close(&mut conn, CloseCode::Error, "Internal server error".into()).await; + return Err(err); + } + } + } + } + }, + Some(task_res) = conn.join_set.join_next() => { + task_res.expect("Connection subtask failed") + }, + Some(response_res) = conn.responses.next() => { + let response_msg = response_res?; + send_msg(&mut conn, &response_msg).await?; + }, + else => break, + } + } + + close( + &mut conn, + CloseCode::Normal, + "Thank you for using sqld".into(), + ) + .await; + Ok(()) +} + +async fn handle_msg( + conn: &mut Conn, + client_msg: tungstenite::Message, +) -> Result { + match client_msg { + tungstenite::Message::Text(client_msg) => { + // client messages are received as text WebSocket messages that encode the `ClientMsg` + // in JSON + let client_msg: proto::ClientMsg = match serde_json::from_str(&client_msg) { + Ok(client_msg) => client_msg, + Err(err) => bail!(ProtocolError::Deserialize { source: err }), + }; + + match client_msg { + proto::ClientMsg::Hello { jwt } => handle_hello_msg(conn, jwt).await, + proto::ClientMsg::Request { + request_id, + request, + } => handle_request_msg(conn, request_id, request).await, + } + } + tungstenite::Message::Binary(_) => bail!(ProtocolError::BinaryWebSocketMessage), + tungstenite::Message::Ping(ping_data) => { + let pong_msg = tungstenite::Message::Pong(ping_data); + conn.ws + .send(pong_msg) + .await + .context("Could not send pong to the WebSocket")?; + Ok(true) + } + tungstenite::Message::Pong(_) => Ok(true), + tungstenite::Message::Close(_) => Ok(false), + tungstenite::Message::Frame(_) => panic!("Received a tungstenite::Message::Frame"), + } +} + +async fn handle_hello_msg(conn: &mut Conn, jwt: Option) -> Result { + let hello_res = match conn.session.as_mut() { + None => session::handle_initial_hello(&conn.server, conn.version, jwt) + .map(|session| conn.session = Some(session)), + Some(session) => session::handle_repeated_hello(&conn.server, session, jwt), + }; + + match hello_res { + Ok(_) => { + send_msg(conn, &proto::ServerMsg::HelloOk {}).await?; + Ok(true) + } + Err(err) => match downcast_error(err) { + Ok(error) => { + send_msg(conn, &proto::ServerMsg::HelloError { error }).await?; + Ok(false) + } + Err(err) => Err(err), + }, + } +} + +async fn handle_request_msg( + conn: &mut Conn, + request_id: i32, + request: proto::Request, +) -> Result { + let Some(session) = conn.session.as_mut() else { + bail!(ProtocolError::RequestBeforeHello) + }; + + let response_rx = session::handle_request(&conn.server, session, &mut conn.join_set, request) + .await + .unwrap_or_else(|err| { + // we got an error immediately, but let's treat it as a special case of the general + // flow + let (tx, rx) = oneshot::channel(); + tx.send(Err(err)).unwrap(); + rx + }); + + conn.responses.push(ResponseFuture { + request_id, + response_rx: response_rx.fuse(), + }); + Ok(true) +} + +impl Future for ResponseFuture { + type Output = Result; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + match ready!(Pin::new(&mut self.response_rx).poll(cx)) { + Ok(Ok(response)) => Poll::Ready(Ok(proto::ServerMsg::ResponseOk { + request_id: self.request_id, + response, + })), + Ok(Err(err)) => match downcast_error(err) { + Ok(error) => Poll::Ready(Ok(proto::ServerMsg::ResponseError { + request_id: self.request_id, + error, + })), + Err(err) => Poll::Ready(Err(err)), + }, + Err(_recv_err) => { + // do not propagate this error, because the error that caused the receiver to drop + // is very likely propagating from another task at this moment, and we don't want + // to hide it. + // this is also the reason why we need to use `Fuse` in self.response_rx + tracing::warn!("Response sender was dropped"); + Poll::Pending + } + } + } +} + +fn downcast_error(err: anyhow::Error) -> Result { + match err.downcast_ref::() { + Some(error) => Ok(proto::Error { + message: error.to_string(), + code: error.code().into(), + }), + None => Err(err), + } +} + +async fn send_msg(conn: &mut Conn, msg: &proto::ServerMsg) -> Result<()> { + let msg = serde_json::to_string(&msg).context("Could not serialize response message")?; + let msg = tungstenite::Message::Text(msg); + conn.ws + .send(msg) + .await + .context("Could not send response to the WebSocket") +} + +async fn close(conn: &mut Conn, code: CloseCode, reason: String) { + if conn.ws_closed { + return; + } + + let close_frame = tungstenite::protocol::frame::CloseFrame { + code, + reason: Cow::Owned(reason), + }; + if let Err(err) = conn + .ws + .send(tungstenite::Message::Close(Some(close_frame))) + .await + { + if !matches!( + err, + tungstenite::Error::AlreadyClosed | tungstenite::Error::ConnectionClosed + ) { + tracing::warn!( + "Could not send close frame to WebSocket of connection #{}: {:?}", + conn.conn_id, + err + ); + } + } + + conn.ws_closed = true; +} + +fn protocol_error_to_close_code(err: &ProtocolError) -> CloseCode { + match err { + ProtocolError::Deserialize { .. } => CloseCode::Invalid, + ProtocolError::BinaryWebSocketMessage => CloseCode::Unsupported, + _ => CloseCode::Policy, + } +} diff --git a/libsqlx-server/src/hrana/ws/handshake.rs b/libsqlx-server/src/hrana/ws/handshake.rs new file mode 100644 index 00000000..ef187a6a --- /dev/null +++ b/libsqlx-server/src/hrana/ws/handshake.rs @@ -0,0 +1,140 @@ +use anyhow::{anyhow, bail, Context as _, Result}; +use futures::{SinkExt as _, StreamExt as _}; +use tokio_tungstenite::tungstenite; +use tungstenite::http; + +use super::super::Version; +use super::Upgrade; + +#[derive(Debug)] +pub enum WebSocket { + Tcp(tokio_tungstenite::WebSocketStream), + Upgraded(tokio_tungstenite::WebSocketStream), +} + +pub async fn handshake_tcp(socket: tokio::net::TcpStream) -> Result<(WebSocket, Version)> { + let mut version = None; + let callback = |req: &http::Request<()>, resp: http::Response<()>| { + let (mut resp_parts, _) = resp.into_parts(); + resp_parts + .headers + .insert("server", http::HeaderValue::from_static("sqld-hrana-tcp")); + + match negotiate_version(req.headers(), &mut resp_parts.headers) { + Ok(version_) => { + version = Some(version_); + Ok(http::Response::from_parts(resp_parts, ())) + } + Err(resp_body) => Err(http::Response::from_parts(resp_parts, Some(resp_body))), + } + }; + + let ws_config = Some(get_ws_config()); + let stream = + tokio_tungstenite::accept_hdr_async_with_config(socket, callback, ws_config).await?; + Ok((WebSocket::Tcp(stream), version.unwrap())) +} + +pub async fn handshake_upgrade(upgrade: Upgrade) -> Result<(WebSocket, Version)> { + let mut req = upgrade.request; + + let ws_config = Some(get_ws_config()); + let (mut resp, stream_fut_version_res) = match hyper_tungstenite::upgrade(&mut req, ws_config) { + Ok((mut resp, stream_fut)) => match negotiate_version(req.headers(), resp.headers_mut()) { + Ok(version) => (resp, Ok((stream_fut, version))), + Err(msg) => { + *resp.status_mut() = http::StatusCode::BAD_REQUEST; + *resp.body_mut() = hyper::Body::from(msg.clone()); + ( + resp, + Err(anyhow!("Could not negotiate subprotocol: {}", msg)), + ) + } + }, + Err(err) => { + let resp = http::Response::builder() + .status(http::StatusCode::BAD_REQUEST) + .body(hyper::Body::from(format!("{err}"))) + .unwrap(); + ( + resp, + Err(anyhow!(err).context("Protocol error in HTTP upgrade")), + ) + } + }; + + resp.headers_mut().insert( + "server", + http::HeaderValue::from_static("sqld-hrana-upgrade"), + ); + if upgrade.response_tx.send(resp).is_err() { + bail!("Could not send the HTTP upgrade response") + } + + let (stream_fut, version) = stream_fut_version_res?; + let stream = stream_fut + .await + .context("Could not upgrade HTTP request to a WebSocket")?; + Ok((WebSocket::Upgraded(stream), version)) +} + +fn negotiate_version( + req_headers: &http::HeaderMap, + resp_headers: &mut http::HeaderMap, +) -> Result { + if let Some(protocol_hdr) = req_headers.get("sec-websocket-protocol") { + let supported_by_client = protocol_hdr + .to_str() + .unwrap_or("") + .split(',') + .map(|p| p.trim()); + + let mut hrana1_supported = false; + let mut hrana2_supported = false; + for protocol_str in supported_by_client { + hrana1_supported |= protocol_str.eq_ignore_ascii_case("hrana1"); + hrana2_supported |= protocol_str.eq_ignore_ascii_case("hrana2"); + } + + let version = if hrana2_supported { + Version::Hrana2 + } else if hrana1_supported { + Version::Hrana1 + } else { + return Err("Only 'hrana1' and 'hrana2' subprotocols are supported".into()); + }; + + resp_headers.append( + "sec-websocket-protocol", + http::HeaderValue::from_str(&version.to_string()).unwrap(), + ); + Ok(version) + } else { + // Sec-WebSocket-Protocol header not present, assume that the client wants hrana1 + // According to RFC 6455, we must not set the Sec-WebSocket-Protocol response header + Ok(Version::Hrana1) + } +} + +fn get_ws_config() -> tungstenite::protocol::WebSocketConfig { + tungstenite::protocol::WebSocketConfig { + max_send_queue: Some(1 << 20), + ..Default::default() + } +} + +impl WebSocket { + pub async fn recv(&mut self) -> Option> { + match self { + Self::Tcp(stream) => stream.next().await, + Self::Upgraded(stream) => stream.next().await, + } + } + + pub async fn send(&mut self, msg: tungstenite::Message) -> tungstenite::Result<()> { + match self { + Self::Tcp(stream) => stream.send(msg).await, + Self::Upgraded(stream) => stream.send(msg).await, + } + } +} diff --git a/libsqlx-server/src/hrana/ws/mod.rs b/libsqlx-server/src/hrana/ws/mod.rs new file mode 100644 index 00000000..32a34957 --- /dev/null +++ b/libsqlx-server/src/hrana/ws/mod.rs @@ -0,0 +1,104 @@ +use crate::auth::Auth; +use crate::database::Database; +use crate::utils::services::idle_shutdown::IdleKicker; +use anyhow::{Context as _, Result}; +use enclose::enclose; +use std::net::SocketAddr; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use tokio::sync::{mpsc, oneshot}; + +pub mod proto; + +mod conn; +mod handshake; +mod session; + +struct Server { + db_factory: Arc>, + auth: Arc, + idle_kicker: Option, + next_conn_id: AtomicU64, +} + +#[derive(Debug)] +pub struct Accept { + pub socket: tokio::net::TcpStream, + pub peer_addr: SocketAddr, +} + +#[derive(Debug)] +pub struct Upgrade { + pub request: hyper::Request, + pub response_tx: oneshot::Sender>, +} + +pub async fn serve( + db_factory: Arc>, + auth: Arc, + idle_kicker: Option, + mut accept_rx: mpsc::Receiver, + mut upgrade_rx: mpsc::Receiver, +) -> Result<()> { + let server = Arc::new(Server { + db_factory, + auth, + idle_kicker, + next_conn_id: AtomicU64::new(0), + }); + + let mut join_set = tokio::task::JoinSet::new(); + loop { + if let Some(kicker) = server.idle_kicker.as_ref() { + kicker.kick(); + } + + tokio::select! { + Some(accept) = accept_rx.recv() => { + let conn_id = server.next_conn_id.fetch_add(1, Ordering::AcqRel); + tracing::info!("Received TCP connection #{} from {}", conn_id, accept.peer_addr); + + join_set.spawn(enclose!{(server, conn_id) async move { + match conn::handle_tcp(server, accept.socket, conn_id).await { + Ok(_) => tracing::info!("TCP connection #{} was terminated", conn_id), + Err(err) => tracing::error!("TCP connection #{} failed: {:?}", conn_id, err), + } + }}); + }, + Some(upgrade) = upgrade_rx.recv() => { + let conn_id = server.next_conn_id.fetch_add(1, Ordering::AcqRel); + tracing::info!("Received HTTP upgrade connection #{}", conn_id); + + join_set.spawn(enclose!{(server, conn_id) async move { + match conn::handle_upgrade(server, upgrade, conn_id).await { + Ok(_) => tracing::info!("HTTP upgrade connection #{} was terminated", conn_id), + Err(err) => tracing::error!("HTTP upgrade connection #{} failed: {:?}", conn_id, err), + } + }}); + }, + Some(task_res) = join_set.join_next() => { + task_res.expect("Hrana connection task failed") + }, + else => { + tracing::error!("hrana server loop exited"); + return Ok(()) + } + } + } +} + +pub async fn listen(bind_addr: SocketAddr, accept_tx: mpsc::Sender) -> Result<()> { + let listener = tokio::net::TcpListener::bind(bind_addr) + .await + .context("Could not bind TCP listener")?; + let local_addr = listener.local_addr()?; + tracing::info!("Listening for Hrana connections on {}", local_addr); + + loop { + let (socket, peer_addr) = listener + .accept() + .await + .context("Could not accept a TCP connection")?; + let _: Result<_, _> = accept_tx.send(Accept { socket, peer_addr }).await; + } +} diff --git a/libsqlx-server/src/hrana/ws/proto.rs b/libsqlx-server/src/hrana/ws/proto.rs new file mode 100644 index 00000000..6bb88367 --- /dev/null +++ b/libsqlx-server/src/hrana/ws/proto.rs @@ -0,0 +1,127 @@ +//! Structures for Hrana-over-WebSockets. + +pub use super::super::proto::*; +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ClientMsg { + Hello { jwt: Option }, + Request { request_id: i32, request: Request }, +} + +#[derive(Serialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ServerMsg { + HelloOk {}, + HelloError { error: Error }, + ResponseOk { request_id: i32, response: Response }, + ResponseError { request_id: i32, error: Error }, +} + +#[derive(Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Request { + OpenStream(OpenStreamReq), + CloseStream(CloseStreamReq), + Execute(ExecuteReq), + Batch(BatchReq), + Sequence(SequenceReq), + Describe(DescribeReq), + StoreSql(StoreSqlReq), + CloseSql(CloseSqlReq), +} + +#[derive(Serialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Response { + OpenStream(OpenStreamResp), + CloseStream(CloseStreamResp), + Execute(ExecuteResp), + Batch(BatchResp), + Sequence(SequenceResp), + Describe(DescribeResp), + StoreSql(StoreSqlResp), + CloseSql(CloseSqlResp), +} + +#[derive(Deserialize, Debug)] +pub struct OpenStreamReq { + pub stream_id: i32, +} + +#[derive(Serialize, Debug)] +pub struct OpenStreamResp {} + +#[derive(Deserialize, Debug)] +pub struct CloseStreamReq { + pub stream_id: i32, +} + +#[derive(Serialize, Debug)] +pub struct CloseStreamResp {} + +#[derive(Deserialize, Debug)] +pub struct ExecuteReq { + pub stream_id: i32, + pub stmt: Stmt, +} + +#[derive(Serialize, Debug)] +pub struct ExecuteResp { + pub result: StmtResult, +} + +#[derive(Deserialize, Debug)] +pub struct BatchReq { + pub stream_id: i32, + pub batch: Batch, +} + +#[derive(Serialize, Debug)] +pub struct BatchResp { + pub result: BatchResult, +} + +#[derive(Deserialize, Debug)] +pub struct SequenceReq { + pub stream_id: i32, + #[serde(default)] + pub sql: Option, + #[serde(default)] + pub sql_id: Option, +} + +#[derive(Serialize, Debug)] +pub struct SequenceResp {} + +#[derive(Deserialize, Debug)] +pub struct DescribeReq { + pub stream_id: i32, + #[serde(default)] + pub sql: Option, + #[serde(default)] + pub sql_id: Option, +} + +#[derive(Serialize, Debug)] +pub struct DescribeResp { + pub result: DescribeResult, +} + +#[derive(Deserialize, Debug)] +pub struct StoreSqlReq { + pub sql_id: i32, + pub sql: String, +} + +#[derive(Serialize, Debug)] +pub struct StoreSqlResp {} + +#[derive(Deserialize, Debug)] +pub struct CloseSqlReq { + pub sql_id: i32, +} + +#[derive(Serialize, Debug)] +pub struct CloseSqlResp {} diff --git a/libsqlx-server/src/hrana/ws/session.rs b/libsqlx-server/src/hrana/ws/session.rs new file mode 100644 index 00000000..f59bcecc --- /dev/null +++ b/libsqlx-server/src/hrana/ws/session.rs @@ -0,0 +1,329 @@ +use std::collections::HashMap; + +use anyhow::{anyhow, bail, Context as _, Result}; +use futures::future::BoxFuture; +use tokio::sync::{mpsc, oneshot}; + +use super::super::{batch, stmt, ProtocolError, Version}; +use super::{proto, Server}; +use crate::auth::{AuthError, Authenticated}; +use crate::database::Database; + +/// Session-level state of an authenticated Hrana connection. +pub struct Session { + authenticated: Authenticated, + version: Version, + streams: HashMap>, + sqls: HashMap, +} + +struct StreamHandle { + job_tx: mpsc::Sender>, +} + +/// An arbitrary job that is executed on a [`Stream`]. +/// +/// All jobs are executed sequentially on a single task (as evidenced by the `&mut Stream` passed +/// to `f`). +struct StreamJob { + /// The async function which performs the job. + #[allow(clippy::type_complexity)] + f: Box FnOnce(&'s mut Stream) -> BoxFuture<'s, Result> + Send>, + /// The result of `f` will be sent here. + resp_tx: oneshot::Sender>, +} + +/// State of a Hrana stream, which corresponds to a standalone database connection. +struct Stream { + /// The database handle is `None` when the stream is created, and normally set to `Some` by the + /// first job executed on the stream by the [`proto::OpenStreamReq`] request. However, if that + /// request returns an error, the following requests may encounter a `None` here. + db: Option, +} + +/// An error which can be converted to a Hrana [Error][proto::Error]. +#[derive(thiserror::Error, Debug)] +pub enum ResponseError { + #[error("Authentication failed: {source}")] + Auth { source: AuthError }, + #[error("Stream {stream_id} has failed to open")] + StreamNotOpen { stream_id: i32 }, + #[error("The server already stores {count} SQL texts, it cannot store more")] + SqlTooMany { count: usize }, + #[error(transparent)] + Stmt(stmt::StmtError), +} + +pub(super) fn handle_initial_hello( + server: &Server, + version: Version, + jwt: Option, +) -> Result> { + let authenticated = server + .auth + .authenticate_jwt(jwt.as_deref()) + .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; + + Ok(Session { + authenticated, + version, + streams: HashMap::new(), + sqls: HashMap::new(), + }) +} + +pub(super) fn handle_repeated_hello( + server: &Server, + session: &mut Session, + jwt: Option, +) -> Result<()> { + if session.version < Version::Hrana2 { + bail!(ProtocolError::NotSupported { + what: "Repeated hello message", + min_version: Version::Hrana2, + }) + } + + session.authenticated = server + .auth + .authenticate_jwt(jwt.as_deref()) + .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; + Ok(()) +} + +pub(super) async fn handle_request( + server: &Server, + session: &mut Session, + join_set: &mut tokio::task::JoinSet<()>, + req: proto::Request, +) -> Result>> { + // TODO: this function has rotten: it is too long and contains too much duplicated code. It + // should be refactored at the next opportunity, together with code in stmt.rs and batch.rs + + let (resp_tx, resp_rx) = oneshot::channel(); + + macro_rules! stream_respond { + ($stream_hnd:expr, async move |$stream:ident| { $($body:tt)* }) => { + stream_respond($stream_hnd, resp_tx, move |$stream| { + Box::pin(async move { $($body)* }) + }) + .await + }; + } + + macro_rules! respond { + ($value:expr) => { + resp_tx.send(Ok($value)).unwrap() + }; + } + + macro_rules! ensure_version { + ($min_version:expr, $what:expr) => { + if session.version < $min_version { + bail!(ProtocolError::NotSupported { + what: $what, + min_version: $min_version, + }) + } + }; + } + + macro_rules! get_stream_mut { + ($stream_id:expr) => { + match session.streams.get_mut(&$stream_id) { + Some(stream_hdn) => stream_hdn, + None => bail!(ProtocolError::StreamNotFound { + stream_id: $stream_id + }), + } + }; + } + + macro_rules! get_stream_db { + ($stream:expr, $stream_id:expr) => { + match $stream.db.as_ref() { + Some(db) => db, + None => bail!(ResponseError::StreamNotOpen { + stream_id: $stream_id + }), + } + }; + } + + match req { + proto::Request::OpenStream(req) => { + let stream_id = req.stream_id; + if session.streams.contains_key(&stream_id) { + bail!(ProtocolError::StreamExists { stream_id }) + } + + let mut stream_hnd = stream_spawn(join_set, Stream { db: None }); + let db_factory = server.db_factory.clone(); + + stream_respond!(&mut stream_hnd, async move |stream| { + let db = db_factory + .create() + .await + .context("Could not create a database connection")?; + stream.db = Some(db); + Ok(proto::Response::OpenStream(proto::OpenStreamResp {})) + }); + + session.streams.insert(stream_id, stream_hnd); + } + proto::Request::CloseStream(req) => { + let stream_id = req.stream_id; + let Some(mut stream_hnd) = session.streams.remove(&stream_id) else { + bail!(ProtocolError::StreamNotFound { stream_id }) + }; + + stream_respond!(&mut stream_hnd, async move |_stream| { + Ok(proto::Response::CloseStream(proto::CloseStreamResp {})) + }); + } + proto::Request::Execute(req) => { + let stream_id = req.stream_id; + let stream_hnd = get_stream_mut!(stream_id); + + let query = stmt::proto_stmt_to_query(&req.stmt, &session.sqls, session.version) + .map_err(catch_stmt_error)?; + let auth = session.authenticated; + + stream_respond!(stream_hnd, async move |stream| { + let db = get_stream_db!(stream, stream_id); + let result = stmt::execute_stmt(db, auth, query) + .await + .map_err(catch_stmt_error)?; + Ok(proto::Response::Execute(proto::ExecuteResp { result })) + }); + } + proto::Request::Batch(req) => { + let stream_id = req.stream_id; + let stream_hnd = get_stream_mut!(stream_id); + + let pgm = batch::proto_batch_to_program(&req.batch, &session.sqls, session.version) + .map_err(catch_stmt_error)?; + let auth = session.authenticated; + + stream_respond!(stream_hnd, async move |stream| { + let db = get_stream_db!(stream, stream_id); + let result = batch::execute_batch(db, auth, pgm).await?; + Ok(proto::Response::Batch(proto::BatchResp { result })) + }); + } + proto::Request::Sequence(req) => { + ensure_version!(Version::Hrana2, "The `sequence` request"); + let stream_id = req.stream_id; + let stream_hnd = get_stream_mut!(stream_id); + + let sql = stmt::proto_sql_to_sql( + req.sql.as_deref(), + req.sql_id, + &session.sqls, + session.version, + )?; + let pgm = batch::proto_sequence_to_program(sql).map_err(catch_stmt_error)?; + let auth = session.authenticated; + + stream_respond!(stream_hnd, async move |stream| { + let db = get_stream_db!(stream, stream_id); + batch::execute_sequence(db, auth, pgm) + .await + .map_err(catch_stmt_error)?; + Ok(proto::Response::Sequence(proto::SequenceResp {})) + }); + } + proto::Request::Describe(req) => { + ensure_version!(Version::Hrana2, "The `describe` request"); + let stream_id = req.stream_id; + let stream_hnd = get_stream_mut!(stream_id); + + let sql = stmt::proto_sql_to_sql( + req.sql.as_deref(), + req.sql_id, + &session.sqls, + session.version, + )? + .into(); + let auth = session.authenticated; + + stream_respond!(stream_hnd, async move |stream| { + let db = get_stream_db!(stream, stream_id); + let result = stmt::describe_stmt(db, auth, sql) + .await + .map_err(catch_stmt_error)?; + Ok(proto::Response::Describe(proto::DescribeResp { result })) + }); + } + proto::Request::StoreSql(req) => { + ensure_version!(Version::Hrana2, "The `store_sql` request"); + let sql_id = req.sql_id; + if session.sqls.contains_key(&sql_id) { + bail!(ProtocolError::SqlExists { sql_id }) + } else if session.sqls.len() >= MAX_SQL_COUNT { + bail!(ResponseError::SqlTooMany { + count: session.sqls.len() + }) + } + + session.sqls.insert(sql_id, req.sql); + respond!(proto::Response::StoreSql(proto::StoreSqlResp {})); + } + proto::Request::CloseSql(req) => { + ensure_version!(Version::Hrana2, "The `close_sql` request"); + session.sqls.remove(&req.sql_id); + respond!(proto::Response::CloseSql(proto::CloseSqlResp {})); + } + } + Ok(resp_rx) +} + +const MAX_SQL_COUNT: usize = 150; + +fn stream_spawn( + join_set: &mut tokio::task::JoinSet<()>, + stream: Stream, +) -> StreamHandle { + let (job_tx, mut job_rx) = mpsc::channel::>(8); + join_set.spawn(async move { + let mut stream = stream; + while let Some(job) = job_rx.recv().await { + let res = (job.f)(&mut stream).await; + let _: Result<_, _> = job.resp_tx.send(res); + } + }); + StreamHandle { job_tx } +} + +async fn stream_respond( + stream_hnd: &mut StreamHandle, + resp_tx: oneshot::Sender>, + f: F, +) where + for<'s> F: FnOnce(&'s mut Stream) -> BoxFuture<'s, Result>, + F: Send + 'static, +{ + let job = StreamJob { + f: Box::new(f), + resp_tx, + }; + let _: Result<_, _> = stream_hnd.job_tx.send(job).await; +} + +fn catch_stmt_error(err: anyhow::Error) -> anyhow::Error { + match err.downcast::() { + Ok(stmt_err) => anyhow!(ResponseError::Stmt(stmt_err)), + Err(err) => err, + } +} + +impl ResponseError { + pub fn code(&self) -> &'static str { + match self { + Self::Auth { source } => source.code(), + Self::SqlTooMany { .. } => "SQL_STORE_TOO_MANY", + Self::StreamNotOpen { .. } => "STREAM_NOT_OPEN", + Self::Stmt(err) => err.code(), + } + } +} diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index 80e787d6..6b23ef58 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -1,4 +1,4 @@ -use std::{path::PathBuf, sync::Arc}; +use std::sync::Arc; use axum::{extract::State, routing::post, Json, Router}; use color_eyre::eyre::Result; diff --git a/libsqlx-server/src/http/user.rs b/libsqlx-server/src/http/user.rs deleted file mode 100644 index 040f5a66..00000000 --- a/libsqlx-server/src/http/user.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::sync::Arc; - -use axum::{async_trait, extract::FromRequestParts, response::IntoResponse, routing::get, Router, Json}; -use color_eyre::Result; -use hyper::{http::request::Parts, server::accept::Accept, StatusCode}; -use serde::Serialize; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::mpsc, -}; - -use crate::{allocation::AllocationMessage, manager::Manager}; - -pub struct UserApiConfig { - pub manager: Arc, -} - -struct UserApiState { - manager: Arc, -} - -pub async fn run_user_api(config: UserApiConfig, listener: I) -> Result<()> -where - I: Accept, - I::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static, -{ - let state = UserApiState { manager: config.manager }; - - let app = Router::new() - .route("/", get(test_database)) - .with_state(Arc::new(state)); - - axum::Server::builder(listener) - .serve(app.into_make_service()) - .await?; - - Ok(()) -} - -struct Database { - sender: mpsc::Sender, -} - -#[derive(Debug, thiserror::Error)] -enum UserApiError { - #[error("missing host header")] - MissingHost, - #[error("invalid host header format")] - InvalidHost, - #[error("Database `{0}` doesn't exist")] - UnknownDatabase(String), -} - -impl UserApiError { - fn http_status(&self) -> StatusCode { - match self { - UserApiError::MissingHost - | UserApiError::InvalidHost - | UserApiError::UnknownDatabase(_) => StatusCode::BAD_REQUEST, - } - } -} - -#[derive(Debug, Serialize)] -struct ApiError { - error: String, -} - -impl IntoResponse for UserApiError { - fn into_response(self) -> axum::response::Response { - let mut resp = Json(ApiError { - error: self.to_string() - }).into_response(); - *resp.status_mut() = self.http_status(); - - resp - } -} - -#[async_trait] -impl FromRequestParts> for Database { - type Rejection = UserApiError; - - async fn from_request_parts( - parts: &mut Parts, - state: &Arc, - ) -> Result { - let Some(host) = parts.headers.get("host") else { return Err(UserApiError::MissingHost) }; - let Ok(host_str) = std::str::from_utf8(host.as_bytes()) else {return Err(UserApiError::MissingHost)}; - let db_id = parse_host(host_str)?; - let Some(sender) = state.manager.alloc(db_id).await else { return Err(UserApiError::UnknownDatabase(db_id.to_owned())) }; - - Ok(Database { sender }) - } -} - -fn parse_host(host: &str) -> Result<&str, UserApiError> { - let mut split = host.split("."); - let Some(db_id) = split.next() else { return Err(UserApiError::InvalidHost) }; - Ok(db_id) -} diff --git a/libsqlx-server/src/http/user/error.rs b/libsqlx-server/src/http/user/error.rs new file mode 100644 index 00000000..9aab9a71 --- /dev/null +++ b/libsqlx-server/src/http/user/error.rs @@ -0,0 +1,41 @@ +use axum::response::IntoResponse; +use axum::Json; +use hyper::StatusCode; +use serde::Serialize; + +#[derive(Debug, thiserror::Error)] +pub enum UserApiError { + #[error("missing host header")] + MissingHost, + #[error("invalid host header format")] + InvalidHost, + #[error("Database `{0}` doesn't exist")] + UnknownDatabase(String), +} + +impl UserApiError { + fn http_status(&self) -> StatusCode { + match self { + UserApiError::MissingHost + | UserApiError::InvalidHost + | UserApiError::UnknownDatabase(_) => StatusCode::BAD_REQUEST, + } + } +} + +#[derive(Debug, Serialize)] +pub struct ApiError { + error: String, +} + +impl IntoResponse for UserApiError { + fn into_response(self) -> axum::response::Response { + let mut resp = Json(ApiError { + error: self.to_string(), + }) + .into_response(); + *resp.status_mut() = self.http_status(); + + resp + } +} diff --git a/libsqlx-server/src/http/user/extractors.rs b/libsqlx-server/src/http/user/extractors.rs new file mode 100644 index 00000000..2b3f5a14 --- /dev/null +++ b/libsqlx-server/src/http/user/extractors.rs @@ -0,0 +1,32 @@ +use std::sync::Arc; + +use axum::async_trait; +use axum::extract::FromRequestParts; +use hyper::http::request::Parts; + +use crate::database::Database; + +use super::{error::UserApiError, UserApiState}; + +#[async_trait] +impl FromRequestParts> for Database { + type Rejection = UserApiError; + + async fn from_request_parts( + parts: &mut Parts, + state: &Arc, + ) -> Result { + let Some(host) = parts.headers.get("host") else { return Err(UserApiError::MissingHost) }; + let Ok(host_str) = std::str::from_utf8(host.as_bytes()) else {return Err(UserApiError::MissingHost)}; + let db_id = parse_host(host_str)?; + let Some(sender) = state.manager.alloc(db_id).await else { return Err(UserApiError::UnknownDatabase(db_id.to_owned())) }; + + Ok(Database { sender }) + } +} + +fn parse_host(host: &str) -> Result<&str, UserApiError> { + let mut split = host.split("."); + let Some(db_id) = split.next() else { return Err(UserApiError::InvalidHost) }; + Ok(db_id) +} diff --git a/libsqlx-server/src/http/user/mod.rs b/libsqlx-server/src/http/user/mod.rs new file mode 100644 index 00000000..4c314a39 --- /dev/null +++ b/libsqlx-server/src/http/user/mod.rs @@ -0,0 +1,48 @@ +use std::sync::Arc; + +use axum::routing::post; +use axum::{Json, Router}; +use color_eyre::Result; +use hyper::server::accept::Accept; +use tokio::io::{AsyncRead, AsyncWrite}; + +use crate::database::Database; +use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; +use crate::manager::Manager; + +mod error; +mod extractors; + +pub struct UserApiConfig { + pub manager: Arc, +} + +struct UserApiState { + manager: Arc, +} + +pub async fn run_user_api(config: UserApiConfig, listener: I) -> Result<()> +where + I: Accept, + I::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ + let state = UserApiState { + manager: config.manager, + }; + + let app = Router::new() + .route("/v2/pipeline", post(handle_hrana_pipeline)) + .with_state(Arc::new(state)); + + axum::Server::builder(listener) + .serve(app.into_make_service()) + .await?; + + Ok(()) +} + +async fn handle_hrana_pipeline(db: Database, Json(req): Json) -> Json { + let resp = db.hrana_pipeline(req).await; + dbg!(); + Json(resp.unwrap()) +} diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index 2e9411cf..a8829093 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -13,7 +13,7 @@ use tracing::metadata::LevelFilter; use tracing_subscriber::prelude::*; mod allocation; -mod databases; +mod database; mod hrana; mod http; mod manager; @@ -28,7 +28,9 @@ async fn main() -> Result<()> { let store = Arc::new(Store::new(&db_path)); let admin_api_listener = tokio::net::TcpListener::bind("0.0.0.0:3456").await?; join_set.spawn(run_admin_api( - AdminApiConfig { meta_store: store.clone() }, + AdminApiConfig { + meta_store: store.clone(), + }, AddrIncoming::from_listener(admin_api_listener)?, )); diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index 81ac3b72..48315e0a 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; @@ -7,6 +6,7 @@ use tokio::sync::mpsc; use tokio::task::JoinSet; use crate::allocation::{Allocation, AllocationMessage, Database}; +use crate::hrana; use crate::meta::Store; pub struct Manager { @@ -39,10 +39,10 @@ impl Manager { let alloc = Allocation { inbox, database: Database::from_config(&config, path), - connections: HashMap::new(), connections_futs: JoinSet::new(), next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, + hrana_server: Arc::new(hrana::http::Server::new(None)), // TODO: handle self URL? }; tokio::spawn(alloc.run()); diff --git a/libsqlx-server/src/meta.rs b/libsqlx-server/src/meta.rs index 4eade1b0..06e37a76 100644 --- a/libsqlx-server/src/meta.rs +++ b/libsqlx-server/src/meta.rs @@ -32,7 +32,7 @@ impl Store { }); } - pub async fn deallocate(&self, alloc_id: Uuid) { + pub async fn deallocate(&self, _alloc_id: Uuid) { todo!() } @@ -48,7 +48,7 @@ impl Store { tokio::task::block_in_place(|| { let mut out = Vec::new(); for kv in self.meta_store.iter() { - let (k, v) = kv.unwrap(); + let (_k, v) = kv.unwrap(); let alloc = bincode::deserialize(&v).unwrap(); out.push(alloc); } diff --git a/libsqlx/src/analysis.rs b/libsqlx/src/analysis.rs index 0c7f7d43..fccbf3dc 100644 --- a/libsqlx/src/analysis.rs +++ b/libsqlx/src/analysis.rs @@ -1,4 +1,3 @@ -use anyhow::Result; use fallible_iterator::FallibleIterator; use sqlite3_parser::ast::{Cmd, PragmaBody, QualifiedName, Stmt}; use sqlite3_parser::lexer::sql::{Parser, ParserError}; @@ -201,15 +200,15 @@ impl Statement { } } - pub fn parse(s: &str) -> impl Iterator> + '_ { + pub fn parse(s: &str) -> impl Iterator> + '_ { fn parse_inner( original: &str, stmt_count: u64, has_more_stmts: bool, c: Cmd, - ) -> Result { + ) -> crate::Result { let kind = - StmtKind::kind(&c).ok_or_else(|| anyhow::anyhow!("unsupported statement"))?; + StmtKind::kind(&c).ok_or_else(|| crate::error::Error::UnsupportedStatement)?; if stmt_count == 1 && !has_more_stmts { // XXX: Temporary workaround for integration with Atlas @@ -259,9 +258,7 @@ impl Statement { found: Some(found), }, Some((line, col)), - )) => Some(Err(anyhow::anyhow!( - "syntax error around L{line}:{col}: `{found}`" - ))), + )) => Some(Err(crate::error::Error::SyntaxError { line, col, found})), Err(e) => Some(Err(e.into())), } }) diff --git a/libsqlx/src/connection.rs b/libsqlx/src/connection.rs index a5eb7e60..38d31964 100644 --- a/libsqlx/src/connection.rs +++ b/libsqlx/src/connection.rs @@ -2,8 +2,7 @@ use rusqlite::types::Value; use crate::program::{Program, Step}; use crate::query::Query; -use crate::result_builder::ResultBuilder; -use crate::QueryBuilderConfig; +use crate::result_builder::{ResultBuilder, QueryBuilderConfig, QueryResultBuilderError}; #[derive(Debug, Clone)] pub struct DescribeResponse { @@ -48,7 +47,7 @@ pub trait Connection { fn init( &mut self, _config: &QueryBuilderConfig, - ) -> std::result::Result<(), crate::QueryResultBuilderError> { + ) -> std::result::Result<(), QueryResultBuilderError> { self.error = None; self.rows.clear(); self.current_row.clear(); @@ -59,12 +58,12 @@ pub trait Connection { fn add_row_value( &mut self, v: rusqlite::types::ValueRef, - ) -> Result<(), crate::QueryResultBuilderError> { + ) -> Result<(), QueryResultBuilderError> { self.current_row.push(v.into()); Ok(()) } - fn finish_row(&mut self) -> Result<(), crate::QueryResultBuilderError> { + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { let row = std::mem::take(&mut self.current_row); self.rows.push(row); @@ -74,7 +73,7 @@ pub trait Connection { fn step_error( &mut self, error: crate::error::Error, - ) -> Result<(), crate::QueryResultBuilderError> { + ) -> Result<(), QueryResultBuilderError> { self.error.replace(error); Ok(()) } diff --git a/libsqlx/src/database/libsql/connection.rs b/libsqlx/src/database/libsql/connection.rs index 0a2cb6b0..554a22da 100644 --- a/libsqlx/src/database/libsql/connection.rs +++ b/libsqlx/src/database/libsql/connection.rs @@ -177,7 +177,7 @@ impl LibsqlConnection { query .params .bind(&mut stmt) - .map_err(Error::LibSqlInvalidQueryParams)?; + .map_err(|e|Error::LibSqlInvalidQueryParams(e.to_string()))?; let mut qresult = stmt.raw_query(); builder.begin_rows()?; diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index 41de3569..2844a204 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -174,6 +174,7 @@ impl Database for LibsqlDatabase { type Connection = LibsqlConnection<::Context>; fn connect(&self) -> Result { + dbg!(); Ok( LibsqlConnection::<::Context>::new( &self.db_path, diff --git a/libsqlx/src/error.rs b/libsqlx/src/error.rs index 6e35e217..47fde1ae 100644 --- a/libsqlx/src/error.rs +++ b/libsqlx/src/error.rs @@ -1,10 +1,12 @@ use crate::result_builder::QueryResultBuilderError; +pub use rusqlite::Error as RusqliteError; +pub use rusqlite::ffi::ErrorCode; #[allow(clippy::enum_variant_names)] #[derive(Debug, thiserror::Error)] pub enum Error { #[error("LibSQL failed to bind provided query parameters: `{0}`")] - LibSqlInvalidQueryParams(anyhow::Error), + LibSqlInvalidQueryParams(String), #[error("Transaction timed-out")] LibSqlTxTimeout, #[error("Server can't handle additional transactions")] @@ -33,6 +35,14 @@ pub enum Error { Blocked(Option), #[error("invalid replication log header")] InvalidLogHeader, + #[error("unsupported statement")] + UnsupportedStatement, + #[error("Syntax error at {line}:{col}: {found}")] + SyntaxError { + line: u64, col: usize, found: String + }, + #[error(transparent)] + LexerError(#[from] sqlite3_parser::lexer::sql::Error) } impl From for Error { diff --git a/libsqlx/src/lib.rs b/libsqlx/src/lib.rs index a6e3c3a2..f9ef106d 100644 --- a/libsqlx/src/lib.rs +++ b/libsqlx/src/lib.rs @@ -4,8 +4,8 @@ pub mod query; mod connection; mod database; -mod program; -mod result_builder; +pub mod program; +pub mod result_builder; mod seal; pub type Result = std::result::Result; @@ -14,7 +14,6 @@ pub use connection::Connection; pub use database::libsql; pub use database::proxy; pub use database::Database; -pub use program::Program; -pub use result_builder::{ - Column, QueryBuilderConfig, QueryResultBuilderError, ResultBuilder, ResultBuilderExt, -}; +pub use database::libsql::replication_log::FrameNo; + +pub use rusqlite; diff --git a/libsqlx/src/program.rs b/libsqlx/src/program.rs index 3eb2f551..131dd125 100644 --- a/libsqlx/src/program.rs +++ b/libsqlx/src/program.rs @@ -4,13 +4,13 @@ use crate::query::Query; #[derive(Debug, Clone)] pub struct Program { - pub steps: Arc>, + pub steps: Arc<[Step]>, } impl Program { pub fn new(steps: Vec) -> Self { Self { - steps: Arc::new(steps), + steps: steps.into(), } } @@ -19,7 +19,20 @@ impl Program { } pub fn steps(&self) -> &[Step] { - self.steps.as_slice() + &self.steps + } + + /// transforms a collection of queries into a batch program. The execution of each query + /// depends on the success of the previous one. + pub fn from_queries(qs: impl IntoIterator) -> Self { + let steps = qs.into_iter().enumerate().map(|(idx, query)| Step { + cond: (idx > 0).then(|| Cond::Ok { step: idx - 1 }), + query, + }) + .collect(); + + Self { steps } + } #[cfg(test)] diff --git a/libsqlx/src/result_builder.rs b/libsqlx/src/result_builder.rs index ae299b1e..be5e27a7 100644 --- a/libsqlx/src/result_builder.rs +++ b/libsqlx/src/result_builder.rs @@ -2,7 +2,7 @@ use std::fmt; use std::io::{self, ErrorKind}; use bytesize::ByteSize; -use rusqlite::types::ValueRef; +pub use rusqlite::types::ValueRef; use crate::database::FrameNo; @@ -170,6 +170,12 @@ pub struct StepResultsBuilder { is_skipped: bool, } +impl StepResultsBuilder { + pub fn into_ret(self) -> Vec { + self.step_results + } +} + impl ResultBuilder for StepResultsBuilder { fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { *self = Default::default();