From 5b4b3a085d8f7212fd3b9deb95ed9239bde6e608 Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Sat, 25 Oct 2025 17:24:29 +0200 Subject: [PATCH 1/9] WIP --- src/graph.rs | 266 +++------------------- src/lib.rs | 109 +-------- src/registry.rs | 582 ------------------------------------------------ src/utils.rs | 308 ------------------------- 4 files changed, 32 insertions(+), 1233 deletions(-) delete mode 100644 src/registry.rs delete mode 100644 src/utils.rs diff --git a/src/graph.rs b/src/graph.rs index fbe8c9e..44381f6 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -1,46 +1,44 @@ //! Inheritance graph construction and traversal. -use std::collections::{HashMap, HashSet, VecDeque}; +use std::{ + collections::{HashMap, HashSet}, + path::PathBuf, +}; -use crate::registry::{ClassId, ClassRegistry}; +use crate::parser::ParsedFile; + +pub type ModuleName = String; + +#[derive(Debug)] +pub struct ModuleMetadata { + pub file_path: PathBuf, + pub is_package: bool, +} + +#[derive(Debug)] +pub struct ClassId { + pub module: ModuleName, + pub name: String, +} /// An inheritance graph mapping classes to their children. pub struct InheritanceGraph { - /// Map from parent ClassId to child ClassIds - children: HashMap>, + pub modules: HashMap, + pub classes: HashMap>, + pub children: HashMap>, } impl InheritanceGraph { - /// Builds an inheritance graph from a class registry. + /// Builds an inheritance graph from parsed files. /// /// # Arguments /// - /// * `registry` - The class registry containing all classes + /// * `parsed_files` - The parsed python files. /// /// # Returns /// /// An inheritance graph with parent-child relationships. - pub fn build(registry: &ClassRegistry) -> Self { - let mut children: HashMap> = HashMap::new(); - - // For each class, resolve its base classes and add edges - for class_id in registry.all_class_ids() { - if let Some(info) = registry.get(&class_id) { - for base in &info.bases { - // Try to resolve the base class - if let Some(parent_id) = registry.resolve_base(base, &class_id.module_path) { - // Add edge from parent to child - children - .entry(parent_id) - .or_default() - .push(class_id.clone()); - } - } - } - } - - Self { children } - } + pub fn build(parsed_files: &ParsedFile) -> Self {} /// Finds all transitive subclasses of a given class. /// @@ -53,217 +51,5 @@ impl InheritanceGraph { /// # Returns /// /// A vector of all transitive subclasses (not including the root class itself). - pub fn find_all_subclasses(&self, root: &ClassId) -> Vec { - let mut result = Vec::new(); - let mut visited = HashSet::new(); - let mut queue = VecDeque::new(); - - // Start with the root's immediate children - if let Some(children) = self.children.get(root) { - for child in children { - queue.push_back(child.clone()); - } - } - - // BFS traversal - while let Some(class_id) = queue.pop_front() { - // Skip if already visited (handles potential cycles) - if !visited.insert(class_id.clone()) { - continue; - } - - result.push(class_id.clone()); - - // Add this class's children to the queue - if let Some(children) = self.children.get(&class_id) { - for child in children { - if !visited.contains(child) { - queue.push_back(child.clone()); - } - } - } - } - - result - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::parser::{BaseClass, ClassDefinition}; - use crate::registry::ClassRegistry; - use std::path::PathBuf; - - #[test] - fn test_simple_inheritance() { - let registry = ClassRegistry::new(vec![ - // Animal (base class) - crate::parser::ParsedFile { - module_path: "animals".to_string(), - classes: vec![ClassDefinition { - name: "Animal".to_string(), - module_path: "animals".to_string(), - file_path: PathBuf::from("animals.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }, - // Dog(Animal) - crate::parser::ParsedFile { - module_path: "animals".to_string(), - classes: vec![ClassDefinition { - name: "Dog".to_string(), - module_path: "animals".to_string(), - file_path: PathBuf::from("animals.py"), - bases: vec![BaseClass::Simple("Animal".to_string())], - }], - imports: vec![], - is_package: false, - }, - ]); - - let graph = InheritanceGraph::build(®istry); - - let animal_id = ClassId::new("animals".to_string(), "Animal".to_string()); - let subclasses = graph.find_all_subclasses(&animal_id); - - assert_eq!(subclasses.len(), 1); - assert_eq!(subclasses[0].class_name, "Dog"); - } - - #[test] - fn test_transitive_inheritance() { - let registry = ClassRegistry::new(vec![ - // Animal - crate::parser::ParsedFile { - module_path: "base".to_string(), - classes: vec![ClassDefinition { - name: "Animal".to_string(), - module_path: "base".to_string(), - file_path: PathBuf::from("base.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }, - // Mammal(Animal) - crate::parser::ParsedFile { - module_path: "base".to_string(), - classes: vec![ClassDefinition { - name: "Mammal".to_string(), - module_path: "base".to_string(), - file_path: PathBuf::from("base.py"), - bases: vec![BaseClass::Simple("Animal".to_string())], - }], - imports: vec![], - is_package: false, - }, - // Dog(Mammal) - crate::parser::ParsedFile { - module_path: "base".to_string(), - classes: vec![ClassDefinition { - name: "Dog".to_string(), - module_path: "base".to_string(), - file_path: PathBuf::from("base.py"), - bases: vec![BaseClass::Simple("Mammal".to_string())], - }], - imports: vec![], - is_package: false, - }, - ]); - - let graph = InheritanceGraph::build(®istry); - - let animal_id = ClassId::new("base".to_string(), "Animal".to_string()); - let subclasses = graph.find_all_subclasses(&animal_id); - - // Should find both Mammal and Dog - assert_eq!(subclasses.len(), 2); - - let names: HashSet<_> = subclasses.iter().map(|c| c.class_name.as_str()).collect(); - assert!(names.contains("Mammal")); - assert!(names.contains("Dog")); - } - - #[test] - fn test_multiple_inheritance() { - let registry = ClassRegistry::new(vec![ - // Animal - crate::parser::ParsedFile { - module_path: "base".to_string(), - classes: vec![ClassDefinition { - name: "Animal".to_string(), - module_path: "base".to_string(), - file_path: PathBuf::from("base.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }, - // Pet - crate::parser::ParsedFile { - module_path: "base".to_string(), - classes: vec![ClassDefinition { - name: "Pet".to_string(), - module_path: "base".to_string(), - file_path: PathBuf::from("base.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }, - // Dog(Animal, Pet) - crate::parser::ParsedFile { - module_path: "base".to_string(), - classes: vec![ClassDefinition { - name: "Dog".to_string(), - module_path: "base".to_string(), - file_path: PathBuf::from("base.py"), - bases: vec![ - BaseClass::Simple("Animal".to_string()), - BaseClass::Simple("Pet".to_string()), - ], - }], - imports: vec![], - is_package: false, - }, - ]); - - let graph = InheritanceGraph::build(®istry); - - // Dog should be a subclass of both Animal and Pet - let animal_id = ClassId::new("base".to_string(), "Animal".to_string()); - let animal_subclasses = graph.find_all_subclasses(&animal_id); - assert_eq!(animal_subclasses.len(), 1); - assert_eq!(animal_subclasses[0].class_name, "Dog"); - - let pet_id = ClassId::new("base".to_string(), "Pet".to_string()); - let pet_subclasses = graph.find_all_subclasses(&pet_id); - assert_eq!(pet_subclasses.len(), 1); - assert_eq!(pet_subclasses[0].class_name, "Dog"); - } - - #[test] - fn test_no_subclasses() { - let registry = ClassRegistry::new(vec![crate::parser::ParsedFile { - module_path: "base".to_string(), - classes: vec![ClassDefinition { - name: "Animal".to_string(), - module_path: "base".to_string(), - file_path: PathBuf::from("base.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }]); - - let graph = InheritanceGraph::build(®istry); - - let animal_id = ClassId::new("base".to_string(), "Animal".to_string()); - let subclasses = graph.find_all_subclasses(&animal_id); - - assert_eq!(subclasses.len(), 0); - } + pub fn find_all_subclasses(&self, root: &ClassId) -> Vec {} } diff --git a/src/lib.rs b/src/lib.rs index 793828a..71d3ec4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,18 +21,15 @@ //! # } //! ``` -pub(crate) mod discovery; -pub(crate) mod error; -pub(crate) mod graph; -pub(crate) mod parser; -pub(crate) mod registry; -pub(crate) mod utils; +pub mod discovery; +pub mod error; +pub mod graph; +pub mod parser; use std::path::PathBuf; pub use error::{Error, Result}; use graph::InheritanceGraph; -use registry::{ClassId, ClassRegistry}; /// A reference to a Python class. #[derive(Debug, Clone, PartialEq, Eq)] @@ -74,7 +71,6 @@ impl ClassReference { /// ``` pub struct SubclassFinder { root_dir: PathBuf, - registry: ClassRegistry, graph: InheritanceGraph, } @@ -111,17 +107,10 @@ impl SubclassFinder { }) .collect(); - // Build registry from parsed files - let registry = ClassRegistry::new(parsed_files); - // Build the inheritance graph - let graph = graph::InheritanceGraph::build(®istry); + let graph = graph::InheritanceGraph::build(&parsed_files); - Ok(Self { - root_dir, - registry, - graph, - }) + Ok(Self { root_dir, graph }) } /// Finds all transitive subclasses of a given class. @@ -165,91 +154,5 @@ impl SubclassFinder { class_name: &str, module_path: Option<&str>, ) -> Result> { - // Find the target class - let target_id = self.find_target_class(class_name, module_path)?; - - // Find all subclasses - let subclass_ids = self.graph.find_all_subclasses(&target_id); - - // Convert to ClassReferences - let mut results: Vec = subclass_ids - .into_iter() - .filter_map(|id| { - self.registry.get(&id).map(|info| ClassReference { - class_name: id.class_name.clone(), - module_path: id.module_path.clone(), - file_path: info.file_path.clone(), - }) - }) - .collect(); - - // Sort by module path for consistent output - results.sort_by(|a, b| { - a.module_path - .cmp(&b.module_path) - .then(a.class_name.cmp(&b.class_name)) - }); - - Ok(results) - } - - /// Finds the ClassId for the target class. - fn find_target_class(&self, class_name: &str, module_path: Option<&str>) -> Result { - // If module path provided, look for exact match or re-export - if let Some(module) = module_path { - let id = ClassId::new(module.to_string(), class_name.to_string()); - - // Try direct lookup first - if self.registry.get(&id).is_some() { - return Ok(id); - } - - // Try resolving through re-exports - if let Some(resolved) = self.registry.resolve_class_through_reexports(&id) { - return Ok(resolved); - } - - return Err(Error::ClassNotFound { - name: class_name.to_string(), - module_path: Some(module.to_string()), - }); - } - - // Otherwise find by name - let matches = self - .registry - .find_by_name(class_name) - .filter(|ids| !ids.is_empty()) - .ok_or_else(|| Error::ClassNotFound { - name: class_name.to_string(), - module_path: None, - })?; - - match matches.len() { - 1 => Ok(matches[0].clone()), - _ => { - let candidates = matches.iter().map(|id| id.module_path.clone()).collect(); - Err(Error::AmbiguousClassName { - name: class_name.to_string(), - candidates, - }) - } - } - } - - /// Returns the number of classes found in the codebase. - pub fn class_count(&self) -> usize { - self.registry.len() } - - /// Returns the root directory being searched. - pub fn root_dir(&self) -> &PathBuf { - &self.root_dir - } -} - -#[cfg(test)] -mod tests { - - // Integration tests will be in tests/ directory } diff --git a/src/registry.rs b/src/registry.rs deleted file mode 100644 index 0e9b235..0000000 --- a/src/registry.rs +++ /dev/null @@ -1,582 +0,0 @@ -//! Class registry for tracking class definitions and resolving references. - -use std::collections::HashMap; -use std::path::PathBuf; - -use crate::parser::{BaseClass, ClassDefinition, Import, ParsedFile}; - -/// A unique identifier for a class. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct ClassId { - pub module_path: String, - pub class_name: String, -} - -impl ClassId { - pub fn new(module_path: String, class_name: String) -> Self { - Self { - module_path, - class_name, - } - } -} - -/// Information about where a class is defined. -#[derive(Debug, Clone)] -pub struct ClassInfo { - pub file_path: PathBuf, - pub bases: Vec, -} - -/// A registry of all classes found in a codebase. -#[derive(Default)] -pub struct ClassRegistry { - /// Map from ClassId to ClassInfo - classes: HashMap, - - /// Map from simple class name to all ClassIds with that name - /// (for ambiguity detection) - name_index: HashMap>, - - /// Map from module path to imports in that module - imports: HashMap>, - - /// Map from ClassId (re-exported location) to ClassId (original location) - /// E.g., foo.Bar -> foo._internal.Bar - re_exports: HashMap, - - /// Set of module paths that are packages (__init__.py files) - packages: std::collections::HashSet, -} - -impl ClassRegistry { - /// Creates a new class registry from a vector of parsed files. - /// - /// This will build the registry and resolve all re-exports. - pub fn new(parsed_files: Vec) -> Self { - let mut registry = Self::default(); - - // Add all files to the registry - for parsed in parsed_files { - registry.add_file(parsed); - } - - // Build re-export mappings now that all classes are registered - registry.build_reexports(); - - registry - } - - /// Adds a parsed file to the registry. - fn add_file(&mut self, parsed: ParsedFile) { - // Store imports for this module - self.imports - .insert(parsed.module_path.clone(), parsed.imports.clone()); - - // Track if this is a package - if parsed.is_package { - self.packages.insert(parsed.module_path.clone()); - } - - // Add all classes - for class in parsed.classes { - self.add_class(class); - } - - // Track re-exports: when we import a class, it may be re-exported - // We'll do a second pass after all files are added - } - - /// Second pass: build re-export mappings after all classes are registered. - /// This should be called after all files have been added. - /// - /// This may need multiple iterations to resolve transitive re-exports - /// (e.g., A re-exports from B, B re-exports from C). - fn build_reexports(&mut self) { - // Iterate until no new re-exports are added (fixed-point iteration) - let mut changed = true; - while changed { - changed = false; - let mut reexports_to_add = Vec::new(); - - for (module_path, imports) in &self.imports { - for import in imports { - match import { - Import::From { module, names } => { - for (name, alias) in names { - // The exported name is the alias if present, otherwise the original name - let exported_name = alias.as_ref().unwrap_or(name); - - // Try to find the original class - let original_module = module.clone(); - let original_id = - ClassId::new(original_module.clone(), name.clone()); - - // Check if the class exists directly or as a re-export - if self.classes.contains_key(&original_id) - || self.re_exports.contains_key(&original_id) - { - // Register the re-export - let reexport_id = - ClassId::new(module_path.clone(), exported_name.clone()); - reexports_to_add.push((reexport_id, original_id)); - } - } - } - Import::RelativeFrom { - level, - module: rel_module, - names, - } => { - // Resolve the relative import - let is_package = self.packages.contains(module_path); - if let Some(base) = crate::utils::resolve_relative_import_base( - module_path, - *level, - is_package, - ) { - for (name, alias) in names { - let exported_name = alias.as_ref().unwrap_or(name); - - // Build the full module path - let original_module = if let Some(m) = rel_module { - if base.is_empty() { - m.clone() - } else { - format!("{base}.{m}") - } - } else { - base.clone() - }; - - let original_id = ClassId::new(original_module, name.clone()); - - // Check if the class exists directly or as a re-export - // (we'll resolve the chain later) - if self.classes.contains_key(&original_id) - || self.re_exports.contains_key(&original_id) - { - let reexport_id = ClassId::new( - module_path.clone(), - exported_name.clone(), - ); - reexports_to_add.push((reexport_id, original_id)); - } - } - } - } - Import::Module { .. } => { - // Module imports don't re-export classes - } - } - } - } - - // Add all the re-exports - for (reexport_id, original_id) in reexports_to_add { - // Only add if it's new - if let std::collections::hash_map::Entry::Vacant(e) = - self.re_exports.entry(reexport_id) - { - e.insert(original_id); - changed = true; - } - } - } - } - - /// Adds a class definition to the registry. - fn add_class(&mut self, class: ClassDefinition) { - let id = ClassId::new(class.module_path, class.name.clone()); - - let info = ClassInfo { - file_path: class.file_path, - bases: class.bases, - }; - - // Add to name index - self.name_index - .entry(class.name) - .or_default() - .push(id.clone()); - - // Add to main registry - self.classes.insert(id, info); - } - - /// Finds all classes with a given name. - pub fn find_by_name(&self, name: &str) -> Option<&Vec> { - self.name_index.get(name) - } - - /// Gets class info by ClassId. - pub fn get(&self, id: &ClassId) -> Option<&ClassInfo> { - self.classes.get(id) - } - - /// Resolves a ClassId through re-exports to find the canonical ClassId. - /// If the given ClassId is a re-export, follows the chain to find the original. - /// Otherwise returns the input ClassId if it exists in the registry. - /// - /// This is the public API for resolving re-exports. - pub fn resolve_class_through_reexports(&self, id: &ClassId) -> Option { - self.resolve_through_reexports(id) - } - - /// Internal method to resolve through re-exports. - fn resolve_through_reexports(&self, id: &ClassId) -> Option { - let mut current = id.clone(); - let mut visited = std::collections::HashSet::new(); - - // Follow the re-export chain - loop { - // Prevent infinite loops - if visited.contains(¤t) { - break; - } - visited.insert(current.clone()); - - // Check if this is a re-export - if let Some(original) = self.re_exports.get(¤t) { - current = original.clone(); - } else { - // No more re-exports, check if this exists - if self.classes.contains_key(¤t) { - return Some(current); - } - break; - } - } - - None - } - - /// Returns all class IDs in the registry. - pub fn all_class_ids(&self) -> Vec { - self.classes.keys().cloned().collect() - } - - /// Resolves a base class reference to a ClassId. - /// - /// Given a base class reference in a class definition, determine which - /// actual class it refers to based on imports and available classes. - pub fn resolve_base(&self, base: &BaseClass, context_module: &str) -> Option { - match base { - BaseClass::Simple(name) => { - // Look up the name in imports for this module - if let Some(imports) = self.imports.get(context_module) { - let is_package = self.packages.contains(context_module); - if let Some(qualified) = - crate::utils::resolve_name(name, imports, context_module, is_package) - { - // The import tells us it's from a specific module - // Try to find a class with this name in that module - return self.find_class_by_qualified_name(&qualified); - } - } - - // Not imported - might be in the same module - let id = ClassId::new(context_module.to_string(), name.clone()); - if let Some(resolved) = self.resolve_through_reexports(&id) { - return Some(resolved); - } - - // Try to find by name alone (if unambiguous) - let matches = self.find_by_name(name)?; - if matches.len() == 1 { - return Some(matches[0].clone()); - } - - None - } - BaseClass::Attribute(parts) => { - // For attribute references like `module.Class`, we need to figure out - // what `module` refers to based on imports. - - if parts.len() < 2 { - return None; - } - - // The last part is the class name, everything before is the module/package - let class_name = parts.last().unwrap(); - - // Check if this is a fully qualified reference - // Try progressively shorter module paths - for i in (0..parts.len() - 1).rev() { - let module_path = parts[..=i].join("."); - let _remaining_parts = &parts[i + 1..]; - - // Check if this module path matches an import - if let Some(imports) = self.imports.get(context_module) { - let is_package = self.packages.contains(context_module); - if let Some(resolved) = crate::utils::resolve_name( - &parts[0], - imports, - context_module, - is_package, - ) { - // Build the full path - let full_module = if parts.len() > 2 { - format!("{}.{}", resolved, parts[1..parts.len() - 1].join(".")) - } else { - resolved - }; - - let id = ClassId::new(full_module, class_name.to_string()); - if let Some(resolved) = self.resolve_through_reexports(&id) { - return Some(resolved); - } - } - } - - // Try as a direct module path - let id = ClassId::new(module_path, class_name.to_string()); - if let Some(resolved) = self.resolve_through_reexports(&id) { - return Some(resolved); - } - } - - None - } - } - } - - /// Finds a class by its qualified name (e.g., "foo.bar.ClassName"). - fn find_class_by_qualified_name(&self, qualified: &str) -> Option { - // Split into module path and class name - let parts: Vec<&str> = qualified.split('.').collect(); - if parts.is_empty() { - return None; - } - - let class_name = parts.last().unwrap(); - - // Try progressively shorter module paths - for i in (0..parts.len() - 1).rev() { - let module_path = parts[..=i].join("."); - let id = ClassId::new(module_path, class_name.to_string()); - if let Some(resolved) = self.resolve_through_reexports(&id) { - return Some(resolved); - } - } - - None - } - - /// Returns the number of classes in the registry. - pub fn len(&self) -> usize { - self.classes.len() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::parser::ClassDefinition; - - #[test] - fn test_add_and_find_class() { - let registry = ClassRegistry::new(vec![ParsedFile { - module_path: "animals".to_string(), - classes: vec![ClassDefinition { - name: "Dog".to_string(), - module_path: "animals".to_string(), - file_path: PathBuf::from("animals.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }]); - - // Find by name only - let matches = registry.find_by_name("Dog").unwrap(); - assert_eq!(matches.len(), 1); - assert_eq!(matches[0].class_name, "Dog"); - } - - #[test] - fn test_ambiguous_class_name() { - let registry = ClassRegistry::new(vec![ - ParsedFile { - module_path: "zoo".to_string(), - classes: vec![ClassDefinition { - name: "Animal".to_string(), - module_path: "zoo".to_string(), - file_path: PathBuf::from("zoo.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }, - ParsedFile { - module_path: "farm".to_string(), - classes: vec![ClassDefinition { - name: "Animal".to_string(), - module_path: "farm".to_string(), - file_path: PathBuf::from("farm.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }, - ]); - - // Should find both - let matches = registry.find_by_name("Animal").unwrap(); - assert_eq!(matches.len(), 2); - } - - #[test] - fn test_simple_reexport() { - use crate::parser::{Import, ParsedFile}; - - let registry = ClassRegistry::new(vec![ - // Define a class in base module - ParsedFile { - module_path: "mypackage.base".to_string(), - classes: vec![ClassDefinition { - name: "Animal".to_string(), - module_path: "mypackage.base".to_string(), - file_path: PathBuf::from("mypackage/base.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }, - // Re-export it from __init__.py - ParsedFile { - module_path: "mypackage".to_string(), - classes: vec![], - imports: vec![Import::RelativeFrom { - level: 1, - module: Some("base".to_string()), - names: vec![("Animal".to_string(), None)], - }], - is_package: true, - }, - ]); - - // Should be able to find it via the re-exported path - let id = ClassId::new("mypackage".to_string(), "Animal".to_string()); - let resolved = registry.resolve_through_reexports(&id); - assert!(resolved.is_some()); - let resolved = resolved.unwrap(); - assert_eq!(resolved.module_path, "mypackage.base"); - assert_eq!(resolved.class_name, "Animal"); - } - - #[test] - fn test_transitive_reexport() { - use crate::parser::{Import, ParsedFile}; - - let registry = ClassRegistry::new(vec![ - // Define a class in _base module - ParsedFile { - module_path: "pkg._nodes._base".to_string(), - classes: vec![ClassDefinition { - name: "Node".to_string(), - module_path: "pkg._nodes._base".to_string(), - file_path: PathBuf::from("pkg/_nodes/_base.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }, - // Re-export from _nodes/__init__.py - ParsedFile { - module_path: "pkg._nodes".to_string(), - classes: vec![], - imports: vec![Import::RelativeFrom { - level: 1, - module: Some("_base".to_string()), - names: vec![("Node".to_string(), None)], - }], - is_package: true, - }, - // Re-export from pkg/__init__.py - ParsedFile { - module_path: "pkg".to_string(), - classes: vec![], - imports: vec![Import::RelativeFrom { - level: 1, - module: Some("_nodes".to_string()), - names: vec![("Node".to_string(), None)], - }], - is_package: true, - }, - ]); - - // Should be able to find it via the top-level re-exported path - let id = ClassId::new("pkg".to_string(), "Node".to_string()); - let resolved = registry.resolve_through_reexports(&id); - assert!(resolved.is_some()); - let resolved = resolved.unwrap(); - assert_eq!(resolved.module_path, "pkg._nodes._base"); - assert_eq!(resolved.class_name, "Node"); - - // Should also work via the intermediate path - let id = ClassId::new("pkg._nodes".to_string(), "Node".to_string()); - let resolved = registry.resolve_through_reexports(&id); - assert!(resolved.is_some()); - let resolved = resolved.unwrap(); - assert_eq!(resolved.module_path, "pkg._nodes._base"); - } - - #[test] - fn test_reexport_with_inheritance() { - use crate::parser::{BaseClass, Import, ParsedFile}; - - let registry = ClassRegistry::new(vec![ - // Define Animal in base module - ParsedFile { - module_path: "animals.base".to_string(), - classes: vec![ClassDefinition { - name: "Animal".to_string(), - module_path: "animals.base".to_string(), - file_path: PathBuf::from("animals/base.py"), - bases: vec![], - }], - imports: vec![], - is_package: false, - }, - // Re-export Animal from animals/__init__.py - ParsedFile { - module_path: "animals".to_string(), - classes: vec![], - imports: vec![Import::RelativeFrom { - level: 1, - module: Some("base".to_string()), - names: vec![("Animal".to_string(), None)], - }], - is_package: true, - }, - // Define Dog that inherits from re-exported Animal - ParsedFile { - module_path: "pets".to_string(), - classes: vec![ClassDefinition { - name: "Dog".to_string(), - module_path: "pets".to_string(), - file_path: PathBuf::from("pets.py"), - bases: vec![BaseClass::Simple("Animal".to_string())], - }], - imports: vec![Import::From { - module: "animals".to_string(), - names: vec![("Animal".to_string(), None)], - }], - is_package: false, - }, - ]); - - // Resolve Dog's base class - let dog_info = registry - .get(&ClassId::new("pets".to_string(), "Dog".to_string())) - .unwrap(); - let base = &dog_info.bases[0]; - let resolved_base = registry.resolve_base(base, "pets"); - - assert!(resolved_base.is_some()); - let resolved_base = resolved_base.unwrap(); - assert_eq!(resolved_base.module_path, "animals.base"); - assert_eq!(resolved_base.class_name, "Animal"); - } -} diff --git a/src/utils.rs b/src/utils.rs deleted file mode 100644 index d7712d5..0000000 --- a/src/utils.rs +++ /dev/null @@ -1,308 +0,0 @@ -//! Utility functions for module path and import resolution. - -use crate::parser::Import; - -/// Resolves a name to a fully qualified module path. -/// -/// Given a name used in the code and the imports in the file, determine -/// what module path it refers to. -/// -/// # Arguments -/// -/// * `name` - The name to resolve (e.g., "Animal") -/// * `imports` - The imports in the current file -/// * `current_module` - The current module path -/// * `is_package` - Whether the current module is a package (__init__.py) -/// -/// # Returns -/// -/// The fully qualified name, or None if it cannot be resolved. -pub fn resolve_name( - name: &str, - imports: &[Import], - current_module: &str, - is_package: bool, -) -> Option { - for import in imports { - match import { - Import::Module { module, alias } => { - // import foo.bar as baz OR import foo.bar - if alias.as_ref().unwrap_or(module) == name { - return Some(module.clone()); - } - } - Import::From { module, names } => { - // from foo import Bar [as Baz] - if let Some((n, _)) = names - .iter() - .find(|(n, alias)| alias.as_ref().unwrap_or(n) == name) - { - return Some(format!("{module}.{n}")); - } - } - Import::RelativeFrom { - level, - module: rel_module, - names, - } => { - // from .relative import Bar - if let Some((n, _)) = names - .iter() - .find(|(n, alias)| alias.as_ref().unwrap_or(n) == name) - { - let base = resolve_relative_import_base(current_module, *level, is_package)?; - return Some(match (base.is_empty(), rel_module) { - (true, None) => n.to_string(), - (true, Some(m)) => format!("{m}.{n}"), - (false, None) => format!("{base}.{n}"), - (false, Some(m)) => format!("{base}.{m}.{n}"), - }); - } - } - } - } - None -} - -/// Resolves a relative import to the base module path. -/// -/// This follows Python's relative import semantics as described in PEP 328. -/// -/// # Arguments -/// -/// * `current_module` - The current module path (e.g., "foo.bar.baz") -/// * `level` - Number of dots in the relative import (from Python AST) -/// * `is_package` - Whether the current module is a package (__init__.py file) -/// -/// # Returns -/// -/// The base module path to which the relative import is resolved. -pub fn resolve_relative_import_base( - current_module: &str, - level: usize, - is_package: bool, -) -> Option { - if level == 0 { - return None; // Not a relative import - } - - let parts: Vec<&str> = current_module.split('.').collect(); - - let base = if is_package { - // For packages (__init__.py files) - if level == 1 { - // Single dot means "this package" - current_module.to_string() - } else { - // Multiple dots: go up (level - 1) parent packages - let components_to_keep = parts.len().saturating_sub(level - 1); - if components_to_keep == 0 { - String::new() - } else { - parts[..components_to_keep].join(".") - } - } - } else { - // For regular modules - let components_to_keep = parts.len().saturating_sub(level); - if components_to_keep == 0 { - String::new() - } else { - parts[..components_to_keep].join(".") - } - }; - - Some(base) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_resolve_relative_import_base() { - // Package with level=1: stay in current package - // "foo.bar.baz" package (__init__.py) with level=1 stays at "foo.bar.baz" - assert_eq!( - resolve_relative_import_base("foo.bar.baz", 1, true), - Some("foo.bar.baz".to_string()) - ); - // Package with level=2: go up 1 parent - assert_eq!( - resolve_relative_import_base("foo.bar.baz", 2, true), - Some("foo.bar".to_string()) - ); - // Package with level=3: go up 2 parents - assert_eq!( - resolve_relative_import_base("foo.bar.baz", 3, true), - Some("foo".to_string()) - ); - - // Regular module with level=1: go to containing package - // "foo.bar.baz" module with level=1 goes to "foo.bar" - assert_eq!( - resolve_relative_import_base("foo.bar.baz", 1, false), - Some("foo.bar".to_string()) - ); - - // Single-component package with level=1: stay in package - assert_eq!( - resolve_relative_import_base("mypackage", 1, true), - Some("mypackage".to_string()) - ); - // Single-component module with level=1: go to empty (top level) - assert_eq!( - resolve_relative_import_base("mypackage", 1, false), - Some(String::new()) - ); - } - - /// Test case for resolve_name function - struct Case { - name: &'static str, - imports: Vec, - current_module: &'static str, - is_package: bool, - expected: Option<&'static str>, - } - - #[yare::parameterized( - from_import_direct = { - Case { - name: "Dog", - imports: vec![Import::From { - module: "animals".to_string(), - names: vec![("Dog".to_string(), None)], - }], - current_module: "test.module", - is_package: false, - expected: Some("animals.Dog"), - } - }, - from_import_with_alias = { - Case { - name: "Kitty", - imports: vec![Import::From { - module: "pets".to_string(), - names: vec![("Cat".to_string(), Some("Kitty".to_string()))], - }], - current_module: "test.module", - is_package: false, - expected: Some("pets.Cat"), - } - }, - from_import_alias_no_match = { - Case { - name: "Cat", - imports: vec![Import::From { - module: "pets".to_string(), - names: vec![("Cat".to_string(), Some("Kitty".to_string()))], - }], - current_module: "test.module", - is_package: false, - expected: None, - } - }, - module_import = { - Case { - name: "zoo", - imports: vec![Import::Module { - module: "zoo".to_string(), - alias: None, - }], - current_module: "test.module", - is_package: false, - expected: Some("zoo"), - } - }, - relative_from_module_import = { - Case { - name: "Animal", - imports: vec![Import::RelativeFrom { - level: 1, - module: Some("base".to_string()), - names: vec![("Animal".to_string(), None)], - }], - current_module: "mypackage.dog", - is_package: false, - expected: Some("mypackage.base.Animal"), - } - }, - relative_from_current_package = { - Case { - name: "Cat", - imports: vec![Import::RelativeFrom { - level: 1, - module: None, - names: vec![("Cat".to_string(), None)], - }], - current_module: "mypackage.dog", - is_package: false, - expected: Some("mypackage.Cat"), - } - }, - relative_two_levels_up = { - Case { - name: "Helper", - imports: vec![Import::RelativeFrom { - level: 2, - module: Some("utils".to_string()), - names: vec![("Helper".to_string(), None)], - }], - current_module: "pkg.sub.module", - is_package: false, - expected: Some("pkg.utils.Helper"), - } - }, - relative_three_levels_up = { - Case { - name: "Config", - imports: vec![Import::RelativeFrom { - level: 3, - module: None, - names: vec![("Config".to_string(), None)], - }], - current_module: "pkg.sub.module", - is_package: false, - expected: Some("Config"), - } - }, - relative_from_init = { - Case { - name: "Node", - imports: vec![Import::RelativeFrom { - level: 1, - module: Some("_core".to_string()), - names: vec![("Node".to_string(), None)], - }], - current_module: "mypackage", - is_package: true, - expected: Some("mypackage._core.Node"), - } - }, - relative_from_toplevel = { - Case { - name: "Foo", - imports: vec![Import::RelativeFrom { - level: 1, - module: None, - names: vec![("Foo".to_string(), None)], - }], - current_module: "toplevel", - is_package: false, - expected: Some("Foo"), - } - }, - )] - fn test_resolve_name(case: Case) { - assert_eq!( - resolve_name( - case.name, - &case.imports, - case.current_module, - case.is_package - ), - case.expected.map(|s| s.to_string()) - ); - } -} From 0a0ff4504c84c1c79fb68fc9bab13310714bfcd6 Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Sat, 25 Oct 2025 18:08:56 +0200 Subject: [PATCH 2/9] Add file_path to parsed file --- src/parser.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/parser.rs b/src/parser.rs index 2b561a5..6302b7e 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -54,6 +54,8 @@ pub enum Import { /// The result of parsing a Python file. #[derive(Debug)] pub struct ParsedFile { + /// The path of this file + pub file_path: PathBuf, /// The module path of this file pub module_path: String, /// Class definitions found in this file @@ -138,6 +140,7 @@ pub fn parse_file(file_path: &Path, module_path: &str) -> Result { .unwrap_or(false); Ok(ParsedFile { + file_path: file_path.to_path_buf(), module_path: module_path.to_string(), classes, imports, From 6d4b44fe6d06c4831c1f370b658743c5ff4b2559 Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Sat, 25 Oct 2025 18:16:03 +0200 Subject: [PATCH 3/9] WIP --- src/graph.rs | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/graph.rs b/src/graph.rs index 44381f6..e70b383 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -25,7 +25,27 @@ pub struct ClassId { pub struct InheritanceGraph { pub modules: HashMap, pub classes: HashMap>, - pub children: HashMap>, + pub imports: HashMap>, + pub class_children: HashMap>, +} + +/// An enum representing a resolved import. +/// +/// `import X` is always an imported module. +/// +/// `from X import Y` can be either a module import, or a module member import. +/// This can be determined by first seeing if the module X.Y exists. If so then this is a module import of module X.Y. +/// If not we check if the module X exists. If so then this is an import of the member Y from the module X. +pub enum ResolvedImport { + Module { + module: ModuleName, + imported_as: String, + }, + ModuleMember { + module: ModuleName, + member: String, + imported_as: String, + }, } impl InheritanceGraph { From 97ad566009ad25a719c82c5e1a1e267f41aad17d Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Sat, 25 Oct 2025 20:18:20 +0200 Subject: [PATCH 4/9] WIP --- src/graph.rs | 214 ++++++++++++++++++++++++++++++++++++++++++++++++++- src/lib.rs | 107 +++++++++++++++++++++++++- 2 files changed, 315 insertions(+), 6 deletions(-) diff --git a/src/graph.rs b/src/graph.rs index e70b383..8609978 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -15,7 +15,7 @@ pub struct ModuleMetadata { pub is_package: bool, } -#[derive(Debug)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ClassId { pub module: ModuleName, pub name: String, @@ -24,7 +24,7 @@ pub struct ClassId { /// An inheritance graph mapping classes to their children. pub struct InheritanceGraph { pub modules: HashMap, - pub classes: HashMap>, + pub classes: HashSet, pub imports: HashMap>, pub class_children: HashMap>, } @@ -58,7 +58,124 @@ impl InheritanceGraph { /// # Returns /// /// An inheritance graph with parent-child relationships. - pub fn build(parsed_files: &ParsedFile) -> Self {} + pub fn build(parsed_files: &[ParsedFile]) -> Self { + let mut modules = HashMap::new(); + let mut classes: HashSet = HashSet::new(); + let mut imports: HashMap> = HashMap::new(); + let mut class_children: HashMap> = HashMap::new(); + + // First pass: collect all modules and classes + for file in parsed_files { + // Register module metadata + modules.insert( + file.module_path.clone(), + ModuleMetadata { + file_path: file.file_path.clone(), + is_package: file.is_package, + }, + ); + + // Register classes in this module + for class in &file.classes { + classes.insert(ClassId { + module: file.module_path.clone(), + name: class.name.clone(), + }); + } + } + + // Second pass: resolve imports + for file in parsed_files { + let mut resolved_imports = Vec::new(); + + for import in &file.imports { + match import { + crate::parser::Import::Module { module, alias } => { + let imported_as = if let Some(alias) = alias { + alias.clone() + } else { + module.clone() + }; + resolved_imports.push(ResolvedImport::Module { + module: module.clone(), + imported_as, + }); + } + crate::parser::Import::From { module, names } => { + for (name, alias) in names { + // Check if this is a module import or member import + let full_module = format!("{module}.{name}"); + if modules.contains_key(&full_module) { + // This is a module import + resolved_imports.push(ResolvedImport::Module { + module: full_module, + imported_as: alias.clone().unwrap_or_else(|| name.clone()), + }); + } else { + // This is a member import + resolved_imports.push(ResolvedImport::ModuleMember { + module: module.clone(), + member: name.clone(), + imported_as: alias.clone().unwrap_or_else(|| name.clone()), + }); + } + } + } + crate::parser::Import::RelativeFrom { + level, + module, + names, + } => { + // Resolve relative import to absolute module path + if let Some(abs_module) = resolve_relative_import( + &file.module_path, + *level, + module.as_deref(), + file.is_package, + ) { + for (name, alias) in names { + // Check if this is a module import or member import + let full_module = format!("{abs_module}.{name}"); + if modules.contains_key(&full_module) { + // This is a module import + resolved_imports.push(ResolvedImport::Module { + module: full_module, + imported_as: alias.clone().unwrap_or_else(|| name.clone()), + }); + } else { + // This is a member import + resolved_imports.push(ResolvedImport::ModuleMember { + module: abs_module.clone(), + member: name.clone(), + imported_as: alias.clone().unwrap_or_else(|| name.clone()), + }); + } + } + } + } + } + } + + imports.insert(file.module_path.clone(), resolved_imports); + } + + // Third pass: build parent-child relationships + for file in parsed_files { + for class in &file.classes { + for base in &class.bases { + // Resolve the base class to a ClassId + // TODO + } + } + } + + Self { + modules, + classes, + imports, + class_children, + } + } /// Finds all transitive subclasses of a given class. /// @@ -71,5 +188,94 @@ impl InheritanceGraph { /// # Returns /// /// A vector of all transitive subclasses (not including the root class itself). - pub fn find_all_subclasses(&self, root: &ClassId) -> Vec {} + pub fn find_all_subclasses(&self, root: &ClassId) -> Vec { + use std::collections::VecDeque; + + let mut result = Vec::new(); + let mut visited = HashSet::new(); + let mut queue = VecDeque::new(); + + // Start BFS from the root + queue.push_back(root.clone()); + visited.insert((root.module.clone(), root.name.clone())); + + while let Some(current) = queue.pop_front() { + // Find all direct children + if let Some(children) = self.class_children.get(¤t) { + for child in children { + let key = (child.module.clone(), child.name.clone()); + if !visited.contains(&key) { + visited.insert(key); + result.push(child.clone()); + queue.push_back(child.clone()); + } + } + } + } + + result + } +} + +/// Resolves a relative import to an absolute module path. +/// +/// # Arguments +/// +/// * `current_module` - The module path where the import occurs +/// * `level` - Number of dots in the relative import +/// * `relative_module` - Optional module name after the dots +/// * `is_package` - Whether the current module is a package (__init__.py) +/// +/// # Returns +/// +/// The absolute module path, or None if the import cannot be resolved +fn resolve_relative_import( + current_module: &str, + level: usize, + relative_module: Option<&str>, + is_package: bool, +) -> Option { + let parts: Vec<&str> = current_module.split('.').collect(); + + // level determines how many parent levels to go up + // level=1: from . import x (current package) + // level=2: from .. import x (parent package) + // + // Key insight: "current package" means different things: + // - For pkg/__init__.py (is_package=true), current package is "pkg" + // - For pkg/module.py (is_package=false), current package is also "pkg" + // + // So for a package, level=1 should not remove anything + // For a module, level=1 should remove the last component + if level == 0 || level > parts.len() { + return None; + } + + // Calculate how many components to remove + // For packages: level=1 removes 0, level=2 removes 1, etc. + // For modules: level=1 removes 1, level=2 removes 2, etc. + let levels_to_remove = if is_package { + level.saturating_sub(1) + } else { + level + }; + + if levels_to_remove > parts.len() { + return None; + } + + let base_parts = &parts[..parts.len() - levels_to_remove]; + + let base = if base_parts.is_empty() { + None + } else { + Some(base_parts.join(".")) + }; + + match (base, relative_module) { + (Some(base_str), Some(module)) => Some(format!("{base_str}.{module}")), + (Some(base_str), None) => Some(base_str), + (None, Some(module)) => Some(module.to_string()), + (None, None) => None, + } } diff --git a/src/lib.rs b/src/lib.rs index 71d3ec4..6490abc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -70,7 +70,6 @@ impl ClassReference { /// # } /// ``` pub struct SubclassFinder { - root_dir: PathBuf, graph: InheritanceGraph, } @@ -110,7 +109,12 @@ impl SubclassFinder { // Build the inheritance graph let graph = graph::InheritanceGraph::build(&parsed_files); - Ok(Self { root_dir, graph }) + Ok(Self { graph }) + } + + /// Returns the total number of classes found in the codebase. + pub fn class_count(&self) -> usize { + self.graph.classes.values().map(|v| v.len()).sum() } /// Finds all transitive subclasses of a given class. @@ -154,5 +158,104 @@ impl SubclassFinder { class_name: &str, module_path: Option<&str>, ) -> Result> { + // Find all classes with the given name + let mut candidates = Vec::new(); + for class_ids in self.graph.classes.values() { + for class_id in class_ids { + if class_id.name == class_name { + candidates.push(class_id); + } + } + } + + // Filter by module path if provided + let root_class = if let Some(module) = module_path { + // Try exact match first + let exact_match = candidates.iter().find(|c| c.module == module).copied(); + + if let Some(class_id) = exact_match { + class_id + } else { + // Check if the module re-exports the class + // Look for imports in the specified module that import this class + if let Some(imports) = self.graph.imports.get(module) { + let mut found = None; + for import in imports { + if let graph::ResolvedImport::ModuleMember { + module: source_module, + member, + .. + } = import + { + if member == class_name { + // Check if the class exists in the source module + found = candidates + .iter() + .find(|c| c.module == *source_module) + .copied(); + if found.is_some() { + break; + } + } + } + } + + if let Some(class_id) = found { + class_id + } else { + return Err(Error::ClassNotFound { + name: class_name.to_string(), + module_path: Some(module.to_string()), + }); + } + } else { + return Err(Error::ClassNotFound { + name: class_name.to_string(), + module_path: Some(module.to_string()), + }); + } + } + } else { + // No module path specified + if candidates.is_empty() { + return Err(Error::ClassNotFound { + name: class_name.to_string(), + module_path: None, + }); + } else if candidates.len() > 1 { + let module_names: Vec = + candidates.iter().map(|c| c.module.clone()).collect(); + return Err(Error::AmbiguousClassName { + name: class_name.to_string(), + candidates: module_names, + }); + } else { + candidates[0] + } + }; + + // Find all subclasses using BFS + let subclass_ids = self.graph.find_all_subclasses(root_class); + + // Convert ClassIds to ClassReferences + let mut references = Vec::new(); + for class_id in subclass_ids { + if let Some(metadata) = self.graph.modules.get(&class_id.module) { + references.push(ClassReference { + class_name: class_id.name.clone(), + module_path: class_id.module.clone(), + file_path: metadata.file_path.clone(), + }); + } + } + + // Sort by module path for consistent output + references.sort_by(|a, b| { + a.module_path + .cmp(&b.module_path) + .then_with(|| a.class_name.cmp(&b.class_name)) + }); + + Ok(references) } } From daaa85a95b1dacbbb0adb67704d3606e9e69d499 Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Sun, 26 Oct 2025 06:33:12 +0100 Subject: [PATCH 5/9] WIP --- src/graph.rs | 218 ++---------------------------------------------- src/lib.rs | 115 ++----------------------- src/parser.rs | 47 ++++------- src/registry.rs | 61 ++++++++++++++ 4 files changed, 93 insertions(+), 348 deletions(-) create mode 100644 src/registry.rs diff --git a/src/graph.rs b/src/graph.rs index 8609978..7adfa3c 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -1,32 +1,14 @@ //! Inheritance graph construction and traversal. -use std::{ - collections::{HashMap, HashSet}, - path::PathBuf, -}; +use std::collections::{HashMap, HashSet}; -use crate::parser::ParsedFile; +use crate::registry::{ClassId, Registry}; pub type ModuleName = String; -#[derive(Debug)] -pub struct ModuleMetadata { - pub file_path: PathBuf, - pub is_package: bool, -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct ClassId { - pub module: ModuleName, - pub name: String, -} - /// An inheritance graph mapping classes to their children. pub struct InheritanceGraph { - pub modules: HashMap, - pub classes: HashSet, - pub imports: HashMap>, - pub class_children: HashMap>, + pub children: HashMap>, } /// An enum representing a resolved import. @@ -49,132 +31,9 @@ pub enum ResolvedImport { } impl InheritanceGraph { - /// Builds an inheritance graph from parsed files. - /// - /// # Arguments - /// - /// * `parsed_files` - The parsed python files. - /// - /// # Returns - /// - /// An inheritance graph with parent-child relationships. - pub fn build(parsed_files: &[ParsedFile]) -> Self { - let mut modules = HashMap::new(); - let mut classes: HashSet = HashSet::new(); - let mut imports: HashMap> = HashMap::new(); - let mut class_children: HashMap> = HashMap::new(); - - // First pass: collect all modules and classes - for file in parsed_files { - // Register module metadata - modules.insert( - file.module_path.clone(), - ModuleMetadata { - file_path: file.file_path.clone(), - is_package: file.is_package, - }, - ); - - // Register classes in this module - for class in &file.classes { - classes.insert(ClassId { - module: file.module_path.clone(), - name: class.name.clone(), - }); - } - } - - // Second pass: resolve imports - for file in parsed_files { - let mut resolved_imports = Vec::new(); - - for import in &file.imports { - match import { - crate::parser::Import::Module { module, alias } => { - let imported_as = if let Some(alias) = alias { - alias.clone() - } else { - module.clone() - }; - resolved_imports.push(ResolvedImport::Module { - module: module.clone(), - imported_as, - }); - } - crate::parser::Import::From { module, names } => { - for (name, alias) in names { - // Check if this is a module import or member import - let full_module = format!("{module}.{name}"); - if modules.contains_key(&full_module) { - // This is a module import - resolved_imports.push(ResolvedImport::Module { - module: full_module, - imported_as: alias.clone().unwrap_or_else(|| name.clone()), - }); - } else { - // This is a member import - resolved_imports.push(ResolvedImport::ModuleMember { - module: module.clone(), - member: name.clone(), - imported_as: alias.clone().unwrap_or_else(|| name.clone()), - }); - } - } - } - crate::parser::Import::RelativeFrom { - level, - module, - names, - } => { - // Resolve relative import to absolute module path - if let Some(abs_module) = resolve_relative_import( - &file.module_path, - *level, - module.as_deref(), - file.is_package, - ) { - for (name, alias) in names { - // Check if this is a module import or member import - let full_module = format!("{abs_module}.{name}"); - if modules.contains_key(&full_module) { - // This is a module import - resolved_imports.push(ResolvedImport::Module { - module: full_module, - imported_as: alias.clone().unwrap_or_else(|| name.clone()), - }); - } else { - // This is a member import - resolved_imports.push(ResolvedImport::ModuleMember { - module: abs_module.clone(), - member: name.clone(), - imported_as: alias.clone().unwrap_or_else(|| name.clone()), - }); - } - } - } - } - } - } - - imports.insert(file.module_path.clone(), resolved_imports); - } - - // Third pass: build parent-child relationships - for file in parsed_files { - for class in &file.classes { - for base in &class.bases { - // Resolve the base class to a ClassId - // TODO - } - } - } - - Self { - modules, - classes, - imports, - class_children, - } + pub fn build(registry: Registry) -> Self { + // TODO Use `registry.resolve_class`. + todo!() } /// Finds all transitive subclasses of a given class. @@ -201,7 +60,7 @@ impl InheritanceGraph { while let Some(current) = queue.pop_front() { // Find all direct children - if let Some(children) = self.class_children.get(¤t) { + if let Some(children) = self.children.get(¤t) { for child in children { let key = (child.module.clone(), child.name.clone()); if !visited.contains(&key) { @@ -216,66 +75,3 @@ impl InheritanceGraph { result } } - -/// Resolves a relative import to an absolute module path. -/// -/// # Arguments -/// -/// * `current_module` - The module path where the import occurs -/// * `level` - Number of dots in the relative import -/// * `relative_module` - Optional module name after the dots -/// * `is_package` - Whether the current module is a package (__init__.py) -/// -/// # Returns -/// -/// The absolute module path, or None if the import cannot be resolved -fn resolve_relative_import( - current_module: &str, - level: usize, - relative_module: Option<&str>, - is_package: bool, -) -> Option { - let parts: Vec<&str> = current_module.split('.').collect(); - - // level determines how many parent levels to go up - // level=1: from . import x (current package) - // level=2: from .. import x (parent package) - // - // Key insight: "current package" means different things: - // - For pkg/__init__.py (is_package=true), current package is "pkg" - // - For pkg/module.py (is_package=false), current package is also "pkg" - // - // So for a package, level=1 should not remove anything - // For a module, level=1 should remove the last component - if level == 0 || level > parts.len() { - return None; - } - - // Calculate how many components to remove - // For packages: level=1 removes 0, level=2 removes 1, etc. - // For modules: level=1 removes 1, level=2 removes 2, etc. - let levels_to_remove = if is_package { - level.saturating_sub(1) - } else { - level - }; - - if levels_to_remove > parts.len() { - return None; - } - - let base_parts = &parts[..parts.len() - levels_to_remove]; - - let base = if base_parts.is_empty() { - None - } else { - Some(base_parts.join(".")) - }; - - match (base, relative_module) { - (Some(base_str), Some(module)) => Some(format!("{base_str}.{module}")), - (Some(base_str), None) => Some(base_str), - (None, Some(module)) => Some(module.to_string()), - (None, None) => None, - } -} diff --git a/src/lib.rs b/src/lib.rs index 6490abc..246aec0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,12 +25,15 @@ pub mod discovery; pub mod error; pub mod graph; pub mod parser; +pub mod registry; use std::path::PathBuf; pub use error::{Error, Result}; use graph::InheritanceGraph; +use crate::registry::Registry; + /// A reference to a Python class. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ClassReference { @@ -70,6 +73,7 @@ impl ClassReference { /// # } /// ``` pub struct SubclassFinder { + registry: Registry, graph: InheritanceGraph, } @@ -106,15 +110,12 @@ impl SubclassFinder { }) .collect(); - // Build the inheritance graph - let graph = graph::InheritanceGraph::build(&parsed_files); + let registry = Registry::build(&parsed_files)?; - Ok(Self { graph }) - } + // Build the inheritance graph + let graph = InheritanceGraph::build(®istry); - /// Returns the total number of classes found in the codebase. - pub fn class_count(&self) -> usize { - self.graph.classes.values().map(|v| v.len()).sum() + Ok(Self { registry, graph }) } /// Finds all transitive subclasses of a given class. @@ -158,104 +159,6 @@ impl SubclassFinder { class_name: &str, module_path: Option<&str>, ) -> Result> { - // Find all classes with the given name - let mut candidates = Vec::new(); - for class_ids in self.graph.classes.values() { - for class_id in class_ids { - if class_id.name == class_name { - candidates.push(class_id); - } - } - } - - // Filter by module path if provided - let root_class = if let Some(module) = module_path { - // Try exact match first - let exact_match = candidates.iter().find(|c| c.module == module).copied(); - - if let Some(class_id) = exact_match { - class_id - } else { - // Check if the module re-exports the class - // Look for imports in the specified module that import this class - if let Some(imports) = self.graph.imports.get(module) { - let mut found = None; - for import in imports { - if let graph::ResolvedImport::ModuleMember { - module: source_module, - member, - .. - } = import - { - if member == class_name { - // Check if the class exists in the source module - found = candidates - .iter() - .find(|c| c.module == *source_module) - .copied(); - if found.is_some() { - break; - } - } - } - } - - if let Some(class_id) = found { - class_id - } else { - return Err(Error::ClassNotFound { - name: class_name.to_string(), - module_path: Some(module.to_string()), - }); - } - } else { - return Err(Error::ClassNotFound { - name: class_name.to_string(), - module_path: Some(module.to_string()), - }); - } - } - } else { - // No module path specified - if candidates.is_empty() { - return Err(Error::ClassNotFound { - name: class_name.to_string(), - module_path: None, - }); - } else if candidates.len() > 1 { - let module_names: Vec = - candidates.iter().map(|c| c.module.clone()).collect(); - return Err(Error::AmbiguousClassName { - name: class_name.to_string(), - candidates: module_names, - }); - } else { - candidates[0] - } - }; - - // Find all subclasses using BFS - let subclass_ids = self.graph.find_all_subclasses(root_class); - - // Convert ClassIds to ClassReferences - let mut references = Vec::new(); - for class_id in subclass_ids { - if let Some(metadata) = self.graph.modules.get(&class_id.module) { - references.push(ClassReference { - class_name: class_id.name.clone(), - module_path: class_id.module.clone(), - file_path: metadata.file_path.clone(), - }); - } - } - - // Sort by module path for consistent output - references.sort_by(|a, b| { - a.module_path - .cmp(&b.module_path) - .then_with(|| a.class_name.cmp(&b.class_name)) - }); - - Ok(references) + todo!() } } diff --git a/src/parser.rs b/src/parser.rs index 6302b7e..01c8337 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -8,15 +8,6 @@ use std::path::{Path, PathBuf}; use crate::error::{Error, Result}; -/// Represents a base class reference in a class definition. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum BaseClass { - /// Simple name reference (e.g., `Animal`) - Simple(String), - /// Attribute reference (e.g., `module.Animal` or `package.module.Animal`) - Attribute(Vec), -} - /// Represents a Python class definition. #[derive(Debug, Clone)] pub struct ClassDefinition { @@ -26,29 +17,23 @@ pub struct ClassDefinition { pub module_path: String, /// The file path where the class is defined pub file_path: PathBuf, - /// The base classes this class inherits from (unresolved) - pub bases: Vec, + /// The base classes this class inherits from e.g. "Foo" or "foo.Foo" + pub bases: Vec, } -/// Represents an import statement. -#[derive(Debug, Clone)] -pub enum Import { - /// `import foo` or `import foo.bar` - Module { - module: String, - alias: Option, - }, - /// `from foo import bar` or `from foo import bar as baz` - From { - module: String, - names: Vec<(String, Option)>, // (name, alias) - }, - /// `from .relative import foo` (relative import) - RelativeFrom { - level: usize, // Number of dots - module: Option, - names: Vec<(String, Option)>, - }, +/// An import. +/// +/// E.g. +/// `import a` => { imported_item=a, imported_as=a } +/// `import a.b` => { imported_item=a.b, imported_as=a.b } +/// `import a.b as c` => { imported_item=a.b, imported_as=c } +/// `from a import b` => { imported_item=a.b, imported_as=b } +/// `from a import b as c` => { imported_item=a.b, imported_as=c } +/// `from a.b import c` => { imported_item=a.b.c, imported_as=c } +#[derive(Debug)] +pub struct Import { + pub imported_item: String, + pub imported_as: String, } /// The result of parsing a Python file. @@ -60,7 +45,7 @@ pub struct ParsedFile { pub module_path: String, /// Class definitions found in this file pub classes: Vec, - /// Import statements found in this file + /// Import statements found in this file (relative imports already resolved) pub imports: Vec, /// Whether this is a package (__init__.py file) pub is_package: bool, diff --git a/src/registry.rs b/src/registry.rs new file mode 100644 index 0000000..36bd6a1 --- /dev/null +++ b/src/registry.rs @@ -0,0 +1,61 @@ +use std::{ + collections::{HashMap, HashSet}, + path::PathBuf, +}; + +use crate::parser::Import; + +pub type ModuleName = String; + +#[derive(Debug, Clone)] +pub struct ModuleMetadata { + pub file_path: PathBuf, + pub is_package: bool, +} + +#[derive(Debug, Clone)] +pub struct ClassMetadata { + // Base classes (unresolved) + pub bases: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ClassId { + pub module: ModuleName, + pub name: String, +} + +pub struct Registry { + pub modules: HashMap, + pub classes: HashMap, + pub classes_by_module: HashMap>, + pub imports: HashMap>, + pub class_children: HashMap>, +} + +impl Registry { + pub fn resolve_class(&self, module: &str, name: &str) -> Option { + // Resolve a class name in a given module. + // + // First, check whether a class with this name exists in the module. If so then we're done! + // If not, we need to use the imports to resolve the class. + // + // Suppose we're trying to resolve the name a.b.c.d. + // + // First, resolve based on the imports in the module. + // Look for a prefix based on `imported_as` and then substitute the related `imported_item`. + // E.g. Suppose { imported_item=pkg.foo, imported_as=a.b } + // Then the prefix a.b matches, so we substitute and the name becomes pkg.foo.c.d. + // + // Next, we have to find which module pkg.foo.c.d refers to. To do this consider in order: + // - pkg.foo.c.d + // - pkg.foo.c (remainder: d) + // - pkg.foo (remainder: c.d) + // - pkg (remainder: foo.c.d) + // Once we have a match, then we know we've found the module. + // We then go to that module and resolve the remainder name (recurse into `resolve_class`). + // E.g. Suppose pkg.foo matches a module. + // Then we call `resolve_class` with module="pkg.foo" and remainder="c.d". + todo!() + } +} From 5c94f3ca927bc8dd9983a4ee910e295b2ca46449 Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Sun, 26 Oct 2025 08:20:46 +0100 Subject: [PATCH 6/9] Maybe this is it? --- src/graph.rs | 42 +++++++-------- src/lib.rs | 71 +++++++++++++++++++++++++- src/parser.rs | 133 ++++++++++++++++++++++++++++++++---------------- src/registry.rs | 120 ++++++++++++++++++++++++++++++++++++------- 4 files changed, 279 insertions(+), 87 deletions(-) diff --git a/src/graph.rs b/src/graph.rs index 7adfa3c..41000fd 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -4,36 +4,30 @@ use std::collections::{HashMap, HashSet}; use crate::registry::{ClassId, Registry}; -pub type ModuleName = String; - /// An inheritance graph mapping classes to their children. pub struct InheritanceGraph { pub children: HashMap>, } -/// An enum representing a resolved import. -/// -/// `import X` is always an imported module. -/// -/// `from X import Y` can be either a module import, or a module member import. -/// This can be determined by first seeing if the module X.Y exists. If so then this is a module import of module X.Y. -/// If not we check if the module X exists. If so then this is an import of the member Y from the module X. -pub enum ResolvedImport { - Module { - module: ModuleName, - imported_as: String, - }, - ModuleMember { - module: ModuleName, - member: String, - imported_as: String, - }, -} - impl InheritanceGraph { - pub fn build(registry: Registry) -> Self { - // TODO Use `registry.resolve_class`. - todo!() + pub fn build(registry: &Registry) -> Self { + let mut children: HashMap> = HashMap::new(); + + // Iterate through all classes and resolve their base classes + for (child_id, metadata) in ®istry.classes { + for base_name in &metadata.bases { + // Try to resolve the base class using the registry + if let Some(parent_id) = registry.resolve_class(&child_id.module, base_name) { + // Add child to parent's children set + children + .entry(parent_id) + .or_default() + .insert(child_id.clone()); + } + } + } + + Self { children } } /// Finds all transitive subclasses of a given class. diff --git a/src/lib.rs b/src/lib.rs index 246aec0..8b5aa18 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -159,6 +159,75 @@ impl SubclassFinder { class_name: &str, module_path: Option<&str>, ) -> Result> { - todo!() + // Find the target class + let target_id = if let Some(module) = module_path { + // Module specified - look for class in that module (including re-exports) + if let Some(resolved_id) = self.registry.resolve_class(module, class_name) { + resolved_id + } else { + return Err(Error::ClassNotFound { + name: class_name.to_string(), + module_path: Some(module.to_string()), + }); + } + } else { + // No module specified - search for class by name + let mut matches = Vec::new(); + for class_id in self.registry.classes.keys() { + if class_id.name == class_name { + matches.push(class_id.clone()); + } + } + + match matches.len() { + 0 => { + return Err(Error::ClassNotFound { + name: class_name.to_string(), + module_path: None, + }); + } + 1 => matches.into_iter().next().unwrap(), + _ => { + let candidates: Vec = + matches.iter().map(|id| id.module.clone()).collect(); + return Err(Error::AmbiguousClassName { + name: class_name.to_string(), + candidates, + }); + } + } + }; + + // Find all subclasses using the graph + let subclass_ids = self.graph.find_all_subclasses(&target_id); + + // Convert to ClassReference + let mut results: Vec = subclass_ids + .into_iter() + .filter_map(|id| { + self.registry + .modules + .get(&id.module) + .map(|metadata| ClassReference { + class_name: id.name.clone(), + module_path: id.module.clone(), + file_path: metadata.file_path.clone(), + }) + }) + .collect(); + + // Sort by module path for consistent output + results.sort_by(|a, b| { + a.module_path + .cmp(&b.module_path) + .then(a.class_name.cmp(&b.class_name)) + }); + + Ok(results) + } + + /// Returns the number of classes found in the codebase. + pub fn class_count(&self) -> usize { + self.registry.classes.len() } } diff --git a/src/parser.rs b/src/parser.rs index 01c8337..6a58052 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -30,7 +30,7 @@ pub struct ClassDefinition { /// `from a import b` => { imported_item=a.b, imported_as=b } /// `from a import b as c` => { imported_item=a.b, imported_as=c } /// `from a.b import c` => { imported_item=a.b.c, imported_as=c } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Import { pub imported_item: String, pub imported_as: String, @@ -187,47 +187,58 @@ fn extract_from_statements( ); } Stmt::Import(import_stmt) => { - // Only process imports at the top level (not inside classes) - if parent_class.is_none() { - for alias in &import_stmt.names { - imports.push(Import::Module { - module: alias.name.to_string(), - alias: alias.asname.as_ref().map(|a| a.to_string()), - }); - } + for alias in &import_stmt.names { + // `import a` => { imported_item=a, imported_as=a } + // `import a.b as c` => { imported_item=a.b, imported_as=c } + let imported_item = alias.name.to_string(); + let imported_as = alias + .asname + .as_ref() + .map(|a| a.to_string()) + .unwrap_or_else(|| imported_item.clone()); + imports.push(Import { + imported_item, + imported_as, + }); } } Stmt::ImportFrom(import_from) => { - // Only process imports at the top level (not inside classes) - if parent_class.is_none() { - let level = import_from.level as usize; - let names: Vec<(String, Option)> = import_from - .names - .iter() - .map(|alias| { - ( - alias.name.to_string(), - alias.asname.as_ref().map(|a| a.to_string()), - ) - }) - .collect(); - - if level > 0 { - imports.push(Import::RelativeFrom { - level, - module: import_from.module.as_ref().map(|m| m.to_string()), - names, - }); + let level = import_from.level as usize; + + for alias in &import_from.names { + let name = alias.name.to_string(); + let imported_as = alias + .asname + .as_ref() + .map(|a| a.to_string()) + .unwrap_or_else(|| name.clone()); + + let imported_item = if level > 0 { + // Relative import: resolve to absolute module path + let base_module = resolve_relative_module(module_path, level, file_path); + if let Some(from_module) = import_from.module.as_ref() { + format!("{base_module}.{from_module}.{name}") + } else { + format!("{base_module}.{name}") + } } else { - imports.push(Import::From { - module: import_from - .module - .as_ref() - .map(|m| m.to_string()) - .unwrap_or_default(), - names, - }); - } + // Absolute import: from a.b import c => a.b.c + let from_module = import_from + .module + .as_ref() + .map(|m| m.to_string()) + .unwrap_or_default(); + if from_module.is_empty() { + name.clone() + } else { + format!("{from_module}.{name}") + } + }; + + imports.push(Import { + imported_item, + imported_as, + }); } } _ => {} @@ -235,11 +246,45 @@ fn extract_from_statements( } } -/// Extracts a base class reference from an expression. -fn extract_base_class(expr: &Expr) -> Option { +/// Resolves a relative import to an absolute module path. +/// +/// # Arguments +/// +/// * `current_module` - The module path of the file doing the import +/// * `level` - The number of dots in the relative import +/// * `file_path` - The file path (used to determine if this is a package) +/// +/// # Returns +/// +/// The base module path for the relative import +fn resolve_relative_module(current_module: &str, level: usize, file_path: &Path) -> String { + let is_package = file_path + .file_name() + .and_then(|name| name.to_str()) + .map(|name| name == "__init__.py") + .unwrap_or(false); + + let parts: Vec<&str> = current_module.split('.').collect(); + + // If this is a package (__init__.py), level 1 means current package + // If this is a module, level 1 means parent package + let base_level = if is_package { level - 1 } else { level }; + + // Go up 'base_level' directories + if base_level >= parts.len() { + // Going too far up - return empty string + String::new() + } else { + parts[..parts.len() - base_level].join(".") + } +} + +/// Extracts a base class reference from an expression as a string. +/// Returns strings like "Foo" or "module.Foo" or "pkg.mod.Foo" +fn extract_base_class(expr: &Expr) -> Option { match expr { // Simple name: class Foo(Bar) - Expr::Name(name) => Some(BaseClass::Simple(name.id.to_string())), + Expr::Name(name) => Some(name.id.to_string()), // Attribute: class Foo(module.Bar) or class Foo(pkg.mod.Bar) Expr::Attribute(_) => { @@ -262,7 +307,7 @@ fn extract_base_class(expr: &Expr) -> Option { } parts.reverse(); - Some(BaseClass::Attribute(parts)) + Some(parts.join(".")) } // Subscript: class Foo(Generic[T]) - extract the base without the subscript @@ -411,7 +456,7 @@ class TopLevel(Foo): .find(|c| c.name == "Bar.NestedInBar") .unwrap(); assert_eq!(nested_in_bar.bases.len(), 1); - assert_eq!(nested_in_bar.bases[0], BaseClass::Simple("Foo".to_string())); + assert_eq!(nested_in_bar.bases[0], "Foo"); // Verify that DoublyNested has Foo as a base let doubly_nested = parsed @@ -420,6 +465,6 @@ class TopLevel(Foo): .find(|c| c.name == "Bar.AnotherNested.DoublyNested") .unwrap(); assert_eq!(doubly_nested.bases.len(), 1); - assert_eq!(doubly_nested.bases[0], BaseClass::Simple("Foo".to_string())); + assert_eq!(doubly_nested.bases[0], "Foo"); } } diff --git a/src/registry.rs b/src/registry.rs index 36bd6a1..49dccd3 100644 --- a/src/registry.rs +++ b/src/registry.rs @@ -3,7 +3,10 @@ use std::{ path::PathBuf, }; -use crate::parser::Import; +use crate::{ + error::Result, + parser::{Import, ParsedFile}, +}; pub type ModuleName = String; @@ -30,32 +33,113 @@ pub struct Registry { pub classes: HashMap, pub classes_by_module: HashMap>, pub imports: HashMap>, - pub class_children: HashMap>, } impl Registry { + /// Builds a registry from parsed Python files. + pub fn build(parsed_files: &[ParsedFile]) -> Result { + let mut modules = HashMap::new(); + let mut classes = HashMap::new(); + let mut classes_by_module: HashMap> = HashMap::new(); + let mut imports = HashMap::new(); + + for parsed in parsed_files { + // Add module metadata + modules.insert( + parsed.module_path.clone(), + ModuleMetadata { + file_path: parsed.file_path.clone(), + is_package: parsed.is_package, + }, + ); + + // Add classes + for class in &parsed.classes { + let class_id = ClassId { + module: parsed.module_path.clone(), + name: class.name.clone(), + }; + + classes.insert( + class_id.clone(), + ClassMetadata { + bases: class.bases.clone(), + }, + ); + + classes_by_module + .entry(parsed.module_path.clone()) + .or_default() + .insert(class_id); + } + + // Add imports + imports.insert(parsed.module_path.clone(), parsed.imports.clone()); + } + + Ok(Self { + modules, + classes, + classes_by_module, + imports, + }) + } + pub fn resolve_class(&self, module: &str, name: &str) -> Option { // Resolve a class name in a given module. // // First, check whether a class with this name exists in the module. If so then we're done! + let direct_id = ClassId { + module: module.to_string(), + name: name.to_string(), + }; + if self.classes.contains_key(&direct_id) { + return Some(direct_id); + } + // If not, we need to use the imports to resolve the class. - // - // Suppose we're trying to resolve the name a.b.c.d. - // + let imports = self.imports.get(module)?; + // First, resolve based on the imports in the module. // Look for a prefix based on `imported_as` and then substitute the related `imported_item`. - // E.g. Suppose { imported_item=pkg.foo, imported_as=a.b } - // Then the prefix a.b matches, so we substitute and the name becomes pkg.foo.c.d. - // - // Next, we have to find which module pkg.foo.c.d refers to. To do this consider in order: - // - pkg.foo.c.d - // - pkg.foo.c (remainder: d) - // - pkg.foo (remainder: c.d) - // - pkg (remainder: foo.c.d) - // Once we have a match, then we know we've found the module. - // We then go to that module and resolve the remainder name (recurse into `resolve_class`). - // E.g. Suppose pkg.foo matches a module. - // Then we call `resolve_class` with module="pkg.foo" and remainder="c.d". - todo!() + let mut resolved_name = name.to_string(); + + for import in imports { + // Check if name starts with imported_as + if name == import.imported_as { + // Exact match: replace entire name + resolved_name = import.imported_item.clone(); + break; + } else if let Some(remainder) = name.strip_prefix(&format!("{}.", import.imported_as)) { + // Prefix match: substitute prefix + resolved_name = format!("{}.{}", import.imported_item, remainder); + break; + } + } + + // Now we have to find which module resolved_name refers to. + // Split the name into parts + let parts: Vec<&str> = resolved_name.split('.').collect(); + + // Try progressively shorter prefixes to find a matching module + for i in (1..=parts.len()).rev() { + let module_candidate = parts[..i].join("."); + + // Check if this module exists + if self.modules.contains_key(&module_candidate) { + if i == parts.len() { + // The entire name is a module - this shouldn't be a class + return None; + } + + // We found the module, the rest is the class name within that module + let remainder = parts[i..].join("."); + // Recurse to resolve the class in that module + return self.resolve_class(&module_candidate, &remainder); + } + } + + // Couldn't resolve + None } } From be6dd573181c89ee07f1272ff7c2c3e667039786 Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Sun, 26 Oct 2025 08:36:59 +0100 Subject: [PATCH 7/9] Doc + further tests --- src/graph.rs | 66 ++++++++++++++++++--- src/parser.rs | 70 ++++++++++++++++------- src/registry.rs | 117 +++++++++++++++++++++++++++++++------- tests/integration_test.rs | 92 ++++++++++++++++++++++++++++++ 4 files changed, 294 insertions(+), 51 deletions(-) diff --git a/src/graph.rs b/src/graph.rs index 41000fd..6a9f6f1 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -1,24 +1,53 @@ //! Inheritance graph construction and traversal. +//! +//! This module provides functionality to build and query a class inheritance graph. +//! The graph maps parent classes to their direct and transitive children, enabling +//! efficient subclass discovery. use std::collections::{HashMap, HashSet}; use crate::registry::{ClassId, Registry}; -/// An inheritance graph mapping classes to their children. +/// An inheritance graph representing parent-child relationships between classes. +/// +/// This structure maps each class to the set of classes that directly inherit from it. +/// The graph can be used to efficiently find all descendants of a given class using +/// breadth-first search. pub struct InheritanceGraph { + /// Maps parent classes to their direct children. pub children: HashMap>, } impl InheritanceGraph { + /// Builds an inheritance graph from a registry. + /// + /// This resolves all base class references and constructs parent-to-child + /// relationships. The registry's `resolve_class` method is used to handle + /// imports and re-exports correctly. + /// + /// # Arguments + /// + /// * `registry` - The registry containing all classes and their base class references + /// + /// # Returns + /// + /// An inheritance graph ready for traversal. + /// + /// # Algorithm + /// + /// For each class in the registry: + /// 1. Resolve each of its base class names to a `ClassId` + /// 2. Add this class to the parent's children set + /// 3. Skip any base classes that cannot be resolved (e.g., external dependencies) pub fn build(registry: &Registry) -> Self { let mut children: HashMap> = HashMap::new(); - // Iterate through all classes and resolve their base classes + // Build parent → children edges by examining each class's bases for (child_id, metadata) in ®istry.classes { for base_name in &metadata.bases { - // Try to resolve the base class using the registry + // Resolve the base class reference in this class's module context if let Some(parent_id) = registry.resolve_class(&child_id.module, base_name) { - // Add child to parent's children set + // Add this class as a child of its parent children .entry(parent_id) .or_default() @@ -32,15 +61,33 @@ impl InheritanceGraph { /// Finds all transitive subclasses of a given class. /// - /// Uses BFS to traverse the inheritance graph and collect all descendants. + /// Performs a breadth-first search to discover all classes that directly or + /// indirectly inherit from the specified root class. This includes: + /// - Direct children (one level of inheritance) + /// - Grandchildren (two levels) + /// - Great-grandchildren, etc. (any depth) /// /// # Arguments /// - /// * `root` - The root class to find subclasses of + /// * `root` - The class to find subclasses for /// /// # Returns /// - /// A vector of all transitive subclasses (not including the root class itself). + /// A vector containing all transitive subclasses. The root class itself is not + /// included in the result. The order of classes in the vector is determined by + /// the BFS traversal order. + /// + /// # Examples + /// + /// ```text + /// Given: + /// class Animal: pass + /// class Mammal(Animal): pass + /// class Dog(Mammal): pass + /// class Cat(Mammal): pass + /// + /// find_all_subclasses(Animal) → [Mammal, Dog, Cat] + /// ``` pub fn find_all_subclasses(&self, root: &ClassId) -> Vec { use std::collections::VecDeque; @@ -48,12 +95,13 @@ impl InheritanceGraph { let mut visited = HashSet::new(); let mut queue = VecDeque::new(); - // Start BFS from the root + // Initialize BFS with the root class queue.push_back(root.clone()); visited.insert((root.module.clone(), root.name.clone())); + // BFS traversal while let Some(current) = queue.pop_front() { - // Find all direct children + // Examine all direct children of the current class if let Some(children) = self.children.get(¤t) { for child in children { let key = (child.module.clone(), child.name.clone()); diff --git a/src/parser.rs b/src/parser.rs index 6a58052..98a57b3 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -135,16 +135,20 @@ pub fn parse_file(file_path: &Path, module_path: &str) -> Result { /// Recursively extracts classes and imports from a list of statements. /// -/// This function handles both top-level and nested class definitions. +/// This function walks the AST and extracts: +/// - Class definitions (including nested classes) +/// - Import statements (both `import` and `from...import` forms) +/// +/// Nested classes are represented with dot notation (e.g., "Outer.Inner"). /// /// # Arguments /// -/// * `stmts` - The statements to process -/// * `parent_class` - The parent class name for nested classes (e.g., Some("Bar")) -/// * `module_path` - The module path for this file -/// * `file_path` - The file path -/// * `classes` - Mutable vector to accumulate class definitions -/// * `imports` - Mutable vector to accumulate imports (only at top level) +/// * `stmts` - The AST statements to process +/// * `parent_class` - The parent class name if processing nested classes (e.g., `Some("Outer")`) +/// * `module_path` - The module path for this file (e.g., "foo.bar") +/// * `file_path` - The file path (used for resolving relative imports) +/// * `classes` - Mutable vector to accumulate discovered class definitions +/// * `imports` - Mutable vector to accumulate discovered imports (only at top level) fn extract_from_statements( stmts: &[Stmt], parent_class: Option<&str>, @@ -156,13 +160,14 @@ fn extract_from_statements( for stmt in stmts { match stmt { Stmt::ClassDef(class_def) => { - // Build the full class name (with parent prefix if nested) + // Build fully qualified class name (e.g., "Outer.Inner" for nested classes) let full_name = if let Some(parent) = parent_class { format!("{}.{}", parent, class_def.name) } else { class_def.name.to_string() }; + // Extract base classes, filtering out unresolvable references let bases = class_def .bases() .iter() @@ -176,7 +181,7 @@ fn extract_from_statements( bases, }); - // Recursively extract nested classes + // Recursively process nested classes extract_from_statements( class_def.body.as_slice(), Some(&full_name), @@ -187,9 +192,9 @@ fn extract_from_statements( ); } Stmt::Import(import_stmt) => { + // Process `import foo` or `import foo as bar` statements + // Format: { imported_item: "foo", imported_as: "bar" } for alias in &import_stmt.names { - // `import a` => { imported_item=a, imported_as=a } - // `import a.b as c` => { imported_item=a.b, imported_as=c } let imported_item = alias.name.to_string(); let imported_as = alias .asname @@ -203,6 +208,7 @@ fn extract_from_statements( } } Stmt::ImportFrom(import_from) => { + // Process `from foo import bar` or `from .foo import bar` statements let level = import_from.level as usize; for alias in &import_from.names { @@ -214,7 +220,8 @@ fn extract_from_statements( .unwrap_or_else(|| name.clone()); let imported_item = if level > 0 { - // Relative import: resolve to absolute module path + // Relative import: resolve dots to absolute module path + // e.g., `from ..pkg import Foo` → "parent.pkg.Foo" let base_module = resolve_relative_module(module_path, level, file_path); if let Some(from_module) = import_from.module.as_ref() { format!("{base_module}.{from_module}.{name}") @@ -222,7 +229,8 @@ fn extract_from_statements( format!("{base_module}.{name}") } } else { - // Absolute import: from a.b import c => a.b.c + // Absolute import: combine module and name + // e.g., `from foo import Bar` → "foo.Bar" let from_module = import_from .module .as_ref() @@ -248,15 +256,34 @@ fn extract_from_statements( /// Resolves a relative import to an absolute module path. /// +/// Relative imports in Python use dots to indicate the starting point: +/// - `.foo` means "foo in the current package" +/// - `..foo` means "foo in the parent package" +/// - `...foo` means "foo in the grandparent package" +/// +/// This function converts the relative path to an absolute module path based on +/// the current module's location. +/// /// # Arguments /// -/// * `current_module` - The module path of the file doing the import -/// * `level` - The number of dots in the relative import -/// * `file_path` - The file path (used to determine if this is a package) +/// * `current_module` - The module path of the file containing the import (e.g., "pkg.sub.module") +/// * `level` - The number of leading dots in the relative import (e.g., 2 for `..foo`) +/// * `file_path` - The file path (used to check if this is a `__init__.py` package file) /// /// # Returns /// -/// The base module path for the relative import +/// The absolute module path that the relative import refers to. +/// +/// # Examples +/// +/// ```text +/// # In pkg/sub/module.py (module path "pkg.sub.module"): +/// from ..other import Foo # level=2 → resolves to "pkg.other" +/// +/// # In pkg/sub/__init__.py (module path "pkg.sub"): +/// from ..other import Foo # level=2 → resolves to "pkg.other" +/// from .local import Bar # level=1 → resolves to "pkg.sub.local" +/// ``` fn resolve_relative_module(current_module: &str, level: usize, file_path: &Path) -> String { let is_package = file_path .file_name() @@ -266,13 +293,14 @@ fn resolve_relative_module(current_module: &str, level: usize, file_path: &Path) let parts: Vec<&str> = current_module.split('.').collect(); - // If this is a package (__init__.py), level 1 means current package - // If this is a module, level 1 means parent package + // Packages start at themselves, modules start at their parent + // e.g., In "pkg.sub" package: level 1 = "pkg.sub", level 2 = "pkg" + // In "pkg.sub.mod" module: level 1 = "pkg.sub", level 2 = "pkg" let base_level = if is_package { level - 1 } else { level }; - // Go up 'base_level' directories + // Navigate up the package hierarchy if base_level >= parts.len() { - // Going too far up - return empty string + // Trying to go above the root - return empty (error case) String::new() } else { parts[..parts.len() - base_level].join(".") diff --git a/src/registry.rs b/src/registry.rs index 49dccd3..f1e89eb 100644 --- a/src/registry.rs +++ b/src/registry.rs @@ -1,3 +1,9 @@ +//! Class registry for tracking class definitions and resolving references. +//! +//! The registry maintains a complete index of all Python modules, classes, and imports +//! found in a codebase. It provides functionality to resolve class references across +//! modules, handling imports and re-exports correctly. + use std::{ collections::{HashMap, HashSet}, path::PathBuf, @@ -8,35 +14,69 @@ use crate::{ parser::{Import, ParsedFile}, }; +/// Type alias for Python module names (e.g., "foo.bar.baz"). pub type ModuleName = String; +/// Metadata about a Python module. #[derive(Debug, Clone)] pub struct ModuleMetadata { + /// The file system path to this module. pub file_path: PathBuf, + /// Whether this is a package (`__init__.py` file). pub is_package: bool, } +/// Metadata about a Python class definition. #[derive(Debug, Clone)] pub struct ClassMetadata { - // Base classes (unresolved) + /// The base classes this class inherits from. + /// + /// These are stored as unresolved strings (e.g., "Foo" or "module.Foo") + /// and must be resolved using the registry's import information. pub bases: Vec, } +/// A unique identifier for a class within the codebase. +/// +/// Consists of the module path and class name. Note that nested classes +/// are represented with dot notation (e.g., "OuterClass.InnerClass"). #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ClassId { + /// The module path (e.g., "foo.bar"). pub module: ModuleName, + /// The class name, potentially including nesting (e.g., "Outer.Inner"). pub name: String, } +/// A registry of all modules, classes, and imports in a Python codebase. +/// +/// The registry is built from parsed files and provides methods to: +/// - Look up class definitions by name or module +/// - Resolve class references through imports and re-exports +/// - Track the inheritance relationships between classes pub struct Registry { + /// Metadata for each module in the codebase. pub modules: HashMap, + /// Metadata for each class, indexed by ClassId. pub classes: HashMap, + /// Index of classes organized by module for efficient lookup. pub classes_by_module: HashMap>, + /// Import statements for each module. pub imports: HashMap>, } impl Registry { /// Builds a registry from parsed Python files. + /// + /// This processes all parsed files to create indexes of modules, classes, and imports. + /// + /// # Arguments + /// + /// * `parsed_files` - The collection of parsed Python files + /// + /// # Returns + /// + /// A fully constructed registry ready for class resolution. pub fn build(parsed_files: &[ParsedFile]) -> Result { let mut modules = HashMap::new(); let mut classes = HashMap::new(); @@ -44,7 +84,7 @@ impl Registry { let mut imports = HashMap::new(); for parsed in parsed_files { - // Add module metadata + // Record module metadata modules.insert( parsed.module_path.clone(), ModuleMetadata { @@ -53,7 +93,7 @@ impl Registry { }, ); - // Add classes + // Index all class definitions from this module for class in &parsed.classes { let class_id = ClassId { module: parsed.module_path.clone(), @@ -73,7 +113,7 @@ impl Registry { .insert(class_id); } - // Add imports + // Store import statements for later resolution imports.insert(parsed.module_path.clone(), parsed.imports.clone()); } @@ -85,10 +125,44 @@ impl Registry { }) } + /// Resolves a class name within a given module's context. + /// + /// This method handles the complexity of Python's import system, including: + /// - Direct class references within the same module + /// - Classes imported from other modules + /// - Classes re-exported through `__init__.py` files + /// - Attribute-style class references (e.g., "module.Class") + /// + /// # Algorithm + /// + /// 1. Check if the class exists directly in the specified module + /// 2. If not, consult the module's imports to resolve the name: + /// - Match against `imported_as` names from import statements + /// - Substitute with the actual `imported_item` path + /// 3. Parse the resolved name to find the defining module: + /// - Try progressively shorter prefixes (e.g., "a.b.c" → "a.b" → "a") + /// - Stop when we find a module that exists + /// 4. Recursively resolve the remaining name within that module + /// + /// # Arguments + /// + /// * `module` - The module path providing the context for resolution + /// * `name` - The class name to resolve (may include dots for attribute access) + /// + /// # Returns + /// + /// The resolved `ClassId` if the class can be found, or `None` if resolution fails. + /// + /// # Examples + /// + /// ```text + /// # In module "zoo": + /// from animals import Dog + /// class Puppy(Dog): # Resolves "Dog" to ClassId { module: "animals", name: "Dog" } + /// pass + /// ``` pub fn resolve_class(&self, module: &str, name: &str) -> Option { - // Resolve a class name in a given module. - // - // First, check whether a class with this name exists in the module. If so then we're done! + // First, check for a direct class definition in this module let direct_id = ClassId { module: module.to_string(), name: name.to_string(), @@ -97,49 +171,50 @@ impl Registry { return Some(direct_id); } - // If not, we need to use the imports to resolve the class. + // Not found directly - use imports to resolve the reference let imports = self.imports.get(module)?; - // First, resolve based on the imports in the module. - // Look for a prefix based on `imported_as` and then substitute the related `imported_item`. + // Substitute imported names with their actual module paths + // Example: If "Dog" is imported as "from animals import Dog", + // then "Dog" becomes "animals.Dog" let mut resolved_name = name.to_string(); for import in imports { - // Check if name starts with imported_as if name == import.imported_as { - // Exact match: replace entire name + // Exact match: "Dog" → "animals.Dog" resolved_name = import.imported_item.clone(); break; } else if let Some(remainder) = name.strip_prefix(&format!("{}.", import.imported_as)) { - // Prefix match: substitute prefix + // Prefix match: "Dog.Puppy" → "animals.Dog.Puppy" resolved_name = format!("{}.{}", import.imported_item, remainder); break; } } - // Now we have to find which module resolved_name refers to. - // Split the name into parts + // Parse the resolved name to find the defining module and class + // Example: "animals.Dog" needs to be split into module "animals" and class "Dog" let parts: Vec<&str> = resolved_name.split('.').collect(); - // Try progressively shorter prefixes to find a matching module + // Try each possible split point from longest to shortest prefix + // This handles cases like "a.b.c.Class" where "a.b.c" might be the module for i in (1..=parts.len()).rev() { let module_candidate = parts[..i].join("."); - // Check if this module exists if self.modules.contains_key(&module_candidate) { if i == parts.len() { - // The entire name is a module - this shouldn't be a class + // The entire resolved name is just a module, not a class return None; } - // We found the module, the rest is the class name within that module + // Found the module! The remainder is the class name within it let remainder = parts[i..].join("."); - // Recurse to resolve the class in that module + + // Recursively resolve in case the class itself is re-exported return self.resolve_class(&module_candidate, &remainder); } } - // Couldn't resolve + // Unable to resolve this class reference None } } diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 48f7e31..793a13d 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -661,3 +661,95 @@ class C(testpkg.b.testpkg.a.A): ... temp.close().unwrap(); } + +#[test] +fn test_reexport_not_via_init() { + let temp = assert_fs::TempDir::new().unwrap(); + + temp.child("testpkg/a.py") + .write_str( + r#" +class A: ... +"#, + ) + .unwrap(); + + temp.child("testpkg/b.py") + .write_str( + r#" +from testpkg.a import A + +class B(A): ... +"#, + ) + .unwrap(); + + temp.child("testpkg/c.py") + .write_str( + r#" +from testpkg.b import A + +class C(A): ... +"#, + ) + .unwrap(); + + let mut cmd = Command::cargo_bin("pysubclasses").unwrap(); + cmd.arg("A") + .arg("--directory") + .arg(temp.path()) + .assert() + .success() + .stdout(predicate::str::contains("Found 2 subclass(es) of 'A'")) + .stdout(predicate::str::contains("B")) + .stdout(predicate::str::contains("C")); + + temp.close().unwrap(); +} + +#[test] +fn test_reexport_not_via_init_pass_reexport_module() { + let temp = assert_fs::TempDir::new().unwrap(); + + temp.child("testpkg/a.py") + .write_str( + r#" +class A: ... +"#, + ) + .unwrap(); + + temp.child("testpkg/b.py") + .write_str( + r#" +from testpkg.a import A + +class B(A): ... +"#, + ) + .unwrap(); + + temp.child("testpkg/c.py") + .write_str( + r#" +from testpkg.b import A + +class C(A): ... +"#, + ) + .unwrap(); + + let mut cmd = Command::cargo_bin("pysubclasses").unwrap(); + cmd.arg("A") + .arg("--directory") + .arg(temp.path()) + .arg("--module") + .arg("testpkg.b") + .assert() + .success() + .stdout(predicate::str::contains("Found 2 subclass(es) of 'A'")) + .stdout(predicate::str::contains("B")) + .stdout(predicate::str::contains("C")); + + temp.close().unwrap(); +} From 4d61b344febfbf04b0c05d661b80d081e54353d2 Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Sun, 26 Oct 2025 13:07:51 +0100 Subject: [PATCH 8/9] Update CLAUDE.md --- CLAUDE.md | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 67cbee8..0509e41 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,9 +1,15 @@ See @README for a general project overview. -# Additional Instructions +## Additional Instructions - Write idiomatic rust code that is simple as possible. Avoid excessive abstraction. -- All code should have tests (unit and/or integration tests). -- Look for opportunities to refactor, simplify and improve your code. -- Ensure that documentation is up to date with the code. +- Look for opportunities to refactor, simplify and improve the code. +- Ensure that code is documented, and that the documentation is up to date with the code. - When you're done making changes test, lint and format the code to check everything. + +## Testing Guidelines + +- All code should have tests (unit and/or integration tests). +- Use the assert_fs and assert_cmd crates for integration testing. +- Where parametric tests are appropriate, use the yare crate. + Always create a struct (e.g. `Case`) to hold test case data. From 58052a5696b4957d59104b069556cff9fa8d7598 Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Sun, 26 Oct 2025 13:20:28 +0100 Subject: [PATCH 9/9] Add tests for parser --- src/parser.rs | 281 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 281 insertions(+) diff --git a/src/parser.rs b/src/parser.rs index 98a57b3..df3254e 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -495,4 +495,285 @@ class TopLevel(Foo): assert_eq!(doubly_nested.bases.len(), 1); assert_eq!(doubly_nested.bases[0], "Foo"); } + + // Parametric tests for import parsing + #[derive(Debug)] + struct ImportCase { + name: &'static str, + python_code: &'static str, + module_path: &'static str, + is_package: bool, + expected_imports: Vec<(&'static str, &'static str)>, // (imported_item, imported_as) + } + + #[yare::parameterized( + absolute_import_simple = { ImportCase { + name: "absolute import simple", + python_code: "import foo", + module_path: "mymodule", + is_package: false, + expected_imports: vec![("foo", "foo")], + } }, + absolute_import_dotted = { ImportCase { + name: "absolute import dotted", + python_code: "import foo.bar.baz", + module_path: "mymodule", + is_package: false, + expected_imports: vec![("foo.bar.baz", "foo.bar.baz")], + } }, + absolute_import_alias = { ImportCase { + name: "absolute import with alias", + python_code: "import foo as f", + module_path: "mymodule", + is_package: false, + expected_imports: vec![("foo", "f")], + } }, + absolute_import_dotted_alias = { ImportCase { + name: "absolute import dotted with alias", + python_code: "import foo.bar as fb", + module_path: "mymodule", + is_package: false, + expected_imports: vec![("foo.bar", "fb")], + } }, + from_import_simple = { ImportCase { + name: "from import simple", + python_code: "from foo import Bar", + module_path: "mymodule", + is_package: false, + expected_imports: vec![("foo.Bar", "Bar")], + } }, + from_import_dotted = { ImportCase { + name: "from import dotted", + python_code: "from foo.bar import Baz", + module_path: "mymodule", + is_package: false, + expected_imports: vec![("foo.bar.Baz", "Baz")], + } }, + from_import_alias = { ImportCase { + name: "from import with alias", + python_code: "from foo import Bar as B", + module_path: "mymodule", + is_package: false, + expected_imports: vec![("foo.Bar", "B")], + } }, + from_import_multiple = { ImportCase { + name: "from import multiple", + python_code: "from foo import Bar, Baz", + module_path: "mymodule", + is_package: false, + expected_imports: vec![("foo.Bar", "Bar"), ("foo.Baz", "Baz")], + } }, + from_import_multiple_alias = { ImportCase { + name: "from import multiple with alias", + python_code: "from foo import Bar as B, Baz as Z", + module_path: "mymodule", + is_package: false, + expected_imports: vec![("foo.Bar", "B"), ("foo.Baz", "Z")], + } }, + relative_import_one_level_module = { ImportCase { + name: "relative import one level from module", + python_code: "from .sibling import Foo", + module_path: "pkg.mymodule", + is_package: false, + expected_imports: vec![("pkg.sibling.Foo", "Foo")], + } }, + relative_import_one_level_package = { ImportCase { + name: "relative import one level from package", + python_code: "from .sibling import Foo", + module_path: "pkg", + is_package: true, + expected_imports: vec![("pkg.sibling.Foo", "Foo")], + } }, + relative_import_two_levels_module = { ImportCase { + name: "relative import two levels from module", + python_code: "from ..other import Foo", + module_path: "pkg.sub.mymodule", + is_package: false, + expected_imports: vec![("pkg.other.Foo", "Foo")], + } }, + relative_import_two_levels_package = { ImportCase { + name: "relative import two levels from package", + python_code: "from ..other import Foo", + module_path: "pkg.sub", + is_package: true, + expected_imports: vec![("pkg.other.Foo", "Foo")], + } }, + relative_import_no_module = { ImportCase { + name: "relative import without from module", + python_code: "from . import Foo", + module_path: "pkg.mymodule", + is_package: false, + expected_imports: vec![("pkg.Foo", "Foo")], + } }, + relative_import_alias = { ImportCase { + name: "relative import with alias", + python_code: "from .sibling import Foo as F", + module_path: "pkg.mymodule", + is_package: false, + expected_imports: vec![("pkg.sibling.Foo", "F")], + } }, + multiple_imports = { ImportCase { + name: "multiple import statements", + python_code: "import foo\nfrom bar import Baz", + module_path: "mymodule", + is_package: false, + expected_imports: vec![("foo", "foo"), ("bar.Baz", "Baz")], + } }, + )] + fn test_import_parsing(case: ImportCase) { + let temp_dir = std::env::temp_dir(); + // Create a proper directory structure for package tests + let test_id = case.name.replace(' ', "_"); + let test_root = temp_dir.join(format!("test_imports_{test_id}")); + std::fs::create_dir_all(&test_root).unwrap(); + + let temp_file = if case.is_package { + test_root.join("__init__.py") + } else { + test_root.join("test.py") + }; + + std::fs::write(&temp_file, case.python_code).unwrap(); + + let parsed = parse_file(&temp_file, case.module_path).unwrap(); + + // Clean up + let _ = std::fs::remove_dir_all(&test_root); + + // Verify imports + assert_eq!( + parsed.imports.len(), + case.expected_imports.len(), + "Case '{}': expected {} imports, got {}", + case.name, + case.expected_imports.len(), + parsed.imports.len() + ); + + for (i, (expected_item, expected_as)) in case.expected_imports.iter().enumerate() { + assert_eq!( + parsed.imports[i].imported_item, *expected_item, + "Case '{}': import {} - expected imported_item '{}', got '{}'", + case.name, i, expected_item, parsed.imports[i].imported_item + ); + assert_eq!( + parsed.imports[i].imported_as, *expected_as, + "Case '{}': import {} - expected imported_as '{}', got '{}'", + case.name, i, expected_as, parsed.imports[i].imported_as + ); + } + } + + // Parametric tests for base class extraction + #[derive(Debug)] + struct BaseClassCase { + name: &'static str, + python_code: &'static str, + class_name: &'static str, + expected_bases: Vec<&'static str>, + } + + #[yare::parameterized( + simple_base = { BaseClassCase { + name: "simple base class", + python_code: "class Foo(Bar): pass", + class_name: "Foo", + expected_bases: vec!["Bar"], + } }, + attribute_base = { BaseClassCase { + name: "attribute base class", + python_code: "class Foo(module.Bar): pass", + class_name: "Foo", + expected_bases: vec!["module.Bar"], + } }, + nested_attribute_base = { BaseClassCase { + name: "nested attribute base class", + python_code: "class Foo(pkg.module.Bar): pass", + class_name: "Foo", + expected_bases: vec!["pkg.module.Bar"], + } }, + multiple_bases_simple = { BaseClassCase { + name: "multiple simple bases", + python_code: "class Foo(Bar, Baz): pass", + class_name: "Foo", + expected_bases: vec!["Bar", "Baz"], + } }, + multiple_bases_mixed = { BaseClassCase { + name: "multiple mixed bases", + python_code: "class Foo(Bar, pkg.Baz): pass", + class_name: "Foo", + expected_bases: vec!["Bar", "pkg.Baz"], + } }, + generic_base = { BaseClassCase { + name: "generic base class", + python_code: "class Foo(Generic[T]): pass", + class_name: "Foo", + expected_bases: vec!["Generic"], + } }, + generic_with_multiple = { BaseClassCase { + name: "generic with multiple type params", + python_code: "class Foo(Dict[str, int]): pass", + class_name: "Foo", + expected_bases: vec!["Dict"], + } }, + mixed_generic_and_simple = { BaseClassCase { + name: "mixed generic and simple", + python_code: "class Foo(Bar, Generic[T]): pass", + class_name: "Foo", + expected_bases: vec!["Bar", "Generic"], + } }, + no_bases = { BaseClassCase { + name: "no base classes", + python_code: "class Foo: pass", + class_name: "Foo", + expected_bases: vec![], + } }, + attribute_generic = { BaseClassCase { + name: "attribute with generic", + python_code: "class Foo(typing.Generic[T]): pass", + class_name: "Foo", + expected_bases: vec!["typing.Generic"], + } }, + )] + fn test_base_class_extraction(case: BaseClassCase) { + let temp_dir = std::env::temp_dir(); + let temp_file = temp_dir.join(format!("test_bases_{}.py", case.name)); + + std::fs::write(&temp_file, case.python_code).unwrap(); + + let parsed = parse_file(&temp_file, "test_module").unwrap(); + + // Clean up + let _ = std::fs::remove_file(&temp_file); + + // Find the class + let class = parsed + .classes + .iter() + .find(|c| c.name == case.class_name) + .unwrap_or_else(|| { + panic!( + "Case '{}': class '{}' not found", + case.name, case.class_name + ) + }); + + // Verify bases + assert_eq!( + class.bases.len(), + case.expected_bases.len(), + "Case '{}': expected {} bases, got {}", + case.name, + case.expected_bases.len(), + class.bases.len() + ); + + for (i, expected_base) in case.expected_bases.iter().enumerate() { + assert_eq!( + class.bases[i], *expected_base, + "Case '{}': base {} - expected '{}', got '{}'", + case.name, i, expected_base, class.bases[i] + ); + } + } }