Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added P2_ONLY_random_state_agent.pth
Binary file not shown.
Binary file added STABLE_random_state_agent.pth
Binary file not shown.
Binary file added __pycache__/hsm.cpython-39.pyc
Binary file not shown.
Binary file added __pycache__/hsmv2.cpython-39.pyc
Binary file not shown.
Binary file added __pycache__/mcts.cpython-39.pyc
Binary file not shown.
154 changes: 85 additions & 69 deletions ataxx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
FEN_4SIDES = "x5o/7/3-3/2-1-2/3-3/7/o5x x 0 1"
FEN_EMPTY = "7/7/7/7/7/7/7 x 0 1"

# Additional Parameters
FEN_CENTERX = "x-3-o/1-3-1/2-1-2/3-3/2-1-2/1-3-1/o-3-x x 0 1"
FEN_ISLAND = "x5o/3-3/2-1-2/1-3-1/2-1-2/3-3/o5x x 0 1"

BOARD_DIM = 7

class Move:
def __init__(self, fr_x, fr_y, to_x, to_y):
self.fr_x = fr_x
Expand All @@ -27,23 +33,23 @@ def from_san(cls, san):
if san == "0000":
return cls.null()
elif len(san) == 2:
if san[0] not in "abcdefg":
raise Exception(F"ValueError {san}")
elif san[1] not in "1234567":
raise Exception(F"ValueError {san}")
# if san[0] not in "abcdefg":
# raise Exception(F"ValueError {san}")
# elif san[1] not in "1234567":
# raise Exception(F"ValueError {san}")

to_x = ord(san[0]) - ord('a')
to_y = ord(san[1]) - ord('1')
return cls(to_x, to_y, to_x, to_y)
elif len(san) == 4:
if san[0] not in "abcdefg":
raise Exception(F"ValueError {san}")
elif san[1] not in "1234567":
raise Exception(F"ValueError {san}")
elif san[2] not in "abcdefg":
raise Exception(F"ValueError {san}")
elif san[3] not in "1234567":
raise Exception(F"ValueError {san}")
# if san[0] not in "abcdefg":
# raise Exception(F"ValueError {san}")
# elif san[1] not in "1234567":
# raise Exception(F"ValueError {san}")
# elif san[2] not in "abcdefg":
# raise Exception(F"ValueError {san}")
# elif san[3] not in "1234567":
# raise Exception(F"ValueError {san}")

fr_x = ord(san[0]) - ord('a')
fr_y = ord(san[1]) - ord('1')
Expand Down Expand Up @@ -93,9 +99,10 @@ def __str__(self):
return F"{chr(ord('a')+self.fr_x)}{self.fr_y+1}{chr(ord('a')+self.to_x)}{self.to_y+1}"

class Board:
def __init__(self, fen=FEN_STARTPOS):
self._board = [[GAP for x in range(7+4)] for y in range(7+4)]
self._counts = [0, 0, 0, 49]
def __init__(self, fen=FEN_STARTPOS, board_dim=BOARD_DIM):
self.board_dim = board_dim
self._board = [[GAP for x in range(self.board_dim+4)] for y in range(self.board_dim+4)]
self._counts = [0, 0, 0, self.board_dim ** 2]
self.set_fen(fen)

def get(self, x, y):
Expand Down Expand Up @@ -140,11 +147,12 @@ def count(self):
return self.num_black(), self.num_white(), self.num_gaps(), self.num_empty()

def __str__(self):
board = " a b c d e f g\n"
board += " ╔═╦═╦═╦═╦═╦═╦═╗\n"
for y in range(6, -1, -1):
board += chr(y+49) + '║'
for x in range(0, 7):
board = " " + " ".join(chr(ord('a') + x) for x in range(self.board_dim)) + "\n"
board += " ╔" + "═╦" * (self.board_dim - 1) + "═╗\n"

for y in range(self.board_dim - 1, -1, -1):
board += chr(y + 49) + '║'
for x in range(self.board_dim):
if self.get(x, y) == BLACK:
board += 'X'
elif self.get(x, y) == WHITE:
Expand All @@ -156,26 +164,29 @@ def __str__(self):
else:
board += "?"
board += '║'
board += chr(y+49) + '\n'
board += chr(y + 49) + '\n'

if y > 0:
board += ' ╠═╬═╬═╬═╬═╬═╬═╣\n'
board += " ╚═╩═╩═╩═╩═╩═╩═╝\n"
board += " a b c d e f g\n"
board += ' ╠' + "═╬" * (self.board_dim - 1) + "═╣\n"

board += " ╚" + "═╩" * (self.board_dim - 1) + "═╝\n"
board += " " + " ".join(chr(ord('a') + x) for x in range(self.board_dim)) + "\n"

if self.turn == BLACK:
board += "Turn: X"
elif self.turn == WHITE:
board += "Turn: O"
else:
board += "Turn: ?"

return board

def get_fen(self):
"""Return a fen string for the current position"""

fen = ''
for y in range(6, -1, -1):
for y in range(self.board_dim - 1, -1, -1):
empty = 0
for x in range(7):
for x in range(self.board_dim):
if self.get(x, y) != EMPTY and empty > 0:
fen += str(empty)
empty = 0
Expand Down Expand Up @@ -218,28 +229,34 @@ def set_fen(self, fen):
fen = FEN_STARTPOS
elif fen == "empty":
fen = FEN_EMPTY
elif fen == "island":
fen = FEN_ISLAND
elif fen == "foursides":
fen = FEN_4SIDES
elif fen == "centerx":
fen = FEN_CENTERX

parts = fen.split()

if len(parts) < 1 or len(parts) > 4:
return False
if parts[0].count('/') != 6:
return False
if len(parts[0]) < len("7/7/7/7/7/7/7"):
return False
if len(parts[0]) > len("xxxxxxx/xxxxxxx/xxxxxxx/xxxxxxx/xxxxxxx/xxxxxxx/xxxxxxx"):
return False
# if len(parts) < 1 or len(parts) > 4:
# return False
# if parts[0].count('/') != 6:
# return False
# if len(parts[0]) < len("7/7/7/7/7/7/7"):
# return False
# if len(parts[0]) > len("xxxxxxx/xxxxxxx/xxxxxxx/xxxxxxx/xxxxxxx/xxxxxxx/xxxxxxx"):
# return False

# Clear board
for x in range(7):
for y in range(7):
for x in range(self.board_dim):
for y in range(self.board_dim):
self.set(x, y, EMPTY)
self.turn = BLACK
self.halfmove_clock = 0
self.fullmove_clock = 1
self.history = []
self._halfmove_stack = []
self._counts = [0, 0, 0, 49]
self._counts = [0, 0, 0, self.board_dim ** 2]

# Add side to move
if len(parts) < 2:
Expand All @@ -256,7 +273,7 @@ def set_fen(self, fen):
# Set board
sq = 0
for c in parts[0]:
x, y = sq%7, 7 - sq//7 - 1
x, y = sq%self.board_dim, self.board_dim - sq//self.board_dim - 1

if c in "1234567":
sq = sq + int(c)
Expand All @@ -279,9 +296,8 @@ def set_fen(self, fen):
return False

# We need to have parsed the right number of squares
if sq != 7 * 7:
return False

# if sq != self.board_dim ** 2:
# return False
# Set turn
if parts[1] in "bBxX":
self.turn = BLACK
Expand All @@ -305,7 +321,7 @@ def set_fen(self, fen):
# Save fen
self._start_fen = ' '.join(parts)

self.hash = calculate_hash(self)
self.hash = calculate_hash(self, self.board_dim)

return True

Expand All @@ -331,10 +347,10 @@ def makemove(self, move):
return

self.set(move.to_x, move.to_y, self.turn)
self.hash ^= get_sq_hash(move.to_x, move.to_y, self.turn)
self.hash ^= get_sq_hash(move.to_x, move.to_y, self.turn, self.board_dim)

if move.is_double():
self.hash ^= get_sq_hash(move.fr_x, move.fr_y, self.turn)
self.hash ^= get_sq_hash(move.fr_x, move.fr_y, self.turn, self.board_dim)
self.set(move.fr_x, move.fr_y, EMPTY)

for idx, (dx, dy) in enumerate(SINGLES):
Expand All @@ -344,8 +360,8 @@ def makemove(self, move):
self.set(x, y, self.turn)
self._counts[self.turn] += 1
self._counts[opponent] -= 1
self.hash ^= get_sq_hash(x, y, opponent)
self.hash ^= get_sq_hash(x, y, self.turn)
self.hash ^= get_sq_hash(x, y, opponent, self.board_dim)
self.hash ^= get_sq_hash(x, y, self.turn, self.board_dim)
else:
move.flipped[idx] = False

Expand Down Expand Up @@ -379,13 +395,13 @@ def undo(self):
return

# Remove the piece we placed
self.hash ^= get_sq_hash(move.to_x, move.to_y, self.get(move.to_x, move.to_y))
self.hash ^= get_sq_hash(move.to_x, move.to_y, self.get(move.to_x, move.to_y), self.board_dim)
self._counts[us] -= 1
self.set(move.to_x, move.to_y, EMPTY)

# Restore the piece we removed
if move.is_double():
self.hash ^= get_sq_hash(move.fr_x, move.fr_y, us)
self.hash ^= get_sq_hash(move.fr_x, move.fr_y, us, self.board_dim)
self._counts[us] += 1
self.set(move.fr_x, move.fr_y, us)

Expand All @@ -396,8 +412,8 @@ def undo(self):
self.set(move.to_x + dx, move.to_y + dy, them)
self._counts[us] -= 1
self._counts[them] += 1
self.hash ^= get_sq_hash(move.to_x + dx, move.to_y + dy, us)
self.hash ^= get_sq_hash(move.to_x + dx, move.to_y + dy, them)
self.hash ^= get_sq_hash(move.to_x + dx, move.to_y + dy, us, self.board_dim)
self.hash ^= get_sq_hash(move.to_x + dx, move.to_y + dy, them, self.board_dim)

def predict_hash(self, move):
"""Calculate the hash after a move is played without applying the move"""
Expand All @@ -415,16 +431,16 @@ def predict_hash(self, move):
if move == Move.null():
return hash

hash ^= get_sq_hash(move.to_x, move.to_y, self.turn)
hash ^= get_sq_hash(move.to_x, move.to_y, self.turn, self.board_dim)

if move.is_double():
hash ^= get_sq_hash(move.fr_x, move.fr_y, self.turn)
hash ^= get_sq_hash(move.fr_x, move.fr_y, self.turn, self.board_dim)

for dx, dy in SINGLES:
x, y = move.to_x + dx, move.to_y + dy
if self.get(x, y) == opponent:
hash ^= get_sq_hash(x, y, opponent)
hash ^= get_sq_hash(x, y, self.turn)
hash ^= get_sq_hash(x, y, opponent, self.board_dim)
hash ^= get_sq_hash(x, y, self.turn, self.board_dim)

return hash

Expand Down Expand Up @@ -453,8 +469,8 @@ def legal_moves(self):
return []

movelist = []
for x in range(7):
for y in range(7):
for x in range(self.board_dim):
for y in range(self.board_dim):
# Singles
if self.get(x, y) == EMPTY:
for dx, dy in SINGLES:
Expand All @@ -471,8 +487,8 @@ def legal_moves(self):
if movelist == []:
# If the opponent can move, we have to pass
opponent = WHITE if self.turn == BLACK else BLACK
for x in range(7):
for y in range(7):
for x in range(self.board_dim):
for y in range(self.board_dim):
if self.get(x, y) == opponent:
# Singles
for dx, dy in SINGLES:
Expand All @@ -495,8 +511,8 @@ def must_pass(self):
bool:Whether the side to move must pass
"""

for x in range(7):
for y in range(7):
for x in range(self.board_dim):
for y in range(self.board_dim):
if self.get(x, y) == self.turn:
# Singles
for dx, dy in SINGLES:
Expand Down Expand Up @@ -526,10 +542,10 @@ def is_legal(self, move):
if self.must_pass():
return move == Move.null()

if move.fr_x < 0 or move.fr_y > 6:
return False
if move.to_x < 0 or move.to_y > 6:
return False
# if move.fr_x < 0 or move.fr_y > 6:
# return False
# if move.to_x < 0 or move.to_y > 6:
# return False

if move.is_single():
# To square must be empty
Expand Down Expand Up @@ -598,8 +614,8 @@ def gameover(self):
return True

# No moves left
for x in range(7):
for y in range(7):
for x in range(self.board_dim):
for y in range(self.board_dim):
if self.get(x, y) in [BLACK, WHITE]:
# Singles
for dx, dy in SINGLES:
Expand All @@ -620,8 +636,8 @@ def result(self):
has_moves: bool = False

# No moves left
for x in range(7):
for y in range(7):
for x in range(self.board_dim):
for y in range(self.board_dim):
piece = self.get(x, y)

if piece == BLACK:
Expand Down
Loading