Skip to content

Good First Issue: Add MLX Op Handler for aten.hardtanh #18921

@metascroy

Description

@metascroy

🚀 The feature, motivation and pitch

Good First Issue: Add MLX Op Handler for aten.hardtanh

Summary

Add support for aten.hardtanh in the MLX delegate. This activation function clamps values to a range and is used in ReLU6 and other bounded activation functions.

Background

The MLX delegate converts PyTorch aten ops into MLX graph nodes during export. When an aten op has no handler, it falls back to CPU execution, breaking the GPU acceleration pipeline. Adding a handler lets the op run on the Metal GPU via MLX.

Approach: Decomposed handler (preferred)

aten.hardtanh is equivalent to clamp(x, min_val, max_val), which maps directly to the existing ClipNode:

# hardtanh(x, min_val, max_val) = clamp(x, min_val, max_val)

Steps

  1. Add handler in backends/mlx/ops.py

    @REGISTRY.register(target=[torch.ops.aten.hardtanh.default])
    def _hardtanh_handler(P: MLXProgramBuilder, n: Node) -> Slot:
        """Handle aten.hardtanh - bounded activation function.
        
        hardtanh(x, min_val, max_val) = clamp(x, min_val, max_val)
        Default: min_val=-1, max_val=1
        """
        args = P.args(n)
        require_args(args, 1, 3, "aten.hardtanh")
        require_kwargs(P.kwargs(n), set(), "aten.hardtanh")
        x = args[0]
        min_val = args[1] if len(args) > 1 else -1.0
        max_val = args[2] if len(args) > 2 else 1.0
        
        x_meta = n.args[0].meta.get("val")
        dtype = x_meta.dtype if x_meta is not None else torch.float32
        
        # Lift scalar bounds to 0-D constant tensors
        a_min_tid = P.slot_to_tid(emit_lifted_constant(P, float(min_val), dtype))
        a_max_tid = P.slot_to_tid(emit_lifted_constant(P, float(max_val), dtype))
        
        out = P.make_or_get_slot(n)
        P.emit(
            ClipNode(
                x=P.slot_to_tid(x),
                out=P.slot_to_tid(out),
                a_min=a_min_tid,
                a_max=a_max_tid,
            )
        )
        return out
  2. Add test in backends/mlx/test/test_ops.py

    Use the existing _make_unary_op_test infrastructure:

    # Add to _UNARY_OP_TESTS list:
    {"op_name": "hardtanh", "op_fn": torch.nn.functional.hardtanh, "shapes": _SHAPES_3, "input_fn": _input_fn(scale=3)},

    Note: The default parameters (min_val=-1, max_val=1) will be tested. If you want to test custom bounds, create a separate test class following the ClampTest pattern.

Running tests

python -m executorch.backends.mlx.test.run_all_tests -k hardtanh

References

  • MLX C++: array clip(const array &a, const optional<array> &a_min, const optional<array> &a_max, StreamOrDevice s = {})
  • PyTorch signature: hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor
  • Use case: ReLU6 is hardtanh(x, 0, 6), used in MobileNet

Alternatives

No response

Additional context

No response

RFC (Optional)

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    good first issueGood for newcomerstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    No status

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions