Skip to content

Commit 49f28c8

Browse files
authored
Merge pull request #69894 from rintaro/astgen-choice-optional
[ASTGen] Several improvements to generalize node handling
2 parents 9b10ab3 + aeb7a21 commit 49f28c8

File tree

8 files changed

+146
-187
lines changed

8 files changed

+146
-187
lines changed

lib/ASTGen/Sources/ASTGen/ASTGen.swift

Lines changed: 62 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ enum ASTNode {
2828
case decl(BridgedDecl)
2929
case stmt(BridgedStmt)
3030
case expr(BridgedExpr)
31-
case type(BridgedTypeRepr)
3231

3332
var castToExpr: BridgedExpr {
3433
guard case .expr(let bridged) = self else {
@@ -51,13 +50,6 @@ enum ASTNode {
5150
return bridged
5251
}
5352

54-
var castToType: BridgedTypeRepr {
55-
guard case .type(let bridged) = self else {
56-
fatalError("Expected a type")
57-
}
58-
return bridged
59-
}
60-
6153
var bridged: BridgedASTNode {
6254
switch self {
6355
case .expr(let e):
@@ -66,21 +58,6 @@ enum ASTNode {
6658
return BridgedASTNode(raw: s.raw, kind: .stmt)
6759
case .decl(let d):
6860
return BridgedASTNode(raw: d.raw, kind: .decl)
69-
default:
70-
fatalError("Must be expr, stmt, or decl.")
71-
}
72-
}
73-
74-
var raw: UnsafeMutableRawPointer {
75-
switch self {
76-
case .expr(let e):
77-
return e.raw
78-
case .stmt(let s):
79-
return s.raw
80-
case .decl(let d):
81-
return d.raw
82-
case .type(let t):
83-
return t.raw
8461
}
8562
}
8663
}
@@ -97,8 +74,6 @@ class Boxed<Value> {
9774
}
9875

9976
struct ASTGenVisitor {
100-
typealias ResultType = ASTNode
101-
10277
fileprivate let diagnosticEngine: BridgedDiagnosticEngine
10378

10479
let base: UnsafeBufferPointer<UInt8>
@@ -187,47 +162,9 @@ extension ASTGenVisitor {
187162
}
188163
}
189164

190-
extension ASTGenVisitor {
191-
/// Generate ASTNode from a Syntax node. The node must be a decl, stmt, expr, or
192-
/// type.
193-
func generate(_ node: Syntax) -> ASTNode {
194-
if let decl = node.as(DeclSyntax.self) {
195-
return .decl(self.generate(decl: decl))
196-
}
197-
if let stmt = node.as(StmtSyntax.self) {
198-
return .stmt(self.generate(stmt: stmt))
199-
}
200-
if let expr = node.as(ExprSyntax.self) {
201-
return .expr(self.generate(expr: expr))
202-
}
203-
if let type = node.as(TypeSyntax.self) {
204-
return .type(self.generate(type: type))
205-
}
206-
207-
// --- Special cases where `node` doesn't belong to one of the base kinds.
208-
209-
// CodeBlockSyntax -> BraceStmt.
210-
if let node = node.as(CodeBlockSyntax.self) {
211-
return .stmt(self.generate(codeBlock: node).asStmt)
212-
}
213-
// CodeBlockItemSyntax -> ASTNode.
214-
if let node = node.as(CodeBlockItemSyntax.self) {
215-
return self.generate(codeBlockItem: node)
216-
}
217-
218-
fatalError("node does not correspond to an ASTNode \(node.kind)")
219-
}
220-
}
221-
222165
// Misc visits.
223166
// TODO: Some of these are called within a single file/method; we may want to move them to the respective files.
224167
extension ASTGenVisitor {
225-
226-
/// Do NOT introduce another usage of this. Not all choices can produce 'ASTNode'.
227-
func generate(choices node: some SyntaxChildChoices) -> ASTNode {
228-
return self.generate(Syntax(node))
229-
}
230-
231168
public func generate(memberBlockItem node: MemberBlockItemSyntax) -> BridgedDecl {
232169
generate(decl: node.decl)
233170
}
@@ -237,11 +174,29 @@ extension ASTGenVisitor {
237174
}
238175

239176
public func generate(conditionElement node: ConditionElementSyntax) -> ASTNode {
240-
generate(choices: node.condition)
177+
// FIXME: returning ASTNode is wrong, non-expression conditions are not ASTNode.
178+
switch node.condition {
179+
case .availability(_):
180+
break
181+
case .expression(let node):
182+
return .expr(self.generate(expr: node))
183+
case .matchingPattern(_):
184+
break
185+
case .optionalBinding(_):
186+
break
187+
}
188+
fatalError("unimplemented")
241189
}
242190

243191
public func generate(codeBlockItem node: CodeBlockItemSyntax) -> ASTNode {
244-
generate(choices: node.item)
192+
switch node.item {
193+
case .decl(let node):
194+
return .decl(self.generate(decl: node))
195+
case .stmt(let node):
196+
return .stmt(self.generate(stmt: node))
197+
case .expr(let node):
198+
return .expr(self.generate(expr: node))
199+
}
245200
}
246201

247202
public func generate(arrayElement node: ArrayElementSyntax) -> BridgedExpr {
@@ -255,79 +210,59 @@ extension ASTGenVisitor {
255210
}
256211

257212
// Forwarding overloads that take optional syntax nodes. These are defined on demand to achieve a consistent
258-
// 'self.visit(<expr>)' recursion pattern between optional and non-optional inputs.
213+
// 'self.generate(foo: FooSyntax)' recursion pattern between optional and non-optional inputs.
259214
extension ASTGenVisitor {
260215
@inline(__always)
261-
func generate(optional node: TypeSyntax?) -> BridgedTypeRepr? {
262-
guard let node else {
263-
return nil
264-
}
265-
266-
return self.generate(type: node)
216+
func generate(type node: TypeSyntax?) -> BridgedNullableTypeRepr {
217+
self.map(node, generate(type:))
267218
}
268219

269220
@inline(__always)
270-
func generate(optional node: ExprSyntax?) -> BridgedExpr? {
271-
guard let node else {
272-
return nil
273-
}
274-
275-
return self.generate(expr: node)
221+
func generate(expr node: ExprSyntax?) -> BridgedNullableExpr {
222+
self.map(node, generate(expr:))
276223
}
277224

278-
/// DO NOT introduce another usage of this. Not all choices can produce 'ASTNode'.
279225
@inline(__always)
280-
func generate(optional node: (some SyntaxChildChoices)?) -> ASTNode? {
281-
guard let node else {
282-
return nil
283-
}
284-
285-
return self.generate(choices: node)
226+
func generate(genericParameterClause node: GenericParameterClauseSyntax?) -> BridgedNullableGenericParamList {
227+
self.map(node, generate(genericParameterClause:))
286228
}
287229

288230
@inline(__always)
289-
func generate(optional node: GenericParameterClauseSyntax?) -> BridgedGenericParamList? {
290-
guard let node else {
291-
return nil
292-
}
293-
294-
return self.generate(genericParameterClause: node)
231+
func generate(genericWhereClause node: GenericWhereClauseSyntax?) -> BridgedNullableTrailingWhereClause {
232+
self.map(node, generate(genericWhereClause:))
295233
}
296234

297235
@inline(__always)
298-
func generate(optional node: GenericWhereClauseSyntax?) -> BridgedTrailingWhereClause? {
299-
guard let node else {
300-
return nil
301-
}
302-
303-
return self.generate(genericWhereClause: node)
236+
func generate(enumCaseParameterClause node: EnumCaseParameterClauseSyntax?) -> BridgedNullableParameterList {
237+
self.map(node, generate(enumCaseParameterClause:))
304238
}
305239

306240
@inline(__always)
307-
func generate(optional node: EnumCaseParameterClauseSyntax?) -> BridgedParameterList? {
308-
guard let node else {
309-
return nil
310-
}
311-
312-
return self.generate(enumCaseParameterClause: node)
241+
func generate(inheritedTypeList node: InheritedTypeListSyntax?) -> BridgedArrayRef {
242+
self.map(node, generate(inheritedTypeList:))
313243
}
314244

315245
@inline(__always)
316-
func generate(optional node: InheritedTypeListSyntax?) -> BridgedArrayRef {
317-
guard let node else {
318-
return .init()
319-
}
320-
321-
return self.generate(inheritedTypeList: node)
246+
func generate(precedenceGroupNameList node: PrecedenceGroupNameListSyntax?) -> BridgedArrayRef {
247+
self.map(node, generate(precedenceGroupNameList:))
322248
}
323249

250+
// Helper function for `generate(foo: FooSyntax?)` methods.
324251
@inline(__always)
325-
func generate(optional node: PrecedenceGroupNameListSyntax?) -> BridgedArrayRef {
326-
guard let node else {
327-
return .init()
328-
}
252+
private func map<Node: SyntaxProtocol, Result: HasNullable>(
253+
_ node: Node?,
254+
_ body: (Node) -> Result
255+
) -> Result.Nullable {
256+
return Result.asNullable(node.map(body))
257+
}
329258

330-
return self.generate(precedenceGroupNameList: node)
259+
// Helper function for `generate(barList: BarListSyntax?)` methods for collection nodes.
260+
@inline(__always)
261+
private func map<Node: SyntaxCollection>(
262+
_ node: Node?,
263+
_ body: (Node) -> BridgedArrayRef
264+
) -> BridgedArrayRef {
265+
return node.map(body) ?? .init()
331266
}
332267
}
333268

@@ -422,16 +357,16 @@ public func buildTopLevelASTNodes(
422357

423358
/// Generate an AST node at the given source location. Returns the generated
424359
/// ASTNode and mutate the pointee of `endLocPtr` to the end of the node.
425-
private func _build<Node: SyntaxProtocol>(
426-
kind: Node.Type,
360+
private func _build<Node: SyntaxProtocol, Result>(
361+
generator: (ASTGenVisitor) -> (Node) -> Result,
427362
diagEngine: BridgedDiagnosticEngine,
428363
sourceFilePtr: UnsafeRawPointer,
429364
sourceLoc: BridgedSourceLoc,
430365
declContext: BridgedDeclContext,
431366
astContext: BridgedASTContext,
432367
legacyParser: BridgedLegacyParser,
433368
endLocPtr: UnsafeMutablePointer<BridgedSourceLoc>
434-
) -> UnsafeMutableRawPointer? {
369+
) -> Result? {
435370
let sourceFile = sourceFilePtr.assumingMemoryBound(to: ExportedSourceFile.self)
436371

437372
// Find the type syntax node.
@@ -452,13 +387,13 @@ private func _build<Node: SyntaxProtocol>(
452387
endLocPtr.pointee = sourceLoc.advanced(by: node.totalLength.utf8Length)
453388

454389
// Convert the syntax node.
455-
return ASTGenVisitor(
390+
return generator(ASTGenVisitor(
456391
diagnosticEngine: diagEngine,
457392
sourceBuffer: sourceFile.pointee.buffer,
458393
declContext: declContext,
459394
astContext: astContext,
460395
legacyParser: legacyParser
461-
).generate(Syntax(node)).raw
396+
))(node)
462397
}
463398

464399
@_cdecl("swift_ASTGen_buildTypeRepr")
@@ -473,15 +408,15 @@ func buildTypeRepr(
473408
endLocPtr: UnsafeMutablePointer<BridgedSourceLoc>
474409
) -> UnsafeMutableRawPointer? {
475410
return _build(
476-
kind: TypeSyntax.self,
411+
generator: ASTGenVisitor.generate(type:),
477412
diagEngine: diagEngine,
478413
sourceFilePtr: sourceFilePtr,
479414
sourceLoc: sourceLoc,
480415
declContext: declContext,
481416
astContext: astContext,
482417
legacyParser: legacyParser,
483418
endLocPtr: endLocPtr
484-
)
419+
)?.raw
485420
}
486421

487422
@_cdecl("swift_ASTGen_buildDecl")
@@ -496,15 +431,15 @@ func buildDecl(
496431
endLocPtr: UnsafeMutablePointer<BridgedSourceLoc>
497432
) -> UnsafeMutableRawPointer? {
498433
return _build(
499-
kind: DeclSyntax.self,
434+
generator: ASTGenVisitor.generate(decl:),
500435
diagEngine: diagEngine,
501436
sourceFilePtr: sourceFilePtr,
502437
sourceLoc: sourceLoc,
503438
declContext: declContext,
504439
astContext: astContext,
505440
legacyParser: legacyParser,
506441
endLocPtr: endLocPtr
507-
)
442+
)?.raw
508443
}
509444

510445
@_cdecl("swift_ASTGen_buildExpr")
@@ -519,15 +454,15 @@ func buildExpr(
519454
endLocPtr: UnsafeMutablePointer<BridgedSourceLoc>
520455
) -> UnsafeMutableRawPointer? {
521456
return _build(
522-
kind: ExprSyntax.self,
457+
generator: ASTGenVisitor.generate(expr:),
523458
diagEngine: diagEngine,
524459
sourceFilePtr: sourceFilePtr,
525460
sourceLoc: sourceLoc,
526461
declContext: declContext,
527462
astContext: astContext,
528463
legacyParser: legacyParser,
529464
endLocPtr: endLocPtr
530-
)
465+
)?.raw
531466
}
532467

533468
@_cdecl("swift_ASTGen_buildStmt")
@@ -542,13 +477,13 @@ func buildStmt(
542477
endLocPtr: UnsafeMutablePointer<BridgedSourceLoc>
543478
) -> UnsafeMutableRawPointer? {
544479
return _build(
545-
kind: StmtSyntax.self,
480+
generator: ASTGenVisitor.generate(stmt:),
546481
diagEngine: diagEngine,
547482
sourceFilePtr: sourceFilePtr,
548483
sourceLoc: sourceLoc,
549484
declContext: declContext,
550485
astContext: astContext,
551486
legacyParser: legacyParser,
552487
endLocPtr: endLocPtr
553-
)
488+
)?.raw
554489
}

0 commit comments

Comments
 (0)