diff --git a/internal/tests/integration/dataframe_test.go b/internal/tests/integration/dataframe_test.go index 720cc98..cfd1985 100644 --- a/internal/tests/integration/dataframe_test.go +++ b/internal/tests/integration/dataframe_test.go @@ -1204,3 +1204,27 @@ func TestDataFrame_RangeIter(t *testing.T) { assert.Error(t, err) } } + +func TestDataFrame_Dtypes(t *testing.T) { + ctx, spark := connect() + data := [][]any{ + {"bob", "Developer", 125000, 1}, + } + schema := types.StructOf( + types.NewStructField("Name", types.STRING), + types.NewStructField("Role", types.STRING), + types.NewStructField("Salary", types.LONG), + types.NewStructField("Performance", types.LONG), + ) + + df, err := spark.CreateDataFrame(ctx, data, schema) + require.NoError(t, err) + dtypes, err := df.Dtypes(ctx) + require.NoError(t, err) + assert.Equal(t, []sql.DataTypeInfo{ + {Name: "Name", Type: "string"}, + {Name: "Role", Type: "string"}, + {Name: "Salary", Type: "bigint"}, + {Name: "Performance", Type: "bigint"}, + }, dtypes) +} diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go index ce67b0e..7147b99 100644 --- a/spark/sql/dataframe.go +++ b/spark/sql/dataframe.go @@ -38,6 +38,11 @@ type ResultCollector interface { WriteRow(values []any) } +type DataTypeInfo struct { + Name string + Type string +} + // DataFrame is a wrapper for data frame, representing a distributed collection of data row. type DataFrame interface { // PlanId returns the plan id of the data frame. @@ -115,6 +120,8 @@ type DataFrame interface { DropNaAll(ctx context.Context, cols ...string) (DataFrame, error) // Drops all rows containing null or NaN values in the specified columns. with a max threshold. DropNaWithThreshold(ctx context.Context, threshold int32, cols ...string) (DataFrame, error) + // Returns a list of column names and their data types. + Dtypes(ctx context.Context) ([]DataTypeInfo, error) // ExceptAll is similar to Substract but does not perform the distinct operation. ExceptAll(ctx context.Context, other DataFrame) DataFrame // Explain returns the string explain plan for the current DataFrame according to the explainMode. @@ -1716,3 +1723,16 @@ func (df *dataFrameImpl) All(ctx context.Context) iter.Seq2[types.Row, error] { } } } + +func (df *dataFrameImpl) Dtypes(ctx context.Context) ([]DataTypeInfo, error) { + schema, err := df.Schema(ctx) + if err != nil { + return nil, err + } + dtypes := make([]DataTypeInfo, len(schema.Fields)) + for i, field := range schema.Fields { + dtypes[i].Name = field.Name + dtypes[i].Type = field.DataType.TypeName() + } + return dtypes, nil +}