diff --git a/CLAUDE.md b/CLAUDE.md index 67cbee8..0509e41 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,9 +1,15 @@ See @README for a general project overview. -# Additional Instructions +## Additional Instructions - Write idiomatic rust code that is simple as possible. Avoid excessive abstraction. -- All code should have tests (unit and/or integration tests). -- Look for opportunities to refactor, simplify and improve your code. -- Ensure that documentation is up to date with the code. +- Look for opportunities to refactor, simplify and improve the code. +- Ensure that code is documented, and that the documentation is up to date with the code. - When you're done making changes test, lint and format the code to check everything. + +## Testing Guidelines + +- All code should have tests (unit and/or integration tests). +- Use the assert_fs and assert_cmd crates for integration testing. +- Where parametric tests are appropriate, use the yare crate. + Always create a struct (e.g. `Case`) to hold test case data. diff --git a/src/graph.rs b/src/graph.rs index fbe8c9e..6a9f6f1 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -1,40 +1,57 @@ //! Inheritance graph construction and traversal. +//! +//! This module provides functionality to build and query a class inheritance graph. +//! The graph maps parent classes to their direct and transitive children, enabling +//! efficient subclass discovery. -use std::collections::{HashMap, HashSet, VecDeque}; +use std::collections::{HashMap, HashSet}; -use crate::registry::{ClassId, ClassRegistry}; +use crate::registry::{ClassId, Registry}; -/// An inheritance graph mapping classes to their children. +/// An inheritance graph representing parent-child relationships between classes. +/// +/// This structure maps each class to the set of classes that directly inherit from it. +/// The graph can be used to efficiently find all descendants of a given class using +/// breadth-first search. pub struct InheritanceGraph { - /// Map from parent ClassId to child ClassIds - children: HashMap>, + /// Maps parent classes to their direct children. + pub children: HashMap>, } impl InheritanceGraph { - /// Builds an inheritance graph from a class registry. + /// Builds an inheritance graph from a registry. + /// + /// This resolves all base class references and constructs parent-to-child + /// relationships. The registry's `resolve_class` method is used to handle + /// imports and re-exports correctly. /// /// # Arguments /// - /// * `registry` - The class registry containing all classes + /// * `registry` - The registry containing all classes and their base class references /// /// # Returns /// - /// An inheritance graph with parent-child relationships. - pub fn build(registry: &ClassRegistry) -> Self { - let mut children: HashMap> = HashMap::new(); - - // For each class, resolve its base classes and add edges - for class_id in registry.all_class_ids() { - if let Some(info) = registry.get(&class_id) { - for base in &info.bases { - // Try to resolve the base class - if let Some(parent_id) = registry.resolve_base(base, &class_id.module_path) { - // Add edge from parent to child - children - .entry(parent_id) - .or_default() - .push(class_id.clone()); - } + /// An inheritance graph ready for traversal. + /// + /// # Algorithm + /// + /// For each class in the registry: + /// 1. Resolve each of its base class names to a `ClassId` + /// 2. Add this class to the parent's children set + /// 3. Skip any base classes that cannot be resolved (e.g., external dependencies) + pub fn build(registry: &Registry) -> Self { + let mut children: HashMap> = HashMap::new(); + + // Build parent → children edges by examining each class's bases + for (child_id, metadata) in ®istry.classes { + for base_name in &metadata.bases { + // Resolve the base class reference in this class's module context + if let Some(parent_id) = registry.resolve_class(&child_id.module, base_name) { + // Add this class as a child of its parent + children + .entry(parent_id) + .or_default() + .insert(child_id.clone()); } } } @@ -44,40 +61,53 @@ impl InheritanceGraph { /// Finds all transitive subclasses of a given class. /// - /// Uses BFS to traverse the inheritance graph and collect all descendants. + /// Performs a breadth-first search to discover all classes that directly or + /// indirectly inherit from the specified root class. This includes: + /// - Direct children (one level of inheritance) + /// - Grandchildren (two levels) + /// - Great-grandchildren, etc. (any depth) /// /// # Arguments /// - /// * `root` - The root class to find subclasses of + /// * `root` - The class to find subclasses for /// /// # Returns /// - /// A vector of all transitive subclasses (not including the root class itself). + /// A vector containing all transitive subclasses. The root class itself is not + /// included in the result. The order of classes in the vector is determined by + /// the BFS traversal order. + /// + /// # Examples + /// + /// ```text + /// Given: + /// class Animal: pass + /// class Mammal(Animal): pass + /// class Dog(Mammal): pass + /// class Cat(Mammal): pass + /// + /// find_all_subclasses(Animal) → [Mammal, Dog, Cat] + /// ``` pub fn find_all_subclasses(&self, root: &ClassId) -> Vec { + use std::collections::VecDeque; + let mut result = Vec::new(); let mut visited = HashSet::new(); let mut queue = VecDeque::new(); - // Start with the root's immediate children - if let Some(children) = self.children.get(root) { - for child in children { - queue.push_back(child.clone()); - } - } + // Initialize BFS with the root class + queue.push_back(root.clone()); + visited.insert((root.module.clone(), root.name.clone())); // BFS traversal - while let Some(class_id) = queue.pop_front() { - // Skip if already visited (handles potential cycles) - if !visited.insert(class_id.clone()) { - continue; - } - - result.push(class_id.clone()); - - // Add this class's children to the queue - if let Some(children) = self.children.get(&class_id) { + while let Some(current) = queue.pop_front() { + // Examine all direct children of the current class + if let Some(children) = self.children.get(¤t) { for child in children { - if !visited.contains(child) { + let key = (child.module.clone(), child.name.clone()); + if !visited.contains(&key) { + visited.insert(key); + result.push(child.clone()); queue.push_back(child.clone()); } } @@ -87,183 +117,3 @@ impl InheritanceGraph { result } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::parser::{BaseClass, ClassDefinition}; - use crate::registry::ClassRegistry; - use std::path::PathBuf; - - #[test] - fn test_simple_inheritance() { - let registry = ClassRegistry::new(vec![ - // Animal (base class) - crate::parser::ParsedFile { - module_path: "animals".to_string(), - classes: vec![ClassDefinition { - name: "Animal".to_string(), - module_path: "animals".to_string(), - file_path: PathBuf::from("animals.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }, - // Dog(Animal) - crate::parser::ParsedFile { - module_path: "animals".to_string(), - classes: vec![ClassDefinition { - name: "Dog".to_string(), - module_path: "animals".to_string(), - file_path: PathBuf::from("animals.py"), - bases: vec![BaseClass::Simple("Animal".to_string())], - }], - imports: vec![], - is_package: false, - }, - ]); - - let graph = InheritanceGraph::build(®istry); - - let animal_id = ClassId::new("animals".to_string(), "Animal".to_string()); - let subclasses = graph.find_all_subclasses(&animal_id); - - assert_eq!(subclasses.len(), 1); - assert_eq!(subclasses[0].class_name, "Dog"); - } - - #[test] - fn test_transitive_inheritance() { - let registry = ClassRegistry::new(vec![ - // Animal - crate::parser::ParsedFile { - module_path: "base".to_string(), - classes: vec![ClassDefinition { - name: "Animal".to_string(), - module_path: "base".to_string(), - file_path: PathBuf::from("base.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }, - // Mammal(Animal) - crate::parser::ParsedFile { - module_path: "base".to_string(), - classes: vec![ClassDefinition { - name: "Mammal".to_string(), - module_path: "base".to_string(), - file_path: PathBuf::from("base.py"), - bases: vec![BaseClass::Simple("Animal".to_string())], - }], - imports: vec![], - is_package: false, - }, - // Dog(Mammal) - crate::parser::ParsedFile { - module_path: "base".to_string(), - classes: vec![ClassDefinition { - name: "Dog".to_string(), - module_path: "base".to_string(), - file_path: PathBuf::from("base.py"), - bases: vec![BaseClass::Simple("Mammal".to_string())], - }], - imports: vec![], - is_package: false, - }, - ]); - - let graph = InheritanceGraph::build(®istry); - - let animal_id = ClassId::new("base".to_string(), "Animal".to_string()); - let subclasses = graph.find_all_subclasses(&animal_id); - - // Should find both Mammal and Dog - assert_eq!(subclasses.len(), 2); - - let names: HashSet<_> = subclasses.iter().map(|c| c.class_name.as_str()).collect(); - assert!(names.contains("Mammal")); - assert!(names.contains("Dog")); - } - - #[test] - fn test_multiple_inheritance() { - let registry = ClassRegistry::new(vec![ - // Animal - crate::parser::ParsedFile { - module_path: "base".to_string(), - classes: vec![ClassDefinition { - name: "Animal".to_string(), - module_path: "base".to_string(), - file_path: PathBuf::from("base.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }, - // Pet - crate::parser::ParsedFile { - module_path: "base".to_string(), - classes: vec![ClassDefinition { - name: "Pet".to_string(), - module_path: "base".to_string(), - file_path: PathBuf::from("base.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }, - // Dog(Animal, Pet) - crate::parser::ParsedFile { - module_path: "base".to_string(), - classes: vec![ClassDefinition { - name: "Dog".to_string(), - module_path: "base".to_string(), - file_path: PathBuf::from("base.py"), - bases: vec![ - BaseClass::Simple("Animal".to_string()), - BaseClass::Simple("Pet".to_string()), - ], - }], - imports: vec![], - is_package: false, - }, - ]); - - let graph = InheritanceGraph::build(®istry); - - // Dog should be a subclass of both Animal and Pet - let animal_id = ClassId::new("base".to_string(), "Animal".to_string()); - let animal_subclasses = graph.find_all_subclasses(&animal_id); - assert_eq!(animal_subclasses.len(), 1); - assert_eq!(animal_subclasses[0].class_name, "Dog"); - - let pet_id = ClassId::new("base".to_string(), "Pet".to_string()); - let pet_subclasses = graph.find_all_subclasses(&pet_id); - assert_eq!(pet_subclasses.len(), 1); - assert_eq!(pet_subclasses[0].class_name, "Dog"); - } - - #[test] - fn test_no_subclasses() { - let registry = ClassRegistry::new(vec![crate::parser::ParsedFile { - module_path: "base".to_string(), - classes: vec![ClassDefinition { - name: "Animal".to_string(), - module_path: "base".to_string(), - file_path: PathBuf::from("base.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }]); - - let graph = InheritanceGraph::build(®istry); - - let animal_id = ClassId::new("base".to_string(), "Animal".to_string()); - let subclasses = graph.find_all_subclasses(&animal_id); - - assert_eq!(subclasses.len(), 0); - } -} diff --git a/src/lib.rs b/src/lib.rs index 793828a..8b5aa18 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,18 +21,18 @@ //! # } //! ``` -pub(crate) mod discovery; -pub(crate) mod error; -pub(crate) mod graph; -pub(crate) mod parser; -pub(crate) mod registry; -pub(crate) mod utils; +pub mod discovery; +pub mod error; +pub mod graph; +pub mod parser; +pub mod registry; use std::path::PathBuf; pub use error::{Error, Result}; use graph::InheritanceGraph; -use registry::{ClassId, ClassRegistry}; + +use crate::registry::Registry; /// A reference to a Python class. #[derive(Debug, Clone, PartialEq, Eq)] @@ -73,8 +73,7 @@ impl ClassReference { /// # } /// ``` pub struct SubclassFinder { - root_dir: PathBuf, - registry: ClassRegistry, + registry: Registry, graph: InheritanceGraph, } @@ -111,17 +110,12 @@ impl SubclassFinder { }) .collect(); - // Build registry from parsed files - let registry = ClassRegistry::new(parsed_files); + let registry = Registry::build(&parsed_files)?; // Build the inheritance graph - let graph = graph::InheritanceGraph::build(®istry); + let graph = InheritanceGraph::build(®istry); - Ok(Self { - root_dir, - registry, - graph, - }) + Ok(Self { registry, graph }) } /// Finds all transitive subclasses of a given class. @@ -166,20 +160,59 @@ impl SubclassFinder { module_path: Option<&str>, ) -> Result> { // Find the target class - let target_id = self.find_target_class(class_name, module_path)?; + let target_id = if let Some(module) = module_path { + // Module specified - look for class in that module (including re-exports) + if let Some(resolved_id) = self.registry.resolve_class(module, class_name) { + resolved_id + } else { + return Err(Error::ClassNotFound { + name: class_name.to_string(), + module_path: Some(module.to_string()), + }); + } + } else { + // No module specified - search for class by name + let mut matches = Vec::new(); + for class_id in self.registry.classes.keys() { + if class_id.name == class_name { + matches.push(class_id.clone()); + } + } + + match matches.len() { + 0 => { + return Err(Error::ClassNotFound { + name: class_name.to_string(), + module_path: None, + }); + } + 1 => matches.into_iter().next().unwrap(), + _ => { + let candidates: Vec = + matches.iter().map(|id| id.module.clone()).collect(); + return Err(Error::AmbiguousClassName { + name: class_name.to_string(), + candidates, + }); + } + } + }; - // Find all subclasses + // Find all subclasses using the graph let subclass_ids = self.graph.find_all_subclasses(&target_id); - // Convert to ClassReferences + // Convert to ClassReference let mut results: Vec = subclass_ids .into_iter() .filter_map(|id| { - self.registry.get(&id).map(|info| ClassReference { - class_name: id.class_name.clone(), - module_path: id.module_path.clone(), - file_path: info.file_path.clone(), - }) + self.registry + .modules + .get(&id.module) + .map(|metadata| ClassReference { + class_name: id.name.clone(), + module_path: id.module.clone(), + file_path: metadata.file_path.clone(), + }) }) .collect(); @@ -193,63 +226,8 @@ impl SubclassFinder { Ok(results) } - /// Finds the ClassId for the target class. - fn find_target_class(&self, class_name: &str, module_path: Option<&str>) -> Result { - // If module path provided, look for exact match or re-export - if let Some(module) = module_path { - let id = ClassId::new(module.to_string(), class_name.to_string()); - - // Try direct lookup first - if self.registry.get(&id).is_some() { - return Ok(id); - } - - // Try resolving through re-exports - if let Some(resolved) = self.registry.resolve_class_through_reexports(&id) { - return Ok(resolved); - } - - return Err(Error::ClassNotFound { - name: class_name.to_string(), - module_path: Some(module.to_string()), - }); - } - - // Otherwise find by name - let matches = self - .registry - .find_by_name(class_name) - .filter(|ids| !ids.is_empty()) - .ok_or_else(|| Error::ClassNotFound { - name: class_name.to_string(), - module_path: None, - })?; - - match matches.len() { - 1 => Ok(matches[0].clone()), - _ => { - let candidates = matches.iter().map(|id| id.module_path.clone()).collect(); - Err(Error::AmbiguousClassName { - name: class_name.to_string(), - candidates, - }) - } - } - } - /// Returns the number of classes found in the codebase. pub fn class_count(&self) -> usize { - self.registry.len() + self.registry.classes.len() } - - /// Returns the root directory being searched. - pub fn root_dir(&self) -> &PathBuf { - &self.root_dir - } -} - -#[cfg(test)] -mod tests { - - // Integration tests will be in tests/ directory } diff --git a/src/parser.rs b/src/parser.rs index 2b561a5..df3254e 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -8,15 +8,6 @@ use std::path::{Path, PathBuf}; use crate::error::{Error, Result}; -/// Represents a base class reference in a class definition. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum BaseClass { - /// Simple name reference (e.g., `Animal`) - Simple(String), - /// Attribute reference (e.g., `module.Animal` or `package.module.Animal`) - Attribute(Vec), -} - /// Represents a Python class definition. #[derive(Debug, Clone)] pub struct ClassDefinition { @@ -26,39 +17,35 @@ pub struct ClassDefinition { pub module_path: String, /// The file path where the class is defined pub file_path: PathBuf, - /// The base classes this class inherits from (unresolved) - pub bases: Vec, + /// The base classes this class inherits from e.g. "Foo" or "foo.Foo" + pub bases: Vec, } -/// Represents an import statement. +/// An import. +/// +/// E.g. +/// `import a` => { imported_item=a, imported_as=a } +/// `import a.b` => { imported_item=a.b, imported_as=a.b } +/// `import a.b as c` => { imported_item=a.b, imported_as=c } +/// `from a import b` => { imported_item=a.b, imported_as=b } +/// `from a import b as c` => { imported_item=a.b, imported_as=c } +/// `from a.b import c` => { imported_item=a.b.c, imported_as=c } #[derive(Debug, Clone)] -pub enum Import { - /// `import foo` or `import foo.bar` - Module { - module: String, - alias: Option, - }, - /// `from foo import bar` or `from foo import bar as baz` - From { - module: String, - names: Vec<(String, Option)>, // (name, alias) - }, - /// `from .relative import foo` (relative import) - RelativeFrom { - level: usize, // Number of dots - module: Option, - names: Vec<(String, Option)>, - }, +pub struct Import { + pub imported_item: String, + pub imported_as: String, } /// The result of parsing a Python file. #[derive(Debug)] pub struct ParsedFile { + /// The path of this file + pub file_path: PathBuf, /// The module path of this file pub module_path: String, /// Class definitions found in this file pub classes: Vec, - /// Import statements found in this file + /// Import statements found in this file (relative imports already resolved) pub imports: Vec, /// Whether this is a package (__init__.py file) pub is_package: bool, @@ -138,6 +125,7 @@ pub fn parse_file(file_path: &Path, module_path: &str) -> Result { .unwrap_or(false); Ok(ParsedFile { + file_path: file_path.to_path_buf(), module_path: module_path.to_string(), classes, imports, @@ -147,16 +135,20 @@ pub fn parse_file(file_path: &Path, module_path: &str) -> Result { /// Recursively extracts classes and imports from a list of statements. /// -/// This function handles both top-level and nested class definitions. +/// This function walks the AST and extracts: +/// - Class definitions (including nested classes) +/// - Import statements (both `import` and `from...import` forms) +/// +/// Nested classes are represented with dot notation (e.g., "Outer.Inner"). /// /// # Arguments /// -/// * `stmts` - The statements to process -/// * `parent_class` - The parent class name for nested classes (e.g., Some("Bar")) -/// * `module_path` - The module path for this file -/// * `file_path` - The file path -/// * `classes` - Mutable vector to accumulate class definitions -/// * `imports` - Mutable vector to accumulate imports (only at top level) +/// * `stmts` - The AST statements to process +/// * `parent_class` - The parent class name if processing nested classes (e.g., `Some("Outer")`) +/// * `module_path` - The module path for this file (e.g., "foo.bar") +/// * `file_path` - The file path (used for resolving relative imports) +/// * `classes` - Mutable vector to accumulate discovered class definitions +/// * `imports` - Mutable vector to accumulate discovered imports (only at top level) fn extract_from_statements( stmts: &[Stmt], parent_class: Option<&str>, @@ -168,13 +160,14 @@ fn extract_from_statements( for stmt in stmts { match stmt { Stmt::ClassDef(class_def) => { - // Build the full class name (with parent prefix if nested) + // Build fully qualified class name (e.g., "Outer.Inner" for nested classes) let full_name = if let Some(parent) = parent_class { format!("{}.{}", parent, class_def.name) } else { class_def.name.to_string() }; + // Extract base classes, filtering out unresolvable references let bases = class_def .bases() .iter() @@ -188,7 +181,7 @@ fn extract_from_statements( bases, }); - // Recursively extract nested classes + // Recursively process nested classes extract_from_statements( class_def.body.as_slice(), Some(&full_name), @@ -199,47 +192,61 @@ fn extract_from_statements( ); } Stmt::Import(import_stmt) => { - // Only process imports at the top level (not inside classes) - if parent_class.is_none() { - for alias in &import_stmt.names { - imports.push(Import::Module { - module: alias.name.to_string(), - alias: alias.asname.as_ref().map(|a| a.to_string()), - }); - } + // Process `import foo` or `import foo as bar` statements + // Format: { imported_item: "foo", imported_as: "bar" } + for alias in &import_stmt.names { + let imported_item = alias.name.to_string(); + let imported_as = alias + .asname + .as_ref() + .map(|a| a.to_string()) + .unwrap_or_else(|| imported_item.clone()); + imports.push(Import { + imported_item, + imported_as, + }); } } Stmt::ImportFrom(import_from) => { - // Only process imports at the top level (not inside classes) - if parent_class.is_none() { - let level = import_from.level as usize; - let names: Vec<(String, Option)> = import_from - .names - .iter() - .map(|alias| { - ( - alias.name.to_string(), - alias.asname.as_ref().map(|a| a.to_string()), - ) - }) - .collect(); - - if level > 0 { - imports.push(Import::RelativeFrom { - level, - module: import_from.module.as_ref().map(|m| m.to_string()), - names, - }); + // Process `from foo import bar` or `from .foo import bar` statements + let level = import_from.level as usize; + + for alias in &import_from.names { + let name = alias.name.to_string(); + let imported_as = alias + .asname + .as_ref() + .map(|a| a.to_string()) + .unwrap_or_else(|| name.clone()); + + let imported_item = if level > 0 { + // Relative import: resolve dots to absolute module path + // e.g., `from ..pkg import Foo` → "parent.pkg.Foo" + let base_module = resolve_relative_module(module_path, level, file_path); + if let Some(from_module) = import_from.module.as_ref() { + format!("{base_module}.{from_module}.{name}") + } else { + format!("{base_module}.{name}") + } } else { - imports.push(Import::From { - module: import_from - .module - .as_ref() - .map(|m| m.to_string()) - .unwrap_or_default(), - names, - }); - } + // Absolute import: combine module and name + // e.g., `from foo import Bar` → "foo.Bar" + let from_module = import_from + .module + .as_ref() + .map(|m| m.to_string()) + .unwrap_or_default(); + if from_module.is_empty() { + name.clone() + } else { + format!("{from_module}.{name}") + } + }; + + imports.push(Import { + imported_item, + imported_as, + }); } } _ => {} @@ -247,11 +254,65 @@ fn extract_from_statements( } } -/// Extracts a base class reference from an expression. -fn extract_base_class(expr: &Expr) -> Option { +/// Resolves a relative import to an absolute module path. +/// +/// Relative imports in Python use dots to indicate the starting point: +/// - `.foo` means "foo in the current package" +/// - `..foo` means "foo in the parent package" +/// - `...foo` means "foo in the grandparent package" +/// +/// This function converts the relative path to an absolute module path based on +/// the current module's location. +/// +/// # Arguments +/// +/// * `current_module` - The module path of the file containing the import (e.g., "pkg.sub.module") +/// * `level` - The number of leading dots in the relative import (e.g., 2 for `..foo`) +/// * `file_path` - The file path (used to check if this is a `__init__.py` package file) +/// +/// # Returns +/// +/// The absolute module path that the relative import refers to. +/// +/// # Examples +/// +/// ```text +/// # In pkg/sub/module.py (module path "pkg.sub.module"): +/// from ..other import Foo # level=2 → resolves to "pkg.other" +/// +/// # In pkg/sub/__init__.py (module path "pkg.sub"): +/// from ..other import Foo # level=2 → resolves to "pkg.other" +/// from .local import Bar # level=1 → resolves to "pkg.sub.local" +/// ``` +fn resolve_relative_module(current_module: &str, level: usize, file_path: &Path) -> String { + let is_package = file_path + .file_name() + .and_then(|name| name.to_str()) + .map(|name| name == "__init__.py") + .unwrap_or(false); + + let parts: Vec<&str> = current_module.split('.').collect(); + + // Packages start at themselves, modules start at their parent + // e.g., In "pkg.sub" package: level 1 = "pkg.sub", level 2 = "pkg" + // In "pkg.sub.mod" module: level 1 = "pkg.sub", level 2 = "pkg" + let base_level = if is_package { level - 1 } else { level }; + + // Navigate up the package hierarchy + if base_level >= parts.len() { + // Trying to go above the root - return empty (error case) + String::new() + } else { + parts[..parts.len() - base_level].join(".") + } +} + +/// Extracts a base class reference from an expression as a string. +/// Returns strings like "Foo" or "module.Foo" or "pkg.mod.Foo" +fn extract_base_class(expr: &Expr) -> Option { match expr { // Simple name: class Foo(Bar) - Expr::Name(name) => Some(BaseClass::Simple(name.id.to_string())), + Expr::Name(name) => Some(name.id.to_string()), // Attribute: class Foo(module.Bar) or class Foo(pkg.mod.Bar) Expr::Attribute(_) => { @@ -274,7 +335,7 @@ fn extract_base_class(expr: &Expr) -> Option { } parts.reverse(); - Some(BaseClass::Attribute(parts)) + Some(parts.join(".")) } // Subscript: class Foo(Generic[T]) - extract the base without the subscript @@ -423,7 +484,7 @@ class TopLevel(Foo): .find(|c| c.name == "Bar.NestedInBar") .unwrap(); assert_eq!(nested_in_bar.bases.len(), 1); - assert_eq!(nested_in_bar.bases[0], BaseClass::Simple("Foo".to_string())); + assert_eq!(nested_in_bar.bases[0], "Foo"); // Verify that DoublyNested has Foo as a base let doubly_nested = parsed @@ -432,6 +493,287 @@ class TopLevel(Foo): .find(|c| c.name == "Bar.AnotherNested.DoublyNested") .unwrap(); assert_eq!(doubly_nested.bases.len(), 1); - assert_eq!(doubly_nested.bases[0], BaseClass::Simple("Foo".to_string())); + assert_eq!(doubly_nested.bases[0], "Foo"); + } + + // Parametric tests for import parsing + #[derive(Debug)] + struct ImportCase { + name: &'static str, + python_code: &'static str, + module_path: &'static str, + is_package: bool, + expected_imports: Vec<(&'static str, &'static str)>, // (imported_item, imported_as) + } + + #[yare::parameterized( + absolute_import_simple = { ImportCase { + name: "absolute import simple", + python_code: "import foo", + module_path: "mymodule", + is_package: false, + expected_imports: vec![("foo", "foo")], + } }, + absolute_import_dotted = { ImportCase { + name: "absolute import dotted", + python_code: "import foo.bar.baz", + module_path: "mymodule", + is_package: false, + expected_imports: vec![("foo.bar.baz", "foo.bar.baz")], + } }, + absolute_import_alias = { ImportCase { + name: "absolute import with alias", + python_code: "import foo as f", + module_path: "mymodule", + is_package: false, + expected_imports: vec![("foo", "f")], + } }, + absolute_import_dotted_alias = { ImportCase { + name: "absolute import dotted with alias", + python_code: "import foo.bar as fb", + module_path: "mymodule", + is_package: false, + expected_imports: vec![("foo.bar", "fb")], + } }, + from_import_simple = { ImportCase { + name: "from import simple", + python_code: "from foo import Bar", + module_path: "mymodule", + is_package: false, + expected_imports: vec![("foo.Bar", "Bar")], + } }, + from_import_dotted = { ImportCase { + name: "from import dotted", + python_code: "from foo.bar import Baz", + module_path: "mymodule", + is_package: false, + expected_imports: vec![("foo.bar.Baz", "Baz")], + } }, + from_import_alias = { ImportCase { + name: "from import with alias", + python_code: "from foo import Bar as B", + module_path: "mymodule", + is_package: false, + expected_imports: vec![("foo.Bar", "B")], + } }, + from_import_multiple = { ImportCase { + name: "from import multiple", + python_code: "from foo import Bar, Baz", + module_path: "mymodule", + is_package: false, + expected_imports: vec![("foo.Bar", "Bar"), ("foo.Baz", "Baz")], + } }, + from_import_multiple_alias = { ImportCase { + name: "from import multiple with alias", + python_code: "from foo import Bar as B, Baz as Z", + module_path: "mymodule", + is_package: false, + expected_imports: vec![("foo.Bar", "B"), ("foo.Baz", "Z")], + } }, + relative_import_one_level_module = { ImportCase { + name: "relative import one level from module", + python_code: "from .sibling import Foo", + module_path: "pkg.mymodule", + is_package: false, + expected_imports: vec![("pkg.sibling.Foo", "Foo")], + } }, + relative_import_one_level_package = { ImportCase { + name: "relative import one level from package", + python_code: "from .sibling import Foo", + module_path: "pkg", + is_package: true, + expected_imports: vec![("pkg.sibling.Foo", "Foo")], + } }, + relative_import_two_levels_module = { ImportCase { + name: "relative import two levels from module", + python_code: "from ..other import Foo", + module_path: "pkg.sub.mymodule", + is_package: false, + expected_imports: vec![("pkg.other.Foo", "Foo")], + } }, + relative_import_two_levels_package = { ImportCase { + name: "relative import two levels from package", + python_code: "from ..other import Foo", + module_path: "pkg.sub", + is_package: true, + expected_imports: vec![("pkg.other.Foo", "Foo")], + } }, + relative_import_no_module = { ImportCase { + name: "relative import without from module", + python_code: "from . import Foo", + module_path: "pkg.mymodule", + is_package: false, + expected_imports: vec![("pkg.Foo", "Foo")], + } }, + relative_import_alias = { ImportCase { + name: "relative import with alias", + python_code: "from .sibling import Foo as F", + module_path: "pkg.mymodule", + is_package: false, + expected_imports: vec![("pkg.sibling.Foo", "F")], + } }, + multiple_imports = { ImportCase { + name: "multiple import statements", + python_code: "import foo\nfrom bar import Baz", + module_path: "mymodule", + is_package: false, + expected_imports: vec![("foo", "foo"), ("bar.Baz", "Baz")], + } }, + )] + fn test_import_parsing(case: ImportCase) { + let temp_dir = std::env::temp_dir(); + // Create a proper directory structure for package tests + let test_id = case.name.replace(' ', "_"); + let test_root = temp_dir.join(format!("test_imports_{test_id}")); + std::fs::create_dir_all(&test_root).unwrap(); + + let temp_file = if case.is_package { + test_root.join("__init__.py") + } else { + test_root.join("test.py") + }; + + std::fs::write(&temp_file, case.python_code).unwrap(); + + let parsed = parse_file(&temp_file, case.module_path).unwrap(); + + // Clean up + let _ = std::fs::remove_dir_all(&test_root); + + // Verify imports + assert_eq!( + parsed.imports.len(), + case.expected_imports.len(), + "Case '{}': expected {} imports, got {}", + case.name, + case.expected_imports.len(), + parsed.imports.len() + ); + + for (i, (expected_item, expected_as)) in case.expected_imports.iter().enumerate() { + assert_eq!( + parsed.imports[i].imported_item, *expected_item, + "Case '{}': import {} - expected imported_item '{}', got '{}'", + case.name, i, expected_item, parsed.imports[i].imported_item + ); + assert_eq!( + parsed.imports[i].imported_as, *expected_as, + "Case '{}': import {} - expected imported_as '{}', got '{}'", + case.name, i, expected_as, parsed.imports[i].imported_as + ); + } + } + + // Parametric tests for base class extraction + #[derive(Debug)] + struct BaseClassCase { + name: &'static str, + python_code: &'static str, + class_name: &'static str, + expected_bases: Vec<&'static str>, + } + + #[yare::parameterized( + simple_base = { BaseClassCase { + name: "simple base class", + python_code: "class Foo(Bar): pass", + class_name: "Foo", + expected_bases: vec!["Bar"], + } }, + attribute_base = { BaseClassCase { + name: "attribute base class", + python_code: "class Foo(module.Bar): pass", + class_name: "Foo", + expected_bases: vec!["module.Bar"], + } }, + nested_attribute_base = { BaseClassCase { + name: "nested attribute base class", + python_code: "class Foo(pkg.module.Bar): pass", + class_name: "Foo", + expected_bases: vec!["pkg.module.Bar"], + } }, + multiple_bases_simple = { BaseClassCase { + name: "multiple simple bases", + python_code: "class Foo(Bar, Baz): pass", + class_name: "Foo", + expected_bases: vec!["Bar", "Baz"], + } }, + multiple_bases_mixed = { BaseClassCase { + name: "multiple mixed bases", + python_code: "class Foo(Bar, pkg.Baz): pass", + class_name: "Foo", + expected_bases: vec!["Bar", "pkg.Baz"], + } }, + generic_base = { BaseClassCase { + name: "generic base class", + python_code: "class Foo(Generic[T]): pass", + class_name: "Foo", + expected_bases: vec!["Generic"], + } }, + generic_with_multiple = { BaseClassCase { + name: "generic with multiple type params", + python_code: "class Foo(Dict[str, int]): pass", + class_name: "Foo", + expected_bases: vec!["Dict"], + } }, + mixed_generic_and_simple = { BaseClassCase { + name: "mixed generic and simple", + python_code: "class Foo(Bar, Generic[T]): pass", + class_name: "Foo", + expected_bases: vec!["Bar", "Generic"], + } }, + no_bases = { BaseClassCase { + name: "no base classes", + python_code: "class Foo: pass", + class_name: "Foo", + expected_bases: vec![], + } }, + attribute_generic = { BaseClassCase { + name: "attribute with generic", + python_code: "class Foo(typing.Generic[T]): pass", + class_name: "Foo", + expected_bases: vec!["typing.Generic"], + } }, + )] + fn test_base_class_extraction(case: BaseClassCase) { + let temp_dir = std::env::temp_dir(); + let temp_file = temp_dir.join(format!("test_bases_{}.py", case.name)); + + std::fs::write(&temp_file, case.python_code).unwrap(); + + let parsed = parse_file(&temp_file, "test_module").unwrap(); + + // Clean up + let _ = std::fs::remove_file(&temp_file); + + // Find the class + let class = parsed + .classes + .iter() + .find(|c| c.name == case.class_name) + .unwrap_or_else(|| { + panic!( + "Case '{}': class '{}' not found", + case.name, case.class_name + ) + }); + + // Verify bases + assert_eq!( + class.bases.len(), + case.expected_bases.len(), + "Case '{}': expected {} bases, got {}", + case.name, + case.expected_bases.len(), + class.bases.len() + ); + + for (i, expected_base) in case.expected_bases.iter().enumerate() { + assert_eq!( + class.bases[i], *expected_base, + "Case '{}': base {} - expected '{}', got '{}'", + case.name, i, expected_base, class.bases[i] + ); + } } } diff --git a/src/registry.rs b/src/registry.rs index 0e9b235..f1e89eb 100644 --- a/src/registry.rs +++ b/src/registry.rs @@ -1,582 +1,220 @@ //! Class registry for tracking class definitions and resolving references. +//! +//! The registry maintains a complete index of all Python modules, classes, and imports +//! found in a codebase. It provides functionality to resolve class references across +//! modules, handling imports and re-exports correctly. -use std::collections::HashMap; -use std::path::PathBuf; +use std::{ + collections::{HashMap, HashSet}, + path::PathBuf, +}; -use crate::parser::{BaseClass, ClassDefinition, Import, ParsedFile}; +use crate::{ + error::Result, + parser::{Import, ParsedFile}, +}; -/// A unique identifier for a class. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct ClassId { - pub module_path: String, - pub class_name: String, -} - -impl ClassId { - pub fn new(module_path: String, class_name: String) -> Self { - Self { - module_path, - class_name, - } - } -} +/// Type alias for Python module names (e.g., "foo.bar.baz"). +pub type ModuleName = String; -/// Information about where a class is defined. +/// Metadata about a Python module. #[derive(Debug, Clone)] -pub struct ClassInfo { +pub struct ModuleMetadata { + /// The file system path to this module. pub file_path: PathBuf, - pub bases: Vec, + /// Whether this is a package (`__init__.py` file). + pub is_package: bool, } -/// A registry of all classes found in a codebase. -#[derive(Default)] -pub struct ClassRegistry { - /// Map from ClassId to ClassInfo - classes: HashMap, - - /// Map from simple class name to all ClassIds with that name - /// (for ambiguity detection) - name_index: HashMap>, - - /// Map from module path to imports in that module - imports: HashMap>, +/// Metadata about a Python class definition. +#[derive(Debug, Clone)] +pub struct ClassMetadata { + /// The base classes this class inherits from. + /// + /// These are stored as unresolved strings (e.g., "Foo" or "module.Foo") + /// and must be resolved using the registry's import information. + pub bases: Vec, +} - /// Map from ClassId (re-exported location) to ClassId (original location) - /// E.g., foo.Bar -> foo._internal.Bar - re_exports: HashMap, +/// A unique identifier for a class within the codebase. +/// +/// Consists of the module path and class name. Note that nested classes +/// are represented with dot notation (e.g., "OuterClass.InnerClass"). +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ClassId { + /// The module path (e.g., "foo.bar"). + pub module: ModuleName, + /// The class name, potentially including nesting (e.g., "Outer.Inner"). + pub name: String, +} - /// Set of module paths that are packages (__init__.py files) - packages: std::collections::HashSet, +/// A registry of all modules, classes, and imports in a Python codebase. +/// +/// The registry is built from parsed files and provides methods to: +/// - Look up class definitions by name or module +/// - Resolve class references through imports and re-exports +/// - Track the inheritance relationships between classes +pub struct Registry { + /// Metadata for each module in the codebase. + pub modules: HashMap, + /// Metadata for each class, indexed by ClassId. + pub classes: HashMap, + /// Index of classes organized by module for efficient lookup. + pub classes_by_module: HashMap>, + /// Import statements for each module. + pub imports: HashMap>, } -impl ClassRegistry { - /// Creates a new class registry from a vector of parsed files. +impl Registry { + /// Builds a registry from parsed Python files. + /// + /// This processes all parsed files to create indexes of modules, classes, and imports. + /// + /// # Arguments + /// + /// * `parsed_files` - The collection of parsed Python files /// - /// This will build the registry and resolve all re-exports. - pub fn new(parsed_files: Vec) -> Self { - let mut registry = Self::default(); + /// # Returns + /// + /// A fully constructed registry ready for class resolution. + pub fn build(parsed_files: &[ParsedFile]) -> Result { + let mut modules = HashMap::new(); + let mut classes = HashMap::new(); + let mut classes_by_module: HashMap> = HashMap::new(); + let mut imports = HashMap::new(); - // Add all files to the registry for parsed in parsed_files { - registry.add_file(parsed); - } - - // Build re-export mappings now that all classes are registered - registry.build_reexports(); - - registry - } - - /// Adds a parsed file to the registry. - fn add_file(&mut self, parsed: ParsedFile) { - // Store imports for this module - self.imports - .insert(parsed.module_path.clone(), parsed.imports.clone()); - - // Track if this is a package - if parsed.is_package { - self.packages.insert(parsed.module_path.clone()); - } + // Record module metadata + modules.insert( + parsed.module_path.clone(), + ModuleMetadata { + file_path: parsed.file_path.clone(), + is_package: parsed.is_package, + }, + ); + + // Index all class definitions from this module + for class in &parsed.classes { + let class_id = ClassId { + module: parsed.module_path.clone(), + name: class.name.clone(), + }; + + classes.insert( + class_id.clone(), + ClassMetadata { + bases: class.bases.clone(), + }, + ); + + classes_by_module + .entry(parsed.module_path.clone()) + .or_default() + .insert(class_id); + } - // Add all classes - for class in parsed.classes { - self.add_class(class); + // Store import statements for later resolution + imports.insert(parsed.module_path.clone(), parsed.imports.clone()); } - // Track re-exports: when we import a class, it may be re-exported - // We'll do a second pass after all files are added + Ok(Self { + modules, + classes, + classes_by_module, + imports, + }) } - /// Second pass: build re-export mappings after all classes are registered. - /// This should be called after all files have been added. + /// Resolves a class name within a given module's context. /// - /// This may need multiple iterations to resolve transitive re-exports - /// (e.g., A re-exports from B, B re-exports from C). - fn build_reexports(&mut self) { - // Iterate until no new re-exports are added (fixed-point iteration) - let mut changed = true; - while changed { - changed = false; - let mut reexports_to_add = Vec::new(); - - for (module_path, imports) in &self.imports { - for import in imports { - match import { - Import::From { module, names } => { - for (name, alias) in names { - // The exported name is the alias if present, otherwise the original name - let exported_name = alias.as_ref().unwrap_or(name); - - // Try to find the original class - let original_module = module.clone(); - let original_id = - ClassId::new(original_module.clone(), name.clone()); - - // Check if the class exists directly or as a re-export - if self.classes.contains_key(&original_id) - || self.re_exports.contains_key(&original_id) - { - // Register the re-export - let reexport_id = - ClassId::new(module_path.clone(), exported_name.clone()); - reexports_to_add.push((reexport_id, original_id)); - } - } - } - Import::RelativeFrom { - level, - module: rel_module, - names, - } => { - // Resolve the relative import - let is_package = self.packages.contains(module_path); - if let Some(base) = crate::utils::resolve_relative_import_base( - module_path, - *level, - is_package, - ) { - for (name, alias) in names { - let exported_name = alias.as_ref().unwrap_or(name); - - // Build the full module path - let original_module = if let Some(m) = rel_module { - if base.is_empty() { - m.clone() - } else { - format!("{base}.{m}") - } - } else { - base.clone() - }; - - let original_id = ClassId::new(original_module, name.clone()); - - // Check if the class exists directly or as a re-export - // (we'll resolve the chain later) - if self.classes.contains_key(&original_id) - || self.re_exports.contains_key(&original_id) - { - let reexport_id = ClassId::new( - module_path.clone(), - exported_name.clone(), - ); - reexports_to_add.push((reexport_id, original_id)); - } - } - } - } - Import::Module { .. } => { - // Module imports don't re-export classes - } - } - } - } - - // Add all the re-exports - for (reexport_id, original_id) in reexports_to_add { - // Only add if it's new - if let std::collections::hash_map::Entry::Vacant(e) = - self.re_exports.entry(reexport_id) - { - e.insert(original_id); - changed = true; - } - } - } - } - - /// Adds a class definition to the registry. - fn add_class(&mut self, class: ClassDefinition) { - let id = ClassId::new(class.module_path, class.name.clone()); - - let info = ClassInfo { - file_path: class.file_path, - bases: class.bases, + /// This method handles the complexity of Python's import system, including: + /// - Direct class references within the same module + /// - Classes imported from other modules + /// - Classes re-exported through `__init__.py` files + /// - Attribute-style class references (e.g., "module.Class") + /// + /// # Algorithm + /// + /// 1. Check if the class exists directly in the specified module + /// 2. If not, consult the module's imports to resolve the name: + /// - Match against `imported_as` names from import statements + /// - Substitute with the actual `imported_item` path + /// 3. Parse the resolved name to find the defining module: + /// - Try progressively shorter prefixes (e.g., "a.b.c" → "a.b" → "a") + /// - Stop when we find a module that exists + /// 4. Recursively resolve the remaining name within that module + /// + /// # Arguments + /// + /// * `module` - The module path providing the context for resolution + /// * `name` - The class name to resolve (may include dots for attribute access) + /// + /// # Returns + /// + /// The resolved `ClassId` if the class can be found, or `None` if resolution fails. + /// + /// # Examples + /// + /// ```text + /// # In module "zoo": + /// from animals import Dog + /// class Puppy(Dog): # Resolves "Dog" to ClassId { module: "animals", name: "Dog" } + /// pass + /// ``` + pub fn resolve_class(&self, module: &str, name: &str) -> Option { + // First, check for a direct class definition in this module + let direct_id = ClassId { + module: module.to_string(), + name: name.to_string(), }; + if self.classes.contains_key(&direct_id) { + return Some(direct_id); + } - // Add to name index - self.name_index - .entry(class.name) - .or_default() - .push(id.clone()); - - // Add to main registry - self.classes.insert(id, info); - } - - /// Finds all classes with a given name. - pub fn find_by_name(&self, name: &str) -> Option<&Vec> { - self.name_index.get(name) - } - - /// Gets class info by ClassId. - pub fn get(&self, id: &ClassId) -> Option<&ClassInfo> { - self.classes.get(id) - } - - /// Resolves a ClassId through re-exports to find the canonical ClassId. - /// If the given ClassId is a re-export, follows the chain to find the original. - /// Otherwise returns the input ClassId if it exists in the registry. - /// - /// This is the public API for resolving re-exports. - pub fn resolve_class_through_reexports(&self, id: &ClassId) -> Option { - self.resolve_through_reexports(id) - } + // Not found directly - use imports to resolve the reference + let imports = self.imports.get(module)?; - /// Internal method to resolve through re-exports. - fn resolve_through_reexports(&self, id: &ClassId) -> Option { - let mut current = id.clone(); - let mut visited = std::collections::HashSet::new(); + // Substitute imported names with their actual module paths + // Example: If "Dog" is imported as "from animals import Dog", + // then "Dog" becomes "animals.Dog" + let mut resolved_name = name.to_string(); - // Follow the re-export chain - loop { - // Prevent infinite loops - if visited.contains(¤t) { + for import in imports { + if name == import.imported_as { + // Exact match: "Dog" → "animals.Dog" + resolved_name = import.imported_item.clone(); break; - } - visited.insert(current.clone()); - - // Check if this is a re-export - if let Some(original) = self.re_exports.get(¤t) { - current = original.clone(); - } else { - // No more re-exports, check if this exists - if self.classes.contains_key(¤t) { - return Some(current); - } + } else if let Some(remainder) = name.strip_prefix(&format!("{}.", import.imported_as)) { + // Prefix match: "Dog.Puppy" → "animals.Dog.Puppy" + resolved_name = format!("{}.{}", import.imported_item, remainder); break; } } - None - } - - /// Returns all class IDs in the registry. - pub fn all_class_ids(&self) -> Vec { - self.classes.keys().cloned().collect() - } - - /// Resolves a base class reference to a ClassId. - /// - /// Given a base class reference in a class definition, determine which - /// actual class it refers to based on imports and available classes. - pub fn resolve_base(&self, base: &BaseClass, context_module: &str) -> Option { - match base { - BaseClass::Simple(name) => { - // Look up the name in imports for this module - if let Some(imports) = self.imports.get(context_module) { - let is_package = self.packages.contains(context_module); - if let Some(qualified) = - crate::utils::resolve_name(name, imports, context_module, is_package) - { - // The import tells us it's from a specific module - // Try to find a class with this name in that module - return self.find_class_by_qualified_name(&qualified); - } - } - - // Not imported - might be in the same module - let id = ClassId::new(context_module.to_string(), name.clone()); - if let Some(resolved) = self.resolve_through_reexports(&id) { - return Some(resolved); - } + // Parse the resolved name to find the defining module and class + // Example: "animals.Dog" needs to be split into module "animals" and class "Dog" + let parts: Vec<&str> = resolved_name.split('.').collect(); - // Try to find by name alone (if unambiguous) - let matches = self.find_by_name(name)?; - if matches.len() == 1 { - return Some(matches[0].clone()); - } + // Try each possible split point from longest to shortest prefix + // This handles cases like "a.b.c.Class" where "a.b.c" might be the module + for i in (1..=parts.len()).rev() { + let module_candidate = parts[..i].join("."); - None - } - BaseClass::Attribute(parts) => { - // For attribute references like `module.Class`, we need to figure out - // what `module` refers to based on imports. - - if parts.len() < 2 { + if self.modules.contains_key(&module_candidate) { + if i == parts.len() { + // The entire resolved name is just a module, not a class return None; } - // The last part is the class name, everything before is the module/package - let class_name = parts.last().unwrap(); - - // Check if this is a fully qualified reference - // Try progressively shorter module paths - for i in (0..parts.len() - 1).rev() { - let module_path = parts[..=i].join("."); - let _remaining_parts = &parts[i + 1..]; - - // Check if this module path matches an import - if let Some(imports) = self.imports.get(context_module) { - let is_package = self.packages.contains(context_module); - if let Some(resolved) = crate::utils::resolve_name( - &parts[0], - imports, - context_module, - is_package, - ) { - // Build the full path - let full_module = if parts.len() > 2 { - format!("{}.{}", resolved, parts[1..parts.len() - 1].join(".")) - } else { - resolved - }; + // Found the module! The remainder is the class name within it + let remainder = parts[i..].join("."); - let id = ClassId::new(full_module, class_name.to_string()); - if let Some(resolved) = self.resolve_through_reexports(&id) { - return Some(resolved); - } - } - } - - // Try as a direct module path - let id = ClassId::new(module_path, class_name.to_string()); - if let Some(resolved) = self.resolve_through_reexports(&id) { - return Some(resolved); - } - } - - None - } - } - } - - /// Finds a class by its qualified name (e.g., "foo.bar.ClassName"). - fn find_class_by_qualified_name(&self, qualified: &str) -> Option { - // Split into module path and class name - let parts: Vec<&str> = qualified.split('.').collect(); - if parts.is_empty() { - return None; - } - - let class_name = parts.last().unwrap(); - - // Try progressively shorter module paths - for i in (0..parts.len() - 1).rev() { - let module_path = parts[..=i].join("."); - let id = ClassId::new(module_path, class_name.to_string()); - if let Some(resolved) = self.resolve_through_reexports(&id) { - return Some(resolved); + // Recursively resolve in case the class itself is re-exported + return self.resolve_class(&module_candidate, &remainder); } } + // Unable to resolve this class reference None } - - /// Returns the number of classes in the registry. - pub fn len(&self) -> usize { - self.classes.len() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::parser::ClassDefinition; - - #[test] - fn test_add_and_find_class() { - let registry = ClassRegistry::new(vec![ParsedFile { - module_path: "animals".to_string(), - classes: vec![ClassDefinition { - name: "Dog".to_string(), - module_path: "animals".to_string(), - file_path: PathBuf::from("animals.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }]); - - // Find by name only - let matches = registry.find_by_name("Dog").unwrap(); - assert_eq!(matches.len(), 1); - assert_eq!(matches[0].class_name, "Dog"); - } - - #[test] - fn test_ambiguous_class_name() { - let registry = ClassRegistry::new(vec![ - ParsedFile { - module_path: "zoo".to_string(), - classes: vec![ClassDefinition { - name: "Animal".to_string(), - module_path: "zoo".to_string(), - file_path: PathBuf::from("zoo.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }, - ParsedFile { - module_path: "farm".to_string(), - classes: vec![ClassDefinition { - name: "Animal".to_string(), - module_path: "farm".to_string(), - file_path: PathBuf::from("farm.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }, - ]); - - // Should find both - let matches = registry.find_by_name("Animal").unwrap(); - assert_eq!(matches.len(), 2); - } - - #[test] - fn test_simple_reexport() { - use crate::parser::{Import, ParsedFile}; - - let registry = ClassRegistry::new(vec![ - // Define a class in base module - ParsedFile { - module_path: "mypackage.base".to_string(), - classes: vec![ClassDefinition { - name: "Animal".to_string(), - module_path: "mypackage.base".to_string(), - file_path: PathBuf::from("mypackage/base.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }, - // Re-export it from __init__.py - ParsedFile { - module_path: "mypackage".to_string(), - classes: vec![], - imports: vec![Import::RelativeFrom { - level: 1, - module: Some("base".to_string()), - names: vec![("Animal".to_string(), None)], - }], - is_package: true, - }, - ]); - - // Should be able to find it via the re-exported path - let id = ClassId::new("mypackage".to_string(), "Animal".to_string()); - let resolved = registry.resolve_through_reexports(&id); - assert!(resolved.is_some()); - let resolved = resolved.unwrap(); - assert_eq!(resolved.module_path, "mypackage.base"); - assert_eq!(resolved.class_name, "Animal"); - } - - #[test] - fn test_transitive_reexport() { - use crate::parser::{Import, ParsedFile}; - - let registry = ClassRegistry::new(vec![ - // Define a class in _base module - ParsedFile { - module_path: "pkg._nodes._base".to_string(), - classes: vec![ClassDefinition { - name: "Node".to_string(), - module_path: "pkg._nodes._base".to_string(), - file_path: PathBuf::from("pkg/_nodes/_base.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }, - // Re-export from _nodes/__init__.py - ParsedFile { - module_path: "pkg._nodes".to_string(), - classes: vec![], - imports: vec![Import::RelativeFrom { - level: 1, - module: Some("_base".to_string()), - names: vec![("Node".to_string(), None)], - }], - is_package: true, - }, - // Re-export from pkg/__init__.py - ParsedFile { - module_path: "pkg".to_string(), - classes: vec![], - imports: vec![Import::RelativeFrom { - level: 1, - module: Some("_nodes".to_string()), - names: vec![("Node".to_string(), None)], - }], - is_package: true, - }, - ]); - - // Should be able to find it via the top-level re-exported path - let id = ClassId::new("pkg".to_string(), "Node".to_string()); - let resolved = registry.resolve_through_reexports(&id); - assert!(resolved.is_some()); - let resolved = resolved.unwrap(); - assert_eq!(resolved.module_path, "pkg._nodes._base"); - assert_eq!(resolved.class_name, "Node"); - - // Should also work via the intermediate path - let id = ClassId::new("pkg._nodes".to_string(), "Node".to_string()); - let resolved = registry.resolve_through_reexports(&id); - assert!(resolved.is_some()); - let resolved = resolved.unwrap(); - assert_eq!(resolved.module_path, "pkg._nodes._base"); - } - - #[test] - fn test_reexport_with_inheritance() { - use crate::parser::{BaseClass, Import, ParsedFile}; - - let registry = ClassRegistry::new(vec![ - // Define Animal in base module - ParsedFile { - module_path: "animals.base".to_string(), - classes: vec![ClassDefinition { - name: "Animal".to_string(), - module_path: "animals.base".to_string(), - file_path: PathBuf::from("animals/base.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }, - // Re-export Animal from animals/__init__.py - ParsedFile { - module_path: "animals".to_string(), - classes: vec![], - imports: vec![Import::RelativeFrom { - level: 1, - module: Some("base".to_string()), - names: vec![("Animal".to_string(), None)], - }], - is_package: true, - }, - // Define Dog that inherits from re-exported Animal - ParsedFile { - module_path: "pets".to_string(), - classes: vec![ClassDefinition { - name: "Dog".to_string(), - module_path: "pets".to_string(), - file_path: PathBuf::from("pets.py"), - bases: vec![BaseClass::Simple("Animal".to_string())], - }], - imports: vec![Import::From { - module: "animals".to_string(), - names: vec![("Animal".to_string(), None)], - }], - is_package: false, - }, - ]); - - // Resolve Dog's base class - let dog_info = registry - .get(&ClassId::new("pets".to_string(), "Dog".to_string())) - .unwrap(); - let base = &dog_info.bases[0]; - let resolved_base = registry.resolve_base(base, "pets"); - - assert!(resolved_base.is_some()); - let resolved_base = resolved_base.unwrap(); - assert_eq!(resolved_base.module_path, "animals.base"); - assert_eq!(resolved_base.class_name, "Animal"); - } } diff --git a/src/utils.rs b/src/utils.rs deleted file mode 100644 index d7712d5..0000000 --- a/src/utils.rs +++ /dev/null @@ -1,308 +0,0 @@ -//! Utility functions for module path and import resolution. - -use crate::parser::Import; - -/// Resolves a name to a fully qualified module path. -/// -/// Given a name used in the code and the imports in the file, determine -/// what module path it refers to. -/// -/// # Arguments -/// -/// * `name` - The name to resolve (e.g., "Animal") -/// * `imports` - The imports in the current file -/// * `current_module` - The current module path -/// * `is_package` - Whether the current module is a package (__init__.py) -/// -/// # Returns -/// -/// The fully qualified name, or None if it cannot be resolved. -pub fn resolve_name( - name: &str, - imports: &[Import], - current_module: &str, - is_package: bool, -) -> Option { - for import in imports { - match import { - Import::Module { module, alias } => { - // import foo.bar as baz OR import foo.bar - if alias.as_ref().unwrap_or(module) == name { - return Some(module.clone()); - } - } - Import::From { module, names } => { - // from foo import Bar [as Baz] - if let Some((n, _)) = names - .iter() - .find(|(n, alias)| alias.as_ref().unwrap_or(n) == name) - { - return Some(format!("{module}.{n}")); - } - } - Import::RelativeFrom { - level, - module: rel_module, - names, - } => { - // from .relative import Bar - if let Some((n, _)) = names - .iter() - .find(|(n, alias)| alias.as_ref().unwrap_or(n) == name) - { - let base = resolve_relative_import_base(current_module, *level, is_package)?; - return Some(match (base.is_empty(), rel_module) { - (true, None) => n.to_string(), - (true, Some(m)) => format!("{m}.{n}"), - (false, None) => format!("{base}.{n}"), - (false, Some(m)) => format!("{base}.{m}.{n}"), - }); - } - } - } - } - None -} - -/// Resolves a relative import to the base module path. -/// -/// This follows Python's relative import semantics as described in PEP 328. -/// -/// # Arguments -/// -/// * `current_module` - The current module path (e.g., "foo.bar.baz") -/// * `level` - Number of dots in the relative import (from Python AST) -/// * `is_package` - Whether the current module is a package (__init__.py file) -/// -/// # Returns -/// -/// The base module path to which the relative import is resolved. -pub fn resolve_relative_import_base( - current_module: &str, - level: usize, - is_package: bool, -) -> Option { - if level == 0 { - return None; // Not a relative import - } - - let parts: Vec<&str> = current_module.split('.').collect(); - - let base = if is_package { - // For packages (__init__.py files) - if level == 1 { - // Single dot means "this package" - current_module.to_string() - } else { - // Multiple dots: go up (level - 1) parent packages - let components_to_keep = parts.len().saturating_sub(level - 1); - if components_to_keep == 0 { - String::new() - } else { - parts[..components_to_keep].join(".") - } - } - } else { - // For regular modules - let components_to_keep = parts.len().saturating_sub(level); - if components_to_keep == 0 { - String::new() - } else { - parts[..components_to_keep].join(".") - } - }; - - Some(base) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_resolve_relative_import_base() { - // Package with level=1: stay in current package - // "foo.bar.baz" package (__init__.py) with level=1 stays at "foo.bar.baz" - assert_eq!( - resolve_relative_import_base("foo.bar.baz", 1, true), - Some("foo.bar.baz".to_string()) - ); - // Package with level=2: go up 1 parent - assert_eq!( - resolve_relative_import_base("foo.bar.baz", 2, true), - Some("foo.bar".to_string()) - ); - // Package with level=3: go up 2 parents - assert_eq!( - resolve_relative_import_base("foo.bar.baz", 3, true), - Some("foo".to_string()) - ); - - // Regular module with level=1: go to containing package - // "foo.bar.baz" module with level=1 goes to "foo.bar" - assert_eq!( - resolve_relative_import_base("foo.bar.baz", 1, false), - Some("foo.bar".to_string()) - ); - - // Single-component package with level=1: stay in package - assert_eq!( - resolve_relative_import_base("mypackage", 1, true), - Some("mypackage".to_string()) - ); - // Single-component module with level=1: go to empty (top level) - assert_eq!( - resolve_relative_import_base("mypackage", 1, false), - Some(String::new()) - ); - } - - /// Test case for resolve_name function - struct Case { - name: &'static str, - imports: Vec, - current_module: &'static str, - is_package: bool, - expected: Option<&'static str>, - } - - #[yare::parameterized( - from_import_direct = { - Case { - name: "Dog", - imports: vec![Import::From { - module: "animals".to_string(), - names: vec![("Dog".to_string(), None)], - }], - current_module: "test.module", - is_package: false, - expected: Some("animals.Dog"), - } - }, - from_import_with_alias = { - Case { - name: "Kitty", - imports: vec![Import::From { - module: "pets".to_string(), - names: vec![("Cat".to_string(), Some("Kitty".to_string()))], - }], - current_module: "test.module", - is_package: false, - expected: Some("pets.Cat"), - } - }, - from_import_alias_no_match = { - Case { - name: "Cat", - imports: vec![Import::From { - module: "pets".to_string(), - names: vec![("Cat".to_string(), Some("Kitty".to_string()))], - }], - current_module: "test.module", - is_package: false, - expected: None, - } - }, - module_import = { - Case { - name: "zoo", - imports: vec![Import::Module { - module: "zoo".to_string(), - alias: None, - }], - current_module: "test.module", - is_package: false, - expected: Some("zoo"), - } - }, - relative_from_module_import = { - Case { - name: "Animal", - imports: vec![Import::RelativeFrom { - level: 1, - module: Some("base".to_string()), - names: vec![("Animal".to_string(), None)], - }], - current_module: "mypackage.dog", - is_package: false, - expected: Some("mypackage.base.Animal"), - } - }, - relative_from_current_package = { - Case { - name: "Cat", - imports: vec![Import::RelativeFrom { - level: 1, - module: None, - names: vec![("Cat".to_string(), None)], - }], - current_module: "mypackage.dog", - is_package: false, - expected: Some("mypackage.Cat"), - } - }, - relative_two_levels_up = { - Case { - name: "Helper", - imports: vec![Import::RelativeFrom { - level: 2, - module: Some("utils".to_string()), - names: vec![("Helper".to_string(), None)], - }], - current_module: "pkg.sub.module", - is_package: false, - expected: Some("pkg.utils.Helper"), - } - }, - relative_three_levels_up = { - Case { - name: "Config", - imports: vec![Import::RelativeFrom { - level: 3, - module: None, - names: vec![("Config".to_string(), None)], - }], - current_module: "pkg.sub.module", - is_package: false, - expected: Some("Config"), - } - }, - relative_from_init = { - Case { - name: "Node", - imports: vec![Import::RelativeFrom { - level: 1, - module: Some("_core".to_string()), - names: vec![("Node".to_string(), None)], - }], - current_module: "mypackage", - is_package: true, - expected: Some("mypackage._core.Node"), - } - }, - relative_from_toplevel = { - Case { - name: "Foo", - imports: vec![Import::RelativeFrom { - level: 1, - module: None, - names: vec![("Foo".to_string(), None)], - }], - current_module: "toplevel", - is_package: false, - expected: Some("Foo"), - } - }, - )] - fn test_resolve_name(case: Case) { - assert_eq!( - resolve_name( - case.name, - &case.imports, - case.current_module, - case.is_package - ), - case.expected.map(|s| s.to_string()) - ); - } -} diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 48f7e31..793a13d 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -661,3 +661,95 @@ class C(testpkg.b.testpkg.a.A): ... temp.close().unwrap(); } + +#[test] +fn test_reexport_not_via_init() { + let temp = assert_fs::TempDir::new().unwrap(); + + temp.child("testpkg/a.py") + .write_str( + r#" +class A: ... +"#, + ) + .unwrap(); + + temp.child("testpkg/b.py") + .write_str( + r#" +from testpkg.a import A + +class B(A): ... +"#, + ) + .unwrap(); + + temp.child("testpkg/c.py") + .write_str( + r#" +from testpkg.b import A + +class C(A): ... +"#, + ) + .unwrap(); + + let mut cmd = Command::cargo_bin("pysubclasses").unwrap(); + cmd.arg("A") + .arg("--directory") + .arg(temp.path()) + .assert() + .success() + .stdout(predicate::str::contains("Found 2 subclass(es) of 'A'")) + .stdout(predicate::str::contains("B")) + .stdout(predicate::str::contains("C")); + + temp.close().unwrap(); +} + +#[test] +fn test_reexport_not_via_init_pass_reexport_module() { + let temp = assert_fs::TempDir::new().unwrap(); + + temp.child("testpkg/a.py") + .write_str( + r#" +class A: ... +"#, + ) + .unwrap(); + + temp.child("testpkg/b.py") + .write_str( + r#" +from testpkg.a import A + +class B(A): ... +"#, + ) + .unwrap(); + + temp.child("testpkg/c.py") + .write_str( + r#" +from testpkg.b import A + +class C(A): ... +"#, + ) + .unwrap(); + + let mut cmd = Command::cargo_bin("pysubclasses").unwrap(); + cmd.arg("A") + .arg("--directory") + .arg(temp.path()) + .arg("--module") + .arg("testpkg.b") + .assert() + .success() + .stdout(predicate::str::contains("Found 2 subclass(es) of 'A'")) + .stdout(predicate::str::contains("B")) + .stdout(predicate::str::contains("C")); + + temp.close().unwrap(); +}