diff --git a/__tests__/extraction.test.ts b/__tests__/extraction.test.ts index 8a70ffed..66deafab 100644 --- a/__tests__/extraction.test.ts +++ b/__tests__/extraction.test.ts @@ -3079,3 +3079,273 @@ describe('Directory Exclusion', () => { expect(files.every((f) => !f.includes('vendor'))).toBe(true); }); }); + +// ============================================================================= +// Scala +// ============================================================================= + +describe('Scala Extraction', () => { + describe('Language detection', () => { + it('should detect Scala files', () => { + expect(detectLanguage('Main.scala')).toBe('scala'); + expect(detectLanguage('script.sc')).toBe('scala'); + expect(detectLanguage('src/UserService.scala')).toBe('scala'); + }); + + it('should report Scala as supported', () => { + expect(isLanguageSupported('scala')).toBe(true); + expect(getSupportedLanguages()).toContain('scala'); + }); + }); + + describe('Class extraction', () => { + it('should extract class definitions', () => { + const code = ` +class UserService(private val repo: UserRepository) { + def findUser(id: String): Option[String] = Some(id) +} +`; + const result = extractFromSource('UserService.scala', code); + const cls = result.nodes.find((n) => n.kind === 'class' && n.name === 'UserService'); + expect(cls).toBeDefined(); + expect(cls?.language).toBe('scala'); + }); + + it('should extract object definitions as class kind', () => { + const code = ` +object DatabaseConfig { + val url = "jdbc:postgresql://localhost/mydb" +} +`; + const result = extractFromSource('Config.scala', code); + const obj = result.nodes.find((n) => n.kind === 'class' && n.name === 'DatabaseConfig'); + expect(obj).toBeDefined(); + }); + + it('should extract trait definitions as trait kind', () => { + const code = ` +trait Repository[A] { + def findById(id: String): Option[A] + def save(entity: A): Unit +} +`; + const result = extractFromSource('Repository.scala', code); + const trait_ = result.nodes.find((n) => n.kind === 'trait' && n.name === 'Repository'); + expect(trait_).toBeDefined(); + }); + }); + + describe('Method and function extraction', () => { + it('should extract method definitions inside a class', () => { + const code = ` +class Calculator { + def add(a: Int, b: Int): Int = a + b + def divide(a: Double, b: Double): Double = a / b +} +`; + const result = extractFromSource('Calculator.scala', code); + const methods = result.nodes.filter((n) => n.kind === 'method'); + expect(methods.find((m) => m.name === 'add')).toBeDefined(); + expect(methods.find((m) => m.name === 'divide')).toBeDefined(); + }); + + it('should extract method signatures', () => { + const code = ` +class Greeter { + def greet(name: String): String = s"Hello, \${name}!" +} +`; + const result = extractFromSource('Greeter.scala', code); + const method = result.nodes.find((n) => n.name === 'greet'); + expect(method?.signature).toContain('name: String'); + expect(method?.signature).toContain('String'); + }); + + it('should extract top-level function definitions as functions', () => { + const code = ` +def factorial(n: Int): Int = if (n <= 1) 1 else n * factorial(n - 1) +def greet(name: String): String = s"Hello, \${name}!" +`; + const result = extractFromSource('utils.scala', code); + const fns = result.nodes.filter((n) => n.kind === 'function'); + expect(fns.find((f) => f.name === 'factorial')).toBeDefined(); + expect(fns.find((f) => f.name === 'greet')).toBeDefined(); + }); + }); + + describe('Val and var extraction', () => { + it('should extract val inside a class as field', () => { + const code = ` +class Config { + val timeout: Int = 30 + val host: String = "localhost" +} +`; + const result = extractFromSource('Config.scala', code); + const fields = result.nodes.filter((n) => n.kind === 'field'); + expect(fields.find((f) => f.name === 'timeout')).toBeDefined(); + expect(fields.find((f) => f.name === 'host')).toBeDefined(); + }); + + it('should extract var inside a class as field', () => { + const code = ` +class Counter { + var count: Int = 0 +} +`; + const result = extractFromSource('Counter.scala', code); + const field = result.nodes.find((n) => n.kind === 'field' && n.name === 'count'); + expect(field).toBeDefined(); + }); + + it('should extract top-level val as constant', () => { + const code = ` +val MaxConnections: Int = 100 +val DefaultTimeout = 30 +`; + const result = extractFromSource('constants.scala', code); + const consts = result.nodes.filter((n) => n.kind === 'constant'); + expect(consts.find((c) => c.name === 'MaxConnections')).toBeDefined(); + }); + + it('should extract top-level var as variable', () => { + const code = ` +var retries: Int = 3 +`; + const result = extractFromSource('state.scala', code); + const v = result.nodes.find((n) => n.kind === 'variable' && n.name === 'retries'); + expect(v).toBeDefined(); + }); + + it('should include type in val/var signature', () => { + const code = ` +class Service { + val timeout: Int = 30 +} +`; + const result = extractFromSource('Service.scala', code); + const field = result.nodes.find((n) => n.name === 'timeout'); + expect(field?.signature).toContain('timeout'); + expect(field?.signature).toContain('Int'); + }); + }); + + describe('Enum extraction', () => { + it('should extract enum definitions', () => { + const code = ` +enum Color: + case Red + case Green + case Blue +`; + const result = extractFromSource('Color.scala', code); + const enumNode = result.nodes.find((n) => n.kind === 'enum' && n.name === 'Color'); + expect(enumNode).toBeDefined(); + }); + + it('should extract enum cases as enum_member', () => { + const code = ` +enum Direction: + case North + case South + case East + case West +`; + const result = extractFromSource('Direction.scala', code); + const members = result.nodes.filter((n) => n.kind === 'enum_member'); + expect(members.find((m) => m.name === 'North')).toBeDefined(); + expect(members.find((m) => m.name === 'South')).toBeDefined(); + expect(members.length).toBeGreaterThanOrEqual(4); + }); + }); + + describe('Type alias extraction', () => { + it('should extract type aliases', () => { + const code = ` +type UserId = String +type UserMap = Map[String, String] +`; + const result = extractFromSource('types.scala', code); + const aliases = result.nodes.filter((n) => n.kind === 'type_alias'); + expect(aliases.find((a) => a.name === 'UserId')).toBeDefined(); + expect(aliases.find((a) => a.name === 'UserMap')).toBeDefined(); + }); + }); + + describe('Import extraction', () => { + it('should extract import declarations', () => { + const code = ` +import scala.collection.mutable.ListBuffer +import scala.concurrent.Future +`; + const result = extractFromSource('imports.scala', code); + const imports = result.nodes.filter((n) => n.kind === 'import'); + expect(imports.length).toBeGreaterThanOrEqual(2); + }); + }); + + describe('Visibility modifiers', () => { + it('should extract private visibility', () => { + const code = ` +class Service { + private val secret: String = "abc" + private def helper(): Unit = {} +} +`; + const result = extractFromSource('Service.scala', code); + const secretField = result.nodes.find((n) => n.name === 'secret'); + expect(secretField?.visibility).toBe('private'); + const helperMethod = result.nodes.find((n) => n.name === 'helper'); + expect(helperMethod?.visibility).toBe('private'); + }); + + it('should extract protected visibility', () => { + const code = ` +class Base { + protected def helperMethod(): Unit = {} +} +`; + const result = extractFromSource('Base.scala', code); + const method = result.nodes.find((n) => n.name === 'helperMethod'); + expect(method?.visibility).toBe('protected'); + }); + + it('should default to public visibility', () => { + const code = ` +class Greeter { + def hello(): Unit = {} +} +`; + const result = extractFromSource('Greeter.scala', code); + const method = result.nodes.find((n) => n.name === 'hello'); + expect(method?.visibility).toBe('public'); + }); + }); + + describe('Inheritance', () => { + it('should extract extends relationships', () => { + const code = ` +class AdminUser extends User { + def adminAction(): Unit = {} +} +`; + const result = extractFromSource('AdminUser.scala', code); + const extendsRefs = result.unresolvedReferences.filter((r) => r.referenceKind === 'extends'); + expect(extendsRefs.find((r) => r.referenceName === 'User')).toBeDefined(); + }); + }); + + describe('Call extraction', () => { + it('should extract function call expressions', () => { + const code = ` +def processData(): Unit = { + val result = computeResult() + println(result) +} +`; + const result = extractFromSource('processor.scala', code); + const calls = result.unresolvedReferences.filter((r) => r.referenceKind === 'calls'); + expect(calls.length).toBeGreaterThan(0); + }); + }); +}); diff --git a/src/extraction/grammars.ts b/src/extraction/grammars.ts index df264fb3..e831fb46 100644 --- a/src/extraction/grammars.ts +++ b/src/extraction/grammars.ts @@ -34,6 +34,7 @@ const WASM_GRAMMAR_FILES: Record = { kotlin: 'tree-sitter-kotlin.wasm', dart: 'tree-sitter-dart.wasm', pascal: 'tree-sitter-pascal.wasm', + scala: 'tree-sitter-scala.wasm', }; /** @@ -74,6 +75,8 @@ export const EXTENSION_MAP: Record = { '.lpr': 'pascal', '.dfm': 'pascal', '.fmx': 'pascal', + '.scala': 'scala', + '.sc': 'scala', }; /** @@ -121,8 +124,8 @@ export async function loadGrammarsForLanguages(languages: Language[]): Promise> = { typescript: typescriptExtractor, @@ -41,4 +42,5 @@ export const EXTRACTORS: Partial> = { kotlin: kotlinExtractor, dart: dartExtractor, pascal: pascalExtractor, + scala: scalaExtractor, }; diff --git a/src/extraction/languages/scala.ts b/src/extraction/languages/scala.ts new file mode 100644 index 00000000..fca6aea7 --- /dev/null +++ b/src/extraction/languages/scala.ts @@ -0,0 +1,143 @@ +import type { Node as SyntaxNode } from 'web-tree-sitter'; +import { getNodeText } from '../tree-sitter-helpers'; +import type { LanguageExtractor } from '../tree-sitter-types'; + +function getValVarName(node: SyntaxNode, source: string): string | null { + const patternNode = node.childForFieldName('pattern'); + if (!patternNode) return null; + if (patternNode.type === 'identifier') return getNodeText(patternNode, source); + const identChild = patternNode.namedChildren.find((c: SyntaxNode) => c.type === 'identifier'); + return identChild ? getNodeText(identChild, source) : null; +} + +function extractVisibility(node: SyntaxNode): 'public' | 'private' | 'protected' { + for (let i = 0; i < node.namedChildCount; i++) { + const child = node.namedChild(i); + if (!child) continue; + if (child.type === 'modifiers' || child.type === 'access_modifier') { + const text = child.text; + if (text.includes('private')) return 'private'; + if (text.includes('protected')) return 'protected'; + } + } + return 'public'; +} + +export const scalaExtractor: LanguageExtractor = { + // top-level function_definition is handled via methodTypes (same pattern as Kotlin) + functionTypes: [], + classTypes: ['class_definition', 'object_definition', 'trait_definition'], + methodTypes: ['function_definition', 'function_declaration'], + interfaceTypes: [], + structTypes: [], + enumTypes: ['enum_definition'], + enumMemberTypes: [], // handled in visitNode — enum_case_definitions wraps the cases + typeAliasTypes: ['type_definition'], + importTypes: ['import_declaration'], + callTypes: ['call_expression'], + variableTypes: [], // val/var handled in visitNode (use `pattern` field, not `name`) + fieldTypes: [], + extraClassNodeTypes: [], + + nameField: 'name', + bodyField: 'body', + paramsField: 'parameters', + returnField: 'return_type', + interfaceKind: 'trait', + + classifyClassNode: (node: SyntaxNode) => { + if (node.type === 'trait_definition') return 'trait'; + return 'class'; + }, + + getSignature: (node: SyntaxNode, source: string) => { + const params = node.childForFieldName('parameters'); + const returnType = node.childForFieldName('return_type'); + if (!params && !returnType) return undefined; + let sig = params ? getNodeText(params, source) : ''; + if (returnType) sig += ': ' + getNodeText(returnType, source); + return sig || undefined; + }, + + getVisibility: (node: SyntaxNode) => extractVisibility(node), + + isAsync: () => false, + + isStatic: (node: SyntaxNode) => { + for (let i = 0; i < node.namedChildCount; i++) { + const child = node.namedChild(i); + if (child?.type === 'modifiers' && child.text.includes('static')) return true; + } + return false; + }, + + visitNode: (node: SyntaxNode, ctx) => { + const t = node.type; + + // val/var: name is in `pattern` field (identifier), not `name` + if (t === 'val_definition' || t === 'var_definition') { + const name = getValVarName(node, ctx.source); + if (!name) return false; + + const isInClass = ctx.nodeStack.length > 0 && + (() => { + const parentId = ctx.nodeStack[ctx.nodeStack.length - 1]; + const parentNode = ctx.nodes.find((n) => n.id === parentId); + return parentNode != null && ( + parentNode.kind === 'class' || parentNode.kind === 'trait' || + parentNode.kind === 'interface' || parentNode.kind === 'struct' || + parentNode.kind === 'enum' || parentNode.kind === 'module' + ); + })(); + + const kind = isInClass ? 'field' : (t === 'val_definition' ? 'constant' : 'variable'); + const typeNode = node.childForFieldName('type'); + const sig = typeNode + ? `${t === 'val_definition' ? 'val' : 'var'} ${name}: ${getNodeText(typeNode, ctx.source)}` + : undefined; + + ctx.createNode(kind, name, node, { signature: sig, visibility: extractVisibility(node) }); + return true; + } + + // enum_case_definitions wraps simple_enum_case / full_enum_case children + if (t === 'enum_case_definitions') { + for (let i = 0; i < node.namedChildCount; i++) { + const child = node.namedChild(i); + if (!child) continue; + if (child.type === 'simple_enum_case' || child.type === 'full_enum_case') { + const nameNode = child.childForFieldName('name'); + if (nameNode) ctx.createNode('enum_member', getNodeText(nameNode, ctx.source), child); + } + } + return true; + } + + // extension_definition: visit body children directly, no container node + if (t === 'extension_definition') { + const body = node.childForFieldName('body'); + if (body) { + for (let i = 0; i < body.namedChildCount; i++) { + const child = body.namedChild(i); + if (child) ctx.visitNode(child); + } + } + return true; + } + + return false; + }, + + extractImport: (node: SyntaxNode, source: string) => { + const importText = getNodeText(node, source).trim(); + const pathNode = node.childForFieldName('path'); + if (pathNode) return { moduleName: getNodeText(pathNode, source), signature: importText }; + for (let i = 0; i < node.namedChildCount; i++) { + const child = node.namedChild(i); + if (child?.type === 'identifier' || child?.type === 'stable_identifier') { + return { moduleName: getNodeText(child, source), signature: importText }; + } + } + return null; + }, +}; diff --git a/src/extraction/wasm/tree-sitter-scala.wasm b/src/extraction/wasm/tree-sitter-scala.wasm new file mode 100644 index 00000000..8652623f Binary files /dev/null and b/src/extraction/wasm/tree-sitter-scala.wasm differ diff --git a/src/types.ts b/src/types.ts index 6834483d..62bcfa00 100644 --- a/src/types.ts +++ b/src/types.ts @@ -75,6 +75,7 @@ export type Language = | 'svelte' | 'liquid' | 'pascal' + | 'scala' | 'unknown'; // ============================================================================= @@ -527,6 +528,9 @@ export const DEFAULT_CONFIG: CodeGraphConfig = { '**/*.lpr', '**/*.dfm', '**/*.fmx', + // Scala + '**/*.scala', + '**/*.sc', ], exclude: [ // Version control