Skip to content

Commit 378fa56

Browse files
committed
Moved sunfish_nnue.py to make it testable by quick_tests
1 parent da975d0 commit 378fa56

File tree

3 files changed

+63
-55
lines changed

3 files changed

+63
-55
lines changed

nnue/sunfish_nnue_color.py renamed to sunfish_nnue.py

Lines changed: 30 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from functools import partial
88
print = partial(print, flush=True)
99

10+
version = 'sunfish nnue'
11+
1012
###############################################################################
1113
# A small neural network to evaluate positions
1214
###############################################################################
@@ -105,11 +107,13 @@ def features(board):
105107
}
106108

107109
# Constants for tuning search
108-
#EVAL_ROUGHNESS = 13
109-
#QS_LIMIT = 200
110-
#QS_CAPTURE, QS_SINGLE, QS_DOUBLE = range(3)
111-
#QS_TYPE = QS_CAPTURE
112-
#debug = False
110+
EVAL_ROUGHNESS = 13
111+
112+
# minifier-hide start
113+
opt_ranges = dict(
114+
EVAL_ROUGHNESS = (0, 50),
115+
)
116+
# minifier-hide end
113117

114118

115119
###############################################################################
@@ -258,11 +262,7 @@ def is_capture(self, move):
258262
# to last forever (until python stackoverflows.) Thus we need to either
259263
# dampen the eval function, or like here, reduce QS search to captures
260264
# only. Well, captures plus promotions.
261-
return (
262-
self.board[move.j] != "."
263-
or abs(move.j - self.kp) < 2
264-
or self.board[move.i] == "P" and (A8 <= move.j <= H8 or move.j == self.ep)
265-
)
265+
return self.board[move.j] != "." or abs(move.j - self.kp) < 2 or move.prom
266266

267267
def compute_value(self):
268268
#relu6 = lambda x: np.minimum(np.maximum(x, 0), 6)
@@ -288,7 +288,6 @@ def hash(self):
288288
# return (self.wf + self.bf).sum()
289289
# return self._replace(wf=0, bf=0)
290290

291-
292291
###############################################################################
293292
# Search logic
294293
###############################################################################
@@ -359,11 +358,7 @@ def bound(self, pos, gamma, depth, root=True):
359358
def moves():
360359
# First try not moving at all. We only do this if there is at least one major
361360
# piece left on the board, since otherwise zugzwangs are too dangerous.
362-
# It doesn't make sense to use this function for depth 2, since it will take us
363-
# to depth max(0, d-2)=0, meaning reducing by two. So it's not actually the
364-
# opponents turn. This seems like it should be a major bug?
365-
#if (depth >= 3 or depth == 1) and not root and any(c in pos.board for c in "NBRQ"):
366-
if depth >= 3 and not root:
361+
if depth > 2 and not root and any(c in pos.board for c in "NBRQ"):
367362
yield None, -self.bound(pos.rotate(nullmove=True), 1-gamma, depth-3, False)
368363
# For QSearch we have a different kind of null-move, namely we can just stop
369364
# and not capture anything else.
@@ -387,9 +382,8 @@ def mvv_lva(move):
387382
return score
388383

389384
killer = self.tp_move.get(pos.hash())
390-
if killer:
391-
if depth > 0 or pos.is_capture(killer):
392-
yield killer, -self.bound(pos.move(killer), 1-gamma, depth-1, False)
385+
if killer and (depth > 0 or pos.is_capture(killer)):
386+
yield killer, -self.bound(pos.move(killer), 1-gamma, depth-1, False)
393387

394388
# Then all the other moves
395389
# moves = [(move, pos.move(move)) for move in pos.gen_moves()]
@@ -430,24 +424,12 @@ def mvv_lva(move):
430424
self.tp_move[pos.hash()] = move
431425
break
432426

433-
# Stalemate checking is a bit tricky: Say we failed low, because
434-
# we can't (legally) move and so the (real) score is -infty.
435-
# At the next depth we are allowed to just return r, -infty <= r < gamma,
436-
# which is normally fine.
437-
# However, what if gamma = -10 and we don't have any legal moves?
438-
# Then the score is actaully a draw and we should fail high!
439-
# Thus, if best < gamma and best < 0 we need to double check what we are doing.
440-
# This doesn't prevent sunfish from making a move that results in stalemate,
441-
# but only if depth == 1, so that's probably fair enough.
442-
# (Btw, at depth 1 we can also mate without realizing.)
443-
if best < gamma and best < 0 and depth > 0:
444-
# A position is dead if the curent player has a move that captures the king
445-
is_dead = lambda pos: any(
446-
pos.move(m).score <= -MATE_LOWER for m in pos.gen_moves()
447-
)
448-
if all(is_dead(pos.move(m)) for m in pos.gen_moves()):
449-
in_check = is_dead(pos.rotate(nullmove=True))
450-
best = -MATE_UPPER if in_check else 0
427+
# Stalemate checking
428+
if depth > 0 and best == -MATE_UPPER:
429+
flipped = pos.rotate(nullmove=True)
430+
# Hopefully this is already in the TT because of null-move
431+
in_check = self.bound(flipped, MATE_UPPER, 0) == MATE_UPPER
432+
best = -MATE_LOWER if in_check else 0
451433

452434
# Table part 2
453435
self.tp_score[pos.hash(), depth, root] = (
@@ -477,7 +459,7 @@ def search(self, history):
477459
# 'while lower != upper' would work, but play tests show a margin of 20 plays
478460
# better.
479461
lower, upper = -MATE_UPPER, MATE_UPPER
480-
while lower < upper - 1:
462+
while lower < upper - EVAL_ROUGHNESS:
481463
score = self.bound(pos, gamma, depth)
482464
if score >= gamma:
483465
lower = score
@@ -504,11 +486,20 @@ def render(i):
504486

505487
wf, bf = features(initial)
506488
hist = [Position(initial, 0, wf, bf, (True, True), (True, True), 0, 0)]
489+
490+
491+
# minifier-hide start
492+
import sys, tools.uci
493+
tools.uci.run(sys.modules[__name__], hist[-1])
494+
sys.exit()
495+
# minifier-hide end
496+
497+
507498
searcher = Searcher()
508499
while True:
509500
args = input().split()
510501
if args[0] == "uci":
511-
print(f"id name sunfish nnue")
502+
print(f"id name {version}")
512503
print("uciok")
513504

514505
elif args[0] == "isready":

tools/tester.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ async def run(cls, engine, args):
266266
continue
267267
score = info["score"]
268268
if score.is_mate() or score.relative.cp > 10000:
269-
if info["pv"]:
269+
if "pv" in info and info["pv"]:
270270
b = board.copy()
271271
for move in info["pv"]:
272272
b.push(move)

tools/uci.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,12 @@ def go_loop(searcher, hist, stop_event, max_movetime=0, max_depth=0, debug=False
5353
fields["score cp"] = f"{score} upperbound"
5454
print("info", " ".join(f"{k} {v}" for k, v in fields.items()))
5555

56-
if best_move and elapsed > max_movetime * 2 / 3:
57-
break
58-
if stop_event.is_set():
59-
break
56+
# We may not have a move yet at depth = 1
57+
if depth > 1:
58+
if elapsed > max_movetime * 2 / 3:
59+
break
60+
if stop_event.is_set():
61+
break
6062

6163
my_pv = pv(searcher, hist[-1], include_scores=False)
6264
print("bestmove", my_pv[0] if my_pv else "(none)")
@@ -134,12 +136,11 @@ def _perft_count(pos, depth):
134136
return res
135137

136138

137-
def run(sunfish_module):
139+
def run(sunfish_module, startpos):
138140
global sunfish
139141
sunfish = sunfish_module
140142

141143
debug = False
142-
startpos = sunfish.Position(sunfish.initial, 0, (True, True), (True, True), 0, 0)
143144
hist = [startpos]
144145
searcher = sunfish.Searcher()
145146

@@ -208,7 +209,11 @@ def run(sunfish_module):
208209
wc = ("Q" in castling, "K" in castling)
209210
bc = ("k" in castling, "q" in castling)
210211
ep = sunfish.parse(enpas) if enpas != "-" else 0
211-
pos = sunfish.Position(board, 0, wc, bc, ep, 0)
212+
if hasattr(sunfish, 'features'):
213+
wf, bf = sunfish.features(board)
214+
pos = sunfish.Position(board, 0, wf, bf, wc, bc, ep, 0)
215+
else:
216+
pos = sunfish.Position(board, 0, wc, bc, ep, 0)
212217
hist = [pos] if color == "w" else [pos, pos.rotate()]
213218
if len(args) > 8 and args[8] == "moves":
214219
for move in args[9:]:
@@ -288,8 +293,9 @@ def get_color(pos):
288293
def can_kill_king(pos):
289294
# If we just checked for opponent moves capturing the king, we would miss
290295
# captures in case of illegal castling.
291-
MATE_LOWER = 60_000 - 10 * 929
292-
return any(pos.value(m) >= MATE_LOWER for m in pos.gen_moves())
296+
#MATE_LOWER = 60_000 - 10 * 929
297+
#return any(pos.value(m) >= MATE_LOWER for m in pos.gen_moves())
298+
return any(pos.board[m.j] == 'k' or abs(m.j - pos.kp) < 2 for m in pos.gen_moves())
293299

294300

295301
def pv(searcher, pos, include_scores=True, include_loop=False):
@@ -300,7 +306,9 @@ def pv(searcher, pos, include_scores=True, include_loop=False):
300306
if include_scores:
301307
res.append(str(pos.score))
302308
while True:
303-
if hasattr(searcher, "tp_move"):
309+
if hasattr(pos, "wf"):
310+
move = searcher.tp_move.get(pos.hash())
311+
elif hasattr(searcher, "tp_move"):
304312
move = searcher.tp_move.get(pos)
305313
elif hasattr(searcher, "tt_new"):
306314
move = searcher.tt_new[0][pos, True].move
@@ -309,11 +317,20 @@ def pv(searcher, pos, include_scores=True, include_loop=False):
309317
break
310318
res.append(render_move(move, get_color(pos) == WHITE))
311319
pos, color = pos.move(move), 1 - color
312-
if pos in seen_pos:
313-
if include_loop:
314-
res.append("loop")
315-
break
316-
seen_pos.add(pos)
320+
321+
if hasattr(pos, "wf"):
322+
if pos.hash() in seen_pos:
323+
if include_loop:
324+
res.append("loop")
325+
break
326+
seen_pos.add(pos.hash())
327+
else:
328+
if pos in seen_pos:
329+
if include_loop:
330+
res.append("loop")
331+
break
332+
seen_pos.add(pos)
333+
317334
if include_scores:
318335
res.append(str(pos.score if color == origc else -pos.score))
319336
return res

0 commit comments

Comments
 (0)