-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtests.py
More file actions
49 lines (32 loc) · 1.28 KB
/
tests.py
File metadata and controls
49 lines (32 loc) · 1.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np
import pytest
from main import matmul, matmul4r
def check_equal(np_result, matmul_result, matmul4r_result):
assert np.array_equal(np_result, matmul_result)
assert np.array_equal(np_result, matmul4r_result)
def compute(mat_a, mat_b):
return np.matmul(mat_a, mat_b) % 2, matmul(mat_a, mat_b, binary=True), matmul4r(mat_a, mat_b)
def test_trivial_case():
mat_a = np.array([[0]])
mat_b = np.array([[0]])
np_result, matmul_result, matmul4r_result = compute(mat_a, mat_b)
check_equal(np_result, matmul_result, matmul4r_result)
def test_random_case():
n = np.random.randint(1, 101)
mat_a = np.random.randint(0, 2, (n, n))
mat_b = np.random.randint(0, 2, (n, n))
np_result, matmul_result, matmul4r_result = compute(mat_a, mat_b)
check_equal(np_result, matmul_result, matmul4r_result)
def test_non_square():
n = np.random.randint(1, 101)
mat_a = np.random.randint(0, 2, (n, n+1))
mat_b = np.random.randint(0, 2, (n+1, n))
with pytest.raises(Exception):
_ = matmul4r(mat_a, mat_b)
def test_non_binary():
n = np.random.randint(1, 101)
mat_a = np.random.randint(0, 2, (n, n))
mat_b = np.random.randint(0, 2, (n, n))
mat_a[0, 0] = 2
with pytest.raises(Exception):
_ = matmul4r(mat_a, mat_b)