From 2f0444f329f79b86963d784cb96bd74562bbbcda Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Mon, 27 Apr 2026 17:17:50 -0700 Subject: [PATCH] Add dedicated Arrow CSR result type --- src_cpp/include/py_connection.h | 2 + src_cpp/include/py_query_result.h | 1 + src_cpp/py_connection.cpp | 10 +++++ src_cpp/py_query_result.cpp | 55 ++++++++++++++++++++++++++ src_py/__init__.py | 4 +- src_py/_lbug_capi.py | 5 +++ src_py/connection.py | 23 ++++++++++- src_py/query_result.py | 54 ++++++++++++++++++++++++++ test/test_arrow.py | 64 +++++++++++++++++++++++++++++++ 9 files changed, 216 insertions(+), 2 deletions(-) diff --git a/src_cpp/include/py_connection.h b/src_cpp/include/py_connection.h index b05a087..2817f87 100644 --- a/src_cpp/include/py_connection.h +++ b/src_cpp/include/py_connection.h @@ -29,6 +29,8 @@ class PyConnection { const py::dict& params); std::unique_ptr query(const std::string& statement); + std::unique_ptr queryAsArrow(const std::string& statement, + int64_t chunkSize); void setMaxNumThreadForExec(uint64_t numThreads); diff --git a/src_cpp/include/py_query_result.h b/src_cpp/include/py_query_result.h index 4243bdf..dfec9ab 100644 --- a/src_cpp/include/py_query_result.h +++ b/src_cpp/include/py_query_result.h @@ -35,6 +35,7 @@ class PyQueryResult { py::object getAsDF(); lbug::pyarrow::Table getAsArrow(std::int64_t chunkSize, bool fallbackExtensionTypes); + py::dict getCSR(); py::list getColumnDataTypes(); diff --git a/src_cpp/py_connection.cpp b/src_cpp/py_connection.cpp index f8a7139..1abf5ea 100644 --- a/src_cpp/py_connection.cpp +++ b/src_cpp/py_connection.cpp @@ -31,6 +31,8 @@ void PyConnection::initialize(py::handle& m) { .def("execute", &PyConnection::execute, py::arg("prepared_statement"), py::arg("parameters") = py::dict()) .def("query", &PyConnection::query, py::arg("statement")) + .def("query_as_arrow", &PyConnection::queryAsArrow, py::arg("statement"), + py::arg("chunk_size")) .def("set_max_threads_for_exec", &PyConnection::setMaxNumThreadForExec, py::arg("num_threads")) .def("prepare", &PyConnection::prepare, py::arg("query"), @@ -175,6 +177,14 @@ std::unique_ptr PyConnection::query(const std::string& statement) return checkAndWrapQueryResult(queryResult); } +std::unique_ptr PyConnection::queryAsArrow(const std::string& statement, + int64_t chunkSize) { + py::gil_scoped_release release; + auto queryResult = conn->queryAsArrow(statement, chunkSize); + py::gil_scoped_acquire acquire; + return checkAndWrapQueryResult(queryResult); +} + void PyConnection::setMaxNumThreadForExec(uint64_t numThreads) { conn->setMaxNumThreadForExec(numThreads); } diff --git a/src_cpp/py_query_result.cpp b/src_cpp/py_query_result.cpp index b3e0db8..01d65de 100644 --- a/src_cpp/py_query_result.cpp +++ b/src_cpp/py_query_result.cpp @@ -7,12 +7,14 @@ #include "common/arrow/arrow_row_batch.h" #include "common/constants.h" #include "common/exception/not_implemented.h" +#include "common/exception/runtime.h" #include "common/types/uuid.h" #include "common/types/value/nested.h" #include "common/types/value/node.h" #include "common/types/value/rel.h" #include "datetime.h" // python lib #include "include/py_query_result_converter.h" +#include "main/query_result/arrow_query_result.h" using namespace lbug::common; using lbug::importCache; @@ -30,6 +32,7 @@ void PyQueryResult::initialize(py::handle& m) { .def("close", &PyQueryResult::close) .def("getAsDF", &PyQueryResult::getAsDF) .def("getAsArrow", &PyQueryResult::getAsArrow) + .def("getCSR", &PyQueryResult::getCSR) .def("getColumnNames", &PyQueryResult::getColumnNames) .def("getColumnDataTypes", &PyQueryResult::getColumnDataTypes) .def("resetIterator", &PyQueryResult::resetIterator) @@ -85,6 +88,30 @@ void PyQueryResult::close() { } } +namespace { + +py::array_t copyToNumpyArray(const std::vector& values) { + auto result = py::array_t(values.size()); + auto* data = static_cast(result.request().ptr); + std::copy(values.begin(), values.end(), data); + return result; +} + +py::dict buildCSRResult(std::vector indptr, std::vector indices, + std::vector edgeIDs, bool includeEdgeIDs) { + py::dict result; + result["indptr"] = copyToNumpyArray(indptr); + result["indices"] = copyToNumpyArray(indices); + if (includeEdgeIDs) { + result["edge_ids"] = copyToNumpyArray(edgeIDs); + } else { + result["edge_ids"] = py::none(); + } + return result; +} + +} // namespace + static py::object converTimestampToPyObject(timestamp_t& timestamp) { int32_t year = 0, month = 0, day = 0, hour = 0, min = 0, sec = 0, micros = 0; date_t date; @@ -320,6 +347,23 @@ py::object PyQueryResult::getArrowChunks(const std::vector& types, lbug::pyarrow::Table PyQueryResult::getAsArrow(std::int64_t chunkSize, bool fallbackExtensionTypes) { + if (queryResult->getType() == QueryResultType::ARROW) { + auto types = queryResult->getColumnDataTypes(); + auto names = queryResult->getColumnNames(); + py::list batches; + auto batchImportFunc = importCache->pyarrow.lib.RecordBatch._import_from_c(); + while (queryResult->hasNextArrowChunk()) { + auto data = queryResult->getNextArrowChunk(chunkSize); + auto schema = ArrowConverter::toArrowSchema(types, names, fallbackExtensionTypes); + batches.append( + batchImportFunc((std::uint64_t)data.get(), (std::uint64_t)schema.get())); + } + auto schema = ArrowConverter::toArrowSchema(types, names, fallbackExtensionTypes); + auto fromBatchesFunc = importCache->pyarrow.lib.Table.from_batches(); + auto schemaImportFunc = importCache->pyarrow.lib.Schema._import_from_c(); + auto schemaObj = schemaImportFunc((std::uint64_t)schema.get()); + return py::cast(fromBatchesFunc(batches, schemaObj)); + } auto types = queryResult->getColumnDataTypes(); auto names = queryResult->getColumnNames(); py::list batches = getArrowChunks(types, names, chunkSize, fallbackExtensionTypes); @@ -330,6 +374,17 @@ lbug::pyarrow::Table PyQueryResult::getAsArrow(std::int64_t chunkSize, return py::cast(fromBatchesFunc(batches, schemaObj)); } +py::dict PyQueryResult::getCSR() { + if (auto* arrowQueryResult = dynamic_cast(queryResult); + arrowQueryResult != nullptr && arrowQueryResult->hasCSRMetadata()) { + const auto& metadata = arrowQueryResult->getCSRMetadata(); + return buildCSRResult(metadata.indptr, metadata.indices, metadata.edgeIDs, + metadata.hasEdgeIDs); + } + throw RuntimeException( + "CSR export is only supported for Arrow query results with native CSR metadata."); +} + py::list PyQueryResult::getColumnDataTypes() { auto columnDataTypes = queryResult->getColumnDataTypes(); py::tuple result(columnDataTypes.size()); diff --git a/src_py/__init__.py b/src_py/__init__.py index 6a60db1..782cbda 100644 --- a/src_py/__init__.py +++ b/src_py/__init__.py @@ -56,7 +56,7 @@ from .connection import Connection # noqa: E402 from .database import Database # noqa: E402 from .prepared_statement import PreparedStatement # noqa: E402 -from .query_result import QueryResult # noqa: E402 +from .query_result import ArrowQueryResult, CSRResult, QueryResult # noqa: E402 from .types import Type # noqa: E402 _VERSION_INFO: tuple[str, int] | None = None @@ -80,7 +80,9 @@ def __getattr__(name: str) -> str | int: __all__ = [ "AsyncConnection", + "ArrowQueryResult", "Connection", + "CSRResult", "Database", "PreparedStatement", "QueryResult", diff --git a/src_py/_lbug_capi.py b/src_py/_lbug_capi.py index 75e4f80..d7fffd8 100644 --- a/src_py/_lbug_capi.py +++ b/src_py/_lbug_capi.py @@ -1229,6 +1229,11 @@ def getAsArrow(self, *_args: Any, **_kwargs: Any) -> Any: "Arrow export is not yet implemented in C-API backend" ) + def getCSR(self, *_args: Any, **_kwargs: Any) -> Any: + raise NotImplementedError( + "CSR export is not yet implemented in C-API backend" + ) + def getAsDF(self) -> Any: raise NotImplementedError( "DataFrame export is not yet implemented in C-API backend" diff --git a/src_py/connection.py b/src_py/connection.py index 6f47b3a..0fad143 100644 --- a/src_py/connection.py +++ b/src_py/connection.py @@ -8,7 +8,7 @@ from ._backend import get_capi_module, get_pybind_module from .prepared_statement import PreparedStatement -from .query_result import QueryResult +from .query_result import ArrowQueryResult, QueryResult if TYPE_CHECKING: import sys @@ -369,6 +369,27 @@ def execute( all_query_results.append(next_query_result) return all_query_results + def query_as_arrow(self, query: str, chunk_size: int) -> ArrowQueryResult: + """ + Execute a query with the native Arrow collector path. + + This is the efficient path for CSR-aware Arrow export. + """ + self.init_connection() + if not self._using_pybind_backend(): + msg = "query_as_arrow requires the pybind backend" + raise NotImplementedError(msg) + query_result_internal = self._get_pybind_connection().query_as_arrow( + query, chunk_size + ) + if not query_result_internal.isSuccess(): + raise RuntimeError(query_result_internal.getErrorMessage()) + current_query_result = ArrowQueryResult( + self, query_result_internal, native_chunk_size=chunk_size + ) + self._register_query_result(current_query_result) + return current_query_result + def _prepare( self, query: str, diff --git a/src_py/query_result.py b/src_py/query_result.py index 12cd8d6..b3372c4 100644 --- a/src_py/query_result.py +++ b/src_py/query_result.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import TYPE_CHECKING from .constants import DST, ID, LABEL, NODES, RELS, SRC @@ -525,6 +526,59 @@ def rows_as_dict(self, state=True) -> Self: return self +class ArrowQueryResult(QueryResult): + """QueryResult backed by the native Arrow collector path.""" + + def __init__( + self, connection: Any, query_result: Any, native_chunk_size: int + ) -> None: + super().__init__(connection, query_result) + self._native_chunk_size = native_chunk_size + + def get_as_arrow( + self, chunk_size: int | None = None, *, fallbackExtensionTypes: bool = False + ) -> pa.Table: + """ + Get the query result as a PyArrow Table. + + Arrow-native results preserve the execution-time chunking chosen by + `Connection.query_as_arrow(...)`. Requesting `None`, `0`, or `-1` + reuses that native chunk size instead of rechunking the result. + """ + if chunk_size is None or chunk_size <= 0: + chunk_size = self._native_chunk_size + return super().get_as_arrow( + chunk_size, fallbackExtensionTypes=fallbackExtensionTypes + ) + + def csr(self) -> CSRResult: + """ + Get native CSR arrays from an Arrow query result. + + This is available only for Arrow results with CSR metadata, typically + from `Connection.query_as_arrow(...)` on relationship-shaped projections. + """ + self.check_for_query_result_close() + + import pyarrow as pa + + csr = self._query_result.getCSR() + return CSRResult( + indptr=pa.array(csr["indptr"]), + indices=pa.array(csr["indices"]), + edge_ids=( + None if csr["edge_ids"] is None else pa.array(csr["edge_ids"]) + ), + ) + + +@dataclass(frozen=True) +class CSRResult: + indptr: pa.Array + indices: pa.Array + edge_ids: pa.Array | None = None + + def _row_to_dict(columns: list[str], row: list[Any]) -> dict[str, Any]: if len(columns) != len(row): msg = "Number of columns in output row does not match number of columns" diff --git a/test/test_arrow.py b/test/test_arrow.py index 72c7af2..40f0b50 100644 --- a/test/test_arrow.py +++ b/test/test_arrow.py @@ -772,3 +772,67 @@ def test_to_arrow1(conn: lb.Connection) -> None: -1 ) # what is a chunk size of -1 even supposed to mean? assert arrow_tbl == [] + + +def test_query_as_arrow_csr_with_rel_ids(conn_db_readonly: ConnDB) -> None: + conn, _ = conn_db_readonly + query = """ + MATCH (a:person)-[b:knows]->(c:person) + RETURN a.rowid, b.rowid, c.rowid + """ + rows = conn.execute(query).get_all() + csr = conn.query_as_arrow(query, 8).csr() + + assert csr.edge_ids is not None + + reconstructed = [] + indptr = csr.indptr.to_pylist() + indices = csr.indices.to_pylist() + edge_ids = csr.edge_ids.to_pylist() + for src_rowid in range(len(indptr) - 1): + for idx in range(indptr[src_rowid], indptr[src_rowid + 1]): + reconstructed.append([src_rowid, edge_ids[idx], indices[idx]]) + + assert reconstructed == rows + + +def test_query_as_arrow_csr_with_extra_columns(conn_db_readonly: ConnDB) -> None: + conn, _ = conn_db_readonly + query = """ + MATCH (a:person)-[b:knows]->(c:person) + RETURN a.rowid, b.rowid, c.rowid, b.date, c.fName + """ + result = conn.query_as_arrow(query, 8) + csr = result.csr() + arrow_tbl = result.get_as_arrow(0) + + assert csr.edge_ids is not None + assert arrow_tbl.column_names == ["a.rowid", "b.rowid", "c.rowid", "b.date", "c.fName"] + assert len(csr.indptr) >= 2 + + +def test_query_as_arrow_csr_without_rel_ids(conn_db_readonly: ConnDB) -> None: + conn, _ = conn_db_readonly + query = """ + MATCH (a:person)-[:knows]->(c:person) + RETURN a.rowid, c.rowid + """ + rows = conn.execute(query).get_all() + csr = conn.query_as_arrow(query, 8).csr() + + assert csr.edge_ids is None + + reconstructed = [] + indptr = csr.indptr.to_pylist() + indices = csr.indices.to_pylist() + for src_rowid in range(len(indptr) - 1): + for idx in range(indptr[src_rowid], indptr[src_rowid + 1]): + reconstructed.append([src_rowid, indices[idx]]) + + assert reconstructed == rows + + +def test_query_as_arrow_csr_rejects_non_csr_shape(conn_db_readonly: ConnDB) -> None: + conn, _ = conn_db_readonly + with pytest.raises(RuntimeError, match="CSR export is only supported"): + conn.query_as_arrow("MATCH (a:person) RETURN a.fName", 8).csr()