Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 42 additions & 7 deletions gym_snake/envs/snakeGameGym.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from gym_snake.envs.snakeGame import SnakeGame
from gym_snake.envs.snakeGym import SnakeGym
import pygame
import random


class SnakeGameGym(SnakeGame):
Expand Down Expand Up @@ -42,6 +43,9 @@ def __init__(self, fps: int, use_pygame: bool = True):
self.snake = SnakeGym(self.rows,self.cols, self.make_snake_rand_pos())
self.fruit_pos = (0,0)
self.generate_fruit()
self.num_bombs = 6
self.bombs = []
self.generate_bombs(self.num_bombs)
self.score = 0

if self.use_pygame:
Expand Down Expand Up @@ -74,6 +78,7 @@ def get_board(self) -> np.ndarray:
1 is space with fruit in it
2 is space with snake body in it
3 is space with snake head in it
4 is space with bomb in it
"""
# Initializes empty board
board = np.zeros([self.rows, self.cols], dtype=int)
Expand All @@ -83,6 +88,12 @@ def get_board(self) -> np.ndarray:
fruit_col = self.fruit_pos[1]
board[fruit_row][fruit_col] = 1

# Add Bombs
for bomb in self.bombs:
bomb_row = bomb[0]
bomb_col = bomb[1]
board[bomb_row][bomb_col] = 4

# Add Snake to Board
for i in range(len(self.snake.body)):
pos = self.snake.body[i]
Expand Down Expand Up @@ -115,6 +126,26 @@ def move_snake(self, action: spaces.Discrete(4)) -> None:

self.snake.update_body_positions()

def generate_bombs(self, num_bombs):
"""Function to generate a new random position for the bomb."""

bombs = []

for i in range(num_bombs):

bomb_row = random.randrange(0,self.rows)
bomb_col = random.randrange(0,self.cols)

#Continually generate a location for the fruit until it is not in the snake's body
while (bomb_row, bomb_col) in self.snake.body or (bomb_row, bomb_col) == self.fruit_pos:

bomb_row = random.randrange(0,self.rows)
bomb_col = random.randrange(0,self.cols)

bombs.append((bomb_row, bomb_col))

self.bombs = bombs

def respond_to_fruit_consumption(self) -> int:
"""
Function that extends a snake, generates new snake tail block and fruit,
Expand All @@ -134,9 +165,16 @@ def check_fruit_collision(self) -> bool:
Function that detects and handles if the snake has collided with a fruit.
"""
#If we found a fruit
if self.snake.body[0] == self.fruit_pos:
return True

return self.snake.body[0] == self.fruit_pos

def check_bomb_collision(self) -> bool:
"""
Function that detects and handles if the snake has collided with a bomb.
"""
for bomb in self.bombs:
if(self.snake.body[0]==bomb):
return True

return False

def check_wall_collision(self) -> bool:
Expand All @@ -149,10 +187,7 @@ def check_wall_collision(self) -> bool:
head_x = head[1]

#If there is a wall collision, game over
if head_x == self.cols or head_y == self.rows or head_x < 0 or head_y < 0:
return True

return False
return head_x == self.cols or head_y == self.rows or head_x < 0 or head_y < 0

def check_body_collision(self) -> bool:
"""
Expand Down
6 changes: 5 additions & 1 deletion gym_snake/envs/snake_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self,

self.reward_func = reward_func
self.action_space = spaces.Discrete(4)
self.observation_space = spaces.Box(low=0, high=3, shape=(self.game.cols, self.game.rows), dtype=int)
self.observation_space = spaces.Box(low=0, high=4, shape=(self.game.cols, self.game.rows), dtype=int)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the 4 be non-hard coded?




Expand Down Expand Up @@ -65,6 +65,9 @@ def step(self, action: spaces.Discrete(4)) -> tuple:
rewards = self.reward_func(reward_dict)

# Game is over if wall collision or body collision occurred. TODO: add end done for time limit
<<<<<<< HEAD
done = self.game.check_wall_collision() or self.game.check_body_collision() or self.game.check_bomb_collision()
=======
done = did_collide_wall or did_collide_body

# FIXME: Figure out what to do with info. stable_baseline3 seems to require episode object
Expand All @@ -73,6 +76,7 @@ def step(self, action: spaces.Discrete(4)) -> tuple:
# If there was a fruit collision during last frame, move the fruit.
if did_consume_fruit:
self.game.respond_to_fruit_consumption()
>>>>>>> e2d5ec34947fbd48b8eac39afae08a414df25e0c

if self.game.use_pygame:
self.game.clock.tick(self.game.fps)
Expand Down