Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
78c7091
[SPARK-52780] Add ToLocalIterator and Arrow Record Streaming
caldempsey Jul 13, 2025
5e0a589
[debug] a case where context cancellations result in a panic
caldempsey Jul 13, 2025
c277f5b
[SPARK-52780] fix test compilation
caldempsey Jul 13, 2025
7ce5d47
[SPARK-52780] TestRowIterator_BothChannelsClosedCleanly should EOF (D…
caldempsey Jul 13, 2025
2b6044a
[SPARK-52780] fix linting error
caldempsey Jul 13, 2025
1a897ef
[SPARK-52780] rowiterator.go channel closing should deterministically…
caldempsey Jul 13, 2025
8c18703
[SPARK-52780] lint errors
caldempsey Jul 13, 2025
3dcab75
fix: merge
caldempsey Sep 3, 2025
485067e
Merge branch 'master' into callum/SPARK-52780
caldempsey Sep 3, 2025
f285079
feat: update the client base to provide lazy fetch
caldempsey Sep 3, 2025
917ce9f
feat: rename ToLocalIterator to StreamRows, establish RowIterator as …
caldempsey Sep 3, 2025
ad7e935
fix: golint-ci
caldempsey Sep 3, 2025
d38170b
fix: improve test doc-comments
caldempsey Sep 3, 2025
a18468f
feat: add tests for streaming rows in DataFrame operations including:
caldempsey Oct 22, 2025
434a579
fix: update Spark version to 4.0.1 in build workflow
caldempsey Oct 22, 2025
0432bde
fix: remove debug print lines from ToTable()
caldempsey Mar 1, 2026
928e9b3
fix: remove c.done race condition in ToRecordSequence
caldempsey Mar 1, 2026
146e423
fix: remove NewRowPull2, fold EOF handling into NewRowSequence
caldempsey Mar 1, 2026
aa4b293
fix: extract rowIterFromRecord to simplify NewRowSequence
caldempsey Mar 1, 2026
fb2a9aa
fix: prefer explicit error yield
caldempsey Mar 1, 2026
b29e5ef
fix: address feedback
caldempsey Mar 5, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ on:
- master

env:
SPARK_VERSION: '4.0.0'
SPARK_VERSION: '4.0.1'
HADOOP_VERSION: '3'

permissions:
Expand Down
195 changes: 194 additions & 1 deletion internal/tests/integration/dataframe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1196,7 +1196,6 @@ func TestDataFrame_RangeIter(t *testing.T) {
}
assert.Equal(t, 10, cnt)

// Check that errors are properly propagated
df, err = spark.Sql(ctx, "select if(id = 5, raise_error('handle'), false) from range(10)")
assert.NoError(t, err)
for _, err := range df.All(ctx) {
Expand Down Expand Up @@ -1224,3 +1223,197 @@ func TestDataFrame_SchemaTreeString(t *testing.T) {
assert.Contains(t, ts, "|-- second: array")
assert.Contains(t, ts, "|-- third: map")
}

func TestDataFrame_StreamRowsThroughChannel(t *testing.T) {
// Demonstrates how StreamRows can be used to pipe data through a channel for scenarios like:
// - Proxying Spark data through gRPC streaming or unary RPCs
// - Implementing producer-consumer patterns with backpressure based on Spark results
// - Buffering and rate-limiting data flow between systems
ctx, spark := connect()
df, err := spark.Sql(ctx, "select id, id * 2 as doubled, 'test_' || cast(id as string) as label from range(100)")
assert.NoError(t, err)

iter, err := df.StreamRows(ctx)
assert.NoError(t, err)

rowChan := make(chan map[string]interface{}, 10)
errChan := make(chan error, 1)

go func() {
defer close(rowChan)
for row, err := range iter {
if err != nil {
errChan <- err
return
}

rowData := make(map[string]interface{})
names := row.FieldNames()
for i, name := range names {
rowData[name] = row.At(i)
}

select {
case rowChan <- rowData:
case <-ctx.Done():
errChan <- ctx.Err()
return
}
}
}()

// In a gRPC scenario, this would be your response handler
receivedRows := make([]map[string]interface{}, 0)
consumerDone := make(chan struct{})

go func() {
defer close(consumerDone)
for rowData := range rowChan {
receivedRows = append(receivedRows, rowData)

id := rowData["id"].(int64)
doubled := rowData["doubled"].(int64)
assert.Equal(t, id*2, doubled)
}
}()

<-consumerDone

select {
case err := <-errChan:
assert.NoError(t, err)
default:
// continue
}

assert.Equal(t, 100, len(receivedRows))

assert.Equal(t, int64(0), receivedRows[0]["id"])
assert.Equal(t, int64(99), receivedRows[99]["id"])
assert.Equal(t, "test_0", receivedRows[0]["label"])
assert.Equal(t, "test_99", receivedRows[99]["label"])
}

func TestDataFrame_StreamRows(t *testing.T) {
ctx, spark := connect()
df, err := spark.Sql(ctx, "select * from range(100)")
assert.NoError(t, err)

iter, err := df.StreamRows(ctx)
assert.NoError(t, err)
assert.NotNil(t, iter)

cnt := 0

for row, err := range iter {
assert.NoError(t, err)
assert.NotNil(t, row)
assert.Equal(t, 1, row.Len())
cnt++
}
assert.Equal(t, 100, cnt)
}

func TestDataFrame_StreamRowsWithFilter(t *testing.T) {
ctx, spark := connect()
df, err := spark.Sql(ctx, "select * from range(100)")
assert.NoError(t, err)

df, err = df.Filter(ctx, functions.Col("id").Lt(functions.IntLit(10)))
assert.NoError(t, err)

iter, err := df.StreamRows(ctx)
assert.NoError(t, err)

cnt := 0
for row, err := range iter {
assert.NoError(t, err)
assert.NotNil(t, row)
assert.Equal(t, 1, row.Len())
assert.Less(t, row.At(0).(int64), int64(10))
cnt++
}
assert.Equal(t, 10, cnt)
}

func TestDataFrame_StreamRowsEmpty(t *testing.T) {
ctx, spark := connect()
df, err := spark.Sql(ctx, "select * from range(0)")
assert.NoError(t, err)

iter, err := df.StreamRows(ctx)
assert.NoError(t, err)

cnt := 0
for row, err := range iter {
assert.NoError(t, err)
assert.NotNil(t, row)
cnt++
}
assert.Equal(t, 0, cnt)
}

func TestDataFrame_StreamRowsWithError(t *testing.T) {
ctx, spark := connect()
df, err := spark.Sql(ctx, "select if(id = 5, raise_error('test error'), id) as id from range(10)")
assert.NoError(t, err)

iter, err := df.StreamRows(ctx)
assert.NoError(t, err)

errorEncountered := false
for _, err := range iter {
if err != nil {
errorEncountered = true
assert.Error(t, err)
break
}
}
assert.True(t, errorEncountered, "Expected to encounter an error during iteration")
}

func TestDataFrame_StreamRowsMultipleColumns(t *testing.T) {
ctx, spark := connect()
df, err := spark.Sql(ctx, "select id, id * 2 as doubled, 'test' as name from range(50)")
assert.NoError(t, err)

iter, err := df.StreamRows(ctx)
assert.NoError(t, err)

cnt := 0
for row, err := range iter {
assert.NoError(t, err)
assert.NotNil(t, row)
assert.Equal(t, 3, row.Len())

id := row.At(0).(int64)
doubled := row.At(1).(int64)
name := row.At(2).(string)

assert.Equal(t, id*2, doubled)
assert.Equal(t, "test", name)
cnt++
}
assert.Equal(t, 50, cnt)
}

func TestDataFrame_StreamRowsLargeDataset(t *testing.T) {
ctx, spark := connect()
df, err := spark.Sql(ctx, "select * from range(10000)")
assert.NoError(t, err)

iter, err := df.StreamRows(ctx)
assert.NoError(t, err)

cnt := 0
lastValue := int64(-1)
for row, err := range iter {
assert.NoError(t, err)
assert.NotNil(t, row)
currentValue := row.At(0).(int64)
assert.Greater(t, currentValue, lastValue)
lastValue = currentValue
cnt++
}
assert.Equal(t, 10000, cnt)
}
4 changes: 4 additions & 0 deletions spark/client/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package base

import (
"context"
"iter"

"github.com/apache/spark-connect-go/spark/sql/utils"

Expand Down Expand Up @@ -47,6 +48,9 @@ type SparkConnectClient interface {
}

type ExecuteResponseStream interface {
// ToTable consumes all arrow.Record batches to a single arrow.Table. Useful for collecting all query results into a client DF.
ToTable() (*types.StructType, arrow.Table, error)
// ToRecordSequence lazily consumes each arrow.Record retrieved by a query. Useful for streaming query results.
ToRecordSequence(ctx context.Context) iter.Seq2[arrow.Record, error]
Properties() map[string]any
}
98 changes: 98 additions & 0 deletions spark/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"errors"
"fmt"
"io"
"iter"

"github.com/apache/spark-connect-go/spark/sql/utils"

Expand Down Expand Up @@ -434,6 +435,103 @@ func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) {
}
}

// ToRecordSequence returns a single Seq2 iterator that directly yields results as they arrive.
func (c *ExecutePlanClient) ToRecordSequence(ctx context.Context) iter.Seq2[arrow.Record, error] {
return func(yield func(arrow.Record, error) bool) {
// Represents Spark's reattachable execution.
// Tracks logical completion locally to avoid racing on shared struct state.
// Spliced from ToTable. We may eventually want to DRY up these workflows.
done := false

for {
select {
case <-ctx.Done():
yield(nil, ctx.Err())
return
default:
}

resp, err := c.responseStream.Recv()

select {
case <-ctx.Done():
yield(nil, ctx.Err())
return
default:
}

if errors.Is(err, io.EOF) {
break
}

if err != nil {
if se := sparkerrors.FromRPCError(err); se != nil {
yield(nil, sparkerrors.WithType(se, sparkerrors.ExecutionError))
} else {
yield(nil, err)
}
return
}

if resp == nil {
continue
}

if resp.GetSessionId() != c.sessionId {
yield(nil, sparkerrors.WithType(
&sparkerrors.InvalidServerSideSessionDetailsError{
OwnSessionId: c.sessionId,
ReceivedSessionId: resp.GetSessionId(),
}, sparkerrors.InvalidServerSideSessionError))
return
}

if resp.Schema != nil {
var schemaErr error
c.schema, schemaErr = types.ConvertProtoDataTypeToStructType(resp.Schema)
if schemaErr != nil {
yield(nil, sparkerrors.WithType(schemaErr, sparkerrors.ExecutionError))
return
}
}

switch x := resp.ResponseType.(type) {
case *proto.ExecutePlanResponse_SqlCommandResult_:
if val := x.SqlCommandResult.GetRelation(); val != nil {
c.properties["sql_command_result"] = val
}

case *proto.ExecutePlanResponse_ArrowBatch_:
record, err := types.ReadArrowBatchToRecord(x.ArrowBatch.Data, c.schema)
if err != nil {
yield(nil, err)
return
}
if !yield(record, nil) {
return
}

case *proto.ExecutePlanResponse_ResultComplete_:
done = true
return

case *proto.ExecutePlanResponse_ExecutionProgress_:
// Progress updates - ignore for now

default:
// Explicitly ignore unknown message types
}
}

// Check that the result is logically complete. With re-attachable execution
// the server may interrupt the connection, and we need a ResultComplete
// message to confirm the full result was received.
if c.opts.ReattachExecution && !done {
yield(nil, sparkerrors.WithType(fmt.Errorf("the result is not complete"), sparkerrors.ExecutionError))
}
}
}

func NewExecuteResponseStream(
responseClient proto.SparkConnectService_ExecutePlanClient,
sessionId string,
Expand Down
Loading