Skip to content

Commit e350cfd

Browse files
committed
format
1 parent 7b51d3a commit e350cfd

File tree

3 files changed

+448
-2
lines changed

3 files changed

+448
-2
lines changed

asv_bench/benchmarks/groupby.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,4 +1152,88 @@ def time_resample_multiindex(self):
11521152
).mean()
11531153

11541154

1155+
class GroupByAggregateArrowDtypes:
1156+
param_names = ["dtype", "method"]
1157+
params = [
1158+
[
1159+
"bool[pyarrow]",
1160+
"int32[pyarrow]",
1161+
"float64[pyarrow]",
1162+
"decimal128(25, 3)[pyarrow]",
1163+
"string[pyarrow]",
1164+
"timestamp[s, tz=UTC][pyarrow]",
1165+
"duration[ms][pyarrow]",
1166+
],
1167+
[
1168+
"any",
1169+
"all",
1170+
"sum",
1171+
"prod",
1172+
"min",
1173+
"max",
1174+
"mean",
1175+
"std",
1176+
"var",
1177+
"sem",
1178+
"count",
1179+
"median",
1180+
],
1181+
]
1182+
1183+
def setup(self, dtype, method):
1184+
import pyarrow as pa
1185+
1186+
# Parse dtype string
1187+
if dtype.startswith("decimal128"):
1188+
pa_type = pa.decimal128(25, 3)
1189+
elif dtype.startswith("timestamp"):
1190+
pa_type = pa.timestamp("s", "UTC")
1191+
elif dtype.startswith("duration"):
1192+
pa_type = pa.duration("ms")
1193+
elif dtype == "bool[pyarrow]":
1194+
pa_type = pa.bool_()
1195+
elif dtype == "int32[pyarrow]":
1196+
pa_type = pa.int32()
1197+
elif dtype == "float64[pyarrow]":
1198+
pa_type = pa.float64()
1199+
elif dtype == "string[pyarrow]":
1200+
pa_type = pa.string()
1201+
else:
1202+
raise ValueError(f"Unsupported dtype: {dtype}")
1203+
1204+
size = 100_000
1205+
ncols = 5
1206+
columns = list("abcde")
1207+
1208+
# Generate data based on type
1209+
if pa.types.is_floating(pa_type):
1210+
data = np.random.randn(size, ncols)
1211+
elif pa.types.is_integer(pa_type):
1212+
data = np.random.randint(0, 10_000, (size, ncols))
1213+
elif pa.types.is_decimal(pa_type):
1214+
from decimal import Decimal
1215+
1216+
data = np.random.randn(size, ncols).round(3)
1217+
data = [[Decimal(str(x)) for x in row] for row in data]
1218+
elif pa.types.is_boolean(pa_type):
1219+
data = np.random.choice([True, False], (size, ncols))
1220+
elif pa.types.is_timestamp(pa_type):
1221+
data = np.random.randint(0, 1_000_000, (size, ncols))
1222+
elif pa.types.is_duration(pa_type):
1223+
data = np.random.randint(0, 1_000_000, (size, ncols))
1224+
elif pa.types.is_string(pa_type):
1225+
data = np.random.choice(list(ascii_letters), (size, ncols))
1226+
else:
1227+
raise ValueError(f"Unsupported pyarrow type: {pa_type}")
1228+
1229+
df = DataFrame(data, columns=columns, dtype=dtype)
1230+
# Add some NAs
1231+
df.iloc[::10, ::2] = NA
1232+
df["key"] = np.random.randint(0, 100, size)
1233+
self.df = df
1234+
1235+
def time_frame_agg(self, dtype, method):
1236+
self.df.groupby("key").agg(method)
1237+
1238+
11551239
from .pandas_vb_common import setup # noqa: F401 isort:skip

pandas/core/arrays/arrow/array.py

Lines changed: 157 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2601,6 +2601,143 @@ def _to_masked(self):
26012601
arr = self.to_numpy(dtype=dtype.numpy_dtype, na_value=na_value)
26022602
return dtype.construct_array_type()(arr, mask)
26032603

2604+
# Mapping from pandas groupby 'how' to PyArrow aggregation function names
2605+
_PYARROW_AGG_FUNCS: dict[str, str] = {
2606+
"sum": "sum",
2607+
"prod": "product",
2608+
"min": "min",
2609+
"max": "max",
2610+
"mean": "mean",
2611+
"std": "stddev",
2612+
"var": "variance",
2613+
"sem": "stddev", # sem = stddev / sqrt(count), computed below
2614+
"count": "count",
2615+
"any": "any",
2616+
"all": "all",
2617+
}
2618+
2619+
# Default values for missing groups (identity elements for each operation)
2620+
_PYARROW_AGG_DEFAULTS: dict[str, int | bool] = {
2621+
"sum": 0,
2622+
"prod": 1,
2623+
"count": 0,
2624+
"any": False,
2625+
"all": True,
2626+
}
2627+
2628+
def _groupby_op_pyarrow(
2629+
self,
2630+
*,
2631+
how: str,
2632+
min_count: int,
2633+
ngroups: int,
2634+
ids: npt.NDArray[np.intp],
2635+
**kwargs,
2636+
) -> Self | None:
2637+
"""
2638+
Perform groupby aggregation using PyArrow's native Table.group_by.
2639+
2640+
Returns None if the operation is not supported by PyArrow,
2641+
in which case the caller should fall back to the Cython path.
2642+
"""
2643+
pa_agg_func = self._PYARROW_AGG_FUNCS.get(how)
2644+
if pa_agg_func is None:
2645+
return None
2646+
2647+
pa_type = self._pa_array.type
2648+
# PyArrow doesn't support these aggregations for temporal types directly
2649+
if pa.types.is_temporal(pa_type) and how in ["std", "var", "sem"]:
2650+
return None
2651+
2652+
# PyArrow's any/all only work on boolean types
2653+
if how in ["any", "all"] and not pa.types.is_boolean(pa_type):
2654+
return None
2655+
2656+
# Filter out NA group (ids == -1) to avoid unnecessary computation
2657+
mask = ids >= 0
2658+
if not mask.all():
2659+
ids = ids[mask]
2660+
values = pc.filter(self._pa_array, mask)
2661+
else:
2662+
values = self._pa_array
2663+
2664+
# Create a PyArrow table with the values and group IDs
2665+
# Explicitly cast ids to int64 since np.intp is platform-dependent
2666+
group_id_arr = pa.array(ids, type=pa.int64())
2667+
table = pa.table({"value": values, "group_id": group_id_arr})
2668+
2669+
# Build aggregation list - always include count for null handling
2670+
# For std/var/sem, pass VarianceOptions with ddof to match pandas behavior
2671+
if how in ["std", "var", "sem"]:
2672+
ddof = kwargs.get("ddof", 1)
2673+
agg_with_opts = ("value", pa_agg_func, pc.VarianceOptions(ddof=ddof))
2674+
aggs = [agg_with_opts, ("value", "count")]
2675+
else:
2676+
aggs = [("value", pa_agg_func), ("value", "count")]
2677+
2678+
# Perform the groupby aggregation
2679+
result_table = table.group_by("group_id").aggregate(aggs)
2680+
2681+
# Extract results
2682+
result_group_ids = result_table.column("group_id")
2683+
result_values = result_table.column(f"value_{pa_agg_func}")
2684+
result_counts = result_table.column("value_count")
2685+
2686+
# For sem, compute stddev / sqrt(count) using PyArrow compute
2687+
if how == "sem":
2688+
sqrt_counts = pc.sqrt(result_counts)
2689+
result_values = pc.divide(result_values, sqrt_counts)
2690+
2691+
output_type = result_values.type
2692+
default_value = pa.scalar(self._PYARROW_AGG_DEFAULTS.get(how), type=output_type)
2693+
2694+
# Handle nulls from all-null groups for sum/prod with min_count=0
2695+
if result_values.null_count > 0 and how in ["sum", "prod"] and min_count == 0:
2696+
result_values = pc.if_else(
2697+
pc.is_null(result_values), default_value, result_values
2698+
)
2699+
2700+
# Handle min_count: groups with count < min_count should be null
2701+
if min_count > 0:
2702+
below_min_count = pc.less(result_counts, pa.scalar(min_count))
2703+
result_values = pc.if_else(below_min_count, None, result_values)
2704+
2705+
# PyArrow returns results in encounter order. We need to reorder to
2706+
# match expected output (group 0, 1, 2, ..., ngroups-1) and fill
2707+
# missing groups with default values.
2708+
#
2709+
# We use NumPy scatter (O(n)) instead of:
2710+
# - pc.scatter: doesn't handle missing groups, workaround is slower
2711+
# - join+sort: O(n log n), slower for high-cardinality groupby
2712+
#
2713+
# Explicitly cast to int64 to ensure usable as NumPy indices
2714+
result_group_ids_np = result_group_ids.to_numpy(zero_copy_only=False).astype(
2715+
np.int64, copy=False
2716+
)
2717+
result_values_np = result_values.to_numpy(zero_copy_only=False)
2718+
2719+
default_py = default_value.as_py()
2720+
if default_py is not None and min_count == 0:
2721+
# Operations with identity elements (sum=0, prod=1, count=0, any=False,
2722+
# all=True): fill missing groups with default value
2723+
output_np = np.full(ngroups, default_py, dtype=result_values_np.dtype)
2724+
output_np[result_group_ids_np] = result_values_np
2725+
pa_result = pa.array(output_np, type=output_type)
2726+
else:
2727+
# Operations without identity elements (mean, std, var, min, max, sem):
2728+
# fill missing groups with null using a boolean mask
2729+
output_np = np.empty(ngroups, dtype=result_values_np.dtype)
2730+
null_mask = np.ones(ngroups, dtype=bool) # True = null/missing
2731+
output_np[result_group_ids_np] = result_values_np
2732+
null_mask[result_group_ids_np] = False
2733+
# Restore nulls for groups that had null results (min_count or all-null)
2734+
if result_values.null_count > 0:
2735+
result_nulls = pc.is_null(result_values).to_numpy()
2736+
null_mask[result_group_ids_np[result_nulls]] = True
2737+
pa_result = pa.array(output_np, type=output_type, mask=null_mask)
2738+
2739+
return self._from_pyarrow_array(pa_result)
2740+
26042741
def _groupby_op(
26052742
self,
26062743
*,
@@ -2635,9 +2772,27 @@ def _groupby_op(
26352772
**kwargs,
26362773
)
26372774

2638-
# maybe convert to a compatible dtype optimized for groupby
2639-
values: ExtensionArray
26402775
pa_type = self._pa_array.type
2776+
2777+
# Try PyArrow-native path for decimal and string types where it's faster.
2778+
# For integer/float/boolean, the fallback path via _to_masked() is faster.
2779+
if (
2780+
pa.types.is_decimal(pa_type)
2781+
or pa.types.is_string(pa_type)
2782+
or pa.types.is_large_string(pa_type)
2783+
):
2784+
result = self._groupby_op_pyarrow(
2785+
how=how,
2786+
min_count=min_count,
2787+
ngroups=ngroups,
2788+
ids=ids,
2789+
**kwargs,
2790+
)
2791+
if result is not None:
2792+
return result
2793+
2794+
# Fall back to converting to masked/datetime array and using Cython
2795+
values: ExtensionArray
26412796
if pa.types.is_timestamp(pa_type):
26422797
values = self._to_datetimearray()
26432798
elif pa.types.is_duration(pa_type):

0 commit comments

Comments
 (0)