-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpatch.py
More file actions
251 lines (210 loc) · 9.87 KB
/
patch.py
File metadata and controls
251 lines (210 loc) · 9.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import re
import Levenshtein
# Hunk header for a normal unified diff
UNIFIED_DIFF_HUNK_HEADER_REGEX = r'@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@'
# Hunk header for a unified diff with no line counts like @@ ... @@
UNIFIED_DIFF_HUNK_HEADER_NO_COUNTS_REGEX = r'@@ \.\.\. @@'
def is_unified_diff(patch: list[str]) -> bool:
# Check if the patch contains unified diff hunk headers
for line in patch:
if re.match(UNIFIED_DIFF_HUNK_HEADER_REGEX, line):
return True
if re.match(UNIFIED_DIFF_HUNK_HEADER_NO_COUNTS_REGEX, line):
return True
return False
def is_unified_diff_no_counts(patch: list[str]) -> bool:
# Check if the patch contains unified diff hunk headers without line counts
for line in patch:
if re.match(UNIFIED_DIFF_HUNK_HEADER_NO_COUNTS_REGEX, line):
return True
return False
class Hunk:
MAX_STARTING_CONTEXT = 3
MAX_TRAILING_CONTEXT = 3
def __init__(self, header: str, lines: list[str]):
# Extract original header info
match = re.match(UNIFIED_DIFF_HUNK_HEADER_REGEX, header)
if match:
self.start_original = int(match.group(1))
self.start_new = int(match.group(3))
else:
self.start_original = 0
self.start_new = 0
self.match = []
self.replace = []
for line in lines:
if not line:
print("Empty line in hunk, truncating hunk context")
break
if line.startswith('+'):
line_content = line[1:] # Skip the first character (+, -, or space)
self.replace.append(line_content)
elif line.startswith('-'):
line_content = line[1:] # Skip the first character (+, -, or space)
self.match.append(line_content)
else:
if line[0].isspace():
line_content = line[1:]
else:
line_content = line # Fix for faulty LLM patch
self.match.append(line_content)
self.replace.append(line_content)
# Adjust starting and trailing context, and trim match/replace lists accordingly
# LLMs may add too much context, but it can also create pairs of +/- lines that do not differ
# We will keep at most MAX_STARTING_CONTEXT lines of context at the start and MAX_TRAILING_CONTEXT lines at the end
# To do this, we match actual self.match and self.replace lines from the start and end
# and trim the rest
actual_start = 0
# count actual matching context lines from the start
for i in range(min(len(self.match), len(self.replace))):
if self.match[i] == self.replace[i]:
actual_start += 1
else:
break
if actual_start > self.MAX_STARTING_CONTEXT:
trim_amount = actual_start - self.MAX_STARTING_CONTEXT
print(f"Trimming {trim_amount} starting context lines")
self.match = self.match[trim_amount:]
self.replace = self.replace[trim_amount:]
if not self.match or not self.replace:
return
actual_end = 0
# count actual matching context lines from the end
for i in range(1, min(len(self.match), len(self.replace)) + 1):
if self.match[-i] == self.replace[-i]:
actual_end += 1
else:
break
if actual_end > self.MAX_TRAILING_CONTEXT:
trim_amount = actual_end - self.MAX_TRAILING_CONTEXT
if trim_amount > 0:
print(f"Trimming {trim_amount} trailing context lines")
self.match = self.match[:-trim_amount]
self.replace = self.replace[:-trim_amount]
def empty(self) -> bool:
return self.match_count() == 0
def match_count(self) -> int:
return len(self.match)
def replace_count(self) -> int:
return len(self.replace)
def trim_comment(self, line):
# Remove trailing whitespace
line = line.rstrip()
# Remove python comment if there is a python comment
# For example: " code # comment" -> " code"
# But not: " print('#')" as it is not a comment
# Regex to match # that's not inside quotes
# This pattern matches strings and skips # inside them
pattern = r'''(?:[^'"#]|"[^"]*"|'[^']*')*?(?=#|$)'''
match = re.match(pattern, line)
if match:
return match.group(0).rstrip()
return line
def matches_code(self, code_lines: list[str], start_line: int, fuzziness: int) -> bool:
# Check if the hunk matches the code lines starting at start_line (0-based)
for i in range(self.match_count()):
code_index = start_line + i
if code_index >= len(code_lines):
return False
code_line = code_lines[code_index]
patch_line = self.match[i]
if fuzziness == 0:
# With no fuzziness, lines must match exactly
# If there is a mismatch, return False
if code_line != patch_line:
return False
if fuzziness > 0:
# With fuzziness, trim comments and trailing whitespace before comparing
code_line = self.trim_comment(code_line)
patch_line = self.trim_comment(patch_line)
if fuzziness == 1:
# With fuzziness 1, ignore leading/trailing whitespace and still require exact match of the remaining content
if code_line != patch_line:
return False
if fuzziness >= 2:
# With fuzziness 2, match even if a couple of characters differ
if Levenshtein.distance(code_line, patch_line) > 3:
return False
return True
def match_code(self, code_lines: list[str], fuzziness: int) -> int:
# Try to match the hunk to code lines starting at start_line (0-based)
# Return the line where it matches, or None if no match
for i in range(0, len(code_lines) - self.match_count() + 1):
if self.matches_code(code_lines, i, fuzziness):
return i
return None
def __repr__(self) -> str:
return f"(start_original={self.start_original}, start_new={self.start_new}, match_count={self.match_count()}, replace_count={self.replace_count()})"
def extract_hunks(patch: list[str]) -> list[Hunk]:
# Go through the unified diff lines and fix the hunk headers
# For each hunk header line starting with @@, count the number of added, removed, and unchanged lines
# The hunk header format is @@ -start,count +start,count @@
# The hunk ends with another @@, ---, +++, or end of file
# The header may have incorrect line counts, so we need to recalculate them
# Identify all hunks and add them to the list
hunks = []
current_hunk_start = None
for i, line in enumerate(patch):
if line.startswith('@@') or line.startswith('+++') or line.startswith('---'):
if current_hunk_start is not None:
h = Hunk(patch[current_hunk_start], patch[current_hunk_start + 1:i])
hunks.append(h)
current_hunk_start = None
if line.startswith('@@'):
current_hunk_start = i
if current_hunk_start is not None:
hunks.append(Hunk(patch[current_hunk_start], patch[current_hunk_start + 1:]))
return hunks
def patch_code(code_lines: list[str], patch_lines: list[str], fuzziness: int = 0):
hunk_list = extract_hunks(patch_lines)
failed_hunks = 0
print(f"Extracted {len(hunk_list)} hunks:")
# identify all hunks to apply
application_list = []
for hunk in hunk_list:
if hunk.empty():
print("[SKIP] Useless hunk")
continue
# print("Hunk", hunk)
hunk_start = None
for fuzziness_level in range(fuzziness + 1):
hunk_start = hunk.match_code(code_lines, fuzziness)
if hunk_start:
if fuzziness_level > 0:
print(f"[WARNING] Hunk {hunk} applied with fuzziness {fuzziness_level}")
break
if hunk_start is None:
print(f"[FAIL] Can't apply hunk {hunk}")
failed_hunks += 1
else:
# print("[OK] Applying hunk at", hunk_start)
application_list.append((hunk_start, hunk))
# Sort application_list by start
application_list.sort(key=lambda x: x[0])
source_offset = 0
for hunk_start, hunk in application_list:
# print(f"Replacing lines {start_index} to {start_index + hunk.source_length} with {len(new_lines)} new lines.")
start = hunk_start + source_offset
code_lines[start:start + hunk.match_count()] = hunk.replace
source_offset += hunk.replace_count() - hunk.match_count()
if failed_hunks > 0:
print(f"Patch application failed. {failed_hunks}/{len(hunk_list)} hunks failed to apply.")
else:
print(f"Patch application complete. All {len(hunk_list)} hunks applied successfully.")
return failed_hunks == 0
if __name__ == "__main__":
for n in range(1,7):
original_file_name = f"test_sets/patch/test{n}.py"
patch_file_name = f"test_sets/patch/test{n}.patch"
print(f"--- Testing patching {original_file_name} with {patch_file_name} ---")
with open(original_file_name, 'r') as original_file:
original_content = original_file.read()
with open(patch_file_name, 'r') as patch_file:
patch_content = patch_file.read()
code_lines = original_content.splitlines()
patch_lines = patch_content.splitlines()
patch_code(code_lines, patch_lines, fuzziness=2)
# save the file to
with open(f"solutions/patched_file_v{n+1}.py", "w") as f:
f.write("\n".join(code_lines))
print("-" * 40)