Skip to content

Commit eb3ca15

Browse files
committed
[mlir] Add a builtin distinct attribute
A distinct attribute associates a referenced attribute to a unique identifier. Every call to its create function allocates a new distinct attribute instance. The address of the attribute instance temporarily serves as its unique identifier. Similar to the names of SSA values, the final unique identifiers are generated during pretty printing. Examples: #distinct = distinct[0]<42.0 : f32> #distinct1 = distinct[1]<42.0 : f32> #distinct2 = distinct[2]<array<i32: 10, 42>> This mechanism is meant to generate attributes with a unique identifier, which can be used to mark groups of operations that share a common properties such as if they are aliasing. The design of the distinct attribute ensures minimal memory footprint per distinct attribute since it only contains a reference to another attribute. All distinct attributes are stored outside of the storage uniquer in a thread local store that is part of the context. It uses one bump pointer allocator per thread to ensure distinct attributes can be created in-parallel. Differential Revision: https://reviews.llvm.org/D153360
1 parent 3dd319e commit eb3ca15

23 files changed

+487
-7
lines changed

mlir/docs/Dialects/Builtin.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,37 @@ Operations.
2323

2424
[include "Dialects/BuiltinLocationAttributes.md"]
2525

26+
## DistinctAttribute
27+
28+
A DistinctAttribute associates an attribute to a unique identifier.
29+
As a result, multiple DistinctAttribute instances may point to the same
30+
attribute. Every call to the `create` function allocates a new
31+
DistinctAttribute instance. The address of the attribute instance serves as a
32+
temporary unique identifier. Similar to the names of SSA values, the final
33+
unique identifiers are generated during pretty printing. This delayed
34+
numbering ensures the printed identifiers are deterministic even if
35+
multiple DistinctAttribute instances are created in-parallel.
36+
37+
Syntax:
38+
39+
```
40+
distinct-id ::= integer-literal
41+
distinct-attribute ::= `distinct` `[` distinct-id `]<` attribute `>`
42+
```
43+
44+
Examples:
45+
46+
```mlir
47+
#distinct = distinct[0]<42.0 : f32>
48+
#distinct1 = distinct[1]<42.0 : f32>
49+
#distinct2 = distinct[2]<array<i32: 10, 42>>
50+
```
51+
52+
This mechanism is meant to generate attributes with a unique
53+
identifier, which can be used to mark groups of operations that share a
54+
common property. For example, groups of aliasing memory operations may be
55+
marked using one DistinctAttribute instance per alias group.
56+
2657
## Operations
2758

2859
[include "Dialects/BuiltinOps.md"]

mlir/include/mlir/IR/AttributeSupport.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,14 @@ class AbstractAttribute {
149149

150150
namespace detail {
151151
class AttributeUniquer;
152+
class DistinctAttributeUniquer;
152153
} // namespace detail
153154

154155
/// Base storage class appearing in an attribute. Derived storage classes should
155156
/// only be constructed within the context of the AttributeUniquer.
156157
class alignas(8) AttributeStorage : public StorageUniquer::BaseStorage {
157158
friend detail::AttributeUniquer;
159+
friend detail::DistinctAttributeUniquer;
158160
friend StorageUniquer;
159161

160162
public:

mlir/include/mlir/IR/BuiltinAttributes.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,43 @@ auto SparseElementsAttr::try_value_begin_impl(OverloadToken<T>) const
10001000
return iterator<T>(llvm::seq<ptrdiff_t>(0, getNumElements()).begin(), mapFn);
10011001
}
10021002

1003+
//===----------------------------------------------------------------------===//
1004+
// DistinctAttr
1005+
//===----------------------------------------------------------------------===//
1006+
1007+
namespace detail {
1008+
struct DistinctAttrStorage;
1009+
class DistinctAttributeUniquer;
1010+
} // namespace detail
1011+
1012+
/// An attribute that associates a referenced attribute to a unique identifier.
1013+
/// Every call to the create function allocates a new distinct attribute
1014+
/// instance. The address of the attribute instance serves as a temporary
1015+
/// identifier. Similar to the names of SSA values, the final identifiers are
1016+
/// generated during pretty printing. This delayed numbering ensures the
1017+
/// printed identifiers are deterministic even if multiple distinct attribute
1018+
/// instances are created in-parallel.
1019+
///
1020+
/// Examples:
1021+
///
1022+
/// #distinct = distinct[0]<42.0 : f32>
1023+
/// #distinct1 = distinct[1]<42.0 : f32>
1024+
/// #distinct2 = distinct[2]<array<i32: 10, 42>>
1025+
class DistinctAttr
1026+
: public detail::StorageUserBase<DistinctAttr, Attribute,
1027+
detail::DistinctAttrStorage,
1028+
detail::DistinctAttributeUniquer> {
1029+
public:
1030+
using Base::Base;
1031+
1032+
/// Returns the referenced attribute.
1033+
Attribute getReferencedAttr() const;
1034+
1035+
/// Creates a distinct attribute that associates a referenced attribute to a
1036+
/// unique identifier.
1037+
static DistinctAttr create(Attribute referencedAttr);
1038+
};
1039+
10031040
//===----------------------------------------------------------------------===//
10041041
// StringAttr
10051042
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/BuiltinDialectBytecode.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,10 @@ def SparseElementsAttr : DialectAttribute<(attr
181181
DenseElementsAttr:$values
182182
)>;
183183

184+
def DistinctAttr : DialectAttribute<(attr
185+
Attribute:$referencedAttr
186+
)>;
187+
184188
// Types
185189
// -----
186190

@@ -316,7 +320,8 @@ def BuiltinDialectAttributes : DialectAttributes<"Builtin"> {
316320
DenseArrayAttr,
317321
DenseIntOrFPElementsAttr,
318322
DenseStringElementsAttr,
319-
SparseElementsAttr
323+
SparseElementsAttr,
324+
DistinctAttr
320325
];
321326
}
322327

mlir/lib/AsmParser/AttributeParser.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ using namespace mlir::detail;
4646
/// `:` (tensor-type | vector-type)
4747
/// | `strided` `<` `[` comma-separated-int-or-question `]`
4848
/// (`,` `offset` `:` integer-literal)? `>`
49+
/// | distinct-attribute
4950
/// | extended-attribute
5051
///
5152
Attribute Parser::parseAttribute(Type type) {
@@ -155,6 +156,10 @@ Attribute Parser::parseAttribute(Type type) {
155156
case Token::kw_strided:
156157
return parseStridedLayoutAttr();
157158

159+
// Parse a distinct attribute.
160+
case Token::kw_distinct:
161+
return parseDistinctAttr(type);
162+
158163
// Parse a string attribute.
159164
case Token::string: {
160165
auto val = getToken().getStringValue();
@@ -1214,3 +1219,54 @@ Attribute Parser::parseStridedLayoutAttr() {
12141219
return StridedLayoutAttr::get(getContext(), *offset, strides);
12151220
// return getChecked<StridedLayoutAttr>(loc,getContext(), *offset, strides);
12161221
}
1222+
1223+
/// Parse a distinct attribute.
1224+
///
1225+
/// distinct-attribute ::= `distinct`
1226+
/// `[` integer-literal `]<` attribute-value `>`
1227+
///
1228+
Attribute Parser::parseDistinctAttr(Type type) {
1229+
consumeToken(Token::kw_distinct);
1230+
if (parseToken(Token::l_square, "expected '[' after 'distinct'"))
1231+
return {};
1232+
1233+
// Parse the distinct integer identifier.
1234+
Token token = getToken();
1235+
if (parseToken(Token::integer, "expected distinct ID"))
1236+
return {};
1237+
std::optional<uint64_t> value = token.getUInt64IntegerValue();
1238+
if (!value) {
1239+
emitError("expected an unsigned 64-bit integer");
1240+
return {};
1241+
}
1242+
1243+
// Parse the referenced attribute.
1244+
if (parseToken(Token::r_square, "expected ']' to close distinct ID") ||
1245+
parseToken(Token::less, "expected '<' after distinct ID"))
1246+
return {};
1247+
Attribute referencedAttr = parseAttribute(type);
1248+
if (!referencedAttr) {
1249+
emitError("expected attribute");
1250+
return {};
1251+
}
1252+
1253+
// Add the distinct attribute to the parser state, if it has not been parsed
1254+
// before. Otherwise, check if the parsed reference attribute matches the one
1255+
// found in the parser state.
1256+
DenseMap<uint64_t, DistinctAttr> &distinctAttrs =
1257+
state.symbols.distinctAttributes;
1258+
auto it = distinctAttrs.find(*value);
1259+
if (it == distinctAttrs.end()) {
1260+
DistinctAttr distinctAttr = DistinctAttr::create(referencedAttr);
1261+
it = distinctAttrs.try_emplace(*value, distinctAttr).first;
1262+
} else if (it->getSecond().getReferencedAttr() != referencedAttr) {
1263+
emitError("referenced attribute does not match previous definition: ")
1264+
<< it->getSecond().getReferencedAttr();
1265+
return {};
1266+
}
1267+
1268+
if (parseToken(Token::greater, "expected '>' to close distinct attribute"))
1269+
return {};
1270+
1271+
return it->getSecond();
1272+
}

mlir/lib/AsmParser/Parser.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ class Parser {
250250
/// Parse an attribute dictionary.
251251
ParseResult parseAttributeDict(NamedAttrList &attributes);
252252

253+
/// Parse a distinct attribute.
254+
Attribute parseDistinctAttr(Type type);
255+
253256
/// Parse an extended attribute.
254257
Attribute parseExtendedAttr(Type type);
255258

mlir/lib/AsmParser/ParserState.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ struct SymbolState {
3636
DenseMap<const OpAsmDialectInterface *,
3737
llvm::StringMap<std::pair<std::string, AsmDialectResourceHandle>>>
3838
dialectResources;
39+
40+
/// A map from unique integer identifier to DistinctAttr.
41+
DenseMap<uint64_t, DistinctAttr> distinctAttributes;
3942
};
4043

4144
//===----------------------------------------------------------------------===//

mlir/lib/AsmParser/TokenKinds.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ TOK_KEYWORD(ceildiv)
8989
TOK_KEYWORD(complex)
9090
TOK_KEYWORD(dense)
9191
TOK_KEYWORD(dense_resource)
92+
TOK_KEYWORD(distinct)
9293
TOK_KEYWORD(f16)
9394
TOK_KEYWORD(f32)
9495
TOK_KEYWORD(f64)

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,8 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
806806
} else if (llvm::isa<AffineMapAttr, DenseArrayAttr, FloatAttr, IntegerAttr,
807807
IntegerSetAttr, UnitAttr>(attr)) {
808808
return;
809+
} else if (auto distinctAttr = dyn_cast<DistinctAttr>(attr)) {
810+
printAttribute(distinctAttr.getReferencedAttr());
809811
} else if (auto dictAttr = dyn_cast<DictionaryAttr>(attr)) {
810812
for (const NamedAttribute &nestedAttr : dictAttr.getValue()) {
811813
printAttribute(nestedAttr.getName());
@@ -1604,6 +1606,31 @@ StringRef SSANameState::uniqueValueName(StringRef name) {
16041606
return name;
16051607
}
16061608

1609+
//===----------------------------------------------------------------------===//
1610+
// DistinctState
1611+
//===----------------------------------------------------------------------===//
1612+
1613+
namespace {
1614+
/// This class manages the state for distinct attributes.
1615+
class DistinctState {
1616+
public:
1617+
/// Returns a unique identifier for the given distinct attribute.
1618+
uint64_t getId(DistinctAttr distinctAttr);
1619+
1620+
private:
1621+
uint64_t distinctCounter = 0;
1622+
DenseMap<DistinctAttr, uint64_t> distinctAttrMap;
1623+
};
1624+
} // namespace
1625+
1626+
uint64_t DistinctState::getId(DistinctAttr distinctAttr) {
1627+
auto [it, inserted] =
1628+
distinctAttrMap.try_emplace(distinctAttr, distinctCounter);
1629+
if (inserted)
1630+
distinctCounter++;
1631+
return it->getSecond();
1632+
}
1633+
16071634
//===----------------------------------------------------------------------===//
16081635
// Resources
16091636
//===----------------------------------------------------------------------===//
@@ -1715,6 +1742,9 @@ class AsmStateImpl {
17151742
/// Get the state used for SSA names.
17161743
SSANameState &getSSANameState() { return nameState; }
17171744

1745+
/// Get the state used for distinct attribute identifiers.
1746+
DistinctState &getDistinctState() { return distinctState; }
1747+
17181748
/// Return the dialects within the context that implement
17191749
/// OpAsmDialectInterface.
17201750
DialectInterfaceCollection<OpAsmDialectInterface> &getDialectInterfaces() {
@@ -1758,6 +1788,9 @@ class AsmStateImpl {
17581788
/// The state used for SSA value names.
17591789
SSANameState nameState;
17601790

1791+
/// The state used for distinct attribute identifiers.
1792+
DistinctState distinctState;
1793+
17611794
/// Flags that control op output.
17621795
OpPrintingFlags printerFlags;
17631796

@@ -2106,6 +2139,11 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
21062139
} else if (llvm::isa<UnitAttr>(attr)) {
21072140
os << "unit";
21082141
return;
2142+
} else if (auto distinctAttr = llvm::dyn_cast<DistinctAttr>(attr)) {
2143+
os << "distinct[" << state.getDistinctState().getId(distinctAttr) << "]<";
2144+
printAttribute(distinctAttr.getReferencedAttr());
2145+
os << '>';
2146+
return;
21092147
} else if (auto dictAttr = llvm::dyn_cast<DictionaryAttr>(attr)) {
21102148
os << '{';
21112149
interleaveComma(dictAttr.getValue(),

mlir/lib/IR/AttributeDetail.h

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
#define ATTRIBUTEDETAIL_H_
1515

1616
#include "mlir/IR/AffineMap.h"
17+
#include "mlir/IR/AttributeSupport.h"
1718
#include "mlir/IR/BuiltinAttributes.h"
1819
#include "mlir/IR/BuiltinTypes.h"
1920
#include "mlir/IR/IntegerSet.h"
2021
#include "mlir/IR/MLIRContext.h"
2122
#include "mlir/Support/StorageUniquer.h"
23+
#include "mlir/Support/ThreadLocalCache.h"
2224
#include "llvm/ADT/APFloat.h"
2325
#include "llvm/ADT/PointerIntPair.h"
2426
#include "llvm/Support/TrailingObjects.h"
@@ -349,6 +351,70 @@ struct StringAttrStorage : public AttributeStorage {
349351
Dialect *referencedDialect;
350352
};
351353

354+
//===----------------------------------------------------------------------===//
355+
// DistinctAttr
356+
//===----------------------------------------------------------------------===//
357+
358+
/// An attribute to store a distinct reference to another attribute.
359+
struct DistinctAttrStorage : public AttributeStorage {
360+
using KeyTy = Attribute;
361+
362+
DistinctAttrStorage(Attribute referencedAttr)
363+
: referencedAttr(referencedAttr) {}
364+
365+
/// Returns the referenced attribute as key.
366+
KeyTy getAsKey() const { return KeyTy(referencedAttr); }
367+
368+
/// The referenced attribute.
369+
Attribute referencedAttr;
370+
};
371+
372+
/// A specialized attribute uniquer for distinct attributes that always
373+
/// allocates since the distinct attribute instances use the address of their
374+
/// storage as unique identifier.
375+
class DistinctAttributeUniquer {
376+
public:
377+
/// Creates a distinct attribute storage. Allocates every time since the
378+
/// address of the storage serves as unique identifier.
379+
template <typename T, typename... Args>
380+
static T get(MLIRContext *context, Args &&...args) {
381+
static_assert(std::is_same_v<typename T::ImplType, DistinctAttrStorage>,
382+
"expects a distinct attribute storage");
383+
DistinctAttrStorage *storage = DistinctAttributeUniquer::allocateStorage(
384+
context, std::forward<Args>(args)...);
385+
storage->initializeAbstractAttribute(
386+
AbstractAttribute::lookup(DistinctAttr::getTypeID(), context));
387+
return storage;
388+
}
389+
390+
private:
391+
/// Allocates a distinct attribute storage.
392+
static DistinctAttrStorage *allocateStorage(MLIRContext *context,
393+
Attribute referencedAttr);
394+
};
395+
396+
/// An allocator for distinct attribute storage instances. It uses thread local
397+
/// bump pointer allocators stored in a thread local cache to ensure the storage
398+
/// is freed after the destruction of the distinct attribute allocator.
399+
class DistinctAttributeAllocator {
400+
public:
401+
DistinctAttributeAllocator() = default;
402+
403+
DistinctAttributeAllocator(DistinctAttributeAllocator &&) = delete;
404+
DistinctAttributeAllocator(const DistinctAttributeAllocator &) = delete;
405+
DistinctAttributeAllocator &
406+
operator=(const DistinctAttributeAllocator &) = delete;
407+
408+
/// Allocates a distinct attribute storage using a thread local bump pointer
409+
/// allocator to enable synchronization free parallel allocations.
410+
DistinctAttrStorage *allocate(Attribute referencedAttr) {
411+
return new (allocatorCache.get().Allocate<DistinctAttrStorage>())
412+
DistinctAttrStorage(referencedAttr);
413+
}
414+
415+
private:
416+
ThreadLocalCache<llvm::BumpPtrAllocator> allocatorCache;
417+
};
352418
} // namespace detail
353419
} // namespace mlir
354420

0 commit comments

Comments
 (0)