diff --git a/gym_snake/envs/snakeGameGym.py b/gym_snake/envs/snakeGameGym.py index f634ae5..4fc5308 100644 --- a/gym_snake/envs/snakeGameGym.py +++ b/gym_snake/envs/snakeGameGym.py @@ -10,6 +10,7 @@ from gym_snake.envs.snakeGame import SnakeGame from gym_snake.envs.snakeGym import SnakeGym import pygame +import random class SnakeGameGym(SnakeGame): @@ -42,6 +43,9 @@ def __init__(self, fps: int, use_pygame: bool = True): self.snake = SnakeGym(self.rows,self.cols, self.make_snake_rand_pos()) self.fruit_pos = (0,0) self.generate_fruit() + self.num_bombs = 6 + self.bombs = [] + self.generate_bombs(self.num_bombs) self.score = 0 if self.use_pygame: @@ -74,6 +78,7 @@ def get_board(self) -> np.ndarray: 1 is space with fruit in it 2 is space with snake body in it 3 is space with snake head in it + 4 is space with bomb in it """ # Initializes empty board board = np.zeros([self.rows, self.cols], dtype=int) @@ -83,6 +88,12 @@ def get_board(self) -> np.ndarray: fruit_col = self.fruit_pos[1] board[fruit_row][fruit_col] = 1 + # Add Bombs + for bomb in self.bombs: + bomb_row = bomb[0] + bomb_col = bomb[1] + board[bomb_row][bomb_col] = 4 + # Add Snake to Board for i in range(len(self.snake.body)): pos = self.snake.body[i] @@ -115,6 +126,26 @@ def move_snake(self, action: spaces.Discrete(4)) -> None: self.snake.update_body_positions() + def generate_bombs(self, num_bombs): + """Function to generate a new random position for the bomb.""" + + bombs = [] + + for i in range(num_bombs): + + bomb_row = random.randrange(0,self.rows) + bomb_col = random.randrange(0,self.cols) + + #Continually generate a location for the fruit until it is not in the snake's body + while (bomb_row, bomb_col) in self.snake.body or (bomb_row, bomb_col) == self.fruit_pos: + + bomb_row = random.randrange(0,self.rows) + bomb_col = random.randrange(0,self.cols) + + bombs.append((bomb_row, bomb_col)) + + self.bombs = bombs + def respond_to_fruit_consumption(self) -> int: """ Function that extends a snake, generates new snake tail block and fruit, @@ -134,9 +165,16 @@ def check_fruit_collision(self) -> bool: Function that detects and handles if the snake has collided with a fruit. """ #If we found a fruit - if self.snake.body[0] == self.fruit_pos: - return True - + return self.snake.body[0] == self.fruit_pos + + def check_bomb_collision(self) -> bool: + """ + Function that detects and handles if the snake has collided with a bomb. + """ + for bomb in self.bombs: + if(self.snake.body[0]==bomb): + return True + return False def check_wall_collision(self) -> bool: @@ -149,10 +187,7 @@ def check_wall_collision(self) -> bool: head_x = head[1] #If there is a wall collision, game over - if head_x == self.cols or head_y == self.rows or head_x < 0 or head_y < 0: - return True - - return False + return head_x == self.cols or head_y == self.rows or head_x < 0 or head_y < 0 def check_body_collision(self) -> bool: """ diff --git a/gym_snake/envs/snake_env.py b/gym_snake/envs/snake_env.py index bcb4d61..2a9f001 100644 --- a/gym_snake/envs/snake_env.py +++ b/gym_snake/envs/snake_env.py @@ -31,7 +31,7 @@ def __init__(self, self.reward_func = reward_func self.action_space = spaces.Discrete(4) - self.observation_space = spaces.Box(low=0, high=3, shape=(self.game.cols, self.game.rows), dtype=int) + self.observation_space = spaces.Box(low=0, high=4, shape=(self.game.cols, self.game.rows), dtype=int) @@ -65,6 +65,9 @@ def step(self, action: spaces.Discrete(4)) -> tuple: rewards = self.reward_func(reward_dict) # Game is over if wall collision or body collision occurred. TODO: add end done for time limit +<<<<<<< HEAD + done = self.game.check_wall_collision() or self.game.check_body_collision() or self.game.check_bomb_collision() +======= done = did_collide_wall or did_collide_body # FIXME: Figure out what to do with info. stable_baseline3 seems to require episode object @@ -73,6 +76,7 @@ def step(self, action: spaces.Discrete(4)) -> tuple: # If there was a fruit collision during last frame, move the fruit. if did_consume_fruit: self.game.respond_to_fruit_consumption() +>>>>>>> e2d5ec34947fbd48b8eac39afae08a414df25e0c if self.game.use_pygame: self.game.clock.tick(self.game.fps)