Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions bank_account.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
class BankAccount:
"""A simple BankAccount class."""

def __init__(self, account_holder: str, initial_balance: float = 0.0):
"""Initialize the account with the account holder's name and initial balance."""
self.account_holder = account_holder
self.balance = initial_balance

def deposit(self, amount: float):
"""Deposit an amount into the account."""
if amount <= 0:
raise ValueError("Deposit amount must be positive.")
self.balance += amount
return self.balance

def withdraw(self, amount: float):
"""Withdraw an amount from the account."""
if amount <= 0:
raise ValueError("Withdrawal amount must be positive.")
if amount > self.balance:
raise ValueError("Insufficient balance.")
self.balance -= amount
return self.balance

def get_balance(self):
"""Return the current balance of the account."""
return self.balance

def __str__(self):
"""Return a string representation of the account."""
return f"BankAccount({self.account_holder}, Balance: {self.balance})"

30 changes: 30 additions & 0 deletions checking_account.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from bank_account import BankAccount

class CheckingAccount(BankAccount):
"""A CheckingAccount class that extends BankAccount with overdraft protection."""

def __init__(self, account_holder: str, initial_balance: float = 0.0, overdraft_limit: float = 500.0):
"""
Initialize the checking account with an overdraft limit.
:param overdraft_limit: Maximum overdraft amount allowed (default is $500.0).
"""
super().__init__(account_holder, initial_balance)
self.overdraft_limit = overdraft_limit

def withdraw(self, amount: float):
"""
Withdraw an amount from the account, considering overdraft protection.
:param amount: The amount to withdraw.
:raises ValueError: If withdrawal exceeds overdraft limit.
"""
if amount <= 0:
raise ValueError("Withdrawal amount must be positive.")
if amount > self.balance + self.overdraft_limit:
raise ValueError("Withdrawal exceeds overdraft limit.")
self.balance -= amount
return self.balance

def __str__(self):
"""Return a string representation of the checking account."""
return f"CheckingAccount({self.account_holder}, Balance: {self.balance}, Overdraft Limit: {self.overdraft_limit})"

18 changes: 18 additions & 0 deletions generated_test_cases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
from combined_test_cases import *

# Initialize instances
Person_1 = Person('woARN', -85)
Company_1 = Company('dDveR')

# Test cases
def test_Company_add_employee_0():
instance = Company_1
result = instance.add_employee(Person())
assert isinstance(result, NoneType)

def test_Company_add_employee_1():
instance = Company_1
result = instance.add_employee(Person())
assert isinstance(result, NoneType)

202 changes: 177 additions & 25 deletions randoop_cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import click
from pathlib import Path
import requests
import zipfile
import os
import shutil
from io import BytesIO
from .module_loader import load_module
from .class_inspection import get_classes
from .test_generator import randoop_test_generator, write_regression_tests
import time
from rich.console import Console
from rich.progress import Progress
import time

console = Console()

Expand All @@ -17,6 +23,92 @@ def simulate_loading(task_name, steps=5, delay=0.5):
time.sleep(delay) # Simulating a delay
progress.update(task, advance=1)

def download_and_extract_repo(repo_url, temp_dir):
"""
Downloads a GitHub repository as a zip file and extracts it to a temporary directory.
"""
if not repo_url.endswith(".git"):
repo_url = repo_url.rstrip("/") + ".git"

zip_url = repo_url.replace(".git", "/archive/refs/heads/master.zip")
console.print(f"Downloading repository from: {zip_url}")

response = requests.get(zip_url, stream=True)
if response.status_code == 200:
with zipfile.ZipFile(BytesIO(response.content)) as zf:
zf.extractall(temp_dir)
else:
console.print(f"[bold red]Failed to download the repository: {response.status_code}[/bold red]")
exit(1)

def identify_source_files(repo_path):
"""
Identifies source code files in the repository by excluding non-code files like docs, tests, and examples.
"""
source_files = []
for root, dirs, files in os.walk(repo_path):
# Ignore non-source directories
ignored_dirs = {"docs", "examples", "tests", ".github", "__pycache__"}
dirs[:] = [d for d in dirs if d not in ignored_dirs]

for file in files:
if file.endswith(".py") and not file.startswith("__init__"):
source_files.append(Path(root) / file)

return source_files

import re
from collections import defaultdict, deque

def parse_imports(file_path):
"""
Parse a Python file to extract imported modules or files.
"""
imports = []
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
match = re.match(r"^\s*(?:from|import) (\S+)", line)
if match:
imports.append(match.group(1).split('.')[0]) # Extract the module name
return imports

def resolve_dependencies(source_files):
"""
Resolve file loading order based on import dependencies using topological sorting.
"""
dependency_graph = defaultdict(set)
file_map = {file.stem: file for file in source_files}

# Build the dependency graph
for file in source_files:
imports = parse_imports(file)
for module in imports:
if module in file_map: # Only consider local modules
dependency_graph[file.stem].add(module)

# Perform topological sorting
visited = set()
resolved = []
temp_stack = set()

def visit(node):
if node in temp_stack:
raise ValueError(f"Circular dependency detected: {node}")
if node not in visited:
temp_stack.add(node)
for dep in dependency_graph[node]:
visit(dep)
temp_stack.remove(node)
visited.add(node)
resolved.append(node)

for file in source_files:
if file.stem not in visited:
visit(file.stem)

# Return files in the resolved order
return [file_map[stem] for stem in resolved]

@click.command()
@click.option(
"-k",
Expand All @@ -26,43 +118,103 @@ def simulate_loading(task_name, steps=5, delay=0.5):
help="Number of times to extend the sequence (default: 2)",
show_default=True,
)
@click.option(
"--repo-url",
type=str,
default=None,
help="GitHub repository URL to process.",
)
@click.option(
"-f",
"--file-path",
"--file",
"file_paths",
type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path),
required=True,
help="Path to the Python file with class definitions",
multiple=True,
help="Path to individual Python files to process (use -f multiple times for multiple files).",
)
def main(sequence_length, file_path):
"""Python Randoop test generator for Python classes."""
def main(sequence_length, repo_url, file_paths):
"""
Process a GitHub repository URL or multiple Python source files and generate test cases for their classes.
"""
console.print("[bold blue]Randoop-Python Test Generator[/bold blue]\n")

# Simulate loading for module loading
simulate_loading("Loading module")
module = load_module(file_path)
# Temporary directory for repository
temp_dir = Path("temp_repo")

# Simulate loading for class inspection
simulate_loading("Inspecting classes")
classes = get_classes(module)
try:
shared_namespace = {}

if repo_url:
# If a GitHub repository URL is provided, process the repository
console.print("[bold green]Processing GitHub repository...[/bold green]")
simulate_loading("Downloading repository")
download_and_extract_repo(repo_url, temp_dir)

# Identify source files in the repository
repo_root = next(temp_dir.iterdir()) # First directory inside the extracted repo
simulate_loading("Identifying source files")
source_files = identify_source_files(repo_root)
elif file_paths:
# If files are provided via -f, process them
console.print("[bold green]Processing provided files...[/bold green]")
source_files = list(file_paths)
else:
console.print("[bold red]No repository URL or files provided. Please specify one.[/bold red]")
exit(1)

# Resolve dependencies and sort files
source_files = resolve_dependencies(source_files)

# Load all source files into the shared namespace
for file_path in source_files:
console.print(f"\n[bold yellow]Processing file: {file_path}[/bold yellow]\n")
simulate_loading("Loading module")
load_module(file_path, shared_namespace)

# Inspect classes
simulate_loading("Inspecting classes")
all_classes = [
(name, obj)
for name, obj in shared_namespace.items()
if isinstance(obj, type) # Only consider class types
]
if not all_classes:
console.print("[bold red]No classes found in the source files.[/bold red]")
exit(1)

# Generate tests
simulate_loading("Generating Random tests")
sequences, error_prone_cases, storage, instance_creation_data = randoop_test_generator(
all_classes, sequence_length
)

# Display Successful Sequences
console.print("\n[bold green]-----> Generated Instances and Sequences:[/bold green]")
for seq in sequences:
print(seq)

# Display Error-Prone Test Cases
console.print("\n[bold yellow]-----> Error-Prone Test Cases:[/bold yellow]")
for error in error_prone_cases:
print(error)

if not classes:
console.print(f"[bold red]No classes found in '{file_path}'.[/bold red]")
exit(1)

# Simulate loading for test generation
simulate_loading("Generating Random tests")
test_results = randoop_test_generator(classes, sequence_length)
# Write test cases
write_test_cases(
sequences=sequences,
storage=storage,
module_name="combined_namespace",
file_path="combined_test_cases.py",
instance_creation_data=instance_creation_data,
)

# Display Successful Sequences
print("\n-----> Generated Instances and Sequences:")
for seq in test_results["sequences"]:
print(seq)
console.print("[bold green]All tasks completed successfully![/bold green]")

print("\n-----> Error-Prone Test Cases:")
for error in test_results["error_cases"]:
print(error)
finally:
# Cleanup temporary directory if used
if repo_url and temp_dir.exists():
shutil.rmtree(temp_dir)

console.print("[bold green]All tasks completed successfully![/bold green]")

if __name__ == "__main__":
main()
24 changes: 21 additions & 3 deletions randoop_cli/module_loader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,27 @@
import importlib.util
import sys
import os

def load_module(file_path):
spec = importlib.util.spec_from_file_location("module.name", file_path)
def load_module(file_path, namespace):
"""
Load a Python module from a file path into a shared namespace.
Skip setup.py or files that shouldn't be loaded.
"""
# Skip setup.py as it causes issues during package loading
if file_path.name == "setup.py":
print(f"Skipping {file_path} due to potential issues with loading")
return

module_name = file_path.stem
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules["module.name"] = module
sys.modules[module_name] = module

spec.loader.exec_module(module)

# Merge module attributes into the shared namespace
for name, obj in vars(module).items():
if not name.startswith("_"): # Ignore private attributes
namespace[name] = obj

return module
23 changes: 23 additions & 0 deletions savings_account.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from bank_account import BankAccount

class SavingsAccount(BankAccount):
"""A SavingsAccount class that extends BankAccount with interest calculation."""

def __init__(self, account_holder: str, initial_balance: float = 0.0, interest_rate: float = 0.02):
"""
Initialize the savings account with an interest rate.
:param interest_rate: Annual interest rate as a decimal (default is 0.02 for 2%).
"""
super().__init__(account_holder, initial_balance)
self.interest_rate = interest_rate

def add_interest(self):
"""Add interest to the current balance."""
interest = self.balance * self.interest_rate
self.deposit(interest)
return self.balance

def __str__(self):
"""Return a string representation of the savings account."""
return f"SavingsAccount({self.account_holder}, Balance: {self.balance}, Interest Rate: {self.interest_rate})"