-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[mlir] Add dialect hooks for registering custom type and attribute alias printers #173091
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
…ias printers This patch introduces a mechanism for dialects to register custom alias printers for types and attributes via the `OpAsmDialectInterface`. This allows dialects to provide alternative printed representations for types and attributes based on their TypeID, including types/attributes from other dialects. The new `registerAttrAliasPrinter` and `registerTypeAliasPrinter` virtual methods accept callbacks that register printers for specific TypeIDs. When printing, these custom printers are invoked in registration order, and the first one to produce output is used. The precedence for alias resolution is: 1. Explicit type/attribute aliases returned by `getAlias` 2. Dialect-specific alias printers registered via the new hooks 3. Default type/attribute printers Signed-off-by: Fabian Mora <[email protected]>
3fc2f91 to
797c073
Compare
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Fabian Mora (fabianmcg) ChangesThis patch introduces a mechanism for dialects to register custom alias printers for types and attributes via the The new The precedence for alias resolution is:
Example: struct TestOpAsmInterface : public OpAsmDialectInterface {
// ...
void registerAttrAliasPrinter(InsertAttrAliasPrinter insertFn) const final {
insertFn(TypeID::get<TypeAttr>(),
[](Attribute attr, AsmPrinter &printer, bool printStripped) {
auto tTy = dyn_cast<TupleType>(cast<TypeAttr>(attr).getValue());
if (!tTy || tTy.size() != 1)
return;
if (auto iTy = dyn_cast<TestIntegerType>(tTy.getType(0));
!iTy || iTy.getWidth() != 7)
return;
printer.getStream() << "tuple_i7_from_attr";
});
}
void registerTypeAliasPrinter(InsertTypeAliasPrinter insertFn) const final {
insertFn(TypeID::get<TupleType>(),
[](Type type, AsmPrinter &printer, bool printStripped) {
auto tTy = dyn_cast<TupleType>(type);
if (!tTy || tTy.size() != 1)
return;
if (auto iTy = dyn_cast<TestIntegerType>(tTy.getType(0));
!iTy || iTy.getWidth() != 7)
return;
printer.getStream() << "tuple_i7";
});
}
// ...
};Resulting MLIR printing behavior. "test.op"() {types = [
tuple<!test.int<s, 7>>
]} : () -> (
tuple<!test.int<s, 7>>
)
// Prints as:
%0 = "test.op"() {types = [!test.tuple_i7_from_attr]} : () -> !test.tuple_i7NOTE: It's future work integrating this into ODS. Patch is 20.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/173091.diff 4 Files Affected:
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index d70aa346eaa1f..a6d39c4d17a8d 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -17,6 +17,8 @@
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/OpAsmSupport.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/TypeID.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/SMLoc.h"
#include <optional>
@@ -169,6 +171,9 @@ class AsmPrinter {
if (succeeded(printAlias(attrOrType)))
return;
+ if (succeeded(printDialectAlias(attrOrType, /*printStripped=*/true)))
+ return;
+
raw_ostream &os = getStream();
uint64_t posPrior = os.tell();
attrOrType.print(*this);
@@ -218,6 +223,14 @@ class AsmPrinter {
/// be printed.
virtual LogicalResult printAlias(Type type);
+ /// Print the alias for the given attribute, return failure if no alias could
+ /// be printed.
+ virtual LogicalResult printDialectAlias(Attribute attr, bool printStripped);
+
+ /// Print the alias for the given type, return failure if no alias could
+ /// be printed.
+ virtual LogicalResult printDialectAlias(Type type, bool printStripped);
+
/// Print the given string as a keyword, or a quoted and escaped string if it
/// has any special or non-printable characters in it.
virtual void printKeywordOrString(StringRef keyword);
@@ -1799,6 +1812,30 @@ class OpAsmDialectInterface
return AliasResult::NoAlias;
}
+ /// Hooks for registering alias printers for types and attributes. These
+ /// printers are invoked when printing types or attributes of the given
+ /// TypeID. Printers are invoked in the order they are registered, and the
+ /// first one to print an alias is used.
+ /// The precedence of these printers is as follow:
+ /// 1. The type and attribute aliases returned by `getAlias`.
+ /// 2. Dialect-specific alias printers registered here.
+ /// 3. The type and attribute printers.
+ /// The boolean argument to the printer indicates whether the stripped form
+ /// of the type or attribute is being printed.
+ /// NOTE: This mechanism caches the printed object, therefore the printer
+ /// must always produce the same output for the same input.
+ using AttributeAliasPrinter =
+ llvm::function_ref<void(Attribute, AsmPrinter &, bool)>;
+ using InsertAttrAliasPrinter =
+ llvm::function_ref<void(TypeID, AttributeAliasPrinter)>;
+ virtual void registerAttrAliasPrinter(InsertAttrAliasPrinter insertFn) const {
+ }
+ using TypeAliasPrinter = llvm::function_ref<void(Type, AsmPrinter &, bool)>;
+ using InsertTypeAliasPrinter =
+ llvm::function_ref<void(TypeID, TypeAliasPrinter)>;
+ virtual void registerTypeAliasPrinter(InsertTypeAliasPrinter insertFn) const {
+ }
+
//===--------------------------------------------------------------------===//
// Resources
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 7d991cea6c468..86fa754b3bb54 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -32,13 +32,16 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/PointerIntPair.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/Endian.h"
@@ -47,6 +50,7 @@
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/Threading.h"
#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
#include <type_traits>
#include <optional>
@@ -414,6 +418,7 @@ class AsmPrinter::Impl {
public:
Impl(raw_ostream &os, AsmStateImpl &state);
explicit Impl(Impl &other) : Impl(other.os, other.state) {}
+ explicit Impl(raw_ostream &os, Impl &other) : Impl(os, other.state) {}
/// Returns the output stream of the printer.
raw_ostream &getStream() { return os; }
@@ -447,6 +452,10 @@ class AsmPrinter::Impl {
/// be printed.
LogicalResult printAlias(Attribute attr);
+ /// Print the dialect alias for the given attribute, return failure if no
+ /// alias could be printed.
+ LogicalResult printDialectAlias(Attribute attr, bool printStripped);
+
/// Print the given type or an alias.
void printType(Type type);
/// Print the given type.
@@ -456,6 +465,10 @@ class AsmPrinter::Impl {
/// be printed.
LogicalResult printAlias(Type type);
+ /// Print the dialect alias for the given type, return failure if no alias
+ /// could be printed.
+ LogicalResult printDialectAlias(Type type, bool printStripped);
+
/// Print the given location to the stream. If `allowAlias` is true, this
/// allows for the internal location to use an attribute alias.
void printLocation(LocationAttr loc, bool allowAlias = false);
@@ -812,6 +825,12 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
initializer.visit(type);
return success();
}
+ LogicalResult printDialectAlias(Attribute attr, bool printStripped) override {
+ return failure();
+ }
+ LogicalResult printDialectAlias(Type type, bool printStripped) override {
+ return failure();
+ }
/// Consider the given location to be printed for an alias.
void printOptionalLocationSpecifier(Location loc) override {
@@ -991,6 +1010,11 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
return success();
}
+ LogicalResult printDialectAlias(Attribute, bool) override {
+ return failure();
+ }
+ LogicalResult printDialectAlias(Type, bool) override { return failure(); }
+
/// Record the alias result of a child element.
void recordAliasResult(std::pair<size_t, size_t> aliasDepthAndIndex) {
childIndices.push_back(aliasDepthAndIndex.second);
@@ -1251,7 +1275,10 @@ namespace {
/// This class manages the state for type and attribute aliases.
class AliasState {
public:
- // Initialize the internal aliases.
+ /// Initialize the alias state for custom dialect aliases.
+ AliasState(DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
+
+ /// Initialize the internal aliases.
void
initialize(Operation *op, const OpPrintingFlags &printerFlags,
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
@@ -1275,20 +1302,99 @@ class AliasState {
printAliases(p, newLine, /*isDeferred=*/true);
}
+ /// Get an attribute alias if it exists. Returns the alias if found,
+ /// or a default constructed pair otherwise.
+ std::pair<const Dialect *, StringRef>
+ getAttrAlias(AsmPrinter::Impl &p, Attribute attr, bool printStripped) {
+ return getAlias(p, attr.getTypeID(), attr.getAsOpaquePointer(),
+ printStripped);
+ }
+
+ /// Get a type alias if it exists. Returns the alias if found,
+ /// or a default constructed pair otherwise.
+ std::pair<const Dialect *, StringRef>
+ getTypeAlias(AsmPrinter::Impl &p, Type type, bool printStripped) {
+ return getAlias(p, type.getTypeID(), type.getAsOpaquePointer(),
+ printStripped);
+ }
+
private:
+ using TypeIDPrinter =
+ std::tuple<TypeID, const Dialect *,
+ std::function<void(const void *, AsmPrinter &, bool)>>;
+ using PrinterIterator = SmallVectorImpl<TypeIDPrinter>::iterator;
+
/// Print all of the referenced aliases that support the provided resolution
/// behavior.
void printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
bool isDeferred);
+ /// Comparison function for TypeIDPrinter.
+ static bool comparePrinters(const TypeIDPrinter &lhs,
+ const TypeIDPrinter &rhs);
+
+ /// Find custom printers for the given TypeID.
+ llvm::iterator_range<PrinterIterator> findPrinters(TypeID typeID);
+
+ /// Get an attribute or type alias if it exists. Returns the alias if found,
+ /// or a default constructed pair otherwise.
+ std::pair<const Dialect *, StringRef> getAlias(AsmPrinter::Impl &p,
+ TypeID typeID,
+ const void *opaqueAttrType,
+ bool printStripped);
+
/// Mapping between attribute/type and alias.
llvm::MapVector<const void *, SymbolAlias> attrTypeToAlias;
/// An allocator used for alias names.
llvm::BumpPtrAllocator aliasAllocator;
+
+ /// Mapping between attribute/type ID and custom printers for them.
+ SmallVector<TypeIDPrinter> attrTypePrinters;
+
+ /// Cache for custom printed attributes/types.
+ DenseMap<llvm::PointerIntPair<const void *, 1, bool>,
+ std::pair<const Dialect *, std::string>>
+ printCache;
};
} // namespace
+AliasState::AliasState(
+ DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
+ // Collect all of the custom alias printers.
+ for (const OpAsmDialectInterface &interface : interfaces) {
+ auto insertAliasAttrFn =
+ [&](TypeID typeID,
+ OpAsmDialectInterface::AttributeAliasPrinter printer) {
+ if (!printer)
+ return;
+ attrTypePrinters.emplace_back(
+ typeID, interface.getDialect(),
+ [printer](const void *attr, AsmPrinter &p, bool printStripped) {
+ printer(Attribute::getFromOpaquePointer(attr), p,
+ printStripped);
+ });
+ };
+ auto insertAliasTypeFn =
+ [&](TypeID typeID, OpAsmDialectInterface::TypeAliasPrinter printer) {
+ if (!printer)
+ return;
+ attrTypePrinters.emplace_back(
+ typeID, interface.getDialect(),
+ [printer](const void *attr, AsmPrinter &p, bool printStripped) {
+ printer(Type::getFromOpaquePointer(attr), p, printStripped);
+ });
+ };
+ interface.registerAttrAliasPrinter(insertAliasAttrFn);
+ interface.registerTypeAliasPrinter(insertAliasTypeFn);
+ }
+
+ // Sort the printers by TypeID for efficient lookup.
+ // Stable sort guarantees that the order of registration is preserved.
+ std::stable_sort(attrTypePrinters.begin(), attrTypePrinters.end(),
+ comparePrinters);
+}
+
void AliasState::initialize(
Operation *op, const OpPrintingFlags &printerFlags,
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
@@ -1315,6 +1421,24 @@ LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
return success();
}
+bool AliasState::comparePrinters(const TypeIDPrinter &lhs,
+ const TypeIDPrinter &rhs) {
+ return std::get<0>(lhs).getAsOpaquePointer() <
+ std::get<0>(rhs).getAsOpaquePointer();
+}
+
+llvm::iterator_range<AliasState::PrinterIterator>
+AliasState::findPrinters(TypeID typeID) {
+ TypeIDPrinter key = std::make_tuple(
+ typeID, /*unused*/ nullptr,
+ /*unused*/ std::function<void(const void *, AsmPrinter &, bool)>());
+ PrinterIterator lb = std::lower_bound(
+ attrTypePrinters.begin(), attrTypePrinters.end(), key, comparePrinters);
+ PrinterIterator ub = std::upper_bound(
+ attrTypePrinters.begin(), attrTypePrinters.end(), key, comparePrinters);
+ return llvm::make_range(lb, ub);
+}
+
void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
bool isDeferred) {
auto filterFn = [=](const auto &aliasIt) {
@@ -1342,6 +1466,40 @@ void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
}
}
+std::pair<const Dialect *, StringRef>
+AliasState::getAlias(AsmPrinter::Impl &p, TypeID typeID,
+ const void *opaqueAttrType, bool printStripped) {
+ llvm::PointerIntPair<const void *, 1, bool> key(opaqueAttrType,
+ printStripped);
+ // Check the cache first.
+ if (auto it = printCache.find(key); it != printCache.end())
+ return it->second;
+
+ // Try to get the alias using custom printers.
+ std::string buffer;
+ llvm::raw_string_ostream os(buffer);
+ AsmPrinter::Impl printImpl(os, p);
+ DialectAsmPrinter printer(printImpl);
+ for (const auto &printInfo : findPrinters(typeID)) {
+ // Invoke the printer.
+ std::get<2>(printInfo)(opaqueAttrType, printer, printStripped);
+
+ // Trim any whitespace.
+ if (StringRef str = StringRef(buffer).trim(); str != buffer)
+ buffer = str.str();
+
+ // If we printed something, cache and return.
+ if (!buffer.empty()) {
+ StringRef alias = (printCache[key] = std::make_pair(
+ std::get<1>(printInfo), std::move(buffer)))
+ .second;
+ return std::make_pair(std::get<1>(printInfo), alias);
+ }
+ buffer.clear();
+ }
+ return {nullptr, StringRef()};
+}
+
//===----------------------------------------------------------------------===//
// SSANameState
//===----------------------------------------------------------------------===//
@@ -1948,11 +2106,13 @@ class AsmStateImpl {
public:
explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags,
AsmState::LocationMap *locationMap)
- : interfaces(op->getContext()), nameState(op, printerFlags),
- printerFlags(printerFlags), locationMap(locationMap) {}
+ : interfaces(op->getContext()), aliasState(interfaces),
+ nameState(op, printerFlags), printerFlags(printerFlags),
+ locationMap(locationMap) {}
explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
AsmState::LocationMap *locationMap)
- : interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {}
+ : interfaces(ctx), aliasState(interfaces), printerFlags(printerFlags),
+ locationMap(locationMap) {}
/// Initialize the alias state to enable the printing of aliases.
void initializeAliases(Operation *op) {
@@ -2377,6 +2537,38 @@ LogicalResult AsmPrinter::Impl::printAlias(Type type) {
return state.getAliasState().getAlias(type, os);
}
+LogicalResult AsmPrinter::Impl::printDialectAlias(Attribute attr,
+ bool printStripped) {
+ // Check to see if there is a dialect alias for this attribute.
+ auto [aliasDialect, alias] =
+ state.getAliasState().getAttrAlias(*this, attr, printStripped);
+ if (aliasDialect && !alias.empty()) {
+ if (printStripped) {
+ os << alias;
+ return success();
+ }
+ printDialectSymbol(os, "!", aliasDialect->getNamespace(), alias);
+ return success();
+ }
+ return failure();
+}
+
+LogicalResult AsmPrinter::Impl::printDialectAlias(Type type,
+ bool printStripped) {
+ // Check to see if there is a dialect alias for this type.
+ auto [aliasDialect, alias] =
+ state.getAliasState().getTypeAlias(*this, type, printStripped);
+ if (aliasDialect && !alias.empty()) {
+ if (printStripped) {
+ os << alias;
+ return success();
+ }
+ printDialectSymbol(os, "!", aliasDialect->getNamespace(), alias);
+ return success();
+ }
+ return failure();
+}
+
void AsmPrinter::Impl::printAttribute(Attribute attr,
AttrTypeElision typeElision) {
if (!attr) {
@@ -2387,6 +2579,8 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
// Try to print an alias for this attribute.
if (succeeded(printAlias(attr)))
return;
+ if (succeeded(printDialectAlias(attr, /*printStripped=*/false)))
+ return;
return printAttributeImpl(attr, typeElision);
}
void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
@@ -2715,6 +2909,8 @@ void AsmPrinter::Impl::printType(Type type) {
// Try to print an alias for this type.
if (succeeded(printAlias(type)))
return;
+ if (succeeded(printDialectAlias(type, /*printStripped=*/false)))
+ return;
return printTypeImpl(type);
}
@@ -2987,6 +3183,17 @@ LogicalResult AsmPrinter::printAlias(Type type) {
return impl->printAlias(type);
}
+LogicalResult AsmPrinter::printDialectAlias(Attribute attr,
+ bool printStripped) {
+ assert(impl && "expected AsmPrinter::printDialectAlias to be overriden");
+ return impl->printAlias(attr);
+}
+
+LogicalResult AsmPrinter::printDialectAlias(Type type, bool printStripped) {
+ assert(impl && "expected AsmPrinter::printDialectAlias to be overriden");
+ return impl->printAlias(type);
+}
+
void AsmPrinter::printAttributeWithoutType(Attribute attr) {
assert(impl &&
"expected AsmPrinter::printAttributeWithoutType to be overriden");
diff --git a/mlir/test/IR/print-attr-type-dialect-aliases.mlir b/mlir/test/IR/print-attr-type-dialect-aliases.mlir
new file mode 100644
index 0000000000000..95adb49d0a59a
--- /dev/null
+++ b/mlir/test/IR/print-attr-type-dialect-aliases.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// Check that attr and type aliases are properly printed.
+
+// CHECK: {types = [!test.tuple_i7_from_attr, !test.tuple_i6_from_attr, tuple<!test.int<signed, 5>>]} : () -> (!test.tuple_i7, !test.tuple_i6, tuple<!test.int<signed, 5>>)
+"test.op"() {types = [
+ tuple<!test.int<s, 7>>,
+ tuple<!test.int<s, 6>>,
+ tuple<!test.int<s, 5>>
+ ]} : () -> (
+ tuple<!test.int<s, 7>>,
+ tuple<!test.int<s, 6>>,
+ tuple<!test.int<s, 5>>
+ )
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 3d4aa23ebe78a..8061fadb1da95 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -8,6 +8,8 @@
#include "TestDialect.h"
#include "TestOps.h"
+#include "TestTypes.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/InliningUtils.h"
@@ -262,6 +264,52 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
return AliasResult::NoAlias;
}
+ void registerAttrAliasPrinter(InsertAttrAliasPrinter insertFn) const final {
+ insertFn(TypeID::get<TypeAttr>(),
+ [](Attribute attr, AsmPrinter &printer, bool printStripped) {
+ auto tTy = dyn_cast<TupleType>(cast<TypeAttr>(attr).getValue());
+ if (!tTy || tTy.size() != 1)
+ return;
+ if (auto iTy = dyn_cast<TestIntegerType>(tTy.getType(0));
+ !iTy || iTy.getWidth() != 7)
+ return;
+ printer.getStream() << "tuple_i7_from_attr";
+ });
+ insertFn(TypeID::get<TypeAttr>(),
+ [](Attribute attr, AsmPrinter &printer, bool printStripped) {
+ auto tTy = dyn_cast<TupleType>(cast<TypeAttr>(attr).getValue());
+ if (!tTy || tTy.size() != 1)
+ return;
+ if (auto iTy = dyn_cast<TestIntegerType>(tTy.getType(0));
+ !iTy || iTy.getWidth() != 6)
+ return;
+ printer.getStream() << "tuple_i6_from_attr";
+ });
+ }
+
+ void registerTypeAliasPrinter(InsertTypeAliasPrinter insertFn) const final {
+ insertFn(TypeID::get<TupleType>(),
+ [](Type type, AsmPrinter &printer, bool printStripped) {
+ auto tTy = dyn_cast<TupleType>(type);
+ if (!tTy || tTy.size() != 1)
+ return;
+ if (auto iTy = dyn_cast<TestIntegerType>(tTy.getType(0));
+ !iTy || iTy.getWidth() != 7)
+ return;
+ printer.getStream() << "tuple_i7";
+ });
+ insertFn(TypeID::get<TupleType>(),
+ [](Type type, AsmPrinter &printer, bool printStripped) {
+ auto tTy = dyn_cast<TupleType>(type);
+ if (!tTy || tTy.size() != 1)
+ return;
+ if (auto iTy = dyn_cast<TestIntegerType>(tTy.getType(0));
+ !iTy || iTy.getWidth() != 6)
+ return;
+ printer.getStream(...
[truncated]
|
joker-eph
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add documentation please?
|
So this would mean that depending on what dialect is registered one gets different output. You have a file with only Ops from dialect A, but depending on if the tool registers dialect B or not, one gets different output? |
| // Check that attr and type aliases are properly printed. | ||
|
|
||
| // CHECK: {types = [!test.tuple_i7_from_attr, !test.tuple_i6_from_attr, tuple<!test.int<signed, 5>>]} : () -> (!test.tuple_i7, !test.tuple_i6, tuple<!test.int<signed, 5>>) | ||
| "test.op"() {types = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you attach this to a non-test op (like module), we need to verify that the parse can load the test dialect when it encounters an alias for an on-loaded dialect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This patch doesn't affect parsing behavior. That is: !test.tuple_i7_from_attrcan't be parsed without additional logic and telling the dialect how to parse. This patch only affects how to print.
Parsing and automatic hook creation is part of the future work with ODS.
It is even further: it'll depends on if B gets actually loaded. I would be more comfortable with a solution that restricts this to only entities of the dialect itself. This kind of cross-dialect interactions have too many pitfalls. |
On the website pages or where? Because the functions have doc comments.
Yes, that's why the public functions are called dialect alias. The alias is only printed if the dialect adding the alias is loaded. |
I meant the markdown doc.
That's a problem to me, this is the kind of thing that belongs to the |
| /// Hooks for registering alias printers for types and attributes. These | ||
| /// printers are invoked when printing types or attributes of the given | ||
| /// TypeID. Printers are invoked in the order they are registered, and the | ||
| /// first one to print an alias is used. | ||
| /// The precedence of these printers is as follow: | ||
| /// 1. The type and attribute aliases returned by `getAlias`. | ||
| /// 2. Dialect-specific alias printers registered here. | ||
| /// 3. The type and attribute printers. | ||
| /// The boolean argument to the printer indicates whether the stripped form | ||
| /// of the type or attribute is being printed. | ||
| /// NOTE: This mechanism caches the printed object, therefore the printer | ||
| /// must always produce the same output for the same input. | ||
| using AttributeAliasPrinter = | ||
| llvm::function_ref<void(Attribute, AsmPrinter &, bool)>; | ||
| using InsertAttrAliasPrinter = | ||
| llvm::function_ref<void(TypeID, AttributeAliasPrinter)>; | ||
| virtual void registerAttrAliasPrinter(InsertAttrAliasPrinter insertFn) const { | ||
| } | ||
| using TypeAliasPrinter = llvm::function_ref<void(Type, AsmPrinter &, bool)>; | ||
| using InsertTypeAliasPrinter = | ||
| llvm::function_ref<void(TypeID, TypeAliasPrinter)>; | ||
| virtual void registerTypeAliasPrinter(InsertTypeAliasPrinter insertFn) const { | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@joker-eph answering here, the functions are part of OpAsmDialectInterface. The way it works is, the printer collects this hooks, and then invokes them.
This patch introduces a mechanism for dialects to register custom alias printers for types and attributes via the
OpAsmDialectInterface. This allows dialects to provide alternative printed representations for types and attributes based on their TypeID, including types/attributes from other dialects.The new
registerAttrAliasPrinterandregisterTypeAliasPrintervirtual methods accept callbacks that register printers for specific TypeIDs. When printing, these custom printers are invoked in registration order, and the first one to produce output is used.The precedence for alias resolution is:
getAliasExample:
Overriding code:
Resulting MLIR printing behavior.
NOTE: It's future work integrating this into ODS.