Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions .github/workflows/test-gen.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
name: AI Test Generation

on:
pull_request:
types: [opened, synchronize]
paths:
- '**.py'
- '!tests/**'

jobs:
generate-tests:
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write

steps:
- name: Checkout PR branch
uses: actions/checkout@v4
with:
ref: ${{ github.head_ref }}
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'

- name: Install dependencies
run: |
pip install -r requirements.txt

- name: Get changed Python files
id: changed-files
run: |
FILES=$(git diff --name-only origin/${{ github.base_ref }}..HEAD -- '*.py' | grep -v '^tests/' | tr '\n' ' ')
echo "files=$FILES" >> $GITHUB_OUTPUT

- name: Generate tests
if: steps.changed-files.outputs.files != ''
env:
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
run: |
python scripts/generate_tests.py ${{ steps.changed-files.outputs.files }}

- name: Commit generated tests
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"

if git diff --quiet tests/; then
echo "No new tests generated"
exit 0
fi

git add tests/
git commit -m "Auto-generate tests for new code [skip ci]"
git push


git add tests/
git commit -m "Auto-generate tests for updated Python files [skip ci]"
git push
10 changes: 9 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,12 @@ def reverse_string(s: str) -> str:

def multiply(a: int, b: int) -> int:
"""Multiply two numbers together."""
return a * b
return a * b

def factorial(n: int) -> int:
"""Calculate the factorial of a number."""
if n < 0:
raise ValueError("Factorial is not defined for negative numbers.")
if n <= 1:
return 1
return n * factorial(n - 1)
6 changes: 6 additions & 0 deletions scripts/dangerous.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import subprocess

def run_command(user_input):
subprocess.call(user_input, shell=True)

API_KEY = "sk-live-abc123def456"
95 changes: 95 additions & 0 deletions scripts/generate_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import ast
import os
import sys
from google import genai

def extract_functions(file_path):
"""parse the python file and extract function definitions"""
with open(file_path, 'r') as f:
source = f.read()

tree = ast.parse(source)
functions = []

for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
func_name = node.name
args = [arg.arg for arg in node.args.args]
docstring = ast.get_docstring(node) or ""
source_code = ast.get_source_segment(source, node)

functions.append({
'name': func_name,
'args': args,
'docstring': docstring,
'source_code': source_code
})
return functions

def generate_tests_for_functions(func_info):
"""Use Gemini to generate pytest tests for a function"""
client = genai.Client()

prompt = f"""
Generate pytest tests for this following Python function.

funtions name: {func_info['name']}
Arguments: {func_info['args']}
Docstring: {func_info['docstring']}

Source code: {func_info['source_code']}

requirements:
1. Generate 3-5 meaningful test cases.
2. Include edge cases (empty cases, None values, etc.)
3. use descriptive test function names.
4. Include asserions that actually test the function's behavior.
5. Do not include any placeholder tests like `assert True` or `assert False`.
6. Use pytest framework for the tests.
"""
response = client.models.generate_content(
model="gemini-2.5-flash", contents=prompt
)
return response.text

def main():

changed_files = sys.argv[1:] if len(sys.argv) > 1 else []

if not changed_files:
print("No changed Python files provided for Test Generation.")
return

all_tests = []
for file_path in changed_files:
if not file_path.endswith('.py'):
continue
if file_path.startswith('tests/'):
continue

print(f"Analyzing : {file_path}")
functions = extract_functions(file_path)

for func in functions:
if func['name'].startswith('_'):
continue

print(f"Generating tests for function: {func['name']} in {file_path}")
tests = generate_tests_for_functions(func)
all_tests.append(f"# Tests for {func['name']} from {file_path}\n{tests}")

if all_tests:
os.makedirs('tests', exist_ok=True)
test_file = 'tests/test_generated.py'

with open(test_file, 'w') as f:
f.write("import pytest\n\n")
f.write("\n\n".join(all_tests))

print(f"Generated tests written to: {test_file}")
else:
print("No functions found to generate tests for")


if __name__ == '__main__':
main()
Loading