Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions CLAUDE.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
See @README for a general project overview.

# Additional Instructions
## Additional Instructions

- Write idiomatic rust code that is simple as possible. Avoid excessive abstraction.
- All code should have tests (unit and/or integration tests).
- Look for opportunities to refactor, simplify and improve your code.
- Ensure that documentation is up to date with the code.
- Look for opportunities to refactor, simplify and improve the code.
- Ensure that code is documented, and that the documentation is up to date with the code.
- When you're done making changes test, lint and format the code to check everything.

## Testing Guidelines

- All code should have tests (unit and/or integration tests).
- Use the assert_fs and assert_cmd crates for integration testing.
- Where parametric tests are appropriate, use the yare crate.
Always create a struct (e.g. `Case`) to hold test case data.
296 changes: 73 additions & 223 deletions src/graph.rs
Original file line number Diff line number Diff line change
@@ -1,40 +1,57 @@
//! Inheritance graph construction and traversal.
//!
//! This module provides functionality to build and query a class inheritance graph.
//! The graph maps parent classes to their direct and transitive children, enabling
//! efficient subclass discovery.

use std::collections::{HashMap, HashSet, VecDeque};
use std::collections::{HashMap, HashSet};

use crate::registry::{ClassId, ClassRegistry};
use crate::registry::{ClassId, Registry};

/// An inheritance graph mapping classes to their children.
/// An inheritance graph representing parent-child relationships between classes.
///
/// This structure maps each class to the set of classes that directly inherit from it.
/// The graph can be used to efficiently find all descendants of a given class using
/// breadth-first search.
pub struct InheritanceGraph {
/// Map from parent ClassId to child ClassIds
children: HashMap<ClassId, Vec<ClassId>>,
/// Maps parent classes to their direct children.
pub children: HashMap<ClassId, HashSet<ClassId>>,
}

impl InheritanceGraph {
/// Builds an inheritance graph from a class registry.
/// Builds an inheritance graph from a registry.
///
/// This resolves all base class references and constructs parent-to-child
/// relationships. The registry's `resolve_class` method is used to handle
/// imports and re-exports correctly.
///
/// # Arguments
///
/// * `registry` - The class registry containing all classes
/// * `registry` - The registry containing all classes and their base class references
///
/// # Returns
///
/// An inheritance graph with parent-child relationships.
pub fn build(registry: &ClassRegistry) -> Self {
let mut children: HashMap<ClassId, Vec<ClassId>> = HashMap::new();

// For each class, resolve its base classes and add edges
for class_id in registry.all_class_ids() {
if let Some(info) = registry.get(&class_id) {
for base in &info.bases {
// Try to resolve the base class
if let Some(parent_id) = registry.resolve_base(base, &class_id.module_path) {
// Add edge from parent to child
children
.entry(parent_id)
.or_default()
.push(class_id.clone());
}
/// An inheritance graph ready for traversal.
///
/// # Algorithm
///
/// For each class in the registry:
/// 1. Resolve each of its base class names to a `ClassId`
/// 2. Add this class to the parent's children set
/// 3. Skip any base classes that cannot be resolved (e.g., external dependencies)
pub fn build(registry: &Registry) -> Self {
let mut children: HashMap<ClassId, HashSet<ClassId>> = HashMap::new();

// Build parent → children edges by examining each class's bases
for (child_id, metadata) in &registry.classes {
for base_name in &metadata.bases {
// Resolve the base class reference in this class's module context
if let Some(parent_id) = registry.resolve_class(&child_id.module, base_name) {
// Add this class as a child of its parent
children
.entry(parent_id)
.or_default()
.insert(child_id.clone());
}
}
}
Expand All @@ -44,40 +61,53 @@ impl InheritanceGraph {

/// Finds all transitive subclasses of a given class.
///
/// Uses BFS to traverse the inheritance graph and collect all descendants.
/// Performs a breadth-first search to discover all classes that directly or
/// indirectly inherit from the specified root class. This includes:
/// - Direct children (one level of inheritance)
/// - Grandchildren (two levels)
/// - Great-grandchildren, etc. (any depth)
///
/// # Arguments
///
/// * `root` - The root class to find subclasses of
/// * `root` - The class to find subclasses for
///
/// # Returns
///
/// A vector of all transitive subclasses (not including the root class itself).
/// A vector containing all transitive subclasses. The root class itself is not
/// included in the result. The order of classes in the vector is determined by
/// the BFS traversal order.
///
/// # Examples
///
/// ```text
/// Given:
/// class Animal: pass
/// class Mammal(Animal): pass
/// class Dog(Mammal): pass
/// class Cat(Mammal): pass
///
/// find_all_subclasses(Animal) → [Mammal, Dog, Cat]
/// ```
pub fn find_all_subclasses(&self, root: &ClassId) -> Vec<ClassId> {
use std::collections::VecDeque;

let mut result = Vec::new();
let mut visited = HashSet::new();
let mut queue = VecDeque::new();

// Start with the root's immediate children
if let Some(children) = self.children.get(root) {
for child in children {
queue.push_back(child.clone());
}
}
// Initialize BFS with the root class
queue.push_back(root.clone());
visited.insert((root.module.clone(), root.name.clone()));

// BFS traversal
while let Some(class_id) = queue.pop_front() {
// Skip if already visited (handles potential cycles)
if !visited.insert(class_id.clone()) {
continue;
}

result.push(class_id.clone());

// Add this class's children to the queue
if let Some(children) = self.children.get(&class_id) {
while let Some(current) = queue.pop_front() {
// Examine all direct children of the current class
if let Some(children) = self.children.get(&current) {
for child in children {
if !visited.contains(child) {
let key = (child.module.clone(), child.name.clone());
if !visited.contains(&key) {
visited.insert(key);
result.push(child.clone());
queue.push_back(child.clone());
}
}
Expand All @@ -87,183 +117,3 @@ impl InheritanceGraph {
result
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::parser::{BaseClass, ClassDefinition};
use crate::registry::ClassRegistry;
use std::path::PathBuf;

#[test]
fn test_simple_inheritance() {
let registry = ClassRegistry::new(vec![
// Animal (base class)
crate::parser::ParsedFile {
module_path: "animals".to_string(),
classes: vec![ClassDefinition {
name: "Animal".to_string(),
module_path: "animals".to_string(),
file_path: PathBuf::from("animals.py"),
bases: vec![],
}],
imports: vec![],
is_package: false,
},
// Dog(Animal)
crate::parser::ParsedFile {
module_path: "animals".to_string(),
classes: vec![ClassDefinition {
name: "Dog".to_string(),
module_path: "animals".to_string(),
file_path: PathBuf::from("animals.py"),
bases: vec![BaseClass::Simple("Animal".to_string())],
}],
imports: vec![],
is_package: false,
},
]);

let graph = InheritanceGraph::build(&registry);

let animal_id = ClassId::new("animals".to_string(), "Animal".to_string());
let subclasses = graph.find_all_subclasses(&animal_id);

assert_eq!(subclasses.len(), 1);
assert_eq!(subclasses[0].class_name, "Dog");
}

#[test]
fn test_transitive_inheritance() {
let registry = ClassRegistry::new(vec![
// Animal
crate::parser::ParsedFile {
module_path: "base".to_string(),
classes: vec![ClassDefinition {
name: "Animal".to_string(),
module_path: "base".to_string(),
file_path: PathBuf::from("base.py"),
bases: vec![],
}],
imports: vec![],
is_package: false,
},
// Mammal(Animal)
crate::parser::ParsedFile {
module_path: "base".to_string(),
classes: vec![ClassDefinition {
name: "Mammal".to_string(),
module_path: "base".to_string(),
file_path: PathBuf::from("base.py"),
bases: vec![BaseClass::Simple("Animal".to_string())],
}],
imports: vec![],
is_package: false,
},
// Dog(Mammal)
crate::parser::ParsedFile {
module_path: "base".to_string(),
classes: vec![ClassDefinition {
name: "Dog".to_string(),
module_path: "base".to_string(),
file_path: PathBuf::from("base.py"),
bases: vec![BaseClass::Simple("Mammal".to_string())],
}],
imports: vec![],
is_package: false,
},
]);

let graph = InheritanceGraph::build(&registry);

let animal_id = ClassId::new("base".to_string(), "Animal".to_string());
let subclasses = graph.find_all_subclasses(&animal_id);

// Should find both Mammal and Dog
assert_eq!(subclasses.len(), 2);

let names: HashSet<_> = subclasses.iter().map(|c| c.class_name.as_str()).collect();
assert!(names.contains("Mammal"));
assert!(names.contains("Dog"));
}

#[test]
fn test_multiple_inheritance() {
let registry = ClassRegistry::new(vec![
// Animal
crate::parser::ParsedFile {
module_path: "base".to_string(),
classes: vec![ClassDefinition {
name: "Animal".to_string(),
module_path: "base".to_string(),
file_path: PathBuf::from("base.py"),
bases: vec![],
}],
imports: vec![],
is_package: false,
},
// Pet
crate::parser::ParsedFile {
module_path: "base".to_string(),
classes: vec![ClassDefinition {
name: "Pet".to_string(),
module_path: "base".to_string(),
file_path: PathBuf::from("base.py"),
bases: vec![],
}],
imports: vec![],
is_package: false,
},
// Dog(Animal, Pet)
crate::parser::ParsedFile {
module_path: "base".to_string(),
classes: vec![ClassDefinition {
name: "Dog".to_string(),
module_path: "base".to_string(),
file_path: PathBuf::from("base.py"),
bases: vec![
BaseClass::Simple("Animal".to_string()),
BaseClass::Simple("Pet".to_string()),
],
}],
imports: vec![],
is_package: false,
},
]);

let graph = InheritanceGraph::build(&registry);

// Dog should be a subclass of both Animal and Pet
let animal_id = ClassId::new("base".to_string(), "Animal".to_string());
let animal_subclasses = graph.find_all_subclasses(&animal_id);
assert_eq!(animal_subclasses.len(), 1);
assert_eq!(animal_subclasses[0].class_name, "Dog");

let pet_id = ClassId::new("base".to_string(), "Pet".to_string());
let pet_subclasses = graph.find_all_subclasses(&pet_id);
assert_eq!(pet_subclasses.len(), 1);
assert_eq!(pet_subclasses[0].class_name, "Dog");
}

#[test]
fn test_no_subclasses() {
let registry = ClassRegistry::new(vec![crate::parser::ParsedFile {
module_path: "base".to_string(),
classes: vec![ClassDefinition {
name: "Animal".to_string(),
module_path: "base".to_string(),
file_path: PathBuf::from("base.py"),
bases: vec![],
}],
imports: vec![],
is_package: false,
}]);

let graph = InheritanceGraph::build(&registry);

let animal_id = ClassId::new("base".to_string(), "Animal".to_string());
let subclasses = graph.find_all_subclasses(&animal_id);

assert_eq!(subclasses.len(), 0);
}
}
Loading