diff --git a/Sources/SwiftRefactor/CMakeLists.txt b/Sources/SwiftRefactor/CMakeLists.txt index e478a428a79..aae774e7144 100644 --- a/Sources/SwiftRefactor/CMakeLists.txt +++ b/Sources/SwiftRefactor/CMakeLists.txt @@ -13,6 +13,7 @@ add_swift_syntax_library(SwiftRefactor ConvertComputedPropertyToStored.swift ConvertComputedPropertyToZeroParameterFunction.swift ConvertStoredPropertyToComputed.swift + ConvertToTernaryExpression.swift ConvertZeroParameterFunctionToComputedProperty.swift ExpandEditorPlaceholder.swift FormatRawStringLiteral.swift diff --git a/Sources/SwiftRefactor/Collection+Only.swift b/Sources/SwiftRefactor/Collection+Only.swift new file mode 100644 index 00000000000..7f4681a81b4 --- /dev/null +++ b/Sources/SwiftRefactor/Collection+Only.swift @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2026 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +extension Collection { + /// If the collection contains a single element, return it, otherwise `nil`. + var only: Element? { + if !isEmpty && index(after: startIndex) == endIndex { + return self.first! + } else { + return nil + } + } +} diff --git a/Sources/SwiftRefactor/ConvertToTernaryExpression.swift b/Sources/SwiftRefactor/ConvertToTernaryExpression.swift new file mode 100644 index 00000000000..151f36a9084 --- /dev/null +++ b/Sources/SwiftRefactor/ConvertToTernaryExpression.swift @@ -0,0 +1,306 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2026 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +#if swift(>=6) +public import SwiftSyntax +#else +import SwiftSyntax +#endif + +/// Converts an if-else statement that assigns to the same variable in both branches +/// into a ternary expression assignment. +/// +/// ## Before +/// +/// ```swift +/// let result: Type +/// if condition { +/// result = trueValue +/// } else { +/// result = falseValue +/// } +/// ``` +/// +/// ## After +/// +/// ```swift +/// let result: Type = condition ? trueValue : falseValue +/// ``` +/// +/// Also supports: +/// - Tuple assignments: `(a, b) = condition ? (1, "a") : (2, "b")` +/// - Assignment without declaration: `existingVar = condition ? value1 : value2` +/// +/// This refactoring is applicable when: +/// - There is an if-else statement (no else-if chains) +/// - Both branches contain a single assignment expression +/// - Both assignments target the same variable / same tuple pattern +/// - Optionally, the variable is declared immediately before the if statement +public struct ConvertToTernaryExpression: EditRefactoringProvider { + public static func textRefactor(syntax: IfExprSyntax, in context: Void) throws -> [SourceEdit] { + guard let convertible = try ConvertibleIfElse(ifExpr: syntax) else { + throw RefactoringNotApplicableError( + "Cannot convert: if-else branches must each contain a single assignment to the same variable" + ) + } + + // Walk up to find the CodeBlockItemSyntax containing this if expression + let ifItem: CodeBlockItemSyntax + if let exprStmt = syntax.parent?.as(ExpressionStmtSyntax.self), + let item = exprStmt.parent?.as(CodeBlockItemSyntax.self) + { + ifItem = item + } else if let item = syntax.parent?.as(CodeBlockItemSyntax.self) { + ifItem = item + } else { + throw RefactoringNotApplicableError("if expression is not in a code block") + } + + guard let codeBlockList = ifItem.parent?.as(CodeBlockItemListSyntax.self) else { + throw RefactoringNotApplicableError("if expression is not in a code block list") + } + + // Check for a preceding compatible variable declaration + let items = Array(codeBlockList) + guard let ifIdx = items.firstIndex(where: { $0.id == ifItem.id }) else { + throw RefactoringNotApplicableError("cannot find if expression in code block") + } + + let precedingItem = ifIdx > 0 ? items[ifIdx - 1] : nil + if let precedingItem, + let varDecl = precedingItem.item.as(VariableDeclSyntax.self), + varDecl.isVariableDeclarationWithoutInitializer, + let binding = varDecl.bindings.only, + let identifierPattern = binding.pattern.as(IdentifierPatternSyntax.self), + identifierPattern.identifier.text == convertible.assignmentTargetName + { + // Case: merge preceding decl + if → single declaration with ternary initializer. + // Replace the range from the start of the declaration to the end of the if expression + // (excluding surrounding trivia), preserving whatever whitespace is on either side. + let newDecl = createTernaryDeclaration(from: convertible, variableDecl: varDecl) + let range = precedingItem.positionAfterSkippingLeadingTrivia.. AssignmentInfo? { + guard let statement = codeBlock.statements.only else { + return nil + } + + let expr: ExprSyntax + if let exprStmt = statement.item.as(ExpressionStmtSyntax.self) { + expr = exprStmt.expression + } else if let directExpr = statement.item.as(ExprSyntax.self) { + expr = directExpr + } else { + return nil + } + + guard let sequenceExpr = expr.as(SequenceExprSyntax.self) else { + return nil + } + + return try extractFromSequenceAssignment(sequenceExpr) + } + + private static func extractFromSequenceAssignment( + _ sequenceExpr: SequenceExprSyntax + ) throws -> AssignmentInfo? { + let elements = Array(sequenceExpr.elements) + + guard elements.count >= 3, elements[1].as(AssignmentExprSyntax.self) != nil else { + return nil + } + + let lhs = ExprSyntax(elements[0]) + let rhs: ExprSyntax + if elements.count == 3 { + rhs = ExprSyntax(elements[2]) + } else { + rhs = ExprSyntax(SequenceExprSyntax(elements: ExprListSyntax(Array(elements[2...])))) + } + + return AssignmentInfo( + targetExpr: lhs, + targetName: lhs.as(DeclReferenceExprSyntax.self)?.baseName.text, + valueExpr: rhs + ) + } + + // MARK: - Builders + + private static func makeTernaryExpr(from convertible: ConvertibleIfElse) -> TernaryExprSyntax { + TernaryExprSyntax( + condition: convertible.condition.trimmed.with(\.trailingTrivia, .space), + questionMark: .infixQuestionMarkToken(trailingTrivia: .space), + thenExpression: convertible.trueExpr.trimmed.with(\.trailingTrivia, .space), + colon: .colonToken(trailingTrivia: .space), + elseExpression: convertible.falseExpr.trimmed + ) + } + + /// Creates the new variable declaration with ternary initializer. + /// Preserves the original type annotation for named tuples only. + private static func createTernaryDeclaration( + from convertible: ConvertibleIfElse, + variableDecl: VariableDeclSyntax + ) -> VariableDeclSyntax { + guard let originalBinding = variableDecl.bindings.only else { + fatalError("Invalid state: binding should exist") + } + + let ternaryExpr = makeTernaryExpr(from: convertible) + + let preserveAnnotation = originalBinding.typeAnnotation.map { isNamedTupleType($0.type) } ?? false + let typeAnnotation: TypeAnnotationSyntax? = preserveAnnotation ? originalBinding.typeAnnotation?.trimmed : nil + + let newBinding = + originalBinding + .with(\.typeAnnotation, typeAnnotation) + .with( + \.initializer, + InitializerClauseSyntax( + equal: .equalToken(leadingTrivia: .space, trailingTrivia: .space), + value: ExprSyntax(ternaryExpr) + ) + ) + + return variableDecl.with(\.bindings, PatternBindingListSyntax([newBinding])) + } + + /// Creates a ternary assignment expression (when no declaration exists). + private static func createTernaryAssignment(from convertible: ConvertibleIfElse) -> ExprSyntax { + let ternaryExpr = makeTernaryExpr(from: convertible) + + let assignmentSeq = SequenceExprSyntax( + elements: ExprListSyntax([ + convertible.assignmentTargetExpr.with(\.trailingTrivia, .space), + ExprSyntax(AssignmentExprSyntax()).with(\.trailingTrivia, .space), + ExprSyntax(ternaryExpr), + ]) + ) + + return ExprSyntax(assignmentSeq) + } + + // MARK: - Validation Helpers + + private static func isNamedTupleType(_ type: TypeSyntax) -> Bool { + guard let tupleType = type.as(TupleTypeSyntax.self) else { return false } + return tupleType.elements.contains { $0.firstName != nil } + } +} + +// MARK: - Models + +/// Holds the extracted components of a convertible if-else pattern. +/// +/// Example: given +/// ```swift +/// if condition { +/// result = trueValue +/// } else { +/// result = falseValue +/// } +/// ``` +/// - `assignmentTargetExpr` is `result` +/// - `assignmentTargetName` is `"result"` (`nil` for tuple targets like `(x, y)`) +/// - `condition` is `condition` +/// - `trueExpr` is `trueValue` (from the then-branch) +/// - `falseExpr` is `falseValue` (from the else-branch) +private struct ConvertibleIfElse { + /// LHS of the assignment (`result` or `(x, y)`). + let assignmentTargetExpr: ExprSyntax + + /// Only present when LHS is a simple identifier (e.g. `result`). + let assignmentTargetName: String? + + let condition: ExprSyntax + let trueExpr: ExprSyntax + let falseExpr: ExprSyntax + + init?(ifExpr: IfExprSyntax) throws { + guard let conditionElement = ifExpr.conditions.only, + case .expression(let condition) = conditionElement.condition + else { + return nil + } + + guard let elseBody = ifExpr.elseBody, case .codeBlock(let elseBlock) = elseBody else { + return nil + } + + guard let thenAssignment = try ConvertToTernaryExpression.extractSingleAssignment(from: ifExpr.body) else { + return nil + } + + guard let elseAssignment = try ConvertToTernaryExpression.extractSingleAssignment(from: elseBlock) else { + return nil + } + + guard thenAssignment.targetExpr.trimmed.description == elseAssignment.targetExpr.trimmed.description else { + return nil + } + + self.assignmentTargetExpr = thenAssignment.targetExpr + self.assignmentTargetName = thenAssignment.targetName + self.condition = condition + self.trueExpr = thenAssignment.valueExpr + self.falseExpr = elseAssignment.valueExpr + } +} + +/// Represents the extracted components of a single assignment expression. +/// +/// Example: for `result = trueValue`: +/// - `targetExpr` is `result` +/// - `targetName` is `"result"` (nil for tuple targets like `(x, y)`) +/// - `valueExpr` is `trueValue` +/// +/// Example: for `(x, y) = (1, 2)`: +/// - `targetExpr` is `(x, y)` +/// - `targetName` is `nil` +/// - `valueExpr` is `(1, 2)` +private struct AssignmentInfo { + let targetExpr: ExprSyntax + let targetName: String? + let valueExpr: ExprSyntax +} + +extension VariableDeclSyntax { + /// Returns `true` if this is a `let` or `var` declaration with a type annotation + /// but no initializer and no attributes. + /// + /// Example: `let result: Int` → `true`; `let result = 0` → `false` + fileprivate var isVariableDeclarationWithoutInitializer: Bool { + guard let binding = bindings.only, + binding.typeAnnotation?.type != nil, + binding.initializer == nil, + attributes.isEmpty + else { + return false + } + return true + } +} diff --git a/Tests/SwiftRefactorTest/ConvertToTernaryExpressionTests.swift b/Tests/SwiftRefactorTest/ConvertToTernaryExpressionTests.swift new file mode 100644 index 00000000000..c07cf5fdefa --- /dev/null +++ b/Tests/SwiftRefactorTest/ConvertToTernaryExpressionTests.swift @@ -0,0 +1,253 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2026 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +import SwiftParser +import SwiftRefactor +import SwiftSyntax +import SwiftSyntaxBuilder +import XCTest +import _SwiftSyntaxTestSupport + +final class ConvertToTernaryExpressionTests: XCTestCase { + + // MARK: - Basic Pattern Tests + func testBasicIfElseWithLetDeclaration() throws { + let baseline = """ + let result: Int + if condition { + result = 10 + } else { + result = 20 + } + """ + + let expected = """ + let result = condition ? 10 : 20 + """ + + try assertRefactorConvert(baseline, expected: expected) + } + + func testStandaloneIfElseAssignment() throws { + let baseline = """ + if isActive { + flag = true + } else { + flag = false + } + """ + + let expected = """ + flag = isActive ? true : false + """ + + try assertRefactorConvert(baseline, expected: expected) + } + + func testParenthesizedCondition() throws { + let baseline = """ + let output: Int + if (x > 0) { + output = 1 + } else { + output = 0 + } + """ + + let expected = """ + let output = (x > 0) ? 1 : 0 + """ + + try assertRefactorConvert(baseline, expected: expected) + } + + func testUnparenthesizedCondition() throws { + let baseline = """ + let output: Int + if x > 0 { + output = 1 + } else { + output = 0 + } + """ + + let expected = """ + let output = x > 0 ? 1 : 0 + """ + + try assertRefactorConvert(baseline, expected: expected) + } + + // MARK: - Tuple Assignment Tests + func testSimpleTupleAssignment() throws { + let baseline = """ + let point: (Int, Int) + if isOrigin { + point = (0, 0) + } else { + point = (10, 20) + } + """ + + let expected = """ + let point = isOrigin ? (0, 0) : (10, 20) + """ + + try assertRefactorConvert(baseline, expected: expected) + } + + func testNamedTupleAssignment() throws { + let baseline = """ + let coordinates: (x: Int, y: Int) + if reset { + coordinates = (x: 0, y: 0) + } else { + coordinates = (x: 100, y: 200) + } + """ + + let expected = """ + let coordinates: (x: Int, y: Int) = reset ? (x: 0, y: 0) : (x: 100, y: 200) + """ + + try assertRefactorConvert(baseline, expected: expected) + } + + // MARK: - Negative Tests - Should NOT Refactor + + func testRejectsElseIfChain() throws { + let baseline = """ + let result: Int + if condition1 { + result = 1 + } else if condition2 { + result = 2 + } else { + result = 3 + } + """ + + try assertRefactorConvert(baseline, expected: nil) + } + + func testNestedTernaryInBranch() throws { + let baseline = """ + let result: Int + if outer { + result = inner ? 1 : 2 + } else { + result = 3 + } + """ + + let expected = """ + let result = outer ? inner ? 1 : 2 : 3 + """ + + try assertRefactorConvert(baseline, expected: expected) + } + + func testClosureInBranch() throws { + let baseline = """ + let result: () -> Void + if condition { + result = { print("hello") } + } else { + result = { print("goodbye") } + } + """ + + let expected = """ + let result = condition ? { print("hello") } : { print("goodbye") } + """ + + try assertRefactorConvert(baseline, expected: expected) + } + + func testRejectsDifferentVariablesInBranches() throws { + let baseline = """ + let result: Int + if condition { + result = 10 + } else { + other = 20 + } + """ + + try assertRefactorConvert(baseline, expected: nil) + } + + func testRejectsNoElseClause() throws { + let baseline = """ + let result: Int + if condition { + result = 10 + } + """ + + try assertRefactorConvert(baseline, expected: nil) + } +} + +private func assertRefactorConvert( + _ baseline: String, + expected: String?, + file: StaticString = #filePath, + line: UInt = #line +) throws { + let sourceFile = Parser.parse(source: baseline) + + class IfExprFinder: SyntaxVisitor { + var result: IfExprSyntax? + override func visit(_ node: IfExprSyntax) -> SyntaxVisitorContinueKind { + if result == nil { result = node } + return .skipChildren + } + } + + let finder = IfExprFinder(viewMode: .sourceAccurate) + finder.walk(sourceFile) + + guard let ifExpr = finder.result else { + XCTFail("No IfExprSyntax found in baseline", file: file, line: line) + return + } + + let edits: [SourceEdit] + do { + edits = try ConvertToTernaryExpression.textRefactor(syntax: ifExpr) + } catch { + if expected != nil { + XCTFail("Refactoring failed unexpectedly: \(error)", file: file, line: line) + } + return + } + + guard let expected = expected else { + XCTFail( + "Expected refactoring to fail, but got \(edits.count) edit(s)", + file: file, + line: line + ) + return + } + + var bytes = Array(baseline.utf8) + for edit in edits.sorted(by: { $0.range.lowerBound > $1.range.lowerBound }) { + let start = edit.range.lowerBound.utf8Offset + let end = edit.range.upperBound.utf8Offset + bytes.replaceSubrange(start..