diff --git a/Sources/MockoloFramework/Models/ConditionalImportBlock.swift b/Sources/MockoloFramework/Models/ConditionalImportBlock.swift index f5e87a91..12c987a3 100644 --- a/Sources/MockoloFramework/Models/ConditionalImportBlock.swift +++ b/Sources/MockoloFramework/Models/ConditionalImportBlock.swift @@ -17,20 +17,34 @@ /// Represents import content: either a simple import statement or a nested conditional block indirect enum ImportContent { case simple(Import) - case conditional(ConditionalImportBlock) + case conditional(ConditionalBlock) } -/// Represents a conditional import block (#if/#elseif/#else/#endif) -struct ConditionalImportBlock { - /// Represents a single clause in a conditional import block +/// Represents a conditional compilation block (#if/#elseif/#else/#endif) that owns +/// both imports and entities found within its clauses. +struct ConditionalBlock { + /// Represents a single clause in a conditional compilation block struct Clause { var type: IfClauseType - var contents: [ImportContent] + var imports: [ImportContent] + var entities: [Entity] } let clauses: [Clause] let offset: Int64 + /// Whether any clause (including nested blocks) contains entities + var containsEntities: Bool { + clauses.contains { clause in + !clause.entities.isEmpty || clause.imports.contains { content in + if case .conditional(let nested) = content { + return nested.containsEntities + } + return false + } + } + } + init(clauses: [Clause], offset: Int64) { self.clauses = clauses self.offset = offset diff --git a/Sources/MockoloFramework/Operations/Generator.swift b/Sources/MockoloFramework/Operations/Generator.swift index 6a8becb0..3d111b63 100644 --- a/Sources/MockoloFramework/Operations/Generator.swift +++ b/Sources/MockoloFramework/Operations/Generator.swift @@ -148,8 +148,30 @@ public func generate(sourceDirs: [String], signpost_begin(name: "Render models") log("Render models with templates...", level: .info) + + // Extract conditional blocks that contain entities from relevant source files + var conditionalEntityBlocks = [ConditionalBlock]() + func collectEntityBlocks(from contents: [ImportContent]) { + for content in contents { + if case .conditional(let block) = content { + if block.containsEntities { + conditionalEntityBlocks.append(block) + } + for clause in block.clauses { + collectEntityBlocks(from: clause.imports) + } + } + } + } + for (path, parsedImports) in pathToImportsMap { + guard relevantPaths.contains(path) else { continue } + collectEntityBlocks(from: parsedImports) + } + conditionalEntityBlocks.sort(by: { $0.offset < $1.offset }) + renderTemplates( entities: resolvedEntities, + conditionalBlocks: conditionalEntityBlocks, arguments: .init( useTemplateFunc: useTemplateFunc, allowSetCallCount: allowSetCallCount, diff --git a/Sources/MockoloFramework/Operations/ImportsHandler.swift b/Sources/MockoloFramework/Operations/ImportsHandler.swift index 01bbdf58..db6d4151 100644 --- a/Sources/MockoloFramework/Operations/ImportsHandler.swift +++ b/Sources/MockoloFramework/Operations/ImportsHandler.swift @@ -22,7 +22,7 @@ func handleImports(pathToImportsMap: ImportMap, testableImports: [String]?, relevantPaths: [String]) -> String { var topLevelImports: [Import] = [] - var conditionalBlocks: [ConditionalImportBlock] = [] + var conditionalBlocks: [ConditionalBlock] = [] // 1. Collect imports from all relevant files for (path, parsedImports) in pathToImportsMap { @@ -98,7 +98,12 @@ private func renderImportContents( resolveAccumulatedSimpleImports() var result = "" + var hasImportOutput = false for clause in block.clauses { + let rendered = renderImportContents(clause.imports, excludeImports: excludeImports, testableImports: testableImports) + if !rendered.isEmpty { + hasImportOutput = true + } switch clause.type { case .if(let condition): result += "#if \(condition)\n" @@ -107,12 +112,13 @@ private func renderImportContents( case .else: result += "#else\n" } - // Recursively render nested block - result += renderImportContents(clause.contents, excludeImports: excludeImports, testableImports: testableImports) + result += rendered result += "\n" } result += "#endif" - clauseLines.append(result) + if hasImportOutput { + clauseLines.append(result) + } } } resolveAccumulatedSimpleImports() @@ -126,7 +132,7 @@ private func visitModuleName(_ contents: [ImportContent]) -> [String] { case .simple(let `import`): return [`import`.moduleName] case .conditional(let block): - return visitModuleName(block.clauses.flatMap(\.contents)) + return visitModuleName(block.clauses.flatMap(\.imports)) } } } diff --git a/Sources/MockoloFramework/Operations/TemplateRenderer.swift b/Sources/MockoloFramework/Operations/TemplateRenderer.swift index 12379c14..adaa58cc 100644 --- a/Sources/MockoloFramework/Operations/TemplateRenderer.swift +++ b/Sources/MockoloFramework/Operations/TemplateRenderer.swift @@ -14,20 +14,106 @@ // limitations under the License. // +import Foundation + /// Renders models with templates for output func renderTemplates(entities: [ResolvedEntity], + conditionalBlocks: [ConditionalBlock], arguments: GenerationArguments, completion: @escaping (String, Int64) -> ()) { - scan(entities) { (resolvedEntity, lock) in + // Build lookup from entity name to resolved entity + let resolvedByName = Dictionary( + entities.map { ($0.key, $0) }, + uniquingKeysWith: { $1 } + ) + + // Collect names of entities that live inside conditional blocks + var conditionalEntityNames = Set() + func collectEntityNames(from blocks: [ConditionalBlock]) { + for block in blocks { + for clause in block.clauses { + for entity in clause.entities { + conditionalEntityNames.insert(entity.entityNode.nameText) + } + for content in clause.imports { + if case .conditional(let nested) = content { + collectEntityNames(from: [nested]) + } + } + } + } + } + collectEntityNames(from: conditionalBlocks) + + // Render conditional blocks, preserving #if/#elseif/#else/#endif structure + func renderBlock(_ block: ConditionalBlock) -> String? { + var lines = [String]() + var blockHasOutput = false + + for clause in block.clauses { + var clauseLines = [String]() + + // Render entities in this clause + for entity in clause.entities { + if let resolved = resolvedByName[entity.entityNode.nameText] { + let mockModel = resolved.model() + if let mockString = mockModel.render( + context: .init(), + arguments: arguments + ), !mockString.isEmpty { + clauseLines.append(mockString) + } + } + } + + // Recurse into nested conditional blocks + for content in clause.imports { + if case .conditional(let nested) = content { + if let nestedOutput = renderBlock(nested) { + clauseLines.append(nestedOutput) + } + } + } + + guard !clauseLines.isEmpty else { continue } + blockHasOutput = true + + switch clause.type { + case .if(let condition): + lines.append("#if \(condition)") + case .elseif(let condition): + lines.append("#elseif \(condition)") + case .else: + lines.append("#else") + } + lines.append(contentsOf: clauseLines) + } + + guard blockHasOutput else { return nil } + lines.append("#endif") + return lines.joined(separator: "\n") + } + + for block in conditionalBlocks { + if let rendered = renderBlock(block) { + completion(rendered, block.offset) + } + } + + // Render standalone entities (not inside any conditional block) + let standalone = entities.filter { !conditionalEntityNames.contains($0.key) } + + let lock = NSLock() + scan(standalone) { (resolvedEntity, _) in let mockModel = resolvedEntity.model() if let mockString = mockModel.render( context: .init(), arguments: arguments ), !mockString.isEmpty { - lock?.lock() + lock.lock() completion(mockString, mockModel.offset) - lock?.unlock() + lock.unlock() } } } diff --git a/Sources/MockoloFramework/Parsers/SwiftSyntaxExtensions.swift b/Sources/MockoloFramework/Parsers/SwiftSyntaxExtensions.swift index e4cf2798..532afa2d 100644 --- a/Sources/MockoloFramework/Parsers/SwiftSyntaxExtensions.swift +++ b/Sources/MockoloFramework/Parsers/SwiftSyntaxExtensions.swift @@ -735,8 +735,7 @@ final class EntityVisitor: SyntaxVisitor { } override func visit(_ node: ProtocolDeclSyntax) -> SyntaxVisitorContinueKind { - let metadata = node.annotationMetadata(with: annotation) - if let ent = Entity.node(with: node, filepath: path, isPrivate: node.isPrivate, isFinal: false, metadata: metadata, processed: false) { + if let ent = makeProtocolEntity(node) { entities.append(ent) } return .skipChildren @@ -751,18 +750,8 @@ final class EntityVisitor: SyntaxVisitor { } override func visit(_ node: ClassDeclSyntax) -> SyntaxVisitorContinueKind { - if scanAsMockfile || node.nameText.hasSuffix("Mock") { - // this mock class node must be public else wouldn't have compiled before - if let ent = Entity.node(with: node, filepath: path, isPrivate: node.isPrivate, isFinal: false, metadata: nil, processed: true) { - entities.append(ent) - } - } else { - if declType == .classType || declType == .all { - let metadata = node.annotationMetadata(with: annotation) - if let ent = Entity.node(with: node, filepath: path, isPrivate: node.isPrivate, isFinal: node.isFinal, metadata: metadata, processed: false) { - entities.append(ent) - } - } + if let ent = makeClassEntity(node) { + entities.append(ent) } return node.genericParameterClause != nil ? .skipChildren : .visitChildren } @@ -772,7 +761,6 @@ final class EntityVisitor: SyntaxVisitor { } override func visit(_ node: ImportDeclSyntax) -> SyntaxVisitorContinueKind { - // Top-level import (not inside #if) if let `import` = Import(line: node.trimmedDescription) { imports.append(.simple(`import`)) } @@ -782,48 +770,75 @@ final class EntityVisitor: SyntaxVisitor { override func visit(_ node: IfConfigDeclSyntax) -> SyntaxVisitorContinueKind { // Check if this is a file macro that should be ignored if let firstCondition = node.clauses.first?.condition?.trimmedDescription, - firstCondition == fileMacro { + !fileMacro.isEmpty, firstCondition == fileMacro { return .visitChildren } - // Parse conditional import block recursively - let block = parseIfConfigDecl(node) - imports.append(.conditional(block)) + let clauses = processTopLevelIfConfig(node) + let hasContent = clauses.contains { !$0.imports.isEmpty || !$0.entities.isEmpty } + if hasContent { + imports.append(.conditional(ConditionalBlock(clauses: clauses, offset: node.offset))) + } return .skipChildren } - /// Recursively parses an IfConfigDeclSyntax into a ConditionalImportBlock - private func parseIfConfigDecl(_ node: IfConfigDeclSyntax) -> ConditionalImportBlock { - var clauseList = [ConditionalImportBlock.Clause]() + /// Processes a top-level #if block, collecting imports and entities into clauses. + /// Entities are also added to `self.entities` so they appear in the protocol map. + private func processTopLevelIfConfig(_ node: IfConfigDeclSyntax) -> [ConditionalBlock.Clause] { + var result = [ConditionalBlock.Clause]() for cl in node.clauses { - guard let clauseType = IfClauseType(cl) else { - continue - } + guard let clauseType = IfClauseType(cl) else { continue } + + var clauseImports = [ImportContent]() + var clauseEntities = [Entity]() - var contents = [ImportContent]() if let list = cl.elements?.as(CodeBlockItemListSyntax.self) { for el in list { if let importItem = el.item.as(ImportDeclSyntax.self) { - // Simple import if let imp = Import(line: importItem.trimmedDescription) { - contents.append(.simple(imp)) + clauseImports.append(.simple(imp)) + } + } else if let protocolDecl = el.item.as(ProtocolDeclSyntax.self) { + if let ent = makeProtocolEntity(protocolDecl) { + clauseEntities.append(ent) + } + } else if let classDecl = el.item.as(ClassDeclSyntax.self) { + if let ent = makeClassEntity(classDecl) { + clauseEntities.append(ent) + } + } else if let nestedIfConfig = el.item.as(IfConfigDeclSyntax.self) { + let nestedClauses = processTopLevelIfConfig(nestedIfConfig) + let hasContent = nestedClauses.contains { !$0.imports.isEmpty || !$0.entities.isEmpty } + if hasContent { + clauseImports.append(.conditional(ConditionalBlock(clauses: nestedClauses, offset: nestedIfConfig.offset))) } - } else if let nested = el.item.as(IfConfigDeclSyntax.self) { - // Nested #if block (recursive) - let nestedBlock = parseIfConfigDecl(nested) - contents.append(.conditional(nestedBlock)) } } } - clauseList.append(ConditionalImportBlock.Clause( - type: clauseType, - contents: contents - )) + // Also register clause entities in the flat list for the protocol map + entities.append(contentsOf: clauseEntities) + + result.append(ConditionalBlock.Clause(type: clauseType, imports: clauseImports, entities: clauseEntities)) } - return ConditionalImportBlock(clauses: clauseList, offset: node.offset) + return result + } + + private func makeProtocolEntity(_ node: ProtocolDeclSyntax) -> Entity? { + let metadata = node.annotationMetadata(with: annotation) + return Entity.node(with: node, filepath: path, isPrivate: node.isPrivate, isFinal: false, metadata: metadata, processed: false) + } + + private func makeClassEntity(_ node: ClassDeclSyntax) -> Entity? { + if scanAsMockfile || node.nameText.hasSuffix("Mock") { + return Entity.node(with: node, filepath: path, isPrivate: node.isPrivate, isFinal: false, metadata: nil, processed: true) + } else if declType == .classType || declType == .all { + let metadata = node.annotationMetadata(with: annotation) + return Entity.node(with: node, filepath: path, isPrivate: node.isPrivate, isFinal: node.isFinal, metadata: metadata, processed: false) + } + return nil } override func visit(_ node: InitializerDeclSyntax) -> SyntaxVisitorContinueKind { diff --git a/Tests/TestConditionalImportBlocks/ConditionalImportBlocksTests.swift b/Tests/TestConditionalImportBlocks/ConditionalImportBlocksTests.swift new file mode 100644 index 00000000..6f059326 --- /dev/null +++ b/Tests/TestConditionalImportBlocks/ConditionalImportBlocksTests.swift @@ -0,0 +1,25 @@ +import XCTest +@testable import MockoloFramework + +final class ConditionalImportBlocksTests: MockoloTestCase { + func testProtocolInsideIfBlockWithNonImportDeclaration() { + verify(srcContent: FixtureConditionalImportBlocks.protocolInIfBlock, + dstContent: FixtureConditionalImportBlocks.protocolInIfBlockMock) + } + func testConditionalImportBlockPreserved() { + verify(srcContent: FixtureConditionalImportBlocks.conditionalImportBlock, + dstContent: FixtureConditionalImportBlocks.conditionalImportBlockMock) + } + func testNestedIfBlocksWithMultipleProtocols() { + verify(srcContent: FixtureConditionalImportBlocks.nestedIfBlocks, + dstContent: FixtureConditionalImportBlocks.nestedIfBlocksMock) + } + func testIfBlockWithImportsAndProtocol() { + verify(srcContent: FixtureConditionalImportBlocks.ifBlockWithImportsAndProtocol, + dstContent: FixtureConditionalImportBlocks.ifBlockWithImportsAndProtocolMock) + } + func testMixedNestedBlocks() { + verify(srcContent: FixtureConditionalImportBlocks.mixedNestedBlocks, + dstContent: FixtureConditionalImportBlocks.mixedNestedBlocksMock) + } +} diff --git a/Tests/TestConditionalImportBlocks/FixtureConditionalImportBlocks.swift b/Tests/TestConditionalImportBlocks/FixtureConditionalImportBlocks.swift new file mode 100644 index 00000000..604babf6 --- /dev/null +++ b/Tests/TestConditionalImportBlocks/FixtureConditionalImportBlocks.swift @@ -0,0 +1,197 @@ +enum FixtureConditionalImportBlocks { + + /// Protocol inside a #if block that contains non-import declarations + static let protocolInIfBlock = + """ + #if os(iOS) + /// @mockable + public protocol PlatformProtocol { + func platformFunction() + } + #endif + """ + + /// Expected mock for protocol inside #if block — mock is wrapped in the same #if + static let protocolInIfBlockMock = + """ + #if os(iOS) + public class PlatformProtocolMock: PlatformProtocol { + public init() { } + + + public private(set) var platformFunctionCallCount = 0 + public var platformFunctionHandler: (() -> ())? + public func platformFunction() { + platformFunctionCallCount += 1 + if let platformFunctionHandler = platformFunctionHandler { + platformFunctionHandler() + } + } + } + #endif + """ + + /// Protocol inside a #if block containing only imports (should be treated as conditional import) + static let conditionalImportBlock = + """ + #if canImport(Foundation) + import Foundation + #endif + + /// @mockable + public protocol ServiceProtocol { + func execute() + } + """ + + /// Expected output with conditional import preserved and protocol mocked + static let conditionalImportBlockMock = + """ + #if canImport(Foundation) + import Foundation + #endif + + + public class ServiceProtocolMock: ServiceProtocol { + public init() { } + + + public private(set) var executeCallCount = 0 + public var executeHandler: (() -> ())? + public func execute() { + executeCallCount += 1 + if let executeHandler = executeHandler { + executeHandler() + } + } + } + """ + + /// Multiple protocols in nested #if blocks with mixed content + static let nestedIfBlocks = + """ + #if os(iOS) + /// @mockable + public protocol iOSProtocol { + func iosMethod() + } + #elseif os(macOS) + /// @mockable + public protocol macOSProtocol { + func macosMethod() + } + #endif + """ + + /// Expected mocks for both protocols, preserving #if/#elseif structure + static let nestedIfBlocksMock = + """ + #if os(iOS) + public class iOSProtocolMock: iOSProtocol { + public init() { } + + + public private(set) var iosMethodCallCount = 0 + public var iosMethodHandler: (() -> ())? + public func iosMethod() { + iosMethodCallCount += 1 + if let iosMethodHandler = iosMethodHandler { + iosMethodHandler() + } + } + } + #elseif os(macOS) + public class macOSProtocolMock: macOSProtocol { + public init() { } + + + public private(set) var macosMethodCallCount = 0 + public var macosMethodHandler: (() -> ())? + public func macosMethod() { + macosMethodCallCount += 1 + if let macosMethodHandler = macosMethodHandler { + macosMethodHandler() + } + } + } + #endif + """ + + /// #if block with imports and a protocol (should visit children and discover protocol) + static let ifBlockWithImportsAndProtocol = + """ + #if DEBUG + import XCTest + /// @mockable + public protocol DebugProtocol { + func debugFunction() + } + #endif + """ + + /// Import is captured as conditional import, mock is wrapped in #if + static let ifBlockWithImportsAndProtocolMock = + """ + #if DEBUG + import XCTest + #endif + + + #if DEBUG + public class DebugProtocolMock: DebugProtocol { + public init() { } + + + public private(set) var debugFunctionCallCount = 0 + public var debugFunctionHandler: (() -> ())? + public func debugFunction() { + debugFunctionCallCount += 1 + if let debugFunctionHandler = debugFunctionHandler { + debugFunctionHandler() + } + } + } + #endif + """ + + /// Nested #if blocks where inner only contains imports + static let mixedNestedBlocks = + """ + #if os(iOS) + #if DEBUG + import XCTest + #endif + /// @mockable + public protocol MixedProtocol { + func mixedMethod() + } + #endif + """ + + /// Nested import block preserved, mock wrapped in outer #if + static let mixedNestedBlocksMock = + """ + #if os(iOS) + #if DEBUG + import XCTest + #endif + #endif + + + #if os(iOS) + public class MixedProtocolMock: MixedProtocol { + public init() { } + + + public private(set) var mixedMethodCallCount = 0 + public var mixedMethodHandler: (() -> ())? + public func mixedMethod() { + mixedMethodCallCount += 1 + if let mixedMethodHandler = mixedMethodHandler { + mixedMethodHandler() + } + } + } + #endif + """ +}