From 865a4e698854fc0938bc50a82936475ff5dab689 Mon Sep 17 00:00:00 2001 From: Marc Prud'hommeaux Date: Thu, 12 Mar 2026 11:22:56 -0400 Subject: [PATCH] Support freestanding @Test functions in transpiled tests --- .../Kotlin/KotlinUnitTestTransformer.swift | 91 ++++++++++++++----- Tests/SkipSyntaxTests/TransformerTests.swift | 78 ++++++++++++++-- 2 files changed, 140 insertions(+), 29 deletions(-) diff --git a/Sources/SkipSyntax/Kotlin/KotlinUnitTestTransformer.swift b/Sources/SkipSyntax/Kotlin/KotlinUnitTestTransformer.swift index f34a0776..de9b59cd 100644 --- a/Sources/SkipSyntax/Kotlin/KotlinUnitTestTransformer.swift +++ b/Sources/SkipSyntax/Kotlin/KotlinUnitTestTransformer.swift @@ -8,7 +8,7 @@ /// 1. **XCTest**: Classes inheriting from `XCTestCase` with `test`-prefixed methods /// 2. **Swift Testing**: Functions annotated with `@Test` and types annotated with `@Suite` /// -/// In both cases, JUnit `@Test` annotations and the AndroidJUnit4 runner annotation are applied. +/// In both cases, JUnit `@Test` annotation is applied. /// Async test functions are wrapped with coroutine test dispatchers. /// /// - Seealso: `SkipUnit/XCTest.kt` @@ -19,6 +19,8 @@ final class KotlinUnitTestTransformer: KotlinTransformer { /// Types annotated with `@Suite` in Swift source. private var swiftTestingSuites: [Source.FilePath: Set] = [:] + static let testRunnerAnnotation: String? = nil // was: "@org.junit.runner.RunWith(androidx.test.ext.junit.runners.AndroidJUnit4::class)" + func gather(from syntaxTree: SyntaxTree) { var testFunctions: Set = [] var suiteTypes: Set = [] @@ -73,27 +75,36 @@ final class KotlinUnitTestTransformer: KotlinTransformer { } private func visit(_ node: KotlinSyntaxNode, codebaseInfo: CodebaseInfo.Context, testFuncNames: Set, importPackages: inout Set) -> VisitResult { - if let functionDeclaration = node as? KotlinFunctionDeclaration, let owningClass = functionDeclaration.parent as? KotlinClassDeclaration { - // Check for XCTest-style test functions (name-based detection) - let isXCTest = Self.isXCTestFunction(functionDeclaration, owningClass: owningClass, codebaseInfo: codebaseInfo) - // Check for Swift Testing @Test functions (attribute-based detection) - let isSwiftTesting = testFuncNames.contains(functionDeclaration.name) - - if isXCTest || isSwiftTesting { - if functionDeclaration.apiFlags.options.contains(.async) { - transformAsyncTest(functionDeclaration: functionDeclaration, owningClass: owningClass, importPackages: &importPackages) - } else { - functionDeclaration.annotations += ["@Test"] - } - let testRunner = "@org.junit.runner.RunWith(androidx.test.ext.junit.runners.AndroidJUnit4::class)" - if !owningClass.annotations.contains(testRunner) { - owningClass.annotations += [testRunner] - } - // For Swift Testing @Suite types that don't extend XCTestCase, - // make them implement the XCTestCase interface for assertion access - if isSwiftTesting && !isXCTest { - ensureXCTestCaseConformance(owningClass) + if let functionDeclaration = node as? KotlinFunctionDeclaration { + if let owningClass = functionDeclaration.parent as? KotlinClassDeclaration { + // Check for XCTest-style test functions (name-based detection) + let isXCTest = Self.isXCTestFunction(functionDeclaration, owningClass: owningClass, codebaseInfo: codebaseInfo) + // Check for Swift Testing @Test functions (attribute-based detection) + let isSwiftTesting = testFuncNames.contains(functionDeclaration.name) + + if isXCTest || isSwiftTesting { + if functionDeclaration.apiFlags.options.contains(.async) { + transformAsyncTest(functionDeclaration: functionDeclaration, owningClass: owningClass, importPackages: &importPackages) + } else { + functionDeclaration.annotations += ["@Test"] + } + if let testRunnerAnnotation = Self.testRunnerAnnotation { + if !owningClass.annotations.contains(testRunnerAnnotation) { + owningClass.annotations += [testRunnerAnnotation] + } + } + // For Swift Testing @Suite types that don't extend XCTestCase, + // make them implement the XCTestCase interface for assertion access + if isSwiftTesting && !isXCTest { + ensureXCTestCaseConformance(owningClass) + } + return .skip } + } else if let owningCodeBlock = functionDeclaration.parent as? KotlinCodeBlock, + functionDeclaration.role == .global, + testFuncNames.contains(functionDeclaration.name) { + // Freestanding @Test function — wrap in a generated test class + wrapFreestandingTestFunction(functionDeclaration, in: owningCodeBlock, importPackages: &importPackages) return .skip } } @@ -115,6 +126,44 @@ final class KotlinUnitTestTransformer: KotlinTransformer { } } + /// Wraps a freestanding `@Test` function in a generated JUnit test class. + /// e.g., `@Test func addition() { ... }` becomes: + /// ``` + /// class AdditionTests: XCTestCase { + /// @Test fun addition() { ... } + /// } + /// ``` + private func wrapFreestandingTestFunction(_ functionDeclaration: KotlinFunctionDeclaration, in codeBlock: KotlinCodeBlock, importPackages: inout Set) { + // Generate a class name from the function name (e.g., "addition" -> "AdditionTests") + let className = functionDeclaration.name.prefix(1).uppercased() + functionDeclaration.name.dropFirst() + "Tests" + + // Create a wrapper class (final, not open) + let classDeclaration = KotlinClassDeclaration(name: className, signature: .named(className, []), declarationType: .classDeclaration) + classDeclaration.modifiers = Modifiers(isFinal: true) + classDeclaration.inherits = [.named("XCTestCase", [])] + if let testRunnerAnnotation = Self.testRunnerAnnotation { + classDeclaration.annotations = [testRunnerAnnotation] + } + classDeclaration.extras = functionDeclaration.extras + + // Move the function into the class + if let index = codeBlock.statements.firstIndex(where: { $0 === functionDeclaration }) { + functionDeclaration.role = .member + functionDeclaration.extras = nil + if functionDeclaration.apiFlags.options.contains(.async) { + transformAsyncTest(functionDeclaration: functionDeclaration, owningClass: classDeclaration, importPackages: &importPackages) + } else { + functionDeclaration.annotations += ["@Test"] + } + classDeclaration.members = [functionDeclaration] + functionDeclaration.parent = classDeclaration + + codeBlock.statements[index] = classDeclaration + classDeclaration.parent = codeBlock + classDeclaration.assignParentReferences() + } + } + private func transformAsyncTest(functionDeclaration: KotlinFunctionDeclaration, owningClass: KotlinClassDeclaration, importPackages: inout Set) { importPackages.insert("kotlinx.coroutines.*") importPackages.insert("kotlinx.coroutines.test.*") diff --git a/Tests/SkipSyntaxTests/TransformerTests.swift b/Tests/SkipSyntaxTests/TransformerTests.swift index 75928ff7..f7b87ec9 100644 --- a/Tests/SkipSyntaxTests/TransformerTests.swift +++ b/Tests/SkipSyntaxTests/TransformerTests.swift @@ -22,7 +22,6 @@ final class TransformerTests: XCTestCase { """, kotlin: """ import skip.unit.* - @org.junit.runner.RunWith(androidx.test.ext.junit.runners.AndroidJUnit4::class) internal open class TestCase: XCTestCase { @Test internal open fun testSomeTest() = Unit @@ -53,7 +52,6 @@ final class TransformerTests: XCTestCase { import skip.unit.* - @org.junit.runner.RunWith(androidx.test.ext.junit.runners.AndroidJUnit4::class) internal open class TestCase: XCTestCase { @OptIn(ExperimentalCoroutinesApi::class) @@ -88,7 +86,6 @@ final class TransformerTests: XCTestCase { """, kotlin: """ import skip.unit.* - @org.junit.runner.RunWith(androidx.test.ext.junit.runners.AndroidJUnit4::class) internal class MyTests: XCTestCase { @Test internal fun addition(): Unit = expectEqual(1 + 1, 2) @@ -109,7 +106,6 @@ final class TransformerTests: XCTestCase { """, kotlin: """ import skip.unit.* - @org.junit.runner.RunWith(androidx.test.ext.junit.runners.AndroidJUnit4::class) internal class MyTests: XCTestCase { @Test internal fun boolCheck() { @@ -132,7 +128,6 @@ final class TransformerTests: XCTestCase { """, kotlin: """ import skip.unit.* - @org.junit.runner.RunWith(androidx.test.ext.junit.runners.AndroidJUnit4::class) internal class MyTests: XCTestCase { @Test internal fun inequality(): Unit = expectNotEqual(1, 2) @@ -153,7 +148,6 @@ final class TransformerTests: XCTestCase { """, kotlin: """ import skip.unit.* - @org.junit.runner.RunWith(androidx.test.ext.junit.runners.AndroidJUnit4::class) internal class MyTests: XCTestCase { @Test internal fun unwrap() { @@ -184,7 +178,6 @@ final class TransformerTests: XCTestCase { """, kotlin: """ import skip.unit.* - @org.junit.runner.RunWith(androidx.test.ext.junit.runners.AndroidJUnit4::class) internal class MathTests: XCTestCase { @Test internal fun addition(): Unit = expectEqual(2 + 2, 4) @@ -212,7 +205,6 @@ final class TransformerTests: XCTestCase { """, kotlin: """ import skip.unit.* - @org.junit.runner.RunWith(androidx.test.ext.junit.runners.AndroidJUnit4::class) internal class CompTests: XCTestCase { @Test internal fun comparisons() { @@ -225,6 +217,76 @@ final class TransformerTests: XCTestCase { """) } + func testSwiftTestingFreestandingFunction() async throws { + try await check(swift: """ + import Testing + + @Test func addition() { + #expect(1 + 2 == 3) + } + """, kotlin: """ + import skip.unit.* + + internal class AdditionTests: XCTestCase { + @Test + internal fun addition(): Unit = expectEqual(1 + 2, 3) + } + """) + } + + func testSwiftTestingMultipleFreestandingFunctions() async throws { + try await check(swift: """ + import Testing + + @Test func addition() { + #expect(1 + 1 == 2) + } + + @Test func subtraction() { + #expect(5 - 3 == 2) + } + + func helperNotATest() -> Int { + return 42 + } + """, kotlin: """ + import skip.unit.* + + internal class AdditionTests: XCTestCase { + @Test + internal fun addition(): Unit = expectEqual(1 + 1, 2) + } + + internal class SubtractionTests: XCTestCase { + @Test + internal fun subtraction(): Unit = expectEqual(5 - 3, 2) + } + + internal fun helperNotATest(): Int = 42 + """) + } + + func testSwiftTestingFreestandingExpectTrue() async throws { + try await check(swift: """ + import Testing + + @Test func boolCheck() { + let x = true + #expect(x) + } + """, kotlin: """ + import skip.unit.* + + internal class BoolCheckTests: XCTestCase { + @Test + internal fun boolCheck() { + val x = true + expectTrue(x) + } + } + """) + } + func testModuleBundleTransformer() async throws { try await check(swift: """ import Foundation