diff --git a/bank_account.py b/bank_account.py new file mode 100644 index 0000000..f02514c --- /dev/null +++ b/bank_account.py @@ -0,0 +1,32 @@ +class BankAccount: + """A simple BankAccount class.""" + + def __init__(self, account_holder: str, initial_balance: float = 0.0): + """Initialize the account with the account holder's name and initial balance.""" + self.account_holder = account_holder + self.balance = initial_balance + + def deposit(self, amount: float): + """Deposit an amount into the account.""" + if amount <= 0: + raise ValueError("Deposit amount must be positive.") + self.balance += amount + return self.balance + + def withdraw(self, amount: float): + """Withdraw an amount from the account.""" + if amount <= 0: + raise ValueError("Withdrawal amount must be positive.") + if amount > self.balance: + raise ValueError("Insufficient balance.") + self.balance -= amount + return self.balance + + def get_balance(self): + """Return the current balance of the account.""" + return self.balance + + def __str__(self): + """Return a string representation of the account.""" + return f"BankAccount({self.account_holder}, Balance: {self.balance})" + diff --git a/checking_account.py b/checking_account.py new file mode 100644 index 0000000..0548cfa --- /dev/null +++ b/checking_account.py @@ -0,0 +1,30 @@ +from bank_account import BankAccount + +class CheckingAccount(BankAccount): + """A CheckingAccount class that extends BankAccount with overdraft protection.""" + + def __init__(self, account_holder: str, initial_balance: float = 0.0, overdraft_limit: float = 500.0): + """ + Initialize the checking account with an overdraft limit. + :param overdraft_limit: Maximum overdraft amount allowed (default is $500.0). + """ + super().__init__(account_holder, initial_balance) + self.overdraft_limit = overdraft_limit + + def withdraw(self, amount: float): + """ + Withdraw an amount from the account, considering overdraft protection. + :param amount: The amount to withdraw. + :raises ValueError: If withdrawal exceeds overdraft limit. + """ + if amount <= 0: + raise ValueError("Withdrawal amount must be positive.") + if amount > self.balance + self.overdraft_limit: + raise ValueError("Withdrawal exceeds overdraft limit.") + self.balance -= amount + return self.balance + + def __str__(self): + """Return a string representation of the checking account.""" + return f"CheckingAccount({self.account_holder}, Balance: {self.balance}, Overdraft Limit: {self.overdraft_limit})" + diff --git a/generated_test_cases.py b/generated_test_cases.py new file mode 100644 index 0000000..23bdfbe --- /dev/null +++ b/generated_test_cases.py @@ -0,0 +1,18 @@ +import pytest +from combined_test_cases import * + +# Initialize instances +Person_1 = Person('woARN', -85) +Company_1 = Company('dDveR') + +# Test cases +def test_Company_add_employee_0(): + instance = Company_1 + result = instance.add_employee(Person()) + assert isinstance(result, NoneType) + +def test_Company_add_employee_1(): + instance = Company_1 + result = instance.add_employee(Person()) + assert isinstance(result, NoneType) + diff --git a/randoop_cli/cli.py b/randoop_cli/cli.py index cdd3424..f410d1d 100644 --- a/randoop_cli/cli.py +++ b/randoop_cli/cli.py @@ -1,11 +1,17 @@ import click from pathlib import Path +import requests +import zipfile +import os +import shutil +from io import BytesIO from .module_loader import load_module from .class_inspection import get_classes from .test_generator import randoop_test_generator, write_regression_tests import time from rich.console import Console from rich.progress import Progress +import time console = Console() @@ -17,6 +23,92 @@ def simulate_loading(task_name, steps=5, delay=0.5): time.sleep(delay) # Simulating a delay progress.update(task, advance=1) +def download_and_extract_repo(repo_url, temp_dir): + """ + Downloads a GitHub repository as a zip file and extracts it to a temporary directory. + """ + if not repo_url.endswith(".git"): + repo_url = repo_url.rstrip("/") + ".git" + + zip_url = repo_url.replace(".git", "/archive/refs/heads/master.zip") + console.print(f"Downloading repository from: {zip_url}") + + response = requests.get(zip_url, stream=True) + if response.status_code == 200: + with zipfile.ZipFile(BytesIO(response.content)) as zf: + zf.extractall(temp_dir) + else: + console.print(f"[bold red]Failed to download the repository: {response.status_code}[/bold red]") + exit(1) + +def identify_source_files(repo_path): + """ + Identifies source code files in the repository by excluding non-code files like docs, tests, and examples. + """ + source_files = [] + for root, dirs, files in os.walk(repo_path): + # Ignore non-source directories + ignored_dirs = {"docs", "examples", "tests", ".github", "__pycache__"} + dirs[:] = [d for d in dirs if d not in ignored_dirs] + + for file in files: + if file.endswith(".py") and not file.startswith("__init__"): + source_files.append(Path(root) / file) + + return source_files + +import re +from collections import defaultdict, deque + +def parse_imports(file_path): + """ + Parse a Python file to extract imported modules or files. + """ + imports = [] + with open(file_path, "r", encoding="utf-8") as f: + for line in f: + match = re.match(r"^\s*(?:from|import) (\S+)", line) + if match: + imports.append(match.group(1).split('.')[0]) # Extract the module name + return imports + +def resolve_dependencies(source_files): + """ + Resolve file loading order based on import dependencies using topological sorting. + """ + dependency_graph = defaultdict(set) + file_map = {file.stem: file for file in source_files} + + # Build the dependency graph + for file in source_files: + imports = parse_imports(file) + for module in imports: + if module in file_map: # Only consider local modules + dependency_graph[file.stem].add(module) + + # Perform topological sorting + visited = set() + resolved = [] + temp_stack = set() + + def visit(node): + if node in temp_stack: + raise ValueError(f"Circular dependency detected: {node}") + if node not in visited: + temp_stack.add(node) + for dep in dependency_graph[node]: + visit(dep) + temp_stack.remove(node) + visited.add(node) + resolved.append(node) + + for file in source_files: + if file.stem not in visited: + visit(file.stem) + + # Return files in the resolved order + return [file_map[stem] for stem in resolved] + @click.command() @click.option( "-k", @@ -26,43 +118,103 @@ def simulate_loading(task_name, steps=5, delay=0.5): help="Number of times to extend the sequence (default: 2)", show_default=True, ) +@click.option( + "--repo-url", + type=str, + default=None, + help="GitHub repository URL to process.", +) @click.option( "-f", - "--file-path", + "--file", + "file_paths", type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path), - required=True, - help="Path to the Python file with class definitions", + multiple=True, + help="Path to individual Python files to process (use -f multiple times for multiple files).", ) -def main(sequence_length, file_path): - """Python Randoop test generator for Python classes.""" +def main(sequence_length, repo_url, file_paths): + """ + Process a GitHub repository URL or multiple Python source files and generate test cases for their classes. + """ console.print("[bold blue]Randoop-Python Test Generator[/bold blue]\n") - # Simulate loading for module loading - simulate_loading("Loading module") - module = load_module(file_path) + # Temporary directory for repository + temp_dir = Path("temp_repo") - # Simulate loading for class inspection - simulate_loading("Inspecting classes") - classes = get_classes(module) + try: + shared_namespace = {} + + if repo_url: + # If a GitHub repository URL is provided, process the repository + console.print("[bold green]Processing GitHub repository...[/bold green]") + simulate_loading("Downloading repository") + download_and_extract_repo(repo_url, temp_dir) + + # Identify source files in the repository + repo_root = next(temp_dir.iterdir()) # First directory inside the extracted repo + simulate_loading("Identifying source files") + source_files = identify_source_files(repo_root) + elif file_paths: + # If files are provided via -f, process them + console.print("[bold green]Processing provided files...[/bold green]") + source_files = list(file_paths) + else: + console.print("[bold red]No repository URL or files provided. Please specify one.[/bold red]") + exit(1) + + # Resolve dependencies and sort files + source_files = resolve_dependencies(source_files) + + # Load all source files into the shared namespace + for file_path in source_files: + console.print(f"\n[bold yellow]Processing file: {file_path}[/bold yellow]\n") + simulate_loading("Loading module") + load_module(file_path, shared_namespace) + + # Inspect classes + simulate_loading("Inspecting classes") + all_classes = [ + (name, obj) + for name, obj in shared_namespace.items() + if isinstance(obj, type) # Only consider class types + ] + if not all_classes: + console.print("[bold red]No classes found in the source files.[/bold red]") + exit(1) + + # Generate tests + simulate_loading("Generating Random tests") + sequences, error_prone_cases, storage, instance_creation_data = randoop_test_generator( + all_classes, sequence_length + ) + + # Display Successful Sequences + console.print("\n[bold green]-----> Generated Instances and Sequences:[/bold green]") + for seq in sequences: + print(seq) + + # Display Error-Prone Test Cases + console.print("\n[bold yellow]-----> Error-Prone Test Cases:[/bold yellow]") + for error in error_prone_cases: + print(error) - if not classes: - console.print(f"[bold red]No classes found in '{file_path}'.[/bold red]") - exit(1) - # Simulate loading for test generation - simulate_loading("Generating Random tests") - test_results = randoop_test_generator(classes, sequence_length) + # Write test cases + write_test_cases( + sequences=sequences, + storage=storage, + module_name="combined_namespace", + file_path="combined_test_cases.py", + instance_creation_data=instance_creation_data, + ) - # Display Successful Sequences - print("\n-----> Generated Instances and Sequences:") - for seq in test_results["sequences"]: - print(seq) + console.print("[bold green]All tasks completed successfully![/bold green]") - print("\n-----> Error-Prone Test Cases:") - for error in test_results["error_cases"]: - print(error) + finally: + # Cleanup temporary directory if used + if repo_url and temp_dir.exists(): + shutil.rmtree(temp_dir) - console.print("[bold green]All tasks completed successfully![/bold green]") if __name__ == "__main__": main() diff --git a/randoop_cli/module_loader.py b/randoop_cli/module_loader.py index 2b6e5f5..b61087e 100644 --- a/randoop_cli/module_loader.py +++ b/randoop_cli/module_loader.py @@ -1,9 +1,27 @@ import importlib.util import sys +import os -def load_module(file_path): - spec = importlib.util.spec_from_file_location("module.name", file_path) +def load_module(file_path, namespace): + """ + Load a Python module from a file path into a shared namespace. + Skip setup.py or files that shouldn't be loaded. + """ + # Skip setup.py as it causes issues during package loading + if file_path.name == "setup.py": + print(f"Skipping {file_path} due to potential issues with loading") + return + + module_name = file_path.stem + spec = importlib.util.spec_from_file_location(module_name, file_path) module = importlib.util.module_from_spec(spec) - sys.modules["module.name"] = module + sys.modules[module_name] = module + spec.loader.exec_module(module) + + # Merge module attributes into the shared namespace + for name, obj in vars(module).items(): + if not name.startswith("_"): # Ignore private attributes + namespace[name] = obj + return module diff --git a/savings_account.py b/savings_account.py new file mode 100644 index 0000000..fe0a05f --- /dev/null +++ b/savings_account.py @@ -0,0 +1,23 @@ +from bank_account import BankAccount + +class SavingsAccount(BankAccount): + """A SavingsAccount class that extends BankAccount with interest calculation.""" + + def __init__(self, account_holder: str, initial_balance: float = 0.0, interest_rate: float = 0.02): + """ + Initialize the savings account with an interest rate. + :param interest_rate: Annual interest rate as a decimal (default is 0.02 for 2%). + """ + super().__init__(account_holder, initial_balance) + self.interest_rate = interest_rate + + def add_interest(self): + """Add interest to the current balance.""" + interest = self.balance * self.interest_rate + self.deposit(interest) + return self.balance + + def __str__(self): + """Return a string representation of the savings account.""" + return f"SavingsAccount({self.account_holder}, Balance: {self.balance}, Interest Rate: {self.interest_rate})" +