Good First Issue: Add MLX Op Handler for aten.flip
Summary
Add support for aten.flip in the MLX delegate. This op reverses tensor elements along specified dimensions and is needed by image augmentation and sequence reversal operations.
Background
The MLX delegate converts PyTorch aten ops into MLX graph nodes during export. When an aten op has no handler, it falls back to CPU execution, breaking the GPU acceleration pipeline. Adding a handler lets the op run on the Metal GPU via MLX.
Approach: Decomposed handler (preferred)
aten.flip can be decomposed using SliceNode with step=-1:
# flip(x, dims=[0, 2]) reverses along dims 0 and 2
# For each dim d in dims:
# x = slice(x, axis=d, start=size-1, stop=-(size+1), step=-1)
This approach reuses the existing SliceNode which already supports negative step (see topk handler for reference).
Steps
-
Add handler in backends/mlx/ops.py
@REGISTRY.register(target=[torch.ops.aten.flip.default])
def _flip_handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
require_args(args, 2, 2, "aten.flip")
require_kwargs(P.kwargs(n), set(), "aten.flip")
x, dims = args
# Get input shape for computing slice bounds
x_meta = n.args[0].meta.get("val")
out = x # Start with input, chain slices
for dim in dims:
dim_size = x_meta.shape[dim]
_, tmp = P.make_tmp_slot()
P.emit(
SliceNode(
x=P.slot_to_tid(out),
out=P.slot_to_tid(tmp),
axis=P.to_int_or_vid(dim),
start=P.to_int_or_vid(dim_size - 1),
stop=P.to_int_or_vid(-(dim_size + 1)),
step=-1,
)
)
out = tmp
final_out = P.make_or_get_slot(n)
P.emit(IdCopyNode(x=P.slot_to_tid(out), out=P.slot_to_tid(final_out)))
return final_out
-
Add test in backends/mlx/test/test_ops.py
This op doesn't fit the simple unary pattern (has dims parameter), so create a custom test class following the existing patterns like PermuteTest:
class FlipModel(nn.Module):
def __init__(self, dims: List[int]):
super().__init__()
self.dims = dims
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.flip(x, self.dims)
@register_test
class FlipTest(OpTestCase):
name = "flip"
def __init__(self, shape: Tuple[int, ...], dims: List[int]):
self.shape = shape
self.dims = dims
dims_str = "_".join(str(d) for d in dims)
shape_str = "x".join(str(s) for s in shape)
self.name = f"flip_{shape_str}_dims{dims_str}"
@classmethod
def get_test_configs(cls) -> List["FlipTest"]:
return [
cls(shape=(4, 5), dims=[0]),
cls(shape=(4, 5), dims=[1]),
cls(shape=(4, 5), dims=[0, 1]),
cls(shape=(3, 4, 5), dims=[-1]),
cls(shape=(3, 4, 5), dims=[0, 2]),
]
def create_model(self) -> nn.Module:
return FlipModel(self.dims)
def create_inputs(self) -> Tuple[torch.Tensor, ...]:
return (torch.randn(self.shape),)
Running tests
python -m executorch.backends.mlx.test.run_all_tests -k flip
References
Good First Issue: Add MLX Op Handler for
aten.flipSummary
Add support for
aten.flipin the MLX delegate. This op reverses tensor elements along specified dimensions and is needed by image augmentation and sequence reversal operations.Background
The MLX delegate converts PyTorch aten ops into MLX graph nodes during export. When an aten op has no handler, it falls back to CPU execution, breaking the GPU acceleration pipeline. Adding a handler lets the op run on the Metal GPU via MLX.
Approach: Decomposed handler (preferred)
aten.flipcan be decomposed usingSliceNodewithstep=-1:This approach reuses the existing
SliceNodewhich already supports negative step (seetopkhandler for reference).Steps
Add handler in
backends/mlx/ops.pyAdd test in
backends/mlx/test/test_ops.pyThis op doesn't fit the simple unary pattern (has
dimsparameter), so create a custom test class following the existing patterns likePermuteTest:Running tests
References
emit_reverse()functionflip(Tensor self, int[] dims) -> Tensor