🚀 The feature, motivation and pitch
Good First Issue: Add MLX Op Handler for aten.hardtanh
Summary
Add support for aten.hardtanh in the MLX delegate. This activation function clamps values to a range and is used in ReLU6 and other bounded activation functions.
Background
The MLX delegate converts PyTorch aten ops into MLX graph nodes during export. When an aten op has no handler, it falls back to CPU execution, breaking the GPU acceleration pipeline. Adding a handler lets the op run on the Metal GPU via MLX.
Approach: Decomposed handler (preferred)
aten.hardtanh is equivalent to clamp(x, min_val, max_val), which maps directly to the existing ClipNode:
# hardtanh(x, min_val, max_val) = clamp(x, min_val, max_val)
Steps
-
Add handler in backends/mlx/ops.py
@REGISTRY.register(target=[torch.ops.aten.hardtanh.default])
def _hardtanh_handler(P: MLXProgramBuilder, n: Node) -> Slot:
"""Handle aten.hardtanh - bounded activation function.
hardtanh(x, min_val, max_val) = clamp(x, min_val, max_val)
Default: min_val=-1, max_val=1
"""
args = P.args(n)
require_args(args, 1, 3, "aten.hardtanh")
require_kwargs(P.kwargs(n), set(), "aten.hardtanh")
x = args[0]
min_val = args[1] if len(args) > 1 else -1.0
max_val = args[2] if len(args) > 2 else 1.0
x_meta = n.args[0].meta.get("val")
dtype = x_meta.dtype if x_meta is not None else torch.float32
# Lift scalar bounds to 0-D constant tensors
a_min_tid = P.slot_to_tid(emit_lifted_constant(P, float(min_val), dtype))
a_max_tid = P.slot_to_tid(emit_lifted_constant(P, float(max_val), dtype))
out = P.make_or_get_slot(n)
P.emit(
ClipNode(
x=P.slot_to_tid(x),
out=P.slot_to_tid(out),
a_min=a_min_tid,
a_max=a_max_tid,
)
)
return out
-
Add test in backends/mlx/test/test_ops.py
Use the existing _make_unary_op_test infrastructure:
# Add to _UNARY_OP_TESTS list:
{"op_name": "hardtanh", "op_fn": torch.nn.functional.hardtanh, "shapes": _SHAPES_3, "input_fn": _input_fn(scale=3)},
Note: The default parameters (min_val=-1, max_val=1) will be tested. If you want to test custom bounds, create a separate test class following the ClampTest pattern.
Running tests
python -m executorch.backends.mlx.test.run_all_tests -k hardtanh
References
- MLX C++:
array clip(const array &a, const optional<array> &a_min, const optional<array> &a_max, StreamOrDevice s = {})
- PyTorch signature:
hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor
- Use case: ReLU6 is
hardtanh(x, 0, 6), used in MobileNet
Alternatives
No response
Additional context
No response
RFC (Optional)
No response
🚀 The feature, motivation and pitch
Good First Issue: Add MLX Op Handler for
aten.hardtanhSummary
Add support for
aten.hardtanhin the MLX delegate. This activation function clamps values to a range and is used in ReLU6 and other bounded activation functions.Background
The MLX delegate converts PyTorch aten ops into MLX graph nodes during export. When an aten op has no handler, it falls back to CPU execution, breaking the GPU acceleration pipeline. Adding a handler lets the op run on the Metal GPU via MLX.
Approach: Decomposed handler (preferred)
aten.hardtanhis equivalent toclamp(x, min_val, max_val), which maps directly to the existingClipNode:# hardtanh(x, min_val, max_val) = clamp(x, min_val, max_val)Steps
Add handler in
backends/mlx/ops.pyAdd test in
backends/mlx/test/test_ops.pyUse the existing
_make_unary_op_testinfrastructure:Note: The default parameters (min_val=-1, max_val=1) will be tested. If you want to test custom bounds, create a separate test class following the
ClampTestpattern.Running tests
References
array clip(const array &a, const optional<array> &a_min, const optional<array> &a_max, StreamOrDevice s = {})hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensorhardtanh(x, 0, 6), used in MobileNetAlternatives
No response
Additional context
No response
RFC (Optional)
No response