From fe4da56e78bc143a349f52114d5f7db5fa7d50db Mon Sep 17 00:00:00 2001 From: Sebastian Hungerecker Date: Sat, 3 Jun 2023 20:27:54 +0200 Subject: [PATCH 1/9] Remove IO from functions and add tests --- main.py | 26 ++++++++++++++++--- requirements.txt | 1 + superoptimizer.py | 17 ++++++------- tests/test_cpu.py | 36 ++++++++++++++++++++++++++ tests/test_supperoptimizer.py | 48 +++++++++++++++++++++++++++++++++++ 5 files changed, 115 insertions(+), 13 deletions(-) create mode 100644 requirements.txt create mode 100644 tests/test_cpu.py create mode 100644 tests/test_supperoptimizer.py diff --git a/main.py b/main.py index 3a1324a..99e4863 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,22 @@ from superoptimizer import * + +def print_optimal_from_code(assembly, max_length, max_mem, max_val, debug=False): + print(f"***Source***{assembly}") + state = run(assembly, max_mem) + print_optimal_from_state(state, max_length, max_val, debug) + + +def print_optimal_from_state(state, max_length, max_val, debug=False): + print("***State***") + print(state) + print() + print("***Optimal***") + print(optimal_from_state(state, max_length, max_val, debug)) + print("=" * 20) + print() + + def main(): # Test 1 assembly = """ @@ -11,14 +28,15 @@ def main(): SWAP 0, 3 LOAD 3 """ - optimal_from_code(assembly, 4, 4, 5) + print_optimal_from_code(assembly, 4, 4, 5) # Test 2 state = [0, 2, 1] - optimal_from_state(state, 3, 5) + print_optimal_from_state(state, 3, 5) ## Test 3 - Careful, I don't think this will finish for days. # state = [2, 4, 6, 8, 10, 12] - # optimal_from_state(state, 10, 15, True) + # print_optimal_from_state(state, 10, 15, True) + -main() \ No newline at end of file +main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e079f8a --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +pytest diff --git a/superoptimizer.py b/superoptimizer.py index b098db1..9bdde61 100644 --- a/superoptimizer.py +++ b/superoptimizer.py @@ -2,22 +2,21 @@ from cpu import CPU import assembler -# Helper function that finds the optimal code given the assembly code. -def optimal_from_code(assembly, max_length, max_mem, max_val, debug=False): + +# Helper function that runs a piece of assembly code. +def run(assembly, max_mem): cpu = CPU(max_mem) program = assembler.parse(assembly) - state = cpu.execute(program) - print(f"***Source***{assembly}") - optimal_from_state(state, max_length, max_val, debug) + return cpu.execute(program) + # Helper function that finds the optimal code given the goal state. def optimal_from_state(state, max_length, max_val, debug=False): max_mem = len(state) - print(f"***State***\n{state}\n") opt = Superoptimizer() - shortest_program = opt.search(max_length, max_mem, max_val, state, debug) - disassembly = assembler.output(shortest_program) - print(f"***Optimal***\n{disassembly}\n{'='*20}\n") + shortest_program = opt.search(max_length, max_mem, max_val, state, debug) + return assembler.output(shortest_program) + class Superoptimizer: def __init__(self): diff --git a/tests/test_cpu.py b/tests/test_cpu.py new file mode 100644 index 0000000..d24746b --- /dev/null +++ b/tests/test_cpu.py @@ -0,0 +1,36 @@ +from superoptimizer import run + + +def test_load_and_swap(): + assembly = """ +LOAD 3 +SWAP 0, 1 +LOAD 3 +SWAP 0, 2 +LOAD 3 +SWAP 0, 3 +LOAD 3 + """ + assert run(assembly, 4) == [3, 3, 3, 3] + assert run(assembly, 10) == [3, 3, 3, 3, 0, 0, 0, 0, 0, 0] + + +def test_load_and_xor(): + assembly = """ +LOAD 42 +XOR 1, 0 +LOAD 23 +XOR 1, 0 + """ + assert run(assembly, 2) == [23, 42 ^ 23] + + +def test_load_and_inc(): + assembly = """ +LOAD 41 +INC 0 +INC 1 +INC 1 +INC 1 + """ + assert run(assembly, 2) == [42, 3] diff --git a/tests/test_supperoptimizer.py b/tests/test_supperoptimizer.py new file mode 100644 index 0000000..5bc5638 --- /dev/null +++ b/tests/test_supperoptimizer.py @@ -0,0 +1,48 @@ +from superoptimizer import optimal_from_state, run + + +def count_instructions(assembly): + if assembly == "\n": + return 0 + else: + return assembly.count("\n") + + +def assert_optimal_length(length, state, max_length, max_val): + """ + Asserts that, given the arguments `state`, `max_length` and `max_val`, the + superoptimizer finds a program of length `length` and that this program produces + the state `state` when run. + + This method does not assert a specific output program, so that tests don't + fail due to changes in the optimizer that make it output different instructions + as long as the result is still correct and optimal. + """ + optimal = optimal_from_state(state, max_length, max_val) + assert count_instructions(optimal) == length + assert run(optimal, len(state)) == state + + +def test_four_threes(): + # Optimal program is a load and three xors + assert_optimal_length(4, [3, 3, 3, 3], 4, 5) + + +def test_0_2_1(): + # Optimal program is load, swap and inc + assert_optimal_length(3, [0, 2, 1], 3, 5) + assert_optimal_length(3, [0, 2, 1, 0], 4, 5) + + +def test_zeros(): + # Optimal program is empty + # Test currently fails because the empty program isn't included in the search space + assert_optimal_length(0, [0], 1, 3) + assert_optimal_length(0, [0, 0], 1, 3) + + +def test_increasing_sequence(): + # Optimal program is inc for first memory slot and then xor+inc for every one after that + assert_optimal_length(1, [1], 2, 3) + assert_optimal_length(3, [1, 2], 4, 3) + assert_optimal_length(5, [1, 2, 3], 5, 3) From a2297b2520f0f070545ca844b4dd20cf33893b2e Mon Sep 17 00:00:00 2001 From: Sebastian Hungerecker Date: Sat, 3 Jun 2023 20:52:41 +0200 Subject: [PATCH 2/9] Fix wrong results when the shortest program is empty For a state that's all zeros, the superoptimizer used to output "LOAD 0" as the shortest program when in fact "LOAD 0" is a no-op and the shortest program to get all zeros is the empty program. --- superoptimizer.py | 5 +++-- tests/test_supperoptimizer.py | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/superoptimizer.py b/superoptimizer.py index 9bdde61..3b6d2f3 100644 --- a/superoptimizer.py +++ b/superoptimizer.py @@ -24,18 +24,19 @@ def __init__(self): # Generates all possible programs. def generate_programs(self, cpu, max_length, max_mem, max_val): + yield [] for length in range(1, max_length + 1): for prog in product(cpu.ops.values(), repeat=length): arg_sets = [] for op in prog: if op == cpu.load: arg_sets.append([tuple([val]) for val in range(max_val + 1)]) - elif op == cpu.swap or op == cpu.xor: + elif op == cpu.swap or op == cpu.xor: arg_sets.append(product(range(max_mem), repeat=2)) elif op == cpu.inc: arg_sets.append([tuple([val]) for val in range(max_mem)]) for arg_set in product(*arg_sets): - program = [(op, *args) for op, args in zip(prog, arg_set)] + program = [(op, *args) for op, args in zip(prog, arg_set)] yield program # Tests all of the generated programs and returns the shortest. diff --git a/tests/test_supperoptimizer.py b/tests/test_supperoptimizer.py index 5bc5638..9d99183 100644 --- a/tests/test_supperoptimizer.py +++ b/tests/test_supperoptimizer.py @@ -36,7 +36,6 @@ def test_0_2_1(): def test_zeros(): # Optimal program is empty - # Test currently fails because the empty program isn't included in the search space assert_optimal_length(0, [0], 1, 3) assert_optimal_length(0, [0, 0], 1, 3) From 835827d741c5361c24df6fd76eef401a1e9fc39f Mon Sep 17 00:00:00 2001 From: Sebastian Hungerecker Date: Sat, 3 Jun 2023 21:35:41 +0200 Subject: [PATCH 3/9] Stop searching when the optimal program is found Since programs are generated in ascending order of length, the first program to be found that produces the right state will be the optimal one. Thus there's no reason to keep searching once a program has been found. This improves performance quite a bit in some cases and removes the need to guess the optimal length to provide a short maximum length. Additionally, this removes the program cache as it would only help when reusing results from previous searches, which is currently being done and wouldn't be correct when using different options (like max_value) anyway. --- superoptimizer.py | 17 +++++------------ tests/test_supperoptimizer.py | 26 +++++++++++++++----------- 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/superoptimizer.py b/superoptimizer.py index 3b6d2f3..579c7e6 100644 --- a/superoptimizer.py +++ b/superoptimizer.py @@ -19,9 +19,6 @@ def optimal_from_state(state, max_length, max_val, debug=False): class Superoptimizer: - def __init__(self): - self.program_cache = {} - # Generates all possible programs. def generate_programs(self, cpu, max_length, max_mem, max_val): yield [] @@ -46,16 +43,12 @@ def search(self, max_length, max_mem, max_val, target_state, debug=False): for program in self.generate_programs(cpu, max_length, max_mem, max_val): state = cpu.execute(program) if state == target_state: - state = tuple(state) - if state not in self.program_cache or len(program) < len(self.program_cache[state]): - self.program_cache[state] = program - + return program + # Debugging. if debug: count += 1 - if count % 1000000 == 0: print(f"Programs searched: {count:,}") - if count % 10000000 == 0: - solution = self.program_cache.get(tuple(target_state), None) - print(f"Best solution: {solution}") + if count % 1000000 == 0: + print(f"Programs searched: {count:,}") - return self.program_cache.get(tuple(target_state), None) + return None diff --git a/tests/test_supperoptimizer.py b/tests/test_supperoptimizer.py index 9d99183..b0ff472 100644 --- a/tests/test_supperoptimizer.py +++ b/tests/test_supperoptimizer.py @@ -1,6 +1,9 @@ from superoptimizer import optimal_from_state, run +MAX_LENGTH = 1000000 + + def count_instructions(assembly): if assembly == "\n": return 0 @@ -8,7 +11,7 @@ def count_instructions(assembly): return assembly.count("\n") -def assert_optimal_length(length, state, max_length, max_val): +def assert_optimal_length(optimal_length, state, max_val): """ Asserts that, given the arguments `state`, `max_length` and `max_val`, the superoptimizer finds a program of length `length` and that this program produces @@ -18,30 +21,31 @@ def assert_optimal_length(length, state, max_length, max_val): fail due to changes in the optimizer that make it output different instructions as long as the result is still correct and optimal. """ - optimal = optimal_from_state(state, max_length, max_val) - assert count_instructions(optimal) == length + optimal = optimal_from_state(state, MAX_LENGTH, max_val) + assert count_instructions(optimal) == optimal_length assert run(optimal, len(state)) == state + assert optimal_from_state(state, optimal_length, max_val) == optimal def test_four_threes(): # Optimal program is a load and three xors - assert_optimal_length(4, [3, 3, 3, 3], 4, 5) + assert_optimal_length(4, [3, 3, 3, 3], 5) def test_0_2_1(): # Optimal program is load, swap and inc - assert_optimal_length(3, [0, 2, 1], 3, 5) - assert_optimal_length(3, [0, 2, 1, 0], 4, 5) + assert_optimal_length(3, [0, 2, 1], 5) + assert_optimal_length(3, [0, 2, 1, 0], 5) def test_zeros(): # Optimal program is empty - assert_optimal_length(0, [0], 1, 3) - assert_optimal_length(0, [0, 0], 1, 3) + assert_optimal_length(0, [0], 3) + assert_optimal_length(0, [0, 0], 3) def test_increasing_sequence(): # Optimal program is inc for first memory slot and then xor+inc for every one after that - assert_optimal_length(1, [1], 2, 3) - assert_optimal_length(3, [1, 2], 4, 3) - assert_optimal_length(5, [1, 2, 3], 5, 3) + assert_optimal_length(1, [1], 3) + assert_optimal_length(3, [1, 2], 3) + assert_optimal_length(5, [1, 2, 3], 3) From 79a31379c9b18c7c6bc61ab562b397654c3d3337 Mon Sep 17 00:00:00 2001 From: Sebastian Hungerecker Date: Sat, 3 Jun 2023 21:44:52 +0200 Subject: [PATCH 4/9] Fix exception when no program could be found --- assembler.py | 6 +++++- tests/test_supperoptimizer.py | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/assembler.py b/assembler.py index 01fae8f..8580464 100644 --- a/assembler.py +++ b/assembler.py @@ -15,9 +15,13 @@ def parse(assembly): program.append((op, *args)) return program + # Turns a program into a string. def output(program): - if len(program) == 0: return "\n" + if program == None: + return None + if len(program) == 0: + return "\n" cpu = CPU(1) assembly = "" for instruction in program: diff --git a/tests/test_supperoptimizer.py b/tests/test_supperoptimizer.py index b0ff472..b569fb6 100644 --- a/tests/test_supperoptimizer.py +++ b/tests/test_supperoptimizer.py @@ -25,6 +25,8 @@ def assert_optimal_length(optimal_length, state, max_val): assert count_instructions(optimal) == optimal_length assert run(optimal, len(state)) == state assert optimal_from_state(state, optimal_length, max_val) == optimal + if optimal_length > 0: + assert optimal_from_state(state, optimal_length - 1, max_val) == None def test_four_threes(): From 8483c9ad90e90a16bf0602b482ba22490c76a852 Mon Sep 17 00:00:00 2001 From: Sebastian Hungerecker Date: Sat, 10 Jun 2023 20:13:54 +0200 Subject: [PATCH 5/9] Implement specifying the first n bytes of memory as input Also implement an equivalence checker that can determine program equivalence in the presence of input (by trying all possible input values). --- README.md | 2 - brute_force_equivialence_checker.py | 18 ++++ cpu.py | 3 +- main.py | 19 ++-- superoptimizer.py | 35 ++++--- tests/test_cpu.py | 10 ++ tests/test_supperoptimizer.py | 150 +++++++++++++++++++++------- 7 files changed, 175 insertions(+), 62 deletions(-) create mode 100644 brute_force_equivialence_checker.py diff --git a/README.md b/README.md index 48521a1..8aa5396 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,6 @@ To focus on the superoptimizer and not making a comprehensive, realistic assembl There are many possible improvements: -- **Start state.** Right now it assumes the start state is always the same, which means there is no concept of program input. -- **Program equivalence.** A set of inputs and outputs should be specified such that two programs can actually be tested for equivalence. - **Pruning.** Many nonsensical programs are generated, which significantly slows it down. - **More instructions.** There need to be more instructions, especially a conditional instruction, to give the superoptimizer more opportunities to make improvements. diff --git a/brute_force_equivialence_checker.py b/brute_force_equivialence_checker.py new file mode 100644 index 0000000..ff83f61 --- /dev/null +++ b/brute_force_equivialence_checker.py @@ -0,0 +1,18 @@ +from cpu import CPU + + +def generate_inputs(input_size, max_val): + if input_size == 0: + yield () + else: + for x in range(max_val + 1): + for rest in generate_inputs(input_size - 1, max_val): + yield (x, *rest) + + +def are_equivalent(program1, program2, max_mem, max_val, input_size): + cpu = CPU(max_mem) + for input in generate_inputs(input_size, max_val): + if cpu.execute(program1, input) != cpu.execute(program2, input): + return False + return True diff --git a/cpu.py b/cpu.py index d7d1897..857c0d2 100644 --- a/cpu.py +++ b/cpu.py @@ -4,8 +4,9 @@ def __init__(self, max_mem_cells): self.state = [0] * max_mem_cells self.ops = {'LOAD': self.load, 'SWAP': self.swap, 'XOR': self.xor, 'INC': self.inc} - def execute(self, program): + def execute(self, program, input=()): state = self.state.copy() + state[0 : len(input)] = input for instruction in program: op = instruction[0] args = list(instruction[1:]) diff --git a/main.py b/main.py index 99e4863..ed565b7 100644 --- a/main.py +++ b/main.py @@ -4,15 +4,11 @@ def print_optimal_from_code(assembly, max_length, max_mem, max_val, debug=False): print(f"***Source***{assembly}") state = run(assembly, max_mem) - print_optimal_from_state(state, max_length, max_val, debug) - - -def print_optimal_from_state(state, max_length, max_val, debug=False): print("***State***") print(state) print() print("***Optimal***") - print(optimal_from_state(state, max_length, max_val, debug)) + print(optimize(assembly, max_length, max_mem, max_val, debug)) print("=" * 20) print() @@ -31,12 +27,13 @@ def main(): print_optimal_from_code(assembly, 4, 4, 5) # Test 2 - state = [0, 2, 1] - print_optimal_from_state(state, 3, 5) - - ## Test 3 - Careful, I don't think this will finish for days. - # state = [2, 4, 6, 8, 10, 12] - # print_optimal_from_state(state, 10, 15, True) + assembly = """ +LOAD 2 +SWAP 0, 1 +LOAD 1 +SWAP 0, 2 + """ + print_optimal_from_code(assembly, 3, 3, 5) main() diff --git a/superoptimizer.py b/superoptimizer.py index 579c7e6..c8f4bf8 100644 --- a/superoptimizer.py +++ b/superoptimizer.py @@ -1,24 +1,32 @@ from itertools import product from cpu import CPU import assembler +import brute_force_equivialence_checker -# Helper function that runs a piece of assembly code. -def run(assembly, max_mem): +def run(assembly, max_mem, input=()): + """ + Helper function that runs a piece of assembly code. + """ cpu = CPU(max_mem) program = assembler.parse(assembly) - return cpu.execute(program) + return cpu.execute(program, input) -# Helper function that finds the optimal code given the goal state. -def optimal_from_state(state, max_length, max_val, debug=False): - max_mem = len(state) - opt = Superoptimizer() - shortest_program = opt.search(max_length, max_mem, max_val, state, debug) - return assembler.output(shortest_program) +def optimize(assembly, max_length, max_mem, max_val, input_size=0, debug=False): + """ + Helper function that finds the optimal code given the assembly code. + """ + program = assembler.parse(assembly) + opt = Superoptimizer(brute_force_equivialence_checker.are_equivalent) + shortest = opt.search(max_length, max_mem, max_val, program, input_size, debug) + return assembler.output(shortest) class Superoptimizer: + def __init__(self, are_equivalent): + self.are_equivalent = are_equivalent + # Generates all possible programs. def generate_programs(self, cpu, max_length, max_mem, max_val): yield [] @@ -37,13 +45,12 @@ def generate_programs(self, cpu, max_length, max_mem, max_val): yield program # Tests all of the generated programs and returns the shortest. - def search(self, max_length, max_mem, max_val, target_state, debug=False): + def search(self, max_length, max_mem, max_val, program, input_size=0, debug=False): count = 0 cpu = CPU(max_mem) - for program in self.generate_programs(cpu, max_length, max_mem, max_val): - state = cpu.execute(program) - if state == target_state: - return program + for optimal in self.generate_programs(cpu, max_length, max_mem, max_val): + if self.are_equivalent(optimal, program, max_mem, max_val, input_size): + return optimal # Debugging. if debug: diff --git a/tests/test_cpu.py b/tests/test_cpu.py index d24746b..16fd023 100644 --- a/tests/test_cpu.py +++ b/tests/test_cpu.py @@ -34,3 +34,13 @@ def test_load_and_inc(): INC 1 """ assert run(assembly, 2) == [42, 3] + + +def test_input(): + assembly = """ +XOR 1, 0 +INC 1 + """ + assert run(assembly, 2) == [0, 1] + assert run(assembly, 2, [2]) == [2, 3] + assert run(assembly, 2, [1, 2]) == [1, 4] diff --git a/tests/test_supperoptimizer.py b/tests/test_supperoptimizer.py index b569fb6..f604ca7 100644 --- a/tests/test_supperoptimizer.py +++ b/tests/test_supperoptimizer.py @@ -1,53 +1,135 @@ -from superoptimizer import optimal_from_state, run +from superoptimizer import optimize MAX_LENGTH = 1000000 -def count_instructions(assembly): - if assembly == "\n": - return 0 - else: - return assembly.count("\n") +def test_four_threes(): + assembly = """ +LOAD 3 +SWAP 0, 1 +LOAD 3 +SWAP 0, 2 +LOAD 3 +SWAP 0, 3 +LOAD 3 + """ + optimal = """ +LOAD 3 +XOR 1, 0 +XOR 2, 0 +XOR 3, 0 + """.strip() + "\n" + assert optimize(assembly, MAX_LENGTH, 4, 3, 0) == optimal -def assert_optimal_length(optimal_length, state, max_val): +def test_three_threes(): + assembly = """ +LOAD 3 +SWAP 0, 1 +LOAD 3 +SWAP 0, 2 +LOAD 3 """ - Asserts that, given the arguments `state`, `max_length` and `max_val`, the - superoptimizer finds a program of length `length` and that this program produces - the state `state` when run. + optimal = """ +LOAD 3 +XOR 1, 0 +XOR 2, 0 + """.strip() + "\n" + assert optimize(assembly, MAX_LENGTH, 3, 3, 0) == optimal + assert optimize(assembly, 3, 3, 3, 0) == optimal + assert optimize(assembly, 2, 3, 3, 0) == None + assert optimize(assembly, 3, 3, 2, 0) == None - This method does not assert a specific output program, so that tests don't - fail due to changes in the optimizer that make it output different instructions - as long as the result is still correct and optimal. - """ - optimal = optimal_from_state(state, MAX_LENGTH, max_val) - assert count_instructions(optimal) == optimal_length - assert run(optimal, len(state)) == state - assert optimal_from_state(state, optimal_length, max_val) == optimal - if optimal_length > 0: - assert optimal_from_state(state, optimal_length - 1, max_val) == None + # Changing the input size to 1 doesn't change anything as the first input will be overridden by the load + assert optimize(assembly, MAX_LENGTH, 3, 3, 1) == optimal - -def test_four_threes(): - # Optimal program is a load and three xors - assert_optimal_length(4, [3, 3, 3, 3], 5) + # For input size 2, we'll need to clear the second input using swap and another load + optimal = """ +LOAD 3 +SWAP 0, 1 +LOAD 3 +XOR 2, 0 + """.strip() + "\n" + assert optimize(assembly, MAX_LENGTH, 3, 3, 2) == optimal def test_0_2_1(): + assembly = """ +LOAD 2 +SWAP 0, 1 +LOAD 1 +SWAP 0, 2 + """ + optimal = """ +LOAD 2 +SWAP 0, 1 +INC 2 + """.strip() + "\n" # Optimal program is load, swap and inc - assert_optimal_length(3, [0, 2, 1], 5) - assert_optimal_length(3, [0, 2, 1, 0], 5) + assert optimize(assembly, MAX_LENGTH, 3, 2, 0) == optimal + assert optimize(assembly, MAX_LENGTH, 4, 2, 0) == optimal -def test_zeros(): - # Optimal program is empty - assert_optimal_length(0, [0], 3) - assert_optimal_length(0, [0, 0], 3) +def test_no_op(): + assembly = """ +SWAP 0,0 + """ + # Program results in the memory being unchanged, so optimal program is empty + assert optimize(assembly, MAX_LENGTH, 1, 3, 0) == "\n" + assert optimize(assembly, MAX_LENGTH, 2, 3, 0) == "\n" + assert optimize(assembly, MAX_LENGTH, 2, 2, 1) == "\n" + assert optimize(assembly, MAX_LENGTH, 2, 2, 2) == "\n" def test_increasing_sequence(): - # Optimal program is inc for first memory slot and then xor+inc for every one after that - assert_optimal_length(1, [1], 3) - assert_optimal_length(3, [1, 2], 3) - assert_optimal_length(5, [1, 2, 3], 3) + assembly = """ +INC 0 +INC 1 +INC 1 + """ + optimal = """ +LOAD 2 +SWAP 0, 1 +LOAD 1 + """.strip() + "\n" + assert optimize(assembly, MAX_LENGTH, 2, 3, 0) == optimal + + assembly = """ +INC 0 +INC 1 +INC 1 +INC 2 +INC 2 +INC 2 + """ + optimal = """ +LOAD 2 +SWAP 0, 1 +LOAD 3 +SWAP 0, 2 +LOAD 1 + """.strip() + "\n" + assert optimize(assembly, MAX_LENGTH, 3, 3, 0) == optimal + + +def test_increasing_from_input(): + # Given the input x, the following program should produce the sequence x+1, x+2, x+3 + assembly = """ +XOR 1, 0 +XOR 2, 0 +INC 0 +INC 1 +INC 1 +INC 2 +INC 2 +INC 2 + """ + optimal = """ +INC 0 +XOR 1, 0 +INC 1 +XOR 2, 1 +INC 2 + """.strip() + "\n" + assert optimize(assembly, MAX_LENGTH, 3, 2, 1) == optimal From 0cff706508415f4c06843b04e3089faba1af3f7c Mon Sep 17 00:00:00 2001 From: Sebastian Hungerecker Date: Sun, 11 Jun 2023 19:58:17 +0200 Subject: [PATCH 6/9] Rework how programs are represented, parsed, run and generated The program representation now consists of classes with a `__str__` method, so `assembler.output` is no longer needed. More importantly, a program now contains the information how much memory it needs, removing the need to manually set this parameter and removing the possibility of accidentally using a wrong value. The opcodes are now represented as strings rather than functions from the CPU, so that the program representation is no longer tied to the CPU. This removes the need to instantiate dummy CPU objects when you don't actually want to run a program and more importantly makes this representation more usable when working with the program independently of the CPU. The program representation is now defined in the instruction_set module, which removes the dependency from the assembler module to the CPU module. This made it possible to move the `run` helper method into the CPU module (where it belongs) without introducing a cyclical dependency). The CPU, assembler and program generator have been updated according to these changes. The assembler is now also a bit more robust. It now allows leading white space and produces errors on invalid lines, including non-existant opcodes, wrong numbers of arguments or invalid memory addresses. The arity checks and the tracking of the memory size happens based on the definitions in instruction_set and when new instructions are added there, the assembler should be able to parse and check those new instructions without needing to be modified. Similarly the code generator now works solely from the defintions in `instruction_set` and should likewise work without modification when new instructions are added. --- assembler.py | 60 +++++---- brute_force_equivialence_checker.py | 5 +- cpu.py | 51 ++++---- instruction_set.py | 35 ++++++ main.py | 25 ++-- superoptimizer.py | 60 ++++----- tests/test_assembler.py | 50 ++++++++ tests/test_cpu.py | 56 +++++---- tests/test_supperoptimizer.py | 186 ++++++++++++++-------------- 9 files changed, 306 insertions(+), 222 deletions(-) create mode 100644 instruction_set.py create mode 100644 tests/test_assembler.py diff --git a/assembler.py b/assembler.py index 8580464..7f073c3 100644 --- a/assembler.py +++ b/assembler.py @@ -1,38 +1,34 @@ import re -from cpu import CPU +from instruction_set import * + + +INSTRUCTION_REGEX = re.compile(r'(\w+)\s+([-\d]+(?:\s*,\s*[-\d]+)*)') + -# Turns a string into a program. def parse(assembly): + """ + Turns a string into a program + """ lines = assembly.split('\n') - program = [] - cpu = CPU(1) + instructions = [] + mem_size = 1 for line in lines: - match = re.match(r'(\w+)\s+([-\d]+)(?:,\s*([-\d]+)(?:,\s*([-\d]+))?)?', line) + line = line.strip() + if line == '': + continue + match = INSTRUCTION_REGEX.fullmatch(line) if match: - op_str, *args_str = match.groups() - op = cpu.ops[op_str] - args = [int(arg) for arg in args_str if arg is not None] - program.append((op, *args)) - return program - - -# Turns a program into a string. -def output(program): - if program == None: - return None - if len(program) == 0: - return "\n" - cpu = CPU(1) - assembly = "" - for instruction in program: - op = instruction[0] - args = instruction[1:] - if op.__name__ == cpu.load.__name__: - assembly += f"LOAD {args[0]}\n" - elif op.__name__ == cpu.swap.__name__: - assembly += f"SWAP {args[0]}, {args[1]}\n" - elif op.__name__ == cpu.xor.__name__: - assembly += f"XOR {args[0]}, {args[1]}\n" - elif op.__name__ == cpu.inc.__name__: - assembly += f"INC {args[0]}\n" - return assembly \ No newline at end of file + op, args_str = match.groups() + args = tuple(int(arg) for arg in args_str.split(",")) + operand_types = OPS[op] + if len(args) != len(operand_types): + raise ValueError(f'Wrong number of operands: {line}') + for arg, arg_type in zip(args, operand_types): + if arg_type == 'mem': + if arg < 0: + raise ValueError(f'Negative memory address: {line}') + mem_size = max(arg + 1, mem_size) + instructions.append(Instruction(op, args)) + else: + raise ValueError(f'Invalid syntax: {line}') + return Program(tuple(instructions), mem_size) diff --git a/brute_force_equivialence_checker.py b/brute_force_equivialence_checker.py index ff83f61..35f434f 100644 --- a/brute_force_equivialence_checker.py +++ b/brute_force_equivialence_checker.py @@ -10,8 +10,9 @@ def generate_inputs(input_size, max_val): yield (x, *rest) -def are_equivalent(program1, program2, max_mem, max_val, input_size): - cpu = CPU(max_mem) +def are_equivalent(program1, program2, max_val, input_size): + mem_size = max(program1.mem_size, program2.mem_size) + cpu = CPU(mem_size) for input in generate_inputs(input_size, max_val): if cpu.execute(program1, input) != cpu.execute(program2, input): return False diff --git a/cpu.py b/cpu.py index 857c0d2..6eba046 100644 --- a/cpu.py +++ b/cpu.py @@ -1,31 +1,32 @@ +import assembler + + +def run(assembly, input=()): + """ + Helper function that runs a piece of assembly code. + """ + program = assembler.parse(assembly) + cpu = CPU(program.mem_size) + return cpu.execute(program, input) + + class CPU: def __init__(self, max_mem_cells): self.max_mem_cells = max_mem_cells - self.state = [0] * max_mem_cells - self.ops = {'LOAD': self.load, 'SWAP': self.swap, 'XOR': self.xor, 'INC': self.inc} def execute(self, program, input=()): - state = self.state.copy() - state[0 : len(input)] = input - for instruction in program: - op = instruction[0] - args = list(instruction[1:]) - args.insert(0, state) - state = op(*args) - return state - - def load(self, state, val): - state[0] = val - return state - - def swap(self, state, mem1, mem2): - state[mem1], state[mem2] = state[mem2], state[mem1] - return state - - def xor(self, state, mem1, mem2): - state[mem1] = state[mem1] ^ state[mem2] + state = [0] * self.max_mem_cells + state[0: len(input)] = input + for instruction in program.instructions: + match instruction.opcode: + case 'LOAD': + state[0] = instruction.args[0] + case 'SWAP': + mem1, mem2 = instruction.args + state[mem1], state[mem2] = state[mem2], state[mem1] + case 'XOR': + mem1, mem2 = instruction.args + state[mem1] ^= state[mem2] + case 'INC': + state[instruction.args[0]] += 1 return state - - def inc(self, state, mem): - state[mem] += 1 - return state \ No newline at end of file diff --git a/instruction_set.py b/instruction_set.py new file mode 100644 index 0000000..4af93a8 --- /dev/null +++ b/instruction_set.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass + + +@dataclass +class Instruction: + opcode: str + args: tuple[int, ...] + + def __str__(self): + args = ", ".join(str(arg) for arg in self.args) + return f"{self.opcode} {args}" + + +@dataclass +class Program: + instructions: tuple[Instruction, ...] + """ + The instructions that make up this program + """ + + mem_size: int + """ + The amount of memory needed to run this program + """ + + def __str__(self): + return "\n".join(str(instr) for instr in self.instructions) + "\n" + + +OPS = { + "LOAD": ("const",), + "SWAP": ("mem", "mem"), + "XOR": ("mem", "mem"), + "INC": ("mem",) +} diff --git a/main.py b/main.py index ed565b7..acdaf25 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ -from superoptimizer import * +from superoptimizer import optimize +from cpu import run def print_optimal_from_code(assembly, max_length, max_mem, max_val, debug=False): @@ -16,22 +17,22 @@ def print_optimal_from_code(assembly, max_length, max_mem, max_val, debug=False) def main(): # Test 1 assembly = """ -LOAD 3 -SWAP 0, 1 -LOAD 3 -SWAP 0, 2 -LOAD 3 -SWAP 0, 3 -LOAD 3 + LOAD 3 + SWAP 0, 1 + LOAD 3 + SWAP 0, 2 + LOAD 3 + SWAP 0, 3 + LOAD 3 """ print_optimal_from_code(assembly, 4, 4, 5) # Test 2 assembly = """ -LOAD 2 -SWAP 0, 1 -LOAD 1 -SWAP 0, 2 + LOAD 2 + SWAP 0, 1 + LOAD 1 + SWAP 0, 2 """ print_optimal_from_code(assembly, 3, 3, 5) diff --git a/superoptimizer.py b/superoptimizer.py index c8f4bf8..7f018b9 100644 --- a/superoptimizer.py +++ b/superoptimizer.py @@ -1,55 +1,49 @@ from itertools import product -from cpu import CPU import assembler import brute_force_equivialence_checker +from instruction_set import * -def run(assembly, max_mem, input=()): - """ - Helper function that runs a piece of assembly code. - """ - cpu = CPU(max_mem) - program = assembler.parse(assembly) - return cpu.execute(program, input) - - -def optimize(assembly, max_length, max_mem, max_val, input_size=0, debug=False): +def optimize(assembly, max_length, max_val, input_size=0, debug=False): """ Helper function that finds the optimal code given the assembly code. """ program = assembler.parse(assembly) opt = Superoptimizer(brute_force_equivialence_checker.are_equivalent) - shortest = opt.search(max_length, max_mem, max_val, program, input_size, debug) - return assembler.output(shortest) + return opt.search(max_length, max_val, program, input_size, debug) class Superoptimizer: def __init__(self, are_equivalent): self.are_equivalent = are_equivalent + @staticmethod + def generate_operands(operand_type, max_mem, max_val): + if operand_type == "const": + return range(max_val+1) + elif operand_type == "mem": + return range(max_mem) + else: + raise ValueError(f"Illegal operand type: {operand_type}") + # Generates all possible programs. - def generate_programs(self, cpu, max_length, max_mem, max_val): - yield [] + + @staticmethod + def generate_programs(max_length, max_mem, max_val): + yield Program((), 0) for length in range(1, max_length + 1): - for prog in product(cpu.ops.values(), repeat=length): - arg_sets = [] - for op in prog: - if op == cpu.load: - arg_sets.append([tuple([val]) for val in range(max_val + 1)]) - elif op == cpu.swap or op == cpu.xor: - arg_sets.append(product(range(max_mem), repeat=2)) - elif op == cpu.inc: - arg_sets.append([tuple([val]) for val in range(max_mem)]) - for arg_set in product(*arg_sets): - program = [(op, *args) for op, args in zip(prog, arg_set)] - yield program - - # Tests all of the generated programs and returns the shortest. - def search(self, max_length, max_mem, max_val, program, input_size=0, debug=False): + instructions = [] + for op, operand_types in OPS.items(): + arg_sets = (Superoptimizer.generate_operands(ot, max_mem, max_val) for ot in operand_types) + instructions.extend(assembler.Instruction(op, args) for args in product(*arg_sets)) + for prog in product(instructions, repeat=length): + yield Program(prog, max_mem) + + # Tests all the generated programs and returns the shortest. + def search(self, max_length, max_val, program, input_size=0, debug=False): count = 0 - cpu = CPU(max_mem) - for optimal in self.generate_programs(cpu, max_length, max_mem, max_val): - if self.are_equivalent(optimal, program, max_mem, max_val, input_size): + for optimal in self.generate_programs(max_length, program.mem_size, max_val): + if self.are_equivalent(optimal, program, max_val, input_size): return optimal # Debugging. diff --git a/tests/test_assembler.py b/tests/test_assembler.py new file mode 100644 index 0000000..2c730ae --- /dev/null +++ b/tests/test_assembler.py @@ -0,0 +1,50 @@ +import pytest + +from assembler import parse +from instruction_set import * + + +def test_empty_program(): + assert parse('') == Program((), 1) + + +def test_that_all_instructions_and_mem_size(): + assembly = """ + LOAD 42 + XOR 2, 3 + SWAP 42, 23 + INC 13 + """ + instructions = ( + Instruction('LOAD', (42,)), + Instruction('XOR', (2, 3)), + Instruction('SWAP', (42, 23)), + Instruction('INC', (13,)) + ) + assert parse(assembly) == Program(instructions, 43) + + +def test_syntax_errors(): + with pytest.raises(ValueError): + parse("LOAD !&%*") + + with pytest.raises(ValueError): + parse("LOAD") + + with pytest.raises(ValueError): + parse("LOAD 23, 42") + + with pytest.raises(ValueError): + parse("XOR 23") + + with pytest.raises(ValueError): + parse("XOR 23, -42") + + with pytest.raises(ValueError): + parse("INC 1, 2") + + with pytest.raises(ValueError): + parse("SWAP 23") + + with pytest.raises(ValueError): + parse("SWAP 23, -42") diff --git a/tests/test_cpu.py b/tests/test_cpu.py index 16fd023..dd4c599 100644 --- a/tests/test_cpu.py +++ b/tests/test_cpu.py @@ -1,46 +1,50 @@ -from superoptimizer import run +from cpu import run def test_load_and_swap(): assembly = """ -LOAD 3 -SWAP 0, 1 -LOAD 3 -SWAP 0, 2 -LOAD 3 -SWAP 0, 3 -LOAD 3 + LOAD 3 + SWAP 0, 1 + LOAD 3 + SWAP 0, 2 + LOAD 3 + SWAP 0, 3 + LOAD 3 """ - assert run(assembly, 4) == [3, 3, 3, 3] - assert run(assembly, 10) == [3, 3, 3, 3, 0, 0, 0, 0, 0, 0] + assert run(assembly) == [3, 3, 3, 3] def test_load_and_xor(): assembly = """ -LOAD 42 -XOR 1, 0 -LOAD 23 -XOR 1, 0 + LOAD 42 + XOR 1, 0 + LOAD 23 + XOR 1, 0 """ - assert run(assembly, 2) == [23, 42 ^ 23] + assert run(assembly) == [23, 42 ^ 23] def test_load_and_inc(): assembly = """ -LOAD 41 -INC 0 -INC 1 -INC 1 -INC 1 + LOAD 41 + INC 0 + INC 1 + INC 1 + INC 1 """ - assert run(assembly, 2) == [42, 3] + assert run(assembly) == [42, 3] def test_input(): assembly = """ -XOR 1, 0 -INC 1 + XOR 1, 0 + INC 1 """ - assert run(assembly, 2) == [0, 1] - assert run(assembly, 2, [2]) == [2, 3] - assert run(assembly, 2, [1, 2]) == [1, 4] + assert run(assembly) == [0, 1] + assert run(assembly, [2]) == [2, 3] + assert run(assembly, [1, 2]) == [1, 4] + + +def test_load_only(): + assembly = 'LOAD 42' + assert run(assembly) == [42] diff --git a/tests/test_supperoptimizer.py b/tests/test_supperoptimizer.py index f604ca7..e7f84cc 100644 --- a/tests/test_supperoptimizer.py +++ b/tests/test_supperoptimizer.py @@ -1,4 +1,5 @@ from superoptimizer import optimize +from assembler import parse, Program MAX_LENGTH = 1000000 @@ -6,130 +7,131 @@ def test_four_threes(): assembly = """ -LOAD 3 -SWAP 0, 1 -LOAD 3 -SWAP 0, 2 -LOAD 3 -SWAP 0, 3 -LOAD 3 + LOAD 3 + SWAP 0, 1 + LOAD 3 + SWAP 0, 2 + LOAD 3 + SWAP 0, 3 + LOAD 3 """ - optimal = """ -LOAD 3 -XOR 1, 0 -XOR 2, 0 -XOR 3, 0 - """.strip() + "\n" - assert optimize(assembly, MAX_LENGTH, 4, 3, 0) == optimal + optimal = parse(""" + LOAD 3 + XOR 1, 0 + XOR 2, 0 + XOR 3, 0 + """) + assert optimize(assembly, MAX_LENGTH, 3, 0) == optimal def test_three_threes(): assembly = """ -LOAD 3 -SWAP 0, 1 -LOAD 3 -SWAP 0, 2 -LOAD 3 + LOAD 3 + SWAP 0, 1 + LOAD 3 + SWAP 0, 2 + LOAD 3 """ - optimal = """ -LOAD 3 -XOR 1, 0 -XOR 2, 0 - """.strip() + "\n" - assert optimize(assembly, MAX_LENGTH, 3, 3, 0) == optimal - assert optimize(assembly, 3, 3, 3, 0) == optimal - assert optimize(assembly, 2, 3, 3, 0) == None - assert optimize(assembly, 3, 3, 2, 0) == None + optimal = parse(""" + LOAD 3 + XOR 1, 0 + XOR 2, 0 + """) + assert optimize(assembly, MAX_LENGTH, 3, 0) == optimal + # Assert that the program is still found with a tight max_length + assert optimize(assembly, 3, 3, 0) == optimal + # Assert that the program is not found with a max_length that's below the optimal length + assert optimize(assembly, 2, 3, 0) is None + # Assert that the program is not found when max_val is too low + assert optimize(assembly, 3, 2, 0) is None # Changing the input size to 1 doesn't change anything as the first input will be overridden by the load - assert optimize(assembly, MAX_LENGTH, 3, 3, 1) == optimal + assert optimize(assembly, MAX_LENGTH, 3, 1) == optimal # For input size 2, we'll need to clear the second input using swap and another load - optimal = """ -LOAD 3 -SWAP 0, 1 -LOAD 3 -XOR 2, 0 - """.strip() + "\n" - assert optimize(assembly, MAX_LENGTH, 3, 3, 2) == optimal + optimal = parse(""" + LOAD 3 + SWAP 0, 1 + LOAD 3 + XOR 2, 0 + """) + assert optimize(assembly, MAX_LENGTH, 3, 2) == optimal def test_0_2_1(): assembly = """ -LOAD 2 -SWAP 0, 1 -LOAD 1 -SWAP 0, 2 + LOAD 2 + SWAP 0, 1 + LOAD 1 + SWAP 0, 2 """ - optimal = """ -LOAD 2 -SWAP 0, 1 -INC 2 - """.strip() + "\n" - # Optimal program is load, swap and inc - assert optimize(assembly, MAX_LENGTH, 3, 2, 0) == optimal - assert optimize(assembly, MAX_LENGTH, 4, 2, 0) == optimal + optimal = parse(""" + LOAD 2 + SWAP 0, 1 + INC 2 + """) + assert optimize(assembly, MAX_LENGTH, 2, 0) == optimal def test_no_op(): assembly = """ -SWAP 0,0 + SWAP 0,0 """ + empty_program = Program((), 0) # Program results in the memory being unchanged, so optimal program is empty - assert optimize(assembly, MAX_LENGTH, 1, 3, 0) == "\n" - assert optimize(assembly, MAX_LENGTH, 2, 3, 0) == "\n" - assert optimize(assembly, MAX_LENGTH, 2, 2, 1) == "\n" - assert optimize(assembly, MAX_LENGTH, 2, 2, 2) == "\n" + assert optimize(assembly, MAX_LENGTH, 3, 0) == empty_program + assert optimize(assembly, MAX_LENGTH, 2, 1) == empty_program + assert optimize(assembly, MAX_LENGTH, 2, 2) == empty_program def test_increasing_sequence(): assembly = """ -INC 0 -INC 1 -INC 1 + INC 0 + INC 1 + INC 1 """ - optimal = """ -LOAD 2 -SWAP 0, 1 -LOAD 1 - """.strip() + "\n" - assert optimize(assembly, MAX_LENGTH, 2, 3, 0) == optimal + optimal = parse(""" + LOAD 1 + XOR 1, 0 + INC 1 + """) + assert optimize(assembly, MAX_LENGTH, 3, 0) == optimal assembly = """ -INC 0 -INC 1 -INC 1 -INC 2 -INC 2 -INC 2 + INC 0 + INC 1 + INC 1 + INC 2 + INC 2 + INC 2 """ - optimal = """ -LOAD 2 -SWAP 0, 1 -LOAD 3 -SWAP 0, 2 -LOAD 1 - """.strip() + "\n" - assert optimize(assembly, MAX_LENGTH, 3, 3, 0) == optimal + optimal = parse(""" + LOAD 1 + XOR 1, 0 + XOR 2, 0 + INC 1 + XOR 2, 1 + """) + assert optimize(assembly, MAX_LENGTH, 3, 0) == optimal def test_increasing_from_input(): # Given the input x, the following program should produce the sequence x+1, x+2, x+3 assembly = """ -XOR 1, 0 -XOR 2, 0 -INC 0 -INC 1 -INC 1 -INC 2 -INC 2 -INC 2 + XOR 1, 0 + XOR 2, 0 + INC 0 + INC 1 + INC 1 + INC 2 + INC 2 + INC 2 """ - optimal = """ -INC 0 -XOR 1, 0 -INC 1 -XOR 2, 1 -INC 2 - """.strip() + "\n" - assert optimize(assembly, MAX_LENGTH, 3, 2, 1) == optimal + optimal = parse(""" + INC 0 + XOR 1, 0 + INC 1 + XOR 2, 1 + INC 2 + """) + assert optimize(assembly, MAX_LENGTH, 2, 1) == optimal From c112dd3ba19cf815d4d64e522227fdc843f02862 Mon Sep 17 00:00:00 2001 From: Sebastian Hungerecker Date: Sun, 11 Jun 2023 20:55:56 +0200 Subject: [PATCH 7/9] Use fixed-size integers with a configurable bit width This makes the instruction set more realistic as real CPUs don't support arbitrary precision integers either. More impoortantly, having a fixed width means there's only a finite number of possible input, allowing the equivalence checker to actually try all inputs representable using the current bit width rather than unsoundly stopping at an arbitrary limit. Even more importantly, this paves the way for the introduction of an SMT solver as those don't tend to support XOR on unbounded integers, only on bitvectors. --- brute_force_equivialence_checker.py | 14 +++++++++----- cpu.py | 12 +++++++----- main.py | 10 +++++----- superoptimizer.py | 23 ++++++++++++----------- tests/test_cpu.py | 23 ++++++++++++++++------- tests/test_supperoptimizer.py | 24 +++++++++++------------- 6 files changed, 60 insertions(+), 46 deletions(-) diff --git a/brute_force_equivialence_checker.py b/brute_force_equivialence_checker.py index 35f434f..fc5698c 100644 --- a/brute_force_equivialence_checker.py +++ b/brute_force_equivialence_checker.py @@ -2,18 +2,22 @@ def generate_inputs(input_size, max_val): + """ + Generates all possible tuples of the given size with values ranging from 0 (inclusive) + to `max_val` (exclusive). + """ if input_size == 0: yield () else: - for x in range(max_val + 1): + for x in range(max_val): for rest in generate_inputs(input_size - 1, max_val): - yield (x, *rest) + yield x, *rest -def are_equivalent(program1, program2, max_val, input_size): +def are_equivalent(program1, program2, bit_width, input_size): mem_size = max(program1.mem_size, program2.mem_size) - cpu = CPU(mem_size) - for input in generate_inputs(input_size, max_val): + cpu = CPU(mem_size, bit_width) + for input in generate_inputs(input_size, 2 ** bit_width): if cpu.execute(program1, input) != cpu.execute(program2, input): return False return True diff --git a/cpu.py b/cpu.py index 6eba046..0725bf4 100644 --- a/cpu.py +++ b/cpu.py @@ -1,18 +1,19 @@ import assembler -def run(assembly, input=()): +def run(assembly, bit_width, input=()): """ Helper function that runs a piece of assembly code. """ program = assembler.parse(assembly) - cpu = CPU(program.mem_size) + cpu = CPU(program.mem_size, bit_width) return cpu.execute(program, input) class CPU: - def __init__(self, max_mem_cells): + def __init__(self, max_mem_cells, bit_width): self.max_mem_cells = max_mem_cells + self.limit = 2 ** bit_width def execute(self, program, input=()): state = [0] * self.max_mem_cells @@ -20,7 +21,7 @@ def execute(self, program, input=()): for instruction in program.instructions: match instruction.opcode: case 'LOAD': - state[0] = instruction.args[0] + state[0] = instruction.args[0] % self.limit case 'SWAP': mem1, mem2 = instruction.args state[mem1], state[mem2] = state[mem2], state[mem1] @@ -28,5 +29,6 @@ def execute(self, program, input=()): mem1, mem2 = instruction.args state[mem1] ^= state[mem2] case 'INC': - state[instruction.args[0]] += 1 + mem = instruction.args[0] + state[mem] = (state[mem] + 1) % self.limit return state diff --git a/main.py b/main.py index acdaf25..589c401 100644 --- a/main.py +++ b/main.py @@ -2,14 +2,14 @@ from cpu import run -def print_optimal_from_code(assembly, max_length, max_mem, max_val, debug=False): +def print_optimal_from_code(assembly, max_length, bit_width, debug=False): print(f"***Source***{assembly}") - state = run(assembly, max_mem) + state = run(assembly, bit_width) print("***State***") print(state) print() print("***Optimal***") - print(optimize(assembly, max_length, max_mem, max_val, debug)) + print(optimize(assembly, max_length, bit_width, debug)) print("=" * 20) print() @@ -25,7 +25,7 @@ def main(): SWAP 0, 3 LOAD 3 """ - print_optimal_from_code(assembly, 4, 4, 5) + print_optimal_from_code(assembly, 4, 2) # Test 2 assembly = """ @@ -34,7 +34,7 @@ def main(): LOAD 1 SWAP 0, 2 """ - print_optimal_from_code(assembly, 3, 3, 5) + print_optimal_from_code(assembly, 3, 2) main() diff --git a/superoptimizer.py b/superoptimizer.py index 7f018b9..24d326b 100644 --- a/superoptimizer.py +++ b/superoptimizer.py @@ -4,13 +4,13 @@ from instruction_set import * -def optimize(assembly, max_length, max_val, input_size=0, debug=False): +def optimize(assembly, max_length, bit_width, input_size=0, debug=False): """ Helper function that finds the optimal code given the assembly code. """ program = assembler.parse(assembly) opt = Superoptimizer(brute_force_equivialence_checker.are_equivalent) - return opt.search(max_length, max_val, program, input_size, debug) + return opt.search(max_length, bit_width, program, input_size, debug) class Superoptimizer: @@ -18,32 +18,33 @@ def __init__(self, are_equivalent): self.are_equivalent = are_equivalent @staticmethod - def generate_operands(operand_type, max_mem, max_val): + def generate_operands(operand_type, max_mem, bit_width): if operand_type == "const": - return range(max_val+1) + return range(2 ** bit_width) elif operand_type == "mem": return range(max_mem) else: raise ValueError(f"Illegal operand type: {operand_type}") - # Generates all possible programs. - @staticmethod - def generate_programs(max_length, max_mem, max_val): + def generate_programs(max_length, max_mem, bit_width): + """ + Generates all possible programs + """ yield Program((), 0) for length in range(1, max_length + 1): instructions = [] for op, operand_types in OPS.items(): - arg_sets = (Superoptimizer.generate_operands(ot, max_mem, max_val) for ot in operand_types) + arg_sets = (Superoptimizer.generate_operands(ot, max_mem, bit_width) for ot in operand_types) instructions.extend(assembler.Instruction(op, args) for args in product(*arg_sets)) for prog in product(instructions, repeat=length): yield Program(prog, max_mem) # Tests all the generated programs and returns the shortest. - def search(self, max_length, max_val, program, input_size=0, debug=False): + def search(self, max_length, bit_width, program, input_size=0, debug=False): count = 0 - for optimal in self.generate_programs(max_length, program.mem_size, max_val): - if self.are_equivalent(optimal, program, max_val, input_size): + for optimal in self.generate_programs(max_length, program.mem_size, bit_width): + if self.are_equivalent(optimal, program, bit_width, input_size): return optimal # Debugging. diff --git a/tests/test_cpu.py b/tests/test_cpu.py index dd4c599..778f744 100644 --- a/tests/test_cpu.py +++ b/tests/test_cpu.py @@ -11,7 +11,8 @@ def test_load_and_swap(): SWAP 0, 3 LOAD 3 """ - assert run(assembly) == [3, 3, 3, 3] + assert run(assembly, 8) == [3, 3, 3, 3] + assert run(assembly, 1) == [1, 1, 1, 1] def test_load_and_xor(): @@ -21,7 +22,7 @@ def test_load_and_xor(): LOAD 23 XOR 1, 0 """ - assert run(assembly) == [23, 42 ^ 23] + assert run(assembly, 8) == [23, 42 ^ 23] def test_load_and_inc(): @@ -32,7 +33,7 @@ def test_load_and_inc(): INC 1 INC 1 """ - assert run(assembly) == [42, 3] + assert run(assembly, 8) == [42, 3] def test_input(): @@ -40,11 +41,19 @@ def test_input(): XOR 1, 0 INC 1 """ - assert run(assembly) == [0, 1] - assert run(assembly, [2]) == [2, 3] - assert run(assembly, [1, 2]) == [1, 4] + assert run(assembly, 8) == [0, 1] + assert run(assembly, 8, [2]) == [2, 3] + assert run(assembly, 8, [1, 2]) == [1, 4] def test_load_only(): assembly = 'LOAD 42' - assert run(assembly) == [42] + assert run(assembly, 8) == [42] + + +def test_wrap_around(): + assembly = """ + LOAD 255 + INC 0 + """ + assert run(assembly, 8) == [0] diff --git a/tests/test_supperoptimizer.py b/tests/test_supperoptimizer.py index e7f84cc..d4e16d8 100644 --- a/tests/test_supperoptimizer.py +++ b/tests/test_supperoptimizer.py @@ -21,7 +21,7 @@ def test_four_threes(): XOR 2, 0 XOR 3, 0 """) - assert optimize(assembly, MAX_LENGTH, 3, 0) == optimal + assert optimize(assembly, MAX_LENGTH, 2, 0) == optimal def test_three_threes(): @@ -37,16 +37,14 @@ def test_three_threes(): XOR 1, 0 XOR 2, 0 """) - assert optimize(assembly, MAX_LENGTH, 3, 0) == optimal + assert optimize(assembly, MAX_LENGTH, 2, 0) == optimal # Assert that the program is still found with a tight max_length - assert optimize(assembly, 3, 3, 0) == optimal + assert optimize(assembly, 3, 2, 0) == optimal # Assert that the program is not found with a max_length that's below the optimal length - assert optimize(assembly, 2, 3, 0) is None - # Assert that the program is not found when max_val is too low - assert optimize(assembly, 3, 2, 0) is None + assert optimize(assembly, 2, 2, 0) is None # Changing the input size to 1 doesn't change anything as the first input will be overridden by the load - assert optimize(assembly, MAX_LENGTH, 3, 1) == optimal + assert optimize(assembly, MAX_LENGTH, 2, 1) == optimal # For input size 2, we'll need to clear the second input using swap and another load optimal = parse(""" @@ -55,7 +53,7 @@ def test_three_threes(): LOAD 3 XOR 2, 0 """) - assert optimize(assembly, MAX_LENGTH, 3, 2) == optimal + assert optimize(assembly, MAX_LENGTH, 2, 2) == optimal def test_0_2_1(): @@ -79,9 +77,9 @@ def test_no_op(): """ empty_program = Program((), 0) # Program results in the memory being unchanged, so optimal program is empty - assert optimize(assembly, MAX_LENGTH, 3, 0) == empty_program - assert optimize(assembly, MAX_LENGTH, 2, 1) == empty_program - assert optimize(assembly, MAX_LENGTH, 2, 2) == empty_program + assert optimize(assembly, MAX_LENGTH, 2, 0) == empty_program + assert optimize(assembly, MAX_LENGTH, 1, 1) == empty_program + assert optimize(assembly, MAX_LENGTH, 1, 2) == empty_program def test_increasing_sequence(): @@ -95,7 +93,7 @@ def test_increasing_sequence(): XOR 1, 0 INC 1 """) - assert optimize(assembly, MAX_LENGTH, 3, 0) == optimal + assert optimize(assembly, MAX_LENGTH, 2, 0) == optimal assembly = """ INC 0 @@ -112,7 +110,7 @@ def test_increasing_sequence(): INC 1 XOR 2, 1 """) - assert optimize(assembly, MAX_LENGTH, 3, 0) == optimal + assert optimize(assembly, MAX_LENGTH, 2, 0) == optimal def test_increasing_from_input(): From 5ee4904e2b0bd35c342d32ba59718dcf2cf9c933 Mon Sep 17 00:00:00 2001 From: Sebastian Hungerecker Date: Mon, 12 Jun 2023 21:12:57 +0200 Subject: [PATCH 8/9] Add an SMT-based equivalence checker The SMT-based equivalence checker is significantly more expensive in cases where there are no or few inputs. However, it easily outperforms the brute force one in cases where there's a large bit width and/or number of inputs. Sadly, in those cases where the SMT solver would be better, the superoptimizer will already be unbearably slow regardless of the equivalence checker just from generating all the possible candidate programs. So until the program generation is optimized, the SMT based equivalence checker won't do much good for performance. --- main.py | 2 +- requirements.txt | 1 + smt_based_equivalence_checker.py | 12 ++++ smt_program_simulator.py | 32 ++++++++++ superoptimizer.py | 6 +- tests/test_smt_baseed_equivalence_checker.py | 62 +++++++++++++++++++ ...peroptimizer.py => test_superoptimizer.py} | 28 +++++++-- 7 files changed, 135 insertions(+), 8 deletions(-) create mode 100644 smt_based_equivalence_checker.py create mode 100644 smt_program_simulator.py create mode 100644 tests/test_smt_baseed_equivalence_checker.py rename tests/{test_supperoptimizer.py => test_superoptimizer.py} (76%) diff --git a/main.py b/main.py index 589c401..1a1740f 100644 --- a/main.py +++ b/main.py @@ -9,7 +9,7 @@ def print_optimal_from_code(assembly, max_length, bit_width, debug=False): print(state) print() print("***Optimal***") - print(optimize(assembly, max_length, bit_width, debug)) + print(optimize(assembly, max_length, bit_width, debug=debug)) print("=" * 20) print() diff --git a/requirements.txt b/requirements.txt index e079f8a..c063b90 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ pytest +z3-solver diff --git a/smt_based_equivalence_checker.py b/smt_based_equivalence_checker.py new file mode 100644 index 0000000..ee3e1ea --- /dev/null +++ b/smt_based_equivalence_checker.py @@ -0,0 +1,12 @@ +import z3 +from smt_program_simulator import simulate + + +def are_equivalent(program1, program2, bit_width, input_size): + solver = z3.Solver() + mem_size = max(program1.mem_size, program2.mem_size) + state1 = simulate(program1, mem_size, bit_width, input_size) + state2 = simulate(program2, mem_size, bit_width, input_size) + programs_are_different = z3.Or(*(value1 != value2 for value1, value2 in zip(state1, state2))) + print(programs_are_different) + return solver.check(programs_are_different) == z3.unsat diff --git a/smt_program_simulator.py b/smt_program_simulator.py new file mode 100644 index 0000000..258f892 --- /dev/null +++ b/smt_program_simulator.py @@ -0,0 +1,32 @@ +import z3 + + +def simulate(program, mem_size, bit_width, input_size): + """ + Simulate the behavior of the program using an SMT solver. + + The result will be a list containing, for each memory cell, an SMT value representing the + value that will reside in that memory location after running the program. + """ + + def mem_cell(i): + if i < input_size: + return z3.BitVec(f'input{i}', bit_width) + else: + return z3.BitVecVal(0, bit_width) + + state = [mem_cell(i) for i in range(mem_size)] + for instruction in program.instructions: + match instruction.opcode: + case 'LOAD': + state[0] = z3.BitVecVal(instruction.args[0], bit_width) + case 'SWAP': + mem1, mem2 = instruction.args + state[mem1], state[mem2] = state[mem2], state[mem1] + case 'XOR': + mem1, mem2 = instruction.args + state[mem1] ^= state[mem2] + case 'INC': + mem = instruction.args[0] + state[mem] += 1 + return state diff --git a/superoptimizer.py b/superoptimizer.py index 24d326b..2854ef8 100644 --- a/superoptimizer.py +++ b/superoptimizer.py @@ -4,12 +4,14 @@ from instruction_set import * -def optimize(assembly, max_length, bit_width, input_size=0, debug=False): +def optimize(assembly, max_length, bit_width, input_size=0, *, + equivalence_checker=brute_force_equivialence_checker.are_equivalent, + debug=False): """ Helper function that finds the optimal code given the assembly code. """ program = assembler.parse(assembly) - opt = Superoptimizer(brute_force_equivialence_checker.are_equivalent) + opt = Superoptimizer(equivalence_checker) return opt.search(max_length, bit_width, program, input_size, debug) diff --git a/tests/test_smt_baseed_equivalence_checker.py b/tests/test_smt_baseed_equivalence_checker.py new file mode 100644 index 0000000..0497774 --- /dev/null +++ b/tests/test_smt_baseed_equivalence_checker.py @@ -0,0 +1,62 @@ +import assembler +import smt_based_equivalence_checker + + +def test_add_three(): + three_incs = assembler.parse(""" + INC 0 + INC 0 + INC 0 + """) + load_three = assembler.parse('LOAD 3') + one_inc = assembler.parse('INC 0') + + assert smt_based_equivalence_checker.are_equivalent(three_incs, load_three, 8, 0) + assert not smt_based_equivalence_checker.are_equivalent(three_incs, one_inc, 8, 0) + # With a single bit, += 1 and += 3 are equivalent + assert smt_based_equivalence_checker.are_equivalent(three_incs, one_inc, 1, 0) + + # If there's user input, setting to three and increasing by three are no longer equivalent + assert not smt_based_equivalence_checker.are_equivalent(three_incs, load_three, 8, 1) + # However, += 1 and += 3 are still equivalent for single bits + assert smt_based_equivalence_checker.are_equivalent(three_incs, one_inc, 1, 1) + + +def test_swap_vs_xor(): + swap_with_xor = assembler.parse(""" + XOR 0, 1 + XOR 1, 0 + XOR 0, 1 + """) + swap_with_swap = assembler.parse('SWAP 0, 1') + just_xor = assembler.parse('XOR 0, 1') + + assert smt_based_equivalence_checker.are_equivalent(swap_with_xor, swap_with_swap, 8, 2) + assert not smt_based_equivalence_checker.are_equivalent(swap_with_xor, just_xor, 8, 2) + assert not smt_based_equivalence_checker.are_equivalent(swap_with_swap, just_xor, 8, 2) + + +def test_large_program_with_lots_of_inputs(): + program1 = assembler.parse(""" + INC 0 + XOR 0, 1 + XOR 1, 0 + XOR 0, 1 + INC 1 + SWAP 1, 2 + INC 3 + XOR 3, 2 + INC 4 + INC 5 + """) + program2 = assembler.parse(""" + SWAP 0, 2 + SWAP 0, 1 + INC 2 + INC 2 + INC 3 + XOR 3, 2 + INC 4 + INC 5 + """) + assert smt_based_equivalence_checker.are_equivalent(program1, program2, 8, 6) diff --git a/tests/test_supperoptimizer.py b/tests/test_superoptimizer.py similarity index 76% rename from tests/test_supperoptimizer.py rename to tests/test_superoptimizer.py index d4e16d8..583b705 100644 --- a/tests/test_supperoptimizer.py +++ b/tests/test_superoptimizer.py @@ -1,10 +1,18 @@ from superoptimizer import optimize from assembler import parse, Program +import smt_based_equivalence_checker MAX_LENGTH = 1000000 +def optimize_with_both(*args): + result1 = optimize(*args) + result2 = optimize(*args, equivalence_checker=smt_based_equivalence_checker.are_equivalent) + assert result1 == result2 + return result1 + + def test_four_threes(): assembly = """ LOAD 3 @@ -68,7 +76,7 @@ def test_0_2_1(): SWAP 0, 1 INC 2 """) - assert optimize(assembly, MAX_LENGTH, 2, 0) == optimal + assert optimize_with_both(assembly, MAX_LENGTH, 2, 0) == optimal def test_no_op(): @@ -77,9 +85,9 @@ def test_no_op(): """ empty_program = Program((), 0) # Program results in the memory being unchanged, so optimal program is empty - assert optimize(assembly, MAX_LENGTH, 2, 0) == empty_program - assert optimize(assembly, MAX_LENGTH, 1, 1) == empty_program - assert optimize(assembly, MAX_LENGTH, 1, 2) == empty_program + assert optimize_with_both(assembly, MAX_LENGTH, 2, 0) == empty_program + assert optimize_with_both(assembly, MAX_LENGTH, 1, 1) == empty_program + assert optimize_with_both(assembly, MAX_LENGTH, 1, 2) == empty_program def test_increasing_sequence(): @@ -93,7 +101,7 @@ def test_increasing_sequence(): XOR 1, 0 INC 1 """) - assert optimize(assembly, MAX_LENGTH, 2, 0) == optimal + assert optimize_with_both(assembly, MAX_LENGTH, 2, 0) == optimal assembly = """ INC 0 @@ -133,3 +141,13 @@ def test_increasing_from_input(): INC 2 """) assert optimize(assembly, MAX_LENGTH, 2, 1) == optimal + + +def test_add_to_three_mem_cells(): + assembly = """ + INC 0 + INC 1 + INC 2 + """ + optimal = parse(assembly) + assert optimize_with_both(assembly, MAX_LENGTH, 2, 2) == optimal From f2e4e474b9509de0836b33b316270c9cf395e63a Mon Sep 17 00:00:00 2001 From: Sebastian Hungerecker Date: Tue, 13 Jun 2023 22:12:04 +0200 Subject: [PATCH 9/9] Reuse SMT state in SMT-based equivalence checker This improves the performance of the superoptimizer using the SMT-based equivalence checker quite a bit, but it's still noticably slower than the brute force one for cases with little input. --- brute_force_equivialence_checker.py | 42 +++++++++++--------- smt_based_equivalence_checker.py | 20 ++++++---- superoptimizer.py | 11 ++--- tests/test_smt_baseed_equivalence_checker.py | 24 ++++++----- tests/test_superoptimizer.py | 18 +++++---- 5 files changed, 66 insertions(+), 49 deletions(-) diff --git a/brute_force_equivialence_checker.py b/brute_force_equivialence_checker.py index fc5698c..ba7dc8c 100644 --- a/brute_force_equivialence_checker.py +++ b/brute_force_equivialence_checker.py @@ -1,23 +1,29 @@ from cpu import CPU -def generate_inputs(input_size, max_val): - """ - Generates all possible tuples of the given size with values ranging from 0 (inclusive) - to `max_val` (exclusive). - """ - if input_size == 0: - yield () - else: - for x in range(max_val): - for rest in generate_inputs(input_size - 1, max_val): - yield x, *rest +class BruteForceEquivalenceChecker: + def __init__(self, program1, bit_width, input_size): + self.program1 = program1 + self.bit_width = bit_width + self.max_val = 2 ** bit_width + self.input_size = input_size + def generate_inputs(self, input_size): + """ + Generates all possible tuples of the given size with values ranging from 0 (inclusive) + to `max_val` (exclusive). + """ + if input_size == 0: + yield () + else: + for x in range(self.max_val): + for rest in self.generate_inputs(input_size - 1): + yield x, *rest -def are_equivalent(program1, program2, bit_width, input_size): - mem_size = max(program1.mem_size, program2.mem_size) - cpu = CPU(mem_size, bit_width) - for input in generate_inputs(input_size, 2 ** bit_width): - if cpu.execute(program1, input) != cpu.execute(program2, input): - return False - return True + def is_equivalent_to(self, program2): + mem_size = max(self.program1.mem_size, program2.mem_size) + cpu = CPU(mem_size, self.bit_width) + for input in self.generate_inputs(self.input_size): + if cpu.execute(self.program1, input) != cpu.execute(program2, input): + return False + return True diff --git a/smt_based_equivalence_checker.py b/smt_based_equivalence_checker.py index ee3e1ea..5c71f58 100644 --- a/smt_based_equivalence_checker.py +++ b/smt_based_equivalence_checker.py @@ -2,11 +2,15 @@ from smt_program_simulator import simulate -def are_equivalent(program1, program2, bit_width, input_size): - solver = z3.Solver() - mem_size = max(program1.mem_size, program2.mem_size) - state1 = simulate(program1, mem_size, bit_width, input_size) - state2 = simulate(program2, mem_size, bit_width, input_size) - programs_are_different = z3.Or(*(value1 != value2 for value1, value2 in zip(state1, state2))) - print(programs_are_different) - return solver.check(programs_are_different) == z3.unsat +class SmtBasedEquivalenceChecker: + def __init__(self, program1, bit_width, input_size): + self.solver = z3.Solver() + self.bit_width = bit_width + self.input_size = input_size + self.mem_size = program1.mem_size + self.state1 = simulate(program1, self.mem_size, bit_width, input_size) + + def is_equivalent_to(self, program2): + state2 = simulate(program2, self.mem_size, self.bit_width, self.input_size) + programs_are_different = z3.Or(*(value1 != value2 for value1, value2 in zip(self.state1, state2))) + return self.solver.check(programs_are_different) == z3.unsat diff --git a/superoptimizer.py b/superoptimizer.py index 2854ef8..b4dd501 100644 --- a/superoptimizer.py +++ b/superoptimizer.py @@ -1,11 +1,11 @@ from itertools import product import assembler -import brute_force_equivialence_checker +from brute_force_equivialence_checker import BruteForceEquivalenceChecker from instruction_set import * def optimize(assembly, max_length, bit_width, input_size=0, *, - equivalence_checker=brute_force_equivialence_checker.are_equivalent, + equivalence_checker=BruteForceEquivalenceChecker, debug=False): """ Helper function that finds the optimal code given the assembly code. @@ -16,8 +16,8 @@ def optimize(assembly, max_length, bit_width, input_size=0, *, class Superoptimizer: - def __init__(self, are_equivalent): - self.are_equivalent = are_equivalent + def __init__(self, equivalence_checker_class): + self.equivalence_checker_class = equivalence_checker_class @staticmethod def generate_operands(operand_type, max_mem, bit_width): @@ -45,8 +45,9 @@ def generate_programs(max_length, max_mem, bit_width): # Tests all the generated programs and returns the shortest. def search(self, max_length, bit_width, program, input_size=0, debug=False): count = 0 + equivalence_checker = self.equivalence_checker_class(program, bit_width, input_size) for optimal in self.generate_programs(max_length, program.mem_size, bit_width): - if self.are_equivalent(optimal, program, bit_width, input_size): + if equivalence_checker.is_equivalent_to(optimal): return optimal # Debugging. diff --git a/tests/test_smt_baseed_equivalence_checker.py b/tests/test_smt_baseed_equivalence_checker.py index 0497774..5e3b4c2 100644 --- a/tests/test_smt_baseed_equivalence_checker.py +++ b/tests/test_smt_baseed_equivalence_checker.py @@ -1,5 +1,9 @@ import assembler -import smt_based_equivalence_checker +from smt_based_equivalence_checker import SmtBasedEquivalenceChecker + + +def are_equivalent(program1, program2, bit_width, input_size): + return SmtBasedEquivalenceChecker(program1, bit_width, input_size).is_equivalent_to(program2) def test_add_three(): @@ -11,15 +15,15 @@ def test_add_three(): load_three = assembler.parse('LOAD 3') one_inc = assembler.parse('INC 0') - assert smt_based_equivalence_checker.are_equivalent(three_incs, load_three, 8, 0) - assert not smt_based_equivalence_checker.are_equivalent(three_incs, one_inc, 8, 0) + assert are_equivalent(three_incs, load_three, 8, 0) + assert not are_equivalent(three_incs, one_inc, 8, 0) # With a single bit, += 1 and += 3 are equivalent - assert smt_based_equivalence_checker.are_equivalent(three_incs, one_inc, 1, 0) + assert are_equivalent(three_incs, one_inc, 1, 0) # If there's user input, setting to three and increasing by three are no longer equivalent - assert not smt_based_equivalence_checker.are_equivalent(three_incs, load_three, 8, 1) + assert not are_equivalent(three_incs, load_three, 8, 1) # However, += 1 and += 3 are still equivalent for single bits - assert smt_based_equivalence_checker.are_equivalent(three_incs, one_inc, 1, 1) + assert are_equivalent(three_incs, one_inc, 1, 1) def test_swap_vs_xor(): @@ -31,9 +35,9 @@ def test_swap_vs_xor(): swap_with_swap = assembler.parse('SWAP 0, 1') just_xor = assembler.parse('XOR 0, 1') - assert smt_based_equivalence_checker.are_equivalent(swap_with_xor, swap_with_swap, 8, 2) - assert not smt_based_equivalence_checker.are_equivalent(swap_with_xor, just_xor, 8, 2) - assert not smt_based_equivalence_checker.are_equivalent(swap_with_swap, just_xor, 8, 2) + assert are_equivalent(swap_with_xor, swap_with_swap, 8, 2) + assert not are_equivalent(swap_with_xor, just_xor, 8, 2) + assert not are_equivalent(swap_with_swap, just_xor, 8, 2) def test_large_program_with_lots_of_inputs(): @@ -59,4 +63,4 @@ def test_large_program_with_lots_of_inputs(): INC 4 INC 5 """) - assert smt_based_equivalence_checker.are_equivalent(program1, program2, 8, 6) + assert are_equivalent(program1, program2, 8, 6) diff --git a/tests/test_superoptimizer.py b/tests/test_superoptimizer.py index 583b705..60f0ed5 100644 --- a/tests/test_superoptimizer.py +++ b/tests/test_superoptimizer.py @@ -1,6 +1,6 @@ from superoptimizer import optimize from assembler import parse, Program -import smt_based_equivalence_checker +from smt_based_equivalence_checker import SmtBasedEquivalenceChecker MAX_LENGTH = 1000000 @@ -8,7 +8,7 @@ def optimize_with_both(*args): result1 = optimize(*args) - result2 = optimize(*args, equivalence_checker=smt_based_equivalence_checker.are_equivalent) + result2 = optimize(*args, equivalence_checker=SmtBasedEquivalenceChecker) assert result1 == result2 return result1 @@ -29,6 +29,8 @@ def test_four_threes(): XOR 2, 0 XOR 3, 0 """) + # This test case takes too long with the SMT-based equivalence checker, so we only test with the brute force one + # (which doesn't actually need much force when the input size is 0) assert optimize(assembly, MAX_LENGTH, 2, 0) == optimal @@ -45,14 +47,14 @@ def test_three_threes(): XOR 1, 0 XOR 2, 0 """) - assert optimize(assembly, MAX_LENGTH, 2, 0) == optimal + assert optimize_with_both(assembly, MAX_LENGTH, 2, 0) == optimal # Assert that the program is still found with a tight max_length - assert optimize(assembly, 3, 2, 0) == optimal + assert optimize_with_both(assembly, 3, 2, 0) == optimal # Assert that the program is not found with a max_length that's below the optimal length - assert optimize(assembly, 2, 2, 0) is None + assert optimize_with_both(assembly, 2, 2, 0) is None # Changing the input size to 1 doesn't change anything as the first input will be overridden by the load - assert optimize(assembly, MAX_LENGTH, 2, 1) == optimal + assert optimize_with_both(assembly, MAX_LENGTH, 2, 1) == optimal # For input size 2, we'll need to clear the second input using swap and another load optimal = parse(""" @@ -61,7 +63,7 @@ def test_three_threes(): LOAD 3 XOR 2, 0 """) - assert optimize(assembly, MAX_LENGTH, 2, 2) == optimal + assert optimize_with_both(assembly, MAX_LENGTH, 2, 2) == optimal def test_0_2_1(): @@ -118,7 +120,7 @@ def test_increasing_sequence(): INC 1 XOR 2, 1 """) - assert optimize(assembly, MAX_LENGTH, 2, 0) == optimal + assert optimize_with_both(assembly, MAX_LENGTH, 2, 0) == optimal def test_increasing_from_input():