Skip to content

Commit ddf67d4

Browse files
authored
Add support for module selectors in types during macro expansion. (#1649)
This PR modifies our code that detects types so that it correctly handles the module selector if present. For example, we currently misfire on a pattern like this: ```swift try #require(throws: Swift::Never.self) { ... } ``` We would expect to diagnose with a warning of the form: > ⚠️ Passing 'Never.self' to 'require(\_:\_:)' is redundant; invoke non-throwing test code directly instead But we don't because we don't recognize that the listed error type is equivalent to `Never`. This PR fixes that, as well as similar problems with `Tag`, `Tag.List`, `ParallelizationTrait`, and the use of `as` in an expectation expression. Migrating macro expansion code of the form `Testing.T` to `Testing::T` is a future direction. ### Checklist: - [x] Code and documentation should follow the style of the [Style Guide](https://github.com/apple/swift-testing/blob/main/Documentation/StyleGuide.md). - [x] If public symbols are renamed or modified, DocC references should be updated.
1 parent 056b7fd commit ddf67d4

11 files changed

Lines changed: 103 additions & 13 deletions

Sources/TestingMacros/ConditionMacro.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -397,11 +397,11 @@ public struct RequireThrowsMacro: RefinedConditionMacro {
397397
let arguments = argumentList(of: macro, in: context)
398398
let errorExpr = arguments.first { $0.label?.tokenKind == .identifier("throws") }?.expression
399399

400-
if let errorExpr {
401-
let argumentTokens: [String] = errorExpr.tokens(viewMode: .fixedUp).lazy
402-
.filter { $0.tokenKind != .period }
403-
.map(\.textWithoutBackticks)
404-
if argumentTokens == ["Swift", "Never", "self"] || argumentTokens == ["Never", "self"] {
400+
if let errorExpr = errorExpr?.as(MemberAccessExprSyntax.self),
401+
errorExpr.declName.argumentNames == nil,
402+
errorExpr.declName.baseName.tokenKind == .keyword(.self) {
403+
let errorType = "\(errorExpr.base)" as TypeSyntax
404+
if errorType.isNamed("Never", inModuleNamed: "Swift") {
405405
context.diagnose(.requireThrowsNeverIsRedundant(errorExpr, in: macro))
406406
}
407407
}

Sources/TestingMacros/Support/Additions/TypeSyntaxProtocolAdditions.swift

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,40 @@ private let _knownGenericTypeKinds: [SyntaxKind] = [
1717
]
1818

1919
extension TypeSyntaxProtocol {
20+
/// This type's module selector, if any.
21+
var moduleSelector: ModuleSelectorSyntax? {
22+
if let type = self.as(IdentifierTypeSyntax.self) {
23+
return type.moduleSelector
24+
} else if let type = self.as(MemberTypeSyntax.self) {
25+
return type.moduleSelector ?? type.baseType.moduleSelector
26+
}
27+
return nil
28+
}
29+
30+
/// Copy this instance and remove its module selector if present.
31+
///
32+
/// - Returns: A copy of this instance with a `nil` module selector. If this
33+
/// instance does not specify a module selector, returns `self` verbatim.
34+
func removingModuleSelector() -> some TypeSyntaxProtocol {
35+
var result = TypeSyntax(self)
36+
37+
if var type = self.as(IdentifierTypeSyntax.self) {
38+
if type.moduleSelector != nil {
39+
type.moduleSelector = nil
40+
result = TypeSyntax(type)
41+
}
42+
} else if var type = self.as(MemberTypeSyntax.self) {
43+
if type.moduleSelector != nil {
44+
type.moduleSelector = nil
45+
} else if type.baseType.moduleSelector != nil {
46+
type.baseType = TypeSyntax(type.baseType.removingModuleSelector())
47+
}
48+
result = TypeSyntax(type)
49+
}
50+
51+
return result
52+
}
53+
2054
/// Whether or not this type is an optional type (`T?`, `Optional<T>`, etc.)
2155
var isOptional: Bool {
2256
if `is`(OptionalTypeSyntax.self) {
@@ -91,6 +125,19 @@ extension TypeSyntaxProtocol {
91125
///
92126
/// - Returns: Whether or not this type has the given name.
93127
func isNamed(_ name: String, inModuleNamed moduleName: String) -> Bool {
128+
// NOTE: the syntax M::M.T is ambiguous without type checking. We don't know
129+
// from syntax alone if the second M is the module name (repeated) or if it
130+
// is a type in module M with the same name. For example, XCTest::XCTest.T.
131+
// Because it's ambiguous, we don't clear the moduleName argument after we
132+
// strip the module selector and before we recursively call isNamed().
133+
if let moduleSelector {
134+
guard moduleName == moduleSelector.moduleName.textWithoutBackticks else {
135+
return false
136+
}
137+
let selfCopy = self.removingModuleSelector()
138+
return selfCopy.isNamed(name, inModuleNamed: moduleName)
139+
}
140+
94141
// Form a string of the fixed-up tokens representing the type name,
95142
// omitting any generic type parameters.
96143
let nameWithoutGenericParameters = tokens(viewMode: .fixedUp)

Sources/TestingMacros/Support/DiagnosticMessage+Diagnosing.swift

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,17 @@ extension AttributeInfo {
2626
let calledExpr = functionCallExpr.calledExpression.as(MemberAccessExprSyntax.self) {
2727
// Check for .tags() traits.
2828
switch calledExpr.tokens(viewMode: .fixedUp).map(\.textWithoutBackticks).joined() {
29-
case ".tags", "Tag.List.tags", "Testing.Tag.List.tags":
29+
case ".tags", "Tag.List.tags", "Testing.Tag.List.tags", "Testing::Tag.List.tags", "Testing::Testing.Tag.List.tags":
3030
_diagnoseIssuesWithTagsTrait(functionCallExpr, addedTo: self, in: context)
31-
case ".bug", "Bug.bug", "Testing.Bug.bug":
31+
case ".bug", "Bug.bug", "Testing.Bug.bug", "Testing::Bug.bug", "Testing::Testing.Bug.bug":
3232
_diagnoseIssuesWithBugTrait(functionCallExpr, addedTo: self, in: context)
3333
default:
3434
// This is not a trait we can parse.
3535
break
3636
}
3737
} else if let memberAccessExpr = traitExpr.as(MemberAccessExprSyntax.self) {
3838
switch memberAccessExpr.tokens(viewMode: .fixedUp).map(\.textWithoutBackticks).joined() {
39-
case ".serialized", "ParallelizationTrait.serialized", "Testing.ParallelizationTrait.serialized":
39+
case ".serialized", "ParallelizationTrait.serialized", "Testing.ParallelizationTrait.serialized", "Testing::ParallelizationTrait.serialized", "Testing::Testing.ParallelizationTrait.serialized":
4040
_diagnoseIssuesWithParallelizationTrait(memberAccessExpr, addedTo: self, in: context)
4141
default:
4242
// This is not a trait we can parse.
@@ -61,7 +61,7 @@ private func _diagnoseIssuesWithTagsTrait(_ traitExpr: FunctionCallExprSyntax, a
6161
// String literals are supported tags.
6262
} else if let tagExpr = tagExpr.as(MemberAccessExprSyntax.self) {
6363
let joinedTokens = tagExpr.tokens(viewMode: .fixedUp).map(\.textWithoutBackticks).joined()
64-
if joinedTokens.hasPrefix(".") || joinedTokens.hasPrefix("Tag.") || joinedTokens.hasPrefix("Testing.Tag.") {
64+
if joinedTokens.hasPrefix(".") || joinedTokens.hasPrefix("Tag.") || joinedTokens.hasPrefix("Testing.Tag.") || joinedTokens.hasPrefix("Testing::Tag.") || joinedTokens.hasPrefix("Testing::Testing.Tag.") {
6565
// These prefixes are all allowed as they specify a member access
6666
// into the Tag type.
6767
} else {
@@ -338,7 +338,8 @@ func diagnoseExpansionInLibraryTarget(of macro: some FreestandingMacroExpansionS
338338
}
339339

340340
var targetName = "<unknown>"
341-
if let fileID = context.location(of: macro, at: .afterLeadingTrivia, filePathMode: .fileID)?.file.trimmedDescription,
341+
if let location = context.location(of: macro, at: .afterLeadingTrivia, filePathMode: .fileID),
342+
let fileID = location.file.as(StringLiteralExprSyntax.self)?.representedLiteralValue,
342343
let slashIndex = fileID.firstIndex(of: "/") {
343344
targetName = String(fileID[..<slashIndex])
344345
}

Sources/TestingMacros/Support/DiagnosticMessage.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ struct DiagnosticMessage: SwiftDiagnostics.DiagnosticMessage {
829829
/// - expr: The error type expression.
830830
///
831831
/// - Returns: A diagnostic message.
832-
static func requireThrowsNeverIsRedundant(_ expr: ExprSyntax, in macro: some FreestandingMacroExpansionSyntax) -> Self {
832+
static func requireThrowsNeverIsRedundant(_ expr: some ExprSyntaxProtocol, in macro: some FreestandingMacroExpansionSyntax) -> Self {
833833
// We do not provide fix-its because we cannot see the leading "try" keyword
834834
// so we can't provide a valid fix-it to remove the macro either. We can
835835
// provide a fix-it to add "as Optional", but only providing that fix-it may

Sources/TestingMacros/TagMacro.swift

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,13 @@ public struct TagMacro: PeerMacro, AccessorMacro, Sendable {
5555
let typeNameTokens: [String] = type.tokens(viewMode: .fixedUp).lazy
5656
.filter { $0.tokenKind != .period }
5757
.map(\.textWithoutBackticks)
58-
guard typeNameTokens.first == "Tag" || typeNameTokens.starts(with: ["Testing", "Tag"]) else {
58+
let validTypeNameTokens = [
59+
["Tag"],
60+
["Testing", "Tag"],
61+
["Testing", "::", "Tag"],
62+
["Testing", "::", "Testing", "Tag"],
63+
]
64+
guard validTypeNameTokens.contains(where: typeNameTokens.starts(with:)) else {
5965
context.diagnose(.attributeNotSupportedOutsideTagExtension(node, on: variableDecl))
6066
return _fallbackAccessorDecls
6167
}

Tests/TestingMacrosTests/ConditionMacroTests.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,8 @@ struct ConditionMacroTests {
346346
"#expect(x as! T!)", "#require(x as! T!)",
347347
"#expect(x as! Optional<T>)", "#require(x as! Optional<T>)",
348348
"#expect(x as! Swift.Optional<T>)", "#require(x as! Swift.Optional<T>)",
349+
"#expect(x as! Swift::Optional<T>)", "#require(x as! Swift::Optional<T>)",
350+
"#expect(x as! Swift::Swift.Optional<T>)", "#require(x as! Swift::Swift.Optional<T>)",
349351
]
350352
)
351353
func asExclamationMarkSuppressedForBoolAndOptional(input: String) throws {
@@ -424,6 +426,8 @@ struct ConditionMacroTests {
424426
@Test("#require(throws: Never.self) produces a diagnostic",
425427
arguments: [
426428
"#requireThrows(throws: Swift.Never.self)",
429+
"#requireThrows(throws: Swift::Never.self)",
430+
"#requireThrows(throws: Swift::Swift.Never.self)",
427431
"#requireThrows(throws: Never.self)",
428432
"#requireThrowsNever(throws: Never.self)",
429433
]

Tests/TestingMacrosTests/TagMacroTests.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,17 @@ struct TagMacroTests {
2424
("extension Tag { @Tag static var x: Tag }", "Tag"),
2525
("extension Tag { @Tag static var x: Self }", "Tag"),
2626
("extension Testing.Tag { @Tag static var x: Testing.Tag }", "Testing.Tag"),
27+
("extension Testing::Tag { @Tag static var x: Testing::Tag }", "Testing::Tag"),
28+
("extension Testing::Testing.Tag { @Tag static var x: Testing::Testing.Tag }", "Testing::Testing.Tag"),
29+
("extension Testing::Testing.Tag { @Tag static var x: Testing.Tag }", "Testing::Testing.Tag"),
30+
("extension Testing.Tag { @Tag static var x: Testing::Testing.Tag }", "Testing.Tag"),
31+
2732
("extension Tag.A.B { @Tag static var x: Tag }", "Tag.A.B"),
2833
("extension Testing.Tag.A.B { @Tag static var x: Tag }", "Testing.Tag.A.B"),
2934
("extension Tag { struct S { @Tag static var x: Tag } }", "Tag.S"),
3035
("extension Testing.Tag { enum E { @Tag static var x: Tag } }", "Testing.Tag.E"),
36+
("extension Testing::Tag { enum E { @Tag static var x: Tag } }", "Testing::Tag.E"),
37+
("extension Testing::Testing.Tag { enum E { @Tag static var x: Tag } }", "Testing::Testing.Tag.E"),
3138
]
3239
)
3340
func tagMacro(input: String, typeName: String) throws {

Tests/TestingMacrosTests/TestDeclarationMacroTests.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,13 +559,21 @@ struct TestDeclarationMacroTests {
559559
#"@Test(.tags(.f)) func f() {}"#,
560560
#"@Test(Tag.List.tags(.f)) func f() {}"#,
561561
#"@Test(Testing.Tag.List.tags(.f)) func f() {}"#,
562+
#"@Test(Testing::Tag.List.tags(.f)) func f() {}"#,
563+
#"@Test(Testing::Testing.Tag.List.tags(.f)) func f() {}"#,
562564
#"@Test(.tags("abc")) func f() {}"#,
563565
#"@Test(Tag.List.tags("abc")) func f() {}"#,
564566
#"@Test(Testing.Tag.List.tags("abc")) func f() {}"#,
567+
#"@Test(Testing::Tag.List.tags("abc")) func f() {}"#,
568+
#"@Test(Testing::Testing.Tag.List.tags("abc")) func f() {}"#,
565569
#"@Test(.tags(Tag.f)) func f() {}"#,
566570
#"@Test(.tags(Testing.Tag.f)) func f() {}"#,
571+
#"@Test(.tags(Testing::Tag.f)) func f() {}"#,
572+
#"@Test(.tags(Testing::Testing.Tag.f)) func f() {}"#,
567573
#"@Test(.tags(.Foo.Bar.f)) func f() {}"#,
568574
#"@Test(.tags(Testing.Tag.Foo.Bar.f)) func f() {}"#,
575+
#"@Test(.tags(Testing::Tag.Foo.Bar.f)) func f() {}"#,
576+
#"@Test(.tags(Testing::Testing.Tag.Foo.Bar.f)) func f() {}"#,
569577
]
570578
)
571579
func validTagExpressions(input: String) throws {

Tests/TestingTests/IssueTests.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,10 +539,14 @@ final class IssueTests: XCTestCase {
539539
throw MyParameterizedError(index: randomNumber)
540540
}
541541
#expect(throws: Never.self) {}
542+
#expect(throws: Swift::Never.self) {}
543+
#expect(throws: Swift::Swift.Never.self) {}
542544
func genericExpectThrows(_ type: (some Error).Type) {
543545
#expect(throws: type) {}
544546
}
545547
genericExpectThrows(Never.self)
548+
genericExpectThrows(Swift::Never.self)
549+
genericExpectThrows(Swift::Swift.Never.self)
546550
func nonVoidReturning() throws -> Int { throw MyError() }
547551
#expect(throws: MyError.self) {
548552
try nonVoidReturning()

Tests/TestingTests/MiscellaneousTests.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@ private import Foundation
2828

2929
@Sendable func freeSyncFunctionParameterized2(_ i: Int, _ j: String) {}
3030

31+
struct SuiteTypeWithModuleSelector {}
32+
33+
extension TestingTests::SuiteTypeWithModuleSelector {
34+
@Test(.hidden) func withModuleSelector() {}
35+
@Suite(.hidden) struct NestedType {
36+
@Test(.hidden) func nestedFunction() {}
37+
}
38+
}
39+
3140
// This type ensures the parser can correctly infer that f() is a member
3241
// function even though @Test is preceded by another attribute or is embedded in
3342
// a #if statement.

0 commit comments

Comments
 (0)