diff --git a/lib/s2/connection.rb b/lib/s2/connection.rb index 6ce6e94..6640c0f 100644 --- a/lib/s2/connection.rb +++ b/lib/s2/connection.rb @@ -9,13 +9,14 @@ class Connection attr_reader :connected_at, :status - def initialize(resource_id:, task:, ws_url:) + def initialize(resource_id:, task:, ws_url:, headers: {}) @connected_at = nil @queue = nil @resource_id = resource_id @session = nil @task = task @ws_url = ws_url + @headers = headers @status = :initialized @stopping = false @backoff = INITIAL_BACKOFF @@ -81,7 +82,7 @@ def connect_and_run def connect_websocket(&) @status = :connecting - Async::WebSocket::Client.connect(@endpoint) do |ws| + Async::WebSocket::Client.connect(@endpoint, headers: @headers) do |ws| ActiveSupport::Notifications.instrument( "connected.session.s2", resource_id: @resource_id, diff --git a/spec/s2/connection_spec.rb b/spec/s2/connection_spec.rb index 508ae59..8f8ab80 100644 --- a/spec/s2/connection_spec.rb +++ b/spec/s2/connection_spec.rb @@ -35,6 +35,38 @@ expect(sleep_durations).to eq([5, 10, 20, 40, 80, 160, 320, 640, 1280, 2560, 3600, 3600, 3600, 3600]) end + it "passes headers to the websocket client" do + ws = FakeWebSocket.new + received_headers = nil + + allow(Async::WebSocket::Client).to receive(:connect) do |_endpoint, headers:, &block| + received_headers = headers + block.call(ws) + end + + resource_id = SecureRandom.uuid + ws_url = "ws://example.com/#{resource_id}" + headers = { "authorization" => "Basic dXNlcjpwYXNz" } + + connection = described_class.new( + resource_id:, + task: Async::Task.current, + ws_url:, + headers:, + ) + + task = Async do + connection.connect + end + + Async::Task.current.sleep 0.1 + + connection.disconnect + task.stop + + expect(received_headers).to eq(headers) + end + it "reuses the same endpoint across reconnection attempts" do ws = FakeWebSocket.new connection_attempts = []