Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H

#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"
Expand Down Expand Up @@ -91,6 +93,12 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns);
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns,
const UnrollOptions &options);

enum class LayoutKind { Lane, InstData, Subgroup };
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
LayoutKind layoutKind, bool printOnly = false);

LogicalResult resolveLayoutConflicts(Operation *target);

} // namespace xegpu
} // namespace mlir

Expand Down
157 changes: 124 additions & 33 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
Expand All @@ -37,6 +38,7 @@
#include "llvm/Support/raw_ostream.h"

#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/Support/WalkResult.h"

namespace mlir {
namespace xegpu {
Expand All @@ -53,8 +55,6 @@ using namespace mlir::dataflow;

namespace {

enum class LayoutKind { Lane, InstData, Subgroup };

//===----------------------------------------------------------------------===//
// LayoutInfo
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -380,7 +380,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
class LayoutInfoPropagation
: public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
private:
LayoutKind layoutKind;
xegpu::LayoutKind layoutKind;
void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);

Expand Down Expand Up @@ -436,7 +436,7 @@ class LayoutInfoPropagation
public:
LayoutInfoPropagation(DataFlowSolver &solver,
SymbolTableCollection &symbolTable,
LayoutKind layoutKind)
xegpu::LayoutKind layoutKind)
: SparseBackwardDataFlowAnalysis(solver, symbolTable),
layoutKind(layoutKind) {}
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
Expand Down Expand Up @@ -526,12 +526,12 @@ bool LayoutInfoPropagation::hasParamsOfLayoutKind(
if (anchorLayout == nullptr) {
return false;
}
if (layoutKind == LayoutKind::InstData) {
if (layoutKind == xegpu::LayoutKind::InstData) {
return !(anchorLayout.getEffectiveInstDataAsInt().empty());
} else if (layoutKind == LayoutKind::Lane) {
} else if (layoutKind == xegpu::LayoutKind::Lane) {
return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
anchorLayout.getEffectiveLaneDataAsInt().empty());
} else if (layoutKind == LayoutKind::Subgroup) {
} else if (layoutKind == xegpu::LayoutKind::Subgroup) {
return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
anchorLayout.getEffectiveSgDataAsInt().empty());
}
Expand Down Expand Up @@ -579,7 +579,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
instData = {instHeight, instWidth};
}

if (layoutKind == LayoutKind::InstData)
if (layoutKind == xegpu::LayoutKind::InstData)
prefetchLayout =
LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
else
Expand Down Expand Up @@ -748,7 +748,7 @@ void LayoutInfoPropagation::visitDpasOp(
SmallVector<int> instDataA = {maxALen, subgroupSize};
SmallVector<int> instDataB = {subgroupSize, maxBLen};

if (layoutKind == LayoutKind::InstData) {
if (layoutKind == xegpu::LayoutKind::InstData) {
dpasALayout =
LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
dpasBLayout =
Expand All @@ -762,7 +762,7 @@ void LayoutInfoPropagation::visitDpasOp(

if (operands.size() > 2) {
VectorType cTy = dpas.getAccType();
if (layoutKind == LayoutKind::InstData) {
if (layoutKind == xegpu::LayoutKind::InstData) {
const unsigned dataCLen = bTy.getShape().back();
auto supportedCLen =
uArchInstruction->getSupportedN(bTy.getElementType());
Expand Down Expand Up @@ -832,7 +832,7 @@ void LayoutInfoPropagation::visitStoreNdOp(
instData = {instHeight, instWidth};
}

if (layoutKind == LayoutKind::InstData)
if (layoutKind == xegpu::LayoutKind::InstData)
storeLayout =
LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
else
Expand Down Expand Up @@ -992,7 +992,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
instData.push_back(chunkSize);
}

if (layoutKind == LayoutKind::InstData)
if (layoutKind == xegpu::LayoutKind::InstData)
loadLayout =
LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
else
Expand Down Expand Up @@ -1055,7 +1055,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
const int subgroupSize = uArch->getSubgroupSize();

if (layoutKind == LayoutKind::InstData) {
if (layoutKind == xegpu::LayoutKind::InstData) {
SmallVector<int> instData{subgroupSize};
if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
chunkSize > 1)
Expand Down Expand Up @@ -1106,7 +1106,8 @@ class RunLayoutInfoPropagation {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)

RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) {
RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind)
: target(op) {
SymbolTableCollection symbolTable;
loadBaselineAnalyses(solver);
solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
Expand Down Expand Up @@ -1180,6 +1181,77 @@ void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
printFunctionResult(funcOp);
}

namespace {

//===----------------------------------------------------------------------===//
// ResolveLayoutConflicts
//===----------------------------------------------------------------------===//
struct ResolveLayoutConflicts {
ResolveLayoutConflicts(Operation *parentOp)
: parentOp(parentOp), builder(parentOp->getContext()) {}
LogicalResult run();

private:
Operation *parentOp;
OpBuilder builder;
LogicalResult resolveLoadNdOp(xegpu::LoadNdOp loadNdOp);
};

} // namespace

LogicalResult ResolveLayoutConflicts::run() {
auto r = parentOp->walk([&](Operation *op) -> WalkResult {
TypeSwitch<Operation *>(op).Case([&](xegpu::LoadNdOp loadNdOp) {
return failed(resolveLoadNdOp(loadNdOp)) ? WalkResult::interrupt()
: WalkResult::advance();
});
// TODO: Add other layout conflict resolution methods as needed.
return WalkResult::advance();
});

return r.wasInterrupted() ? failure() : success();
}

/// LoadNd has a conflict if the tensor descriptor layout is different from the
/// load's anchor layout.
LogicalResult
ResolveLayoutConflicts::resolveLoadNdOp(xegpu::LoadNdOp loadNdOp) {
Attribute anchorLayout = loadNdOp.getLayoutAttr();
Attribute tdescLayout = loadNdOp.getTensorDescType().getLayout();

if (anchorLayout && tdescLayout && anchorLayout != tdescLayout) {
// Try to get the defining CreateNdDescOp of the tensor descriptor.
auto conflictingCreateNdOp =
loadNdOp.getTensorDesc().getDefiningOp<xegpu::CreateNdDescOp>();
if (!conflictingCreateNdOp) {
DBGS() << "Unable to find defining CreateNdDescOp for tensor descriptor: "
<< loadNdOp.getTensorDesc() << "\n";
return failure();
}
// Duplicate the CreateNdDescOp with the expected layout.
builder.setInsertionPointAfter(conflictingCreateNdOp);
xegpu::TensorDescType tdescType = loadNdOp.getTensorDescType();
auto expectedLayout = anchorLayout;
auto newTensorDescType = xegpu::TensorDescType::get(
conflictingCreateNdOp.getContext(), tdescType.getShape(),
tdescType.getElementType(), tdescType.getEncoding(), expectedLayout);
xegpu::CreateNdDescOp newOp = xegpu::CreateNdDescOp::create(
builder, loadNdOp.getLoc(), newTensorDescType,
conflictingCreateNdOp->getOperands(),
conflictingCreateNdOp->getAttrs());
// Replace only the conflicting uses of the createNdOp that can be
// resolved using the new layout.
conflictingCreateNdOp->replaceUsesWithIf(
ArrayRef<Value>(newOp.getResult()), [&](OpOperand &opnd) {
auto userLoadNdOp = dyn_cast<xegpu::LoadNdOp>(opnd.getOwner());
if (!userLoadNdOp)
return false;
return userLoadNdOp.getLayoutAttr() == expectedLayout;
});
}
return success();
}

using GetLayoutFnTy = function_ref<xegpu::DistributeLayoutAttr(Value)>;
/// Update an operation with the layout of its results. If the result type is
/// a vector type, a temporary layout attribute is added to the operation. If
Expand Down Expand Up @@ -1348,26 +1420,14 @@ struct XeGPUPropagateLayoutPass final

} // namespace

void XeGPUPropagateLayoutPass::runOnOperation() {
LayoutKind layoutKind;
if (this->layoutKind == "lane") {
layoutKind = LayoutKind::Lane;
} else if (this->layoutKind == "inst") {
layoutKind = LayoutKind::InstData;
} else if (this->layoutKind == "subgroup") {
layoutKind = LayoutKind::Subgroup;
} else {
getOperation()->emitError("Unsupported layout kind option: " +
this->layoutKind);
signalPassFailure();
return;
}
RunLayoutInfoPropagation analysis(getOperation(), layoutKind);
LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
LayoutKind layoutKind, bool printOnly) {
RunLayoutInfoPropagation analysis(target, layoutKind);
// Print the analysis result and exit. (for debugging purposes)
if (printOnly) {
auto &os = llvm::outs();
analysis.printAnalysisResult(os);
return;
return success();
}
// Helper to convert LayoutInfo to xegpu::LayoutAttr.
auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
Expand All @@ -1381,8 +1441,7 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
return cast<xegpu::LayoutAttr>(layoutAttr);
};

mlir::OpBuilder builder(&getContext());
Operation *op = getOperation();
Operation *op = target;
auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
LogicalResult r = success();
Expand All @@ -1407,7 +1466,39 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted()) {
if (walkResult.wasInterrupted())
return failure();

return success();
}

LogicalResult xegpu::resolveLayoutConflicts(Operation *target) {
ResolveLayoutConflicts resolver(target);
return resolver.run();
}

void XeGPUPropagateLayoutPass::runOnOperation() {
xegpu::LayoutKind layoutKind;
if (this->layoutKind == "lane") {
layoutKind = xegpu::LayoutKind::Lane;
} else if (this->layoutKind == "inst") {
layoutKind = xegpu::LayoutKind::InstData;
} else if (this->layoutKind == "subgroup") {
layoutKind = xegpu::LayoutKind::Subgroup;
} else {
getOperation()->emitError("Unsupported layout kind option: " +
this->layoutKind);
signalPassFailure();
return;
}
OpBuilder builder(&getContext());
if (failed(xegpu::propagateLayouts(builder, getOperation(), layoutKind,
this->printOnly))) {
signalPassFailure();
return;
}
// Resolve layout conflicts if any.
if (failed(xegpu::resolveLayoutConflicts(getOperation()))) {
signalPassFailure();
return;
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout="layout-kind=inst" -split-input-file %s | FileCheck %s
// RUN: mlir-opt -xevm-attach-target='chip=pvc' -test-xegpu-propagate-layouts="layout-kind=inst" -split-input-file %s | FileCheck %s


// CHECK-LABEL: func.func @load_store_no_array_len(
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout="layout-kind=subgroup" -split-input-file %s | FileCheck %s
// RUN: mlir-opt -xevm-attach-target='chip=pvc' -test-xegpu-propagate-layouts="layout-kind=subgroup" -split-input-file %s | FileCheck %s

gpu.module @test {
// CHECK-LABEL: store_nd
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Dialect/XeGPU/propagate-layout.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout="layout-kind=lane" -split-input-file %s | FileCheck %s
// RUN: mlir-opt -xevm-attach-target='chip=pvc' -test-xegpu-propagate-layouts="layout-kind=lane" -split-input-file %s | FileCheck %s

gpu.module @test {
// CHECK-LABEL: func.func @dpas_f16(
Expand Down Expand Up @@ -32,7 +32,7 @@ func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: me
gpu.module @test {
// CHECK-LABEL: func.func @dpas_i8(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<8x32xi8>, %[[ARG1:[0-9a-zA-Z]+]]: vector<32x16xi8>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xi32>) {
// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}

func.func @dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -109,7 +109,7 @@ gpu.module @test {
// CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -240,7 +240,7 @@ gpu.module @test {
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<12> : vector<16xindex>
// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
// CHECK-SAME: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
// CHECK: %[[ADD_RES:.*]] = arith.addf %[[LOAD_VEC]], %[[LOAD_VEC]] {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : vector<16xf16>
// CHECK: xegpu.store %[[ADD_RES]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
Expand Down Expand Up @@ -697,4 +697,4 @@ func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16
xegpu.store_nd %6, %arg0 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
return
}
}
}
23 changes: 23 additions & 0 deletions mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: mlir-opt --test-xegpu-resolve-layout-conflicts -split-input-file %s | FileCheck %s

#load_lo = #xegpu.layout<inst_data = [8, 16]>
#prefetch_lo = #xegpu.layout<inst_data = [16, 16]>
gpu.module @test {

// CHECK-LABEL: func.func @load_nd_with_conflicting_tensor_desc
// CHECK: %{{.*}} = xegpu.create_nd_tdesc %{{.*}} : memref<64x64xf16>
// CHECK-SAME: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16]>>
// CHECK-NEXT: %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<64x64xf16>
// CHECK-SAME: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [8, 16]>>
// CHECK-NEXT: %{{.*}} = xegpu.load_nd %[[T1]][%{{.*}}, %{{.*}}] <{layout = #xegpu.layout<inst_data = [8, 16]>}> :
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [8, 16]>> -> vector<16x16xf16>
func.func @load_nd_with_conflicting_tensor_desc(%arg0: memref<64x64xf16>) -> vector<16x16xf16> {
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xf16>
-> !xegpu.tensor_desc<16x16xf16, #prefetch_lo>
%1 = xegpu.load_nd %0 [%c0, %c0] {layout = #load_lo} : !xegpu.tensor_desc<16x16xf16, #prefetch_lo>
-> vector<16x16xf16>
xegpu.prefetch_nd %0 [%c0, %c0] {layout = #prefetch_lo} : !xegpu.tensor_desc<16x16xf16, #prefetch_lo>
return %1 : vector<16x16xf16>
}
}
Loading