@@ -2601,6 +2601,143 @@ def _to_masked(self):
26012601 arr = self .to_numpy (dtype = dtype .numpy_dtype , na_value = na_value )
26022602 return dtype .construct_array_type ()(arr , mask )
26032603
2604+ # Mapping from pandas groupby 'how' to PyArrow aggregation function names
2605+ _PYARROW_AGG_FUNCS : dict [str , str ] = {
2606+ "sum" : "sum" ,
2607+ "prod" : "product" ,
2608+ "min" : "min" ,
2609+ "max" : "max" ,
2610+ "mean" : "mean" ,
2611+ "std" : "stddev" ,
2612+ "var" : "variance" ,
2613+ "sem" : "stddev" , # sem = stddev / sqrt(count), computed below
2614+ "count" : "count" ,
2615+ "any" : "any" ,
2616+ "all" : "all" ,
2617+ }
2618+
2619+ # Default values for missing groups (identity elements for each operation)
2620+ _PYARROW_AGG_DEFAULTS : dict [str , int | bool ] = {
2621+ "sum" : 0 ,
2622+ "prod" : 1 ,
2623+ "count" : 0 ,
2624+ "any" : False ,
2625+ "all" : True ,
2626+ }
2627+
2628+ def _groupby_op_pyarrow (
2629+ self ,
2630+ * ,
2631+ how : str ,
2632+ min_count : int ,
2633+ ngroups : int ,
2634+ ids : npt .NDArray [np .intp ],
2635+ ** kwargs ,
2636+ ) -> Self | None :
2637+ """
2638+ Perform groupby aggregation using PyArrow's native Table.group_by.
2639+
2640+ Returns None if the operation is not supported by PyArrow,
2641+ in which case the caller should fall back to the Cython path.
2642+ """
2643+ pa_agg_func = self ._PYARROW_AGG_FUNCS .get (how )
2644+ if pa_agg_func is None :
2645+ return None
2646+
2647+ pa_type = self ._pa_array .type
2648+ # PyArrow doesn't support these aggregations for temporal types directly
2649+ if pa .types .is_temporal (pa_type ) and how in ["std" , "var" , "sem" ]:
2650+ return None
2651+
2652+ # PyArrow's any/all only work on boolean types
2653+ if how in ["any" , "all" ] and not pa .types .is_boolean (pa_type ):
2654+ return None
2655+
2656+ # Filter out NA group (ids == -1) to avoid unnecessary computation
2657+ mask = ids >= 0
2658+ if not mask .all ():
2659+ ids = ids [mask ]
2660+ values = pc .filter (self ._pa_array , mask )
2661+ else :
2662+ values = self ._pa_array
2663+
2664+ # Create a PyArrow table with the values and group IDs
2665+ # Explicitly cast ids to int64 since np.intp is platform-dependent
2666+ group_id_arr = pa .array (ids , type = pa .int64 ())
2667+ table = pa .table ({"value" : values , "group_id" : group_id_arr })
2668+
2669+ # Build aggregation list - always include count for null handling
2670+ # For std/var/sem, pass VarianceOptions with ddof to match pandas behavior
2671+ if how in ["std" , "var" , "sem" ]:
2672+ ddof = kwargs .get ("ddof" , 1 )
2673+ agg_with_opts = ("value" , pa_agg_func , pc .VarianceOptions (ddof = ddof ))
2674+ aggs = [agg_with_opts , ("value" , "count" )]
2675+ else :
2676+ aggs = [("value" , pa_agg_func ), ("value" , "count" )]
2677+
2678+ # Perform the groupby aggregation
2679+ result_table = table .group_by ("group_id" ).aggregate (aggs )
2680+
2681+ # Extract results
2682+ result_group_ids = result_table .column ("group_id" )
2683+ result_values = result_table .column (f"value_{ pa_agg_func } " )
2684+ result_counts = result_table .column ("value_count" )
2685+
2686+ # For sem, compute stddev / sqrt(count) using PyArrow compute
2687+ if how == "sem" :
2688+ sqrt_counts = pc .sqrt (result_counts )
2689+ result_values = pc .divide (result_values , sqrt_counts )
2690+
2691+ output_type = result_values .type
2692+ default_value = pa .scalar (self ._PYARROW_AGG_DEFAULTS .get (how ), type = output_type )
2693+
2694+ # Handle nulls from all-null groups for sum/prod with min_count=0
2695+ if result_values .null_count > 0 and how in ["sum" , "prod" ] and min_count == 0 :
2696+ result_values = pc .if_else (
2697+ pc .is_null (result_values ), default_value , result_values
2698+ )
2699+
2700+ # Handle min_count: groups with count < min_count should be null
2701+ if min_count > 0 :
2702+ below_min_count = pc .less (result_counts , pa .scalar (min_count ))
2703+ result_values = pc .if_else (below_min_count , None , result_values )
2704+
2705+ # PyArrow returns results in encounter order. We need to reorder to
2706+ # match expected output (group 0, 1, 2, ..., ngroups-1) and fill
2707+ # missing groups with default values.
2708+ #
2709+ # We use NumPy scatter (O(n)) instead of:
2710+ # - pc.scatter: doesn't handle missing groups, workaround is slower
2711+ # - join+sort: O(n log n), slower for high-cardinality groupby
2712+ #
2713+ # Explicitly cast to int64 to ensure usable as NumPy indices
2714+ result_group_ids_np = result_group_ids .to_numpy (zero_copy_only = False ).astype (
2715+ np .int64 , copy = False
2716+ )
2717+ result_values_np = result_values .to_numpy (zero_copy_only = False )
2718+
2719+ default_py = default_value .as_py ()
2720+ if default_py is not None and min_count == 0 :
2721+ # Operations with identity elements (sum=0, prod=1, count=0, any=False,
2722+ # all=True): fill missing groups with default value
2723+ output_np = np .full (ngroups , default_py , dtype = result_values_np .dtype )
2724+ output_np [result_group_ids_np ] = result_values_np
2725+ pa_result = pa .array (output_np , type = output_type )
2726+ else :
2727+ # Operations without identity elements (mean, std, var, min, max, sem):
2728+ # fill missing groups with null using a boolean mask
2729+ output_np = np .empty (ngroups , dtype = result_values_np .dtype )
2730+ null_mask = np .ones (ngroups , dtype = bool ) # True = null/missing
2731+ output_np [result_group_ids_np ] = result_values_np
2732+ null_mask [result_group_ids_np ] = False
2733+ # Restore nulls for groups that had null results (min_count or all-null)
2734+ if result_values .null_count > 0 :
2735+ result_nulls = pc .is_null (result_values ).to_numpy ()
2736+ null_mask [result_group_ids_np [result_nulls ]] = True
2737+ pa_result = pa .array (output_np , type = output_type , mask = null_mask )
2738+
2739+ return self ._from_pyarrow_array (pa_result )
2740+
26042741 def _groupby_op (
26052742 self ,
26062743 * ,
@@ -2635,9 +2772,27 @@ def _groupby_op(
26352772 ** kwargs ,
26362773 )
26372774
2638- # maybe convert to a compatible dtype optimized for groupby
2639- values : ExtensionArray
26402775 pa_type = self ._pa_array .type
2776+
2777+ # Try PyArrow-native path for decimal and string types where it's faster.
2778+ # For integer/float/boolean, the fallback path via _to_masked() is faster.
2779+ if (
2780+ pa .types .is_decimal (pa_type )
2781+ or pa .types .is_string (pa_type )
2782+ or pa .types .is_large_string (pa_type )
2783+ ):
2784+ result = self ._groupby_op_pyarrow (
2785+ how = how ,
2786+ min_count = min_count ,
2787+ ngroups = ngroups ,
2788+ ids = ids ,
2789+ ** kwargs ,
2790+ )
2791+ if result is not None :
2792+ return result
2793+
2794+ # Fall back to converting to masked/datetime array and using Cython
2795+ values : ExtensionArray
26412796 if pa .types .is_timestamp (pa_type ):
26422797 values = self ._to_datetimearray ()
26432798 elif pa .types .is_duration (pa_type ):
0 commit comments