| import gymnasium as gym | |
| from gymnasium import spaces | |
| import random | |
| import pygame | |
| import numpy as np | |
| import collections | |
| from collections import deque | |
| from Environment_Constants import ( | |
| GRID_SIZE, CELL_SIZE, SCREEN_WIDTH, SCREEN_HEIGHT, | |
| WHITE, BLACK, GREEN, RED, BLUE, | |
| UP, DOWN, LEFT, RIGHT, | |
| REWARD_FOOD, REWARD_COLLISION, REWARD_STEP, | |
| FPS | |
| ) | |
| class SnakeGameEnv(gym.Env): | |
| metadata = {'render_modes': ['human', 'rgb_array'], 'render_fps': FPS} | |
| def __init__(self, render_mode=None): | |
| super().__init__() | |
| self.grid_size = GRID_SIZE | |
| self.cell_size = CELL_SIZE | |
| self.screen_width = SCREEN_WIDTH | |
| self.screen_height = SCREEN_HEIGHT | |
| self.action_space = spaces.Discrete(3) | |
| self.observation_space = spaces.Box(low=0, high=1, shape=(11,), dtype=np.float32) | |
| self.render_mode = render_mode | |
| self.window = None | |
| self.clock = None | |
| self._init_game_state() | |
| def _init_game_state(self): | |
| self.snake = deque() | |
| self.head = (self.grid_size // 2, self.grid_size // 2) | |
| self.snake.append(self.head) | |
| for _ in range(2): | |
| self.snake.append((self.head[0], self.head[1] + (_ + 1))) | |
| self.direction = UP | |
| self.score = 0 | |
| self.food = self._place_food() | |
| self.game_over = False | |
| self.steps_since_food = 0 | |
| def _place_food(self): | |
| while True: | |
| x = random.randrange(self.grid_size) | |
| y = random.randrange(self.grid_size) | |
| food_pos = (x, y) | |
| if food_pos not in self.snake: | |
| return food_pos | |
| def _get_observation(self): | |
| obs = np.zeros(11, dtype=np.float32) | |
| hx, hy = self.head | |
| if self.direction == UP: | |
| dir_straight = UP | |
| dir_right = RIGHT | |
| dir_left = LEFT | |
| elif self.direction == DOWN: | |
| dir_straight = DOWN | |
| dir_right = LEFT | |
| dir_left = RIGHT | |
| elif self.direction == LEFT: | |
| dir_straight = LEFT | |
| dir_right = UP | |
| dir_left = DOWN | |
| elif self.direction == RIGHT: | |
| dir_straight = RIGHT | |
| dir_right = DOWN | |
| dir_left = UP | |
| check_pos_straight = (hx + dir_straight[0], hy + dir_straight[1]) | |
| check_pos_right = (hx + dir_right[0], hy + dir_right[1]) | |
| check_pos_left = (hx + dir_left[0], hy + dir_left[1]) | |
| def is_danger(pos): | |
| px, py = pos | |
| if not (0 <= px < self.grid_size and 0 <= py < self.grid_size): | |
| return True | |
| if pos in list(self.snake)[1:]: | |
| return True | |
| return False | |
| obs[0] = 1 if is_danger(check_pos_straight) else 0 | |
| obs[1] = 1 if is_danger(check_pos_right) else 0 | |
| obs[2] = 1 if is_danger(check_pos_left) else 0 | |
| fx, fy = self.food | |
| if fy < hy: obs[3] = 1 | |
| if fy > hy: obs[4] = 1 | |
| if fx < hx: obs[5] = 1 | |
| if fx > hx: obs[6] = 1 | |
| if self.direction == UP: obs[7] = 1 | |
| elif self.direction == DOWN: obs[8] = 1 | |
| elif self.direction == LEFT: obs[9] = 1 | |
| elif self.direction == RIGHT: obs[10] = 1 | |
| return obs | |
| def reset(self, seed=None, options=None): | |
| super().reset(seed=seed) | |
| self._init_game_state() | |
| observation = self._get_observation() | |
| info = self._get_info() | |
| if self.render_mode == 'human': | |
| self._render_frame() | |
| return observation, info | |
| def _get_info(self): | |
| return {"score": self.score, "snake_length": len(self.snake)} | |
| def step(self, action): | |
| current_dir_idx = [UP, DOWN, LEFT, RIGHT].index(self.direction) | |
| if action == 0: | |
| new_direction = self.direction | |
| elif action == 1: | |
| if self.direction == UP: new_direction = RIGHT | |
| elif self.direction == DOWN: new_direction = LEFT | |
| elif self.direction == LEFT: new_direction = UP | |
| elif self.direction == RIGHT: new_direction = DOWN | |
| elif action == 2: | |
| if self.direction == UP: new_direction = LEFT | |
| elif self.direction == DOWN: new_direction = RIGHT | |
| elif self.direction == LEFT: new_direction = DOWN | |
| elif self.direction == RIGHT: new_direction = UP | |
| else: | |
| raise ValueError(f"Received invalid action={action} which is not part of the action space.") | |
| self.direction = new_direction | |
| hx, hy = self.head | |
| dx, dy = self.direction | |
| new_head = (hx + dx, hy + dy) | |
| reward = REWARD_STEP | |
| terminated = False | |
| if not (0 <= new_head[0] < self.grid_size and 0 <= new_head[1] < self.grid_size): | |
| terminated = True | |
| reward = REWARD_COLLISION | |
| elif new_head in list(self.snake): | |
| terminated = True | |
| reward = REWARD_COLLISION | |
| if terminated: | |
| self.game_over = True | |
| else: | |
| self.snake.appendleft(new_head) | |
| self.head = new_head | |
| if new_head == self.food: | |
| self.score += 1 | |
| reward = REWARD_FOOD | |
| self.food = self._place_food() | |
| self.steps_since_food = 0 | |
| else: | |
| self.snake.pop() | |
| self.steps_since_food += 1 | |
| if self.steps_since_food > self.grid_size * self.grid_size * 2: | |
| terminated = True | |
| reward = REWARD_COLLISION | |
| observation = self._get_observation() | |
| info = self._get_info() | |
| truncated = False | |
| if self.render_mode == 'human': | |
| self._render_frame() | |
| return observation, reward, terminated, truncated, info | |
| def _render_frame(self): | |
| if self.window is None and self.render_mode == 'human': | |
| pygame.init() | |
| pygame.display.init() | |
| self.window = pygame.display.set_mode((self.screen_width, self.screen_height)) | |
| pygame.display.set_caption("Snake AI Training") | |
| if self.clock is None and self.render_mode == 'human': | |
| self.clock = pygame.time.Clock() | |
| if self.render_mode == 'human': | |
| self.window.fill(BLACK) | |
| pygame.draw.rect(self.window, RED, (self.food[0] * self.cell_size, | |
| self.food[1] * self.cell_size, | |
| self.cell_size, self.cell_size)) | |
| for i, segment in enumerate(self.snake): | |
| color = BLUE if i == 0 else GREEN | |
| pygame.draw.rect(self.window, color, (segment[0] * self.cell_size, | |
| segment[1] * self.cell_size, | |
| self.cell_size, self.cell_size)) | |
| for x in range(0, self.screen_width, self.cell_size): | |
| pygame.draw.line(self.window, WHITE, (x, 0), (x, self.screen_height)) | |
| for y in range(0, self.screen_height, self.cell_size): | |
| pygame.draw.line(self.window, WHITE, (0, y), (self.screen_width, y)) | |
| font = pygame.font.Font(None, 25) | |
| text = font.render(f"Score: {self.score}", True, WHITE) | |
| self.window.blit(text, (5, 5)) | |
| pygame.event.pump() | |
| pygame.display.flip() | |
| self.clock.tick(self.metadata["render_fps"]) | |
| elif self.render_mode == "rgb_array": | |
| surf = pygame.Surface((self.screen_width, self.screen_height)) | |
| surf.fill(BLACK) | |
| pygame.draw.rect(surf, RED, (self.food[0] * self.cell_size, | |
| self.food[1] * self.cell_size, | |
| self.cell_size, self.cell_size)) | |
| for i, segment in enumerate(self.snake): | |
| color = BLUE if i == 0 else GREEN | |
| pygame.draw.rect(surf, color, (segment[0] * self.cell_size, | |
| segment[1] * self.cell_size, | |
| self.cell_size, self.cell_size)) | |
| return np.transpose(np.array(pygame.surfarray.pixels3d(surf)), axes=(1, 0, 2)) | |
| def close(self): | |
| if self.window is not None: | |
| pygame.display.quit() | |
| pygame.quit() | |
| self.window = None | |
| self.clock = None | |