|
|
import re |
|
|
import chess |
|
|
import gradio as gr |
|
|
import spaces |
|
|
from jinja2 import Template |
|
|
from gradio_chessboard import Chessboard |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = "nuriyev/chess-reasoner-grpo" |
|
|
|
|
|
print("Loading model...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
revision="b7e531a630fd35065f9c8287f4bd21dff42f871b", |
|
|
) |
|
|
model.eval() |
|
|
print("Model loaded!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
USER_PROMPT = Template("""You are an expert chess player. |
|
|
|
|
|
Given a current game state, you must select the best legal next move. Think in 1-2 sentences, then output your chosen move. |
|
|
|
|
|
## State |
|
|
|
|
|
Board: |
|
|
{% set fen_board = FEN.split()[0] %} |
|
|
{%- set ns = namespace(board='') -%} |
|
|
{%- for char in fen_board -%} |
|
|
{%- if char in '12345678' -%} |
|
|
{%- set ns.board = ns.board ~ '.' * (char|int) -%} |
|
|
{%- elif char != '/' -%} |
|
|
{%- set ns.board = ns.board ~ char -%} |
|
|
{%- endif -%} |
|
|
{%- endfor -%} |
|
|
{#- Output coordinate grid by file -#} |
|
|
{%- set files = 'abcdefgh' -%} |
|
|
{% for f in range(8) %} |
|
|
{%- for r in range(1, 9) -%} |
|
|
{{ files[f] }}{{ r }}:{{ ns.board[(8-r)*8 + f] }}{% if r < 8 %} {% endif -%} |
|
|
{%- endfor %} |
|
|
{% endfor %} |
|
|
Turn: It is your turn ({{ side_to_move }}) |
|
|
Legal Moves: {{ legal_moves_uci }} |
|
|
|
|
|
## Output format |
|
|
|
|
|
<reason>...brief thinking (1-2 first-person very short concise sentences, identifying threat or opportunity, then deciding on the best move to play next)...</reason> |
|
|
<uci_move>...your_move...</uci_move> |
|
|
|
|
|
NOTE: capital letters are white, lowercase are black.""") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def get_model_move(fen: str) -> tuple[str, str, str]: |
|
|
"""Get model's move for the given position. Returns (uci_move, reasoning, raw_output).""" |
|
|
board = chess.Board(fen) |
|
|
turn = "white" if board.turn else "black" |
|
|
|
|
|
messages = [ |
|
|
{"role": "user", "content": USER_PROMPT.render( |
|
|
FEN=fen, |
|
|
side_to_move=turn, |
|
|
legal_moves_uci=", ".join([move.uci() |
|
|
for move in board.legal_moves]) |
|
|
)}, |
|
|
] |
|
|
|
|
|
text = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
) |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt").to(model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=256, |
|
|
temperature=0.7, |
|
|
top_p=0.8, |
|
|
top_k=20, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
) |
|
|
|
|
|
generated = tokenizer.decode( |
|
|
outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=False) |
|
|
|
|
|
|
|
|
think_match = re.search(r'<reason>(.*?)</reason>', generated, re.DOTALL) |
|
|
move_match = re.search(r'<uci_move>(.*?)</uci_move>', generated) |
|
|
|
|
|
reasoning = think_match.group(1).strip( |
|
|
) if think_match else "No reasoning provided" |
|
|
uci_move = move_match.group(1).strip() if move_match else None |
|
|
|
|
|
|
|
|
raw_output = generated.split('<|im_end|>')[0].strip() |
|
|
|
|
|
return uci_move, reasoning, raw_output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def play_move(fen: str) -> tuple[str, str, str, str]: |
|
|
""" |
|
|
Process the position after player's move and get AI response. |
|
|
Returns: (new_fen, status, reasoning, raw_output) |
|
|
""" |
|
|
board = chess.Board(fen) |
|
|
|
|
|
|
|
|
if board.is_game_over(): |
|
|
result = get_game_result(board) |
|
|
return fen, f"🏁 {result}", "", "" |
|
|
|
|
|
|
|
|
if not board.turn: |
|
|
uci_move, reasoning, raw_output = get_model_move(fen) |
|
|
|
|
|
if uci_move: |
|
|
try: |
|
|
move = chess.Move.from_uci(uci_move) |
|
|
if move in board.legal_moves: |
|
|
board.push(move) |
|
|
else: |
|
|
reasoning = f"⚠️ Model suggested illegal move: {uci_move}. " + reasoning |
|
|
except: |
|
|
reasoning = f"⚠️ Model output invalid move: {uci_move}. " + reasoning |
|
|
|
|
|
|
|
|
if board.is_game_over(): |
|
|
result = get_game_result(board) |
|
|
return board.fen(), f"🏁 {result}", reasoning, raw_output |
|
|
|
|
|
turn_str = "White (You)" if board.turn else "Black (AI)" |
|
|
status = f"**Turn:** {turn_str}" |
|
|
if board.is_check(): |
|
|
status += " ⚠️ CHECK!" |
|
|
|
|
|
return board.fen(), status, reasoning, raw_output |
|
|
|
|
|
|
|
|
turn_str = "White (You)" if board.turn else "Black (AI)" |
|
|
status = f"**Turn:** {turn_str}" |
|
|
if board.is_check(): |
|
|
status += " ⚠️ CHECK!" |
|
|
|
|
|
return fen, status, gr.update(), gr.update() |
|
|
|
|
|
|
|
|
def get_game_result(board: chess.Board) -> str: |
|
|
"""Get the game result string.""" |
|
|
if board.is_checkmate(): |
|
|
winner = "Black" if board.turn else "White" |
|
|
return f"Checkmate! {winner} wins!" |
|
|
elif board.is_stalemate(): |
|
|
return "Stalemate - Draw" |
|
|
elif board.is_insufficient_material(): |
|
|
return "Draw - Insufficient material" |
|
|
elif board.is_fifty_moves(): |
|
|
return "Draw - 50 move rule" |
|
|
elif board.is_repetition(): |
|
|
return "Draw - Repetition" |
|
|
return "Game Over" |
|
|
|
|
|
|
|
|
def reset_game() -> tuple[str, str, str, str]: |
|
|
"""Reset to starting position.""" |
|
|
return chess.STARTING_FEN, "**Turn:** White (You)", "", "" |
|
|
|
|
|
|
|
|
def ai_plays_first() -> tuple[str, str, str, str]: |
|
|
"""Let AI make the opening move.""" |
|
|
return play_move(chess.STARTING_FEN) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="♟️ Chess Reasoner") as demo: |
|
|
gr.Markdown(""" |
|
|
# ♟️ Chess Reasoner |
|
|
Play chess against a reasoning AI! You play as **White** - click on pieces to move them. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
board = Chessboard(value=chess.STARTING_FEN, |
|
|
label="", game_mode=True) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
status = gr.Markdown(value="**Turn:** White (You)") |
|
|
|
|
|
with gr.Row(): |
|
|
reset_btn = gr.Button("🔄 New Game", variant="primary") |
|
|
ai_first_btn = gr.Button("🤖 AI First") |
|
|
|
|
|
with gr.Accordion("🧠 AI Reasoning", open=True): |
|
|
reasoning = gr.Textbox( |
|
|
label="Thinking", lines=3, interactive=False) |
|
|
|
|
|
with gr.Accordion("📝 Raw Output", open=False): |
|
|
raw_output = gr.Textbox( |
|
|
label="Model Output", lines=5, interactive=False) |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
**Model:** [nuriyev/chess-reasoner-grpo](https://huggingface.co/nuriyev/chess-reasoner-grpo) • Fine-tuned from Qwen3-4B-Instruct |
|
|
""") |
|
|
|
|
|
|
|
|
board.change( |
|
|
fn=play_move, |
|
|
inputs=[board], |
|
|
outputs=[board, status, reasoning, raw_output] |
|
|
) |
|
|
|
|
|
reset_btn.click( |
|
|
fn=reset_game, |
|
|
outputs=[board, status, reasoning, raw_output] |
|
|
) |
|
|
|
|
|
ai_first_btn.click( |
|
|
fn=ai_plays_first, |
|
|
outputs=[board, status, reasoning, raw_output] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(ssr_mode=False) |
|
|
|