Skip to content

Commit 7b5b5e9

Browse files
authored
Merge pull request #81171 from slavapestov/fix-issue-81036
AST: Use weighted reduction order for opaque return types
2 parents 3c26321 + 8ab5ca2 commit 7b5b5e9

16 files changed

+175
-95
lines changed

include/swift/AST/SubstitutionMap.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,6 @@ class LookUpConformanceInSubstitutionMap {
306306
};
307307

308308
struct OverrideSubsInfo {
309-
ASTContext &Ctx;
310309
unsigned BaseDepth;
311310
unsigned OrigDepth;
312311
SubstitutionMap BaseSubMap;

include/swift/AST/Types.h

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7254,9 +7254,10 @@ class GenericTypeParamType : public SubstitutableType,
72547254
Identifier Name;
72557255
};
72567256

7257-
unsigned Depth : 15;
72587257
unsigned IsDecl : 1;
7259-
unsigned Index : 16;
7258+
unsigned Depth : 15;
7259+
unsigned Weight : 1;
7260+
unsigned Index : 15;
72607261

72617262
/// The kind of generic type parameter this is.
72627263
GenericTypeParamKind ParamKind;
@@ -7281,15 +7282,21 @@ class GenericTypeParamType : public SubstitutableType,
72817282
Type valueType, const ASTContext &ctx);
72827283

72837284
/// Retrieve a canonical generic type parameter with the given kind, depth,
7284-
/// index, and optional value type.
7285+
/// index, weight, and optional value type.
72857286
static GenericTypeParamType *get(GenericTypeParamKind paramKind,
7286-
unsigned depth, unsigned index,
7287+
unsigned depth, unsigned index, unsigned weight,
72877288
Type valueType, const ASTContext &ctx);
72887289

7289-
/// Retrieve a canonical generic type parameter at the given depth and index.
7290+
/// Retrieve a canonical generic type parameter at the given depth and index,
7291+
/// with weight 0.
72907292
static GenericTypeParamType *getType(unsigned depth, unsigned index,
72917293
const ASTContext &ctx);
72927294

7295+
/// Retrieve a canonical generic type parameter at the given depth and index
7296+
/// for an opaque result type, so with weight 1.
7297+
static GenericTypeParamType *getOpaqueResultType(unsigned depth, unsigned index,
7298+
const ASTContext &ctx);
7299+
72937300
/// Retrieve a canonical generic parameter pack at the given depth and index.
72947301
static GenericTypeParamType *getPack(unsigned depth, unsigned index,
72957302
const ASTContext &ctx);
@@ -7345,6 +7352,14 @@ class GenericTypeParamType : public SubstitutableType,
73457352
return Index;
73467353
}
73477354

7355+
/// The weight of this generic parameter in the type parameter order.
7356+
///
7357+
/// Opaque result types have weight 1, while all other generic parameters
7358+
/// have weight 0.
7359+
unsigned getWeight() const {
7360+
return Weight;
7361+
}
7362+
73487363
/// Returns \c true if this type parameter is declared as a pack.
73497364
///
73507365
/// \code
@@ -7366,20 +7381,24 @@ class GenericTypeParamType : public SubstitutableType,
73667381

73677382
Type getValueType() const;
73687383

7384+
GenericTypeParamType *withDepth(unsigned depth) const;
7385+
73697386
void Profile(llvm::FoldingSetNodeID &ID) {
73707387
// Note: We explicitly don't use 'getName()' because for canonical forms
73717388
// which don't store an identifier we'll go create a tau based form. We
73727389
// really want to just plumb down the null Identifier because that's what's
73737390
// inside the cache.
7374-
Profile(ID, getParamKind(), getDepth(), getIndex(), getValueType(),
7375-
Name);
7391+
Profile(ID, getParamKind(), getDepth(), getIndex(), getWeight(),
7392+
getValueType(), Name);
73767393
}
73777394
static void Profile(llvm::FoldingSetNodeID &ID,
73787395
GenericTypeParamKind paramKind, unsigned depth,
7379-
unsigned index, Type valueType, Identifier name) {
7396+
unsigned index, unsigned weight, Type valueType,
7397+
Identifier name) {
73807398
ID.AddInteger((uint8_t)paramKind);
73817399
ID.AddInteger(depth);
73827400
ID.AddInteger(index);
7401+
ID.AddInteger(weight);
73837402
ID.AddPointer(valueType.getPointer());
73847403
ID.AddPointer(name.get());
73857404
}
@@ -7402,7 +7421,7 @@ class GenericTypeParamType : public SubstitutableType,
74027421
const ASTContext &ctx);
74037422

74047423
explicit GenericTypeParamType(GenericTypeParamKind paramKind, unsigned depth,
7405-
unsigned index, Type valueType,
7424+
unsigned index, unsigned weight, Type valueType,
74067425
RecursiveTypeProperties props,
74077426
const ASTContext &ctx);
74087427
};
@@ -7412,6 +7431,11 @@ static CanGenericTypeParamType getType(unsigned depth, unsigned index,
74127431
return CanGenericTypeParamType(
74137432
GenericTypeParamType::getType(depth, index, C));
74147433
}
7434+
static CanGenericTypeParamType getOpaqueResultType(unsigned depth, unsigned index,
7435+
const ASTContext &C) {
7436+
return CanGenericTypeParamType(
7437+
GenericTypeParamType::getOpaqueResultType(depth, index, C));
7438+
}
74157439
END_CAN_TYPE_WRAPPER(GenericTypeParamType, SubstitutableType)
74167440

74177441
/// A type that refers to a member type of some type that is dependent on a

lib/AST/ASTContext.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5099,8 +5099,8 @@ GenericTypeParamType *GenericTypeParamType::get(Identifier name,
50995099
Type valueType,
51005100
const ASTContext &ctx) {
51015101
llvm::FoldingSetNodeID id;
5102-
GenericTypeParamType::Profile(id, paramKind, depth, index, valueType,
5103-
name);
5102+
GenericTypeParamType::Profile(id, paramKind, depth, index, /*weight=*/0,
5103+
valueType, name);
51045104

51055105
void *insertPos;
51065106
if (auto gpTy = ctx.getImpl().GenericParamTypes.FindNodeOrInsertPos(id, insertPos))
@@ -5110,8 +5110,8 @@ GenericTypeParamType *GenericTypeParamType::get(Identifier name,
51105110
if (paramKind == GenericTypeParamKind::Pack)
51115111
props |= RecursiveTypeProperties::HasParameterPack;
51125112

5113-
auto canType = GenericTypeParamType::get(paramKind, depth, index, valueType,
5114-
ctx);
5113+
auto canType = GenericTypeParamType::get(paramKind, depth, index, /*weight=*/0,
5114+
valueType, ctx);
51155115

51165116
auto result = new (ctx, AllocationArena::Permanent)
51175117
GenericTypeParamType(name, canType, ctx);
@@ -5130,10 +5130,10 @@ GenericTypeParamType *GenericTypeParamType::get(GenericTypeParamDecl *param) {
51305130

51315131
GenericTypeParamType *GenericTypeParamType::get(GenericTypeParamKind paramKind,
51325132
unsigned depth, unsigned index,
5133-
Type valueType,
5133+
unsigned weight, Type valueType,
51345134
const ASTContext &ctx) {
51355135
llvm::FoldingSetNodeID id;
5136-
GenericTypeParamType::Profile(id, paramKind, depth, index, valueType,
5136+
GenericTypeParamType::Profile(id, paramKind, depth, index, weight, valueType,
51375137
Identifier());
51385138

51395139
void *insertPos;
@@ -5145,7 +5145,7 @@ GenericTypeParamType *GenericTypeParamType::get(GenericTypeParamKind paramKind,
51455145
props |= RecursiveTypeProperties::HasParameterPack;
51465146

51475147
auto result = new (ctx, AllocationArena::Permanent)
5148-
GenericTypeParamType(paramKind, depth, index, valueType, props, ctx);
5148+
GenericTypeParamType(paramKind, depth, index, weight, valueType, props, ctx);
51495149
ctx.getImpl().GenericParamTypes.InsertNode(result, insertPos);
51505150
return result;
51515151
}
@@ -5154,22 +5154,29 @@ GenericTypeParamType *GenericTypeParamType::getType(unsigned depth,
51545154
unsigned index,
51555155
const ASTContext &ctx) {
51565156
return GenericTypeParamType::get(GenericTypeParamKind::Type, depth, index,
5157-
/*valueType*/ Type(), ctx);
5157+
/*weight=*/0, /*valueType=*/Type(), ctx);
5158+
}
5159+
5160+
GenericTypeParamType *GenericTypeParamType::getOpaqueResultType(unsigned depth,
5161+
unsigned index,
5162+
const ASTContext &ctx) {
5163+
return GenericTypeParamType::get(GenericTypeParamKind::Type, depth, index,
5164+
/*weight=*/1, /*valueType=*/Type(), ctx);
51585165
}
51595166

51605167
GenericTypeParamType *GenericTypeParamType::getPack(unsigned depth,
51615168
unsigned index,
51625169
const ASTContext &ctx) {
51635170
return GenericTypeParamType::get(GenericTypeParamKind::Pack, depth, index,
5164-
/*valueType*/ Type(), ctx);
5171+
/*weight=*/0, /*valueType=*/Type(), ctx);
51655172
}
51665173

51675174
GenericTypeParamType *GenericTypeParamType::getValue(unsigned depth,
51685175
unsigned index,
51695176
Type valueType,
51705177
const ASTContext &ctx) {
51715178
return GenericTypeParamType::get(GenericTypeParamKind::Value, depth, index,
5172-
valueType, ctx);
5179+
/*weight=*/0, valueType, ctx);
51735180
}
51745181

51755182
ArrayRef<GenericTypeParamType *>

lib/AST/GenericSignature.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -831,8 +831,7 @@ int swift::compareAssociatedTypes(AssociatedTypeDecl *assocType1,
831831
return 0;
832832
}
833833

834-
/// Canonical ordering for type parameters.
835-
int swift::compareDependentTypes(Type type1, Type type2) {
834+
static int compareDependentTypesRec(Type type1, Type type2) {
836835
// Fast-path check for equality.
837836
if (type1->isEqual(type2)) return 0;
838837

@@ -853,7 +852,7 @@ int swift::compareDependentTypes(Type type1, Type type2) {
853852

854853
// - by base, so t_0_n.`P.T` < t_1_m.`P.T`
855854
if (int compareBases =
856-
compareDependentTypes(depMemTy1->getBase(), depMemTy2->getBase()))
855+
compareDependentTypesRec(depMemTy1->getBase(), depMemTy2->getBase()))
857856
return compareBases;
858857

859858
// - by name, so t_n_m.`P.T` < t_n_m.`P.U`
@@ -869,6 +868,17 @@ int swift::compareDependentTypes(Type type1, Type type2) {
869868
return 0;
870869
}
871870

871+
/// Canonical ordering for type parameters.
872+
int swift::compareDependentTypes(Type type1, Type type2) {
873+
auto *root1 = type1->getRootGenericParam();
874+
auto *root2 = type2->getRootGenericParam();
875+
if (root1->getWeight() != root2->getWeight()) {
876+
return root2->getWeight() ? -1 : +1;
877+
}
878+
879+
return compareDependentTypesRec(type1, type2);
880+
}
881+
872882
#pragma mark Generic signature verification
873883

874884
void GenericSignature::verify() const {

lib/AST/RequirementEnvironment.cpp

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,8 @@ RequirementEnvironment::RequirementEnvironment(
4646

4747
auto conformanceToWitnessThunkGenericParamFn = [&](GenericTypeParamType *genericParam)
4848
-> GenericTypeParamType * {
49-
return GenericTypeParamType::get(genericParam->getParamKind(),
50-
genericParam->getDepth() + (covariantSelf ? 1 : 0),
51-
genericParam->getIndex(),
52-
genericParam->getValueType(), ctx);
49+
return genericParam->withDepth(
50+
genericParam->getDepth() + (covariantSelf ? 1 : 0));
5351
};
5452

5553
// This is a substitution function from the generic parameters of the
@@ -109,9 +107,7 @@ RequirementEnvironment::RequirementEnvironment(
109107
// invalid code.
110108
if (genericParam->getDepth() != 1)
111109
return Type();
112-
Type substGenericParam = GenericTypeParamType::get(
113-
genericParam->getParamKind(), depth, genericParam->getIndex(),
114-
genericParam->getValueType(), ctx);
110+
Type substGenericParam = genericParam->withDepth(depth);
115111
if (genericParam->isParameterPack()) {
116112
substGenericParam = PackType::getSingletonPackExpansion(
117113
substGenericParam);
@@ -210,10 +206,7 @@ RequirementEnvironment::RequirementEnvironment(
210206
}
211207

212208
// Create an equivalent generic parameter at the next depth.
213-
auto substGenericParam = GenericTypeParamType::get(
214-
genericParam->getParamKind(), depth, genericParam->getIndex(),
215-
genericParam->getValueType(), ctx);
216-
209+
auto substGenericParam = genericParam->withDepth(depth);
217210
genericParamTypes.push_back(substGenericParam);
218211
}
219212

lib/AST/RequirementMachine/Term.cpp

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -122,19 +122,42 @@ bool Term::containsNameSymbols() const {
122122
return false;
123123
}
124124

125-
/// Shortlex order on symbol ranges.
125+
/// Weighted shortlex order on symbol ranges, used for implementing
126+
/// Term::compare() and MutableTerm::compare().
126127
///
127-
/// First we compare length, then perform a lexicographic comparison
128-
/// on symbols if the two ranges have the same length.
128+
/// We first compute a weight vector for both terms and compare the
129+
/// vectors lexicographically:
130+
/// - Weight of generic param symbols
131+
/// - Number of name symbols
132+
/// - Number of element symbols
129133
///
130-
/// This is used to implement Term::compare() and MutableTerm::compare()
131-
/// below.
132-
static std::optional<int> shortlexCompare(const Symbol *lhsBegin,
133-
const Symbol *lhsEnd,
134-
const Symbol *rhsBegin,
135-
const Symbol *rhsEnd,
136-
RewriteContext &ctx) {
137-
// First, compare the number of name and pack element symbols.
134+
/// If the terms have the same weight, we compare length.
135+
///
136+
/// If the terms have the same weight and length, we perform a
137+
/// lexicographic comparison on symbols.
138+
///
139+
static std::optional<int> compareImpl(const Symbol *lhsBegin,
140+
const Symbol *lhsEnd,
141+
const Symbol *rhsBegin,
142+
const Symbol *rhsEnd,
143+
RewriteContext &ctx) {
144+
ASSERT(lhsBegin != lhsEnd);
145+
ASSERT(rhsBegin != rhsEnd);
146+
147+
// First compare weights on generic parameters. The implicit
148+
// assumption here is we don't form terms with generic parameter
149+
// symbols in the middle, which is true. Otherwise, we'd need
150+
// to add up their weights like we do below for name symbols,
151+
// of course.
152+
if (lhsBegin->getKind() == Symbol::Kind::GenericParam &&
153+
rhsBegin->getKind() == Symbol::Kind::GenericParam) {
154+
unsigned lhsWeight = lhsBegin->getGenericParam()->getWeight();
155+
unsigned rhsWeight = rhsBegin->getGenericParam()->getWeight();
156+
if (lhsWeight != rhsWeight)
157+
return lhsWeight > rhsWeight ? 1 : -1;
158+
}
159+
160+
// Compare the number of name and pack element symbols.
138161
unsigned lhsNameCount = 0;
139162
unsigned lhsPackElementCount = 0;
140163
for (auto *iter = lhsBegin; iter != lhsEnd; ++iter) {
@@ -192,17 +215,17 @@ static std::optional<int> shortlexCompare(const Symbol *lhsBegin,
192215
return 0;
193216
}
194217

195-
/// Shortlex order on terms. Returns None if the terms are identical except
218+
/// Reduction order on terms. Returns None if the terms are identical except
196219
/// for an incomparable superclass or concrete type symbol at the end.
197220
std::optional<int> Term::compare(Term other, RewriteContext &ctx) const {
198-
return shortlexCompare(begin(), end(), other.begin(), other.end(), ctx);
221+
return compareImpl(begin(), end(), other.begin(), other.end(), ctx);
199222
}
200223

201-
/// Shortlex order on mutable terms. Returns None if the terms are identical
224+
/// Reduction order on mutable terms. Returns None if the terms are identical
202225
/// except for an incomparable superclass or concrete type symbol at the end.
203226
std::optional<int> MutableTerm::compare(const MutableTerm &other,
204227
RewriteContext &ctx) const {
205-
return shortlexCompare(begin(), end(), other.begin(), other.end(), ctx);
228+
return compareImpl(begin(), end(), other.begin(), other.end(), ctx);
206229
}
207230

208231
/// Replace the subterm in the range [from,to) of this term with \p rhs.

lib/AST/SubstitutionMap.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -427,8 +427,7 @@ OverrideSubsInfo::OverrideSubsInfo(const NominalTypeDecl *baseNominal,
427427
const NominalTypeDecl *derivedNominal,
428428
GenericSignature baseSig,
429429
const GenericParamList *derivedParams)
430-
: Ctx(baseSig->getASTContext()),
431-
BaseDepth(0),
430+
: BaseDepth(0),
432431
OrigDepth(0),
433432
DerivedParams(derivedParams) {
434433

@@ -468,10 +467,7 @@ Type QueryOverrideSubs::operator()(SubstitutableType *type) const {
468467
->getDeclaredInterfaceType();
469468
}
470469

471-
return GenericTypeParamType::get(
472-
gp->getParamKind(),
473-
gp->getDepth() + info.OrigDepth - info.BaseDepth,
474-
gp->getIndex(), gp->getValueType(), info.Ctx);
470+
return gp->withDepth(gp->getDepth() + info.OrigDepth - info.BaseDepth);
475471
}
476472
}
477473

0 commit comments

Comments
 (0)