SnakeAI_TF_PPO_V0 / Snake_EnvAndAgent.py
privateboss's picture
Upload 8 files
2df2f26 verified
import gymnasium as gym
from gymnasium import spaces
import random
import pygame
import numpy as np
import collections
from collections import deque
from Environment_Constants import (
GRID_SIZE, CELL_SIZE, SCREEN_WIDTH, SCREEN_HEIGHT,
WHITE, BLACK, GREEN, RED, BLUE,
UP, DOWN, LEFT, RIGHT,
REWARD_FOOD, REWARD_COLLISION, REWARD_STEP,
FPS, OBSERVATION_SPACE_SIZE
)
class SnakeGameEnv(gym.Env):
metadata = {'render_modes': ['human', 'rgb_array'], 'render_fps': FPS}
def __init__(self, render_mode=None):
super().__init__()
self.grid_size = GRID_SIZE
self.cell_size = CELL_SIZE
self.screen_width = SCREEN_WIDTH
self.screen_height = SCREEN_HEIGHT
self.action_space = spaces.Discrete(3)
self.observation_space = spaces.Box(low=0, high=1,
shape=(OBSERVATION_SPACE_SIZE,),
dtype=np.float32)
self.render_mode = render_mode
self.window = None
self.clock = None
self._init_game_state()
def _init_game_state(self):
self.snake = deque()
self.head = (self.grid_size // 2, self.grid_size // 2)
self.snake.append(self.head)
self.snake.append((self.head[0], self.head[1] + 1))
self.snake.append((self.head[0], self.head[1] + 2))
self.direction = UP
self.score = 0
self.food = self._place_food()
self.game_over = False
self.steps_since_food = 0
self.length = len(self.snake)
def _place_food(self):
while True:
x = random.randrange(self.grid_size)
y = random.randrange(self.grid_size)
food_pos = (x, y)
if food_pos not in self.snake:
return food_pos
def _is_position_safe_for_observation(self, pos):
px, py = pos
if not (0 <= px < self.grid_size and 0 <= py < self.grid_size):
return False
if pos in list(self.snake)[1:]:
return False
return True
def _get_observation(self):
obs = np.zeros(OBSERVATION_SPACE_SIZE, dtype=np.float32)
hx, hy = self.head
if self.direction == UP:
dir_straight = UP
dir_right = RIGHT
dir_left = LEFT
elif self.direction == DOWN:
dir_straight = DOWN
dir_right = LEFT
dir_left = RIGHT
elif self.direction == LEFT:
dir_straight = LEFT
dir_right = UP
dir_left = DOWN
elif self.direction == RIGHT:
dir_straight = RIGHT
dir_right = DOWN
dir_left = UP
check_pos_straight = (hx + dir_straight[0], hy + dir_straight[1])
check_pos_right = (hx + dir_right[0], hy + dir_right[1])
check_pos_left = (hx + dir_left[0], hy + dir_left[1])
obs[0] = 1 if not self._is_position_safe_for_observation(check_pos_straight) else 0
obs[1] = 1 if not self._is_position_safe_for_observation(check_pos_right) else 0
obs[2] = 1 if not self._is_position_safe_for_observation(check_pos_left) else 0
fx, fy = self.food
if fy < hy: obs[3] = 1
if fy > hy: obs[4] = 1
if fx < hx: obs[5] = 1
if fx > hx: obs[6] = 1
if self.direction == UP: obs[7] = 1
elif self.direction == DOWN: obs[8] = 1
elif self.direction == LEFT: obs[9] = 1
elif self.direction == RIGHT: obs[10] = 1
return obs
def _get_action_mask(self):
mask = np.array([True, True, True], dtype=bool)
hx, hy = self.head
potential_directions = [
self.direction,
None,
None
]
if self.direction == UP:
potential_directions[1] = RIGHT
potential_directions[2] = LEFT
elif self.direction == DOWN:
potential_directions[1] = LEFT
potential_directions[2] = RIGHT
elif self.direction == LEFT:
potential_directions[1] = UP
potential_directions[2] = DOWN
elif self.direction == RIGHT:
potential_directions[1] = DOWN
potential_directions[2] = UP
def _is_potential_move_illegal(pos_to_check, current_snake, food_pos):
if not (0 <= pos_to_check[0] < self.grid_size and 0 <= pos_to_check[1] < self.grid_size):
return True
if pos_to_check in list(current_snake)[:-1]:
return True
if pos_to_check == current_snake[-1]:
if pos_to_check != food_pos:
return True
return False
for action_idx, new_dir in enumerate(potential_directions):
dx, dy = new_dir
potential_head = (hx + dx, hy + dy)
if _is_potential_move_illegal(potential_head, self.snake, self.food):
mask[action_idx] = False
if not np.any(mask):
print(f"Warning: All actions masked out at head {self.head}, direction {self.direction}, food {self.food}. Attempting to find a fallback action.")
found_fallback = False
for i in range(3): # Check Straight, Right, Left
dx, dy = potential_directions[i]
potential_head = (hx + dx, hy + dy)
if not _is_potential_move_illegal(potential_head, self.snake, self.food):
mask[i] = True
found_fallback = True
if not found_fallback:
mask[np.random.choice(3)] = True
print("Critical Warning: No legal actions found even after fallback logic. Enabling a random action to prevent deadlock.")
return mask
def reset(self, seed=None, options=None):
super().reset(seed=seed)
self._init_game_state()
observation = self._get_observation()
info = self._get_info()
if not np.any(info['action_mask']):
print("Warning: No valid actions found in initial reset state.")
if self.render_mode == 'human':
self._render_frame()
return observation, info
def _get_info(self):
"""Returns environment information, including the action mask."""
return {
"score": self.score,
"snake_length": len(self.snake),
"action_mask": self._get_action_mask()
}
def step(self, action):
new_direction = self.direction
if action == 1:
if self.direction == UP: new_direction = RIGHT
elif self.direction == DOWN: new_direction = LEFT
elif self.direction == LEFT: new_direction = UP
elif self.direction == RIGHT: new_direction = DOWN
elif action == 2:
if self.direction == UP: new_direction = LEFT
elif self.direction == DOWN: new_direction = RIGHT
elif self.direction == LEFT: new_direction = DOWN
elif self.direction == RIGHT: new_direction = UP
elif action != 0:
raise ValueError(f"Received invalid action={action} which is not part of the action space.")
self.direction = new_direction
hx, hy = self.head
dx, dy = self.direction
new_head = (hx + dx, hy + dy)
reward = REWARD_STEP
terminated = False
truncated = False
if not (0 <= new_head[0] < self.grid_size and 0 <= new_head[1] < self.grid_size):
terminated = True
reward = REWARD_COLLISION
elif new_head in list(self.snake)[:-1]:
terminated = True
reward = REWARD_COLLISION
elif new_head == self.snake[-1] and new_head != self.food:
terminated = True
reward = REWARD_COLLISION
if terminated:
self.game_over = True
else:
self.snake.appendleft(new_head)
self.head = new_head
if new_head == self.food:
self.score += 1
self.length += 1
reward = REWARD_FOOD
self.food = self._place_food()
self.steps_since_food = 0
else:
self.snake.pop()
self.steps_since_food += 1
if self.steps_since_food >= self.grid_size * self.grid_size * 1.5:
terminated = True
truncated = True
reward = REWARD_COLLISION
observation = self._get_observation()
info = self._get_info()
if self.render_mode == 'human':
self._render_frame()
return observation, reward, terminated, truncated, info
def _render_frame(self):
if self.window is None and self.render_mode == 'human':
pygame.init()
pygame.display.init()
self.window = pygame.display.set_mode((self.screen_width, self.screen_height))
pygame.display.set_caption("Snake AI Training")
if self.clock is None and self.render_mode == 'human':
self.clock = pygame.time.Clock()
if self.render_mode == 'human':
self.window.fill(BLACK)
pygame.draw.rect(self.window, RED, (self.food[0] * self.cell_size,
self.food[1] * self.cell_size,
self.cell_size, self.cell_size))
for i, segment in enumerate(self.snake):
color = BLUE if i == 0 else GREEN
pygame.draw.rect(self.window, color, (segment[0] * self.cell_size,
segment[1] * self.cell_size,
self.cell_size, self.cell_size))
for x in range(0, self.screen_width, self.cell_size):
pygame.draw.line(self.window, WHITE, (x, 0), (x, self.screen_height))
for y in range(0, self.screen_height, self.cell_size):
pygame.draw.line(self.window, WHITE, (0, y), (self.screen_width, y))
font = pygame.font.Font(None, 25)
text = font.render(f"Score: {self.score}", True, WHITE)
self.window.blit(text, (5, 5))
pygame.event.pump()
pygame.display.flip()
self.clock.tick(self.metadata["render_fps"])
elif self.render_mode == "rgb_array":
surf = pygame.Surface((self.screen_width, self.screen_height))
surf.fill(BLACK)
pygame.draw.rect(surf, RED, (self.food[0] * self.cell_size,
self.food[1] * self.cell_size,
self.cell_size, self.cell_size))
for i, segment in enumerate(self.snake):
color = BLUE if i == 0 else GREEN
pygame.draw.rect(surf, color, (segment[0] * self.cell_size,
segment[1] * self.cell_size,
self.cell_size, self.cell_size))
return np.transpose(np.array(pygame.surfarray.pixels3d(surf)), axes=(1, 0, 2))
def close(self):
if self.window is not None:
pygame.display.quit()
pygame.quit()
self.window = None
self.clock = None