SnakeAI_TF_PPO_V1 / Snake_EnvAndAgent.py
privateboss's picture
Create Snake_EnvAndAgent.py
93dd337 verified
import gymnasium as gym
from gymnasium import spaces
import random
import pygame
import numpy as np
import collections
from collections import deque
from Environment_Constants import (
    GRID_SIZE, CELL_SIZE, SCREEN_WIDTH, SCREEN_HEIGHT,
    WHITE, BLACK, GREEN, RED, BLUE,
    UP, DOWN, LEFT, RIGHT,
    REWARD_FOOD, REWARD_COLLISION, REWARD_STEP,
    FPS
)
class SnakeGameEnv(gym.Env):
    metadata = {'render_modes': ['human', 'rgb_array'], 'render_fps': FPS}
    def __init__(self, render_mode=None):
        super().__init__()
        self.grid_size = GRID_SIZE
        self.cell_size = CELL_SIZE
        self.screen_width = SCREEN_WIDTH
        self.screen_height = SCREEN_HEIGHT
        self.action_space = spaces.Discrete(3)
        self.observation_space = spaces.Box(low=0, high=1, shape=(11,), dtype=np.float32)
        self.render_mode = render_mode
        self.window = None
        self.clock = None
        self._init_game_state()
    def _init_game_state(self):
        self.snake = deque()
        self.head = (self.grid_size // 2, self.grid_size // 2)
        self.snake.append(self.head)
   
        for _ in range(2):
            self.snake.append((self.head[0], self.head[1] + (_ + 1)))
        self.direction = UP
        self.score = 0
        self.food = self._place_food()
        self.game_over = False
        self.steps_since_food = 0
    def _place_food(self):
        while True:
            x = random.randrange(self.grid_size)
            y = random.randrange(self.grid_size)
            food_pos = (x, y)
            if food_pos not in self.snake:
                return food_pos
    def _get_observation(self):
     
        obs = np.zeros(11, dtype=np.float32)
        hx, hy = self.head
        if self.direction == UP:
            dir_straight = UP
            dir_right = RIGHT
            dir_left = LEFT
        elif self.direction == DOWN:
            dir_straight = DOWN
            dir_right = LEFT
            dir_left = RIGHT
        elif self.direction == LEFT:
            dir_straight = LEFT
            dir_right = UP
            dir_left = DOWN
        elif self.direction == RIGHT:
            dir_straight = RIGHT
            dir_right = DOWN
            dir_left = UP
        check_pos_straight = (hx + dir_straight[0], hy + dir_straight[1])
        check_pos_right = (hx + dir_right[0], hy + dir_right[1])
        check_pos_left = (hx + dir_left[0], hy + dir_left[1])
        def is_danger(pos):
            px, py = pos
            if not (0 <= px < self.grid_size and 0 <= py < self.grid_size):
                return True
           
            if pos in list(self.snake)[1:]:
                return True
            return False
        obs[0] = 1 if is_danger(check_pos_straight) else 0
        obs[1] = 1 if is_danger(check_pos_right) else 0  
        obs[2] = 1 if is_danger(check_pos_left) else 0
       
        fx, fy = self.food
        if fy < hy: obs[3] = 1
        if fy > hy: obs[4] = 1
        if fx < hx: obs[5] = 1
        if fx > hx: obs[6] = 1
        if self.direction == UP: obs[7] = 1
        elif self.direction == DOWN: obs[8] = 1
        elif self.direction == LEFT: obs[9] = 1
        elif self.direction == RIGHT: obs[10] = 1
        return obs
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self._init_game_state()
        observation = self._get_observation()
        info = self._get_info()
       
        if self.render_mode == 'human':
            self._render_frame()
        return observation, info
    def _get_info(self):
        return {"score": self.score, "snake_length": len(self.snake)}
    def step(self, action):
     
        current_dir_idx = [UP, DOWN, LEFT, RIGHT].index(self.direction)
       
        if action == 0:
            new_direction = self.direction
        elif action == 1:
            if self.direction == UP: new_direction = RIGHT
            elif self.direction == DOWN: new_direction = LEFT
            elif self.direction == LEFT: new_direction = UP
            elif self.direction == RIGHT: new_direction = DOWN
        elif action == 2:
            if self.direction == UP: new_direction = LEFT
            elif self.direction == DOWN: new_direction = RIGHT
            elif self.direction == LEFT: new_direction = DOWN
            elif self.direction == RIGHT: new_direction = UP
        else:
            raise ValueError(f"Received invalid action={action} which is not part of the action space.")
        self.direction = new_direction
        hx, hy = self.head
        dx, dy = self.direction
        new_head = (hx + dx, hy + dy)
        reward = REWARD_STEP
        terminated = False
       
        if not (0 <= new_head[0] < self.grid_size and 0 <= new_head[1] < self.grid_size):
            terminated = True
            reward = REWARD_COLLISION
       
        elif new_head in list(self.snake):
            terminated = True
            reward = REWARD_COLLISION
        if terminated:
            self.game_over = True
        else:
            self.snake.appendleft(new_head)
            self.head = new_head
            if new_head == self.food:
                self.score += 1
                reward = REWARD_FOOD
                self.food = self._place_food()
                self.steps_since_food = 0
            else:
                self.snake.pop()
                self.steps_since_food += 1
            if self.steps_since_food > self.grid_size * self.grid_size * 2:
                 terminated = True
                 reward = REWARD_COLLISION
        observation = self._get_observation()
        info = self._get_info()
        truncated = False
        if self.render_mode == 'human':
            self._render_frame()
        return observation, reward, terminated, truncated, info
    def _render_frame(self):
        if self.window is None and self.render_mode == 'human':
            pygame.init()
            pygame.display.init()
            self.window = pygame.display.set_mode((self.screen_width, self.screen_height))
            pygame.display.set_caption("Snake AI Training")
        if self.clock is None and self.render_mode == 'human':
            self.clock = pygame.time.Clock()
        if self.render_mode == 'human':
       
            self.window.fill(BLACK)
            pygame.draw.rect(self.window, RED, (self.food[0] * self.cell_size,
                                               self.food[1] * self.cell_size,
                                               self.cell_size, self.cell_size))
            for i, segment in enumerate(self.snake):
                color = BLUE if i == 0 else GREEN
                pygame.draw.rect(self.window, color, (segment[0] * self.cell_size,
                                                      segment[1] * self.cell_size,
                                                      self.cell_size, self.cell_size))
            for x in range(0, self.screen_width, self.cell_size):
                pygame.draw.line(self.window, WHITE, (x, 0), (x, self.screen_height))
            for y in range(0, self.screen_height, self.cell_size):
                pygame.draw.line(self.window, WHITE, (0, y), (self.screen_width, y))
            font = pygame.font.Font(None, 25)
            text = font.render(f"Score: {self.score}", True, WHITE)
            self.window.blit(text, (5, 5))
            pygame.event.pump()
            pygame.display.flip()
            self.clock.tick(self.metadata["render_fps"])
        elif self.render_mode == "rgb_array":
            surf = pygame.Surface((self.screen_width, self.screen_height))
            surf.fill(BLACK)
            pygame.draw.rect(surf, RED, (self.food[0] * self.cell_size,
                                         self.food[1] * self.cell_size,
                                         self.cell_size, self.cell_size))
           
            for i, segment in enumerate(self.snake):
                color = BLUE if i == 0 else GREEN
                pygame.draw.rect(surf, color, (segment[0] * self.cell_size,
                                               segment[1] * self.cell_size,
                                               self.cell_size, self.cell_size))
            return np.transpose(np.array(pygame.surfarray.pixels3d(surf)), axes=(1, 0, 2))
    def close(self):
        if self.window is not None:
            pygame.display.quit()
            pygame.quit()
            self.window = None
            self.clock = None