Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 70 additions & 21 deletions Sources/SkipSyntax/Kotlin/KotlinUnitTestTransformer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
/// 1. **XCTest**: Classes inheriting from `XCTestCase` with `test`-prefixed methods
/// 2. **Swift Testing**: Functions annotated with `@Test` and types annotated with `@Suite`
///
/// In both cases, JUnit `@Test` annotations and the AndroidJUnit4 runner annotation are applied.
/// In both cases, JUnit `@Test` annotation is applied.
/// Async test functions are wrapped with coroutine test dispatchers.
///
/// - Seealso: `SkipUnit/XCTest.kt`
Expand All @@ -19,6 +19,8 @@ final class KotlinUnitTestTransformer: KotlinTransformer {
/// Types annotated with `@Suite` in Swift source.
private var swiftTestingSuites: [Source.FilePath: Set<String>] = [:]

static let testRunnerAnnotation: String? = nil // was: "@org.junit.runner.RunWith(androidx.test.ext.junit.runners.AndroidJUnit4::class)"

func gather(from syntaxTree: SyntaxTree) {
var testFunctions: Set<String> = []
var suiteTypes: Set<String> = []
Expand Down Expand Up @@ -73,27 +75,36 @@ final class KotlinUnitTestTransformer: KotlinTransformer {
}

private func visit(_ node: KotlinSyntaxNode, codebaseInfo: CodebaseInfo.Context, testFuncNames: Set<String>, importPackages: inout Set<String>) -> VisitResult<KotlinSyntaxNode> {
if let functionDeclaration = node as? KotlinFunctionDeclaration, let owningClass = functionDeclaration.parent as? KotlinClassDeclaration {
// Check for XCTest-style test functions (name-based detection)
let isXCTest = Self.isXCTestFunction(functionDeclaration, owningClass: owningClass, codebaseInfo: codebaseInfo)
// Check for Swift Testing @Test functions (attribute-based detection)
let isSwiftTesting = testFuncNames.contains(functionDeclaration.name)

if isXCTest || isSwiftTesting {
if functionDeclaration.apiFlags.options.contains(.async) {
transformAsyncTest(functionDeclaration: functionDeclaration, owningClass: owningClass, importPackages: &importPackages)
} else {
functionDeclaration.annotations += ["@Test"]
}
let testRunner = "@org.junit.runner.RunWith(androidx.test.ext.junit.runners.AndroidJUnit4::class)"
if !owningClass.annotations.contains(testRunner) {
owningClass.annotations += [testRunner]
}
// For Swift Testing @Suite types that don't extend XCTestCase,
// make them implement the XCTestCase interface for assertion access
if isSwiftTesting && !isXCTest {
ensureXCTestCaseConformance(owningClass)
if let functionDeclaration = node as? KotlinFunctionDeclaration {
if let owningClass = functionDeclaration.parent as? KotlinClassDeclaration {
// Check for XCTest-style test functions (name-based detection)
let isXCTest = Self.isXCTestFunction(functionDeclaration, owningClass: owningClass, codebaseInfo: codebaseInfo)
// Check for Swift Testing @Test functions (attribute-based detection)
let isSwiftTesting = testFuncNames.contains(functionDeclaration.name)

if isXCTest || isSwiftTesting {
if functionDeclaration.apiFlags.options.contains(.async) {
transformAsyncTest(functionDeclaration: functionDeclaration, owningClass: owningClass, importPackages: &importPackages)
} else {
functionDeclaration.annotations += ["@Test"]
}
if let testRunnerAnnotation = Self.testRunnerAnnotation {
if !owningClass.annotations.contains(testRunnerAnnotation) {
owningClass.annotations += [testRunnerAnnotation]
}
}
// For Swift Testing @Suite types that don't extend XCTestCase,
// make them implement the XCTestCase interface for assertion access
if isSwiftTesting && !isXCTest {
ensureXCTestCaseConformance(owningClass)
}
return .skip
}
} else if let owningCodeBlock = functionDeclaration.parent as? KotlinCodeBlock,
functionDeclaration.role == .global,
testFuncNames.contains(functionDeclaration.name) {
// Freestanding @Test function — wrap in a generated test class
wrapFreestandingTestFunction(functionDeclaration, in: owningCodeBlock, importPackages: &importPackages)
return .skip
}
}
Expand All @@ -115,6 +126,44 @@ final class KotlinUnitTestTransformer: KotlinTransformer {
}
}

/// Wraps a freestanding `@Test` function in a generated JUnit test class.
/// e.g., `@Test func addition() { ... }` becomes:
/// ```
/// class AdditionTests: XCTestCase {
/// @Test fun addition() { ... }
/// }
/// ```
private func wrapFreestandingTestFunction(_ functionDeclaration: KotlinFunctionDeclaration, in codeBlock: KotlinCodeBlock, importPackages: inout Set<String>) {
// Generate a class name from the function name (e.g., "addition" -> "AdditionTests")
let className = functionDeclaration.name.prefix(1).uppercased() + functionDeclaration.name.dropFirst() + "Tests"

// Create a wrapper class (final, not open)
let classDeclaration = KotlinClassDeclaration(name: className, signature: .named(className, []), declarationType: .classDeclaration)
classDeclaration.modifiers = Modifiers(isFinal: true)
classDeclaration.inherits = [.named("XCTestCase", [])]
if let testRunnerAnnotation = Self.testRunnerAnnotation {
classDeclaration.annotations = [testRunnerAnnotation]
}
classDeclaration.extras = functionDeclaration.extras

// Move the function into the class
if let index = codeBlock.statements.firstIndex(where: { $0 === functionDeclaration }) {
functionDeclaration.role = .member
functionDeclaration.extras = nil
if functionDeclaration.apiFlags.options.contains(.async) {
transformAsyncTest(functionDeclaration: functionDeclaration, owningClass: classDeclaration, importPackages: &importPackages)
} else {
functionDeclaration.annotations += ["@Test"]
}
classDeclaration.members = [functionDeclaration]
functionDeclaration.parent = classDeclaration

codeBlock.statements[index] = classDeclaration
classDeclaration.parent = codeBlock
classDeclaration.assignParentReferences()
}
}

private func transformAsyncTest(functionDeclaration: KotlinFunctionDeclaration, owningClass: KotlinClassDeclaration, importPackages: inout Set<String>) {
importPackages.insert("kotlinx.coroutines.*")
importPackages.insert("kotlinx.coroutines.test.*")
Expand Down
78 changes: 70 additions & 8 deletions Tests/SkipSyntaxTests/TransformerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ final class TransformerTests: XCTestCase {
""", kotlin: """
import skip.unit.*

@org.junit.runner.RunWith(androidx.test.ext.junit.runners.AndroidJUnit4::class)
internal open class TestCase: XCTestCase {
@Test
internal open fun testSomeTest() = Unit
Expand Down Expand Up @@ -53,7 +52,6 @@ final class TransformerTests: XCTestCase {

import skip.unit.*

@org.junit.runner.RunWith(androidx.test.ext.junit.runners.AndroidJUnit4::class)
internal open class TestCase: XCTestCase {

@OptIn(ExperimentalCoroutinesApi::class)
Expand Down Expand Up @@ -88,7 +86,6 @@ final class TransformerTests: XCTestCase {
""", kotlin: """
import skip.unit.*

@org.junit.runner.RunWith(androidx.test.ext.junit.runners.AndroidJUnit4::class)
internal class MyTests: XCTestCase {
@Test
internal fun addition(): Unit = expectEqual(1 + 1, 2)
Expand All @@ -109,7 +106,6 @@ final class TransformerTests: XCTestCase {
""", kotlin: """
import skip.unit.*

@org.junit.runner.RunWith(androidx.test.ext.junit.runners.AndroidJUnit4::class)
internal class MyTests: XCTestCase {
@Test
internal fun boolCheck() {
Expand All @@ -132,7 +128,6 @@ final class TransformerTests: XCTestCase {
""", kotlin: """
import skip.unit.*

@org.junit.runner.RunWith(androidx.test.ext.junit.runners.AndroidJUnit4::class)
internal class MyTests: XCTestCase {
@Test
internal fun inequality(): Unit = expectNotEqual(1, 2)
Expand All @@ -153,7 +148,6 @@ final class TransformerTests: XCTestCase {
""", kotlin: """
import skip.unit.*

@org.junit.runner.RunWith(androidx.test.ext.junit.runners.AndroidJUnit4::class)
internal class MyTests: XCTestCase {
@Test
internal fun unwrap() {
Expand Down Expand Up @@ -184,7 +178,6 @@ final class TransformerTests: XCTestCase {
""", kotlin: """
import skip.unit.*

@org.junit.runner.RunWith(androidx.test.ext.junit.runners.AndroidJUnit4::class)
internal class MathTests: XCTestCase {
@Test
internal fun addition(): Unit = expectEqual(2 + 2, 4)
Expand Down Expand Up @@ -212,7 +205,6 @@ final class TransformerTests: XCTestCase {
""", kotlin: """
import skip.unit.*

@org.junit.runner.RunWith(androidx.test.ext.junit.runners.AndroidJUnit4::class)
internal class CompTests: XCTestCase {
@Test
internal fun comparisons() {
Expand All @@ -225,6 +217,76 @@ final class TransformerTests: XCTestCase {
""")
}

func testSwiftTestingFreestandingFunction() async throws {
try await check(swift: """
import Testing

@Test func addition() {
#expect(1 + 2 == 3)
}
""", kotlin: """
import skip.unit.*

internal class AdditionTests: XCTestCase {
@Test
internal fun addition(): Unit = expectEqual(1 + 2, 3)
}
""")
}

func testSwiftTestingMultipleFreestandingFunctions() async throws {
try await check(swift: """
import Testing

@Test func addition() {
#expect(1 + 1 == 2)
}

@Test func subtraction() {
#expect(5 - 3 == 2)
}

func helperNotATest() -> Int {
return 42
}
""", kotlin: """
import skip.unit.*

internal class AdditionTests: XCTestCase {
@Test
internal fun addition(): Unit = expectEqual(1 + 1, 2)
}

internal class SubtractionTests: XCTestCase {
@Test
internal fun subtraction(): Unit = expectEqual(5 - 3, 2)
}

internal fun helperNotATest(): Int = 42
""")
}

func testSwiftTestingFreestandingExpectTrue() async throws {
try await check(swift: """
import Testing

@Test func boolCheck() {
let x = true
#expect(x)
}
""", kotlin: """
import skip.unit.*

internal class BoolCheckTests: XCTestCase {
@Test
internal fun boolCheck() {
val x = true
expectTrue(x)
}
}
""")
}

func testModuleBundleTransformer() async throws {
try await check(swift: """
import Foundation
Expand Down
Loading