diff --git a/internal/socket/subscriber_request_channel.go b/internal/socket/subscriber_request_channel.go index f64b3f8..af48662 100644 --- a/internal/socket/subscriber_request_channel.go +++ b/internal/socket/subscriber_request_channel.go @@ -12,6 +12,26 @@ import ( "go.uber.org/atomic" ) +// FinalPayload is a marker interface for payloads that should be sent with FlagNext|FlagComplete. +type FinalPayload interface { + payload.Payload + IsFinal() bool +} + +// finalPayloadWrapper wraps a payload and marks it as final. +type finalPayloadWrapper struct { + payload.Payload +} + +func (f finalPayloadWrapper) IsFinal() bool { + return true +} + +// NewFinalPayload creates a final payload that will be sent with FlagNext|FlagComplete. +func NewFinalPayload(p payload.Payload) payload.Payload { + return finalPayloadWrapper{Payload: p} +} + type requestChannelSubscriber struct { sid uint32 dc *DuplexConnection @@ -52,15 +72,22 @@ func (r *requestChannelSubscriber) OnSubscribe(ctx context.Context, s rx.Subscri } type respondChannelSubscriber struct { - sid uint32 - n uint32 - dc *DuplexConnection - rcv flux.Processor - subscribed chan<- struct{} - calls *atomic.Int32 + sid uint32 + n uint32 + dc *DuplexConnection + rcv flux.Processor + subscribed chan<- struct{} + calls *atomic.Int32 + sentFinalNext atomic.Bool } func (r *respondChannelSubscriber) OnNext(next payload.Payload) { + if _, ok := next.(FinalPayload); ok { + r.sentFinalNext.Store(true) + r.OnComplete() + r.dc.sendPayload(r.sid, next, core.FlagNext|core.FlagComplete) + return + } r.dc.sendPayload(r.sid, next, core.FlagNext) } @@ -75,6 +102,9 @@ func (r *respondChannelSubscriber) OnComplete() { if r.calls.Inc() == 2 { r.dc.unregister(r.sid) } + if r.sentFinalNext.Load() { + return + } complete := framing.NewWriteablePayloadFrame(r.sid, nil, nil, core.FlagComplete) done := make(chan struct{}) complete.HandleDone(func() { diff --git a/rsocket.go b/rsocket.go index 0e58358..c9543cf 100644 --- a/rsocket.go +++ b/rsocket.go @@ -86,6 +86,11 @@ func NewAbstractSocket(opts ...OptAbstractSocket) RSocket { return sk } +// NewFinalPayload is a wrapper/marker that allows a payload to be sent with both Next and Complete flags. +func NewFinalPayload(p payload.Payload) payload.Payload { + return socket.NewFinalPayload(p) +} + // MetadataPush register request handler for MetadataPush. func MetadataPush(fn func(request payload.Payload)) OptAbstractSocket { return func(socket *socket.AbstractRSocket) {