From c2cf18abcce14ee5559ad5755ddb9e4ae3b42f3c Mon Sep 17 00:00:00 2001 From: Ahmad Ayman Mansour Date: Wed, 21 Jan 2026 22:39:24 +0200 Subject: [PATCH 1/3] Refactor ConvertToTernaryExpression --- Sources/SwiftRefactor/CMakeLists.txt | 1 + .../ConvertToTernaryExpression.swift | 485 ++++++++++++++++++ .../ConvertToTernaryExpressionTests.swift | 242 +++++++++ 3 files changed, 728 insertions(+) create mode 100644 Sources/SwiftRefactor/ConvertToTernaryExpression.swift create mode 100644 Tests/SwiftRefactorTest/ConvertToTernaryExpressionTests.swift diff --git a/Sources/SwiftRefactor/CMakeLists.txt b/Sources/SwiftRefactor/CMakeLists.txt index e478a428a79..aae774e7144 100644 --- a/Sources/SwiftRefactor/CMakeLists.txt +++ b/Sources/SwiftRefactor/CMakeLists.txt @@ -13,6 +13,7 @@ add_swift_syntax_library(SwiftRefactor ConvertComputedPropertyToStored.swift ConvertComputedPropertyToZeroParameterFunction.swift ConvertStoredPropertyToComputed.swift + ConvertToTernaryExpression.swift ConvertZeroParameterFunctionToComputedProperty.swift ExpandEditorPlaceholder.swift FormatRawStringLiteral.swift diff --git a/Sources/SwiftRefactor/ConvertToTernaryExpression.swift b/Sources/SwiftRefactor/ConvertToTernaryExpression.swift new file mode 100644 index 00000000000..f688f2931ad --- /dev/null +++ b/Sources/SwiftRefactor/ConvertToTernaryExpression.swift @@ -0,0 +1,485 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2026 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +#if swift(>=6) +public import SwiftSyntax +#else +import SwiftSyntax +#endif + +/// Converts an if-else statement that assigns to the same variable in both branches +/// into a ternary expression assignment. +/// +/// ## Before +/// +/// ```swift +/// let result: Type +/// if condition { +/// result = trueValue +/// } else { +/// result = falseValue +/// } +/// ``` +/// +/// ## After +/// +/// ```swift +/// let result: Type = condition ? trueValue : falseValue +/// ``` +/// +/// Also supports: +/// - Tuple assignments: `(a, b) = condition ? (1, "a") : (2, "b")` +/// - Assignment without declaration: `existingVar = condition ? value1 : value2` +/// +/// This refactoring is applicable when: +/// - There is an if-else statement (no else-if chains) +/// - Both branches contain a single assignment expression +/// - Both assignments target the same variable / same tuple pattern +/// - Optionally, the variable is declared immediately before the if statement +public struct ConvertToTernaryExpression: SyntaxRefactoringProvider { + + public static func refactor(syntax: CodeBlockItemListSyntax, in context: Void) throws -> CodeBlockItemListSyntax { + guard let convertible = try findConvertiblePattern(in: syntax) else { + throw RefactoringNotApplicableError( + "Cannot convert: if-else branches must each contain a single assignment to the same variable" + ) + } + + return performRefactoring(syntax: syntax, convertible: convertible) + } + + // MARK: - Models + /// ConvertibleIfElse + private struct ConvertibleIfElse { + let variableDecl: VariableDeclSyntax? + let ifExpr: IfExprSyntax + + /// LHS of the assignment (`result` or `(x, y)`). + let assignmentTargetExpr: ExprSyntax + + /// Only present when LHS is a simple identifier (e.g. `result`). + let assignmentTargetName: String? + + let condition: ExprSyntax + let trueExpr: ExprSyntax + let falseExpr: ExprSyntax + + let variableDeclIndex: Int? + let ifExprIndex: Int + let isTupleAssignment: Bool + } + + /// AssignmentInfo + private struct AssignmentInfo { + let targetExpr: ExprSyntax + let targetName: String? + let valueExpr: ExprSyntax + let isTuple: Bool + } + + // MARK: - Finding Patterns + private static func findConvertiblePattern(in codeBlock: CodeBlockItemListSyntax) throws -> ConvertibleIfElse? { + let items = Array(codeBlock) + guard !items.isEmpty else { return nil } + + /// Variable declaration followed by if statement. + if items.count >= 2 { + for i in 0..<(items.count - 1) { + if let varDecl = items[i].item.as(VariableDeclSyntax.self), + let ifExpr = extractIfExpr(from: items[i + 1]) + { + + if let convertible = try analyzePattern( + variableDecl: varDecl, + ifExpr: ifExpr, + varIndex: i, + ifIndex: i + 1 + ) { + return convertible + } + } + } + } + + /// Just if statement (variable exists elsewhere). + for (index, item) in items.enumerated() { + if let ifExpr = extractIfExpr(from: item) { + if let convertible = try analyzePattern( + variableDecl: nil, + ifExpr: ifExpr, + varIndex: nil, + ifIndex: index + ) { + return convertible + } + } + } + + return nil + } + + private static func extractIfExpr(from item: CodeBlockItemSyntax) -> IfExprSyntax? { + if let exprStmt = item.item.as(ExpressionStmtSyntax.self) { + return exprStmt.expression.as(IfExprSyntax.self) + } + return item.item.as(IfExprSyntax.self) + } + + private static func analyzePattern( + variableDecl: VariableDeclSyntax?, + ifExpr: IfExprSyntax, + varIndex: Int?, + ifIndex: Int + ) throws -> ConvertibleIfElse? { + + var expectedVariableName: String? + + if let variableDecl { + guard validateVariableDecl(variableDecl) else { + return nil + } + + guard let binding = variableDecl.bindings.first, + let identifierPattern = binding.pattern.as(IdentifierPatternSyntax.self) + else { + return nil + } + expectedVariableName = identifierPattern.identifier.text + } + + guard validateIfExpr(ifExpr) else { + return nil + } + + guard let condition = extractCondition(from: ifExpr) else { + return nil + } + + guard let elseBlock = extractElseBlock(from: ifExpr) else { + return nil + } + + guard let thenAssignment = try extractSingleAssignment(from: ifExpr.body) else { + return nil + } + + guard let elseAssignment = try extractSingleAssignment(from: elseBlock) else { + return nil + } + + guard thenAssignment.isTuple == elseAssignment.isTuple else { + return nil + } + + guard normalized(thenAssignment.targetExpr) == normalized(elseAssignment.targetExpr) else { + return nil + } + + if let expectedName = expectedVariableName { + guard let thenName = thenAssignment.targetName, thenName == expectedName else { + return nil + } + guard elseAssignment.targetName == expectedName else { + return nil + } + } + + if isExpressionTooComplexForTernary(thenAssignment.valueExpr) + || isExpressionTooComplexForTernary(elseAssignment.valueExpr) + { + return nil + } + + return ConvertibleIfElse( + variableDecl: variableDecl, + ifExpr: ifExpr, + assignmentTargetExpr: thenAssignment.targetExpr, + assignmentTargetName: thenAssignment.targetName, + condition: condition, + trueExpr: thenAssignment.valueExpr, + falseExpr: elseAssignment.valueExpr, + variableDeclIndex: varIndex, + ifExprIndex: ifIndex, + isTupleAssignment: thenAssignment.isTuple + ) + } + + // MARK: - Validation Helpers + private static func validateVariableDecl(_ decl: VariableDeclSyntax) -> Bool { + guard decl.bindings.count == 1, + let binding = decl.bindings.first, + binding.typeAnnotation?.type != nil, + binding.initializer == nil, + decl.attributes.isEmpty + else { + return false + } + + let keyword = decl.bindingSpecifier.tokenKind + return keyword == .keyword(.let) || keyword == .keyword(.var) + } + + private static func validateIfExpr(_ ifExpr: IfExprSyntax) -> Bool { + return ifExpr.conditions.count == 1 + } + + private static func extractCondition(from ifExpr: IfExprSyntax) -> ExprSyntax? { + guard let firstCondition = ifExpr.conditions.first else { + return nil + } + + guard case .expression(let condition) = firstCondition.condition else { + return nil + } + + return condition + } + + private static func extractElseBlock(from ifExpr: IfExprSyntax) -> CodeBlockSyntax? { + guard let elseBody = ifExpr.elseBody else { + return nil + } + + switch elseBody { + case .codeBlock(let block): + return block + case .ifExpr: + return nil + } + } + + private static func normalized(_ expression: ExprSyntax) -> String { + expression + .with(\.leadingTrivia, []) + .with(\.trailingTrivia, .space) + .description + } + + // MARK: - Extracting Assignments + /// Extracts the assignment from a code block. + private static func extractSingleAssignment( + from codeBlock: CodeBlockSyntax + ) throws -> AssignmentInfo? { + + guard codeBlock.statements.count == 1, + let statement = codeBlock.statements.first + else { + return nil + } + + let expr: ExprSyntax + if let exprStmt = statement.item.as(ExpressionStmtSyntax.self) { + expr = exprStmt.expression + } else if let directExpr = statement.item.as(ExprSyntax.self) { + expr = directExpr + } else { + return nil + } + + guard let sequenceExpr = expr.as(SequenceExprSyntax.self) else { + return nil + } + + return try extractFromSequenceAssignment(sequenceExpr) + } + + /// Extracts the target and value from a sequence expression containing assignment. + private static func extractFromSequenceAssignment( + _ sequenceExpr: SequenceExprSyntax + ) throws -> AssignmentInfo? { + + let elements = Array(sequenceExpr.elements) + + let expectedElementCount = 3 + let assignmentOperatorIndex = 1 + + guard elements.count == expectedElementCount, + elements[assignmentOperatorIndex].as(AssignmentExprSyntax.self) != nil + else { + return nil + } + + let lhs = ExprSyntax(elements[0]) + let rhs = ExprSyntax(elements[2]) + + if let lhsIdentifier = lhs.as(DeclReferenceExprSyntax.self) { + return AssignmentInfo( + targetExpr: lhs, + targetName: lhsIdentifier.baseName.text, + valueExpr: rhs, + isTuple: false + ) + } + + if lhs.as(TupleExprSyntax.self) != nil { + guard rhs.as(TupleExprSyntax.self) != nil else { + return nil + } + return AssignmentInfo( + targetExpr: lhs, + targetName: nil, + valueExpr: rhs, + isTuple: true + ) + } + + return nil + } + + private static func isExpressionTooComplexForTernary(_ expr: ExprSyntax) -> Bool { + if expr.as(TernaryExprSyntax.self) != nil { return true } + if expr.as(ClosureExprSyntax.self) != nil { return true } + return false + } + + // MARK: - Applying Refactoring + private static func performRefactoring( + syntax: CodeBlockItemListSyntax, + convertible: ConvertibleIfElse + ) -> CodeBlockItemListSyntax { + + var newItems: [CodeBlockItemSyntax] = [] + + for (index, item) in syntax.enumerated() { + if index == convertible.ifExprIndex { + if convertible.variableDecl == nil { + let assignmentExpr = createTernaryAssignment(from: convertible) + let assignmentStmt = ExpressionStmtSyntax(expression: assignmentExpr) + newItems.append( + withoutTrivia( + CodeBlockItemSyntax(item: .stmt(StmtSyntax(assignmentStmt))) + ) + ) + } + continue + } + + if let varDeclIndex = convertible.variableDeclIndex, index == varDeclIndex { + let newDecl = createTernaryDeclaration(from: convertible) + newItems.append( + withoutTrivia( + CodeBlockItemSyntax(item: .decl(DeclSyntax(newDecl))) + ) + ) + continue + } + + newItems.append(item) + } + + return CodeBlockItemListSyntax(newItems) + } + + // MARK: - Builders + private static func withoutTrivia(_ node: T) -> T { + node.with(\.leadingTrivia, []).with(\.trailingTrivia, []) + } + + private static func makeTernaryExpr(from convertible: ConvertibleIfElse) -> TernaryExprSyntax { + TernaryExprSyntax( + condition: convertible.condition + .with(\.leadingTrivia, []) + .with(\.trailingTrivia, .space), + questionMark: .infixQuestionMarkToken(trailingTrivia: .space), + thenExpression: convertible.trueExpr + .with(\.leadingTrivia, []) + .with(\.trailingTrivia, .space), + colon: .colonToken(trailingTrivia: .space), + elseExpression: convertible.falseExpr + .with(\.leadingTrivia, []) + .with(\.trailingTrivia, []) + ) + } + + /// Creates the new variable declaration with ternary initializer (when declaration exists). + /// Preserves the original pattern + type annotation. + private static func createTernaryDeclaration(from convertible: ConvertibleIfElse) -> VariableDeclSyntax { + guard let variableDecl = convertible.variableDecl else { + fatalError("createTernaryDeclaration called without variable declaration") + } + guard let originalBinding = variableDecl.bindings.first else { + fatalError("Invalid state: binding should exist") + } + + let ternaryExpr = makeTernaryExpr(from: convertible) + + let newBinding = + originalBinding + .with(\.typeAnnotation, nil) + .with( + \.initializer, + InitializerClauseSyntax( + equal: .equalToken(leadingTrivia: .space, trailingTrivia: .space), + value: ExprSyntax(ternaryExpr) + ) + ) + + return variableDecl.with(\.bindings, PatternBindingListSyntax([newBinding])) + } + + /// Creates a ternary assignment expression (when no declaration exists). + private static func createTernaryAssignment(from convertible: ConvertibleIfElse) -> ExprSyntax { + let ternaryExpr = makeTernaryExpr(from: convertible) + + let assignmentSeq = SequenceExprSyntax( + elements: ExprListSyntax([ + convertible.assignmentTargetExpr.with(\.trailingTrivia, .space), + ExprSyntax(AssignmentExprSyntax()).with(\.trailingTrivia, .space), + ExprSyntax(ternaryExpr), + ]) + ) + + return ExprSyntax(assignmentSeq) + } +} + +// MARK: - Alternative API for single if statement refactoring +extension ConvertToTernaryExpression { + + public static func refactor( + ifExpr: IfExprSyntax, + variableDecl: VariableDeclSyntax? = nil + ) throws -> VariableDeclSyntax? { + + guard + let convertible = try analyzePattern( + variableDecl: variableDecl, + ifExpr: ifExpr, + varIndex: variableDecl != nil ? 0 : nil, + ifIndex: variableDecl != nil ? 1 : 0 + ) + else { + throw RefactoringNotApplicableError( + "Cannot convert: if-else pattern is not suitable for ternary expression conversion" + ) + } + + if convertible.variableDecl != nil { + return createTernaryDeclaration(from: convertible) + } + + return nil + } + + public static func canRefactor( + ifExpr: IfExprSyntax, + variableDecl: VariableDeclSyntax? = nil + ) -> Bool { + return + (try? analyzePattern( + variableDecl: variableDecl, + ifExpr: ifExpr, + varIndex: variableDecl != nil ? 0 : nil, + ifIndex: variableDecl != nil ? 1 : 0 + )) != nil + } +} diff --git a/Tests/SwiftRefactorTest/ConvertToTernaryExpressionTests.swift b/Tests/SwiftRefactorTest/ConvertToTernaryExpressionTests.swift new file mode 100644 index 00000000000..37fdb2f8d12 --- /dev/null +++ b/Tests/SwiftRefactorTest/ConvertToTernaryExpressionTests.swift @@ -0,0 +1,242 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2024 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +import SwiftRefactor +import SwiftSyntax +import SwiftSyntaxBuilder +import XCTest +import _SwiftSyntaxTestSupport + +final class ConvertToTernaryExpressionTests: XCTestCase { + + // MARK: - Basic Pattern Tests + func testBasicIfElseWithLetDeclaration() throws { + let baseline: CodeBlockItemListSyntax = """ + let result: Int + if condition { + result = 10 + } else { + result = 20 + } + """ + + let expected: CodeBlockItemListSyntax = """ + let result = condition ? 10 : 20 + """ + + try assertRefactorConvert(baseline, expected: expected) + } + + func testBasicIfElseWithVarDeclaration() throws { + let baseline: CodeBlockItemListSyntax = """ + var status: String + if isValid { + status = "approved" + } else { + status = "rejected" + } + """ + + let expected: CodeBlockItemListSyntax = """ + var status = isValid ? "approved" : "rejected" + """ + + try assertRefactorConvert(baseline, expected: expected) + } + + func testStandaloneIfElseAssignment() throws { + let baseline: CodeBlockItemListSyntax = """ + if isActive { + flag = true + } else { + flag = false + } + """ + + let expected: CodeBlockItemListSyntax = """ + flag = isActive ? true : false + """ + + try assertRefactorConvert(baseline, expected: expected) + } + + func testParenthesizedCondition() throws { + let baseline: CodeBlockItemListSyntax = """ + let output: Int + if (x > 0) { + output = 1 + } else { + output = 0 + } + """ + + let expected: CodeBlockItemListSyntax = """ + let output = (x > 0) ? 1 : 0 + """ + + try assertRefactorConvert(baseline, expected: expected) + } + + // MARK: - Tuple Assignment Tests + func testSimpleTupleAssignment() throws { + let baseline: CodeBlockItemListSyntax = """ + let point: (Int, Int) + if isOrigin { + point = (0, 0) + } else { + point = (10, 20) + } + """ + + let expected: CodeBlockItemListSyntax = """ + let point = isOrigin ? (0, 0) : (10, 20) + """ + + try assertRefactorConvert(baseline, expected: expected) + } + + func testNamedTupleAssignment() throws { + let baseline: CodeBlockItemListSyntax = """ + let coordinates: (x: Int, y: Int) + if reset { + coordinates = (x: 0, y: 0) + } else { + coordinates = (x: 100, y: 200) + } + """ + + let expected: CodeBlockItemListSyntax = """ + let coordinates = reset ? (x: 0, y: 0) : (x: 100, y: 200) + """ + + try assertRefactorConvert(baseline, expected: expected) + } + + // MARK: - Complex Expression Tests + + func testFunctionCallInBranches() throws { + let baseline: CodeBlockItemListSyntax = """ + let result: String + if shouldTransform { + result = transform(input) + } else { + result = identity(input) + } + """ + + let expected: CodeBlockItemListSyntax = """ + let result = shouldTransform ? transform(input) : identity(input) + """ + + try assertRefactorConvert(baseline, expected: expected) + } + + func testDictionaryLiteralInBranches() throws { + let baseline: CodeBlockItemListSyntax = """ + let config: [String: Int] + if useDefault { + config = [:] + } else { + config = ["key": 42] + } + """ + + let expected: CodeBlockItemListSyntax = """ + let config = useDefault ? [:] : ["key": 42] + """ + + try assertRefactorConvert(baseline, expected: expected) + } + + // MARK: - Negative Tests - Should NOT Refactor + + func testRejectsElseIfChain() throws { + let baseline: CodeBlockItemListSyntax = """ + let result: Int + if condition1 { + result = 1 + } else if condition2 { + result = 2 + } else { + result = 3 + } + """ + + try assertRefactorConvert(baseline, expected: nil) + } + + func testRejectsNestedTernary() throws { + let baseline: CodeBlockItemListSyntax = """ + let result: Int + if outer { + result = inner ? 1 : 2 + } else { + result = 3 + } + """ + + try assertRefactorConvert(baseline, expected: nil) + } + + func testRejectsClosureInBranch() throws { + let baseline: CodeBlockItemListSyntax = """ + let result: () -> Void + if condition { + result = { print("hello") } + } else { + result = { print("goodbye") } + } + """ + + try assertRefactorConvert(baseline, expected: nil) + } + + func testRejectsDifferentVariablesInBranches() throws { + let baseline: CodeBlockItemListSyntax = """ + let result: Int + if condition { + result = 10 + } else { + other = 20 + } + """ + + try assertRefactorConvert(baseline, expected: nil) + } + + func testRejectsNoElseClause() throws { + let baseline: CodeBlockItemListSyntax = """ + let result: Int + if condition { + result = 10 + } + """ + + try assertRefactorConvert(baseline, expected: nil) + } +} + +private func assertRefactorConvert( + _ baseline: CodeBlockItemListSyntax, + expected: CodeBlockItemListSyntax?, + file: StaticString = #filePath, + line: UInt = #line +) throws { + try assertRefactor( + baseline, + context: (), + provider: ConvertToTernaryExpression.self, + expected: expected, + file: file, + line: line + ) +} From 2d8f14a114e3248c2135e9efd5aa5df889733745 Mon Sep 17 00:00:00 2001 From: Ahmad Ayman Mansour Date: Thu, 26 Feb 2026 23:02:22 +0200 Subject: [PATCH 2/3] Address review feedback on ConvertToTernaryExpression - Add `Collection.only` to `SyntaxUtils.swift` (mirrors the pattern in `SwiftParserDiagnostics` and `CodeGeneration`) - Preserve the type annotation in the output declaration when the declared type is a named tuple (e.g. `(x: Int, y: Int)`) - Delete `extractCondition`, `extractElseBlock` & `validateIfExpr` and inline their checks directly into `analyzePattern`. - Add inline comments to `isExpressionTooComplexForTernary` explaining why nested ternaries and closures are excluded --- .../ConvertToTernaryExpression.swift | 215 +++++++----------- Sources/SwiftRefactor/SyntaxUtils.swift | 11 + .../ConvertToTernaryExpressionTests.swift | 2 +- 3 files changed, 89 insertions(+), 139 deletions(-) diff --git a/Sources/SwiftRefactor/ConvertToTernaryExpression.swift b/Sources/SwiftRefactor/ConvertToTernaryExpression.swift index f688f2931ad..54913fceaaa 100644 --- a/Sources/SwiftRefactor/ConvertToTernaryExpression.swift +++ b/Sources/SwiftRefactor/ConvertToTernaryExpression.swift @@ -87,40 +87,33 @@ public struct ConvertToTernaryExpression: SyntaxRefactoringProvider { } // MARK: - Finding Patterns + /// Finds a convertible if-else pattern by searching for if expressions first, + /// then optionally checking for a preceding variable declaration. private static func findConvertiblePattern(in codeBlock: CodeBlockItemListSyntax) throws -> ConvertibleIfElse? { let items = Array(codeBlock) guard !items.isEmpty else { return nil } - /// Variable declaration followed by if statement. - if items.count >= 2 { - for i in 0..<(items.count - 1) { - if let varDecl = items[i].item.as(VariableDeclSyntax.self), - let ifExpr = extractIfExpr(from: items[i + 1]) - { - - if let convertible = try analyzePattern( - variableDecl: varDecl, - ifExpr: ifExpr, - varIndex: i, - ifIndex: i + 1 - ) { - return convertible - } - } + for (ifIndex, item) in items.enumerated() { + guard let ifExpr = extractIfExpr(from: item) else { continue } + + var varDecl: VariableDeclSyntax? = nil + var varIndex: Int? = nil + + if ifIndex > 0, + let previousVarDecl = items[ifIndex - 1].item.as(VariableDeclSyntax.self), + declarationMatchesIfAssignment(previousVarDecl, ifExpr: ifExpr) + { + varDecl = previousVarDecl + varIndex = ifIndex - 1 } - } - /// Just if statement (variable exists elsewhere). - for (index, item) in items.enumerated() { - if let ifExpr = extractIfExpr(from: item) { - if let convertible = try analyzePattern( - variableDecl: nil, - ifExpr: ifExpr, - varIndex: nil, - ifIndex: index - ) { - return convertible - } + if let convertible = try analyzePattern( + variableDecl: varDecl, + ifExpr: ifExpr, + varIndex: varIndex, + ifIndex: ifIndex + ) { + return convertible } } @@ -134,6 +127,30 @@ public struct ConvertToTernaryExpression: SyntaxRefactoringProvider { return item.item.as(IfExprSyntax.self) } + /// Quick check to see if a variable declaration matches the assignment in an if expression + private static func declarationMatchesIfAssignment( + _ varDecl: VariableDeclSyntax, + ifExpr: IfExprSyntax + ) -> Bool { + guard validateVariableDecl(varDecl), + varDecl.bindings.count == 1, + let binding = varDecl.bindings.first, + let identifierPattern = binding.pattern.as(IdentifierPatternSyntax.self) + else { + return false + } + + let varName = identifierPattern.identifier.text + + guard let thenAssignment = try? extractSingleAssignment(from: ifExpr.body), + let assignedName = thenAssignment.targetName + else { + return false + } + + return varName == assignedName + } + private static func analyzePattern( variableDecl: VariableDeclSyntax?, ifExpr: IfExprSyntax, @@ -148,7 +165,8 @@ public struct ConvertToTernaryExpression: SyntaxRefactoringProvider { return nil } - guard let binding = variableDecl.bindings.first, + guard variableDecl.bindings.count == 1, + let binding = variableDecl.bindings.first, let identifierPattern = binding.pattern.as(IdentifierPatternSyntax.self) else { return nil @@ -156,15 +174,15 @@ public struct ConvertToTernaryExpression: SyntaxRefactoringProvider { expectedVariableName = identifierPattern.identifier.text } - guard validateIfExpr(ifExpr) else { - return nil - } - - guard let condition = extractCondition(from: ifExpr) else { + guard let firstCondition = ifExpr.conditions.only, + case .expression(let condition) = firstCondition.condition + else { return nil } - guard let elseBlock = extractElseBlock(from: ifExpr) else { + guard let elseBody = ifExpr.elseBody, + case .codeBlock(let elseBlock) = elseBody + else { return nil } @@ -214,6 +232,11 @@ public struct ConvertToTernaryExpression: SyntaxRefactoringProvider { } // MARK: - Validation Helpers + private static func isNamedTupleType(_ type: TypeSyntax) -> Bool { + guard let tupleType = type.as(TupleTypeSyntax.self) else { return false } + return tupleType.elements.contains { $0.firstName != nil } + } + private static func validateVariableDecl(_ decl: VariableDeclSyntax) -> Bool { guard decl.bindings.count == 1, let binding = decl.bindings.first, @@ -228,40 +251,8 @@ public struct ConvertToTernaryExpression: SyntaxRefactoringProvider { return keyword == .keyword(.let) || keyword == .keyword(.var) } - private static func validateIfExpr(_ ifExpr: IfExprSyntax) -> Bool { - return ifExpr.conditions.count == 1 - } - - private static func extractCondition(from ifExpr: IfExprSyntax) -> ExprSyntax? { - guard let firstCondition = ifExpr.conditions.first else { - return nil - } - - guard case .expression(let condition) = firstCondition.condition else { - return nil - } - - return condition - } - - private static func extractElseBlock(from ifExpr: IfExprSyntax) -> CodeBlockSyntax? { - guard let elseBody = ifExpr.elseBody else { - return nil - } - - switch elseBody { - case .codeBlock(let block): - return block - case .ifExpr: - return nil - } - } - private static func normalized(_ expression: ExprSyntax) -> String { - expression - .with(\.leadingTrivia, []) - .with(\.trailingTrivia, .space) - .description + expression.trimmed.description } // MARK: - Extracting Assignments @@ -270,9 +261,7 @@ public struct ConvertToTernaryExpression: SyntaxRefactoringProvider { from codeBlock: CodeBlockSyntax ) throws -> AssignmentInfo? { - guard codeBlock.statements.count == 1, - let statement = codeBlock.statements.first - else { + guard let statement = codeBlock.statements.only else { return nil } @@ -299,11 +288,8 @@ public struct ConvertToTernaryExpression: SyntaxRefactoringProvider { let elements = Array(sequenceExpr.elements) - let expectedElementCount = 3 - let assignmentOperatorIndex = 1 - - guard elements.count == expectedElementCount, - elements[assignmentOperatorIndex].as(AssignmentExprSyntax.self) != nil + guard elements.count == 3, + elements[1].as(AssignmentExprSyntax.self) != nil else { return nil } @@ -336,8 +322,14 @@ public struct ConvertToTernaryExpression: SyntaxRefactoringProvider { } private static func isExpressionTooComplexForTernary(_ expr: ExprSyntax) -> Bool { + // Nested ternaries reduce readability: x = a ? b : (c ? d : e) if expr.as(TernaryExprSyntax.self) != nil { return true } + + // Closures in ternaries are harder to read than if-else blocks. + // Example: action = condition ? { [weak self] in self?.log() } : { print("default") } + // is less clear than an if-else with proper formatting and line breaks. if expr.as(ClosureExprSyntax.self) != nil { return true } + return false } @@ -355,9 +347,7 @@ public struct ConvertToTernaryExpression: SyntaxRefactoringProvider { let assignmentExpr = createTernaryAssignment(from: convertible) let assignmentStmt = ExpressionStmtSyntax(expression: assignmentExpr) newItems.append( - withoutTrivia( - CodeBlockItemSyntax(item: .stmt(StmtSyntax(assignmentStmt))) - ) + CodeBlockItemSyntax(item: .stmt(StmtSyntax(assignmentStmt))).trimmed ) } continue @@ -366,9 +356,7 @@ public struct ConvertToTernaryExpression: SyntaxRefactoringProvider { if let varDeclIndex = convertible.variableDeclIndex, index == varDeclIndex { let newDecl = createTernaryDeclaration(from: convertible) newItems.append( - withoutTrivia( - CodeBlockItemSyntax(item: .decl(DeclSyntax(newDecl))) - ) + CodeBlockItemSyntax(item: .decl(DeclSyntax(newDecl))).trimmed ) continue } @@ -380,23 +368,13 @@ public struct ConvertToTernaryExpression: SyntaxRefactoringProvider { } // MARK: - Builders - private static func withoutTrivia(_ node: T) -> T { - node.with(\.leadingTrivia, []).with(\.trailingTrivia, []) - } - private static func makeTernaryExpr(from convertible: ConvertibleIfElse) -> TernaryExprSyntax { TernaryExprSyntax( - condition: convertible.condition - .with(\.leadingTrivia, []) - .with(\.trailingTrivia, .space), + condition: convertible.condition.trimmed.with(\.trailingTrivia, .space), questionMark: .infixQuestionMarkToken(trailingTrivia: .space), - thenExpression: convertible.trueExpr - .with(\.leadingTrivia, []) - .with(\.trailingTrivia, .space), + thenExpression: convertible.trueExpr.trimmed.with(\.trailingTrivia, .space), colon: .colonToken(trailingTrivia: .space), - elseExpression: convertible.falseExpr - .with(\.leadingTrivia, []) - .with(\.trailingTrivia, []) + elseExpression: convertible.falseExpr.trimmed ) } @@ -412,9 +390,12 @@ public struct ConvertToTernaryExpression: SyntaxRefactoringProvider { let ternaryExpr = makeTernaryExpr(from: convertible) + let preserveAnnotation = originalBinding.typeAnnotation.map { isNamedTupleType($0.type) } ?? false + let typeAnnotation: TypeAnnotationSyntax? = preserveAnnotation ? originalBinding.typeAnnotation?.trimmed : nil + let newBinding = originalBinding - .with(\.typeAnnotation, nil) + .with(\.typeAnnotation, typeAnnotation) .with( \.initializer, InitializerClauseSyntax( @@ -441,45 +422,3 @@ public struct ConvertToTernaryExpression: SyntaxRefactoringProvider { return ExprSyntax(assignmentSeq) } } - -// MARK: - Alternative API for single if statement refactoring -extension ConvertToTernaryExpression { - - public static func refactor( - ifExpr: IfExprSyntax, - variableDecl: VariableDeclSyntax? = nil - ) throws -> VariableDeclSyntax? { - - guard - let convertible = try analyzePattern( - variableDecl: variableDecl, - ifExpr: ifExpr, - varIndex: variableDecl != nil ? 0 : nil, - ifIndex: variableDecl != nil ? 1 : 0 - ) - else { - throw RefactoringNotApplicableError( - "Cannot convert: if-else pattern is not suitable for ternary expression conversion" - ) - } - - if convertible.variableDecl != nil { - return createTernaryDeclaration(from: convertible) - } - - return nil - } - - public static func canRefactor( - ifExpr: IfExprSyntax, - variableDecl: VariableDeclSyntax? = nil - ) -> Bool { - return - (try? analyzePattern( - variableDecl: variableDecl, - ifExpr: ifExpr, - varIndex: variableDecl != nil ? 0 : nil, - ifIndex: variableDecl != nil ? 1 : 0 - )) != nil - } -} diff --git a/Sources/SwiftRefactor/SyntaxUtils.swift b/Sources/SwiftRefactor/SyntaxUtils.swift index 6f27b5dc71f..16d9f312432 100644 --- a/Sources/SwiftRefactor/SyntaxUtils.swift +++ b/Sources/SwiftRefactor/SyntaxUtils.swift @@ -32,6 +32,17 @@ extension Trivia { } } +extension Collection { + /// If the collection contains a single element, return it, otherwise `nil`. + var only: Element? { + if !isEmpty && index(after: startIndex) == endIndex { + return self.first! + } else { + return nil + } + } +} + extension TypeSyntax { var isVoid: Bool { switch self.as(TypeSyntaxEnum.self) { diff --git a/Tests/SwiftRefactorTest/ConvertToTernaryExpressionTests.swift b/Tests/SwiftRefactorTest/ConvertToTernaryExpressionTests.swift index 37fdb2f8d12..7d1988edb61 100644 --- a/Tests/SwiftRefactorTest/ConvertToTernaryExpressionTests.swift +++ b/Tests/SwiftRefactorTest/ConvertToTernaryExpressionTests.swift @@ -115,7 +115,7 @@ final class ConvertToTernaryExpressionTests: XCTestCase { """ let expected: CodeBlockItemListSyntax = """ - let coordinates = reset ? (x: 0, y: 0) : (x: 100, y: 200) + let coordinates: (x: Int, y: Int) = reset ? (x: 0, y: 0) : (x: 100, y: 200) """ try assertRefactorConvert(baseline, expected: expected) From 32f0b678d5051ce8897f030c184f098e6d0f9b16 Mon Sep 17 00:00:00 2001 From: Ahmad Ayman Mansour Date: Thu, 19 Mar 2026 07:46:04 +0200 Subject: [PATCH 3/3] address reviewer comments --- Sources/SwiftRefactor/Collection+Only.swift | 22 + .../ConvertToTernaryExpression.swift | 438 +++++++----------- Sources/SwiftRefactor/SyntaxUtils.swift | 11 - .../ConvertToTernaryExpressionTests.swift | 175 +++---- 4 files changed, 275 insertions(+), 371 deletions(-) create mode 100644 Sources/SwiftRefactor/Collection+Only.swift diff --git a/Sources/SwiftRefactor/Collection+Only.swift b/Sources/SwiftRefactor/Collection+Only.swift new file mode 100644 index 00000000000..7f4681a81b4 --- /dev/null +++ b/Sources/SwiftRefactor/Collection+Only.swift @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2026 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +extension Collection { + /// If the collection contains a single element, return it, otherwise `nil`. + var only: Element? { + if !isEmpty && index(after: startIndex) == endIndex { + return self.first! + } else { + return nil + } + } +} diff --git a/Sources/SwiftRefactor/ConvertToTernaryExpression.swift b/Sources/SwiftRefactor/ConvertToTernaryExpression.swift index 54913fceaaa..151f36a9084 100644 --- a/Sources/SwiftRefactor/ConvertToTernaryExpression.swift +++ b/Sources/SwiftRefactor/ConvertToTernaryExpression.swift @@ -45,222 +45,65 @@ import SwiftSyntax /// - Both branches contain a single assignment expression /// - Both assignments target the same variable / same tuple pattern /// - Optionally, the variable is declared immediately before the if statement -public struct ConvertToTernaryExpression: SyntaxRefactoringProvider { - - public static func refactor(syntax: CodeBlockItemListSyntax, in context: Void) throws -> CodeBlockItemListSyntax { - guard let convertible = try findConvertiblePattern(in: syntax) else { +public struct ConvertToTernaryExpression: EditRefactoringProvider { + public static func textRefactor(syntax: IfExprSyntax, in context: Void) throws -> [SourceEdit] { + guard let convertible = try ConvertibleIfElse(ifExpr: syntax) else { throw RefactoringNotApplicableError( "Cannot convert: if-else branches must each contain a single assignment to the same variable" ) } - return performRefactoring(syntax: syntax, convertible: convertible) - } - - // MARK: - Models - /// ConvertibleIfElse - private struct ConvertibleIfElse { - let variableDecl: VariableDeclSyntax? - let ifExpr: IfExprSyntax - - /// LHS of the assignment (`result` or `(x, y)`). - let assignmentTargetExpr: ExprSyntax - - /// Only present when LHS is a simple identifier (e.g. `result`). - let assignmentTargetName: String? - - let condition: ExprSyntax - let trueExpr: ExprSyntax - let falseExpr: ExprSyntax - - let variableDeclIndex: Int? - let ifExprIndex: Int - let isTupleAssignment: Bool - } - - /// AssignmentInfo - private struct AssignmentInfo { - let targetExpr: ExprSyntax - let targetName: String? - let valueExpr: ExprSyntax - let isTuple: Bool - } - - // MARK: - Finding Patterns - /// Finds a convertible if-else pattern by searching for if expressions first, - /// then optionally checking for a preceding variable declaration. - private static func findConvertiblePattern(in codeBlock: CodeBlockItemListSyntax) throws -> ConvertibleIfElse? { - let items = Array(codeBlock) - guard !items.isEmpty else { return nil } - - for (ifIndex, item) in items.enumerated() { - guard let ifExpr = extractIfExpr(from: item) else { continue } - - var varDecl: VariableDeclSyntax? = nil - var varIndex: Int? = nil - - if ifIndex > 0, - let previousVarDecl = items[ifIndex - 1].item.as(VariableDeclSyntax.self), - declarationMatchesIfAssignment(previousVarDecl, ifExpr: ifExpr) - { - varDecl = previousVarDecl - varIndex = ifIndex - 1 - } - - if let convertible = try analyzePattern( - variableDecl: varDecl, - ifExpr: ifExpr, - varIndex: varIndex, - ifIndex: ifIndex - ) { - return convertible - } - } - - return nil - } - - private static func extractIfExpr(from item: CodeBlockItemSyntax) -> IfExprSyntax? { - if let exprStmt = item.item.as(ExpressionStmtSyntax.self) { - return exprStmt.expression.as(IfExprSyntax.self) - } - return item.item.as(IfExprSyntax.self) - } - - /// Quick check to see if a variable declaration matches the assignment in an if expression - private static func declarationMatchesIfAssignment( - _ varDecl: VariableDeclSyntax, - ifExpr: IfExprSyntax - ) -> Bool { - guard validateVariableDecl(varDecl), - varDecl.bindings.count == 1, - let binding = varDecl.bindings.first, - let identifierPattern = binding.pattern.as(IdentifierPatternSyntax.self) - else { - return false - } - - let varName = identifierPattern.identifier.text - - guard let thenAssignment = try? extractSingleAssignment(from: ifExpr.body), - let assignedName = thenAssignment.targetName - else { - return false - } - - return varName == assignedName - } - - private static func analyzePattern( - variableDecl: VariableDeclSyntax?, - ifExpr: IfExprSyntax, - varIndex: Int?, - ifIndex: Int - ) throws -> ConvertibleIfElse? { - - var expectedVariableName: String? - - if let variableDecl { - guard validateVariableDecl(variableDecl) else { - return nil - } - - guard variableDecl.bindings.count == 1, - let binding = variableDecl.bindings.first, - let identifierPattern = binding.pattern.as(IdentifierPatternSyntax.self) - else { - return nil - } - expectedVariableName = identifierPattern.identifier.text - } - - guard let firstCondition = ifExpr.conditions.only, - case .expression(let condition) = firstCondition.condition - else { - return nil - } - - guard let elseBody = ifExpr.elseBody, - case .codeBlock(let elseBlock) = elseBody - else { - return nil - } - - guard let thenAssignment = try extractSingleAssignment(from: ifExpr.body) else { - return nil - } - - guard let elseAssignment = try extractSingleAssignment(from: elseBlock) else { - return nil - } - - guard thenAssignment.isTuple == elseAssignment.isTuple else { - return nil + // Walk up to find the CodeBlockItemSyntax containing this if expression + let ifItem: CodeBlockItemSyntax + if let exprStmt = syntax.parent?.as(ExpressionStmtSyntax.self), + let item = exprStmt.parent?.as(CodeBlockItemSyntax.self) + { + ifItem = item + } else if let item = syntax.parent?.as(CodeBlockItemSyntax.self) { + ifItem = item + } else { + throw RefactoringNotApplicableError("if expression is not in a code block") } - guard normalized(thenAssignment.targetExpr) == normalized(elseAssignment.targetExpr) else { - return nil + guard let codeBlockList = ifItem.parent?.as(CodeBlockItemListSyntax.self) else { + throw RefactoringNotApplicableError("if expression is not in a code block list") } - if let expectedName = expectedVariableName { - guard let thenName = thenAssignment.targetName, thenName == expectedName else { - return nil - } - guard elseAssignment.targetName == expectedName else { - return nil - } + // Check for a preceding compatible variable declaration + let items = Array(codeBlockList) + guard let ifIdx = items.firstIndex(where: { $0.id == ifItem.id }) else { + throw RefactoringNotApplicableError("cannot find if expression in code block") } - if isExpressionTooComplexForTernary(thenAssignment.valueExpr) - || isExpressionTooComplexForTernary(elseAssignment.valueExpr) + let precedingItem = ifIdx > 0 ? items[ifIdx - 1] : nil + if let precedingItem, + let varDecl = precedingItem.item.as(VariableDeclSyntax.self), + varDecl.isVariableDeclarationWithoutInitializer, + let binding = varDecl.bindings.only, + let identifierPattern = binding.pattern.as(IdentifierPatternSyntax.self), + identifierPattern.identifier.text == convertible.assignmentTargetName { - return nil - } - - return ConvertibleIfElse( - variableDecl: variableDecl, - ifExpr: ifExpr, - assignmentTargetExpr: thenAssignment.targetExpr, - assignmentTargetName: thenAssignment.targetName, - condition: condition, - trueExpr: thenAssignment.valueExpr, - falseExpr: elseAssignment.valueExpr, - variableDeclIndex: varIndex, - ifExprIndex: ifIndex, - isTupleAssignment: thenAssignment.isTuple - ) - } - - // MARK: - Validation Helpers - private static func isNamedTupleType(_ type: TypeSyntax) -> Bool { - guard let tupleType = type.as(TupleTypeSyntax.self) else { return false } - return tupleType.elements.contains { $0.firstName != nil } - } - - private static func validateVariableDecl(_ decl: VariableDeclSyntax) -> Bool { - guard decl.bindings.count == 1, - let binding = decl.bindings.first, - binding.typeAnnotation?.type != nil, - binding.initializer == nil, - decl.attributes.isEmpty - else { - return false + // Case: merge preceding decl + if → single declaration with ternary initializer. + // Replace the range from the start of the declaration to the end of the if expression + // (excluding surrounding trivia), preserving whatever whitespace is on either side. + let newDecl = createTernaryDeclaration(from: convertible, variableDecl: varDecl) + let range = precedingItem.positionAfterSkippingLeadingTrivia.. String { - expression.trimmed.description } // MARK: - Extracting Assignments - /// Extracts the assignment from a code block. - private static func extractSingleAssignment( + + fileprivate static func extractSingleAssignment( from codeBlock: CodeBlockSyntax ) throws -> AssignmentInfo? { - guard let statement = codeBlock.statements.only else { return nil } @@ -281,93 +124,32 @@ public struct ConvertToTernaryExpression: SyntaxRefactoringProvider { return try extractFromSequenceAssignment(sequenceExpr) } - /// Extracts the target and value from a sequence expression containing assignment. private static func extractFromSequenceAssignment( _ sequenceExpr: SequenceExprSyntax ) throws -> AssignmentInfo? { - let elements = Array(sequenceExpr.elements) - guard elements.count == 3, - elements[1].as(AssignmentExprSyntax.self) != nil - else { + guard elements.count >= 3, elements[1].as(AssignmentExprSyntax.self) != nil else { return nil } let lhs = ExprSyntax(elements[0]) - let rhs = ExprSyntax(elements[2]) - - if let lhsIdentifier = lhs.as(DeclReferenceExprSyntax.self) { - return AssignmentInfo( - targetExpr: lhs, - targetName: lhsIdentifier.baseName.text, - valueExpr: rhs, - isTuple: false - ) - } - - if lhs.as(TupleExprSyntax.self) != nil { - guard rhs.as(TupleExprSyntax.self) != nil else { - return nil - } - return AssignmentInfo( - targetExpr: lhs, - targetName: nil, - valueExpr: rhs, - isTuple: true - ) - } - - return nil - } - - private static func isExpressionTooComplexForTernary(_ expr: ExprSyntax) -> Bool { - // Nested ternaries reduce readability: x = a ? b : (c ? d : e) - if expr.as(TernaryExprSyntax.self) != nil { return true } - - // Closures in ternaries are harder to read than if-else blocks. - // Example: action = condition ? { [weak self] in self?.log() } : { print("default") } - // is less clear than an if-else with proper formatting and line breaks. - if expr.as(ClosureExprSyntax.self) != nil { return true } - - return false - } - - // MARK: - Applying Refactoring - private static func performRefactoring( - syntax: CodeBlockItemListSyntax, - convertible: ConvertibleIfElse - ) -> CodeBlockItemListSyntax { - - var newItems: [CodeBlockItemSyntax] = [] - - for (index, item) in syntax.enumerated() { - if index == convertible.ifExprIndex { - if convertible.variableDecl == nil { - let assignmentExpr = createTernaryAssignment(from: convertible) - let assignmentStmt = ExpressionStmtSyntax(expression: assignmentExpr) - newItems.append( - CodeBlockItemSyntax(item: .stmt(StmtSyntax(assignmentStmt))).trimmed - ) - } - continue - } - - if let varDeclIndex = convertible.variableDeclIndex, index == varDeclIndex { - let newDecl = createTernaryDeclaration(from: convertible) - newItems.append( - CodeBlockItemSyntax(item: .decl(DeclSyntax(newDecl))).trimmed - ) - continue - } - - newItems.append(item) + let rhs: ExprSyntax + if elements.count == 3 { + rhs = ExprSyntax(elements[2]) + } else { + rhs = ExprSyntax(SequenceExprSyntax(elements: ExprListSyntax(Array(elements[2...])))) } - return CodeBlockItemListSyntax(newItems) + return AssignmentInfo( + targetExpr: lhs, + targetName: lhs.as(DeclReferenceExprSyntax.self)?.baseName.text, + valueExpr: rhs + ) } // MARK: - Builders + private static func makeTernaryExpr(from convertible: ConvertibleIfElse) -> TernaryExprSyntax { TernaryExprSyntax( condition: convertible.condition.trimmed.with(\.trailingTrivia, .space), @@ -378,13 +160,13 @@ public struct ConvertToTernaryExpression: SyntaxRefactoringProvider { ) } - /// Creates the new variable declaration with ternary initializer (when declaration exists). - /// Preserves the original pattern + type annotation. - private static func createTernaryDeclaration(from convertible: ConvertibleIfElse) -> VariableDeclSyntax { - guard let variableDecl = convertible.variableDecl else { - fatalError("createTernaryDeclaration called without variable declaration") - } - guard let originalBinding = variableDecl.bindings.first else { + /// Creates the new variable declaration with ternary initializer. + /// Preserves the original type annotation for named tuples only. + private static func createTernaryDeclaration( + from convertible: ConvertibleIfElse, + variableDecl: VariableDeclSyntax + ) -> VariableDeclSyntax { + guard let originalBinding = variableDecl.bindings.only else { fatalError("Invalid state: binding should exist") } @@ -421,4 +203,104 @@ public struct ConvertToTernaryExpression: SyntaxRefactoringProvider { return ExprSyntax(assignmentSeq) } + + // MARK: - Validation Helpers + + private static func isNamedTupleType(_ type: TypeSyntax) -> Bool { + guard let tupleType = type.as(TupleTypeSyntax.self) else { return false } + return tupleType.elements.contains { $0.firstName != nil } + } +} + +// MARK: - Models + +/// Holds the extracted components of a convertible if-else pattern. +/// +/// Example: given +/// ```swift +/// if condition { +/// result = trueValue +/// } else { +/// result = falseValue +/// } +/// ``` +/// - `assignmentTargetExpr` is `result` +/// - `assignmentTargetName` is `"result"` (`nil` for tuple targets like `(x, y)`) +/// - `condition` is `condition` +/// - `trueExpr` is `trueValue` (from the then-branch) +/// - `falseExpr` is `falseValue` (from the else-branch) +private struct ConvertibleIfElse { + /// LHS of the assignment (`result` or `(x, y)`). + let assignmentTargetExpr: ExprSyntax + + /// Only present when LHS is a simple identifier (e.g. `result`). + let assignmentTargetName: String? + + let condition: ExprSyntax + let trueExpr: ExprSyntax + let falseExpr: ExprSyntax + + init?(ifExpr: IfExprSyntax) throws { + guard let conditionElement = ifExpr.conditions.only, + case .expression(let condition) = conditionElement.condition + else { + return nil + } + + guard let elseBody = ifExpr.elseBody, case .codeBlock(let elseBlock) = elseBody else { + return nil + } + + guard let thenAssignment = try ConvertToTernaryExpression.extractSingleAssignment(from: ifExpr.body) else { + return nil + } + + guard let elseAssignment = try ConvertToTernaryExpression.extractSingleAssignment(from: elseBlock) else { + return nil + } + + guard thenAssignment.targetExpr.trimmed.description == elseAssignment.targetExpr.trimmed.description else { + return nil + } + + self.assignmentTargetExpr = thenAssignment.targetExpr + self.assignmentTargetName = thenAssignment.targetName + self.condition = condition + self.trueExpr = thenAssignment.valueExpr + self.falseExpr = elseAssignment.valueExpr + } +} + +/// Represents the extracted components of a single assignment expression. +/// +/// Example: for `result = trueValue`: +/// - `targetExpr` is `result` +/// - `targetName` is `"result"` (nil for tuple targets like `(x, y)`) +/// - `valueExpr` is `trueValue` +/// +/// Example: for `(x, y) = (1, 2)`: +/// - `targetExpr` is `(x, y)` +/// - `targetName` is `nil` +/// - `valueExpr` is `(1, 2)` +private struct AssignmentInfo { + let targetExpr: ExprSyntax + let targetName: String? + let valueExpr: ExprSyntax +} + +extension VariableDeclSyntax { + /// Returns `true` if this is a `let` or `var` declaration with a type annotation + /// but no initializer and no attributes. + /// + /// Example: `let result: Int` → `true`; `let result = 0` → `false` + fileprivate var isVariableDeclarationWithoutInitializer: Bool { + guard let binding = bindings.only, + binding.typeAnnotation?.type != nil, + binding.initializer == nil, + attributes.isEmpty + else { + return false + } + return true + } } diff --git a/Sources/SwiftRefactor/SyntaxUtils.swift b/Sources/SwiftRefactor/SyntaxUtils.swift index 16d9f312432..6f27b5dc71f 100644 --- a/Sources/SwiftRefactor/SyntaxUtils.swift +++ b/Sources/SwiftRefactor/SyntaxUtils.swift @@ -32,17 +32,6 @@ extension Trivia { } } -extension Collection { - /// If the collection contains a single element, return it, otherwise `nil`. - var only: Element? { - if !isEmpty && index(after: startIndex) == endIndex { - return self.first! - } else { - return nil - } - } -} - extension TypeSyntax { var isVoid: Bool { switch self.as(TypeSyntaxEnum.self) { diff --git a/Tests/SwiftRefactorTest/ConvertToTernaryExpressionTests.swift b/Tests/SwiftRefactorTest/ConvertToTernaryExpressionTests.swift index 7d1988edb61..c07cf5fdefa 100644 --- a/Tests/SwiftRefactorTest/ConvertToTernaryExpressionTests.swift +++ b/Tests/SwiftRefactorTest/ConvertToTernaryExpressionTests.swift @@ -2,7 +2,7 @@ // // This source file is part of the Swift.org open source project // -// Copyright (c) 2014 - 2024 Apple Inc. and the Swift project authors +// Copyright (c) 2014 - 2026 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +import SwiftParser import SwiftRefactor import SwiftSyntax import SwiftSyntaxBuilder @@ -20,7 +21,7 @@ final class ConvertToTernaryExpressionTests: XCTestCase { // MARK: - Basic Pattern Tests func testBasicIfElseWithLetDeclaration() throws { - let baseline: CodeBlockItemListSyntax = """ + let baseline = """ let result: Int if condition { result = 10 @@ -29,58 +30,58 @@ final class ConvertToTernaryExpressionTests: XCTestCase { } """ - let expected: CodeBlockItemListSyntax = """ + let expected = """ let result = condition ? 10 : 20 """ try assertRefactorConvert(baseline, expected: expected) } - func testBasicIfElseWithVarDeclaration() throws { - let baseline: CodeBlockItemListSyntax = """ - var status: String - if isValid { - status = "approved" + func testStandaloneIfElseAssignment() throws { + let baseline = """ + if isActive { + flag = true } else { - status = "rejected" + flag = false } """ - let expected: CodeBlockItemListSyntax = """ - var status = isValid ? "approved" : "rejected" + let expected = """ + flag = isActive ? true : false """ try assertRefactorConvert(baseline, expected: expected) } - func testStandaloneIfElseAssignment() throws { - let baseline: CodeBlockItemListSyntax = """ - if isActive { - flag = true + func testParenthesizedCondition() throws { + let baseline = """ + let output: Int + if (x > 0) { + output = 1 } else { - flag = false + output = 0 } """ - let expected: CodeBlockItemListSyntax = """ - flag = isActive ? true : false + let expected = """ + let output = (x > 0) ? 1 : 0 """ try assertRefactorConvert(baseline, expected: expected) } - func testParenthesizedCondition() throws { - let baseline: CodeBlockItemListSyntax = """ + func testUnparenthesizedCondition() throws { + let baseline = """ let output: Int - if (x > 0) { + if x > 0 { output = 1 } else { output = 0 } """ - let expected: CodeBlockItemListSyntax = """ - let output = (x > 0) ? 1 : 0 + let expected = """ + let output = x > 0 ? 1 : 0 """ try assertRefactorConvert(baseline, expected: expected) @@ -88,7 +89,7 @@ final class ConvertToTernaryExpressionTests: XCTestCase { // MARK: - Tuple Assignment Tests func testSimpleTupleAssignment() throws { - let baseline: CodeBlockItemListSyntax = """ + let baseline = """ let point: (Int, Int) if isOrigin { point = (0, 0) @@ -97,7 +98,7 @@ final class ConvertToTernaryExpressionTests: XCTestCase { } """ - let expected: CodeBlockItemListSyntax = """ + let expected = """ let point = isOrigin ? (0, 0) : (10, 20) """ @@ -105,7 +106,7 @@ final class ConvertToTernaryExpressionTests: XCTestCase { } func testNamedTupleAssignment() throws { - let baseline: CodeBlockItemListSyntax = """ + let baseline = """ let coordinates: (x: Int, y: Int) if reset { coordinates = (x: 0, y: 0) @@ -114,53 +115,17 @@ final class ConvertToTernaryExpressionTests: XCTestCase { } """ - let expected: CodeBlockItemListSyntax = """ + let expected = """ let coordinates: (x: Int, y: Int) = reset ? (x: 0, y: 0) : (x: 100, y: 200) """ try assertRefactorConvert(baseline, expected: expected) } - // MARK: - Complex Expression Tests - - func testFunctionCallInBranches() throws { - let baseline: CodeBlockItemListSyntax = """ - let result: String - if shouldTransform { - result = transform(input) - } else { - result = identity(input) - } - """ - - let expected: CodeBlockItemListSyntax = """ - let result = shouldTransform ? transform(input) : identity(input) - """ - - try assertRefactorConvert(baseline, expected: expected) - } - - func testDictionaryLiteralInBranches() throws { - let baseline: CodeBlockItemListSyntax = """ - let config: [String: Int] - if useDefault { - config = [:] - } else { - config = ["key": 42] - } - """ - - let expected: CodeBlockItemListSyntax = """ - let config = useDefault ? [:] : ["key": 42] - """ - - try assertRefactorConvert(baseline, expected: expected) - } - // MARK: - Negative Tests - Should NOT Refactor func testRejectsElseIfChain() throws { - let baseline: CodeBlockItemListSyntax = """ + let baseline = """ let result: Int if condition1 { result = 1 @@ -174,8 +139,8 @@ final class ConvertToTernaryExpressionTests: XCTestCase { try assertRefactorConvert(baseline, expected: nil) } - func testRejectsNestedTernary() throws { - let baseline: CodeBlockItemListSyntax = """ + func testNestedTernaryInBranch() throws { + let baseline = """ let result: Int if outer { result = inner ? 1 : 2 @@ -184,11 +149,15 @@ final class ConvertToTernaryExpressionTests: XCTestCase { } """ - try assertRefactorConvert(baseline, expected: nil) + let expected = """ + let result = outer ? inner ? 1 : 2 : 3 + """ + + try assertRefactorConvert(baseline, expected: expected) } - func testRejectsClosureInBranch() throws { - let baseline: CodeBlockItemListSyntax = """ + func testClosureInBranch() throws { + let baseline = """ let result: () -> Void if condition { result = { print("hello") } @@ -197,11 +166,15 @@ final class ConvertToTernaryExpressionTests: XCTestCase { } """ - try assertRefactorConvert(baseline, expected: nil) + let expected = """ + let result = condition ? { print("hello") } : { print("goodbye") } + """ + + try assertRefactorConvert(baseline, expected: expected) } func testRejectsDifferentVariablesInBranches() throws { - let baseline: CodeBlockItemListSyntax = """ + let baseline = """ let result: Int if condition { result = 10 @@ -214,7 +187,7 @@ final class ConvertToTernaryExpressionTests: XCTestCase { } func testRejectsNoElseClause() throws { - let baseline: CodeBlockItemListSyntax = """ + let baseline = """ let result: Int if condition { result = 10 @@ -226,17 +199,55 @@ final class ConvertToTernaryExpressionTests: XCTestCase { } private func assertRefactorConvert( - _ baseline: CodeBlockItemListSyntax, - expected: CodeBlockItemListSyntax?, + _ baseline: String, + expected: String?, file: StaticString = #filePath, line: UInt = #line ) throws { - try assertRefactor( - baseline, - context: (), - provider: ConvertToTernaryExpression.self, - expected: expected, - file: file, - line: line - ) + let sourceFile = Parser.parse(source: baseline) + + class IfExprFinder: SyntaxVisitor { + var result: IfExprSyntax? + override func visit(_ node: IfExprSyntax) -> SyntaxVisitorContinueKind { + if result == nil { result = node } + return .skipChildren + } + } + + let finder = IfExprFinder(viewMode: .sourceAccurate) + finder.walk(sourceFile) + + guard let ifExpr = finder.result else { + XCTFail("No IfExprSyntax found in baseline", file: file, line: line) + return + } + + let edits: [SourceEdit] + do { + edits = try ConvertToTernaryExpression.textRefactor(syntax: ifExpr) + } catch { + if expected != nil { + XCTFail("Refactoring failed unexpectedly: \(error)", file: file, line: line) + } + return + } + + guard let expected = expected else { + XCTFail( + "Expected refactoring to fail, but got \(edits.count) edit(s)", + file: file, + line: line + ) + return + } + + var bytes = Array(baseline.utf8) + for edit in edits.sorted(by: { $0.range.lowerBound > $1.range.lowerBound }) { + let start = edit.range.lowerBound.utf8Offset + let end = edit.range.upperBound.utf8Offset + bytes.replaceSubrange(start..