Skip to content

Commit 79892b3

Browse files
committed
improve comments and benchmark, turn on arrow path for int and float
1 parent ec9659c commit 79892b3

File tree

2 files changed

+43
-102
lines changed

2 files changed

+43
-102
lines changed

asv_bench/benchmarks/groupby.py

Lines changed: 24 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from pandas import (
88
NA,
9+
ArrowDtype,
910
Categorical,
1011
DataFrame,
1112
Index,
@@ -1156,84 +1157,44 @@ class GroupByAggregateArrowDtypes:
11561157
param_names = ["dtype", "method"]
11571158
params = [
11581159
[
1159-
"bool[pyarrow]",
11601160
"int32[pyarrow]",
1161+
"int64[pyarrow]",
1162+
"float32[pyarrow]",
11611163
"float64[pyarrow]",
1162-
"decimal128(25, 3)[pyarrow]",
1164+
"decimal128",
11631165
"string[pyarrow]",
1164-
"timestamp[s, tz=UTC][pyarrow]",
1165-
"duration[ms][pyarrow]",
1166-
],
1167-
[
1168-
"any",
1169-
"all",
1170-
"sum",
1171-
"prod",
1172-
"min",
1173-
"max",
1174-
"mean",
1175-
"std",
1176-
"var",
1177-
"sem",
1178-
"count",
1179-
"median",
11801166
],
1167+
["sum", "prod", "min", "max", "mean", "std", "var", "count"],
11811168
]
11821169

11831170
def setup(self, dtype, method):
11841171
import pyarrow as pa
11851172

1186-
# Parse dtype string
1187-
if dtype.startswith("decimal128"):
1188-
pa_type = pa.decimal128(25, 3)
1189-
elif dtype.startswith("timestamp"):
1190-
pa_type = pa.timestamp("s", "UTC")
1191-
elif dtype.startswith("duration"):
1192-
pa_type = pa.duration("ms")
1193-
elif dtype == "bool[pyarrow]":
1194-
pa_type = pa.bool_()
1195-
elif dtype == "int32[pyarrow]":
1196-
pa_type = pa.int32()
1197-
elif dtype == "float64[pyarrow]":
1198-
pa_type = pa.float64()
1199-
elif dtype == "string[pyarrow]":
1200-
pa_type = pa.string()
1201-
else:
1202-
raise ValueError(f"Unsupported dtype: {dtype}")
1173+
from pandas.api.types import is_string_dtype
12031174

12041175
size = 100_000
1205-
ncols = 5
1206-
columns = list("abcde")
1207-
1208-
# Generate data based on type
1209-
if pa.types.is_floating(pa_type):
1210-
data = np.random.randn(size, ncols)
1211-
elif pa.types.is_integer(pa_type):
1212-
data = np.random.randint(0, 10_000, (size, ncols))
1213-
elif pa.types.is_decimal(pa_type):
1176+
ngroups = 1000
1177+
1178+
if dtype in ("int32[pyarrow]", "int64[pyarrow]"):
1179+
data = np.random.randint(0, 10_000, size)
1180+
elif dtype in ("float32[pyarrow]", "float64[pyarrow]"):
1181+
data = np.random.randn(size)
1182+
elif dtype == "decimal128":
12141183
from decimal import Decimal
12151184

1216-
data = np.random.randn(size, ncols).round(3)
1217-
data = [[Decimal(str(x)) for x in row] for row in data]
1218-
elif pa.types.is_boolean(pa_type):
1219-
data = np.random.choice([True, False], (size, ncols))
1220-
elif pa.types.is_timestamp(pa_type):
1221-
data = np.random.randint(0, 1_000_000, (size, ncols))
1222-
elif pa.types.is_duration(pa_type):
1223-
data = np.random.randint(0, 1_000_000, (size, ncols))
1224-
elif pa.types.is_string(pa_type):
1225-
data = np.random.choice(list(ascii_letters), (size, ncols))
1226-
else:
1227-
raise ValueError(f"Unsupported pyarrow type: {pa_type}")
1185+
data = [Decimal(str(round(x, 3))) for x in np.random.randn(size)]
1186+
dtype = ArrowDtype(pa.decimal128(10, 3))
1187+
elif dtype == "string[pyarrow]":
1188+
data = np.random.choice(list(ascii_letters), size)
12281189

1229-
df = DataFrame(data, columns=columns, dtype=dtype)
1230-
# Add some NAs
1231-
df.iloc[::10, ::2] = NA
1232-
df["key"] = np.random.randint(0, 100, size)
1233-
self.df = df
1190+
ser = Series(data, dtype=dtype)
1191+
if not is_string_dtype(ser.dtype):
1192+
ser.iloc[::10] = NA
1193+
self.ser = ser
1194+
self.key = np.random.randint(0, ngroups, size)
12341195

1235-
def time_frame_agg(self, dtype, method):
1236-
self.df.groupby("key").agg(method)
1196+
def time_series_agg(self, dtype, method):
1197+
self.ser.groupby(self.key).agg(method)
12371198

12381199

12391200
from .pandas_vb_common import setup # noqa: F401 isort:skip

pandas/core/arrays/arrow/array.py

Lines changed: 19 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2601,7 +2601,7 @@ def _to_masked(self):
26012601
arr = self.to_numpy(dtype=dtype.numpy_dtype, na_value=na_value)
26022602
return dtype.construct_array_type()(arr, mask)
26032603

2604-
# Mapping from pandas groupby 'how' to PyArrow aggregation function names
2604+
# pandas groupby 'how' -> PyArrow aggregation function name
26052605
_PYARROW_AGG_FUNCS: dict[str, str] = {
26062606
"sum": "sum",
26072607
"prod": "product",
@@ -2610,13 +2610,13 @@ def _to_masked(self):
26102610
"mean": "mean",
26112611
"std": "stddev",
26122612
"var": "variance",
2613-
"sem": "stddev", # sem = stddev / sqrt(count), computed below
2613+
"sem": "stddev", # sem = stddev / sqrt(count)
26142614
"count": "count",
26152615
"any": "any",
26162616
"all": "all",
26172617
}
26182618

2619-
# Default values for missing groups (identity elements for each operation)
2619+
# Identity elements for operations (used to fill missing groups)
26202620
_PYARROW_AGG_DEFAULTS: dict[str, int | bool] = {
26212621
"sum": 0,
26222622
"prod": 1,
@@ -2637,100 +2637,78 @@ def _groupby_op_pyarrow(
26372637
"""
26382638
Perform groupby aggregation using PyArrow's native Table.group_by.
26392639
2640-
Returns None if the operation is not supported by PyArrow,
2641-
in which case the caller should fall back to the Cython path.
2640+
Returns None if not supported, caller should fall back to Cython path.
26422641
"""
26432642
pa_agg_func = self._PYARROW_AGG_FUNCS.get(how)
26442643
if pa_agg_func is None:
26452644
return None
26462645

26472646
pa_type = self._pa_array.type
2648-
# PyArrow doesn't support these aggregations for temporal types directly
26492647
if pa.types.is_temporal(pa_type) and how in ["std", "var", "sem"]:
26502648
return None
2651-
2652-
# PyArrow's any/all only work on boolean types
26532649
if how in ["any", "all"] and not pa.types.is_boolean(pa_type):
26542650
return None
26552651

2656-
# Filter out NA group (ids == -1) to avoid unnecessary computation
2652+
# Filter out NA group (ids == -1)
26572653
mask = ids >= 0
26582654
if not mask.all():
26592655
ids = ids[mask]
26602656
values = pc.filter(self._pa_array, mask)
26612657
else:
26622658
values = self._pa_array
26632659

2664-
# Create a PyArrow table with the values and group IDs
2665-
# Explicitly cast ids to int64 since np.intp is platform-dependent
2660+
# Build table and run aggregation (cast ids to int64 for portability)
26662661
group_id_arr = pa.array(ids, type=pa.int64())
26672662
table = pa.table({"value": values, "group_id": group_id_arr})
26682663

2669-
# Build aggregation list - always include count for null handling
2670-
# For std/var/sem, pass VarianceOptions with ddof to match pandas behavior
26712664
if how in ["std", "var", "sem"]:
26722665
ddof = kwargs.get("ddof", 1)
2673-
agg_with_opts = ("value", pa_agg_func, pc.VarianceOptions(ddof=ddof))
2674-
aggs = [agg_with_opts, ("value", "count")]
2666+
aggs = [("value", pa_agg_func, pc.VarianceOptions(ddof=ddof))]
26752667
else:
2676-
aggs = [("value", pa_agg_func), ("value", "count")]
2668+
aggs = [("value", pa_agg_func)]
2669+
aggs.append(("value", "count"))
26772670

2678-
# Perform the groupby aggregation
26792671
result_table = table.group_by("group_id").aggregate(aggs)
2680-
2681-
# Extract results
26822672
result_group_ids = result_table.column("group_id")
26832673
result_values = result_table.column(f"value_{pa_agg_func}")
26842674
result_counts = result_table.column("value_count")
26852675

2686-
# For sem, compute stddev / sqrt(count) using PyArrow compute
26872676
if how == "sem":
2688-
sqrt_counts = pc.sqrt(result_counts)
2689-
result_values = pc.divide(result_values, sqrt_counts)
2677+
result_values = pc.divide(result_values, pc.sqrt(result_counts))
26902678

26912679
output_type = result_values.type
26922680
default_value = pa.scalar(self._PYARROW_AGG_DEFAULTS.get(how), type=output_type)
26932681

2694-
# Handle nulls from all-null groups for sum/prod with min_count=0
2682+
# Replace nulls from all-null groups with identity element
26952683
if result_values.null_count > 0 and how in ["sum", "prod"] and min_count == 0:
26962684
result_values = pc.if_else(
26972685
pc.is_null(result_values), default_value, result_values
26982686
)
26992687

2700-
# Handle min_count: groups with count < min_count should be null
2688+
# Null out groups below min_count
27012689
if min_count > 0:
27022690
below_min_count = pc.less(result_counts, pa.scalar(min_count))
27032691
result_values = pc.if_else(below_min_count, None, result_values)
27042692

2705-
# PyArrow returns results in encounter order. We need to reorder to
2706-
# match expected output (group 0, 1, 2, ..., ngroups-1) and fill
2707-
# missing groups with default values.
2708-
#
2709-
# We use NumPy scatter (O(n)) instead of:
2710-
# - pc.scatter: doesn't handle missing groups, workaround is slower
2711-
# - join+sort: O(n log n), slower for high-cardinality groupby
2712-
#
2713-
# Explicitly cast to int64 to ensure usable as NumPy indices
2693+
# Scatter results into output array ordered by group id.
2694+
# NumPy scatter is O(n) vs O(n log n) for join+sort or pc.scatter workaround.
27142695
result_group_ids_np = result_group_ids.to_numpy(zero_copy_only=False).astype(
27152696
np.int64, copy=False
27162697
)
27172698
result_values_np = result_values.to_numpy(zero_copy_only=False)
27182699

27192700
default_py = default_value.as_py()
27202701
if default_py is not None and min_count == 0:
2721-
# Operations with identity elements (sum=0, prod=1, count=0, any=False,
2722-
# all=True): fill missing groups with default value
2702+
# Fill missing groups with identity element
27232703
output_np = np.full(ngroups, default_py, dtype=result_values_np.dtype)
27242704
output_np[result_group_ids_np] = result_values_np
27252705
pa_result = pa.array(output_np, type=output_type)
27262706
else:
2727-
# Operations without identity elements (mean, std, var, min, max, sem):
2728-
# fill missing groups with null using a boolean mask
2707+
# Fill missing groups with null
27292708
output_np = np.empty(ngroups, dtype=result_values_np.dtype)
2730-
null_mask = np.ones(ngroups, dtype=bool) # True = null/missing
2709+
null_mask = np.ones(ngroups, dtype=bool)
27312710
output_np[result_group_ids_np] = result_values_np
27322711
null_mask[result_group_ids_np] = False
2733-
# Restore nulls for groups that had null results (min_count or all-null)
27342712
if result_values.null_count > 0:
27352713
result_nulls = pc.is_null(result_values).to_numpy()
27362714
null_mask[result_group_ids_np[result_nulls]] = True
@@ -2780,6 +2758,8 @@ def _groupby_op(
27802758
pa.types.is_decimal(pa_type)
27812759
or pa.types.is_string(pa_type)
27822760
or pa.types.is_large_string(pa_type)
2761+
or pa.types.is_integer(pa_type) # TEMPORARY: for testing
2762+
or pa.types.is_floating(pa_type) # TEMPORARY: for testing
27832763
):
27842764
result = self._groupby_op_pyarrow(
27852765
how=how,

0 commit comments

Comments
 (0)