Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import datafusion as dfn
import numpy as np
import pyarrow as pa
import pytest
from datafusion import col, lit
from datafusion import functions as F
Expand All @@ -29,6 +30,8 @@ def _doctest_namespace(doctest_namespace: dict) -> None:
"""Add common imports to the doctest namespace."""
doctest_namespace["dfn"] = dfn
doctest_namespace["np"] = np
doctest_namespace["pa"] = pa
doctest_namespace["col"] = col
doctest_namespace["lit"] = lit
doctest_namespace["F"] = F
doctest_namespace["ctx"] = dfn.SessionContext()
58 changes: 57 additions & 1 deletion crates/core/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use datafusion::execution::context::{
};
use datafusion::execution::disk_manager::DiskManagerMode;
use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
use datafusion::execution::options::ReadOptions;
use datafusion::execution::options::{ArrowReadOptions, ReadOptions};
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::prelude::{
Expand Down Expand Up @@ -956,6 +956,39 @@ impl PySessionContext {
Ok(())
}

#[pyo3(signature = (name, path, schema=None, file_extension=".arrow", table_partition_cols=vec![]))]
pub fn register_arrow(
&self,
name: &str,
path: &str,
schema: Option<PyArrowType<Schema>>,
file_extension: &str,
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
py: Python,
) -> PyDataFusionResult<()> {
let mut options = ArrowReadOptions::default().table_partition_cols(
table_partition_cols
.into_iter()
.map(|(name, ty)| (name, ty.0))
.collect::<Vec<(String, DataType)>>(),
);
options.file_extension = file_extension;
options.schema = schema.as_ref().map(|x| &x.0);

let result = self.ctx.register_arrow(name, path, options);
wait_for_future(py, result)??;
Ok(())
}

pub fn register_batch(
&self,
name: &str,
batch: PyArrowType<RecordBatch>,
) -> PyDataFusionResult<()> {
self.ctx.register_batch(name, batch.0)?;
Ok(())
}

// Registers a PyArrow.Dataset
pub fn register_dataset(
&self,
Expand Down Expand Up @@ -1184,6 +1217,29 @@ impl PySessionContext {
Ok(PyDataFrame::new(df))
}

#[pyo3(signature = (path, schema=None, file_extension=".arrow", table_partition_cols=vec![]))]
pub fn read_arrow(
&self,
path: &str,
schema: Option<PyArrowType<Schema>>,
file_extension: &str,
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
py: Python,
) -> PyDataFusionResult<PyDataFrame> {
let mut options = ArrowReadOptions::default().table_partition_cols(
table_partition_cols
.into_iter()
.map(|(name, ty)| (name, ty.0))
.collect::<Vec<(String, DataType)>>(),
);
options.file_extension = file_extension;
options.schema = schema.as_ref().map(|x| &x.0);

let result = self.ctx.read_arrow(path, options);
let df = wait_for_future(py, result)??;
Ok(PyDataFrame::new(df))
}

pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult<PyDataFrame> {
let session = self.clone().into_bound_py_any(table.py())?;
let table = PyTable::new(table, Some(session))?;
Expand Down
122 changes: 122 additions & 0 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,26 @@ def register_udtf(self, func: TableFunction) -> None:
"""Register a user defined table function."""
self.ctx.register_udtf(func._udtf)

def register_batch(self, name: str, batch: pa.RecordBatch) -> None:
"""Register a single :py:class:`pa.RecordBatch` as a table.

Args:
name: Name of the resultant table.
batch: Record batch to register as a table.

Examples:
>>> batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3]})
>>> ctx.register_batch("batch_tbl", batch)
>>> ctx.sql("SELECT * FROM batch_tbl").collect()[0].column(0)
<pyarrow.lib.Int64Array object at ...>
[
1,
2,
3
]
"""
self.ctx.register_batch(name, batch)

def register_record_batches(
self, name: str, partitions: list[list[pa.RecordBatch]]
) -> None:
Expand Down Expand Up @@ -1092,6 +1112,49 @@ def register_avro(
name, str(path), schema, file_extension, table_partition_cols
)

def register_arrow(
self,
name: str,
path: str | pathlib.Path,
schema: pa.Schema | None = None,
file_extension: str = ".arrow",
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
) -> None:
"""Register an Arrow IPC file as a table.

The registered table can be referenced from SQL statements executed
against this context.

Args:
name: Name of the table to register.
path: Path to the Arrow IPC file.
schema: The data source schema.
file_extension: File extension to select.
table_partition_cols: Partition columns.

Examples:
>>> import tempfile, os
>>> table = pa.table({"x": [10, 20, 30]})
>>> with tempfile.TemporaryDirectory() as tmpdir:
... path = os.path.join(tmpdir, "data.arrow")
... with pa.ipc.new_file(path, table.schema) as writer:
... writer.write_table(table)
... ctx.register_arrow("arrow_tbl", path)
... ctx.sql("SELECT * FROM arrow_tbl").collect()[0].column(0)
<pyarrow.lib.Int64Array object at ...>
[
10,
20,
30
]
"""
if table_partition_cols is None:
table_partition_cols = []
table_partition_cols = _convert_table_partition_cols(table_partition_cols)
self.ctx.register_arrow(
name, str(path), schema, file_extension, table_partition_cols
)

def register_dataset(self, name: str, dataset: pa.dataset.Dataset) -> None:
"""Register a :py:class:`pa.dataset.Dataset` as a table.

Expand Down Expand Up @@ -1328,6 +1391,65 @@ def read_avro(
self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension)
)

def read_arrow(
self,
path: str | pathlib.Path,
schema: pa.Schema | None = None,
file_extension: str = ".arrow",
file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
) -> DataFrame:
"""Create a :py:class:`DataFrame` for reading an Arrow IPC data source.

Args:
path: Path to the Arrow IPC file.
schema: The data source schema.
file_extension: File extension to select.
file_partition_cols: Partition columns.

Returns:
DataFrame representation of the read Arrow IPC file.

Examples:
>>> import tempfile, os
>>> table = pa.table({"a": [1, 2, 3]})
>>> with tempfile.TemporaryDirectory() as tmpdir:
... path = os.path.join(tmpdir, "data.arrow")
... with pa.ipc.new_file(path, table.schema) as writer:
... writer.write_table(table)
... df = ctx.read_arrow(path)
... df.collect()[0].column(0)
<pyarrow.lib.Int64Array object at ...>
[
1,
2,
3
]
"""
if file_partition_cols is None:
file_partition_cols = []
file_partition_cols = _convert_table_partition_cols(file_partition_cols)
return DataFrame(
self.ctx.read_arrow(str(path), schema, file_extension, file_partition_cols)
)

def read_empty(self) -> DataFrame:
"""Create an empty :py:class:`DataFrame` with no columns or rows.

This is an alias for :meth:`empty_table`.

Returns:
An empty DataFrame.

Examples:
>>> df = ctx.read_empty()
>>> result = df.collect()
>>> len(result)
1
>>> result[0].num_columns
0
"""
return self.empty_table()

def read_table(
self, table: Table | TableProviderExportable | DataFrame | pa.dataset.Dataset
) -> DataFrame:
Expand Down
62 changes: 62 additions & 0 deletions python/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,68 @@ def test_read_avro(ctx):
assert avro_df is not None


def test_read_arrow(ctx, tmp_path):
# Write an Arrow IPC file, then read it back
table = pa.table({"a": [1, 2, 3], "b": ["x", "y", "z"]})
arrow_path = tmp_path / "test.arrow"
with pa.ipc.new_file(str(arrow_path), table.schema) as writer:
writer.write_table(table)

df = ctx.read_arrow(str(arrow_path))
result = df.collect()
assert result[0].column(0) == pa.array([1, 2, 3])
assert result[0].column(1) == pa.array(["x", "y", "z"])

# Also verify pathlib.Path works
df = ctx.read_arrow(arrow_path)
result = df.collect()
assert result[0].column(0) == pa.array([1, 2, 3])


def test_read_empty(ctx):
df = ctx.read_empty()
result = df.collect()
assert len(result) == 1
assert result[0].num_columns == 0

df = ctx.empty_table()
result = df.collect()
assert len(result) == 1
assert result[0].num_columns == 0


def test_register_arrow(ctx, tmp_path):
# Write an Arrow IPC file, then register and query it
table = pa.table({"x": [10, 20, 30]})
arrow_path = tmp_path / "test.arrow"
with pa.ipc.new_file(str(arrow_path), table.schema) as writer:
writer.write_table(table)

ctx.register_arrow("arrow_tbl", str(arrow_path))
result = ctx.sql("SELECT * FROM arrow_tbl").collect()
assert result[0].column(0) == pa.array([10, 20, 30])

# Also verify pathlib.Path works
ctx.register_arrow("arrow_tbl_path", arrow_path)
result = ctx.sql("SELECT * FROM arrow_tbl_path").collect()
assert result[0].column(0) == pa.array([10, 20, 30])


def test_register_batch(ctx):
batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})
ctx.register_batch("batch_tbl", batch)
result = ctx.sql("SELECT * FROM batch_tbl").collect()
assert result[0].column(0) == pa.array([1, 2, 3])
assert result[0].column(1) == pa.array([4, 5, 6])


def test_register_batch_empty(ctx):
batch = pa.RecordBatch.from_pydict({"a": pa.array([], type=pa.int64())})
ctx.register_batch("empty_batch_tbl", batch)
result = ctx.sql("SELECT * FROM empty_batch_tbl").collect()
assert result[0].num_rows == 0


def test_create_sql_options():
SQLOptions()

Expand Down
Loading