{
+ let st = &format!(
+ r#"CREATE TABLE IF NOT EXISTS "{table}" (
+ id BIGSERIAL UNIQUE,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ payload BYTEA
+ )"#
+ );
+
+ self.client.execute(st, &[]).await?;
+
+ Ok(())
+ }
+
+ async fn notify(&self, channel: &str, message: &str) -> Result<(), Self::Error> {
+ self.client
+ .execute("SELECT pg_notify($1, $2)", &[&channel, &message])
+ .await?;
+ Ok(())
+ }
+}
+
+impl super::Notification for tokio_postgres::Notification {
+ fn channel(&self) -> &str {
+ tokio_postgres::Notification::channel(self)
+ }
+
+ fn payload(&self) -> &str {
+ tokio_postgres::Notification::payload(self)
+ }
+}
diff --git a/crates/socketioxide-postgres/src/drivers/sqlx.rs b/crates/socketioxide-postgres/src/drivers/sqlx.rs
new file mode 100644
index 00000000..3241dfac
--- /dev/null
+++ b/crates/socketioxide-postgres/src/drivers/sqlx.rs
@@ -0,0 +1,79 @@
+use futures_core::stream::BoxStream;
+use futures_util::StreamExt;
+use sqlx::{
+ PgPool,
+ postgres::{PgListener, PgNotification},
+};
+
+use super::Driver;
+
+pub use sqlx as sqlx_client;
+
+/// A [`Driver`] implementation using the [`sqlx`] PostgreSQL client.
+///
+/// It uses [`PgListener`] for LISTEN/NOTIFY and [`PgPool`] for queries.
+#[derive(Debug, Clone)]
+pub struct SqlxDriver {
+ client: PgPool,
+}
+
+impl SqlxDriver {
+ /// Create a new SqlxDriver instance.
+ pub fn new(client: PgPool) -> Self {
+ Self { client }
+ }
+}
+
+impl Driver for SqlxDriver {
+ type Error = sqlx::Error;
+ type Notification = PgNotification;
+ type NotificationStream = BoxStream<'static, Self::Notification>;
+
+ async fn init(&self, table: &str) -> Result<(), Self::Error> {
+ sqlx::query(&format!(
+ r#"CREATE TABLE IF NOT EXISTS "{table}" (
+ id BIGSERIAL UNIQUE,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ payload BYTEA
+ )"#,
+ ))
+ .execute(&self.client)
+ .await?;
+
+ Ok(())
+ }
+
+ async fn listen(&self, channels: &[&str]) -> Result {
+ let mut listener = PgListener::connect_with(&self.client).await?;
+ listener.listen_all(channels.iter().copied()).await?;
+
+ let stream = listener.into_stream();
+ let stream = stream.filter_map(async |res| {
+ res.inspect_err(|err| {
+ tracing::warn!("failed to pull sqlx notification from stream: {err}")
+ })
+ .ok()
+ });
+
+ Ok(Box::pin(stream))
+ }
+
+ async fn notify(&self, channel: &str, message: &str) -> Result<(), Self::Error> {
+ sqlx::query("SELECT pg_notify($1, $2)")
+ .bind(channel)
+ .bind(message)
+ .execute(&self.client)
+ .await?;
+ Ok(())
+ }
+}
+
+impl super::Notification for PgNotification {
+ fn channel(&self) -> &str {
+ PgNotification::channel(self)
+ }
+
+ fn payload(&self) -> &str {
+ PgNotification::payload(self)
+ }
+}
diff --git a/crates/socketioxide-postgres/src/lib.rs b/crates/socketioxide-postgres/src/lib.rs
new file mode 100644
index 00000000..ec09d9e7
--- /dev/null
+++ b/crates/socketioxide-postgres/src/lib.rs
@@ -0,0 +1,917 @@
+#![cfg_attr(docsrs, feature(doc_auto_cfg))]
+#![warn(
+ clippy::all,
+ clippy::todo,
+ clippy::empty_enums,
+ clippy::mem_forget,
+ clippy::unused_self,
+ clippy::filter_map_next,
+ clippy::needless_continue,
+ clippy::needless_borrow,
+ clippy::match_wildcard_for_single_variants,
+ clippy::if_let_mutex,
+ clippy::await_holding_lock,
+ clippy::indexing_slicing,
+ clippy::imprecise_flops,
+ clippy::suboptimal_flops,
+ clippy::lossy_float_literal,
+ clippy::rest_pat_in_fully_bound_structs,
+ clippy::fn_params_excessive_bools,
+ clippy::exit,
+ clippy::inefficient_to_string,
+ clippy::linkedlist,
+ clippy::macro_use_imports,
+ clippy::option_option,
+ clippy::verbose_file_reads,
+ clippy::unnested_or_patterns,
+ rust_2018_idioms,
+ future_incompatible,
+ nonstandard_style,
+ missing_docs
+)]
+//! # A PostgreSQL adapter implementation for the socketioxide crate.
+//! The adapter is used to communicate with other nodes of the same application.
+//! This allows to broadcast messages to sockets connected on other servers,
+//! to get the list of rooms, to add or remove sockets from rooms, etc.
+//!
+//! To achieve this, the adapter uses [LISTEN/NOTIFY](https://www.postgresql.org/docs/current/sql-notify.html)
+//! through PostgreSQL to communicate with other servers.
+//!
+//! The [`Driver`] abstraction allows the use of any PostgreSQL client.
+//! One implementation is provided:
+//! * [`SqlxDriver`](crate::drivers::sqlx::SqlxDriver) for the [`sqlx`] crate.
+//!
+//! You can also implement your own driver by implementing the [`Driver`] trait.
+//!
+//!
+//! Socketioxide-postgres is not compatible with @socketio/postgres-adapter.
+//! They use completely different protocols and cannot be used together.
+//! Do not mix socket.io JS servers with socketioxide rust servers.
+//!
+//!
+//! ## How does it work?
+//!
+//! The [`PostgresAdapterCtr`] is a constructor for the [`SqlxAdapter`] which is an implementation of
+//! the [`Adapter`](https://docs.rs/socketioxide/latest/socketioxide/adapter/trait.Adapter.html) trait.
+//!
+//! Then, for each namespace, an adapter is created and it takes a corresponding [`CoreLocalAdapter`].
+//! The [`CoreLocalAdapter`] allows to manage the local rooms and local sockets. The default `LocalAdapter`
+//! is simply a wrapper around this [`CoreLocalAdapter`].
+//!
+//! Once it is created the adapter is initialized with the [`CustomPostgresAdapter::init`] method.
+//! It will subscribe to three PostgreSQL NOTIFY channels and emit heartbeats.
+//! All messages are encoded with JSON.
+//!
+//! There are 7 types of requests:
+//! * Broadcast a packet to all the matching sockets.
+//! * Broadcast a packet to all the matching sockets and wait for a stream of acks.
+//! * Disconnect matching sockets.
+//! * Get all the rooms.
+//! * Add matching sockets to rooms.
+//! * Remove matching sockets from rooms.
+//! * Fetch all the remote sockets matching the options.
+//! * Heartbeat
+//! * Initial heartbeat. When receiving an initial heartbeat all other servers reply a heartbeat immediately.
+//!
+//! For ack streams, the adapter will first send a `BroadcastAckCount` response to the server that sent the request,
+//! and then send the acks as they are received (more details in [`CustomPostgresAdapter::broadcast_with_ack`] fn).
+//!
+//! On the other side, each time an action has to be performed on the local server, the adapter will
+//! first broadcast a request to all the servers and then perform the action locally.
+
+use drivers::Driver;
+use futures_core::Stream;
+use futures_util::{StreamExt, pin_mut};
+use serde::{Deserialize, Serialize, de::DeserializeOwned};
+use serde_json::value::RawValue;
+use socketioxide_core::{
+ Sid, Uid,
+ adapter::{
+ BroadcastOptions, CoreAdapter, CoreLocalAdapter, DefinedAdapter, RemoteSocketData, Room,
+ RoomParam, SocketEmitter, Spawnable,
+ errors::{AdapterError, BroadcastError},
+ remote_packet::{
+ RequestIn, RequestOut, RequestTypeIn, RequestTypeOut, Response, ResponseType,
+ ResponseTypeId,
+ },
+ },
+ packet::Packet,
+};
+use std::{
+ borrow::Cow,
+ collections::HashMap,
+ fmt, future,
+ pin::Pin,
+ sync::{Arc, Mutex},
+ task::{Context, Poll},
+ time::{Duration, Instant},
+};
+use tokio::sync::mpsc;
+
+use crate::{
+ drivers::Notification,
+ stream::{AckStream, ChanStream},
+};
+
+pub mod drivers;
+mod stream;
+
+/// The configuration of the [`CustomPostgresAdapter`].
+#[derive(Debug, Clone)]
+pub struct PostgresAdapterConfig {
+ /// The heartbeat timeout duration. If a remote node does not respond within this duration,
+ /// it will be considered disconnected. Default is 60 seconds.
+ pub hb_timeout: Duration,
+ /// The heartbeat interval duration. The current node will broadcast a heartbeat to the
+ /// remote nodes at this interval. Default is 10 seconds.
+ pub hb_interval: Duration,
+ /// The request timeout. When expecting a response from remote nodes, if they do not respond within
+ /// this duration, the request will be considered failed. Default is 5 seconds.
+ pub request_timeout: Duration,
+ /// The channel size used to receive ack responses. Default is 255.
+ ///
+ /// If you have a lot of servers/sockets and that you may miss acknowledgement because they arrive faster
+ /// than you poll them with the returned stream, you might want to increase this value.
+ pub ack_response_buffer: usize,
+ /// The table name used to store socket.io attachments. Default is "socket_io_attachments".
+ ///
+ /// > The table name must be a sanitized string. Do not use special characters or spaces.
+ pub table_name: Cow<'static, str>,
+ /// The prefix used for the channels. Default is "socket.io".
+ pub prefix: Cow<'static, str>,
+ /// The threshold to the payload size in bytes. It should match the configured value on your PostgreSQL instance:
+ /// . By default it is 8KB (8000 bytes).
+ pub payload_threshold: usize,
+ /// The duration between cleanup queries on the attachment table.
+ pub cleanup_interval: Duration,
+}
+
+impl PostgresAdapterConfig {
+ /// Create a new [`PostgresAdapterConfig`] with default values.
+ pub fn new() -> Self {
+ Self::default()
+ }
+
+ /// The heartbeat timeout duration. If a remote node does not respond within this duration,
+ /// it will be considered disconnected. Default is 60 seconds.
+ pub fn with_hb_timeout(mut self, hb_timeout: Duration) -> Self {
+ self.hb_timeout = hb_timeout;
+ self
+ }
+
+ /// The heartbeat interval duration. The current node will broadcast a heartbeat to the
+ /// remote nodes at this interval. Default is 10 seconds.
+ pub fn with_hb_interval(mut self, hb_interval: Duration) -> Self {
+ self.hb_interval = hb_interval;
+ self
+ }
+
+ /// The request timeout. When expecting a response from remote nodes, if they do not respond within
+ /// this duration, the request will be considered failed. Default is 5 seconds.
+ pub fn with_request_timeout(mut self, request_timeout: Duration) -> Self {
+ self.request_timeout = request_timeout;
+ self
+ }
+
+ /// The channel size used to receive ack responses. Default is 255.
+ ///
+ /// If you have a lot of servers/sockets and that you may miss acknowledgement because they arrive faster
+ /// than you poll them with the returned stream, you might want to increase this value.
+ pub fn with_ack_response_buffer(mut self, ack_response_buffer: usize) -> Self {
+ self.ack_response_buffer = ack_response_buffer;
+ self
+ }
+
+ /// The table name used to store socket.io attachments. Default is "socket_io_attachments".
+ ///
+ /// > The table name must be a sanitized string. Do not use special characters or spaces.
+ pub fn with_table_name(mut self, table_name: impl Into>) -> Self {
+ self.table_name = table_name.into();
+ self
+ }
+
+ /// The prefix used for the channels. Default is "socket.io".
+ pub fn with_prefix(mut self, prefix: impl Into>) -> Self {
+ self.prefix = prefix.into();
+ self
+ }
+
+ /// The threshold to the payload size in bytes. It should match the configured value on your PostgreSQL instance:
+ /// . By default it is 8KB (8000 bytes).
+ pub fn with_payload_threshold(mut self, payload_threshold: usize) -> Self {
+ self.payload_threshold = payload_threshold;
+ self
+ }
+
+ /// The duration between cleanup queries on the attachment table. Default is 60 seconds.
+ pub fn with_cleanup_interval(mut self, cleanup_interval: Duration) -> Self {
+ self.cleanup_interval = cleanup_interval;
+ self
+ }
+}
+
+impl Default for PostgresAdapterConfig {
+ fn default() -> Self {
+ Self {
+ hb_timeout: Duration::from_secs(60),
+ hb_interval: Duration::from_secs(10),
+ request_timeout: Duration::from_secs(5),
+ ack_response_buffer: 255,
+ table_name: "socket_io_attachments".into(),
+ prefix: "socket.io".into(),
+ payload_threshold: 8_000,
+ cleanup_interval: Duration::from_secs(60),
+ }
+ }
+}
+
+/// Represent any error that might happen when using this adapter.
+#[derive(thiserror::Error)]
+pub enum Error {
+ /// Postgres driver error
+ #[error("driver error: {0}")]
+ Driver(D::Error),
+ /// Packet encoding/decoding error
+ #[error("packet decoding error: {0}")]
+ Serde(#[from] serde_json::Error),
+}
+
+impl fmt::Debug for Error {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ Self::Driver(err) => write!(f, "Driver error: {:?}", err),
+ Self::Serde(err) => write!(f, "Encode/Decode error: {:?}", err),
+ }
+ }
+}
+
+impl From> for AdapterError {
+ fn from(err: Error) -> Self {
+ AdapterError::from(Box::new(err) as Box)
+ }
+}
+
+/// The adapter constructor. For each namespace you define, a new adapter instance is created
+/// from this constructor.
+#[derive(Debug, Clone)]
+pub struct PostgresAdapterCtr {
+ driver: D,
+ config: PostgresAdapterConfig,
+}
+
+#[cfg(feature = "sqlx")]
+impl PostgresAdapterCtr {
+ /// Create a new adapter constructor with the [`sqlx`](drivers::sqlx) driver
+ /// and a default config.
+ pub fn new_with_sqlx(pool: drivers::sqlx::sqlx_client::PgPool) -> Self {
+ Self::new_with_sqlx_config(pool, PostgresAdapterConfig::default())
+ }
+
+ /// Create a new adapter constructor with the [`sqlx`](drivers::sqlx) driver
+ /// and a custom config.
+ pub fn new_with_sqlx_config(
+ pool: drivers::sqlx::sqlx_client::PgPool,
+ config: PostgresAdapterConfig,
+ ) -> Self {
+ let driver = drivers::sqlx::SqlxDriver::new(pool);
+ Self { driver, config }
+ }
+}
+
+impl PostgresAdapterCtr {
+ /// Create a new adapter constructor with a custom postgres driver and a config.
+ ///
+ /// You can implement your own driver by implementing the [`Driver`] trait with any postgres client.
+ /// Check the [`drivers`] module for more information.
+ pub fn new_with_driver(driver: D, config: PostgresAdapterConfig) -> Self {
+ Self { driver, config }
+ }
+}
+
+/// The postgres adapter with the [`sqlx`](drivers::sqlx) driver.
+#[cfg(feature = "sqlx")]
+pub type SqlxAdapter = CustomPostgresAdapter;
+
+type ResponseHandlers = HashMap>>;
+
+/// The postgres adapter implementation.
+/// It is generic over the [`Driver`] used to communicate with the postgres server.
+/// And over the [`SocketEmitter`] used to communicate with the local server. This allows to
+/// avoid cyclic dependencies between the adapter, `socketioxide-core` and `socketioxide` crates.
+pub struct CustomPostgresAdapter {
+ /// The driver used by the adapter. This is used to communicate with the postgres server.
+ /// All the postgres adapter instances share the same driver.
+ driver: D,
+ /// The configuration of the adapter.
+ config: PostgresAdapterConfig,
+ /// The local adapter, used to manage local rooms and socket stores.
+ local: CoreLocalAdapter,
+ /// A map of nodes liveness, with the last time remote nodes were seen alive.
+ nodes_liveness: Mutex>,
+ /// A map of response handlers used to await for responses from the remote servers.
+ responses: Arc>,
+}
+
+impl DefinedAdapter for CustomPostgresAdapter {}
+impl CoreAdapter for CustomPostgresAdapter {
+ type Error = Error;
+ type State = PostgresAdapterCtr;
+ type AckStream = AckStream;
+ type InitRes = InitRes;
+
+ fn new(state: &Self::State, local: CoreLocalAdapter) -> Self {
+ Self {
+ local,
+ driver: state.driver.clone(),
+ config: state.config.clone(),
+ nodes_liveness: Mutex::new(Vec::new()),
+ responses: Arc::new(Mutex::new(HashMap::new())),
+ }
+ }
+
+ fn init(self: Arc, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes {
+ let fut = async move {
+ self.driver.init(&self.config.table_name).await?;
+
+ let global_chan = self.get_global_chan();
+ let node_chan = self.get_node_chan(self.local.server_id());
+ let response_chan = self.get_response_chan(self.local.server_id());
+
+ let channels = [
+ global_chan.as_str(),
+ node_chan.as_str(),
+ response_chan.as_str(),
+ ];
+
+ let stream = self.driver.listen(&channels).await?;
+ tokio::spawn(self.clone().handle_ev_stream(stream));
+ tokio::spawn(self.clone().heartbeat_job());
+
+ // Send initial heartbeat when starting.
+ self.emit_init_heartbeat().await.map_err(|e| match e {
+ Error::Driver(e) => e,
+ Error::Serde(_) => unreachable!(),
+ })?;
+
+ on_success();
+ Ok(())
+ };
+ InitRes(Box::pin(fut))
+ }
+
+ async fn close(&self) -> Result<(), Self::Error> {
+ Ok(())
+ }
+
+ /// Get the number of servers by iterating over the node liveness heartbeats.
+ async fn server_count(&self) -> Result {
+ let treshold = std::time::Instant::now() - self.config.hb_timeout;
+ let mut nodes_liveness = self.nodes_liveness.lock().unwrap();
+ nodes_liveness.retain(|(_, v)| v > &treshold);
+ Ok((nodes_liveness.len() + 1) as u16)
+ }
+
+ /// Broadcast a packet to all the servers to send them through their sockets.
+ async fn broadcast(
+ &self,
+ packet: Packet,
+ opts: BroadcastOptions,
+ ) -> Result<(), BroadcastError> {
+ let node_id = self.local.server_id();
+ if !opts.is_local(node_id) {
+ let req = RequestOut::new(node_id, RequestTypeOut::Broadcast(&packet), &opts);
+ self.send_req(req, None).await.map_err(AdapterError::from)?;
+ }
+
+ self.local.broadcast(packet, opts)?;
+ Ok(())
+ }
+
+ /// Broadcast a packet to all the servers to send them through their sockets.
+ ///
+ /// Returns a Stream that is a combination of the local ack stream and a remote ack stream.
+ /// Here is a specific protocol in order to know how many message the server expect to close
+ /// the stream at the right time:
+ /// * Get the number `n` of remote servers.
+ /// * Send the broadcast request.
+ /// * Expect `n` `BroadcastAckCount` response in the stream to know the number `m` of expected ack responses.
+ /// * Expect `sum(m)` broadcast counts sent by the servers.
+ ///
+ /// Example with 3 remote servers (n = 3):
+ /// ```text
+ /// +---+ +---+ +---+
+ /// | A | | B | | C |
+ /// +---+ +---+ +---+
+ /// | | |
+ /// |---BroadcastWithAck--->| |
+ /// |---BroadcastWithAck--------------------------->|
+ /// | | |
+ /// |<-BroadcastAckCount(2)-| (n = 2; m = 2) |
+ /// |<-BroadcastAckCount(2)-------(n = 2; m = 4)----|
+ /// | | |
+ /// |<----------------Ack---------------------------|
+ /// |<----------------Ack---| |
+ /// | | |
+ /// |<----------------Ack---------------------------|
+ /// |<----------------Ack---| |
+ async fn broadcast_with_ack(
+ &self,
+ packet: Packet,
+ opts: BroadcastOptions,
+ timeout: Option,
+ ) -> Result {
+ if opts.is_local(self.local.server_id()) {
+ tracing::debug!(?opts, "broadcast with ack is local");
+ let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
+ let stream = AckStream::new_local(local);
+ return Ok(stream);
+ }
+ let req = RequestOut::new(
+ self.local.server_id(),
+ RequestTypeOut::BroadcastWithAck(&packet),
+ &opts,
+ );
+ let req_id = req.id;
+
+ let remote_serv_cnt = self.server_count().await?.saturating_sub(1);
+ tracing::trace!(?remote_serv_cnt, "expecting acks from remote servers");
+
+ let (tx, rx) = mpsc::channel(self.config.ack_response_buffer + remote_serv_cnt as usize);
+ self.responses.lock().unwrap().insert(req_id, tx);
+
+ self.send_req(req, None).await?;
+ let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
+
+ Ok(AckStream::new(
+ local,
+ rx,
+ self.config.request_timeout,
+ remote_serv_cnt,
+ req_id,
+ self.responses.clone(),
+ ))
+ }
+
+ async fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), BroadcastError> {
+ if !opts.is_local(self.local.server_id()) {
+ let req = RequestOut::new(
+ self.local.server_id(),
+ RequestTypeOut::DisconnectSockets,
+ &opts,
+ );
+ self.send_req(req, None).await.map_err(AdapterError::from)?;
+ }
+ self.local
+ .disconnect_socket(opts)
+ .map_err(BroadcastError::Socket)?;
+
+ Ok(())
+ }
+
+ async fn rooms(&self, opts: BroadcastOptions) -> Result, Self::Error> {
+ if opts.is_local(self.local.server_id()) {
+ return Ok(self.local.rooms(opts).into_iter().collect());
+ }
+ let req = RequestOut::new(self.local.server_id(), RequestTypeOut::AllRooms, &opts);
+ let req_id = req.id;
+
+ // First get the remote stream because postgres might send
+ // the responses before subscription is done.
+ let stream = self
+ .get_res::<()>(req_id, ResponseTypeId::AllRooms, opts.server_id)
+ .await?;
+ self.send_req(req, opts.server_id).await?;
+ let local = self.local.rooms(opts);
+ let rooms = stream
+ .filter_map(|item| std::future::ready(item.into_rooms()))
+ .fold(local, |mut acc, item| async move {
+ acc.extend(item);
+ acc
+ })
+ .await;
+ Ok(Vec::from_iter(rooms))
+ }
+
+ async fn add_sockets(
+ &self,
+ opts: BroadcastOptions,
+ rooms: impl RoomParam,
+ ) -> Result<(), Self::Error> {
+ let rooms: Vec = rooms.into_room_iter().collect();
+ if !opts.is_local(self.local.server_id()) {
+ let req = RequestOut::new(
+ self.local.server_id(),
+ RequestTypeOut::AddSockets(&rooms),
+ &opts,
+ );
+ self.send_req(req, opts.server_id).await?;
+ }
+ self.local.add_sockets(opts, rooms);
+ Ok(())
+ }
+
+ async fn del_sockets(
+ &self,
+ opts: BroadcastOptions,
+ rooms: impl RoomParam,
+ ) -> Result<(), Self::Error> {
+ let rooms: Vec = rooms.into_room_iter().collect();
+ if !opts.is_local(self.local.server_id()) {
+ let req = RequestOut::new(
+ self.local.server_id(),
+ RequestTypeOut::DelSockets(&rooms),
+ &opts,
+ );
+ self.send_req(req, opts.server_id).await?;
+ }
+ self.local.del_sockets(opts, rooms);
+ Ok(())
+ }
+
+ async fn fetch_sockets(
+ &self,
+ opts: BroadcastOptions,
+ ) -> Result, Self::Error> {
+ if opts.is_local(self.local.server_id()) {
+ return Ok(self.local.fetch_sockets(opts));
+ }
+ let req = RequestOut::new(self.local.server_id(), RequestTypeOut::FetchSockets, &opts);
+ // First get the remote stream because postgres might send
+ // the responses before subscription is done.
+ let remote = self
+ .get_res::(req.id, ResponseTypeId::FetchSockets, opts.server_id)
+ .await?;
+
+ self.send_req(req, opts.server_id).await?;
+ let local = self.local.fetch_sockets(opts);
+ let sockets = remote
+ .filter_map(|item| future::ready(item.into_fetch_sockets()))
+ .fold(local, |mut acc, item| async move {
+ acc.extend(item);
+ acc
+ })
+ .await;
+ Ok(sockets)
+ }
+
+ fn get_local(&self) -> &CoreLocalAdapter {
+ &self.local
+ }
+}
+
+impl CustomPostgresAdapter {
+ async fn heartbeat_job(self: Arc) -> Result<(), Error> {
+ let mut interval = tokio::time::interval(self.config.hb_interval);
+ interval.tick().await; // first tick yields immediately
+ loop {
+ interval.tick().await;
+ self.emit_heartbeat(None).await?;
+ }
+ }
+
+ async fn handle_ev_stream(self: Arc, stream: impl Stream- ) {
+ pin_mut!(stream);
+ while let Some(notif) = stream.next().await {
+ let chan = notif.channel();
+ let resp_chan = self.get_response_chan(self.local.server_id());
+ tracing::info!(chan, resp_chan, notif = notif.payload(), "");
+ if chan == resp_chan {
+ match serde_json::from_str(notif.payload()) {
+ Ok(ResponsePacket {
+ req_id,
+ node_id,
+ payload,
+ }) if node_id != self.local.server_id() => {
+ let handlers = self.responses.lock().unwrap();
+ if let Some(handler) = handlers.get(&req_id) {
+ if let Err(e) = handler.try_send(payload) {
+ tracing::warn!(channel = resp_chan, req_id = %req_id, "error sending response: {e}");
+ }
+ } else {
+ tracing::warn!(channel = resp_chan, req_id = %req_id, "response handler not found");
+ }
+ }
+ Ok(_) => {
+ tracing::trace!("skipping loopback packets");
+ }
+ Err(e) => {
+ tracing::warn!(channel = %notif.channel(), "error handling response: {e}")
+ }
+ };
+ } else {
+ match serde_json::from_str::(notif.payload()) {
+ Ok(req) if req.node_id != self.local.server_id() => self.recv_req(req),
+ Ok(_) => {
+ tracing::trace!("skipping loopback packets")
+ }
+ Err(e) => {
+ tracing::warn!(channel = %notif.channel(), "error decoding request: {e}")
+ }
+ };
+ }
+ }
+ }
+
+ fn recv_req(self: &Arc, req: RequestIn) {
+ tracing::trace!(?req, "incoming request");
+ match (req.r#type, req.opts) {
+ (RequestTypeIn::Broadcast(p), Some(opts)) => self.recv_broadcast(opts, p),
+ (RequestTypeIn::BroadcastWithAck(p), Some(opts)) => self
+ .clone()
+ .recv_broadcast_with_ack(req.node_id, req.id, p, opts),
+ (RequestTypeIn::DisconnectSockets, Some(opts)) => self.recv_disconnect_sockets(opts),
+ (RequestTypeIn::AllRooms, Some(opts)) => self.recv_rooms(req.node_id, req.id, opts),
+ (RequestTypeIn::AddSockets(rooms), Some(opts)) => self.recv_add_sockets(opts, rooms),
+ (RequestTypeIn::DelSockets(rooms), Some(opts)) => self.recv_del_sockets(opts, rooms),
+ (RequestTypeIn::FetchSockets, Some(opts)) => {
+ self.recv_fetch_sockets(req.node_id, req.id, opts)
+ }
+ req_type @ (RequestTypeIn::Heartbeat | RequestTypeIn::InitHeartbeat, _) => {
+ self.recv_heartbeat(req_type.0, req.node_id)
+ }
+ _ => (),
+ }
+ }
+
+ fn recv_broadcast(&self, opts: BroadcastOptions, packet: Packet) {
+ tracing::trace!(?opts, "incoming broadcast");
+ if let Err(e) = self.local.broadcast(packet, opts) {
+ let ns = self.local.path();
+ tracing::warn!(node_id = %self.local.server_id(), ?ns, "remote request broadcast handler: {:?}", e);
+ }
+ }
+
+ fn recv_disconnect_sockets(&self, opts: BroadcastOptions) {
+ if let Err(e) = self.local.disconnect_socket(opts) {
+ let ns = self.local.path();
+ tracing::warn!(
+ node_id = %self.local.server_id(),
+ %ns,
+ "remote request disconnect sockets handler: {:?}",
+ e
+ );
+ }
+ }
+
+ fn recv_broadcast_with_ack(
+ self: Arc,
+ origin: Uid,
+ req_id: Sid,
+ packet: Packet,
+ opts: BroadcastOptions,
+ ) {
+ let (stream, count) = self.local.broadcast_with_ack(packet, opts, None);
+ tokio::spawn(async move {
+ let on_err = |err| {
+ let ns = self.local.path();
+ tracing::warn!(
+ node_id = %self.local.server_id(),
+ %ns,
+ "remote request broadcast with ack handler errors: {:?}",
+ err
+ );
+ };
+ // First send the count of expected acks to the server that sent the request.
+ // This is used to keep track of the number of expected acks.
+ let res = Response {
+ r#type: ResponseType::<()>::BroadcastAckCount(count),
+ node_id: self.local.server_id(),
+ };
+ if let Err(err) = self.send_res(req_id, origin, res).await {
+ on_err(err);
+ return;
+ }
+
+ // Then send the acks as they are received.
+ futures_util::pin_mut!(stream);
+ while let Some(ack) = stream.next().await {
+ let res = Response {
+ r#type: ResponseType::BroadcastAck(ack),
+ node_id: self.local.server_id(),
+ };
+ if let Err(err) = self.send_res(req_id, origin, res).await {
+ on_err(err);
+ return;
+ }
+ }
+ });
+ }
+
+ fn recv_rooms(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) {
+ let rooms = self.local.rooms(opts);
+ let res = Response {
+ r#type: ResponseType::<()>::AllRooms(rooms),
+ node_id: self.local.server_id(),
+ };
+ let fut = self.send_res(req_id, origin, res);
+ let ns = self.local.path().clone();
+ let uid = self.local.server_id();
+ tokio::spawn(async move {
+ if let Err(err) = fut.await {
+ tracing::warn!(?uid, ?ns, "remote request rooms handler: {:?}", err);
+ }
+ });
+ }
+
+ fn recv_add_sockets(&self, opts: BroadcastOptions, rooms: Vec) {
+ self.local.add_sockets(opts, rooms);
+ }
+
+ fn recv_del_sockets(&self, opts: BroadcastOptions, rooms: Vec) {
+ self.local.del_sockets(opts, rooms);
+ }
+ fn recv_fetch_sockets(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) {
+ let sockets = self.local.fetch_sockets(opts);
+ let res = Response {
+ node_id: self.local.server_id(),
+ r#type: ResponseType::FetchSockets(sockets),
+ };
+ let fut = self.send_res(req_id, origin, res);
+ let ns = self.local.path().clone();
+ let uid = self.local.server_id();
+ tokio::spawn(async move {
+ if let Err(err) = fut.await {
+ tracing::warn!(?uid, ?ns, "remote request fetch sockets handler: {:?}", err);
+ }
+ });
+ }
+
+ /// Receive a heartbeat from a remote node.
+ /// It might be a FirstHeartbeat packet, in which case we are re-emitting a heartbeat to the remote node.
+ fn recv_heartbeat(self: &Arc, req_type: RequestTypeIn, origin: Uid) {
+ tracing::debug!(?req_type, "{:?} received", req_type);
+ let mut node_liveness = self.nodes_liveness.lock().unwrap();
+ // Even with a FirstHeartbeat packet we first consume the node liveness to
+ // ensure that the node is not already in the list.
+ for (id, liveness) in node_liveness.iter_mut() {
+ if *id == origin {
+ *liveness = Instant::now();
+ return;
+ }
+ }
+
+ node_liveness.push((origin, Instant::now()));
+
+ if matches!(req_type, RequestTypeIn::InitHeartbeat) {
+ tracing::debug!(
+ ?origin,
+ "initial heartbeat detected, saying hello to the new node"
+ );
+
+ let this = self.clone();
+ tokio::spawn(async move {
+ if let Err(err) = this.emit_heartbeat(Some(origin)).await {
+ tracing::warn!(
+ "could not re-emit heartbeat after new node detection: {:?}",
+ err
+ );
+ }
+ });
+ }
+ }
+
+ /// Send a request to a specific target node or broadcast it to all nodes if no target is specified.
+ async fn send_req(&self, req: RequestOut<'_>, target: Option) -> Result<(), Error> {
+ tracing::trace!(?req, "sending request");
+ let chan = match target {
+ Some(target) => self.get_node_chan(target),
+ None => self.get_global_chan(),
+ };
+ let payload = serde_json::to_string(&req)?;
+ self.driver
+ .notify(&chan, &payload)
+ .await
+ .map_err(Error::Driver)?;
+ Ok(())
+ }
+
+ /// Send a response to the node that sent the request.
+ fn send_res(
+ &self,
+ req_id: Sid,
+ req_origin: Uid,
+ payload: Response,
+ ) -> impl Future