Skip to content

Commit f01075d

Browse files
Rangeet PanRangeet Pan
authored andcommitted
update notebook for generating unit tests
1 parent 2cbc21f commit f01075d

2 files changed

Lines changed: 181 additions & 43 deletions

File tree

docs/examples/java/code_summarization.ipynb

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
"execution_count": null,
3030
"outputs": [],
3131
"source": [
32-
"import os\n",
3332
"from pathlib import Path\n",
3433
"import ollama\n",
3534
"from cldk import CLDK\n",
@@ -122,9 +121,8 @@
122121
"execution_count": null,
123122
"outputs": [],
124123
"source": [
125-
"if __name__ == \"__main__\":\n",
126-
" # Create a new instance of the CLDK class\n",
127-
" cldk = CLDK(language=\"java\")"
124+
"# Create a new instance of the CLDK class\n",
125+
"cldk = CLDK(language=\"java\")"
128126
],
129127
"metadata": {
130128
"collapsed": false
@@ -149,8 +147,8 @@
149147
"execution_count": null,
150148
"outputs": [],
151149
"source": [
152-
" # Create an analysis object over the java application\n",
153-
" analysis = cldk.analysis(project_path=\"JAVA_APP_PATH\", analysis_level=AnalysisLevel.symbol_table)"
150+
"# Create an analysis object over the java application\n",
151+
"analysis = cldk.analysis(project_path=\"JAVA_APP_PATH\", analysis_level=AnalysisLevel.symbol_table)"
154152
],
155153
"metadata": {
156154
"collapsed": false
@@ -194,40 +192,39 @@
194192
"execution_count": null,
195193
"outputs": [],
196194
"source": [
197-
"\n",
198-
" # Iterate over all the files in the project\n",
199-
" for file_path, class_file in analysis.get_symbol_table().items():\n",
200-
" class_file_path = Path(file_path).absolute().resolve()\n",
201-
" # Iterate over all the classes in the file\n",
202-
" for type_name, type_declaration in class_file.type_declarations.items():\n",
203-
" # Iterate over all the methods in the class\n",
204-
" for method in type_declaration.callable_declarations.values():\n",
205-
" # Get code body of the method\n",
206-
" code_body = class_file_path.read_text()\n",
207-
"\n",
208-
" # Initialize the treesitter utils for the class file content\n",
209-
" tree_sitter_utils = cldk.tree_sitter_utils(source_code=code_body)\n",
210-
"\n",
211-
" # Sanitize the class for analysis\n",
212-
" sanitized_class = tree_sitter_utils.sanitize_focal_class(method.declaration)\n",
213-
"\n",
214-
" # Format the instruction for the given focal method and class\n",
215-
" instruction = format_inst(\n",
216-
" code=sanitized_class,\n",
217-
" focal_method=method.declaration,\n",
218-
" focal_class=type_name,\n",
219-
" language=\"java\"\n",
220-
" )\n",
221-
"\n",
222-
" # Prompt the local model on Ollama\n",
223-
" llm_output = prompt_ollama(\n",
224-
" message=instruction,\n",
225-
" model_id=\"granite-code:20b-instruct\",\n",
226-
" )\n",
227-
"\n",
228-
" # Print the instruction and LLM output\n",
229-
" print(f\"Instruction:\\n{instruction}\")\n",
230-
" print(f\"LLM Output:\\n{llm_output}\")"
195+
"# Iterate over all the files in the project\n",
196+
"for file_path, class_file in analysis.get_symbol_table().items():\n",
197+
" class_file_path = Path(file_path).absolute().resolve()\n",
198+
" # Iterate over all the classes in the file\n",
199+
" for type_name, type_declaration in class_file.type_declarations.items():\n",
200+
" # Iterate over all the methods in the class\n",
201+
" for method in type_declaration.callable_declarations.values():\n",
202+
" # Get code body of the method\n",
203+
" code_body = class_file_path.read_text()\n",
204+
" \n",
205+
" # Initialize the treesitter utils for the class file content\n",
206+
" tree_sitter_utils = cldk.tree_sitter_utils(source_code=code_body)\n",
207+
" \n",
208+
" # Sanitize the class for analysis\n",
209+
" sanitized_class = tree_sitter_utils.sanitize_focal_class(method.declaration)\n",
210+
" \n",
211+
" # Format the instruction for the given focal method and class\n",
212+
" instruction = format_inst(\n",
213+
" code=sanitized_class,\n",
214+
" focal_method=method.declaration,\n",
215+
" focal_class=type_name,\n",
216+
" language=\"java\"\n",
217+
" )\n",
218+
" \n",
219+
" # Prompt the local model on Ollama\n",
220+
" llm_output = prompt_ollama(\n",
221+
" message=instruction,\n",
222+
" model_id=\"granite-code:20b-instruct\",\n",
223+
" )\n",
224+
" \n",
225+
" # Print the instruction and LLM output\n",
226+
" print(f\"Instruction:\\n{instruction}\")\n",
227+
" print(f\"LLM Output:\\n{llm_output}\")"
231228
],
232229
"metadata": {
233230
"collapsed": false

docs/examples/java/generate_unit_tests.ipynb

Lines changed: 144 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,155 @@
11
{
22
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"source": [
6+
"Generating unit tests for code is a very tedious task and often takes a significant effort from the developers to write good test cases. There are various tools that are available for automated test generation, such as, EvoSuite, which uses evolutionary algorithm to generate test cases. However, the test cases that are being generated are not natural and often developers do not prefer to add them to their test suite. Whereas, Large Language Models (LLM) being trained with developer-written code, it has better affinity towards generating more natural code--more readable, maintainable code. In this excercise, we will show we can leverage LLMs to generate test cases with the help of CLDK. \n",
7+
"\n",
8+
"For simplicity, we will cover certain aspects of test generation and provide some context information to LLM for better quality of test cases. In this excercise, we will generate unit test for non-private method from a Java class and provide the focal method body and the signature of all the constructors of the class so that LLM can understand how to create object of the focal class during the setup phase of the tests. Also, we will ask LLMs to generate ```N``` number of test cases, where ```N``` is the cyclomatic complexity of the focal method. The intuition is that one test may not be sufficient for covering fairly complex method and cyclomatic complexity score can provide some guidance towards that. \n",
9+
"\n",
10+
"(Step 1) First, we will import all the neccessary libraries"
11+
],
12+
"metadata": {
13+
"collapsed": false
14+
},
15+
"id": "5856baff4aa64ed7"
16+
},
317
{
418
"cell_type": "code",
519
"execution_count": null,
6-
"id": "initial_id",
20+
"outputs": [],
21+
"source": [
22+
"from pathlib import Path\n",
23+
"import ollama\n",
24+
"from cldk import CLDK\n",
25+
"from cldk.analysis import AnalysisLevel"
26+
],
727
"metadata": {
8-
"collapsed": true
28+
"collapsed": false
929
},
30+
"id": "b3d2498ae092fcc"
31+
},
32+
{
33+
"cell_type": "markdown",
34+
"source": [
35+
"(Step 2) Second, we will form the prompt for the model, which will include all the constructor signarures, and the body of the focal method."
36+
],
37+
"metadata": {
38+
"collapsed": false
39+
},
40+
"id": "67eb24b29826d730"
41+
},
42+
{
43+
"cell_type": "code",
44+
"execution_count": null,
1045
"outputs": [],
11-
"source": []
46+
"source": [
47+
"def format_inst(focal_method_body, focal_method, focal_class, constructor_signatures, cyclomatic_complexity, language):\n",
48+
" \"\"\"\n",
49+
" Format the instruction for the given focal method and class.\n",
50+
" \"\"\"\n",
51+
" inst = f\"Question: Can you generate {cyclomatic_complexity} unit tests for the method `{focal_method}` in the class `{focal_class}` below?\\n\"\n",
52+
"\n",
53+
" inst += \"\\n\"\n",
54+
" inst += f\"```{language}\\n\"\n",
55+
" inst += \"```\\n\"\n",
56+
" inst += \"public class {focal_class} {\"\n",
57+
" inst += f\"<|constructors|>\\n{constructor_signatures}\\n<|constructors|>\\n\"\n",
58+
" inst += f\"<|focal method|>\\n {focal_method_body} \\n <|focal method|>\\n\" \n",
59+
" inst += \"}\"\n",
60+
" inst += \"```\\n\"\n",
61+
" inst += \"Answer:\\n\"\n",
62+
" return inst"
63+
],
64+
"metadata": {
65+
"collapsed": false
66+
},
67+
"id": "d7bc9bbaa917df24"
68+
},
69+
{
70+
"cell_type": "markdown",
71+
"source": [
72+
"(Step 3) Third, use ollama to call LLM (in case Granite 8b)."
73+
],
74+
"metadata": {
75+
"collapsed": false
76+
},
77+
"id": "ae9ceb150f5efa92"
78+
},
79+
{
80+
"cell_type": "code",
81+
"execution_count": null,
82+
"outputs": [],
83+
"source": [
84+
"def prompt_ollama(message: str, model_id: str = \"granite-code:8b-instruct\") -> str:\n",
85+
" \"\"\"Prompt local model on Ollama\"\"\"\n",
86+
" response_object = ollama.generate(model=model_id, prompt=message)\n",
87+
" return response_object[\"response\"]"
88+
],
89+
"metadata": {
90+
"collapsed": false
91+
},
92+
"id": "52634feae7374599"
93+
},
94+
{
95+
"cell_type": "markdown",
96+
"source": [
97+
"(Step 3) Third, collect all the information needed for each method. "
98+
],
99+
"metadata": {
100+
"collapsed": false
101+
},
102+
"id": "308c3325116b87d4"
103+
},
104+
{
105+
"cell_type": "code",
106+
"execution_count": null,
107+
"outputs": [],
108+
"source": [
109+
"# Create a new instance of the CLDK class\n",
110+
"cldk = CLDK(language=\"java\")\n",
111+
"# Create an analysis object over the java application. Provide the application path using JAVA_APP_PATH\n",
112+
"analysis = cldk.analysis(project_path=\"JAVA_APP_PATH\", analysis_level=AnalysisLevel.symbol_table)\n",
113+
"# Go through all the classes in the application\n",
114+
"for class_name in analysis.get_classes():\n",
115+
" class_details = analysis.get_class(qualified_class_name=class_name)\n",
116+
" # Generate test cases for non-interface and non-abstract classes\n",
117+
" if not class_details.is_interface and 'abstract' not in class_details.modifiers:\n",
118+
" # Get all constructor signatures\n",
119+
" constructor_signatures = ''\n",
120+
" for method in analysis.get_methods_in_class(qualified_class_name=class_name):\n",
121+
" method_details = analysis.get_method(qualified_class_name=class_name, qualified_method_name=method)\n",
122+
" if method_details.is_constructor:\n",
123+
" constructor_signatures += method_details.signature + '\\n'\n",
124+
" # If no constructor present, then add the signature of the default constructor\n",
125+
" if constructor_signatures=='':\n",
126+
" constructor_signatures = f'public {class_name} ()'\n",
127+
" # Go through all the methods in the class\n",
128+
" for method in analysis.get_methods_in_class(qualified_class_name=class_name):\n",
129+
" # Get the method details\n",
130+
" method_details = analysis.get_method(qualified_class_name=class_name, qualified_method_name=method)\n",
131+
" # Generate test cases for non-private methods\n",
132+
" if 'private' not in method_details.modifiers and not method_details.is_constructor:\n",
133+
" # Gather all the information needed for the prompt, which are focal method body, focal method name, focal class name, constructor signature, and cyclomatic complexity\n",
134+
" prompt = format_inst(focal_method_body=method_details.code,\n",
135+
" focal_method=method,\n",
136+
" focal_class=class_name,\n",
137+
" constructor_signatures=constructor_signatures,\n",
138+
" cyclomatic_complexity=method_details.cyclomatic_complexity)\n",
139+
" # Prompt the local model on Ollama\n",
140+
" llm_output = prompt_ollama(\n",
141+
" message=prompt,\n",
142+
" model_id=\"granite-code:20b-instruct\",\n",
143+
" )\n",
144+
" \n",
145+
" # Print the instruction and LLM output\n",
146+
" print(f\"Instruction:\\n{prompt}\")\n",
147+
" print(f\"LLM Output:\\n{llm_output}\")"
148+
],
149+
"metadata": {
150+
"collapsed": false
151+
},
152+
"id": "65c9558e4de65a52"
12153
}
13154
],
14155
"metadata": {

0 commit comments

Comments
 (0)