From 471f94c35099c5aecbf5cfbd64c81357de056ba6 Mon Sep 17 00:00:00 2001 From: totodore Date: Tue, 6 May 2025 00:23:09 +0200 Subject: [PATCH 01/12] wip --- Cargo.lock | 446 +++++++++++ crates/socketioxide-postgres/Cargo.toml | 58 ++ crates/socketioxide-postgres/README.md | 0 .../socketioxide-postgres/src/drivers/mod.rs | 15 + .../src/drivers/postgres.rs | 34 + .../socketioxide-postgres/src/drivers/sqlx.rs | 43 ++ crates/socketioxide-postgres/src/lib.rs | 694 ++++++++++++++++++ 7 files changed, 1290 insertions(+) create mode 100644 crates/socketioxide-postgres/Cargo.toml create mode 100644 crates/socketioxide-postgres/README.md create mode 100644 crates/socketioxide-postgres/src/drivers/mod.rs create mode 100644 crates/socketioxide-postgres/src/drivers/postgres.rs create mode 100644 crates/socketioxide-postgres/src/drivers/sqlx.rs create mode 100644 crates/socketioxide-postgres/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index fcbc8e71..20d29952 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -54,6 +54,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "android-tzdata" version = "0.1.1" @@ -107,6 +113,15 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "atoi" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" +dependencies = [ + "num-traits", +] + [[package]] name = "autocfg" version = "1.4.0" @@ -387,6 +402,15 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "const-random" version = "0.1.18" @@ -434,6 +458,21 @@ dependencies = [ "libc", ] +[[package]] +name = "crc" +version = "3.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69e6e4d7b33a94f0991c26729976b10ebde1d34c3ee82408fb536164fa10d636" +dependencies = [ + "crc-catalog", +] + +[[package]] +name = "crc-catalog" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" + [[package]] name = "crc16" version = "0.4.0" @@ -494,6 +533,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -635,11 +683,20 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "dotenvy" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" + [[package]] name = "either" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +dependencies = [ + "serde", +] [[package]] name = "engineioxide" @@ -705,6 +762,50 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errno" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + +[[package]] +name = "etcetera" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "136d1b5283a1ab77bd9257427ffd09d8667ced0570b6f938942bc7568ed5b943" +dependencies = [ + "cfg-if", + "home", + "windows-sys 0.48.0", +] + +[[package]] +name = "event-listener" +version = "5.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "fallible-iterator" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "float-cmp" version = "0.10.0" @@ -720,6 +821,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -814,6 +921,17 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-intrusive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" +dependencies = [ + "futures-core", + "lock_api", + "parking_lot", +] + [[package]] name = "futures-io" version = "0.3.31" @@ -938,6 +1056,20 @@ name = "hashbrown" version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] + +[[package]] +name = "hashlink" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" +dependencies = [ + "hashbrown 0.15.2", +] [[package]] name = "heaptrack" @@ -952,6 +1084,12 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + [[package]] name = "hermit-abi" version = "0.5.0" @@ -964,6 +1102,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hkdf" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" +dependencies = [ + "hmac", +] + [[package]] name = "hmac" version = "0.12.1" @@ -973,6 +1120,15 @@ dependencies = [ "digest", ] +[[package]] +name = "home" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "http" version = "1.3.1" @@ -1293,6 +1449,12 @@ version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" +[[package]] +name = "linux-raw-sys" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" + [[package]] name = "litemap" version = "0.7.5" @@ -1581,6 +1743,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "parking" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" + [[package]] name = "parking_lot" version = "0.12.3" @@ -1625,6 +1793,24 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "phf" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_shared" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project-lite" version = "0.2.16" @@ -1637,6 +1823,35 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "postgres-protocol" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76ff0abab4a9b844b93ef7b81f1efc0a366062aaef2cd702c76256b5dc075c54" +dependencies = [ + "base64 0.22.1", + "byteorder", + "bytes", + "fallible-iterator", + "hmac", + "md-5", + "memchr", + "rand 0.9.1", + "sha2", + "stringprep", +] + +[[package]] +name = "postgres-types" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613283563cd90e1dfc3518d548caee47e0e725455ed619881f5cf21f36de4b48" +dependencies = [ + "bytes", + "fallible-iterator", + "postgres-protocol", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -1927,6 +2142,19 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" +dependencies = [ + "bitflags 2.9.0", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.59.0", +] + [[package]] name = "rustls" version = "0.21.12" @@ -2158,6 +2386,12 @@ dependencies = [ "libc", ] +[[package]] +name = "siphasher" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" + [[package]] name = "slab" version = "0.4.9" @@ -2296,6 +2530,27 @@ dependencies = [ "socketioxide-core", ] +[[package]] +name = "socketioxide-postgres" +version = "0.1.0" +dependencies = [ + "bytes", + "futures-core", + "futures-util", + "pin-project-lite", + "rmp-serde", + "serde", + "smallvec", + "socketioxide", + "socketioxide-core", + "sqlx", + "thiserror 2.0.12", + "tokio", + "tokio-postgres", + "tracing", + "tracing-subscriber", +] + [[package]] name = "socketioxide-redis" version = "0.2.2" @@ -2318,6 +2573,122 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "sqlx" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3c3a85280daca669cfd3bcb68a337882a8bc57ec882f72c5d13a430613a738e" +dependencies = [ + "sqlx-core", + "sqlx-macros", + "sqlx-postgres", +] + +[[package]] +name = "sqlx-core" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f743f2a3cea30a58cd479013f75550e879009e3a02f616f18ca699335aa248c3" +dependencies = [ + "base64 0.22.1", + "bytes", + "crc", + "crossbeam-queue", + "either", + "event-listener", + "futures-core", + "futures-intrusive", + "futures-io", + "futures-util", + "hashbrown 0.15.2", + "hashlink", + "indexmap 2.9.0", + "log", + "memchr", + "once_cell", + "percent-encoding", + "serde", + "serde_json", + "sha2", + "smallvec", + "thiserror 2.0.12", + "tracing", + "url", +] + +[[package]] +name = "sqlx-macros" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f4200e0fde19834956d4252347c12a083bdcb237d7a1a1446bffd8768417dce" +dependencies = [ + "proc-macro2", + "quote", + "sqlx-core", + "sqlx-macros-core", + "syn 2.0.100", +] + +[[package]] +name = "sqlx-macros-core" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ceaa29cade31beca7129b6beeb05737f44f82dbe2a9806ecea5a7093d00b7" +dependencies = [ + "dotenvy", + "either", + "heck", + "hex", + "once_cell", + "proc-macro2", + "quote", + "serde", + "serde_json", + "sha2", + "sqlx-core", + "sqlx-postgres", + "syn 2.0.100", + "tempfile", + "url", +] + +[[package]] +name = "sqlx-postgres" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0bedbe1bbb5e2615ef347a5e9d8cd7680fb63e77d9dafc0f29be15e53f1ebe6" +dependencies = [ + "atoi", + "base64 0.22.1", + "bitflags 2.9.0", + "byteorder", + "crc", + "dotenvy", + "etcetera", + "futures-channel", + "futures-core", + "futures-util", + "hex", + "hkdf", + "hmac", + "home", + "itoa", + "log", + "md-5", + "memchr", + "once_cell", + "rand 0.8.5", + "serde", + "serde_json", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror 2.0.12", + "tracing", + "whoami", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -2407,6 +2778,19 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "tempfile" +version = "3.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf" +dependencies = [ + "fastrand", + "getrandom 0.3.2", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -2561,6 +2945,32 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "tokio-postgres" +version = "0.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c95d533c83082bb6490e0189acaa0bbeef9084e60471b696ca6988cd0541fb0" +dependencies = [ + "async-trait", + "byteorder", + "bytes", + "fallible-iterator", + "futures-channel", + "futures-util", + "log", + "parking_lot", + "percent-encoding", + "phf", + "pin-project-lite", + "postgres-protocol", + "postgres-types", + "rand 0.9.1", + "socket2", + "tokio", + "tokio-util", + "whoami", +] + [[package]] name = "tokio-rustls" version = "0.24.1" @@ -2855,6 +3265,12 @@ dependencies = [ "wit-bindgen-rt", ] +[[package]] +name = "wasite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" + [[package]] name = "wasm-bindgen" version = "0.2.100" @@ -2913,12 +3329,33 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "web-sys" +version = "0.3.77" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki-roots" version = "0.25.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" +[[package]] +name = "whoami" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6994d13118ab492c3c80c1f81928718159254c53c472bf9ce36f8dae4add02a7" +dependencies = [ + "redox_syscall", + "wasite", + "web-sys", +] + [[package]] name = "winapi" version = "0.3.9" @@ -3018,6 +3455,15 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.5", +] + [[package]] name = "windows-sys" version = "0.52.0" diff --git a/crates/socketioxide-postgres/Cargo.toml b/crates/socketioxide-postgres/Cargo.toml new file mode 100644 index 00000000..df86ade0 --- /dev/null +++ b/crates/socketioxide-postgres/Cargo.toml @@ -0,0 +1,58 @@ +[package] +name = "socketioxide-postgres" +description = "PostgreSQL adapter for socketioxide" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true +authors.workspace = true +repository.workspace = true +homepage.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +readme = "README.md" + +[features] +sqlx = ["dep:sqlx"] +postgres = ["dep:tokio-postgres"] +default = ["postgres"] + +[dependencies] +socketioxide-core = { version = "0.17", path = "../socketioxide-core", features = [ + "remote-adapter", +] } +futures-core.workspace = true +futures-util.workspace = true +pin-project-lite.workspace = true +serde.workspace = true +smallvec = { workspace = true, features = ["serde"] } +tokio = { workspace = true, features = ["time", "rt", "sync"] } +rmp-serde.workspace = true +tracing.workspace = true +thiserror.workspace = true + +# PostgreSQL implementations +tokio-postgres = { version = "0.7", default-features = false, optional = true, features = [ + "runtime", +] } +sqlx = { version = "0.8", default-features = false, optional = true, features = [ + "postgres", +] } + +[dev-dependencies] +tokio = { workspace = true, features = [ + "macros", + "parking_lot", + "rt-multi-thread", +] } +socketioxide = { path = "../socketioxide", features = [ + "tracing", + "__test_harness", +] } +tracing-subscriber.workspace = true +bytes.workspace = true + +# docs.rs-specific configuration +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] diff --git a/crates/socketioxide-postgres/README.md b/crates/socketioxide-postgres/README.md new file mode 100644 index 00000000..e69de29b diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs new file mode 100644 index 00000000..976f9821 --- /dev/null +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -0,0 +1,15 @@ +mod postgres; +mod sqlx; + +pub type ChanItem = (String, String); + +/// The driver trait can be used to support different LISTEN/NOTIFY backends. +/// It must share handlers/connection between its clones. +pub trait Driver: Clone + Send + Sync + 'static { + type Error: std::error::Error + Send + 'static; + + fn init(&self, table: &str, channels: &[&str]) + -> impl Future>; + fn notify(&self, channel: &str, message: &str) + -> impl Future>; +} diff --git a/crates/socketioxide-postgres/src/drivers/postgres.rs b/crates/socketioxide-postgres/src/drivers/postgres.rs new file mode 100644 index 00000000..8e5b6408 --- /dev/null +++ b/crates/socketioxide-postgres/src/drivers/postgres.rs @@ -0,0 +1,34 @@ +use std::sync::Arc; + +use tokio_postgres::{Client, Connection}; + +use crate::PostgresAdapterConfig; + +use super::Driver; + +#[derive(Debug, Clone)] +pub struct PostgresDriver { + client: Arc, +} + +impl PostgresDriver { + pub fn new(client: Client, connection: Connection) -> Self { + PostgresDriver { + client: Arc::new(client), + } + } +} + +impl Driver for PostgresDriver { + type Error = tokio_postgres::Error; + async fn init(&self, table: &str, channels: &[&str]) -> Result<(), Self::Error> { + self.client + .execute("CREATE TABLE $1 IF NOT EXISTS", &[&table]) + .await?; + Ok(()) + } + + async fn notify(&self, channel: &str, msg: &str) -> Result<(), Self::Error> { + todo!() + } +} diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs new file mode 100644 index 00000000..fa63e56b --- /dev/null +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -0,0 +1,43 @@ +use std::{collections::HashMap, sync::Arc}; + +use sqlx::{PgPool, postgres::PgListener}; +use tokio::sync::mpsc; + +use super::{ChanItem, Driver}; + +#[derive(Debug, Clone)] +pub struct SqlxDriver { + client: PgPool, +} +impl SqlxDriver { + pub fn new(client: PgPool) -> Self { + Self { client } + } + + async fn spawn_listener(&self, mut listener: PgListener, tx: mpsc::Sender) { + while let Ok(notif) = listener + .recv() + .await + .inspect_err(|e| tracing::warn!(?e, "sqlx listener error")) + {} + } +} + +impl Driver for SqlxDriver { + type Error = sqlx::Error; + + async fn init(&self, table: &str, channels: &[&str]) -> Result<(), Self::Error> { + sqlx::query("CREATE TABLE $1 IF NOT EXISTS") + .bind(&table) + .execute(&self.client) + .await?; + let mut listener = PgListener::connect_with(&self.client).await?; + listener.listen_all(channels.iter().copied()).await?; + + Ok(()) + } + + async fn notify(&self, channel: &str, msg: &str) -> Result<(), Self::Error> { + todo!() + } +} diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs new file mode 100644 index 00000000..a3441158 --- /dev/null +++ b/crates/socketioxide-postgres/src/lib.rs @@ -0,0 +1,694 @@ +#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![warn( + clippy::all, + clippy::todo, + clippy::empty_enum, + clippy::mem_forget, + clippy::unused_self, + clippy::filter_map_next, + clippy::needless_continue, + clippy::needless_borrow, + clippy::match_wildcard_for_single_variants, + clippy::if_let_mutex, + clippy::await_holding_lock, + clippy::match_on_vec_items, + clippy::imprecise_flops, + clippy::suboptimal_flops, + clippy::lossy_float_literal, + clippy::rest_pat_in_fully_bound_structs, + clippy::fn_params_excessive_bools, + clippy::exit, + clippy::inefficient_to_string, + clippy::linkedlist, + clippy::macro_use_imports, + clippy::option_option, + clippy::verbose_file_reads, + clippy::unnested_or_patterns, + rust_2018_idioms, + future_incompatible, + nonstandard_style, + missing_docs +)] +//! + +use drivers::Driver; + +use futures_core::Stream; +use serde::{Serialize, de::DeserializeOwned}; +use socketioxide_core::{ + Sid, Uid, + adapter::{ + BroadcastOptions, CoreAdapter, CoreLocalAdapter, DefinedAdapter, RemoteSocketData, Room, + RoomParam, SocketEmitter, Spawnable, + errors::AdapterError, + remote_packet::{ + RequestIn, RequestOut, RequestTypeIn, RequestTypeOut, Response, ResponseType, + ResponseTypeId, + }, + }, + packet::Packet, +}; +use std::{ + borrow::Cow, + collections::HashMap, + fmt, future, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, + time::{Duration, Instant}, +}; +use tokio::sync::mpsc; + +mod drivers; + +/// The configuration of the [`MongoDbAdapter`]. +#[derive(Debug, Clone)] +pub struct PostgresAdapterConfig { + /// The heartbeat timeout duration. If a remote node does not respond within this duration, + /// it will be considered disconnected. Default is 60 seconds. + pub hb_timeout: Duration, + /// The heartbeat interval duration. The current node will broadcast a heartbeat to the + /// remote nodes at this interval. Default is 10 seconds. + pub hb_interval: Duration, + /// The request timeout. When expecting a response from remote nodes, if they do not respond within + /// this duration, the request will be considered failed. Default is 5 seconds. + pub request_timeout: Duration, + /// The channel size used to receive ack responses. Default is 255. + /// + /// If you have a lot of servers/sockets and that you may miss acknowledgement because they arrive faster + /// than you poll them with the returned stream, you might want to increase this value. + pub ack_response_buffer: usize, + /// The table name used to store socket.io attachments. Default is "socket_io_attachments". + pub table_name: Cow<'static, str>, + /// The prefix used for the channels. Default is "socket.io". + pub prefix: Cow<'static, str>, + /// The treshold to the payload size in bytes. It should match the configured value on your PostgreSQL instance: + /// + pub payload_treshold: usize, + /// The duration between cleanup queries on the + pub cleanup_intervals: Duration, +} + +/// Represent any error that might happen when using this adapter. +#[derive(thiserror::Error)] +pub enum Error { + /// Mongo driver error + #[error("driver error: {0}")] + Driver(D::Error), + /// Packet encoding error + #[error("packet encoding error: {0}")] + Encode(#[from] rmp_serde::encode::Error), + /// Packet decoding error + #[error("packet decoding error: {0}")] + Decode(#[from] rmp_serde::decode::Error), +} + +impl Error { + fn from_driver(err: R::Error) -> Self { + Self::Driver(err) + } +} +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Driver(err) => write!(f, "Driver error: {:?}", err), + Self::Decode(err) => write!(f, "Decode error: {:?}", err), + Self::Encode(err) => write!(f, "Encode error: {:?}", err), + } + } +} + +impl From> for AdapterError { + fn from(err: Error) -> Self { + AdapterError::from(Box::new(err) as Box) + } +} + +pub(crate) type ResponseHandlers = HashMap>; + +/// The postgres adapter implementation. +/// It is generic over the [`Driver`] used to communicate with the postgres server. +/// And over the [`SocketEmitter`] used to communicate with the local server. This allows to +/// avoid cyclic dependencies between the adapter, `socketioxide-core` and `socketioxide` crates. +pub struct CustomPostgresAdapter { + /// The driver used by the adapter. This is used to communicate with the postgres server. + /// All the postgres adapter instances share the same driver. + driver: D, + /// The configuration of the adapter. + config: PostgresAdapterConfig, + /// A unique identifier for the adapter to identify itself in the postgres server. + uid: Uid, + /// The local adapter, used to manage local rooms and socket stores. + local: CoreLocalAdapter, + /// A map of nodes liveness, with the last time remote nodes were seen alive. + nodes_liveness: Mutex>, + /// A map of response handlers used to await for responses from the remote servers. + responses: Arc>, +} + +impl DefinedAdapter for CustomPostgresAdapter {} +impl CoreAdapter for CustomPostgresAdapter { + type Error = Error; + type State = PostgresAdapterCtr; + type AckStream = AckStream; + type InitRes = InitRes; + + fn new(state: &Self::State, local: CoreLocalAdapter) -> Self { + let uid = local.server_id(); + Self { + local, + uid, + driver: state.driver.clone(), + config: state.config.clone(), + nodes_liveness: Mutex::new(Vec::new()), + responses: Arc::new(Mutex::new(HashMap::new())), + } + } + + fn init(self: Arc, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes { + let fut = async move { + let stream = self.driver.watch(self.uid, self.local.path()).await?; + tokio::spawn(self.clone().handle_ev_stream(stream)); + tokio::spawn(self.clone().heartbeat_job()); + + // Send initial heartbeat when starting. + self.emit_init_heartbeat().await.map_err(|e| match e { + Error::Driver(e) => e, + Error::Encode(_) | Error::Decode(_) => unreachable!(), + })?; + + on_success(); + Ok(()) + }; + InitRes(Box::pin(fut)) + } + + async fn close(&self) -> Result<(), Self::Error> { + Ok(()) + } + + /// Get the number of servers by iterating over the node liveness heartbeats. + async fn server_count(&self) -> Result { + let treshold = std::time::Instant::now() - self.config.hb_timeout; + let mut nodes_liveness = self.nodes_liveness.lock().unwrap(); + nodes_liveness.retain(|(_, v)| v > &treshold); + Ok((nodes_liveness.len() + 1) as u16) + } + + /// Broadcast a packet to all the servers to send them through their sockets. + async fn broadcast( + &self, + packet: Packet, + opts: BroadcastOptions, + ) -> Result<(), BroadcastError> { + if !opts.is_local(self.uid) { + let req = RequestOut::new(self.uid, RequestTypeOut::Broadcast(&packet), &opts); + self.send_req(req, None).await.map_err(AdapterError::from)?; + } + + self.local.broadcast(packet, opts)?; + Ok(()) + } + + /// Broadcast a packet to all the servers to send them through their sockets. + /// + /// Returns a Stream that is a combination of the local ack stream and a remote ack stream. + /// Here is a specific protocol in order to know how many message the server expect to close + /// the stream at the right time: + /// * Get the number `n` of remote servers. + /// * Send the broadcast request. + /// * Expect `n` `BroadcastAckCount` response in the stream to know the number `m` of expected ack responses. + /// * Expect `sum(m)` broadcast counts sent by the servers. + /// + /// Example with 3 remote servers (n = 3): + /// ```text + /// +---+ +---+ +---+ + /// | A | | B | | C | + /// +---+ +---+ +---+ + /// | | | + /// |---BroadcastWithAck--->| | + /// |---BroadcastWithAck--------------------------->| + /// | | | + /// |<-BroadcastAckCount(2)-| (n = 2; m = 2) | + /// |<-BroadcastAckCount(2)-------(n = 2; m = 4)----| + /// | | | + /// |<----------------Ack---------------------------| + /// |<----------------Ack---| | + /// | | | + /// |<----------------Ack---------------------------| + /// |<----------------Ack---| | + async fn broadcast_with_ack( + &self, + packet: Packet, + opts: BroadcastOptions, + timeout: Option, + ) -> Result { + if opts.is_local(self.uid) { + tracing::debug!(?opts, "broadcast with ack is local"); + let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout); + let stream = AckStream::new_local(local); + return Ok(stream); + } + let req = RequestOut::new(self.uid, RequestTypeOut::BroadcastWithAck(&packet), &opts); + let req_id = req.id; + + let remote_serv_cnt = self.server_count().await?.saturating_sub(1); + tracing::trace!(?remote_serv_cnt, "expecting acks from remote servers"); + + let (tx, rx) = mpsc::channel(self.config.ack_response_buffer + remote_serv_cnt as usize); + self.responses.lock().unwrap().insert(req_id, tx); + self.send_req(req, None).await?; + let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout); + + Ok(AckStream::new( + local, + rx, + self.config.request_timeout, + remote_serv_cnt, + req_id, + self.responses.clone(), + )) + } + + async fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), BroadcastError> { + if !opts.is_local(self.uid) { + let req = RequestOut::new(self.uid, RequestTypeOut::DisconnectSockets, &opts); + self.send_req(req, None).await.map_err(AdapterError::from)?; + } + self.local + .disconnect_socket(opts) + .map_err(BroadcastError::Socket)?; + + Ok(()) + } + + async fn rooms(&self, opts: BroadcastOptions) -> Result, Self::Error> { + if opts.is_local(self.uid) { + return Ok(self.local.rooms(opts).into_iter().collect()); + } + let req = RequestOut::new(self.uid, RequestTypeOut::AllRooms, &opts); + let req_id = req.id; + + // First get the remote stream because mongodb might send + // the responses before subscription is done. + let stream = self + .get_res::<()>(req_id, ResponseTypeId::AllRooms, opts.server_id) + .await?; + self.send_req(req, opts.server_id).await?; + let local = self.local.rooms(opts); + let rooms = stream + .filter_map(|item| std::future::ready(item.into_rooms())) + .fold(local, |mut acc, item| async move { + acc.extend(item); + acc + }) + .await; + Ok(Vec::from_iter(rooms)) + } + + async fn add_sockets( + &self, + opts: BroadcastOptions, + rooms: impl RoomParam, + ) -> Result<(), Self::Error> { + let rooms: Vec = rooms.into_room_iter().collect(); + if !opts.is_local(self.uid) { + let req = RequestOut::new(self.uid, RequestTypeOut::AddSockets(&rooms), &opts); + self.send_req(req, opts.server_id).await?; + } + self.local.add_sockets(opts, rooms); + Ok(()) + } + + async fn del_sockets( + &self, + opts: BroadcastOptions, + rooms: impl RoomParam, + ) -> Result<(), Self::Error> { + let rooms: Vec = rooms.into_room_iter().collect(); + if !opts.is_local(self.uid) { + let req = RequestOut::new(self.uid, RequestTypeOut::DelSockets(&rooms), &opts); + self.send_req(req, opts.server_id).await?; + } + self.local.del_sockets(opts, rooms); + Ok(()) + } + + async fn fetch_sockets( + &self, + opts: BroadcastOptions, + ) -> Result, Self::Error> { + if opts.is_local(self.uid) { + return Ok(self.local.fetch_sockets(opts)); + } + let req = RequestOut::new(self.uid, RequestTypeOut::FetchSockets, &opts); + // First get the remote stream because mongodb might send + // the responses before subscription is done. + let remote = self + .get_res::(req.id, ResponseTypeId::FetchSockets, opts.server_id) + .await?; + + self.send_req(req, opts.server_id).await?; + let local = self.local.fetch_sockets(opts); + let sockets = remote + .filter_map(|item| future::ready(item.into_fetch_sockets())) + .fold(local, |mut acc, item| async move { + acc.extend(item); + acc + }) + .await; + Ok(sockets) + } + + fn get_local(&self) -> &CoreLocalAdapter { + &self.local + } +} + +impl CustomPostgresAdapter { + async fn heartbeat_job(self: Arc) -> Result<(), Error> { + let mut interval = tokio::time::interval(self.config.hb_interval); + interval.tick().await; // first tick yields immediately + loop { + interval.tick().await; + self.emit_heartbeat(None).await?; + } + } + + async fn handle_ev_stream( + self: Arc, + mut stream: impl Stream> + Unpin, + ) { + while let Some(item) = stream.next().await { + match item { + Ok(Item { + header: ItemHeader::Req { target, .. }, + data, + .. + }) if target.is_none_or(|id| id == self.uid) => { + tracing::debug!(?target, "request header"); + if let Err(e) = self.recv_req(data).await { + tracing::warn!("error receiving request from driver: {e}"); + } + } + Ok(Item { + header: ItemHeader::Req { target, .. }, + .. + }) => { + tracing::debug!( + ?target, + "receiving request which is not for us, skipping..." + ); + } + Ok( + item @ Item { + header: ItemHeader::Res { request, .. }, + .. + }, + ) => { + tracing::trace!(?request, "received response"); + let handlers = self.responses.lock().unwrap(); + if let Some(tx) = handlers.get(&request) { + if let Err(e) = tx.try_send(item) { + tracing::warn!("error sending response to handler: {e}"); + } + } else { + tracing::warn!(?request, ?handlers, "could not find req handler"); + } + } + Err(e) => { + tracing::warn!("error receiving event from driver: {e}"); + } + } + } + } + + async fn recv_req(self: &Arc, req: Vec) -> Result<(), Error> { + let req = rmp_serde::from_slice::(&req)?; + tracing::trace!(?req, "incoming request"); + match (req.r#type, req.opts) { + (RequestTypeIn::Broadcast(p), Some(opts)) => self.recv_broadcast(opts, p), + (RequestTypeIn::BroadcastWithAck(p), Some(opts)) => self + .clone() + .recv_broadcast_with_ack(req.node_id, req.id, p, opts), + (RequestTypeIn::DisconnectSockets, Some(opts)) => self.recv_disconnect_sockets(opts), + (RequestTypeIn::AllRooms, Some(opts)) => self.recv_rooms(req.node_id, req.id, opts), + (RequestTypeIn::AddSockets(rooms), Some(opts)) => self.recv_add_sockets(opts, rooms), + (RequestTypeIn::DelSockets(rooms), Some(opts)) => self.recv_del_sockets(opts, rooms), + (RequestTypeIn::FetchSockets, Some(opts)) => { + self.recv_fetch_sockets(req.node_id, req.id, opts) + } + req_type @ (RequestTypeIn::Heartbeat | RequestTypeIn::InitHeartbeat, _) => { + self.recv_heartbeat(req_type.0, req.node_id) + } + _ => (), + } + Ok(()) + } + + fn recv_broadcast(&self, opts: BroadcastOptions, packet: Packet) { + tracing::trace!(?opts, "incoming broadcast"); + if let Err(e) = self.local.broadcast(packet, opts) { + let ns = self.local.path(); + tracing::warn!(?self.uid, ?ns, "remote request broadcast handler: {:?}", e); + } + } + + fn recv_disconnect_sockets(&self, opts: BroadcastOptions) { + if let Err(e) = self.local.disconnect_socket(opts) { + let ns = self.local.path(); + tracing::warn!( + ?self.uid, + ?ns, + "remote request disconnect sockets handler: {:?}", + e + ); + } + } + + fn recv_broadcast_with_ack( + self: Arc, + origin: Uid, + req_id: Sid, + packet: Packet, + opts: BroadcastOptions, + ) { + let (stream, count) = self.local.broadcast_with_ack(packet, opts, None); + tokio::spawn(async move { + let on_err = |err| { + let ns = self.local.path(); + tracing::warn!( + ?self.uid, + ?ns, + "remote request broadcast with ack handler errors: {:?}", + err + ); + }; + // First send the count of expected acks to the server that sent the request. + // This is used to keep track of the number of expected acks. + let res = Response { + r#type: ResponseType::<()>::BroadcastAckCount(count), + node_id: self.uid, + }; + if let Err(err) = self.send_res(req_id, origin, res).await { + on_err(err); + return; + } + + // Then send the acks as they are received. + futures_util::pin_mut!(stream); + while let Some(ack) = stream.next().await { + let res = Response { + r#type: ResponseType::BroadcastAck(ack), + node_id: self.uid, + }; + if let Err(err) = self.send_res(req_id, origin, res).await { + on_err(err); + return; + } + } + }); + } + + fn recv_rooms(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) { + let rooms = self.local.rooms(opts); + let res = Response { + r#type: ResponseType::<()>::AllRooms(rooms), + node_id: self.uid, + }; + let fut = self.send_res(req_id, origin, res); + let ns = self.local.path().clone(); + let uid = self.uid; + tokio::spawn(async move { + if let Err(err) = fut.await { + tracing::warn!(?uid, ?ns, "remote request rooms handler: {:?}", err); + } + }); + } + + fn recv_add_sockets(&self, opts: BroadcastOptions, rooms: Vec) { + self.local.add_sockets(opts, rooms); + } + + fn recv_del_sockets(&self, opts: BroadcastOptions, rooms: Vec) { + self.local.del_sockets(opts, rooms); + } + fn recv_fetch_sockets(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) { + let sockets = self.local.fetch_sockets(opts); + let res = Response { + node_id: self.uid, + r#type: ResponseType::FetchSockets(sockets), + }; + let fut = self.send_res(req_id, origin, res); + let ns = self.local.path().clone(); + let uid = self.uid; + tokio::spawn(async move { + if let Err(err) = fut.await { + tracing::warn!(?uid, ?ns, "remote request fetch sockets handler: {:?}", err); + } + }); + } + + /// Receive a heartbeat from a remote node. + /// It might be a FirstHeartbeat packet, in which case we are re-emitting a heartbeat to the remote node. + fn recv_heartbeat(self: &Arc, req_type: RequestTypeIn, origin: Uid) { + tracing::debug!(?req_type, "{:?} received", req_type); + let mut node_liveness = self.nodes_liveness.lock().unwrap(); + // Even with a FirstHeartbeat packet we first consume the node liveness to + // ensure that the node is not already in the list. + for (id, liveness) in node_liveness.iter_mut() { + if *id == origin { + *liveness = Instant::now(); + return; + } + } + + node_liveness.push((origin, Instant::now())); + + if matches!(req_type, RequestTypeIn::InitHeartbeat) { + tracing::debug!( + ?origin, + "initial heartbeat detected, saying hello to the new node" + ); + + let this = self.clone(); + tokio::spawn(async move { + if let Err(err) = this.emit_heartbeat(Some(origin)).await { + tracing::warn!( + "could not re-emit heartbeat after new node detection: {:?}", + err + ); + } + }); + } + } + + /// Send a request to a specific target node or broadcast it to all nodes if no target is specified. + async fn send_req(&self, req: RequestOut<'_>, target: Option) -> Result<(), Error> { + tracing::trace!(?req, "sending request"); + let head = ItemHeader::Req { target }; + let req = self.new_packet(head, &req)?; + self.driver.emit(&req).await.map_err(Error::from_driver)?; + Ok(()) + } + + /// Send a response to the node that sent the request. + fn send_res( + &self, + req_id: Sid, + req_origin: Uid, + res: Response, + ) -> impl Future>> + Send + 'static { + tracing::trace!(?res, "sending response for {req_id} req to {req_origin}"); + let driver = self.driver.clone(); + let head = ItemHeader::Res { + request: req_id, + target: req_origin, + }; + let res = self.new_packet(head, &res); + + async move { + driver.emit(&res?).await.map_err(Error::from_driver)?; + Ok(()) + } + } + + /// Await for all the responses from the remote servers. + /// If the target node is specified, only await for the response from that node. + async fn get_res( + &self, + req_id: Sid, + response_type: ResponseTypeId, + target: Option, + ) -> Result>, Error> { + // Check for specific target node + let remote_serv_cnt = if target.is_none() { + self.server_count().await?.saturating_sub(1) as usize + } else { + 1 + }; + let (tx, rx) = mpsc::channel(std::cmp::max(remote_serv_cnt, 1)); + self.responses.lock().unwrap().insert(req_id, tx); + let stream = ChanStream::new(rx) + .filter_map(|Item { header, data, .. }| { + let data = match rmp_serde::from_slice::>(&data) { + Ok(data) => Some(data), + Err(e) => { + tracing::warn!(header = ?header, "error decoding response: {e}"); + None + } + }; + future::ready(data) + }) + .filter(move |item| future::ready(ResponseTypeId::from(&item.r#type) == response_type)) + .take(remote_serv_cnt) + .take_until(tokio::time::sleep(self.config.request_timeout)); + let stream = DropStream::new(stream, self.responses.clone(), req_id); + Ok(stream) + } + + /// Emit a heartbeat to the specified target node or broadcast to all nodes. + async fn emit_heartbeat(&self, target: Option) -> Result<(), Error> { + // Send heartbeat when starting. + self.send_req( + RequestOut::new_empty(self.uid, RequestTypeOut::Heartbeat), + target, + ) + .await + } + + /// Emit an initial heartbeat to all nodes. + async fn emit_init_heartbeat(&self) -> Result<(), Error> { + // Send initial heartbeat when starting. + self.send_req( + RequestOut::new_empty(self.uid, RequestTypeOut::InitHeartbeat), + None, + ) + .await + } + fn new_packet(&self, head: ItemHeader, data: &impl Serialize) -> Result> { + let ns = &self.local.path(); + let uid = self.uid; + } +} + +/// The result of the init future. +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct InitRes(futures_core::future::BoxFuture<'static, Result<(), D::Error>>); + +impl Future for InitRes { + type Output = Result<(), D::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.0.as_mut().poll(cx) + } +} +impl Spawnable for InitRes { + fn spawn(self) { + tokio::spawn(async move { + if let Err(e) = self.0.await { + tracing::error!("error initializing adapter: {e}"); + } + }); + } +} From da03dc8b4013d3136bb2f15efdbdd7e6c6cf24f7 Mon Sep 17 00:00:00 2001 From: totodore Date: Wed, 7 May 2025 16:35:04 +0200 Subject: [PATCH 02/12] feat(adapter/postgres): `Driver` wip --- .../socketioxide-postgres/src/drivers/mod.rs | 9 ++ .../src/drivers/postgres.rs | 3 +- .../socketioxide-postgres/src/drivers/sqlx.rs | 87 ++++++++++++++++--- 3 files changed, 85 insertions(+), 14 deletions(-) diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs index 976f9821..10c9333a 100644 --- a/crates/socketioxide-postgres/src/drivers/mod.rs +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -1,3 +1,5 @@ +use serde::de::DeserializeOwned; + mod postgres; mod sqlx; @@ -7,9 +9,16 @@ pub type ChanItem = (String, String); /// It must share handlers/connection between its clones. pub trait Driver: Clone + Send + Sync + 'static { type Error: std::error::Error + Send + 'static; + type NotifStream: futures_core::Stream + Send + 'static; fn init(&self, table: &str, channels: &[&str]) -> impl Future>; + + fn listen( + &self, + channel: &str, + ) -> impl Future, Self::Error>>; + fn notify(&self, channel: &str, message: &str) -> impl Future>; } diff --git a/crates/socketioxide-postgres/src/drivers/postgres.rs b/crates/socketioxide-postgres/src/drivers/postgres.rs index 8e5b6408..86dda23f 100644 --- a/crates/socketioxide-postgres/src/drivers/postgres.rs +++ b/crates/socketioxide-postgres/src/drivers/postgres.rs @@ -2,8 +2,6 @@ use std::sync::Arc; use tokio_postgres::{Client, Connection}; -use crate::PostgresAdapterConfig; - use super::Driver; #[derive(Debug, Clone)] @@ -21,6 +19,7 @@ impl PostgresDriver { impl Driver for PostgresDriver { type Error = tokio_postgres::Error; + async fn init(&self, table: &str, channels: &[&str]) -> Result<(), Self::Error> { self.client .execute("CREATE TABLE $1 IF NOT EXISTS", &[&table]) diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs index fa63e56b..439f0ed3 100644 --- a/crates/socketioxide-postgres/src/drivers/sqlx.rs +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -1,31 +1,66 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::HashMap, + marker::PhantomData, + sync::{Arc, RwLock}, +}; -use sqlx::{PgPool, postgres::PgListener}; +use futures_core::Stream; +use serde::de::DeserializeOwned; +use sqlx::{ + PgPool, + postgres::{PgListener, PgNotification}, +}; use tokio::sync::mpsc; -use super::{ChanItem, Driver}; +use super::Driver; +type HandlerMap = HashMap>; #[derive(Debug, Clone)] pub struct SqlxDriver { client: PgPool, + handlers: Arc>, } impl SqlxDriver { pub fn new(client: PgPool) -> Self { - Self { client } + Self { + client, + handlers: Arc::new(RwLock::new(HashMap::new())), + } } +} - async fn spawn_listener(&self, mut listener: PgListener, tx: mpsc::Sender) { - while let Ok(notif) = listener - .recv() - .await - .inspect_err(|e| tracing::warn!(?e, "sqlx listener error")) - {} +pin_project_lite::pin_project! { + pub struct NotifStream { + #[pin] + rx: tokio::sync::mpsc::Receiver, + _phantom: std::marker::PhantomData T> + } +} +impl Stream for NotifStream { + type Item = T; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.rx.poll_recv(cx) { + std::task::Poll::Ready(_) => todo!(), + std::task::Poll::Pending => todo!(), + } + } +} +impl NotifStream { + pub fn new(rx: mpsc::Receiver) -> Self { + NotifStream { + rx, + _phantom: PhantomData::default(), + } } } impl Driver for SqlxDriver { type Error = sqlx::Error; - + type NotifStream = NotifStream; async fn init(&self, table: &str, channels: &[&str]) -> Result<(), Self::Error> { sqlx::query("CREATE TABLE $1 IF NOT EXISTS") .bind(&table) @@ -33,11 +68,39 @@ impl Driver for SqlxDriver { .await?; let mut listener = PgListener::connect_with(&self.client).await?; listener.listen_all(channels.iter().copied()).await?; + tokio::spawn(spawn_listener(self.handlers.clone(), listener)); Ok(()) } + async fn listen( + &self, + channel: &str, + ) -> Result, Self::Error> { + let (tx, rx) = mpsc::channel(255); + self.handlers.write().unwrap().insert(channel.into(), tx); + Ok(NotifStream::new(rx)) + } async fn notify(&self, channel: &str, msg: &str) -> Result<(), Self::Error> { - todo!() + sqlx::query("NOTIFY $1 $2") + .bind(channel) + .bind(msg) + .execute(&self.client) + .await?; + Ok(()) + } +} + +async fn spawn_listener(handlers: Arc>, mut listener: PgListener) { + while let Ok(notif) = listener + .recv() + .await + .inspect_err(|e| tracing::warn!(?e, "sqlx listener error")) + { + if let Some(tx) = handlers.read().unwrap().get(notif.channel()) { + tx.try_send(notif); + } else { + tracing::warn!("handler not found for channel {}", notif.channel()); + } } } From 092e4474d14bbd9a50e5a64c086a57ae9db26da9 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 22 Jun 2025 20:44:38 +0200 Subject: [PATCH 03/12] wip --- Cargo.lock | 1 + crates/socketioxide-postgres/Cargo.toml | 1 + .../socketioxide-postgres/src/drivers/mod.rs | 17 +- .../socketioxide-postgres/src/drivers/sqlx.rs | 41 ++- crates/socketioxide-postgres/src/lib.rs | 109 ++----- crates/socketioxide-postgres/src/stream.rs | 265 ++++++++++++++++++ 6 files changed, 333 insertions(+), 101 deletions(-) create mode 100644 crates/socketioxide-postgres/src/stream.rs diff --git a/Cargo.lock b/Cargo.lock index 20d29952..013253bf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2540,6 +2540,7 @@ dependencies = [ "pin-project-lite", "rmp-serde", "serde", + "serde_json", "smallvec", "socketioxide", "socketioxide-core", diff --git a/crates/socketioxide-postgres/Cargo.toml b/crates/socketioxide-postgres/Cargo.toml index df86ade0..537fab4b 100644 --- a/crates/socketioxide-postgres/Cargo.toml +++ b/crates/socketioxide-postgres/Cargo.toml @@ -25,6 +25,7 @@ futures-core.workspace = true futures-util.workspace = true pin-project-lite.workspace = true serde.workspace = true +serde_json.workspace = true smallvec = { workspace = true, features = ["serde"] } tokio = { workspace = true, features = ["time", "rt", "sync"] } rmp-serde.workspace = true diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs index 10c9333a..d038e341 100644 --- a/crates/socketioxide-postgres/src/drivers/mod.rs +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -1,15 +1,19 @@ -use serde::de::DeserializeOwned; +use futures_core::Stream; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; mod postgres; mod sqlx; pub type ChanItem = (String, String); +#[derive(Deserialize)] +pub struct Item {} + /// The driver trait can be used to support different LISTEN/NOTIFY backends. /// It must share handlers/connection between its clones. pub trait Driver: Clone + Send + Sync + 'static { type Error: std::error::Error + Send + 'static; - type NotifStream: futures_core::Stream + Send + 'static; + type NotifStream: Stream + Send + 'static; fn init(&self, table: &str, channels: &[&str]) -> impl Future>; @@ -17,8 +21,11 @@ pub trait Driver: Clone + Send + Sync + 'static { fn listen( &self, channel: &str, - ) -> impl Future, Self::Error>>; + ) -> impl Future, Self::Error>> + Send; - fn notify(&self, channel: &str, message: &str) - -> impl Future>; + fn notify( + &self, + channel: &str, + message: &T, + ) -> impl Future> + Send; } diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs index 439f0ed3..c798fa0b 100644 --- a/crates/socketioxide-postgres/src/drivers/sqlx.rs +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -5,26 +5,30 @@ use std::{ }; use futures_core::Stream; -use serde::de::DeserializeOwned; +use serde::{Serialize, de::DeserializeOwned}; use sqlx::{ PgPool, postgres::{PgListener, PgNotification}, }; use tokio::sync::mpsc; +use crate::PostgresAdapterConfig; + use super::Driver; -type HandlerMap = HashMap>; +type HandlerMap = HashMap>; #[derive(Debug, Clone)] pub struct SqlxDriver { client: PgPool, handlers: Arc>, + config: PostgresAdapterConfig, } impl SqlxDriver { - pub fn new(client: PgPool) -> Self { + pub fn new(client: PgPool, config: PostgresAdapterConfig) -> Self { Self { client, handlers: Arc::new(RwLock::new(HashMap::new())), + config, } } } @@ -32,7 +36,7 @@ impl SqlxDriver { pin_project_lite::pin_project! { pub struct NotifStream { #[pin] - rx: tokio::sync::mpsc::Receiver, + rx: mpsc::UnboundedReceiver, _phantom: std::marker::PhantomData T> } } @@ -50,7 +54,7 @@ impl Stream for NotifStream { } } impl NotifStream { - pub fn new(rx: mpsc::Receiver) -> Self { + pub fn new(rx: mpsc::UnboundedReceiver) -> Self { NotifStream { rx, _phantom: PhantomData::default(), @@ -76,18 +80,27 @@ impl Driver for SqlxDriver { &self, channel: &str, ) -> Result, Self::Error> { - let (tx, rx) = mpsc::channel(255); + let (tx, rx) = mpsc::unbounded_channel(); self.handlers.write().unwrap().insert(channel.into(), tx); Ok(NotifStream::new(rx)) } - async fn notify(&self, channel: &str, msg: &str) -> Result<(), Self::Error> { - sqlx::query("NOTIFY $1 $2") - .bind(channel) - .bind(msg) - .execute(&self.client) - .await?; - Ok(()) + fn notify( + &self, + channel: &str, + req: &T, + ) -> impl Future> + Send { + let client = self.client.clone(); + //TODO: handle error + let msg = serde_json::to_string(req).unwrap(); + async move { + sqlx::query("NOTIFY $1 $2") + .bind(channel) + .bind(msg) + .execute(&client) + .await?; + Ok(()) + } } } @@ -98,7 +111,7 @@ async fn spawn_listener(handlers: Arc>, mut listener: PgListe .inspect_err(|e| tracing::warn!(?e, "sqlx listener error")) { if let Some(tx) = handlers.read().unwrap().get(notif.channel()) { - tx.try_send(notif); + tx.send(notif); } else { tracing::warn!("handler not found for channel {}", notif.channel()); } diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index a3441158..91196003 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -32,15 +32,15 @@ //! use drivers::Driver; - use futures_core::Stream; -use serde::{Serialize, de::DeserializeOwned}; +use futures_util::StreamExt; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; use socketioxide_core::{ Sid, Uid, adapter::{ BroadcastOptions, CoreAdapter, CoreLocalAdapter, DefinedAdapter, RemoteSocketData, Room, RoomParam, SocketEmitter, Spawnable, - errors::AdapterError, + errors::{AdapterError, BroadcastError}, remote_packet::{ RequestIn, RequestOut, RequestTypeIn, RequestTypeOut, Response, ResponseType, ResponseTypeId, @@ -50,7 +50,6 @@ use socketioxide_core::{ }; use std::{ borrow::Cow, - collections::HashMap, fmt, future, pin::Pin, sync::{Arc, Mutex}, @@ -60,6 +59,7 @@ use std::{ use tokio::sync::mpsc; mod drivers; +mod stream; /// The configuration of the [`MongoDbAdapter`]. #[derive(Debug, Clone)] @@ -103,11 +103,6 @@ pub enum Error { Decode(#[from] rmp_serde::decode::Error), } -impl Error { - fn from_driver(err: R::Error) -> Self { - Self::Driver(err) - } -} impl fmt::Debug for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -124,7 +119,9 @@ impl From> for AdapterError { } } -pub(crate) type ResponseHandlers = HashMap>; +/// An event we should answer to +#[derive(Debug, Deserialize)] +struct Event {} /// The postgres adapter implementation. /// It is generic over the [`Driver`] used to communicate with the postgres server. @@ -142,8 +139,6 @@ pub struct CustomPostgresAdapter { local: CoreLocalAdapter, /// A map of nodes liveness, with the last time remote nodes were seen alive. nodes_liveness: Mutex>, - /// A map of response handlers used to await for responses from the remote servers. - responses: Arc>, } impl DefinedAdapter for CustomPostgresAdapter {} @@ -161,13 +156,12 @@ impl CoreAdapter for CustomPostgresAdapter driver: state.driver.clone(), config: state.config.clone(), nodes_liveness: Mutex::new(Vec::new()), - responses: Arc::new(Mutex::new(HashMap::new())), } } fn init(self: Arc, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes { let fut = async move { - let stream = self.driver.watch(self.uid, self.local.path()).await?; + let stream = self.driver.listen("event").await?; tokio::spawn(self.clone().handle_ev_stream(stream)); tokio::spawn(self.clone().heartbeat_job()); @@ -254,9 +248,9 @@ impl CoreAdapter for CustomPostgresAdapter let remote_serv_cnt = self.server_count().await?.saturating_sub(1); tracing::trace!(?remote_serv_cnt, "expecting acks from remote servers"); + let res = self.driver.listen("").await?; let (tx, rx) = mpsc::channel(self.config.ack_response_buffer + remote_serv_cnt as usize); - self.responses.lock().unwrap().insert(req_id, tx); self.send_req(req, None).await?; let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout); @@ -375,56 +369,14 @@ impl CustomPostgresAdapter { } } - async fn handle_ev_stream( - self: Arc, - mut stream: impl Stream> + Unpin, - ) { - while let Some(item) = stream.next().await { - match item { - Ok(Item { - header: ItemHeader::Req { target, .. }, - data, - .. - }) if target.is_none_or(|id| id == self.uid) => { - tracing::debug!(?target, "request header"); - if let Err(e) = self.recv_req(data).await { - tracing::warn!("error receiving request from driver: {e}"); - } - } - Ok(Item { - header: ItemHeader::Req { target, .. }, - .. - }) => { - tracing::debug!( - ?target, - "receiving request which is not for us, skipping..." - ); - } - Ok( - item @ Item { - header: ItemHeader::Res { request, .. }, - .. - }, - ) => { - tracing::trace!(?request, "received response"); - let handlers = self.responses.lock().unwrap(); - if let Some(tx) = handlers.get(&request) { - if let Err(e) = tx.try_send(item) { - tracing::warn!("error sending response to handler: {e}"); - } - } else { - tracing::warn!(?request, ?handlers, "could not find req handler"); - } - } - Err(e) => { - tracing::warn!("error receiving event from driver: {e}"); - } - } + async fn handle_ev_stream(self: Arc, stream: impl Stream) { + futures_util::pin_mut!(stream); + while let Some(req) = stream.next().await { + self.recv_req(req); } } - async fn recv_req(self: &Arc, req: Vec) -> Result<(), Error> { - let req = rmp_serde::from_slice::(&req)?; + fn recv_req(self: &Arc, req: RequestIn) { tracing::trace!(?req, "incoming request"); match (req.r#type, req.opts) { (RequestTypeIn::Broadcast(p), Some(opts)) => self.recv_broadcast(opts, p), @@ -443,7 +395,6 @@ impl CustomPostgresAdapter { } _ => (), } - Ok(()) } fn recv_broadcast(&self, opts: BroadcastOptions, packet: Packet) { @@ -586,31 +537,29 @@ impl CustomPostgresAdapter { /// Send a request to a specific target node or broadcast it to all nodes if no target is specified. async fn send_req(&self, req: RequestOut<'_>, target: Option) -> Result<(), Error> { tracing::trace!(?req, "sending request"); - let head = ItemHeader::Req { target }; - let req = self.new_packet(head, &req)?; - self.driver.emit(&req).await.map_err(Error::from_driver)?; + // let head = ItemHeader::Req { target }; + // let req = self.new_packet(head, &req)?; + self.driver + .notify("yolo", &req) + .await + .map_err(Error::Driver)?; Ok(()) } /// Send a response to the node that sent the request. - fn send_res( + async fn send_res( &self, req_id: Sid, req_origin: Uid, res: Response, - ) -> impl Future>> + Send + 'static { + ) -> Result<(), Error> { tracing::trace!(?res, "sending response for {req_id} req to {req_origin}"); - let driver = self.driver.clone(); - let head = ItemHeader::Res { - request: req_id, - target: req_origin, - }; - let res = self.new_packet(head, &res); - async move { - driver.emit(&res?).await.map_err(Error::from_driver)?; - Ok(()) - } + self.driver + .notify("response", &res) + .await + .map_err(Error::Driver)?; + Ok(()) } /// Await for all the responses from the remote servers. @@ -666,10 +615,6 @@ impl CustomPostgresAdapter { ) .await } - fn new_packet(&self, head: ItemHeader, data: &impl Serialize) -> Result> { - let ns = &self.local.path(); - let uid = self.uid; - } } /// The result of the init future. diff --git a/crates/socketioxide-postgres/src/stream.rs b/crates/socketioxide-postgres/src/stream.rs new file mode 100644 index 00000000..e27960e5 --- /dev/null +++ b/crates/socketioxide-postgres/src/stream.rs @@ -0,0 +1,265 @@ +use std::{ + fmt, + pin::Pin, + sync::{Arc, Mutex}, + task::{self, Poll}, + time::Duration, +}; + +use futures_core::{FusedStream, Stream}; +use futures_util::{StreamExt, stream::TakeUntil}; +use pin_project_lite::pin_project; +use serde::de::DeserializeOwned; +use socketioxide_core::{ + Sid, + adapter::AckStreamItem, + adapter::remote_packet::{Response, ResponseType}, +}; +use tokio::{sync::mpsc, time}; + +pin_project! { + /// A stream of acknowledgement messages received from the local and remote servers. + /// It merges the local ack stream with the remote ack stream from all the servers. + // The server_cnt is the number of servers that are expected to send a AckCount message. + // It is decremented each time a AckCount message is received. + // + // The ack_cnt is the number of acks that are expected to be received. It is the sum of all the the ack counts. + // And it is decremented each time an ack is received. + // + // Therefore an exhausted stream correspond to `ack_cnt == 0` and `server_cnt == 0`. + pub struct AckStream { + #[pin] + local: S, + #[pin] + remote: DropStream>, + ack_cnt: u32, + total_ack_cnt: usize, + serv_cnt: u16, + } +} + +impl AckStream { + pub fn new( + local: S, + rx: mpsc::Receiver, + timeout: Duration, + serv_cnt: u16, + req_id: Sid, + ) -> Self { + let remote = ChanStream::new(rx).take_until(time::sleep(timeout)); + let remote = DropStream::new(remote, handlers, req_id); + Self { + local, + ack_cnt: 0, + total_ack_cnt: 0, + serv_cnt, + } + } + pub fn new_local(local: S) -> Self { + let handlers = Arc::new(Mutex::new(ResponseHandlers::new())); + let rx = mpsc::channel(1).1; + let remote = ChanStream::new(rx).take_until(time::sleep(Duration::ZERO)); + let remote = DropStream::new(remote, handlers, Sid::ZERO); + Self { + local, + remote, + ack_cnt: 0, + total_ack_cnt: 0, + serv_cnt: 0, + } + } +} +impl AckStream +where + Err: DeserializeOwned + fmt::Debug, + S: Stream> + FusedStream, +{ + /// Poll the remote stream. First the count of acks is received, then the acks are received. + /// We expect `serv_cnt` of `BroadcastAckCount` messages to be received, then we expect + /// `ack_cnt` of `BroadcastAck` messages. + fn poll_remote( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll>> { + // remote stream is not fused, so we need to check if it is terminated + if FusedStream::is_terminated(&self) { + return Poll::Ready(None); + } + let mut projection = self.project(); + loop { + match projection.remote.as_mut().poll_next(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => return Poll::Ready(None), + Poll::Ready(Some(Item { header, data, .. })) => { + let res = rmp_serde::from_slice::>(&data); + match res { + Ok(Response { + node_id: uid, + r#type: ResponseType::BroadcastAckCount(count), + }) if *projection.serv_cnt > 0 => { + tracing::trace!(?uid, ?header, "receiving broadcast ack count {count}"); + *projection.ack_cnt += count; + *projection.total_ack_cnt += count as usize; + *projection.serv_cnt -= 1; + } + Ok(Response { + node_id: uid, + r#type: ResponseType::BroadcastAck((sid, res)), + }) if *projection.ack_cnt > 0 => { + tracing::trace!( + ?uid, + ?header, + "receiving broadcast ack {sid} {:?}", + res + ); + *projection.ack_cnt -= 1; + return Poll::Ready(Some((sid, res))); + } + Ok(Response { node_id: uid, .. }) => { + tracing::warn!(?uid, ?header, "unexpected response type"); + } + Err(e) => { + tracing::warn!("error decoding ack response: {e}"); + } + } + } + } + } + } +} +impl Stream for AckStream +where + E: DeserializeOwned + fmt::Debug, + S: Stream> + FusedStream, +{ + type Item = AckStreamItem; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + match self.as_mut().project().local.poll_next(cx) { + Poll::Pending => match self.poll_remote(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Some(item)) => Poll::Ready(Some(item)), + Poll::Ready(None) => Poll::Pending, + }, + Poll::Ready(Some(item)) => Poll::Ready(Some(item)), + Poll::Ready(None) => self.poll_remote(cx), + } + } + + fn size_hint(&self) -> (usize, Option) { + let (lower, upper) = self.local.size_hint(); + (lower, upper.map(|upper| upper + self.total_ack_cnt)) + } +} + +impl FusedStream for AckStream +where + Err: DeserializeOwned + fmt::Debug, + S: Stream> + FusedStream, +{ + /// The stream is terminated if: + /// * The local stream is terminated. + /// * All the servers have sent the expected ack count. + /// * We have received all the expected acks. + fn is_terminated(&self) -> bool { + // remote stream is terminated if the timeout is reached + let remote_term = (self.ack_cnt == 0 && self.serv_cnt == 0) || self.remote.is_terminated(); + self.local.is_terminated() && remote_term + } +} +impl fmt::Debug for AckStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AckStream") + .field("ack_cnt", &self.ack_cnt) + .field("total_ack_cnt", &self.total_ack_cnt) + .field("serv_cnt", &self.serv_cnt) + .finish() + } +} + +pin_project! { + /// A stream of messages received from a channel. + pub struct ChanStream { + #[pin] + rx: mpsc::Receiver + } +} +impl ChanStream { + pub fn new(rx: mpsc::Receiver) -> Self { + Self { rx } + } +} +impl Stream for ChanStream { + type Item = Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + self.project().rx.poll_recv(cx) + } +} +pin_project! { + /// A stream that unsubscribes from its source channel when dropped. + pub struct DropStream { + #[pin] + stream: S, + req_id: Sid, + handlers: Arc> + } + impl PinnedDrop for DropStream { + fn drop(this: Pin<&mut Self>) { + let stream = this.project(); + let chan = stream.req_id; + tracing::debug!(?chan, "dropping stream"); + stream.handlers.lock().unwrap().remove(chan); + } + } +} +impl DropStream { + pub fn new(stream: S, handlers: Arc>, req_id: Sid) -> Self { + Self { + stream, + handlers, + req_id, + } + } +} +impl Stream for DropStream { + type Item = S::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + self.project().stream.poll_next(cx) + } +} +impl FusedStream for DropStream { + fn is_terminated(&self) -> bool { + self.stream.is_terminated() + } +} + +#[cfg(test)] +mod tests { + use futures_core::FusedStream; + use futures_util::StreamExt; + use socketioxide_core::{Sid, Value}; + + use super::AckStream; + + #[tokio::test] + async fn local_ack_stream_should_have_a_closed_remote() { + let sid = Sid::new(); + let local = futures_util::stream::once(async move { + (sid, Ok::<_, ()>(Value::Str("local".into(), None))) + }); + let stream = AckStream::new_local(local); + futures_util::pin_mut!(stream); + assert_eq!(stream.ack_cnt, 0); + assert_eq!(stream.total_ack_cnt, 0); + assert_eq!(stream.serv_cnt, 0); + assert!(!stream.local.is_terminated()); + assert!(!stream.is_terminated()); + let data = stream.next().await; + assert!( + matches!(data, Some((id, Ok(Value::Str(msg, None)))) if id == sid && msg == "local") + ); + assert_eq!(stream.next().await, None); + assert!(stream.is_terminated()); + } +} From 818a3066a1bfea961d47fecc2f3d9f85b9ec9aca Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 22 Feb 2026 22:00:53 +0100 Subject: [PATCH 04/12] wip --- .../socketioxide-postgres/src/drivers/mod.rs | 38 +++++- .../src/drivers/postgres.rs | 15 ++- .../socketioxide-postgres/src/drivers/sqlx.rs | 57 ++++----- crates/socketioxide-postgres/src/lib.rs | 66 ++++++----- crates/socketioxide-postgres/src/stream.rs | 108 ++++-------------- 5 files changed, 133 insertions(+), 151 deletions(-) diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs index d038e341..a0e92ed9 100644 --- a/crates/socketioxide-postgres/src/drivers/mod.rs +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -1,7 +1,8 @@ use futures_core::Stream; use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use tokio::sync::mpsc; -mod postgres; +// mod postgres; mod sqlx; pub type ChanItem = (String, String); @@ -13,7 +14,8 @@ pub struct Item {} /// It must share handlers/connection between its clones. pub trait Driver: Clone + Send + Sync + 'static { type Error: std::error::Error + Send + 'static; - type NotifStream: Stream + Send + 'static; + type NotifStream: Stream + Send + 'static; + type Notification: Notification; fn init(&self, table: &str, channels: &[&str]) -> impl Future>; @@ -21,7 +23,7 @@ pub trait Driver: Clone + Send + Sync + 'static { fn listen( &self, channel: &str, - ) -> impl Future, Self::Error>> + Send; + ) -> impl Future> + Send; fn notify( &self, @@ -29,3 +31,33 @@ pub trait Driver: Clone + Send + Sync + 'static { message: &T, ) -> impl Future> + Send; } + +pub trait Notification: Send + 'static { + fn channel(&self) -> &str; + fn payload(&self) -> &str; +} + +pin_project_lite::pin_project! { + pub struct NotifStream { + #[pin] + rx: mpsc::UnboundedReceiver, + } +} +impl Stream for NotifStream { + type Item = T; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.rx.poll_recv(cx) { + std::task::Poll::Ready(notif) => std::task::Poll::Ready(notif), + std::task::Poll::Pending => std::task::Poll::Pending, + } + } +} +impl NotifStream { + pub fn new(rx: mpsc::UnboundedReceiver) -> Self { + NotifStream { rx } + } +} diff --git a/crates/socketioxide-postgres/src/drivers/postgres.rs b/crates/socketioxide-postgres/src/drivers/postgres.rs index 86dda23f..0d6d408a 100644 --- a/crates/socketioxide-postgres/src/drivers/postgres.rs +++ b/crates/socketioxide-postgres/src/drivers/postgres.rs @@ -19,15 +19,28 @@ impl PostgresDriver { impl Driver for PostgresDriver { type Error = tokio_postgres::Error; + type NotifStream; async fn init(&self, table: &str, channels: &[&str]) -> Result<(), Self::Error> { self.client .execute("CREATE TABLE $1 IF NOT EXISTS", &[&table]) .await?; + Ok(()) } - async fn notify(&self, channel: &str, msg: &str) -> Result<(), Self::Error> { + fn listen( + &self, + channel: &str, + ) -> impl Future, Self::Error>> + Send { + todo!() + } + + fn notify( + &self, + channel: &str, + message: &T, + ) -> impl Future> + Send { todo!() } } diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs index c798fa0b..9a0c6d86 100644 --- a/crates/socketioxide-postgres/src/drivers/sqlx.rs +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -1,10 +1,8 @@ use std::{ collections::HashMap, - marker::PhantomData, sync::{Arc, RwLock}, }; -use futures_core::Stream; use serde::{Serialize, de::DeserializeOwned}; use sqlx::{ PgPool, @@ -12,9 +10,10 @@ use sqlx::{ }; use tokio::sync::mpsc; -use crate::PostgresAdapterConfig; +use crate::{PostgresAdapterConfig, drivers::NotifStream}; use super::Driver; + type HandlerMap = HashMap>; #[derive(Debug, Clone)] @@ -23,6 +22,7 @@ pub struct SqlxDriver { handlers: Arc>, config: PostgresAdapterConfig, } + impl SqlxDriver { pub fn new(client: PgPool, config: PostgresAdapterConfig) -> Self { Self { @@ -33,43 +33,17 @@ impl SqlxDriver { } } -pin_project_lite::pin_project! { - pub struct NotifStream { - #[pin] - rx: mpsc::UnboundedReceiver, - _phantom: std::marker::PhantomData T> - } -} -impl Stream for NotifStream { - type Item = T; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - match self.rx.poll_recv(cx) { - std::task::Poll::Ready(_) => todo!(), - std::task::Poll::Pending => todo!(), - } - } -} -impl NotifStream { - pub fn new(rx: mpsc::UnboundedReceiver) -> Self { - NotifStream { - rx, - _phantom: PhantomData::default(), - } - } -} - impl Driver for SqlxDriver { type Error = sqlx::Error; - type NotifStream = NotifStream; + type NotifStream = NotifStream; + type Notification = PgNotification; + async fn init(&self, table: &str, channels: &[&str]) -> Result<(), Self::Error> { sqlx::query("CREATE TABLE $1 IF NOT EXISTS") .bind(&table) .execute(&self.client) .await?; + let mut listener = PgListener::connect_with(&self.client).await?; listener.listen_all(channels.iter().copied()).await?; tokio::spawn(spawn_listener(self.handlers.clone(), listener)); @@ -79,9 +53,12 @@ impl Driver for SqlxDriver { async fn listen( &self, channel: &str, - ) -> Result, Self::Error> { + ) -> Result { let (tx, rx) = mpsc::unbounded_channel(); - self.handlers.write().unwrap().insert(channel.into(), tx); + self.handlers + .write() + .unwrap() + .insert(channel.to_string(), tx); Ok(NotifStream::new(rx)) } @@ -117,3 +94,13 @@ async fn spawn_listener(handlers: Arc>, mut listener: PgListe } } } + +impl super::Notification for PgNotification { + fn channel(&self) -> &str { + PgNotification::channel(self) + } + + fn payload(&self) -> &str { + PgNotification::payload(self) + } +} diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index 91196003..92e18675 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -2,7 +2,7 @@ #![warn( clippy::all, clippy::todo, - clippy::empty_enum, + clippy::empty_enums, clippy::mem_forget, clippy::unused_self, clippy::filter_map_next, @@ -11,7 +11,7 @@ clippy::match_wildcard_for_single_variants, clippy::if_let_mutex, clippy::await_holding_lock, - clippy::match_on_vec_items, + clippy::indexing_slicing, clippy::imprecise_flops, clippy::suboptimal_flops, clippy::lossy_float_literal, @@ -50,13 +50,15 @@ use socketioxide_core::{ }; use std::{ borrow::Cow, + collections::HashMap, fmt, future, pin::Pin, sync::{Arc, Mutex}, task::{Context, Poll}, time::{Duration, Instant}, }; -use tokio::sync::mpsc; + +use crate::{drivers::Notification, stream::AckStream}; mod drivers; mod stream; @@ -123,11 +125,16 @@ impl From> for AdapterError { #[derive(Debug, Deserialize)] struct Event {} +pub struct PostgresAdapterCtr { + driver: D, + config: PostgresAdapterConfig, +} + /// The postgres adapter implementation. /// It is generic over the [`Driver`] used to communicate with the postgres server. /// And over the [`SocketEmitter`] used to communicate with the local server. This allows to /// avoid cyclic dependencies between the adapter, `socketioxide-core` and `socketioxide` crates. -pub struct CustomPostgresAdapter { +pub struct CustomPostgresAdapter { /// The driver used by the adapter. This is used to communicate with the postgres server. /// All the postgres adapter instances share the same driver. driver: D, @@ -139,13 +146,15 @@ pub struct CustomPostgresAdapter { local: CoreLocalAdapter, /// A map of nodes liveness, with the last time remote nodes were seen alive. nodes_liveness: Mutex>, + /// A map of response handlers used to await for responses from the remote servers. + responses: Arc>>, } -impl DefinedAdapter for CustomPostgresAdapter {} +impl DefinedAdapter for CustomPostgresAdapter {} impl CoreAdapter for CustomPostgresAdapter { type Error = Error; type State = PostgresAdapterCtr; - type AckStream = AckStream; + type AckStream = AckStream; type InitRes = InitRes; fn new(state: &Self::State, local: CoreLocalAdapter) -> Self { @@ -156,6 +165,7 @@ impl CoreAdapter for CustomPostgresAdapter driver: state.driver.clone(), config: state.config.clone(), nodes_liveness: Mutex::new(Vec::new()), + responses: Arc::new(Mutex::new(HashMap::new())), } } @@ -248,19 +258,16 @@ impl CoreAdapter for CustomPostgresAdapter let remote_serv_cnt = self.server_count().await?.saturating_sub(1); tracing::trace!(?remote_serv_cnt, "expecting acks from remote servers"); - let res = self.driver.listen("").await?; + let remote = self.driver.listen("").await?; - let (tx, rx) = mpsc::channel(self.config.ack_response_buffer + remote_serv_cnt as usize); self.send_req(req, None).await?; let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout); Ok(AckStream::new( local, - rx, + remote, self.config.request_timeout, remote_serv_cnt, - req_id, - self.responses.clone(), )) } @@ -283,7 +290,7 @@ impl CoreAdapter for CustomPostgresAdapter let req = RequestOut::new(self.uid, RequestTypeOut::AllRooms, &opts); let req_id = req.id; - // First get the remote stream because mongodb might send + // First get the remote stream because postgres might send // the responses before subscription is done. let stream = self .get_res::<()>(req_id, ResponseTypeId::AllRooms, opts.server_id) @@ -547,19 +554,22 @@ impl CustomPostgresAdapter { } /// Send a response to the node that sent the request. - async fn send_res( + fn send_res( &self, req_id: Sid, req_origin: Uid, res: Response, - ) -> Result<(), Error> { + ) -> impl Future>> + 'static { tracing::trace!(?res, "sending response for {req_id} req to {req_origin}"); - - self.driver - .notify("response", &res) - .await - .map_err(Error::Driver)?; - Ok(()) + let driver = self.driver.clone(); + //TODO: is this the right way? + async move { + driver + .notify("response", &res) + .await + .map_err(Error::Driver)?; + Ok(()) + } } /// Await for all the responses from the remote servers. @@ -576,14 +586,16 @@ impl CustomPostgresAdapter { } else { 1 }; - let (tx, rx) = mpsc::channel(std::cmp::max(remote_serv_cnt, 1)); - self.responses.lock().unwrap().insert(req_id, tx); - let stream = ChanStream::new(rx) - .filter_map(|Item { header, data, .. }| { - let data = match rmp_serde::from_slice::>(&data) { + + let stream = self.driver.listen("test").await.unwrap(); + self.responses.lock().unwrap().insert(req_id, stream); + + let stream = stream + .filter_map(|notif| { + let data = match serde_json::from_str::>(notif.payload()) { Ok(data) => Some(data), Err(e) => { - tracing::warn!(header = ?header, "error decoding response: {e}"); + tracing::warn!(channel = %notif.channel(), "error decoding response: {e}"); None } }; @@ -592,13 +604,13 @@ impl CustomPostgresAdapter { .filter(move |item| future::ready(ResponseTypeId::from(&item.r#type) == response_type)) .take(remote_serv_cnt) .take_until(tokio::time::sleep(self.config.request_timeout)); + let stream = DropStream::new(stream, self.responses.clone(), req_id); Ok(stream) } /// Emit a heartbeat to the specified target node or broadcast to all nodes. async fn emit_heartbeat(&self, target: Option) -> Result<(), Error> { - // Send heartbeat when starting. self.send_req( RequestOut::new_empty(self.uid, RequestTypeOut::Heartbeat), target, diff --git a/crates/socketioxide-postgres/src/stream.rs b/crates/socketioxide-postgres/src/stream.rs index e27960e5..5e7a945f 100644 --- a/crates/socketioxide-postgres/src/stream.rs +++ b/crates/socketioxide-postgres/src/stream.rs @@ -1,7 +1,6 @@ use std::{ fmt, pin::Pin, - sync::{Arc, Mutex}, task::{self, Poll}, time::Duration, }; @@ -11,12 +10,13 @@ use futures_util::{StreamExt, stream::TakeUntil}; use pin_project_lite::pin_project; use serde::de::DeserializeOwned; use socketioxide_core::{ - Sid, adapter::AckStreamItem, adapter::remote_packet::{Response, ResponseType}, }; use tokio::{sync::mpsc, time}; +use crate::drivers::{NotifStream, Notification}; + pin_project! { /// A stream of acknowledgement messages received from the local and remote servers. /// It merges the local ack stream with the remote ack stream from all the servers. @@ -27,39 +27,32 @@ pin_project! { // And it is decremented each time an ack is received. // // Therefore an exhausted stream correspond to `ack_cnt == 0` and `server_cnt == 0`. - pub struct AckStream { + pub struct AckStream { #[pin] local: S, #[pin] - remote: DropStream>, + remote: TakeUntil, time::Sleep>, ack_cnt: u32, total_ack_cnt: usize, serv_cnt: u16, } } -impl AckStream { - pub fn new( - local: S, - rx: mpsc::Receiver, - timeout: Duration, - serv_cnt: u16, - req_id: Sid, - ) -> Self { - let remote = ChanStream::new(rx).take_until(time::sleep(timeout)); - let remote = DropStream::new(remote, handlers, req_id); +impl AckStream { + pub fn new(local: S, remote: NotifStream, timeout: Duration, serv_cnt: u16) -> Self { + let remote = remote.take_until(time::sleep(timeout)); Self { local, ack_cnt: 0, total_ack_cnt: 0, serv_cnt, + remote, } } + pub fn new_local(local: S) -> Self { - let handlers = Arc::new(Mutex::new(ResponseHandlers::new())); - let rx = mpsc::channel(1).1; - let remote = ChanStream::new(rx).take_until(time::sleep(Duration::ZERO)); - let remote = DropStream::new(remote, handlers, Sid::ZERO); + let rx = mpsc::unbounded_channel().1; + let remote = NotifStream::new(rx).take_until(time::sleep(Duration::ZERO)); Self { local, remote, @@ -69,7 +62,7 @@ impl AckStream { } } } -impl AckStream +impl AckStream where Err: DeserializeOwned + fmt::Debug, S: Stream> + FusedStream, @@ -90,14 +83,15 @@ where match projection.remote.as_mut().poll_next(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(None) => return Poll::Ready(None), - Poll::Ready(Some(Item { header, data, .. })) => { - let res = rmp_serde::from_slice::>(&data); + Poll::Ready(Some(notif)) => { + let channel = notif.channel(); + let res = serde_json::from_str::>(notif.payload()); match res { Ok(Response { node_id: uid, r#type: ResponseType::BroadcastAckCount(count), }) if *projection.serv_cnt > 0 => { - tracing::trace!(?uid, ?header, "receiving broadcast ack count {count}"); + tracing::trace!(?uid, channel, "receiving broadcast ack count {count}"); *projection.ack_cnt += count; *projection.total_ack_cnt += count as usize; *projection.serv_cnt -= 1; @@ -108,7 +102,7 @@ where }) if *projection.ack_cnt > 0 => { tracing::trace!( ?uid, - ?header, + channel, "receiving broadcast ack {sid} {:?}", res ); @@ -116,7 +110,7 @@ where return Poll::Ready(Some((sid, res))); } Ok(Response { node_id: uid, .. }) => { - tracing::warn!(?uid, ?header, "unexpected response type"); + tracing::warn!(?uid, channel, "unexpected response type"); } Err(e) => { tracing::warn!("error decoding ack response: {e}"); @@ -127,10 +121,11 @@ where } } } -impl Stream for AckStream +impl Stream for AckStream where E: DeserializeOwned + fmt::Debug, S: Stream> + FusedStream, + T: Notification, { type Item = AckStreamItem; fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { @@ -151,10 +146,11 @@ where } } -impl FusedStream for AckStream +impl FusedStream for AckStream where Err: DeserializeOwned + fmt::Debug, S: Stream> + FusedStream, + T: Notification, { /// The stream is terminated if: /// * The local stream is terminated. @@ -166,7 +162,7 @@ where self.local.is_terminated() && remote_term } } -impl fmt::Debug for AckStream { +impl fmt::Debug for AckStream { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("AckStream") .field("ack_cnt", &self.ack_cnt) @@ -176,64 +172,6 @@ impl fmt::Debug for AckStream { } } -pin_project! { - /// A stream of messages received from a channel. - pub struct ChanStream { - #[pin] - rx: mpsc::Receiver - } -} -impl ChanStream { - pub fn new(rx: mpsc::Receiver) -> Self { - Self { rx } - } -} -impl Stream for ChanStream { - type Item = Item; - - fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - self.project().rx.poll_recv(cx) - } -} -pin_project! { - /// A stream that unsubscribes from its source channel when dropped. - pub struct DropStream { - #[pin] - stream: S, - req_id: Sid, - handlers: Arc> - } - impl PinnedDrop for DropStream { - fn drop(this: Pin<&mut Self>) { - let stream = this.project(); - let chan = stream.req_id; - tracing::debug!(?chan, "dropping stream"); - stream.handlers.lock().unwrap().remove(chan); - } - } -} -impl DropStream { - pub fn new(stream: S, handlers: Arc>, req_id: Sid) -> Self { - Self { - stream, - handlers, - req_id, - } - } -} -impl Stream for DropStream { - type Item = S::Item; - - fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - self.project().stream.poll_next(cx) - } -} -impl FusedStream for DropStream { - fn is_terminated(&self) -> bool { - self.stream.is_terminated() - } -} - #[cfg(test)] mod tests { use futures_core::FusedStream; From 001740372a61fdba6fe273e0103aa02aa7677263 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 29 Mar 2026 14:33:02 +0200 Subject: [PATCH 05/12] feat(adapter/postgre): wip --- .../workflows/adapter-ci/docker-compose.yml | 14 + Cargo.lock | 4 + crates/socketioxide-postgres/Cargo.toml | 3 +- .../socketioxide-postgres/src/drivers/mod.rs | 47 +-- .../socketioxide-postgres/src/drivers/sqlx.rs | 84 +++--- crates/socketioxide-postgres/src/lib.rs | 271 +++++++++++++----- crates/socketioxide-postgres/src/stream.rs | 123 ++++++-- e2e/adapter/Cargo.toml | 13 + e2e/adapter/main.rs | 40 ++- e2e/adapter/src/bins/sqlx.rs | 64 +++++ e2e/adapter/src/bins/sqlx_msgpack.rs | 65 +++++ 11 files changed, 518 insertions(+), 210 deletions(-) create mode 100644 e2e/adapter/src/bins/sqlx.rs create mode 100644 e2e/adapter/src/bins/sqlx_msgpack.rs diff --git a/.github/workflows/adapter-ci/docker-compose.yml b/.github/workflows/adapter-ci/docker-compose.yml index 3965afd3..7c0403d2 100644 --- a/.github/workflows/adapter-ci/docker-compose.yml +++ b/.github/workflows/adapter-ci/docker-compose.yml @@ -140,3 +140,17 @@ services: '; wait " + + postgres: + image: postgres:18-alpine + ports: + - 5432:5432 + environment: + POSTGRES_DB: socketio + POSTGRES_PASSWORD: socketio + POSTGRES_USER: socketio + healthcheck: + test: "pg_isready -U socketio" + interval: 2s + timeout: 5s + retries: 5 diff --git a/Cargo.lock b/Cargo.lock index 7cbc7442..597e0aa6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,7 @@ dependencies = [ "hyper-util", "socketioxide", "socketioxide-mongodb", + "socketioxide-postgres", "socketioxide-redis", "tokio", "tracing", @@ -2665,6 +2666,8 @@ dependencies = [ "sha2", "smallvec", "thiserror 2.0.17", + "tokio", + "tokio-stream", "tracing", "url", ] @@ -2701,6 +2704,7 @@ dependencies = [ "sqlx-core", "sqlx-postgres", "syn", + "tokio", "url", ] diff --git a/crates/socketioxide-postgres/Cargo.toml b/crates/socketioxide-postgres/Cargo.toml index 537fab4b..4029a391 100644 --- a/crates/socketioxide-postgres/Cargo.toml +++ b/crates/socketioxide-postgres/Cargo.toml @@ -25,7 +25,7 @@ futures-core.workspace = true futures-util.workspace = true pin-project-lite.workspace = true serde.workspace = true -serde_json.workspace = true +serde_json = { workspace = true, features = ["raw_value"] } smallvec = { workspace = true, features = ["serde"] } tokio = { workspace = true, features = ["time", "rt", "sync"] } rmp-serde.workspace = true @@ -38,6 +38,7 @@ tokio-postgres = { version = "0.7", default-features = false, optional = true, f ] } sqlx = { version = "0.8", default-features = false, optional = true, features = [ "postgres", + "runtime-tokio", ] } [dev-dependencies] diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs index a0e92ed9..94de7f17 100644 --- a/crates/socketioxide-postgres/src/drivers/mod.rs +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -1,29 +1,21 @@ use futures_core::Stream; -use serde::{Deserialize, Serialize, de::DeserializeOwned}; -use tokio::sync::mpsc; +use serde::Serialize; -// mod postgres; -mod sqlx; - -pub type ChanItem = (String, String); - -#[derive(Deserialize)] -pub struct Item {} +pub mod sqlx; /// The driver trait can be used to support different LISTEN/NOTIFY backends. /// It must share handlers/connection between its clones. pub trait Driver: Clone + Send + Sync + 'static { type Error: std::error::Error + Send + 'static; - type NotifStream: Stream + Send + 'static; type Notification: Notification; + type NotificationStream: Stream + Send; - fn init(&self, table: &str, channels: &[&str]) - -> impl Future>; + fn init(&self, table: &str) -> impl Future> + Send; - fn listen( + fn listen( &self, - channel: &str, - ) -> impl Future> + Send; + channels: &[&str], + ) -> impl Future> + Send; fn notify( &self, @@ -36,28 +28,3 @@ pub trait Notification: Send + 'static { fn channel(&self) -> &str; fn payload(&self) -> &str; } - -pin_project_lite::pin_project! { - pub struct NotifStream { - #[pin] - rx: mpsc::UnboundedReceiver, - } -} -impl Stream for NotifStream { - type Item = T; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - match self.rx.poll_recv(cx) { - std::task::Poll::Ready(notif) => std::task::Poll::Ready(notif), - std::task::Poll::Pending => std::task::Poll::Pending, - } - } -} -impl NotifStream { - pub fn new(rx: mpsc::UnboundedReceiver) -> Self { - NotifStream { rx } - } -} diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs index 9a0c6d86..a47b5169 100644 --- a/crates/socketioxide-postgres/src/drivers/sqlx.rs +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -1,65 +1,59 @@ -use std::{ - collections::HashMap, - sync::{Arc, RwLock}, -}; - -use serde::{Serialize, de::DeserializeOwned}; +use futures_core::stream::BoxStream; +use futures_util::StreamExt; +use serde::Serialize; use sqlx::{ PgPool, postgres::{PgListener, PgNotification}, }; -use tokio::sync::mpsc; - -use crate::{PostgresAdapterConfig, drivers::NotifStream}; use super::Driver; -type HandlerMap = HashMap>; +pub use sqlx as sqlx_client; #[derive(Debug, Clone)] pub struct SqlxDriver { client: PgPool, - handlers: Arc>, - config: PostgresAdapterConfig, } impl SqlxDriver { - pub fn new(client: PgPool, config: PostgresAdapterConfig) -> Self { - Self { - client, - handlers: Arc::new(RwLock::new(HashMap::new())), - config, - } + /// Create a new SqlxDriver instance. + pub fn new(client: PgPool) -> Self { + Self { client } } } impl Driver for SqlxDriver { type Error = sqlx::Error; - type NotifStream = NotifStream; type Notification = PgNotification; + type NotificationStream = BoxStream<'static, Self::Notification>; - async fn init(&self, table: &str, channels: &[&str]) -> Result<(), Self::Error> { - sqlx::query("CREATE TABLE $1 IF NOT EXISTS") - .bind(&table) - .execute(&self.client) - .await?; + async fn init(&self, table: &str) -> Result<(), Self::Error> { + sqlx::query(&format!( + r#"CREATE TABLE IF NOT EXISTS "{table}" ( + id BIGSERIAL UNIQUE, + created_at TIMESTAMPTZ DEFAULT NOW(), + payload BYTEA + )"#, + )) + .execute(&self.client) + .await?; + Ok(()) + } + + async fn listen(&self, channels: &[&str]) -> Result { let mut listener = PgListener::connect_with(&self.client).await?; listener.listen_all(channels.iter().copied()).await?; - tokio::spawn(spawn_listener(self.handlers.clone(), listener)); - Ok(()) - } - async fn listen( - &self, - channel: &str, - ) -> Result { - let (tx, rx) = mpsc::unbounded_channel(); - self.handlers - .write() - .unwrap() - .insert(channel.to_string(), tx); - Ok(NotifStream::new(rx)) + let stream = listener.into_stream(); + let stream = stream.filter_map(async |res| { + res.inspect_err(|err| { + tracing::warn!("failed to pull sqlx notification from stream: {err}") + }) + .ok() + }); + + Ok(Box::pin(stream)) } fn notify( @@ -71,7 +65,7 @@ impl Driver for SqlxDriver { //TODO: handle error let msg = serde_json::to_string(req).unwrap(); async move { - sqlx::query("NOTIFY $1 $2") + sqlx::query("SELECT pg_notify($1, $2)") .bind(channel) .bind(msg) .execute(&client) @@ -81,20 +75,6 @@ impl Driver for SqlxDriver { } } -async fn spawn_listener(handlers: Arc>, mut listener: PgListener) { - while let Ok(notif) = listener - .recv() - .await - .inspect_err(|e| tracing::warn!(?e, "sqlx listener error")) - { - if let Some(tx) = handlers.read().unwrap().get(notif.channel()) { - tx.send(notif); - } else { - tracing::warn!("handler not found for channel {}", notif.channel()); - } - } -} - impl super::Notification for PgNotification { fn channel(&self) -> &str { PgNotification::channel(self) diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index 92e18675..f4ee9a23 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -29,12 +29,13 @@ nonstandard_style, missing_docs )] -//! +//! test use drivers::Driver; use futures_core::Stream; -use futures_util::StreamExt; +use futures_util::{StreamExt, pin_mut}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use serde_json::value::RawValue; use socketioxide_core::{ Sid, Uid, adapter::{ @@ -57,10 +58,14 @@ use std::{ task::{Context, Poll}, time::{Duration, Instant}, }; +use tokio::sync::mpsc; -use crate::{drivers::Notification, stream::AckStream}; +use crate::{ + drivers::Notification, + stream::{AckStream, ChanStream}, +}; -mod drivers; +pub mod drivers; mod stream; /// The configuration of the [`MongoDbAdapter`]. @@ -81,36 +86,49 @@ pub struct PostgresAdapterConfig { /// than you poll them with the returned stream, you might want to increase this value. pub ack_response_buffer: usize, /// The table name used to store socket.io attachments. Default is "socket_io_attachments". + /// + /// > The table name must be a sanitized string. Do not use special characters or spaces. pub table_name: Cow<'static, str>, /// The prefix used for the channels. Default is "socket.io". pub prefix: Cow<'static, str>, /// The treshold to the payload size in bytes. It should match the configured value on your PostgreSQL instance: - /// + /// . By default it is 8KB (8000 bytes). pub payload_treshold: usize, - /// The duration between cleanup queries on the + /// The duration between cleanup queries on the attachment table. pub cleanup_intervals: Duration, } +impl Default for PostgresAdapterConfig { + fn default() -> Self { + Self { + hb_timeout: Duration::from_secs(60), + hb_interval: Duration::from_secs(10), + request_timeout: Duration::from_secs(5), + ack_response_buffer: 255, + table_name: "socket_io_attachments".into(), + prefix: "socket.io".into(), + payload_treshold: 8_000, + cleanup_intervals: Duration::from_secs(60), + } + } +} + /// Represent any error that might happen when using this adapter. #[derive(thiserror::Error)] pub enum Error { /// Mongo driver error #[error("driver error: {0}")] Driver(D::Error), - /// Packet encoding error - #[error("packet encoding error: {0}")] - Encode(#[from] rmp_serde::encode::Error), - /// Packet decoding error + /// Packet encoding/decoding error #[error("packet decoding error: {0}")] - Decode(#[from] rmp_serde::decode::Error), + Serde(#[from] serde_json::Error), } impl fmt::Debug for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Driver(err) => write!(f, "Driver error: {:?}", err), - Self::Decode(err) => write!(f, "Decode error: {:?}", err), - Self::Encode(err) => write!(f, "Encode error: {:?}", err), + Self::Serde(err) => write!(f, "Encode/Decode error: {:?}", err), } } } @@ -121,15 +139,24 @@ impl From> for AdapterError { } } -/// An event we should answer to -#[derive(Debug, Deserialize)] -struct Event {} - +/// Constructor for the PostgresAdapterCtr struct. pub struct PostgresAdapterCtr { driver: D, config: PostgresAdapterConfig, } +impl PostgresAdapterCtr { + /// Create a new adapter constructor with a custom postgres driver and a config. + /// + /// You can implement your own driver by implementing the [`Driver`] trait with any postgres client. + /// Check the [`drivers`] module for more information. + pub fn new_with_driver(driver: D, config: PostgresAdapterConfig) -> Self { + Self { driver, config } + } +} + +type ResponseHandlers = HashMap>>; + /// The postgres adapter implementation. /// It is generic over the [`Driver`] used to communicate with the postgres server. /// And over the [`SocketEmitter`] used to communicate with the local server. This allows to @@ -140,28 +167,24 @@ pub struct CustomPostgresAdapter { driver: D, /// The configuration of the adapter. config: PostgresAdapterConfig, - /// A unique identifier for the adapter to identify itself in the postgres server. - uid: Uid, /// The local adapter, used to manage local rooms and socket stores. local: CoreLocalAdapter, /// A map of nodes liveness, with the last time remote nodes were seen alive. nodes_liveness: Mutex>, /// A map of response handlers used to await for responses from the remote servers. - responses: Arc>>, + responses: Arc>, } impl DefinedAdapter for CustomPostgresAdapter {} impl CoreAdapter for CustomPostgresAdapter { type Error = Error; type State = PostgresAdapterCtr; - type AckStream = AckStream; + type AckStream = AckStream; type InitRes = InitRes; fn new(state: &Self::State, local: CoreLocalAdapter) -> Self { - let uid = local.server_id(); Self { local, - uid, driver: state.driver.clone(), config: state.config.clone(), nodes_liveness: Mutex::new(Vec::new()), @@ -171,14 +194,26 @@ impl CoreAdapter for CustomPostgresAdapter fn init(self: Arc, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes { let fut = async move { - let stream = self.driver.listen("event").await?; + self.driver.init(&self.config.table_name).await?; + + let global_chan = self.get_global_chan(); + let node_chan = self.get_node_chan(self.local.server_id()); + let response_chan = self.get_response_chan(self.local.server_id()); + + let channels = [ + global_chan.as_str(), + node_chan.as_str(), + response_chan.as_str(), + ]; + + let stream = self.driver.listen(&channels).await?; tokio::spawn(self.clone().handle_ev_stream(stream)); tokio::spawn(self.clone().heartbeat_job()); // Send initial heartbeat when starting. self.emit_init_heartbeat().await.map_err(|e| match e { Error::Driver(e) => e, - Error::Encode(_) | Error::Decode(_) => unreachable!(), + Error::Serde(_) => unreachable!(), })?; on_success(); @@ -205,8 +240,9 @@ impl CoreAdapter for CustomPostgresAdapter packet: Packet, opts: BroadcastOptions, ) -> Result<(), BroadcastError> { - if !opts.is_local(self.uid) { - let req = RequestOut::new(self.uid, RequestTypeOut::Broadcast(&packet), &opts); + let node_id = self.local.server_id(); + if !opts.is_local(node_id) { + let req = RequestOut::new(node_id, RequestTypeOut::Broadcast(&packet), &opts); self.send_req(req, None).await.map_err(AdapterError::from)?; } @@ -247,33 +283,45 @@ impl CoreAdapter for CustomPostgresAdapter opts: BroadcastOptions, timeout: Option, ) -> Result { - if opts.is_local(self.uid) { + if opts.is_local(self.local.server_id()) { tracing::debug!(?opts, "broadcast with ack is local"); let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout); let stream = AckStream::new_local(local); return Ok(stream); } - let req = RequestOut::new(self.uid, RequestTypeOut::BroadcastWithAck(&packet), &opts); + let req = RequestOut::new( + self.local.server_id(), + RequestTypeOut::BroadcastWithAck(&packet), + &opts, + ); let req_id = req.id; let remote_serv_cnt = self.server_count().await?.saturating_sub(1); tracing::trace!(?remote_serv_cnt, "expecting acks from remote servers"); - let remote = self.driver.listen("").await?; + + let (tx, rx) = mpsc::channel(self.config.ack_response_buffer + remote_serv_cnt as usize); + self.responses.lock().unwrap().insert(req_id, tx); self.send_req(req, None).await?; let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout); Ok(AckStream::new( local, - remote, + rx, self.config.request_timeout, remote_serv_cnt, + req_id, + self.responses.clone(), )) } async fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), BroadcastError> { - if !opts.is_local(self.uid) { - let req = RequestOut::new(self.uid, RequestTypeOut::DisconnectSockets, &opts); + if !opts.is_local(self.local.server_id()) { + let req = RequestOut::new( + self.local.server_id(), + RequestTypeOut::DisconnectSockets, + &opts, + ); self.send_req(req, None).await.map_err(AdapterError::from)?; } self.local @@ -284,10 +332,10 @@ impl CoreAdapter for CustomPostgresAdapter } async fn rooms(&self, opts: BroadcastOptions) -> Result, Self::Error> { - if opts.is_local(self.uid) { + if opts.is_local(self.local.server_id()) { return Ok(self.local.rooms(opts).into_iter().collect()); } - let req = RequestOut::new(self.uid, RequestTypeOut::AllRooms, &opts); + let req = RequestOut::new(self.local.server_id(), RequestTypeOut::AllRooms, &opts); let req_id = req.id; // First get the remote stream because postgres might send @@ -313,8 +361,12 @@ impl CoreAdapter for CustomPostgresAdapter rooms: impl RoomParam, ) -> Result<(), Self::Error> { let rooms: Vec = rooms.into_room_iter().collect(); - if !opts.is_local(self.uid) { - let req = RequestOut::new(self.uid, RequestTypeOut::AddSockets(&rooms), &opts); + if !opts.is_local(self.local.server_id()) { + let req = RequestOut::new( + self.local.server_id(), + RequestTypeOut::AddSockets(&rooms), + &opts, + ); self.send_req(req, opts.server_id).await?; } self.local.add_sockets(opts, rooms); @@ -327,8 +379,12 @@ impl CoreAdapter for CustomPostgresAdapter rooms: impl RoomParam, ) -> Result<(), Self::Error> { let rooms: Vec = rooms.into_room_iter().collect(); - if !opts.is_local(self.uid) { - let req = RequestOut::new(self.uid, RequestTypeOut::DelSockets(&rooms), &opts); + if !opts.is_local(self.local.server_id()) { + let req = RequestOut::new( + self.local.server_id(), + RequestTypeOut::DelSockets(&rooms), + &opts, + ); self.send_req(req, opts.server_id).await?; } self.local.del_sockets(opts, rooms); @@ -339,11 +395,11 @@ impl CoreAdapter for CustomPostgresAdapter &self, opts: BroadcastOptions, ) -> Result, Self::Error> { - if opts.is_local(self.uid) { + if opts.is_local(self.local.server_id()) { return Ok(self.local.fetch_sockets(opts)); } - let req = RequestOut::new(self.uid, RequestTypeOut::FetchSockets, &opts); - // First get the remote stream because mongodb might send + let req = RequestOut::new(self.local.server_id(), RequestTypeOut::FetchSockets, &opts); + // First get the remote stream because postgres might send // the responses before subscription is done. let remote = self .get_res::(req.id, ResponseTypeId::FetchSockets, opts.server_id) @@ -376,10 +432,46 @@ impl CustomPostgresAdapter { } } - async fn handle_ev_stream(self: Arc, stream: impl Stream) { - futures_util::pin_mut!(stream); - while let Some(req) = stream.next().await { - self.recv_req(req); + async fn handle_ev_stream(self: Arc, stream: impl Stream) { + pin_mut!(stream); + while let Some(notif) = stream.next().await { + let chan = notif.channel(); + let resp_chan = self.get_response_chan(self.local.server_id()); + tracing::info!(chan, resp_chan, notif = notif.payload(), ""); + if chan == resp_chan { + match serde_json::from_str(notif.payload()) { + Ok(ResponsePacket { + req_id, + node_id, + payload, + }) if node_id != self.local.server_id() => { + let handlers = self.responses.lock().unwrap(); + if let Some(handler) = handlers.get(&req_id) { + if let Err(e) = handler.try_send(payload) { + tracing::warn!(channel = resp_chan, req_id = %req_id, "error sending response: {e}"); + } + } else { + tracing::warn!(channel = resp_chan, req_id = %req_id, "response handler not found"); + } + } + Ok(_) => { + tracing::trace!("skipping loopback packets"); + } + Err(e) => { + tracing::warn!(channel = %notif.channel(), "error handling response: {e}") + } + }; + } else { + match serde_json::from_str::(notif.payload()) { + Ok(req) if req.node_id != self.local.server_id() => self.recv_req(req), + Ok(_) => { + tracing::trace!("skipping loopback packets") + } + Err(e) => { + tracing::warn!(channel = %notif.channel(), "error decoding request: {e}") + } + }; + } } } @@ -408,7 +500,7 @@ impl CustomPostgresAdapter { tracing::trace!(?opts, "incoming broadcast"); if let Err(e) = self.local.broadcast(packet, opts) { let ns = self.local.path(); - tracing::warn!(?self.uid, ?ns, "remote request broadcast handler: {:?}", e); + tracing::warn!(node_id = %self.local.server_id(), ?ns, "remote request broadcast handler: {:?}", e); } } @@ -416,8 +508,8 @@ impl CustomPostgresAdapter { if let Err(e) = self.local.disconnect_socket(opts) { let ns = self.local.path(); tracing::warn!( - ?self.uid, - ?ns, + node_id = %self.local.server_id(), + %ns, "remote request disconnect sockets handler: {:?}", e ); @@ -436,8 +528,8 @@ impl CustomPostgresAdapter { let on_err = |err| { let ns = self.local.path(); tracing::warn!( - ?self.uid, - ?ns, + node_id = %self.local.server_id(), + %ns, "remote request broadcast with ack handler errors: {:?}", err ); @@ -446,7 +538,7 @@ impl CustomPostgresAdapter { // This is used to keep track of the number of expected acks. let res = Response { r#type: ResponseType::<()>::BroadcastAckCount(count), - node_id: self.uid, + node_id: self.local.server_id(), }; if let Err(err) = self.send_res(req_id, origin, res).await { on_err(err); @@ -458,7 +550,7 @@ impl CustomPostgresAdapter { while let Some(ack) = stream.next().await { let res = Response { r#type: ResponseType::BroadcastAck(ack), - node_id: self.uid, + node_id: self.local.server_id(), }; if let Err(err) = self.send_res(req_id, origin, res).await { on_err(err); @@ -472,11 +564,11 @@ impl CustomPostgresAdapter { let rooms = self.local.rooms(opts); let res = Response { r#type: ResponseType::<()>::AllRooms(rooms), - node_id: self.uid, + node_id: self.local.server_id(), }; let fut = self.send_res(req_id, origin, res); let ns = self.local.path().clone(); - let uid = self.uid; + let uid = self.local.server_id(); tokio::spawn(async move { if let Err(err) = fut.await { tracing::warn!(?uid, ?ns, "remote request rooms handler: {:?}", err); @@ -494,12 +586,12 @@ impl CustomPostgresAdapter { fn recv_fetch_sockets(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) { let sockets = self.local.fetch_sockets(opts); let res = Response { - node_id: self.uid, + node_id: self.local.server_id(), r#type: ResponseType::FetchSockets(sockets), }; let fut = self.send_res(req_id, origin, res); let ns = self.local.path().clone(); - let uid = self.uid; + let uid = self.local.server_id(); tokio::spawn(async move { if let Err(err) = fut.await { tracing::warn!(?uid, ?ns, "remote request fetch sockets handler: {:?}", err); @@ -544,10 +636,12 @@ impl CustomPostgresAdapter { /// Send a request to a specific target node or broadcast it to all nodes if no target is specified. async fn send_req(&self, req: RequestOut<'_>, target: Option) -> Result<(), Error> { tracing::trace!(?req, "sending request"); - // let head = ItemHeader::Req { target }; - // let req = self.new_packet(head, &req)?; + let chan = match target { + Some(target) => self.get_node_chan(target), + None => self.get_global_chan(), + }; self.driver - .notify("yolo", &req) + .notify(&chan, &req) .await .map_err(Error::Driver)?; Ok(()) @@ -558,16 +652,23 @@ impl CustomPostgresAdapter { &self, req_id: Sid, req_origin: Uid, - res: Response, + payload: Response, ) -> impl Future>> + 'static { - tracing::trace!(?res, "sending response for {req_id} req to {req_origin}"); + tracing::trace!( + ?payload, + "sending response for {req_id} req to {req_origin}" + ); let driver = self.driver.clone(); + let chan = self.get_response_chan(req_origin); + let payload = RawValue::from_string(serde_json::to_string(&payload).unwrap()).unwrap(); + let res = ResponsePacket { + req_id, + node_id: self.local.server_id(), + payload, + }; //TODO: is this the right way? async move { - driver - .notify("response", &res) - .await - .map_err(Error::Driver)?; + driver.notify(&chan, &res).await.map_err(Error::Driver)?; Ok(()) } } @@ -587,15 +688,16 @@ impl CustomPostgresAdapter { 1 }; - let stream = self.driver.listen("test").await.unwrap(); - self.responses.lock().unwrap().insert(req_id, stream); + let (tx, rx) = mpsc::channel(std::cmp::max(remote_serv_cnt, 1)); + self.responses.lock().unwrap().insert(req_id, tx); + let stream = ChanStream::new(rx); let stream = stream - .filter_map(|notif| { - let data = match serde_json::from_str::>(notif.payload()) { + .filter_map(|payload| { + let data = match serde_json::from_str::>(payload.get()) { Ok(data) => Some(data), Err(e) => { - tracing::warn!(channel = %notif.channel(), "error decoding response: {e}"); + tracing::warn!("error decoding response: {e}"); None } }; @@ -605,14 +707,13 @@ impl CustomPostgresAdapter { .take(remote_serv_cnt) .take_until(tokio::time::sleep(self.config.request_timeout)); - let stream = DropStream::new(stream, self.responses.clone(), req_id); Ok(stream) } /// Emit a heartbeat to the specified target node or broadcast to all nodes. async fn emit_heartbeat(&self, target: Option) -> Result<(), Error> { self.send_req( - RequestOut::new_empty(self.uid, RequestTypeOut::Heartbeat), + RequestOut::new_empty(self.local.server_id(), RequestTypeOut::Heartbeat), target, ) .await @@ -622,11 +723,26 @@ impl CustomPostgresAdapter { async fn emit_init_heartbeat(&self) -> Result<(), Error> { // Send initial heartbeat when starting. self.send_req( - RequestOut::new_empty(self.uid, RequestTypeOut::InitHeartbeat), + RequestOut::new_empty(self.local.server_id(), RequestTypeOut::InitHeartbeat), None, ) .await } + + fn get_global_chan(&self) -> String { + format!("{}#{}", self.config.prefix, self.local.path()) + } + fn get_node_chan(&self, uid: Uid) -> String { + format!("{}#{}", self.get_global_chan(), uid) + } + fn get_response_chan(&self, uid: Uid) -> String { + format!( + "{}-response#{}#{}", + &self.config.prefix, + self.local.path(), + uid + ) + } } /// The result of the init future. @@ -649,3 +765,10 @@ impl Spawnable for InitRes { }); } } + +#[derive(Deserialize, Serialize)] +struct ResponsePacket { + req_id: Sid, + node_id: Uid, + payload: Box, +} diff --git a/crates/socketioxide-postgres/src/stream.rs b/crates/socketioxide-postgres/src/stream.rs index 5e7a945f..3fec6df7 100644 --- a/crates/socketioxide-postgres/src/stream.rs +++ b/crates/socketioxide-postgres/src/stream.rs @@ -1,6 +1,7 @@ use std::{ fmt, pin::Pin, + sync::{Arc, Mutex}, task::{self, Poll}, time::Duration, }; @@ -9,13 +10,17 @@ use futures_core::{FusedStream, Stream}; use futures_util::{StreamExt, stream::TakeUntil}; use pin_project_lite::pin_project; use serde::de::DeserializeOwned; +use serde_json::value::RawValue; use socketioxide_core::{ - adapter::AckStreamItem, - adapter::remote_packet::{Response, ResponseType}, + Sid, + adapter::{ + AckStreamItem, + remote_packet::{Response, ResponseType}, + }, }; use tokio::{sync::mpsc, time}; -use crate::drivers::{NotifStream, Notification}; +use crate::{ResponseHandlers, drivers::Notification}; pin_project! { /// A stream of acknowledgement messages received from the local and remote servers. @@ -27,20 +32,28 @@ pin_project! { // And it is decremented each time an ack is received. // // Therefore an exhausted stream correspond to `ack_cnt == 0` and `server_cnt == 0`. - pub struct AckStream { + pub struct AckStream { #[pin] local: S, #[pin] - remote: TakeUntil, time::Sleep>, + remote: DropStream>, time::Sleep>>, ack_cnt: u32, total_ack_cnt: usize, serv_cnt: u16, } } -impl AckStream { - pub fn new(local: S, remote: NotifStream, timeout: Duration, serv_cnt: u16) -> Self { - let remote = remote.take_until(time::sleep(timeout)); +impl AckStream { + pub fn new( + local: S, + remote: mpsc::Receiver>, + timeout: Duration, + serv_cnt: u16, + req_sid: Sid, + handlers: Arc>, + ) -> Self { + let remote = ChanStream::new(remote).take_until(time::sleep(timeout)); + let remote = DropStream::new(remote, handlers, req_sid); Self { local, ack_cnt: 0, @@ -51,8 +64,10 @@ impl AckStream { } pub fn new_local(local: S) -> Self { - let rx = mpsc::unbounded_channel().1; - let remote = NotifStream::new(rx).take_until(time::sleep(Duration::ZERO)); + let handlers = Arc::new(Mutex::new(ResponseHandlers::new())); + let rx = mpsc::channel(1).1; + let remote = ChanStream::new(rx).take_until(time::sleep(Duration::ZERO)); + let remote = DropStream::new(remote, handlers, Sid::ZERO); Self { local, remote, @@ -62,13 +77,13 @@ impl AckStream { } } } -impl AckStream +impl AckStream where Err: DeserializeOwned + fmt::Debug, S: Stream> + FusedStream, { - /// Poll the remote stream. First the count of acks is received, then the acks are received. - /// We expect `serv_cnt` of `BroadcastAckCount` messages to be received, then we expect + /// Poll the remote stream. First the count of acks is receivedhen the acks are received. + /// We expect `serv_cnt` of `BroadcastAckCount` messages to be receivedhen we expect /// `ack_cnt` of `BroadcastAck` messages. fn poll_remote( self: Pin<&mut Self>, @@ -84,14 +99,13 @@ where Poll::Pending => return Poll::Pending, Poll::Ready(None) => return Poll::Ready(None), Poll::Ready(Some(notif)) => { - let channel = notif.channel(); - let res = serde_json::from_str::>(notif.payload()); + let res = serde_json::from_str::>(notif.get()); match res { Ok(Response { node_id: uid, r#type: ResponseType::BroadcastAckCount(count), }) if *projection.serv_cnt > 0 => { - tracing::trace!(?uid, channel, "receiving broadcast ack count {count}"); + tracing::trace!(?uid, "receiving broadcast ack count {count}"); *projection.ack_cnt += count; *projection.total_ack_cnt += count as usize; *projection.serv_cnt -= 1; @@ -100,17 +114,12 @@ where node_id: uid, r#type: ResponseType::BroadcastAck((sid, res)), }) if *projection.ack_cnt > 0 => { - tracing::trace!( - ?uid, - channel, - "receiving broadcast ack {sid} {:?}", - res - ); + tracing::trace!(?uid, "receiving broadcast ack {sid} {:?}", res); *projection.ack_cnt -= 1; return Poll::Ready(Some((sid, res))); } Ok(Response { node_id: uid, .. }) => { - tracing::warn!(?uid, channel, "unexpected response type"); + tracing::warn!(?uid, "unexpected response type"); } Err(e) => { tracing::warn!("error decoding ack response: {e}"); @@ -121,11 +130,10 @@ where } } } -impl Stream for AckStream +impl Stream for AckStream where E: DeserializeOwned + fmt::Debug, S: Stream> + FusedStream, - T: Notification, { type Item = AckStreamItem; fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { @@ -146,11 +154,10 @@ where } } -impl FusedStream for AckStream +impl FusedStream for AckStream where Err: DeserializeOwned + fmt::Debug, S: Stream> + FusedStream, - T: Notification, { /// The stream is terminated if: /// * The local stream is terminated. @@ -162,7 +169,7 @@ where self.local.is_terminated() && remote_term } } -impl fmt::Debug for AckStream { +impl fmt::Debug for AckStream { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("AckStream") .field("ack_cnt", &self.ack_cnt) @@ -171,6 +178,64 @@ impl fmt::Debug for AckStream { .finish() } } +pin_project! { + /// A stream of messages received from a channel. + pub struct ChanStream { + #[pin] + rx: mpsc::Receiver + } +} +impl ChanStream { + pub fn new(rx: mpsc::Receiver) -> Self { + Self { rx } + } +} +impl Stream for ChanStream { + type Item = T; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + self.project().rx.poll_recv(cx) + } +} + +pin_project! { + /// A stream that unsubscribes from its source channel when dropped. + pub struct DropStream { + #[pin] + stream: S, + req_id: Sid, + handlers: Arc> + } + impl PinnedDrop for DropStream { + fn drop(this: Pin<&mut Self>) { + let stream = this.project(); + let chan = stream.req_id; + tracing::debug!(?chan, "dropping stream"); + stream.handlers.lock().unwrap().remove(chan); + } + } +} +impl DropStream { + pub fn new(stream: S, handlers: Arc>, req_id: Sid) -> Self { + Self { + stream, + handlers, + req_id, + } + } +} +impl Stream for DropStream { + type Item = S::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + self.project().stream.poll_next(cx) + } +} +impl FusedStream for DropStream { + fn is_terminated(&self) -> bool { + self.stream.is_terminated() + } +} #[cfg(test)] mod tests { @@ -186,7 +251,7 @@ mod tests { let local = futures_util::stream::once(async move { (sid, Ok::<_, ()>(Value::Str("local".into(), None))) }); - let stream = AckStream::new_local(local); + let stream = AckStream::<_>::new_local(local); futures_util::pin_mut!(stream); assert_eq!(stream.ack_cnt, 0); assert_eq!(stream.total_ack_cnt, 0); diff --git a/e2e/adapter/Cargo.toml b/e2e/adapter/Cargo.toml index 1e94a487..cc7ea709 100644 --- a/e2e/adapter/Cargo.toml +++ b/e2e/adapter/Cargo.toml @@ -22,6 +22,11 @@ socketioxide-redis = { path = "../../crates/socketioxide-redis", features = [ "fred", ] } socketioxide-mongodb = { path = "../../crates/socketioxide-mongodb" } +socketioxide-postgres = { path = "../../crates/socketioxide-postgres", features = [ + "sqlx", + "postgres", +] } + hyper-util = { workspace = true, features = ["tokio"] } hyper = { workspace = true, features = ["server", "http1"] } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } @@ -80,3 +85,11 @@ path = "src/bins/mongodb_ttl.rs" [[bin]] name = "mongodb-ttl-e2e-msgpack" path = "src/bins/mongodb_ttl_msgpack.rs" + +[[bin]] +name = "sqlx-e2e" +path = "src/bins/sqlx.rs" + +[[bin]] +name = "sqlx-e2e-msgpack" +path = "src/bins/sqlx_msgpack.rs" diff --git a/e2e/adapter/main.rs b/e2e/adapter/main.rs index f174b849..cd2f0bf9 100644 --- a/e2e/adapter/main.rs +++ b/e2e/adapter/main.rs @@ -3,7 +3,7 @@ use std::fs; use std::process::{Child, Command}; use std::time::Duration; -const BINS: [&str; 12] = [ +const BINS: &[&str] = &[ "fred-e2e", "fred-e2e-msgpack", "redis-e2e", @@ -16,27 +16,34 @@ const BINS: [&str; 12] = [ "mongodb-ttl-e2e-msgpack", "mongodb-capped-e2e", "mongodb-capped-e2e-msgpack", + "sqlx-e2e", + "sqlx-e2e-msgpack", ]; const EXEC_SUFFIX: &str = if cfg!(windows) { ".exe" } else { "" }; const LOG_DIR: &str = "e2e/adapter/logs"; -fn main() { - let filter = args().skip(1).next().unwrap_or("".to_string()); - println!("filter: {}", filter); +fn main() -> Result<(), Box> { + let bin_filter = args().nth(1).unwrap_or("".to_string()); + println!("binary target filter: {}", bin_filter); - if fs::exists(LOG_DIR).unwrap() { - fs::remove_dir_all(LOG_DIR).unwrap(); + let test_filter = args().nth(2); + println!("test filter: {}", test_filter.as_deref().unwrap_or("*")); + + if fs::exists(LOG_DIR)? { + fs::remove_dir_all(LOG_DIR)?; } - fs::create_dir_all(LOG_DIR).unwrap(); + fs::create_dir_all(LOG_DIR)?; // run everything - for target in BINS.into_iter().filter(|name| name.contains(&filter)) { - run(target); + for target in BINS.iter().filter(|name| name.contains(&bin_filter)) { + run(target, test_filter.as_deref()); } println!("All tests passed!"); + + Ok(()) } -fn run(target: &'static str) { +fn run(target: &'static str, test_filter: Option<&str>) { let parser = if target.ends_with("msgpack") { "msgpack" } else { @@ -50,10 +57,15 @@ fn run(target: &'static str) { std::thread::sleep(Duration::from_millis(200)); - let child = Command::new("node") - .arg("--experimental-strip-types") - .arg("--test-reporter=spec") - .arg("--test") + let mut cmd = Command::new("node"); + + cmd.arg("--test-reporter=spec").arg("--test"); + + if let Some(filter) = test_filter { + cmd.arg(format!("--test-name-pattern=\"{filter}\"")); + } + + let child = cmd .arg("e2e/adapter/client.ts") .env("PORTS", "3000,3001,3002") .env("PARSER", parser) diff --git a/e2e/adapter/src/bins/sqlx.rs b/e2e/adapter/src/bins/sqlx.rs new file mode 100644 index 00000000..fb260072 --- /dev/null +++ b/e2e/adapter/src/bins/sqlx.rs @@ -0,0 +1,64 @@ +use hyper::server::conn::http1; +use hyper_util::rt::TokioIo; +use socketioxide::SocketIo; + +use socketioxide_postgres::{ + CustomPostgresAdapter, PostgresAdapterConfig, PostgresAdapterCtr, + drivers::sqlx::{SqlxDriver, sqlx_client::PgPool}, +}; +use tokio::net::TcpListener; +use tracing::{Level, info}; +use tracing_subscriber::FmtSubscriber; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = FmtSubscriber::builder() + .with_line_number(true) + .with_max_level(Level::TRACE) + .finish(); + tracing::subscriber::set_global_default(subscriber)?; + let variant = std::env::args().next().unwrap(); + let variant = variant.split("/").last().unwrap(); + + let config = PostgresAdapterConfig { + prefix: format!("socket.io-{variant}").into(), + ..Default::default() + }; + + let pg_pool = PgPool::connect("postgres://socketio:socketio@localhost:5432/socketio").await?; + let adapter = PostgresAdapterCtr::new_with_driver(SqlxDriver::new(pg_pool), config); + let (svc, io) = SocketIo::builder() + .with_adapter::>(adapter) + .build_svc(); + + io.ns("/", adapter_e2e::handler).await?; + + info!("Starting server with v5 protocol"); + let port: u16 = std::env::var("PORT") + .expect("a PORT env var should be set") + .parse()?; + + let listener = TcpListener::bind(("127.0.0.1", port)).await?; + + // We start a loop to continuously accept incoming connections + loop { + let (stream, _) = listener.accept().await?; + + // Use an adapter to access something implementing `tokio::io` traits as if they implement + // `hyper::rt` IO traits. + let io = TokioIo::new(stream); + let svc = svc.clone(); + + // Spawn a tokio task to serve multiple connections concurrently + tokio::task::spawn(async move { + // Finally, we bind the incoming connection to our `hello` service + if let Err(err) = http1::Builder::new() + .serve_connection(io, svc) + .with_upgrades() + .await + { + println!("Error serving connection: {:?}", err); + } + }); + } +} diff --git a/e2e/adapter/src/bins/sqlx_msgpack.rs b/e2e/adapter/src/bins/sqlx_msgpack.rs new file mode 100644 index 00000000..d7f420f3 --- /dev/null +++ b/e2e/adapter/src/bins/sqlx_msgpack.rs @@ -0,0 +1,65 @@ +use hyper::server::conn::http1; +use hyper_util::rt::TokioIo; +use socketioxide::{ParserConfig, SocketIo}; + +use socketioxide_postgres::{ + CustomPostgresAdapter, PostgresAdapterConfig, PostgresAdapterCtr, + drivers::sqlx::{SqlxDriver, sqlx_client::PgPool}, +}; +use tokio::net::TcpListener; +use tracing::{Level, info}; +use tracing_subscriber::FmtSubscriber; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = FmtSubscriber::builder() + .with_line_number(true) + .with_max_level(Level::TRACE) + .finish(); + tracing::subscriber::set_global_default(subscriber)?; + let variant = std::env::args().next().unwrap(); + let variant = variant.split("/").last().unwrap(); + + let config = PostgresAdapterConfig { + prefix: format!("socket.io-{variant}").into(), + ..Default::default() + }; + + let pg_pool = PgPool::connect("postgres://socketio:socketio@localhost:5432/socketio").await?; + let adapter = PostgresAdapterCtr::new_with_driver(SqlxDriver::new(pg_pool), config); + let (svc, io) = SocketIo::builder() + .with_parser(ParserConfig::msgpack()) + .with_adapter::>(adapter) + .build_svc(); + + io.ns("/", adapter_e2e::handler).await?; + + info!("Starting server with v5 protocol"); + let port: u16 = std::env::var("PORT") + .expect("a PORT env var should be set") + .parse()?; + + let listener = TcpListener::bind(("127.0.0.1", port)).await?; + + // We start a loop to continuously accept incoming connections + loop { + let (stream, _) = listener.accept().await?; + + // Use an adapter to access something implementing `tokio::io` traits as if they implement + // `hyper::rt` IO traits. + let io = TokioIo::new(stream); + let svc = svc.clone(); + + // Spawn a tokio task to serve multiple connections concurrently + tokio::task::spawn(async move { + // Finally, we bind the incoming connection to our `hello` service + if let Err(err) = http1::Builder::new() + .serve_connection(io, svc) + .with_upgrades() + .await + { + println!("Error serving connection: {:?}", err); + } + }); + } +} From 7acdde2ff040ee0400422d95981353a5e1dd2208 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 29 Mar 2026 14:38:38 +0200 Subject: [PATCH 06/12] feat(adapter/postgre): wip --- e2e/adapter/main.rs | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/e2e/adapter/main.rs b/e2e/adapter/main.rs index cd2f0bf9..f18af801 100644 --- a/e2e/adapter/main.rs +++ b/e2e/adapter/main.rs @@ -57,15 +57,9 @@ fn run(target: &'static str, test_filter: Option<&str>) { std::thread::sleep(Duration::from_millis(200)); - let mut cmd = Command::new("node"); - - cmd.arg("--test-reporter=spec").arg("--test"); - - if let Some(filter) = test_filter { - cmd.arg(format!("--test-name-pattern=\"{filter}\"")); - } - - let child = cmd + let child = Command::new("node") + .arg("--test-reporter=spec") + .arg("--test") .arg("e2e/adapter/client.ts") .env("PORTS", "3000,3001,3002") .env("PARSER", parser) From db72426427ae3fee59763ab6bd77333481c9f4e2 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 29 Mar 2026 14:39:45 +0200 Subject: [PATCH 07/12] feat(adapter/postgre): wip --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index af7063a2..5485323b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace.package] edition = "2024" -rust-version = "1.86.0" +rust-version = "1.88.0" authors = ["Théodore Prévot <"] repository = "https://github.com/totodore/socketioxide" homepage = "https://github.com/totodore/socketioxide" From 3453a8680d0613f664fb4edaed1c2439d2a9f6ba Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 29 Mar 2026 14:48:45 +0200 Subject: [PATCH 08/12] feat(adapter/postgre): wip --- crates/socketioxide-postgres/src/drivers/mod.rs | 1 + crates/socketioxide-postgres/src/drivers/sqlx.rs | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs index 94de7f17..57fa558c 100644 --- a/crates/socketioxide-postgres/src/drivers/mod.rs +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -1,6 +1,7 @@ use futures_core::Stream; use serde::Serialize; +#[cfg(feature = "sqlx")] pub mod sqlx; /// The driver trait can be used to support different LISTEN/NOTIFY backends. diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs index a47b5169..f8977284 100644 --- a/crates/socketioxide-postgres/src/drivers/sqlx.rs +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -63,11 +63,11 @@ impl Driver for SqlxDriver { ) -> impl Future> + Send { let client = self.client.clone(); //TODO: handle error - let msg = serde_json::to_string(req).unwrap(); + let msg = serde_json::to_string(req).map_err(|err| sqlx::Error::Decode(Box::new(err))); async move { sqlx::query("SELECT pg_notify($1, $2)") .bind(channel) - .bind(msg) + .bind(msg?) .execute(&client) .await?; Ok(()) From aa5430526c38b6013b63a940bbba392e0b0ab6c0 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 29 Mar 2026 15:23:29 +0200 Subject: [PATCH 09/12] feat(adapter/postgre): wip --- Cargo.lock | 1 - crates/socketioxide-postgres/Cargo.toml | 1 - .../socketioxide-postgres/src/drivers/mod.rs | 8 ++- .../src/drivers/postgres.rs | 59 ++++++++++++------- .../socketioxide-postgres/src/drivers/sqlx.rs | 24 +++----- crates/socketioxide-postgres/src/lib.rs | 9 ++- 6 files changed, 57 insertions(+), 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1dae8524..845b0369 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2614,7 +2614,6 @@ dependencies = [ "futures-core", "futures-util", "pin-project-lite", - "rmp-serde", "serde", "serde_json", "smallvec", diff --git a/crates/socketioxide-postgres/Cargo.toml b/crates/socketioxide-postgres/Cargo.toml index 4029a391..3aa2604e 100644 --- a/crates/socketioxide-postgres/Cargo.toml +++ b/crates/socketioxide-postgres/Cargo.toml @@ -28,7 +28,6 @@ serde.workspace = true serde_json = { workspace = true, features = ["raw_value"] } smallvec = { workspace = true, features = ["serde"] } tokio = { workspace = true, features = ["time", "rt", "sync"] } -rmp-serde.workspace = true tracing.workspace = true thiserror.workspace = true diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs index 57fa558c..6b64b59a 100644 --- a/crates/socketioxide-postgres/src/drivers/mod.rs +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -1,9 +1,11 @@ use futures_core::Stream; -use serde::Serialize; #[cfg(feature = "sqlx")] pub mod sqlx; +// #[cfg(feature = "postgres")] +// pub mod postgres; + /// The driver trait can be used to support different LISTEN/NOTIFY backends. /// It must share handlers/connection between its clones. pub trait Driver: Clone + Send + Sync + 'static { @@ -18,10 +20,10 @@ pub trait Driver: Clone + Send + Sync + 'static { channels: &[&str], ) -> impl Future> + Send; - fn notify( + fn notify( &self, channel: &str, - message: &T, + message: &str, ) -> impl Future> + Send; } diff --git a/crates/socketioxide-postgres/src/drivers/postgres.rs b/crates/socketioxide-postgres/src/drivers/postgres.rs index 0d6d408a..fea516a5 100644 --- a/crates/socketioxide-postgres/src/drivers/postgres.rs +++ b/crates/socketioxide-postgres/src/drivers/postgres.rs @@ -1,46 +1,63 @@ -use std::sync::Arc; +use std::{pin::Pin, sync::Arc}; -use tokio_postgres::{Client, Connection}; +use futures_core::{Stream, stream::BoxStream}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_postgres::{AsyncMessage, Client, Config, Connection, Socket}; use super::Driver; #[derive(Debug, Clone)] pub struct PostgresDriver { client: Arc, + config: Config, } impl PostgresDriver { - pub fn new(client: Client, connection: Connection) -> Self { - PostgresDriver { - client: Arc::new(client), - } + pub fn new(config: Config) -> Self + where + T: AsyncRead + AsyncWrite + Unpin, + { + PostgresDriver { config } } } impl Driver for PostgresDriver { type Error = tokio_postgres::Error; - type NotifStream; + type Notification = tokio_postgres::Notification; + type NotificationStream = BoxStream<'static, Self::Notification>; - async fn init(&self, table: &str, channels: &[&str]) -> Result<(), Self::Error> { + async fn init( + &self, + table: &str, + channels: &[&str], + ) -> Result { + let st = &format!( + r#"CREATE TABLE IF NOT EXISTS "{table}" ( + id BIGSERIAL UNIQUE, + created_at TIMESTAMPTZ DEFAULT NOW(), + payload BYTEA + )"# + ); + + self.client.execute(st, &[]).await?; + + Ok(()) + } + + async fn notify(&self, channel: &str, message: &str) -> Result<(), Self::Error> { self.client - .execute("CREATE TABLE $1 IF NOT EXISTS", &[&table]) + .execute("SELECT pg_notify($1, $2)", &[&channel, &message]) .await?; - Ok(()) } +} - fn listen( - &self, - channel: &str, - ) -> impl Future, Self::Error>> + Send { - todo!() +impl super::Notification for tokio_postgres::Notification { + fn channel(&self) -> &str { + tokio_postgres::Notification::channel(self) } - fn notify( - &self, - channel: &str, - message: &T, - ) -> impl Future> + Send { - todo!() + fn payload(&self) -> &str { + tokio_postgres::Notification::payload(self) } } diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs index f8977284..d23a05f8 100644 --- a/crates/socketioxide-postgres/src/drivers/sqlx.rs +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -1,6 +1,5 @@ use futures_core::stream::BoxStream; use futures_util::StreamExt; -use serde::Serialize; use sqlx::{ PgPool, postgres::{PgListener, PgNotification}, @@ -56,22 +55,13 @@ impl Driver for SqlxDriver { Ok(Box::pin(stream)) } - fn notify( - &self, - channel: &str, - req: &T, - ) -> impl Future> + Send { - let client = self.client.clone(); - //TODO: handle error - let msg = serde_json::to_string(req).map_err(|err| sqlx::Error::Decode(Box::new(err))); - async move { - sqlx::query("SELECT pg_notify($1, $2)") - .bind(channel) - .bind(msg?) - .execute(&client) - .await?; - Ok(()) - } + async fn notify(&self, channel: &str, message: &str) -> Result<(), Self::Error> { + sqlx::query("SELECT pg_notify($1, $2)") + .bind(channel) + .bind(message) + .execute(&self.client) + .await?; + Ok(()) } } diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index f4ee9a23..6a142d90 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -640,8 +640,9 @@ impl CustomPostgresAdapter { Some(target) => self.get_node_chan(target), None => self.get_global_chan(), }; + let payload = serde_json::to_string(&req)?; self.driver - .notify(&chan, &req) + .notify(&chan, &payload) .await .map_err(Error::Driver)?; Ok(()) @@ -666,9 +667,13 @@ impl CustomPostgresAdapter { node_id: self.local.server_id(), payload, }; + let message = serde_json::to_string(&res); //TODO: is this the right way? async move { - driver.notify(&chan, &res).await.map_err(Error::Driver)?; + driver + .notify(&chan, &message?) + .await + .map_err(Error::Driver)?; Ok(()) } } From 874799247b99f200bdbd8ac792ca275ef235dd45 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 29 Mar 2026 18:28:52 +0200 Subject: [PATCH 10/12] feat(adapter/postgre): wip --- crates/socketioxide-postgres/README.md | 137 ++++++++++++++ .../socketioxide-postgres/src/drivers/mod.rs | 14 ++ .../socketioxide-postgres/src/drivers/sqlx.rs | 3 + crates/socketioxide-postgres/src/lib.rs | 172 ++++++++++++++++-- e2e/adapter/src/bins/sqlx.rs | 12 +- e2e/adapter/src/bins/sqlx_msgpack.rs | 13 +- 6 files changed, 317 insertions(+), 34 deletions(-) diff --git a/crates/socketioxide-postgres/README.md b/crates/socketioxide-postgres/README.md index e69de29b..dd9781db 100644 --- a/crates/socketioxide-postgres/README.md +++ b/crates/socketioxide-postgres/README.md @@ -0,0 +1,137 @@ +# [`Socketioxide-Postgres`](https://github.com/totodore/socketioxide) 🚀🦀 + +A [***`socket.io`***](https://socket.io) adapter for [***`Socketioxide`***](https://github.com/totodore/socketioxide), using [PostgreSQL LISTEN/NOTIFY](https://www.postgresql.org/docs/current/sql-notify.html) for event broadcasting. This adapter enables **horizontal scaling** of your Socketioxide servers across distributed deployments by leveraging PostgreSQL as a message bus. + +[![Crates.io](https://img.shields.io/crates/v/socketioxide-postgres.svg)](https://crates.io/crates/socketioxide-postgres) +[![Documentation](https://docs.rs/socketioxide-postgres/badge.svg)](https://docs.rs/socketioxide-postgres) +[![CI](https://github.com/Totodore/socketioxide/actions/workflows/github-ci.yml/badge.svg)](https://github.com/Totodore/socketioxide/actions/workflows/github-ci.yml) + + + +## Features + +- **PostgreSQL LISTEN/NOTIFY-based adapter** +- **Support for any PostgreSQL client** via the [`Driver`] abstraction +- Built-in driver for the [sqlx](https://docs.rs/sqlx) crate: [`SqlxDriver`](https://docs.rs/socketioxide-postgres/latest/socketioxide_postgres/drivers/sqlx/struct.SqlxDriver.html) +- **Heartbeat-based liveness detection** for tracking active server nodes +- Fully compatible with the asynchronous Rust ecosystem +- Implement your own custom driver by implementing the `Driver` trait + +> [!WARNING] +> This adapter is **not compatible** with [`@socket.io/postgres-adapter`](https://github.com/socketio/socket.io-postgres-adapter). +> These projects use entirely different protocols and cannot interoperate. +> **Do not mix Socket.IO JavaScript servers with Socketioxide Rust servers**. + + + +## Example: Using the PostgreSQL Adapter with Axum + +```rust +use serde::{Deserialize, Serialize}; +use socketioxide::{ + adapter::Adapter, + extract::{Data, Extension, SocketRef}, + SocketIo, +}; +use socketioxide_postgres::{ + drivers::sqlx::sqlx_client::{self as sqlx, PgPool}, + SqlxAdapter, PostgresAdapterCtr, PostgresAdapterConfig, +}; +use tower::ServiceBuilder; +use tower_http::{cors::CorsLayer, services::ServeDir}; +use tracing::info; +use tracing_subscriber::FmtSubscriber; + +#[derive(Deserialize, Serialize, Debug, Clone)] +#[serde(transparent)] +struct Username(String); + +#[derive(Deserialize, Serialize, Debug, Clone)] +#[serde(rename_all = "camelCase", untagged)] +enum Res { + Message { + username: Username, + message: String, + }, + Username { + username: Username, + }, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = FmtSubscriber::new(); + + tracing::subscriber::set_global_default(subscriber)?; + + info!("Starting server"); + + let pool = PgPool::connect("postgres://user:password@localhost/socketio").await?; + let adapter = PostgresAdapterCtr::new_with_sqlx(pool); + + let (layer, io) = SocketIo::builder() + .with_adapter::>(adapter) + .build_layer(); + io.ns("/", on_connect).await?; + + let app = axum::Router::new() + .fallback_service(ServeDir::new("dist")) + .layer( + ServiceBuilder::new() + .layer(CorsLayer::permissive()) // Enable CORS policy + .layer(layer), + ); + + let port = std::env::var("PORT") + .map(|s| s.parse().unwrap()) + .unwrap_or(3000); + let listener = tokio::net::TcpListener::bind(("0.0.0.0", port)) + .await + .unwrap(); + axum::serve(listener, app).await.unwrap(); + + Ok(()) +} + +async fn on_connect(socket: SocketRef) { + socket.on("new message", on_msg); + socket.on("typing", on_typing); + socket.on("stop typing", on_stop_typing); +} +async fn on_msg( + s: SocketRef, + Data(msg): Data, + Extension(username): Extension, +) { + let msg = &Res::Message { + username, + message: msg, + }; + s.broadcast().emit("new message", msg).await.ok(); +} +async fn on_typing(s: SocketRef, Extension(username): Extension) { + s.broadcast() + .emit("typing", &Res::Username { username }) + .await + .ok(); +} +async fn on_stop_typing(s: SocketRef, Extension(username): Extension) { + s.broadcast() + .emit("stop typing", &Res::Username { username }) + .await + .ok(); +} + +``` + + + +## Contributions and Feedback / Questions + +Contributions are very welcome! Feel free to open an issue or a PR. If you're unsure where to start, check the [issues](https://github.com/totodore/socketioxide/issues). + +For feedback or questions, join the discussion on the [discussions](https://github.com/totodore/socketioxide/discussions) page. + +## License 🔐 + +This project is licensed under the [MIT license](./LICENSE). diff --git a/crates/socketioxide-postgres/src/drivers/mod.rs b/crates/socketioxide-postgres/src/drivers/mod.rs index 6b64b59a..10e1f5cd 100644 --- a/crates/socketioxide-postgres/src/drivers/mod.rs +++ b/crates/socketioxide-postgres/src/drivers/mod.rs @@ -1,5 +1,9 @@ +//! Drivers are an abstraction over the PostgreSQL LISTEN/NOTIFY backend used by the adapter. +//! You can use the provided implementation or implement your own. + use futures_core::Stream; +/// A driver implementation for the [`sqlx`](https://docs.rs/sqlx) PostgreSQL backend. #[cfg(feature = "sqlx")] pub mod sqlx; @@ -9,17 +13,24 @@ pub mod sqlx; /// The driver trait can be used to support different LISTEN/NOTIFY backends. /// It must share handlers/connection between its clones. pub trait Driver: Clone + Send + Sync + 'static { + /// The error type returned by the driver. type Error: std::error::Error + Send + 'static; + /// The notification type yielded by the notification stream. type Notification: Notification; + /// The stream of notifications returned by [`Driver::listen`]. type NotificationStream: Stream + Send; + /// Initialize the driver. This is called once when the adapter is created. + /// It should create the necessary tables or schema if needed. fn init(&self, table: &str) -> impl Future> + Send; + /// Subscribe to the given NOTIFY channels and return a stream of notifications. fn listen( &self, channels: &[&str], ) -> impl Future> + Send; + /// Send a NOTIFY message on the given channel with the given payload. fn notify( &self, channel: &str, @@ -27,7 +38,10 @@ pub trait Driver: Clone + Send + Sync + 'static { ) -> impl Future> + Send; } +/// A trait representing a PostgreSQL NOTIFY notification. pub trait Notification: Send + 'static { + /// The channel name on which the notification was received. fn channel(&self) -> &str; + /// The payload of the notification. fn payload(&self) -> &str; } diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs index d23a05f8..3241dfac 100644 --- a/crates/socketioxide-postgres/src/drivers/sqlx.rs +++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs @@ -9,6 +9,9 @@ use super::Driver; pub use sqlx as sqlx_client; +/// A [`Driver`] implementation using the [`sqlx`] PostgreSQL client. +/// +/// It uses [`PgListener`] for LISTEN/NOTIFY and [`PgPool`] for queries. #[derive(Debug, Clone)] pub struct SqlxDriver { client: PgPool, diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs index 6a142d90..ec09d9e7 100644 --- a/crates/socketioxide-postgres/src/lib.rs +++ b/crates/socketioxide-postgres/src/lib.rs @@ -29,7 +29,55 @@ nonstandard_style, missing_docs )] -//! test +//! # A PostgreSQL adapter implementation for the socketioxide crate. +//! The adapter is used to communicate with other nodes of the same application. +//! This allows to broadcast messages to sockets connected on other servers, +//! to get the list of rooms, to add or remove sockets from rooms, etc. +//! +//! To achieve this, the adapter uses [LISTEN/NOTIFY](https://www.postgresql.org/docs/current/sql-notify.html) +//! through PostgreSQL to communicate with other servers. +//! +//! The [`Driver`] abstraction allows the use of any PostgreSQL client. +//! One implementation is provided: +//! * [`SqlxDriver`](crate::drivers::sqlx::SqlxDriver) for the [`sqlx`] crate. +//! +//! You can also implement your own driver by implementing the [`Driver`] trait. +//! +//!
+//! Socketioxide-postgres is not compatible with @socketio/postgres-adapter. +//! They use completely different protocols and cannot be used together. +//! Do not mix socket.io JS servers with socketioxide rust servers. +//!
+//! +//! ## How does it work? +//! +//! The [`PostgresAdapterCtr`] is a constructor for the [`SqlxAdapter`] which is an implementation of +//! the [`Adapter`](https://docs.rs/socketioxide/latest/socketioxide/adapter/trait.Adapter.html) trait. +//! +//! Then, for each namespace, an adapter is created and it takes a corresponding [`CoreLocalAdapter`]. +//! The [`CoreLocalAdapter`] allows to manage the local rooms and local sockets. The default `LocalAdapter` +//! is simply a wrapper around this [`CoreLocalAdapter`]. +//! +//! Once it is created the adapter is initialized with the [`CustomPostgresAdapter::init`] method. +//! It will subscribe to three PostgreSQL NOTIFY channels and emit heartbeats. +//! All messages are encoded with JSON. +//! +//! There are 7 types of requests: +//! * Broadcast a packet to all the matching sockets. +//! * Broadcast a packet to all the matching sockets and wait for a stream of acks. +//! * Disconnect matching sockets. +//! * Get all the rooms. +//! * Add matching sockets to rooms. +//! * Remove matching sockets from rooms. +//! * Fetch all the remote sockets matching the options. +//! * Heartbeat +//! * Initial heartbeat. When receiving an initial heartbeat all other servers reply a heartbeat immediately. +//! +//! For ack streams, the adapter will first send a `BroadcastAckCount` response to the server that sent the request, +//! and then send the acks as they are received (more details in [`CustomPostgresAdapter::broadcast_with_ack`] fn). +//! +//! On the other side, each time an action has to be performed on the local server, the adapter will +//! first broadcast a request to all the servers and then perform the action locally. use drivers::Driver; use futures_core::Stream; @@ -68,7 +116,7 @@ use crate::{ pub mod drivers; mod stream; -/// The configuration of the [`MongoDbAdapter`]. +/// The configuration of the [`CustomPostgresAdapter`]. #[derive(Debug, Clone)] pub struct PostgresAdapterConfig { /// The heartbeat timeout duration. If a remote node does not respond within this duration, @@ -91,11 +139,75 @@ pub struct PostgresAdapterConfig { pub table_name: Cow<'static, str>, /// The prefix used for the channels. Default is "socket.io". pub prefix: Cow<'static, str>, - /// The treshold to the payload size in bytes. It should match the configured value on your PostgreSQL instance: + /// The threshold to the payload size in bytes. It should match the configured value on your PostgreSQL instance: /// . By default it is 8KB (8000 bytes). - pub payload_treshold: usize, + pub payload_threshold: usize, /// The duration between cleanup queries on the attachment table. - pub cleanup_intervals: Duration, + pub cleanup_interval: Duration, +} + +impl PostgresAdapterConfig { + /// Create a new [`PostgresAdapterConfig`] with default values. + pub fn new() -> Self { + Self::default() + } + + /// The heartbeat timeout duration. If a remote node does not respond within this duration, + /// it will be considered disconnected. Default is 60 seconds. + pub fn with_hb_timeout(mut self, hb_timeout: Duration) -> Self { + self.hb_timeout = hb_timeout; + self + } + + /// The heartbeat interval duration. The current node will broadcast a heartbeat to the + /// remote nodes at this interval. Default is 10 seconds. + pub fn with_hb_interval(mut self, hb_interval: Duration) -> Self { + self.hb_interval = hb_interval; + self + } + + /// The request timeout. When expecting a response from remote nodes, if they do not respond within + /// this duration, the request will be considered failed. Default is 5 seconds. + pub fn with_request_timeout(mut self, request_timeout: Duration) -> Self { + self.request_timeout = request_timeout; + self + } + + /// The channel size used to receive ack responses. Default is 255. + /// + /// If you have a lot of servers/sockets and that you may miss acknowledgement because they arrive faster + /// than you poll them with the returned stream, you might want to increase this value. + pub fn with_ack_response_buffer(mut self, ack_response_buffer: usize) -> Self { + self.ack_response_buffer = ack_response_buffer; + self + } + + /// The table name used to store socket.io attachments. Default is "socket_io_attachments". + /// + /// > The table name must be a sanitized string. Do not use special characters or spaces. + pub fn with_table_name(mut self, table_name: impl Into>) -> Self { + self.table_name = table_name.into(); + self + } + + /// The prefix used for the channels. Default is "socket.io". + pub fn with_prefix(mut self, prefix: impl Into>) -> Self { + self.prefix = prefix.into(); + self + } + + /// The threshold to the payload size in bytes. It should match the configured value on your PostgreSQL instance: + /// . By default it is 8KB (8000 bytes). + pub fn with_payload_threshold(mut self, payload_threshold: usize) -> Self { + self.payload_threshold = payload_threshold; + self + } + + /// The duration between cleanup queries on the attachment table. Default is 60 seconds. + pub fn with_cleanup_interval(mut self, cleanup_interval: Duration) -> Self { + self.cleanup_interval = cleanup_interval; + self + } } impl Default for PostgresAdapterConfig { @@ -107,8 +219,8 @@ impl Default for PostgresAdapterConfig { ack_response_buffer: 255, table_name: "socket_io_attachments".into(), prefix: "socket.io".into(), - payload_treshold: 8_000, - cleanup_intervals: Duration::from_secs(60), + payload_threshold: 8_000, + cleanup_interval: Duration::from_secs(60), } } } @@ -116,7 +228,7 @@ impl Default for PostgresAdapterConfig { /// Represent any error that might happen when using this adapter. #[derive(thiserror::Error)] pub enum Error { - /// Mongo driver error + /// Postgres driver error #[error("driver error: {0}")] Driver(D::Error), /// Packet encoding/decoding error @@ -139,12 +251,33 @@ impl From> for AdapterError { } } -/// Constructor for the PostgresAdapterCtr struct. +/// The adapter constructor. For each namespace you define, a new adapter instance is created +/// from this constructor. +#[derive(Debug, Clone)] pub struct PostgresAdapterCtr { driver: D, config: PostgresAdapterConfig, } +#[cfg(feature = "sqlx")] +impl PostgresAdapterCtr { + /// Create a new adapter constructor with the [`sqlx`](drivers::sqlx) driver + /// and a default config. + pub fn new_with_sqlx(pool: drivers::sqlx::sqlx_client::PgPool) -> Self { + Self::new_with_sqlx_config(pool, PostgresAdapterConfig::default()) + } + + /// Create a new adapter constructor with the [`sqlx`](drivers::sqlx) driver + /// and a custom config. + pub fn new_with_sqlx_config( + pool: drivers::sqlx::sqlx_client::PgPool, + config: PostgresAdapterConfig, + ) -> Self { + let driver = drivers::sqlx::SqlxDriver::new(pool); + Self { driver, config } + } +} + impl PostgresAdapterCtr { /// Create a new adapter constructor with a custom postgres driver and a config. /// @@ -155,6 +288,10 @@ impl PostgresAdapterCtr { } } +/// The postgres adapter with the [`sqlx`](drivers::sqlx) driver. +#[cfg(feature = "sqlx")] +pub type SqlxAdapter = CustomPostgresAdapter; + type ResponseHandlers = HashMap>>; /// The postgres adapter implementation. @@ -661,14 +798,15 @@ impl CustomPostgresAdapter { ); let driver = self.driver.clone(); let chan = self.get_response_chan(req_origin); - let payload = RawValue::from_string(serde_json::to_string(&payload).unwrap()).unwrap(); - let res = ResponsePacket { - req_id, - node_id: self.local.server_id(), - payload, - }; - let message = serde_json::to_string(&res); - //TODO: is this the right way? + let message = serde_json::to_string(&payload) + .and_then(RawValue::from_string) + .map(|payload| ResponsePacket { + req_id, + node_id: self.local.server_id(), + payload, + }) + .and_then(|res| serde_json::to_string(&res)); + async move { driver .notify(&chan, &message?) diff --git a/e2e/adapter/src/bins/sqlx.rs b/e2e/adapter/src/bins/sqlx.rs index fb260072..37c457ba 100644 --- a/e2e/adapter/src/bins/sqlx.rs +++ b/e2e/adapter/src/bins/sqlx.rs @@ -3,8 +3,7 @@ use hyper_util::rt::TokioIo; use socketioxide::SocketIo; use socketioxide_postgres::{ - CustomPostgresAdapter, PostgresAdapterConfig, PostgresAdapterCtr, - drivers::sqlx::{SqlxDriver, sqlx_client::PgPool}, + PostgresAdapterConfig, PostgresAdapterCtr, SqlxAdapter, drivers::sqlx::sqlx_client::PgPool, }; use tokio::net::TcpListener; use tracing::{Level, info}; @@ -20,15 +19,12 @@ async fn main() -> Result<(), Box> { let variant = std::env::args().next().unwrap(); let variant = variant.split("/").last().unwrap(); - let config = PostgresAdapterConfig { - prefix: format!("socket.io-{variant}").into(), - ..Default::default() - }; + let config = PostgresAdapterConfig::new().with_prefix(format!("socket.io-{variant}")); let pg_pool = PgPool::connect("postgres://socketio:socketio@localhost:5432/socketio").await?; - let adapter = PostgresAdapterCtr::new_with_driver(SqlxDriver::new(pg_pool), config); + let adapter = PostgresAdapterCtr::new_with_sqlx_config(pg_pool, config); let (svc, io) = SocketIo::builder() - .with_adapter::>(adapter) + .with_adapter::>(adapter) .build_svc(); io.ns("/", adapter_e2e::handler).await?; diff --git a/e2e/adapter/src/bins/sqlx_msgpack.rs b/e2e/adapter/src/bins/sqlx_msgpack.rs index d7f420f3..c5a9482d 100644 --- a/e2e/adapter/src/bins/sqlx_msgpack.rs +++ b/e2e/adapter/src/bins/sqlx_msgpack.rs @@ -3,8 +3,7 @@ use hyper_util::rt::TokioIo; use socketioxide::{ParserConfig, SocketIo}; use socketioxide_postgres::{ - CustomPostgresAdapter, PostgresAdapterConfig, PostgresAdapterCtr, - drivers::sqlx::{SqlxDriver, sqlx_client::PgPool}, + PostgresAdapterConfig, PostgresAdapterCtr, SqlxAdapter, drivers::sqlx::sqlx_client::PgPool, }; use tokio::net::TcpListener; use tracing::{Level, info}; @@ -19,17 +18,13 @@ async fn main() -> Result<(), Box> { tracing::subscriber::set_global_default(subscriber)?; let variant = std::env::args().next().unwrap(); let variant = variant.split("/").last().unwrap(); - - let config = PostgresAdapterConfig { - prefix: format!("socket.io-{variant}").into(), - ..Default::default() - }; + let config = PostgresAdapterConfig::new().with_prefix(format!("socket.io-{variant}")); let pg_pool = PgPool::connect("postgres://socketio:socketio@localhost:5432/socketio").await?; - let adapter = PostgresAdapterCtr::new_with_driver(SqlxDriver::new(pg_pool), config); + let adapter = PostgresAdapterCtr::new_with_sqlx_config(pg_pool, config); let (svc, io) = SocketIo::builder() .with_parser(ParserConfig::msgpack()) - .with_adapter::>(adapter) + .with_adapter::>(adapter) .build_svc(); io.ns("/", adapter_e2e::handler).await?; From 0a36aded7378ffee50b9ace2a819904392811a17 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 29 Mar 2026 18:32:38 +0200 Subject: [PATCH 11/12] feat(adapter/postgre): wip --- e2e/adapter/main.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/e2e/adapter/main.rs b/e2e/adapter/main.rs index f18af801..35125745 100644 --- a/e2e/adapter/main.rs +++ b/e2e/adapter/main.rs @@ -26,9 +26,6 @@ fn main() -> Result<(), Box> { let bin_filter = args().nth(1).unwrap_or("".to_string()); println!("binary target filter: {}", bin_filter); - let test_filter = args().nth(2); - println!("test filter: {}", test_filter.as_deref().unwrap_or("*")); - if fs::exists(LOG_DIR)? { fs::remove_dir_all(LOG_DIR)?; } @@ -36,14 +33,14 @@ fn main() -> Result<(), Box> { // run everything for target in BINS.iter().filter(|name| name.contains(&bin_filter)) { - run(target, test_filter.as_deref()); + run(target); } println!("All tests passed!"); Ok(()) } -fn run(target: &'static str, test_filter: Option<&str>) { +fn run(target: &'static str) { let parser = if target.ends_with("msgpack") { "msgpack" } else { From d6fbb14deff22439da0aabb5cb50fd0203ee29c7 Mon Sep 17 00:00:00 2001 From: totodore Date: Sun, 29 Mar 2026 18:43:11 +0200 Subject: [PATCH 12/12] feat(adapter/postgre): add tests --- crates/socketioxide-postgres/Cargo.toml | 1 + .../socketioxide-postgres/tests/broadcast.rs | 149 +++++++++++ crates/socketioxide-postgres/tests/fixture.rs | 247 ++++++++++++++++++ crates/socketioxide-postgres/tests/local.rs | 32 +++ crates/socketioxide-postgres/tests/rooms.rs | 119 +++++++++ crates/socketioxide-postgres/tests/sockets.rs | 170 ++++++++++++ 6 files changed, 718 insertions(+) create mode 100644 crates/socketioxide-postgres/tests/broadcast.rs create mode 100644 crates/socketioxide-postgres/tests/fixture.rs create mode 100644 crates/socketioxide-postgres/tests/local.rs create mode 100644 crates/socketioxide-postgres/tests/rooms.rs create mode 100644 crates/socketioxide-postgres/tests/sockets.rs diff --git a/crates/socketioxide-postgres/Cargo.toml b/crates/socketioxide-postgres/Cargo.toml index 3aa2604e..21724c6f 100644 --- a/crates/socketioxide-postgres/Cargo.toml +++ b/crates/socketioxide-postgres/Cargo.toml @@ -52,6 +52,7 @@ socketioxide = { path = "../socketioxide", features = [ ] } tracing-subscriber.workspace = true bytes.workspace = true +futures-util.workspace = true # docs.rs-specific configuration [package.metadata.docs.rs] diff --git a/crates/socketioxide-postgres/tests/broadcast.rs b/crates/socketioxide-postgres/tests/broadcast.rs new file mode 100644 index 00000000..c7ba71a9 --- /dev/null +++ b/crates/socketioxide-postgres/tests/broadcast.rs @@ -0,0 +1,149 @@ +use std::time::Duration; + +use socketioxide::{adapter::Adapter, extract::SocketRef}; +mod fixture; + +#[tokio::test] +pub async fn broadcast() { + async fn handler(socket: SocketRef
) { + // delay to ensure all socket/servers are connected + tokio::time::sleep(Duration::from_millis(1)).await; + socket.broadcast().emit("test", &2).await.unwrap(); + } + + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", handler).await.unwrap(); + io2.ns("/", handler).await.unwrap(); + + let ((_tx1, mut rx1), (_tx2, mut rx2)) = + tokio::join!(io1.new_dummy_sock("/", ()), io2.new_dummy_sock("/", ())); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + assert_eq!(timeout_rcv!(&mut rx1), r#"42["test",2]"#); + assert_eq!(timeout_rcv!(&mut rx2), r#"42["test",2]"#); + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); +} + +#[tokio::test] +pub async fn broadcast_rooms() { + let [io1, io2, io3] = fixture::spawn_servers(); + let handler = |room: &'static str, to: &'static str| { + move |socket: SocketRef<_>| async move { + // delay to ensure all socket/servers are connected + socket.join(room); + tokio::time::sleep(Duration::from_millis(5)).await; + socket.to(to).emit("test", room).await.unwrap(); + } + }; + + io1.ns("/", handler("room1", "room2")).await.unwrap(); + io2.ns("/", handler("room2", "room3")).await.unwrap(); + io3.ns("/", handler("room3", "room1")).await.unwrap(); + + let ((_tx1, mut rx1), (_tx2, mut rx2), (_tx3, mut rx3)) = tokio::join!( + io1.new_dummy_sock("/", ()), + io2.new_dummy_sock("/", ()), + io3.new_dummy_sock("/", ()) + ); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + timeout_rcv!(&mut rx3); // Connect "/" packet + + // socket 1 is receiving a packet from io3 + assert_eq!(timeout_rcv!(&mut rx1), r#"42["test","room3"]"#); + // socket 2 is receiving a packet from io2 + assert_eq!(timeout_rcv!(&mut rx2), r#"42["test","room1"]"#); + // socket 3 is receiving a packet from io1 + assert_eq!(timeout_rcv!(&mut rx3), r#"42["test","room2"]"#); + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); + timeout_rcv_err!(&mut rx3); +} + +#[tokio::test] +pub async fn broadcast_with_ack() { + use futures_util::stream::StreamExt; + + async fn handler(socket: SocketRef) { + // delay to ensure all socket/servers are connected + tokio::time::sleep(Duration::from_millis(1)).await; + socket + .broadcast() + .emit_with_ack::<_, String>("test", "bar") + .await + .unwrap() + .for_each(|(_, res)| { + socket.emit("ack_res", &res).unwrap(); + async move {} + }) + .await; + } + + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", handler).await.unwrap(); + io2.ns("/", async || ()).await.unwrap(); + + let ((_tx1, mut rx1), (tx2, mut rx2)) = + tokio::join!(io1.new_dummy_sock("/", ()), io2.new_dummy_sock("/", ())); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + + assert_eq!(timeout_rcv!(&mut rx2), r#"421["test","bar"]"#); + let packet_res = r#"431["foo"]"#.to_string().try_into().unwrap(); + tx2.try_send(packet_res).unwrap(); + assert_eq!(timeout_rcv!(&mut rx1), r#"42["ack_res",{"Ok":"foo"}]"#); + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); +} + +#[tokio::test] +pub async fn broadcast_with_ack_timeout() { + use futures_util::StreamExt; + const TIMEOUT: Duration = Duration::from_millis(50); + + async fn handler(socket: SocketRef) { + socket + .broadcast() + .emit_with_ack::<_, String>("test", "bar") + .await + .unwrap() + .for_each(|(_, res)| { + socket.emit("ack_res", &res).unwrap(); + async move {} + }) + .await; + socket.emit("ack_res", "timeout").unwrap(); + } + + let [io1, io2] = fixture::spawn_buggy_servers(TIMEOUT); + + io1.ns("/", handler).await.unwrap(); + io2.ns("/", async || ()).await.unwrap(); + + let now = std::time::Instant::now(); + let ((_tx1, mut rx1), (_tx2, mut rx2)) = + tokio::join!(io1.new_dummy_sock("/", ()), io2.new_dummy_sock("/", ())); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + + assert_eq!(timeout_rcv!(&mut rx2), r#"421["test","bar"]"#); // emit with ack message + // We do not answer + assert_eq!( + timeout_rcv!(&mut rx1, TIMEOUT.as_millis() as u64 + 100), + r#"42["ack_res","timeout"]"# + ); + assert!(now.elapsed() >= TIMEOUT); + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); +} diff --git a/crates/socketioxide-postgres/tests/fixture.rs b/crates/socketioxide-postgres/tests/fixture.rs new file mode 100644 index 00000000..070d9b97 --- /dev/null +++ b/crates/socketioxide-postgres/tests/fixture.rs @@ -0,0 +1,247 @@ +#![allow(dead_code)] + +use futures_core::Stream; +use socketioxide_core::Uid; +use socketioxide_postgres::{ + CustomPostgresAdapter, PostgresAdapterConfig, PostgresAdapterCtr, + drivers::{Driver, Notification}, +}; +use std::{ + convert::Infallible, + pin::Pin, + str::FromStr, + sync::{Arc, RwLock}, + task, + time::Duration, +}; +use tokio::sync::mpsc; + +use socketioxide::{SocketIo, SocketIoConfig, adapter::Emitter}; + +/// Spawns a number of servers with a stub driver for testing. +/// Every server will be connected to every other server. +pub fn spawn_servers() -> [SocketIo>; N] +{ + let sync_buff = Arc::new(RwLock::new(Vec::with_capacity(N))); + spawn_inner(sync_buff, PostgresAdapterConfig::default()) +} + +pub fn spawn_buggy_servers( + timeout: Duration, +) -> [SocketIo>; N] { + let sync_buff = Arc::new(RwLock::new(Vec::with_capacity(N))); + let config = PostgresAdapterConfig::default().with_request_timeout(timeout); + let res = spawn_inner(sync_buff.clone(), config); + + // Reinject a false heartbeat request to simulate a bad number of servers. + // This will trigger timeouts when expecting responses from all servers. + // The heartbeat type is 20 (RequestTypeOut::Heartbeat) in the wire format. + let uid: Uid = Uid::from_str("PHHq01ObWy7Godqx").unwrap(); + let heartbeat_json = serde_json::json!({ + "node_id": uid.to_string(), + "id": "ZG9K1r7xSLBiJYWD", + "type": 20, + "opts": null, + }); + let payload = serde_json::to_string(&heartbeat_json).unwrap(); + + for (_, tx) in sync_buff.read().unwrap().iter() { + // Send the heartbeat to the global channel of the "/" namespace + tx.try_send(StubNotification { + channel: "socket.io#/".to_string(), + payload: payload.clone(), + }) + .unwrap(); + } + + res +} + +fn spawn_inner( + sync_buff: Arc>, + config: PostgresAdapterConfig, +) -> [SocketIo>; N] { + [0; N].map(|_| { + let server_id = Uid::new(); + let (driver, mut rx, tx) = StubDriver::new(server_id); + + // pipe messages to all other servers + sync_buff.write().unwrap().push((server_id, tx)); + let sync_buff = sync_buff.clone(); + tokio::spawn(async move { + while let Some(notif) = rx.recv().await { + tracing::debug!("received notify on channel {:?}", notif.channel); + for (sid, tx) in sync_buff.read().unwrap().iter() { + if *sid != server_id { + tracing::debug!("forwarding notify to server {:?}", sid); + tx.try_send(notif.clone()).unwrap(); + } + } + } + }); + + let adapter = PostgresAdapterCtr::new_with_driver(driver, config.clone()); + let mut config = SocketIoConfig::default(); + config.server_id = server_id; + let (_svc, io) = SocketIo::builder() + .with_config(config) + .with_adapter::>(adapter) + .build_svc(); + io + }) +} + +type NotifyHandlers = Vec<(Uid, mpsc::Sender)>; + +#[derive(Debug, Clone)] +pub struct StubNotification { + channel: String, + payload: String, +} + +impl Notification for StubNotification { + fn channel(&self) -> &str { + &self.channel + } + + fn payload(&self) -> &str { + &self.payload + } +} + +#[derive(Debug, Clone)] +pub struct StubDriver { + server_id: Uid, + /// Sender to emit outgoing NOTIFY messages (to be broadcast to other servers). + tx: mpsc::Sender, + /// Handlers for incoming notifications per listened channel. + handlers: Arc)>>>, +} + +impl StubDriver { + pub fn new( + server_id: Uid, + ) -> (Self, mpsc::Receiver, mpsc::Sender) { + let (tx, rx) = mpsc::channel(255); // outgoing notifies + let (tx1, rx1) = mpsc::channel(255); // incoming notifies + let handlers: Arc)>>> = + Arc::new(RwLock::new(Vec::new())); + + tokio::spawn(pipe_handlers(rx1, handlers.clone())); + + let driver = Self { + server_id, + tx, + handlers, + }; + (driver, rx, tx1) + } +} + +/// Pipe incoming notifications to the matching channel handlers. +async fn pipe_handlers( + mut rx: mpsc::Receiver, + handlers: Arc)>>>, +) { + while let Some(notif) = rx.recv().await { + let handlers = handlers.read().unwrap(); + for (chan, handler) in &*handlers { + if *chan == notif.channel { + handler.try_send(notif.clone()).unwrap(); + } + } + } +} + +pin_project_lite::pin_project! { + pub struct NotificationStream { + #[pin] + rx: mpsc::Receiver, + } +} + +impl Stream for NotificationStream { + type Item = StubNotification; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + self.project().rx.poll_recv(cx) + } +} + +impl Driver for StubDriver { + type Error = Infallible; + type Notification = StubNotification; + type NotificationStream = NotificationStream; + + async fn init(&self, _table: &str) -> Result<(), Self::Error> { + Ok(()) + } + + async fn listen(&self, channels: &[&str]) -> Result { + let (tx, rx) = mpsc::channel(255); + let mut handlers = self.handlers.write().unwrap(); + for chan in channels { + handlers.push((chan.to_string(), tx.clone())); + } + Ok(NotificationStream { rx }) + } + + async fn notify(&self, channel: &str, message: &str) -> Result<(), Self::Error> { + // Also deliver to local handlers (self-delivery, like real PG NOTIFY). + { + let handlers = self.handlers.read().unwrap(); + for (chan, handler) in &*handlers { + if *chan == channel { + handler + .try_send(StubNotification { + channel: channel.to_string(), + payload: message.to_string(), + }) + .unwrap(); + } + } + } + // Send to the broadcast pipe for delivery to other servers. + self.tx + .try_send(StubNotification { + channel: channel.to_string(), + payload: message.to_string(), + }) + .unwrap(); + Ok(()) + } +} + +#[macro_export] +macro_rules! timeout_rcv_err { + ($srx:expr) => { + tokio::time::timeout(std::time::Duration::from_millis(10), $srx.recv()) + .await + .unwrap_err(); + }; +} + +#[macro_export] +macro_rules! timeout_rcv { + ($srx:expr) => { + TryInto::::try_into( + tokio::time::timeout(std::time::Duration::from_millis(10), $srx.recv()) + .await + .unwrap() + .unwrap(), + ) + .unwrap() + }; + ($srx:expr, $t:expr) => { + TryInto::::try_into( + tokio::time::timeout(std::time::Duration::from_millis($t), $srx.recv()) + .await + .unwrap() + .unwrap(), + ) + .unwrap() + }; +} diff --git a/crates/socketioxide-postgres/tests/local.rs b/crates/socketioxide-postgres/tests/local.rs new file mode 100644 index 00000000..49972933 --- /dev/null +++ b/crates/socketioxide-postgres/tests/local.rs @@ -0,0 +1,32 @@ +//! Check that each adapter function with a broadcast options that is [`Local`] returns an immediate future +mod fixture; + +macro_rules! assert_now { + ($fut:expr) => { + #[allow(unused_must_use)] + futures_util::FutureExt::now_or_never($fut) + .expect("Returned future should be sync") + .unwrap() + }; +} + +#[tokio::test] +async fn test_local_fns() { + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", async || ()).await.unwrap(); + io2.ns("/", async || ()).await.unwrap(); + + let (_, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_, mut rx2) = io2.new_dummy_sock("/", ()).await; + + timeout_rcv!(&mut rx1); // connect packet + timeout_rcv!(&mut rx2); // connect packet + + assert_now!(io1.local().emit("test", "test")); + assert_now!(io1.local().emit_with_ack::<_, ()>("test", "test")); + assert_now!(io1.local().join("test")); + assert_now!(io1.local().leave("test")); + assert_now!(io1.local().disconnect()); + assert_now!(io1.local().fetch_sockets()); +} diff --git a/crates/socketioxide-postgres/tests/rooms.rs b/crates/socketioxide-postgres/tests/rooms.rs new file mode 100644 index 00000000..343d400f --- /dev/null +++ b/crates/socketioxide-postgres/tests/rooms.rs @@ -0,0 +1,119 @@ +use std::time::Duration; + +use socketioxide::extract::SocketRef; + +mod fixture; + +#[tokio::test] +pub async fn all_rooms() { + let [io1, io2, io3] = fixture::spawn_servers(); + let handler = + |rooms: &'static [&'static str]| async move |socket: SocketRef<_>| socket.join(rooms); + + io1.ns("/", handler(&["room1", "room2"])).await.unwrap(); + io2.ns("/", handler(&["room2", "room3"])).await.unwrap(); + io3.ns("/", handler(&["room3", "room1"])).await.unwrap(); + + let ((_tx1, mut rx1), (_tx2, mut rx2), (_tx3, mut rx3)) = tokio::join!( + io1.new_dummy_sock("/", ()), + io2.new_dummy_sock("/", ()), + io3.new_dummy_sock("/", ()) + ); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + timeout_rcv!(&mut rx3); // Connect "/" packet + + const ROOMS: [&str; 3] = ["room1", "room2", "room3"]; + for io in [io1, io2, io3] { + let mut rooms = io.rooms().await.unwrap(); + rooms.sort(); + assert_eq!(rooms, ROOMS); + } + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); + timeout_rcv_err!(&mut rx3); +} + +#[tokio::test] +pub async fn all_rooms_timeout() { + const TIMEOUT: Duration = Duration::from_millis(50); + let [io1, io2, io3] = fixture::spawn_buggy_servers(TIMEOUT); + let handler = + |rooms: &'static [&'static str]| async move |socket: SocketRef<_>| socket.join(rooms); + + io1.ns("/", handler(&["room1", "room2"])).await.unwrap(); + io2.ns("/", handler(&["room2", "room3"])).await.unwrap(); + io3.ns("/", handler(&["room3", "room1"])).await.unwrap(); + + let ((_tx1, mut rx1), (_tx2, mut rx2), (_tx3, mut rx3)) = tokio::join!( + io1.new_dummy_sock("/", ()), + io2.new_dummy_sock("/", ()), + io3.new_dummy_sock("/", ()) + ); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + timeout_rcv!(&mut rx3); // Connect "/" packet + + const ROOMS: [&str; 3] = ["room1", "room2", "room3"]; + for io in [io1, io3, io2] { + let now = std::time::Instant::now(); + let mut rooms = io.rooms().await.unwrap(); + dbg!(&rooms); + assert!(dbg!(now.elapsed()) >= TIMEOUT); // timeout time + rooms.sort(); + assert_eq!(rooms, ROOMS); + } + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); + timeout_rcv_err!(&mut rx3); +} +#[tokio::test] +pub async fn add_sockets() { + let handler = |room: &'static str| async move |socket: SocketRef<_>| socket.join(room); + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", handler("room1")).await.unwrap(); + io2.ns("/", handler("room3")).await.unwrap(); + + let ((_tx1, mut rx1), (_tx2, mut rx2)) = + tokio::join!(io1.new_dummy_sock("/", ()), io2.new_dummy_sock("/", ())); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + io1.broadcast().join("room2").await.unwrap(); + let mut rooms = io1.rooms().await.unwrap(); + rooms.sort(); + assert_eq!(rooms, ["room1", "room2", "room3"]); + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); +} + +#[tokio::test] +pub async fn del_sockets() { + let handler = + |rooms: &'static [&'static str]| async move |socket: SocketRef<_>| socket.join(rooms); + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", handler(&["room1", "room2"])).await.unwrap(); + io2.ns("/", handler(&["room3", "room2"])).await.unwrap(); + + let ((_tx1, mut rx1), (_tx2, mut rx2)) = + tokio::join!(io1.new_dummy_sock("/", ()), io2.new_dummy_sock("/", ())); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + + io1.broadcast().leave("room2").await.unwrap(); + + let mut rooms = io1.rooms().await.unwrap(); + rooms.sort(); + assert_eq!(rooms, ["room1", "room3"]); + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); +} diff --git a/crates/socketioxide-postgres/tests/sockets.rs b/crates/socketioxide-postgres/tests/sockets.rs new file mode 100644 index 00000000..947151ff --- /dev/null +++ b/crates/socketioxide-postgres/tests/sockets.rs @@ -0,0 +1,170 @@ +use std::{str::FromStr, time::Duration}; + +use socketioxide::{ + SocketIo, adapter::Adapter, extract::SocketRef, operators::BroadcastOperators, + socket::RemoteSocket, +}; +use socketioxide_core::{Sid, Str, adapter::RemoteSocketData}; +use tokio::time::Instant; + +mod fixture; +fn extract_sid(data: &str) -> Sid { + let data = data + .split("\"sid\":\"") + .nth(1) + .and_then(|s| s.split('"').next()) + .unwrap(); + Sid::from_str(data).unwrap() +} +async fn fetch_sockets_data(op: BroadcastOperators) -> Vec { + let mut sockets = op + .fetch_sockets() + .await + .unwrap() + .into_iter() + .map(RemoteSocket::into_data) + .collect::>(); + sockets.sort_by(|a, b| a.id.cmp(&b.id)); + sockets +} +fn create_expected_sockets( + ids: [Sid; N], + ios: [&SocketIo; N], +) -> [RemoteSocketData; N] { + let mut i = 0; + let mut sockets = ios.map(|io| { + let id = ids[i]; + i += 1; + RemoteSocketData { + id, + server_id: io.config().server_id, + ns: Str::from("/"), + } + }); + sockets.sort_by(|a, b| a.id.cmp(&b.id)); + sockets +} + +#[tokio::test] +pub async fn fetch_sockets() { + let [io1, io2, io3] = fixture::spawn_servers::<3>(); + + io1.ns("/", async || ()).await.unwrap(); + io2.ns("/", async || ()).await.unwrap(); + io3.ns("/", async || ()).await.unwrap(); + + let (_, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_, mut rx2) = io2.new_dummy_sock("/", ()).await; + let (_, mut rx3) = io3.new_dummy_sock("/", ()).await; + + let id1 = extract_sid(&timeout_rcv!(&mut rx1)); + let id2 = extract_sid(&timeout_rcv!(&mut rx2)); + let id3 = extract_sid(&timeout_rcv!(&mut rx3)); + + let mut expected_sockets = create_expected_sockets([id1, id2, id3], [&io1, &io2, &io3]); + expected_sockets.sort_by(|a, b| a.id.cmp(&b.id)); + + let sockets = fetch_sockets_data(io1.broadcast()).await; + assert_eq!(sockets, expected_sockets); + + let sockets = fetch_sockets_data(io2.broadcast()).await; + assert_eq!(sockets, expected_sockets); + + let sockets = fetch_sockets_data(io3.broadcast()).await; + assert_eq!(sockets, expected_sockets); +} + +#[tokio::test] +pub async fn fetch_sockets_with_rooms() { + let [io1, io2, io3] = fixture::spawn_servers::<3>(); + let handler = + |rooms: &'static [&'static str]| async move |socket: SocketRef<_>| socket.join(rooms); + + io1.ns("/", handler(&["room1", "room2"])).await.unwrap(); + io2.ns("/", handler(&["room2", "room3"])).await.unwrap(); + io3.ns("/", handler(&["room3", "room1"])).await.unwrap(); + + let (_, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_, mut rx2) = io2.new_dummy_sock("/", ()).await; + let (_, mut rx3) = io3.new_dummy_sock("/", ()).await; + + let id1 = extract_sid(&timeout_rcv!(&mut rx1)); + let id2 = extract_sid(&timeout_rcv!(&mut rx2)); + let id3 = extract_sid(&timeout_rcv!(&mut rx3)); + + let sockets = fetch_sockets_data(io1.to("room1")).await; + assert_eq!(sockets, create_expected_sockets([id1, id3], [&io1, &io3])); + + let sockets = fetch_sockets_data(io1.to("room2")).await; + assert_eq!(sockets, create_expected_sockets([id1, id2], [&io1, &io2])); + + let sockets = fetch_sockets_data(io1.to("room3")).await; + assert_eq!(sockets, create_expected_sockets([id2, id3], [&io2, &io3])); +} + +#[tokio::test] +pub async fn fetch_sockets_timeout() { + const TIMEOUT: Duration = Duration::from_millis(50); + let [io1, io2] = fixture::spawn_buggy_servers(TIMEOUT); + + io1.ns("/", async || ()).await.unwrap(); + io2.ns("/", async || ()).await.unwrap(); + + let (_, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_, mut rx2) = io2.new_dummy_sock("/", ()).await; + + timeout_rcv!(&mut rx1); // connect packet + timeout_rcv!(&mut rx2); // connect packet + + let now = Instant::now(); + io1.fetch_sockets().await.unwrap(); + assert!(now.elapsed() >= TIMEOUT); +} + +#[tokio::test] +pub async fn remote_socket_emit() { + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", async || ()).await.unwrap(); + io2.ns("/", async || ()).await.unwrap(); + + let (_, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_, mut rx2) = io2.new_dummy_sock("/", ()).await; + + timeout_rcv!(&mut rx1); // connect packet + timeout_rcv!(&mut rx2); // connect packet + + let sockets = io1.fetch_sockets().await.unwrap(); + for socket in sockets { + socket.emit("test", "hello").await.unwrap(); + } + + assert_eq!(timeout_rcv!(&mut rx1), r#"42["test","hello"]"#); + assert_eq!(timeout_rcv!(&mut rx2), r#"42["test","hello"]"#); +} + +#[tokio::test] +pub async fn remote_socket_emit_with_ack() { + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", async || ()).await.unwrap(); + io2.ns("/", async || ()).await.unwrap(); + + let (_, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_, mut rx2) = io2.new_dummy_sock("/", ()).await; + + timeout_rcv!(&mut rx1); // connect packet + timeout_rcv!(&mut rx2); // connect packet + + let sockets = io1.fetch_sockets().await.unwrap(); + for socket in sockets { + #[allow(unused_must_use)] + socket + .emit_with_ack::<_, ()>("test", "hello") + .await + .unwrap(); + } + + assert_eq!(timeout_rcv!(&mut rx1), r#"421["test","hello"]"#); + assert_eq!(timeout_rcv!(&mut rx2), r#"421["test","hello"]"#); +}