Skip to content

Conversation

@charithaintc
Copy link
Contributor

@charithaintc charithaintc commented Dec 19, 2025

This PR adds initial support for layout conflict resolution in XeGPU. Layout conflict occurs when some op's use point expects a different layout than what the op can currently provide. This conflict needs to be resolved by adding certain other xegpu ops.

Initially, We only focus conflict handling at LoadNd's tensor_desc operand. In this can we simply duplicate the corresponding CreateNd op with expected layout.

@llvmbot
Copy link
Member

llvmbot commented Dec 19, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Charitha Saumya (charithaintc)

Changes

This PR adds initial support for layout conflict resolution in XeGPU. Layout conflict occurs when some op's use point expects a different layout than that op can currently provide. This conflict needs to be resolved by adding certain other xegpu ops.

Initially, We only focus conflict handling at LoadNd's tensor_desc operand. In this can we simply duplicate the corresponding CreateNd op with expected layout.


Patch is 24.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/173090.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h (+8)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (+182-33)
  • (modified) mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir (+1-1)
  • (modified) mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir (+1-1)
  • (modified) mlir/test/Dialect/XeGPU/propagate-layout.mlir (+5-5)
  • (added) mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir (+23)
  • (modified) mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp (+76)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 1776a209d0bf1..80ea1e3407058 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -9,6 +9,8 @@
 #ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
 #define MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
 
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/Operation.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/LogicalResult.h"
@@ -91,6 +93,12 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns);
 void populateXeGPUUnrollPatterns(RewritePatternSet &patterns,
                                  const UnrollOptions &options);
 
+enum class LayoutKind { Lane, InstData, Subgroup };
+LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
+                               LayoutKind layoutKind, bool printOnly = false);
+
+LogicalResult resolveLayoutConflicts(Operation *target);
+
 } // namespace xegpu
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 7fc75e7294ea3..c8138a4d16016 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -37,6 +38,7 @@
 #include "llvm/Support/raw_ostream.h"
 
 #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
+#include "mlir/Support/WalkResult.h"
 
 namespace mlir {
 namespace xegpu {
@@ -53,8 +55,6 @@ using namespace mlir::dataflow;
 
 namespace {
 
-enum class LayoutKind { Lane, InstData, Subgroup };
-
 //===----------------------------------------------------------------------===//
 // LayoutInfo
 //===----------------------------------------------------------------------===//
@@ -380,7 +380,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
 class LayoutInfoPropagation
     : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
 private:
-  LayoutKind layoutKind;
+  xegpu::LayoutKind layoutKind;
   void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
                    ArrayRef<const LayoutInfoLattice *> results);
 
@@ -436,7 +436,7 @@ class LayoutInfoPropagation
 public:
   LayoutInfoPropagation(DataFlowSolver &solver,
                         SymbolTableCollection &symbolTable,
-                        LayoutKind layoutKind)
+                        xegpu::LayoutKind layoutKind)
       : SparseBackwardDataFlowAnalysis(solver, symbolTable),
         layoutKind(layoutKind) {}
   using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
@@ -526,12 +526,12 @@ bool LayoutInfoPropagation::hasParamsOfLayoutKind(
   if (anchorLayout == nullptr) {
     return false;
   }
-  if (layoutKind == LayoutKind::InstData) {
+  if (layoutKind == xegpu::LayoutKind::InstData) {
     return !(anchorLayout.getEffectiveInstDataAsInt().empty());
-  } else if (layoutKind == LayoutKind::Lane) {
+  } else if (layoutKind == xegpu::LayoutKind::Lane) {
     return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
              anchorLayout.getEffectiveLaneDataAsInt().empty());
-  } else if (layoutKind == LayoutKind::Subgroup) {
+  } else if (layoutKind == xegpu::LayoutKind::Subgroup) {
     return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
              anchorLayout.getEffectiveSgDataAsInt().empty());
   }
@@ -579,7 +579,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
       instData = {instHeight, instWidth};
     }
 
-    if (layoutKind == LayoutKind::InstData)
+    if (layoutKind == xegpu::LayoutKind::InstData)
       prefetchLayout =
           LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
     else
@@ -748,7 +748,7 @@ void LayoutInfoPropagation::visitDpasOp(
     SmallVector<int> instDataA = {maxALen, subgroupSize};
     SmallVector<int> instDataB = {subgroupSize, maxBLen};
 
-    if (layoutKind == LayoutKind::InstData) {
+    if (layoutKind == xegpu::LayoutKind::InstData) {
       dpasALayout =
           LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
       dpasBLayout =
@@ -762,7 +762,7 @@ void LayoutInfoPropagation::visitDpasOp(
 
     if (operands.size() > 2) {
       VectorType cTy = dpas.getAccType();
-      if (layoutKind == LayoutKind::InstData) {
+      if (layoutKind == xegpu::LayoutKind::InstData) {
         const unsigned dataCLen = bTy.getShape().back();
         auto supportedCLen =
             uArchInstruction->getSupportedN(bTy.getElementType());
@@ -832,7 +832,7 @@ void LayoutInfoPropagation::visitStoreNdOp(
       instData = {instHeight, instWidth};
     }
 
-    if (layoutKind == LayoutKind::InstData)
+    if (layoutKind == xegpu::LayoutKind::InstData)
       storeLayout =
           LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
     else
@@ -992,7 +992,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
         instData.push_back(chunkSize);
     }
 
-    if (layoutKind == LayoutKind::InstData)
+    if (layoutKind == xegpu::LayoutKind::InstData)
       loadLayout =
           LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
     else
@@ -1055,7 +1055,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
     auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
     const int subgroupSize = uArch->getSubgroupSize();
 
-    if (layoutKind == LayoutKind::InstData) {
+    if (layoutKind == xegpu::LayoutKind::InstData) {
       SmallVector<int> instData{subgroupSize};
       if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
           chunkSize > 1)
@@ -1106,7 +1106,8 @@ class RunLayoutInfoPropagation {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
 
-  RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) {
+  RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind)
+      : target(op) {
     SymbolTableCollection symbolTable;
     loadBaselineAnalyses(solver);
     solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
@@ -1180,6 +1181,77 @@ void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
     printFunctionResult(funcOp);
 }
 
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ResolveLayoutConflicts
+//===----------------------------------------------------------------------===//
+struct ResolveLayoutConflicts {
+  ResolveLayoutConflicts(Operation *parentOp)
+      : parentOp(parentOp), builder(parentOp->getContext()) {}
+  LogicalResult run();
+
+private:
+  Operation *parentOp;
+  OpBuilder builder;
+  LogicalResult resolveLoadNdOp(xegpu::LoadNdOp loadNdOp);
+};
+
+}; // namespace
+
+LogicalResult ResolveLayoutConflicts::run() {
+  auto r = parentOp->walk([&](Operation *op) -> WalkResult {
+    TypeSwitch<Operation *>(op).Case([&](xegpu::LoadNdOp loadNdOp) {
+      return failed(resolveLoadNdOp(loadNdOp)) ? WalkResult::interrupt()
+                                               : WalkResult::advance();
+    });
+    // TODO: Add other layout conflict resolution methods as needed.
+    return WalkResult::advance();
+  });
+
+  return r.wasInterrupted() ? failure() : success();
+}
+
+/// LoadNd has a conflict if the tensor descriptor layout is different from the
+/// load's anchor layout.
+LogicalResult
+ResolveLayoutConflicts::resolveLoadNdOp(xegpu::LoadNdOp loadNdOp) {
+  Attribute anchorLayout = loadNdOp.getLayoutAttr();
+  Attribute tdescLayout = loadNdOp.getTensorDescType().getLayout();
+
+  if (anchorLayout && tdescLayout && anchorLayout != tdescLayout) {
+    // Try to get the defining CreateNdDescOp of the tensor descriptor.
+    auto conflictingCreateNdOp =
+        loadNdOp.getTensorDesc().getDefiningOp<xegpu::CreateNdDescOp>();
+    if (!conflictingCreateNdOp) {
+      DBGS() << "Unable to find defining CreateNdDescOp for tensor descriptor: "
+             << loadNdOp.getTensorDesc() << "\n";
+      return failure();
+    }
+    // Duplicate the CreateNdDescOp with the expected layout.
+    builder.setInsertionPointAfter(conflictingCreateNdOp);
+    xegpu::TensorDescType tdescType = loadNdOp.getTensorDescType();
+    auto expectedLayout = anchorLayout;
+    auto newTensorDescType = xegpu::TensorDescType::get(
+        conflictingCreateNdOp.getContext(), tdescType.getShape(),
+        tdescType.getElementType(), tdescType.getEncoding(), expectedLayout);
+    xegpu::CreateNdDescOp newOp = xegpu::CreateNdDescOp::create(
+        builder, loadNdOp.getLoc(), newTensorDescType,
+        conflictingCreateNdOp->getOperands(),
+        conflictingCreateNdOp->getAttrs());
+    // Replace only the conflicting uses of the createNdOp that can be
+    // resolved using the new layout.
+    conflictingCreateNdOp->replaceUsesWithIf(
+        ArrayRef<Value>(newOp.getResult()), [&](OpOperand &opnd) {
+          auto userLoadNdOp = dyn_cast<xegpu::LoadNdOp>(opnd.getOwner());
+          if (!userLoadNdOp)
+            return false;
+          return userLoadNdOp.getLayoutAttr() == expectedLayout;
+        });
+  }
+  return success();
+}
+
 using GetLayoutFnTy = function_ref<xegpu::DistributeLayoutAttr(Value)>;
 /// Update an operation with the layout of its results. If the result type is
 /// a vector type, a temporary layout attribute is added to the operation. If
@@ -1348,26 +1420,14 @@ struct XeGPUPropagateLayoutPass final
 
 } // namespace
 
-void XeGPUPropagateLayoutPass::runOnOperation() {
-  LayoutKind layoutKind;
-  if (this->layoutKind == "lane") {
-    layoutKind = LayoutKind::Lane;
-  } else if (this->layoutKind == "inst") {
-    layoutKind = LayoutKind::InstData;
-  } else if (this->layoutKind == "subgroup") {
-    layoutKind = LayoutKind::Subgroup;
-  } else {
-    getOperation()->emitError("Unsupported layout kind option: " +
-                              this->layoutKind);
-    signalPassFailure();
-    return;
-  }
-  RunLayoutInfoPropagation analysis(getOperation(), layoutKind);
+LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
+                                      LayoutKind layoutKind, bool printOnly) {
+  RunLayoutInfoPropagation analysis(target, layoutKind);
   // Print the analysis result and exit. (for debugging purposes)
   if (printOnly) {
     auto &os = llvm::outs();
     analysis.printAnalysisResult(os);
-    return;
+    return success();
   }
   // Helper to convert LayoutInfo to xegpu::LayoutAttr.
   auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
@@ -1381,8 +1441,7 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
     return cast<xegpu::LayoutAttr>(layoutAttr);
   };
 
-  mlir::OpBuilder builder(&getContext());
-  Operation *op = getOperation();
+  Operation *op = target;
   auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
     for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
       LogicalResult r = success();
@@ -1407,7 +1466,97 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
     }
     return WalkResult::advance();
   });
-  if (walkResult.wasInterrupted()) {
+  if (walkResult.wasInterrupted())
+    return failure();
+
+  return success();
+}
+
+// LogicalResult xegpu::resolveLayoutConflicts(OpBuilder &builder,
+//                                             Operation *target) {
+//   auto r = target->walk([&](xegpu::LoadNdOp loadNdOp) -> WalkResult {
+//     // Load op has a conflict if tensor desc layout is different from the its
+//     // result layout.
+//     auto getResultLayout = [](OpResult result) {
+//       auto resultLayoutName = xegpu::getTemporaryLayoutName(result);
+//       return result.getOwner()->getAttrOfType<xegpu::DistributeLayoutAttr>(
+//           resultLayoutName);
+//     };
+//     auto hasConflict = [&getResultLayout](xegpu::LoadNdOp loadNdOp) -> bool {
+//       auto tdescType = loadNdOp.getTensorDescType();
+//       auto tdescLayout = tdescType.getLayout();
+//       auto resultLayoutName =
+//           xegpu::getTemporaryLayoutName(loadNdOp->getOpResult(0));
+//       auto resultLayout = getResultLayout(loadNdOp->getOpResult(0));
+//       return tdescLayout && resultLayout && tdescLayout != resultLayout;
+//     };
+//     if (hasConflict(loadNdOp)) {
+//       OpBuilder builder(loadNdOp);
+//       // Try to get the defining createNdDesc op.
+//       auto createNdOp =
+//           loadNdOp.getTensorDesc().getDefiningOp<xegpu::CreateNdDescOp>();
+//       if (!createNdOp) {
+//         DBGS() << "Failed to resolve LoadNdOp layout conflict: " << *loadNdOp
+//                << "\n";
+//         return WalkResult::interrupt();
+//       }
+
+//       builder.setInsertionPointAfter(createNdOp);
+//       auto tdescType = loadNdOp.getTensorDescType();
+//       auto expectedLayout = getResultLayout(loadNdOp->getOpResult(0));
+//       auto newTensorDescType = xegpu::TensorDescType::get(
+//           createNdOp.getContext(), tdescType.getShape(),
+//           tdescType.getElementType(), tdescType.getEncoding(),
+//           expectedLayout);
+//       auto newOp = xegpu::CreateNdDescOp::create(
+//           builder, loadNdOp.getLoc(), newTensorDescType,
+//           createNdOp->getOperands(), createNdOp->getAttrs());
+//       // Replace only the conflicting uses of the createNdOp that can be
+//       // resolved using the new layout.
+//       createNdOp->replaceUsesWithIf(
+//           ArrayRef<Value>(newOp.getResult()), [&](OpOperand &opnd) {
+//             auto userLoadNdOp = dyn_cast<xegpu::LoadNdOp>(opnd.getOwner());
+//             if (!userLoadNdOp)
+//               return false;
+//             auto resultLayout =
+//             getResultLayout(userLoadNdOp->getOpResult(0)); return
+//             hasConflict(userLoadNdOp) && resultLayout == expectedLayout;
+//           });
+//     }
+//     return WalkResult::advance();
+//   });
+//   if (r.wasInterrupted())
+//     return failure();
+//   return success();
+// }
+
+LogicalResult xegpu::resolveLayoutConflicts(Operation *target) {
+  ResolveLayoutConflicts resolver(target);
+  return resolver.run();
+}
+
+void XeGPUPropagateLayoutPass::runOnOperation() {
+  xegpu::LayoutKind layoutKind;
+  if (this->layoutKind == "lane") {
+    layoutKind = xegpu::LayoutKind::Lane;
+  } else if (this->layoutKind == "inst") {
+    layoutKind = xegpu::LayoutKind::InstData;
+  } else if (this->layoutKind == "subgroup") {
+    layoutKind = xegpu::LayoutKind::Subgroup;
+  } else {
+    getOperation()->emitError("Unsupported layout kind option: " +
+                              this->layoutKind);
+    signalPassFailure();
+    return;
+  }
+  OpBuilder builder(&getContext());
+  if (failed(xegpu::propagateLayouts(builder, getOperation(), layoutKind,
+                                     this->printOnly))) {
+    signalPassFailure();
+    return;
+  }
+  // Resolve layout conflicts if any.
+  if (failed(xegpu::resolveLayoutConflicts(getOperation()))) {
     signalPassFailure();
     return;
   }
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 5f70831f45e97..5e095fe0df89e 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout="layout-kind=inst" -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -xevm-attach-target='chip=pvc' -test-xegpu-propagate-layouts="layout-kind=inst" -split-input-file %s | FileCheck %s
 
 
 // CHECK-LABEL: func.func @load_store_no_array_len(
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index 092a4cf442782..7675c44be1c61 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout="layout-kind=subgroup" -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -xevm-attach-target='chip=pvc' -test-xegpu-propagate-layouts="layout-kind=subgroup" -split-input-file %s | FileCheck %s
 
 gpu.module @test {
   // CHECK-LABEL: store_nd
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index b88d8e1a78a26..3e7f3d5156d62 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout="layout-kind=lane" -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -xevm-attach-target='chip=pvc' -test-xegpu-propagate-layouts="layout-kind=lane" -split-input-file %s | FileCheck %s
 
 gpu.module @test {
 // CHECK-LABEL: func.func @dpas_f16(
@@ -32,7 +32,7 @@ func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: me
 gpu.module @test {
 // CHECK-LABEL: func.func @dpas_i8(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<8x32xi8>, %[[ARG1:[0-9a-zA-Z]+]]: vector<32x16xi8>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xi32>) {
-// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} 
+// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
 
 func.func @dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
@@ -109,7 +109,7 @@ gpu.module @test {
 // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
 // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]]  <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> 
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]]  <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
 // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
 func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
@@ -240,7 +240,7 @@ gpu.module @test {
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
 // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<12> : vector<16xindex>
-// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] 
+// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
 // CHECK-SAME: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
 // CHECK: %[[ADD_RES:.*]] = arith.addf %[[LOAD_VEC]], %[[LOAD_VEC]] {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : vector<16xf16>
 // CHECK: xegpu.store %[[ADD_RES]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
@@ -697,4 +697,4 @@ func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16
   xegpu.store_nd %6, %arg0  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
   return
 }
-}
\ No newline at end of file
+}
diff --git a/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir b/mlir/test/Diale...
[truncated]

@github-actions
Copy link

github-actions bot commented Dec 19, 2025

🐧 Linux x64 Test Results

  • 7246 tests passed
  • 598 tests skipped

✅ The build succeeded and all tests passed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants