From 8b90d9e9f153cf4009a8b00027503557c8dc1c2e Mon Sep 17 00:00:00 2001 From: Dave Carroll Date: Tue, 9 Nov 2021 00:38:27 -0800 Subject: [PATCH 1/2] implemented regenerating bomb and connected its effects to RL --- gym_snake/envs/snakeGameGym.py | 54 ++++++++++++++++++++++++++++------ gym_snake/envs/snake_env.py | 4 +-- 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/gym_snake/envs/snakeGameGym.py b/gym_snake/envs/snakeGameGym.py index 13b80e2..8800c9f 100644 --- a/gym_snake/envs/snakeGameGym.py +++ b/gym_snake/envs/snakeGameGym.py @@ -9,6 +9,7 @@ from gym_snake.envs.snakeGame import SnakeGame from gym_snake.envs.snake import Snake import pygame +import random class SnakeGameGym(SnakeGame): @@ -41,6 +42,8 @@ def __init__(self, fps: int, use_pygame: bool = True): self.snake = Snake(self.rows,self.cols) self.fruit_pos = (0,0) self.generate_fruit() + self.bomb_pos = (0,0) + self.generate_bomb() self.score = 0 self.high_score = 0 @@ -69,6 +72,7 @@ def get_board(self) -> np.ndarray: 1 is space with fruit in it 2 is space with snake body in it 3 is space with snake head in it + 4 is space with bomb in it """ # Initializes empty board board = np.zeros([self.rows, self.cols], dtype=int) @@ -78,6 +82,11 @@ def get_board(self) -> np.ndarray: fruit_col = self.fruit_pos[1] board[fruit_row][fruit_col] = 1 + # Add Bomb + bomb_row = self.bomb_pos[0] + bomb_col = self.bomb_pos[1] + board[bomb_row][bomb_col] = 4 + # Add Snake to Board for i in range(len(self.snake.body)): pos = self.snake.body[i] @@ -110,6 +119,21 @@ def move_snake(self, action: spaces.Discrete(4)) -> None: self.snake.update_body_positions() + def generate_bomb(self): + """Function to generate a new random position for the bomb.""" + + bomb_row = random.randrange(0,self.rows) + bomb_col = random.randrange(0,self.cols) + + #Continually generate a location for the fruit until it is not in the snake's body + while (bomb_row, bomb_col) in self.snake.body or (bomb_row, bomb_col) == self.fruit_pos: + + bomb_row = random.randrange(0,self.rows) + bomb_col = random.randrange(0,self.cols) + + + self.bomb_pos = (bomb_row,bomb_col) + def respond_to_fruit_consumption(self) -> int: """ Function that extends a snake, generates new snake tail block and fruit, @@ -130,12 +154,13 @@ def check_collisions(self) -> int: Returns a reward based on these collisions """ fruit_collision = self.check_fruit_collision() + bomb_collision = self.check_bomb_collision() wall_collision = self.check_wall_collision() body_collision = self.check_body_collision() if fruit_collision: return 1 - elif wall_collision or body_collision: + elif wall_collision or body_collision or bomb_collision: return -1 else: return 0 @@ -145,10 +170,13 @@ def check_fruit_collision(self) -> bool: Function that detects and handles if the snake has collided with a fruit. """ #If we found a fruit - if self.snake.body[0] == self.fruit_pos: - return True - - return False + return self.snake.body[0] == self.fruit_pos + + def check_bomb_collision(self) -> bool: + """ + Function that detects and handles if the snake has collided with a bomb. + """ + return self.snake.body[0] == self.bomb_pos def check_wall_collision(self) -> bool: """ @@ -160,10 +188,7 @@ def check_wall_collision(self) -> bool: head_x = head[1] #If there is a wall collision, game over - if head_x == self.cols or head_y == self.rows or head_x < 0 or head_y < 0: - return True - - return False + return head_x == self.cols or head_y == self.rows or head_x < 0 or head_y < 0 def check_body_collision(self) -> bool: """ @@ -179,3 +204,14 @@ def check_body_collision(self) -> bool: return True return False + + def game_over(self): + """Function that restarts the game upon game over.""" + + self.snake = Snake(self.rows,self.cols) + self.generate_fruit() + self.generate_bomb() + self.restart = True + if self.score > self.high_score: + self.high_score = self.score + self.score = 0 diff --git a/gym_snake/envs/snake_env.py b/gym_snake/envs/snake_env.py index 4a23963..8a7fe0b 100644 --- a/gym_snake/envs/snake_env.py +++ b/gym_snake/envs/snake_env.py @@ -18,7 +18,7 @@ def __init__(self, use_pygame: bool = True): self.game = SnakeGameGym(fps, use_pygame=use_pygame) self.action_space = spaces.Discrete(4) - self.observation_space = spaces.Box(low=0, high=3, shape=(self.game.cols, self.game.rows), dtype=int) + self.observation_space = spaces.Box(low=0, high=4, shape=(self.game.cols, self.game.rows), dtype=int) @@ -41,7 +41,7 @@ def step(self, action: spaces.Discrete(4)) -> tuple: rewards = self.game.check_collisions() # Game is over if wall collision or body collision occurred. TODO: add end done for time limit - done = self.game.check_wall_collision() or self.game.check_body_collision() + done = self.game.check_wall_collision() or self.game.check_body_collision() or self.game.check_bomb_collision() if self.game.use_pygame: self.game.clock.tick(self.game.fps) From 13124afae41ddb29547dc63428db51d7783a59eb Mon Sep 17 00:00:00 2001 From: Dave Carroll Date: Tue, 9 Nov 2021 21:53:14 -0800 Subject: [PATCH 2/2] added multiple bombs functionality --- gym_snake/envs/snakeGameGym.py | 39 ++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/gym_snake/envs/snakeGameGym.py b/gym_snake/envs/snakeGameGym.py index 8800c9f..2ca6ca3 100644 --- a/gym_snake/envs/snakeGameGym.py +++ b/gym_snake/envs/snakeGameGym.py @@ -42,8 +42,9 @@ def __init__(self, fps: int, use_pygame: bool = True): self.snake = Snake(self.rows,self.cols) self.fruit_pos = (0,0) self.generate_fruit() - self.bomb_pos = (0,0) - self.generate_bomb() + self.num_bombs = 6 + self.bombs = [] + self.generate_bombs(self.num_bombs) self.score = 0 self.high_score = 0 @@ -82,10 +83,11 @@ def get_board(self) -> np.ndarray: fruit_col = self.fruit_pos[1] board[fruit_row][fruit_col] = 1 - # Add Bomb - bomb_row = self.bomb_pos[0] - bomb_col = self.bomb_pos[1] - board[bomb_row][bomb_col] = 4 + # Add Bombs + for bomb in self.bombs: + bomb_row = bomb[0] + bomb_col = bomb[1] + board[bomb_row][bomb_col] = 4 # Add Snake to Board for i in range(len(self.snake.body)): @@ -119,20 +121,25 @@ def move_snake(self, action: spaces.Discrete(4)) -> None: self.snake.update_body_positions() - def generate_bomb(self): + def generate_bombs(self, num_bombs): """Function to generate a new random position for the bomb.""" - bomb_row = random.randrange(0,self.rows) - bomb_col = random.randrange(0,self.cols) + bombs = [] - #Continually generate a location for the fruit until it is not in the snake's body - while (bomb_row, bomb_col) in self.snake.body or (bomb_row, bomb_col) == self.fruit_pos: + for i in range(num_bombs): bomb_row = random.randrange(0,self.rows) bomb_col = random.randrange(0,self.cols) + #Continually generate a location for the fruit until it is not in the snake's body + while (bomb_row, bomb_col) in self.snake.body or (bomb_row, bomb_col) == self.fruit_pos: - self.bomb_pos = (bomb_row,bomb_col) + bomb_row = random.randrange(0,self.rows) + bomb_col = random.randrange(0,self.cols) + + bombs.append((bomb_row, bomb_col)) + + self.bombs = bombs def respond_to_fruit_consumption(self) -> int: """ @@ -176,7 +183,11 @@ def check_bomb_collision(self) -> bool: """ Function that detects and handles if the snake has collided with a bomb. """ - return self.snake.body[0] == self.bomb_pos + for bomb in self.bombs: + if(self.snake.body[0]==bomb): + return True + + return False def check_wall_collision(self) -> bool: """ @@ -210,7 +221,7 @@ def game_over(self): self.snake = Snake(self.rows,self.cols) self.generate_fruit() - self.generate_bomb() + self.generate_bombs(self.num_bombs) self.restart = True if self.score > self.high_score: self.high_score = self.score