Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions internal/tests/integration/dataframe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1204,3 +1204,27 @@ func TestDataFrame_RangeIter(t *testing.T) {
assert.Error(t, err)
}
}

func TestDataFrame_Dtypes(t *testing.T) {
ctx, spark := connect()
data := [][]any{
{"bob", "Developer", 125000, 1},
}
schema := types.StructOf(
types.NewStructField("Name", types.STRING),
types.NewStructField("Role", types.STRING),
types.NewStructField("Salary", types.LONG),
types.NewStructField("Performance", types.LONG),
)

df, err := spark.CreateDataFrame(ctx, data, schema)
require.NoError(t, err)
dtypes, err := df.Dtypes(ctx)
require.NoError(t, err)
assert.Equal(t, []sql.DataTypeInfo{
{Name: "Name", Type: "string"},
{Name: "Role", Type: "string"},
{Name: "Salary", Type: "bigint"},
{Name: "Performance", Type: "bigint"},
}, dtypes)
}
20 changes: 20 additions & 0 deletions spark/sql/dataframe.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ type ResultCollector interface {
WriteRow(values []any)
}

type DataTypeInfo struct {
Name string
Type string
}

// DataFrame is a wrapper for data frame, representing a distributed collection of data row.
type DataFrame interface {
// PlanId returns the plan id of the data frame.
Expand Down Expand Up @@ -115,6 +120,8 @@ type DataFrame interface {
DropNaAll(ctx context.Context, cols ...string) (DataFrame, error)
// Drops all rows containing null or NaN values in the specified columns. with a max threshold.
DropNaWithThreshold(ctx context.Context, threshold int32, cols ...string) (DataFrame, error)
// Returns a list of column names and their data types.
Dtypes(ctx context.Context) ([]DataTypeInfo, error)
// ExceptAll is similar to Substract but does not perform the distinct operation.
ExceptAll(ctx context.Context, other DataFrame) DataFrame
// Explain returns the string explain plan for the current DataFrame according to the explainMode.
Expand Down Expand Up @@ -1716,3 +1723,16 @@ func (df *dataFrameImpl) All(ctx context.Context) iter.Seq2[types.Row, error] {
}
}
}

func (df *dataFrameImpl) Dtypes(ctx context.Context) ([]DataTypeInfo, error) {
schema, err := df.Schema(ctx)
if err != nil {
return nil, err
}
dtypes := make([]DataTypeInfo, len(schema.Fields))
for i, field := range schema.Fields {
dtypes[i].Name = field.Name
dtypes[i].Type = field.DataType.TypeName()
}
return dtypes, nil
}