|
|
import gymnasium as gym
|
|
|
from gymnasium import spaces
|
|
|
import random
|
|
|
import pygame
|
|
|
import numpy as np
|
|
|
import collections
|
|
|
from collections import deque
|
|
|
from Environment_Constants import (
|
|
|
GRID_SIZE, CELL_SIZE, SCREEN_WIDTH, SCREEN_HEIGHT,
|
|
|
WHITE, BLACK, GREEN, RED, BLUE,
|
|
|
UP, DOWN, LEFT, RIGHT,
|
|
|
REWARD_FOOD, REWARD_COLLISION, REWARD_STEP,
|
|
|
FPS, OBSERVATION_SPACE_SIZE
|
|
|
)
|
|
|
|
|
|
class SnakeGameEnv(gym.Env):
|
|
|
metadata = {'render_modes': ['human', 'rgb_array'], 'render_fps': FPS}
|
|
|
|
|
|
def __init__(self, render_mode=None):
|
|
|
super().__init__()
|
|
|
self.grid_size = GRID_SIZE
|
|
|
self.cell_size = CELL_SIZE
|
|
|
self.screen_width = SCREEN_WIDTH
|
|
|
self.screen_height = SCREEN_HEIGHT
|
|
|
|
|
|
self.action_space = spaces.Discrete(3)
|
|
|
|
|
|
self.observation_space = spaces.Box(low=0, high=1,
|
|
|
shape=(OBSERVATION_SPACE_SIZE,),
|
|
|
dtype=np.float32)
|
|
|
|
|
|
self.render_mode = render_mode
|
|
|
self.window = None
|
|
|
self.clock = None
|
|
|
|
|
|
self._init_game_state()
|
|
|
|
|
|
def _init_game_state(self):
|
|
|
self.snake = deque()
|
|
|
self.head = (self.grid_size // 2, self.grid_size // 2)
|
|
|
self.snake.append(self.head)
|
|
|
self.snake.append((self.head[0], self.head[1] + 1))
|
|
|
self.snake.append((self.head[0], self.head[1] + 2))
|
|
|
|
|
|
self.direction = UP
|
|
|
self.score = 0
|
|
|
self.food = self._place_food()
|
|
|
self.game_over = False
|
|
|
self.steps_since_food = 0
|
|
|
self.length = len(self.snake)
|
|
|
|
|
|
|
|
|
def _place_food(self):
|
|
|
while True:
|
|
|
x = random.randrange(self.grid_size)
|
|
|
y = random.randrange(self.grid_size)
|
|
|
food_pos = (x, y)
|
|
|
if food_pos not in self.snake:
|
|
|
return food_pos
|
|
|
|
|
|
def _is_position_safe_for_observation(self, pos):
|
|
|
px, py = pos
|
|
|
if not (0 <= px < self.grid_size and 0 <= py < self.grid_size):
|
|
|
return False
|
|
|
if pos in list(self.snake)[1:]:
|
|
|
return False
|
|
|
return True
|
|
|
|
|
|
def _get_observation(self):
|
|
|
obs = np.zeros(OBSERVATION_SPACE_SIZE, dtype=np.float32)
|
|
|
|
|
|
hx, hy = self.head
|
|
|
|
|
|
if self.direction == UP:
|
|
|
dir_straight = UP
|
|
|
dir_right = RIGHT
|
|
|
dir_left = LEFT
|
|
|
elif self.direction == DOWN:
|
|
|
dir_straight = DOWN
|
|
|
dir_right = LEFT
|
|
|
dir_left = RIGHT
|
|
|
elif self.direction == LEFT:
|
|
|
dir_straight = LEFT
|
|
|
dir_right = UP
|
|
|
dir_left = DOWN
|
|
|
elif self.direction == RIGHT:
|
|
|
dir_straight = RIGHT
|
|
|
dir_right = DOWN
|
|
|
dir_left = UP
|
|
|
|
|
|
check_pos_straight = (hx + dir_straight[0], hy + dir_straight[1])
|
|
|
check_pos_right = (hx + dir_right[0], hy + dir_right[1])
|
|
|
check_pos_left = (hx + dir_left[0], hy + dir_left[1])
|
|
|
|
|
|
obs[0] = 1 if not self._is_position_safe_for_observation(check_pos_straight) else 0
|
|
|
obs[1] = 1 if not self._is_position_safe_for_observation(check_pos_right) else 0
|
|
|
obs[2] = 1 if not self._is_position_safe_for_observation(check_pos_left) else 0
|
|
|
|
|
|
fx, fy = self.food
|
|
|
if fy < hy: obs[3] = 1
|
|
|
if fy > hy: obs[4] = 1
|
|
|
if fx < hx: obs[5] = 1
|
|
|
if fx > hx: obs[6] = 1
|
|
|
|
|
|
if self.direction == UP: obs[7] = 1
|
|
|
elif self.direction == DOWN: obs[8] = 1
|
|
|
elif self.direction == LEFT: obs[9] = 1
|
|
|
elif self.direction == RIGHT: obs[10] = 1
|
|
|
|
|
|
return obs
|
|
|
|
|
|
def _get_action_mask(self):
|
|
|
|
|
|
mask = np.array([True, True, True], dtype=bool)
|
|
|
hx, hy = self.head
|
|
|
|
|
|
potential_directions = [
|
|
|
self.direction,
|
|
|
None,
|
|
|
None
|
|
|
]
|
|
|
|
|
|
if self.direction == UP:
|
|
|
potential_directions[1] = RIGHT
|
|
|
potential_directions[2] = LEFT
|
|
|
elif self.direction == DOWN:
|
|
|
potential_directions[1] = LEFT
|
|
|
potential_directions[2] = RIGHT
|
|
|
elif self.direction == LEFT:
|
|
|
potential_directions[1] = UP
|
|
|
potential_directions[2] = DOWN
|
|
|
elif self.direction == RIGHT:
|
|
|
potential_directions[1] = DOWN
|
|
|
potential_directions[2] = UP
|
|
|
|
|
|
def _is_potential_move_illegal(pos_to_check, current_snake, food_pos):
|
|
|
if not (0 <= pos_to_check[0] < self.grid_size and 0 <= pos_to_check[1] < self.grid_size):
|
|
|
return True
|
|
|
|
|
|
if pos_to_check in list(current_snake)[:-1]:
|
|
|
return True
|
|
|
|
|
|
if pos_to_check == current_snake[-1]:
|
|
|
if pos_to_check != food_pos:
|
|
|
return True
|
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
for action_idx, new_dir in enumerate(potential_directions):
|
|
|
dx, dy = new_dir
|
|
|
potential_head = (hx + dx, hy + dy)
|
|
|
if _is_potential_move_illegal(potential_head, self.snake, self.food):
|
|
|
mask[action_idx] = False
|
|
|
|
|
|
if not np.any(mask):
|
|
|
print(f"Warning: All actions masked out at head {self.head}, direction {self.direction}, food {self.food}. Attempting to find a fallback action.")
|
|
|
found_fallback = False
|
|
|
for i in range(3):
|
|
|
dx, dy = potential_directions[i]
|
|
|
potential_head = (hx + dx, hy + dy)
|
|
|
if not _is_potential_move_illegal(potential_head, self.snake, self.food):
|
|
|
mask[i] = True
|
|
|
found_fallback = True
|
|
|
|
|
|
if not found_fallback:
|
|
|
mask[np.random.choice(3)] = True
|
|
|
print("Critical Warning: No legal actions found even after fallback logic. Enabling a random action to prevent deadlock.")
|
|
|
|
|
|
return mask
|
|
|
|
|
|
def reset(self, seed=None, options=None):
|
|
|
super().reset(seed=seed)
|
|
|
self._init_game_state()
|
|
|
observation = self._get_observation()
|
|
|
info = self._get_info()
|
|
|
|
|
|
if not np.any(info['action_mask']):
|
|
|
print("Warning: No valid actions found in initial reset state.")
|
|
|
|
|
|
if self.render_mode == 'human':
|
|
|
self._render_frame()
|
|
|
return observation, info
|
|
|
|
|
|
def _get_info(self):
|
|
|
"""Returns environment information, including the action mask."""
|
|
|
return {
|
|
|
"score": self.score,
|
|
|
"snake_length": len(self.snake),
|
|
|
"action_mask": self._get_action_mask()
|
|
|
}
|
|
|
|
|
|
def step(self, action):
|
|
|
|
|
|
new_direction = self.direction
|
|
|
|
|
|
if action == 1:
|
|
|
if self.direction == UP: new_direction = RIGHT
|
|
|
elif self.direction == DOWN: new_direction = LEFT
|
|
|
elif self.direction == LEFT: new_direction = UP
|
|
|
elif self.direction == RIGHT: new_direction = DOWN
|
|
|
elif action == 2:
|
|
|
if self.direction == UP: new_direction = LEFT
|
|
|
elif self.direction == DOWN: new_direction = RIGHT
|
|
|
elif self.direction == LEFT: new_direction = DOWN
|
|
|
elif self.direction == RIGHT: new_direction = UP
|
|
|
elif action != 0:
|
|
|
raise ValueError(f"Received invalid action={action} which is not part of the action space.")
|
|
|
|
|
|
self.direction = new_direction
|
|
|
|
|
|
hx, hy = self.head
|
|
|
dx, dy = self.direction
|
|
|
new_head = (hx + dx, hy + dy)
|
|
|
|
|
|
reward = REWARD_STEP
|
|
|
terminated = False
|
|
|
truncated = False
|
|
|
if not (0 <= new_head[0] < self.grid_size and 0 <= new_head[1] < self.grid_size):
|
|
|
terminated = True
|
|
|
reward = REWARD_COLLISION
|
|
|
|
|
|
elif new_head in list(self.snake)[:-1]:
|
|
|
terminated = True
|
|
|
reward = REWARD_COLLISION
|
|
|
elif new_head == self.snake[-1] and new_head != self.food:
|
|
|
terminated = True
|
|
|
reward = REWARD_COLLISION
|
|
|
|
|
|
if terminated:
|
|
|
self.game_over = True
|
|
|
else:
|
|
|
self.snake.appendleft(new_head)
|
|
|
self.head = new_head
|
|
|
|
|
|
if new_head == self.food:
|
|
|
self.score += 1
|
|
|
self.length += 1
|
|
|
reward = REWARD_FOOD
|
|
|
self.food = self._place_food()
|
|
|
self.steps_since_food = 0
|
|
|
else:
|
|
|
self.snake.pop()
|
|
|
self.steps_since_food += 1
|
|
|
|
|
|
if self.steps_since_food >= self.grid_size * self.grid_size * 1.5:
|
|
|
terminated = True
|
|
|
truncated = True
|
|
|
reward = REWARD_COLLISION
|
|
|
|
|
|
|
|
|
observation = self._get_observation()
|
|
|
info = self._get_info()
|
|
|
|
|
|
if self.render_mode == 'human':
|
|
|
self._render_frame()
|
|
|
|
|
|
return observation, reward, terminated, truncated, info
|
|
|
|
|
|
def _render_frame(self):
|
|
|
if self.window is None and self.render_mode == 'human':
|
|
|
pygame.init()
|
|
|
pygame.display.init()
|
|
|
self.window = pygame.display.set_mode((self.screen_width, self.screen_height))
|
|
|
pygame.display.set_caption("Snake AI Training")
|
|
|
if self.clock is None and self.render_mode == 'human':
|
|
|
self.clock = pygame.time.Clock()
|
|
|
|
|
|
if self.render_mode == 'human':
|
|
|
self.window.fill(BLACK)
|
|
|
|
|
|
pygame.draw.rect(self.window, RED, (self.food[0] * self.cell_size,
|
|
|
self.food[1] * self.cell_size,
|
|
|
self.cell_size, self.cell_size))
|
|
|
|
|
|
for i, segment in enumerate(self.snake):
|
|
|
color = BLUE if i == 0 else GREEN
|
|
|
pygame.draw.rect(self.window, color, (segment[0] * self.cell_size,
|
|
|
segment[1] * self.cell_size,
|
|
|
self.cell_size, self.cell_size))
|
|
|
|
|
|
for x in range(0, self.screen_width, self.cell_size):
|
|
|
pygame.draw.line(self.window, WHITE, (x, 0), (x, self.screen_height))
|
|
|
for y in range(0, self.screen_height, self.cell_size):
|
|
|
pygame.draw.line(self.window, WHITE, (0, y), (self.screen_width, y))
|
|
|
|
|
|
font = pygame.font.Font(None, 25)
|
|
|
text = font.render(f"Score: {self.score}", True, WHITE)
|
|
|
self.window.blit(text, (5, 5))
|
|
|
|
|
|
pygame.event.pump()
|
|
|
pygame.display.flip()
|
|
|
self.clock.tick(self.metadata["render_fps"])
|
|
|
elif self.render_mode == "rgb_array":
|
|
|
surf = pygame.Surface((self.screen_width, self.screen_height))
|
|
|
surf.fill(BLACK)
|
|
|
pygame.draw.rect(surf, RED, (self.food[0] * self.cell_size,
|
|
|
self.food[1] * self.cell_size,
|
|
|
self.cell_size, self.cell_size))
|
|
|
for i, segment in enumerate(self.snake):
|
|
|
color = BLUE if i == 0 else GREEN
|
|
|
pygame.draw.rect(surf, color, (segment[0] * self.cell_size,
|
|
|
segment[1] * self.cell_size,
|
|
|
self.cell_size, self.cell_size))
|
|
|
return np.transpose(np.array(pygame.surfarray.pixels3d(surf)), axes=(1, 0, 2))
|
|
|
|
|
|
def close(self):
|
|
|
if self.window is not None:
|
|
|
pygame.display.quit()
|
|
|
pygame.quit()
|
|
|
self.window = None
|
|
|
self.clock = None |