@@ -2601,7 +2601,7 @@ def _to_masked(self):
26012601 arr = self .to_numpy (dtype = dtype .numpy_dtype , na_value = na_value )
26022602 return dtype .construct_array_type ()(arr , mask )
26032603
2604- # Mapping from pandas groupby 'how' to PyArrow aggregation function names
2604+ # pandas groupby 'how' -> PyArrow aggregation function name
26052605 _PYARROW_AGG_FUNCS : dict [str , str ] = {
26062606 "sum" : "sum" ,
26072607 "prod" : "product" ,
@@ -2610,13 +2610,13 @@ def _to_masked(self):
26102610 "mean" : "mean" ,
26112611 "std" : "stddev" ,
26122612 "var" : "variance" ,
2613- "sem" : "stddev" , # sem = stddev / sqrt(count), computed below
2613+ "sem" : "stddev" , # sem = stddev / sqrt(count)
26142614 "count" : "count" ,
26152615 "any" : "any" ,
26162616 "all" : "all" ,
26172617 }
26182618
2619- # Default values for missing groups (identity elements for each operation )
2619+ # Identity elements for operations (used to fill missing groups )
26202620 _PYARROW_AGG_DEFAULTS : dict [str , int | bool ] = {
26212621 "sum" : 0 ,
26222622 "prod" : 1 ,
@@ -2637,100 +2637,78 @@ def _groupby_op_pyarrow(
26372637 """
26382638 Perform groupby aggregation using PyArrow's native Table.group_by.
26392639
2640- Returns None if the operation is not supported by PyArrow,
2641- in which case the caller should fall back to the Cython path.
2640+ Returns None if not supported, caller should fall back to Cython path.
26422641 """
26432642 pa_agg_func = self ._PYARROW_AGG_FUNCS .get (how )
26442643 if pa_agg_func is None :
26452644 return None
26462645
26472646 pa_type = self ._pa_array .type
2648- # PyArrow doesn't support these aggregations for temporal types directly
26492647 if pa .types .is_temporal (pa_type ) and how in ["std" , "var" , "sem" ]:
26502648 return None
2651-
2652- # PyArrow's any/all only work on boolean types
26532649 if how in ["any" , "all" ] and not pa .types .is_boolean (pa_type ):
26542650 return None
26552651
2656- # Filter out NA group (ids == -1) to avoid unnecessary computation
2652+ # Filter out NA group (ids == -1)
26572653 mask = ids >= 0
26582654 if not mask .all ():
26592655 ids = ids [mask ]
26602656 values = pc .filter (self ._pa_array , mask )
26612657 else :
26622658 values = self ._pa_array
26632659
2664- # Create a PyArrow table with the values and group IDs
2665- # Explicitly cast ids to int64 since np.intp is platform-dependent
2660+ # Build table and run aggregation (cast ids to int64 for portability)
26662661 group_id_arr = pa .array (ids , type = pa .int64 ())
26672662 table = pa .table ({"value" : values , "group_id" : group_id_arr })
26682663
2669- # Build aggregation list - always include count for null handling
2670- # For std/var/sem, pass VarianceOptions with ddof to match pandas behavior
26712664 if how in ["std" , "var" , "sem" ]:
26722665 ddof = kwargs .get ("ddof" , 1 )
2673- agg_with_opts = ("value" , pa_agg_func , pc .VarianceOptions (ddof = ddof ))
2674- aggs = [agg_with_opts , ("value" , "count" )]
2666+ aggs = [("value" , pa_agg_func , pc .VarianceOptions (ddof = ddof ))]
26752667 else :
2676- aggs = [("value" , pa_agg_func ), ("value" , "count" )]
2668+ aggs = [("value" , pa_agg_func )]
2669+ aggs .append (("value" , "count" ))
26772670
2678- # Perform the groupby aggregation
26792671 result_table = table .group_by ("group_id" ).aggregate (aggs )
2680-
2681- # Extract results
26822672 result_group_ids = result_table .column ("group_id" )
26832673 result_values = result_table .column (f"value_{ pa_agg_func } " )
26842674 result_counts = result_table .column ("value_count" )
26852675
2686- # For sem, compute stddev / sqrt(count) using PyArrow compute
26872676 if how == "sem" :
2688- sqrt_counts = pc .sqrt (result_counts )
2689- result_values = pc .divide (result_values , sqrt_counts )
2677+ result_values = pc .divide (result_values , pc .sqrt (result_counts ))
26902678
26912679 output_type = result_values .type
26922680 default_value = pa .scalar (self ._PYARROW_AGG_DEFAULTS .get (how ), type = output_type )
26932681
2694- # Handle nulls from all-null groups for sum/prod with min_count=0
2682+ # Replace nulls from all-null groups with identity element
26952683 if result_values .null_count > 0 and how in ["sum" , "prod" ] and min_count == 0 :
26962684 result_values = pc .if_else (
26972685 pc .is_null (result_values ), default_value , result_values
26982686 )
26992687
2700- # Handle min_count: groups with count < min_count should be null
2688+ # Null out groups below min_count
27012689 if min_count > 0 :
27022690 below_min_count = pc .less (result_counts , pa .scalar (min_count ))
27032691 result_values = pc .if_else (below_min_count , None , result_values )
27042692
2705- # PyArrow returns results in encounter order. We need to reorder to
2706- # match expected output (group 0, 1, 2, ..., ngroups-1) and fill
2707- # missing groups with default values.
2708- #
2709- # We use NumPy scatter (O(n)) instead of:
2710- # - pc.scatter: doesn't handle missing groups, workaround is slower
2711- # - join+sort: O(n log n), slower for high-cardinality groupby
2712- #
2713- # Explicitly cast to int64 to ensure usable as NumPy indices
2693+ # Scatter results into output array ordered by group id.
2694+ # NumPy scatter is O(n) vs O(n log n) for join+sort or pc.scatter workaround.
27142695 result_group_ids_np = result_group_ids .to_numpy (zero_copy_only = False ).astype (
27152696 np .int64 , copy = False
27162697 )
27172698 result_values_np = result_values .to_numpy (zero_copy_only = False )
27182699
27192700 default_py = default_value .as_py ()
27202701 if default_py is not None and min_count == 0 :
2721- # Operations with identity elements (sum=0, prod=1, count=0, any=False,
2722- # all=True): fill missing groups with default value
2702+ # Fill missing groups with identity element
27232703 output_np = np .full (ngroups , default_py , dtype = result_values_np .dtype )
27242704 output_np [result_group_ids_np ] = result_values_np
27252705 pa_result = pa .array (output_np , type = output_type )
27262706 else :
2727- # Operations without identity elements (mean, std, var, min, max, sem):
2728- # fill missing groups with null using a boolean mask
2707+ # Fill missing groups with null
27292708 output_np = np .empty (ngroups , dtype = result_values_np .dtype )
2730- null_mask = np .ones (ngroups , dtype = bool ) # True = null/missing
2709+ null_mask = np .ones (ngroups , dtype = bool )
27312710 output_np [result_group_ids_np ] = result_values_np
27322711 null_mask [result_group_ids_np ] = False
2733- # Restore nulls for groups that had null results (min_count or all-null)
27342712 if result_values .null_count > 0 :
27352713 result_nulls = pc .is_null (result_values ).to_numpy ()
27362714 null_mask [result_group_ids_np [result_nulls ]] = True
@@ -2780,6 +2758,8 @@ def _groupby_op(
27802758 pa .types .is_decimal (pa_type )
27812759 or pa .types .is_string (pa_type )
27822760 or pa .types .is_large_string (pa_type )
2761+ or pa .types .is_integer (pa_type ) # TEMPORARY: for testing
2762+ or pa .types .is_floating (pa_type ) # TEMPORARY: for testing
27832763 ):
27842764 result = self ._groupby_op_pyarrow (
27852765 how = how ,
0 commit comments