diff --git a/.gitignore b/.gitignore index 56c94cc..0842a22 100755 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,162 @@ -__pycache__ -.DS_Store -Data + +# Created by https://www.toptal.com/developers/gitignore/api/python +# Edit at https://www.toptal.com/developers/gitignore?templates=python + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# End of https://www.toptal.com/developers/gitignore/api/python + +.idea/ +*.png \ No newline at end of file diff --git a/README.md b/README.md index 36e728e..9d8b590 100644 --- a/README.md +++ b/README.md @@ -22,9 +22,10 @@ Deep Learning is a subfield of machine learning with neural networks inspired by 2. Run `python3 ai.py` command in terminal. ### Creating Training Dataset: +1. Change the desired kets in the game_control.py file's get keys function, by changing return keys to desired ones 1. Run `python3 create_dataset.py` command in terminal. 2. Play your desired game. -3. Stop `create_dataset` program with `Cntrl-C` in terminal. +3. Stop `create_dataset` program with `ctrl-C` in terminal. ### Model Training: `python3 train.py` diff --git a/ai.py b/ai.py index c99cde1..9c3e02e 100755 --- a/ai.py +++ b/ai.py @@ -1,58 +1,83 @@ -# Arda Mavi -import os -import platform +""" +This is the main file for the AI. + +Author: Arda Mavi +""" +import pickle + +import time import numpy as np -from time import sleep -from PIL import ImageGrab -from game_control import * +from PIL import Image +from tensorflow.keras.models import model_from_json +from mss import mss + +from game_control import get_key, press, release, click from predict import predict -from game_control import * -from keras.models import model_from_json + def main(): - # Get Model: - model_file = open('Data/Model/model.json', 'r') - model = model_file.read() - model_file.close() + """ + Main function. + + :return: None + """ + with open("Data/Model/model.json", "r") as model_file: + model = model_file.read() model = model_from_json(model) model.load_weights("Data/Model/weights.h5") - print('AI start now!') - + print("AI starting now!") + with open("listfile.data", "rb") as filehandle: + # read the data as binary data stream + places_list = pickle.load(filehandle) while 1: # Get screenshot: - screen = ImageGrab.grab() - # Image to numpy array: - screen = np.array(screen) + with mss() as sct: + monitor = sct.monitors[1] + sct_img = sct.grab(monitor) + # Convert to PIL/Pillow Image + screen = Image.frombytes("RGB", sct_img.size, sct_img.bgra, "raw", + "BGRX") + screen = np.array(screen)[ + :, :, :3 + ] # Get first 3 channel from image as numpy array. # 4 channel(PNG) to 3 channel(JPG) - Y = predict(model, screen) - if Y == [0,0,0,0]: + y_ai = predict(model, screen) + print(y_ai) + y_ai = places_list[y_ai] + y_ai = [int(i) for i in y_ai] + print(y_ai) + if y_ai == [0, 0, 0, 0]: # Not action continue - elif Y[0] == -1 and Y[1] == -1: + if y_ai[0] == -1 and y_ai[1] == -1: # Only keyboard action. - key = get_key(Y[3]) - if Y[2] == 1: - # Press: - press(key) - else: - # Release: - release(key) - elif Y[2] == 0 and Y[3] == 0: - # Only mouse action. - click(Y[0], Y[1]) - else: - # Mouse and keyboard action. - # Mouse: - click(Y[0], Y[1]) - # Keyboard: - key = get_key(Y[3]) - if Y[2] == 1: + key = get_key(y_ai[3]) + if y_ai[2] == 1: # Press: press(key) else: # Release: release(key) + elif y_ai[2] == 0 and y_ai[3] == 0: + # Click action. + click(y_ai[0], y_ai[1]) + + # else: + # # Mouse and keyboard action. + # # Mouse: + # click(int(y_ai[0]), int(y_ai[1])) + # # Keyboard: + # key = get_key(int(y_ai[3])) + # if y_ai[2] == 1: + # # Press: + # press(key) + # else: + # # Release: + # release(key) + + time.sleep(0.005) + -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/create_dataset.py b/create_dataset.py old mode 100644 new mode 100755 index 1142edb..700a804 --- a/create_dataset.py +++ b/create_dataset.py @@ -1,57 +1,126 @@ -# Arda Mavi +""" +This file will create a dataset of images and labels for training. + +Author: Arda Mavi +""" import os -import sys -import platform +import time +from multiprocessing import Process + import numpy as np -from time import sleep -from PIL import ImageGrab -from game_control import * -from predict import predict -from scipy.misc import imresize +from PIL import Image +from mss import mss +from pynput.keyboard import Listener as key_listener +from pynput.mouse import Listener as mouse_listener + from game_control import get_id from get_dataset import save_img -from multiprocessing import Process -from keras.models import model_from_json -from pynput.mouse import Listener as mouse_listener -from pynput.keyboard import Listener as key_listener + def get_screenshot(): - img = ImageGrab.grab() - img = np.array(img)[:,:,:3] # Get first 3 channel from image as numpy array. - img = imresize(img, (150, 150, 3)).astype('float32')/255. - return img + """ + This function will get the screenshot of the game. + :return: num_array of the screenshot + """ + with mss() as sct: + monitor = sct.monitors[1] + sct_img = sct.grab(monitor) + # Convert to PIL/Pillow Image + img = Image.frombytes('RGB', sct_img.size, sct_img.bgra, 'raw', 'BGRX') + img = np.array(img)[:, :, :3] # Get first 3 channel from image as numpy array. + # resize it with PIL, because scipy.misc.imresize is deprecated. + img = Image.fromarray(img) + img = img.resize((150, 150), resample=Image.NEAREST) + # img = imresize(img, (150, 150, 3)).astype('float32') / 255. + return img + def save_event_keyboard(data_path, event, key): + """ + This function will save the event of the keyboard. + :param data_path: path to save the event + :param event: down or up + :param key: which key is pressed + """ key = get_id(key) - data_path = data_path + '/-1,-1,{0},{1}'.format(event, key) - screenshot = get_screenshot() - save_img(data_path, screenshot) - return + if key != 1000: + data_path = data_path + '/-1,-1,{0},{1},{2}'.format(event, key, time.time()) + screenshot = get_screenshot() + save_img(data_path, screenshot) + + +def save_event_mouse(data_path, x_coordinate, y_coordinate, button, pressed): + """ + This function will save the event of the mouse. + :param data_path: path to save the event + :param x_coordinate: x coordinate + :param y_coordinate: y coordinate + """ + # 539,996,0,0,1643879876.0606766 + # 539 x coordinate + # 996 y coordinate + + # 1643879876.0606766 time since epoch + # button is an enum -def save_event_mouse(data_path, x, y): - data_path = data_path + '/{0},{1},0,0'.format(x, y) + # cut button at dot and keep the last part. + button = button.name.split('.')[-1] + data_path = data_path + '/{0},{1},{2},{3},{4}'.format(x_coordinate, y_coordinate, button, int(pressed), + time.time()) screenshot = get_screenshot() save_img(data_path, screenshot) - return + def listen_mouse(): + """ + This function will listen the mouse and save the event. + :return: None + """ data_path = 'Data/Train_Data/Mouse' if not os.path.exists(data_path): os.makedirs(data_path) - def on_click(x, y, button, pressed): - save_event_mouse(data_path, x, y) + def on_click(x_coordinate, y_coordinate, button, pressed): + """ + This function will get the x and y coordinate of the mouse, when a click happens. + :param x_coordinate: int + :param y_coordinate: int + :TODO: fix the function. Help from: https://pynput.readthedocs.io/en/latest/mouse.html + """ + print(data_path, x_coordinate, y_coordinate, button, pressed) + save_event_mouse(data_path, x_coordinate, y_coordinate, button, pressed) - def on_scroll(x, y, dx, dy): + def on_scroll(x_cord, y_cord, dx, dy): + """ + This function will get the new x and y coordinate of the mouse, when a scroll happens. + dx and dy are the amount of scrolling. + :param x_cord: int + :param y_cord: int + :param dx: int + :param dy: int + :return: None + """ pass - - def on_move(x, y): + + def on_move(x_cord, y_cord): + """ + This function will get the new x and y coordinate of the mouse, when a move happens. + If this callback raises an exception, or returns False, the mouse tracking will be stopped. + :param x_cord: int + :param y_cord: int + :return: None + """ pass with mouse_listener(on_move=on_move, on_click=on_click, on_scroll=on_scroll) as listener: listener.join() + def listen_keyboard(): + """ + This function will listen the keyboard and save the event. + :return: None + """ data_path = 'Data/Train_Data/Keyboard' if not os.path.exists(data_path): os.makedirs(data_path) @@ -65,7 +134,11 @@ def on_release(key): with key_listener(on_press=on_press, on_release=on_release) as listener: listener.join() + def main(): + """ + This is the main function. + """ dataset_path = 'Data/Train_Data/' if not os.path.exists(dataset_path): os.makedirs(dataset_path) @@ -73,7 +146,7 @@ def main(): # Start to listening mouse with new process: Process(target=listen_mouse, args=()).start() listen_keyboard() - return + if __name__ == '__main__': main() diff --git a/database_process.py b/database_process.py deleted file mode 100644 index 73ec22b..0000000 --- a/database_process.py +++ /dev/null @@ -1,35 +0,0 @@ -# Arda Mavi -import os -import sqlite3 - -def set_sql_connect(database_name): - return sqlite3.connect(database_name) -def set_sql_cursor(database_connect): - return database_connect.cursor() - -def close_connect(vt): - if vt: - vt.commit() - vt.close - -def set_connect_and_cursor(path='Data/database.sqlite'): - vt = set_sql_connect(path) - db = set_sql_cursor(vt) - return vt, db - -def create_table(table_name, columns): - vt, db = set_connect_and_cursor() - db.execute("CREATE TABLE IF NOT EXISTS {0} ({1})".format(table_name, columns)) - close_connect(vt) - -def get_data(sql_command): - vt, db = set_connect_and_cursor() - db.execute(sql_command) - gelen_veri = db.fetchall() - close_connect(vt) - return gelen_veri - -def add_data(table, adding): - vt, db = set_connect_and_cursor() - db.execute("INSERT INTO '{0}' VALUES ({1})".format(table, adding)) - close_connect(vt) diff --git a/game_control.py b/game_control.py index a37b8a9..056826e 100755 --- a/game_control.py +++ b/game_control.py @@ -1,38 +1,110 @@ -# Arda Mavi -from pynput.mouse import Button, Controller as Mouse -from pynput.keyboard import Controller as Keyboard +""" +This file contains the game control logic. + +Author: Arda Mavi +""" +import pyautogui + +from pynput.mouse import Controller as Mouse +from pynput.keyboard import Key + # For encoding keyboard keys: def get_keys(): - return ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e','f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', 'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace', 'browserback', 'browserfavorites', 'browserforward', 'browserhome', 'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear', 'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete', 'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10', 'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20', 'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', 'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja', 'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail', 'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack', 'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6', 'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn', 'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn', 'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator', 'shift', 'shiftleft', 'shiftright', 'sleep', 'space', 'stop', 'subtract', 'tab', 'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen', 'command', 'option', 'optionleft', 'optionright'] + """ + Returns a list of all the keys that can be pressed. + :return: The list of keys. + """ + return ["q", "w", "e", "r", "t", "y", "u", "i", "o", "p", + "a", "s", "d", "f", "g", "h", "j", "k", "l", "z", "x", "c", "v", "b", "n", "m", ",", ".", + "Key.space", "Key.shift", "Key.shift_r", "Key.esc", "Key.enter", "Key.backspace", "Key.tab", "Key.ctrl", + "Key.ctrl_r", "Key.caps_lock", "Key.page_up", "Key.page_down", "Key.end", "Key.home", "Key.delete", + "Key.insert", "Key.left", "Key.up", "Key.right", "Key.down", "Key.num_lock", "Key.print_screen", + "Key.f1", "Key.f2", "Key.f3", "Key.f4", "Key.f5", "Key.f6", "Key.f7", "Key.f8", "Key.f9", "Key.f10", + "Key.f11", "Key.f12"] + + +def get_key(key_id): + """ + Returns the key that corresponds to the given key id. + :param key_id: Set the key id. + :return: the key that corresponds to the given key id. + """ + return get_keys()[key_id] -def get_key(id): - return get_keys()[id] def get_id(key): - return get_keys().index(key) + """ + Returns the id of the given key. + :param key: The key. + :return: The id of the given key. + """ + try: + print("Key Pressed:", key.char, sep=" ") + return get_keys().index(key.char) + except: + if (str(key) + "") not in get_keys(): + print((str(key) + ""), " is not in list") + return 1000 + print("Key Pressed:", (str(key) + ""), sep=" ") + return get_keys().index((str(key) + "")) + -keyboard = Keyboard() mouse = Mouse() + # Mouse: def move(x, y): - mouse.position = (x, y) - return + """ + Moves the mouse to the given coordinates. + :param x: x coordinate. + :param y: y coordinate. + :return: None + """ + pyautogui.moveTo(x, y) + def scroll(x, y): + """ + Scrolls the mouse to the given coordinates. + :param x: The horizontal scroll. + :param y: The vertical scroll. + """ mouse.scroll(x, y) - return + def click(x, y): - mouse.press(Button.left) - return + """ + Clicks the mouse at the given coordinates. + :param x: The x coordinate. + :param y: The y coordinate. + """ + move(x, y) + pyautogui.click() + # Keyboard: def press(key): - keyboard.press(key) - return + """ + Presses the given key. + :param key: The key. + """ + if key in ["Key.shift", "shift"]: + pyautogui.keyDown("shift") + elif key in ["Key.space", "space"]: + pyautogui.keyDown("space") + else: + pyautogui.keyDown(key) + def release(key): - keyboard.release(key) - return + """ + Releases the given key. + :param key: the key. + """ + if key in ["Key.shift", "shift"]: + pyautogui.keyUp(Key.shift) + elif key in ["Key.space", "space"]: + pyautogui.keyUp(Key.space) + else: + pyautogui.keyUp(key) diff --git a/get_dataset.py b/get_dataset.py index 0f1ca04..c5360fe 100755 --- a/get_dataset.py +++ b/get_dataset.py @@ -1,47 +1,88 @@ -# Arda Mavi +""" +This file will get the dataset. + +Author: Arda Mavi +""" import os +import pickle + import numpy as np -from keras.utils import to_categorical -from scipy.misc import imread, imresize, imsave +from PIL import Image +from tensorflow.keras.utils import to_categorical +from numpy import size from sklearn.model_selection import train_test_split +from imageio import imread, imsave + def get_img(data_path): - # Getting image array from path: + """ + Getting image array from path and read it. + :param data_path: The image_array path. + :return: the image + """ img = imread(data_path) - img = imresize(img, (150, 150, 3)) return img -def save_img(img, path): - imsave(path, img) - return + +def save_img(path, img): + """ + Saving image to path. + :param path: The path to save the image. + :param img: Which image to save. + """ + imsave(path + '.jpg', img) + + +def before(value, a): + """ + Find first part and slice it. + :param value: The value to slice. + :param a: The value to slice. + :return: value before a + """ + # Find first part and return slice before it. + pos_a = value.find(a) + if pos_a == -1: + return "" + return value[:pos_a] + def get_dataset(dataset_path='Data/Train_Data'): - # Getting all data from data path: - try: - X = np.load('Data/npy_train_data/X.npy') - Y = np.load('Data/npy_train_data/Y.npy') - except: - labels = os.listdir(dataset_path) # Geting labels - X = [] - Y = [] - count_categori = [-1,''] # For encode labels - for label in labels: - datas_path = dataset_path+'/'+label - for data in os.listdir(datas_path): - img = get_img(datas_path+'/'+data) - X.append(img) - # For encode labels: - if data != count_categori[1]: - count_categori[0] += 1 - count_categori[1] = data.split(',') - Y.append(count_categori[0]) - # Create dateset: - X = np.array(X).astype('float32')/255. - Y = np.array(Y).astype('float32') - Y = to_categorical(Y, count_categori[0]+1) - if not os.path.exists('Data/npy_train_data/'): - os.makedirs('Data/npy_train_data/') - np.save('Data/npy_train_data/X.npy', X) - np.save('Data/npy_train_data/Y.npy', Y) - X, X_test, Y, Y_test = train_test_split(X, Y, test_size=0.1, random_state=42) - return X, X_test, Y, Y_test + """ + Getting all data from data path: + :param dataset_path: The path to the dataset. + :return: x_dataset, y_dataset, x_test, y_test + """ + labels = os.listdir(dataset_path) # Getting labels + x_dataset = [] + y_dataset = [] + z_dataset = [] + count_category = [-1, ''] # For encode labels + for label in labels: + datas_path = dataset_path + '/' + label + for data in os.listdir(datas_path): + img = get_img(datas_path + '/' + data) + x_dataset.append(img) + # For encode labels: + current_choice = data.split(',') + del current_choice[4] + if current_choice != count_category[1]: + count_category[0] += 1 + count_category[1] = current_choice + if count_category[1] not in z_dataset: + z_dataset.append(count_category[1]) + y_dataset.append(count_category[0]) + with open('listfile.data', 'wb') as filehandle: + # store the data as binary data stream + pickle.dump(z_dataset, filehandle) + x_dataset = np.array(x_dataset).astype('float32') / 255. + y_dataset = np.array(y_dataset).astype('float32') + # print(y_dataset) + y_dataset = to_categorical(y_dataset, count_category[0] + 1) + # print(y_dataset) + if not os.path.exists('Data/npy_train_data/'): + os.makedirs('Data/npy_train_data/') + np.save('Data/npy_train_data/x_dataset.npy', x_dataset) + np.save('Data/npy_train_data/y_dataset.npy', y_dataset) + x_dataset, x_test, y_dataset, y_test = train_test_split(x_dataset, y_dataset, test_size=0.1) + return x_dataset, x_test, y_dataset, y_test, count_category[0] + 1 diff --git a/get_model.py b/get_model.py index e50271f..b84bea3 100755 --- a/get_model.py +++ b/get_model.py @@ -1,10 +1,19 @@ -# Arda Mavi +""" +This file will get the model from the database and return it to the user. +Author: Arda Mavi +""" import os -from keras.models import Model -from keras.optimizers import Adadelta -from keras.layers import Input, Conv2D, Activation, MaxPooling2D, Flatten, Dense, Dropout + +from tensorflow.keras.layers import Input, Conv2D, Activation, MaxPooling2D, Flatten, Dense, Dropout +from tensorflow.keras.models import Model + def save_model(model): + """ + This function will save the model to the database. + :param model: Which model to use. + :return: None + """ if not os.path.exists('Data/Model/'): os.makedirs('Data/Model/') model_json = model.to_json() @@ -13,42 +22,47 @@ def save_model(model): # serialize weights to HDF5 model.save_weights("Data/Model/weights.h5") print('Model and weights saved') - return -def get_model(): +def get_model(action_total): + """ + This function will get the model from the database. + :param action_total: Total number of actions. + :return: model + """ inputs = Input(shape=(150, 150, 3)) - conv_1 = Conv2D(32, (3,3), strides=(1,1))(inputs) - act_1 = Activation('relu')(conv_1) + conv_1 = Conv2D(32, (3, 3), strides=(1, 1))(inputs) + # act_1 = Activation('relu')(conv_1) - conv_2 = Conv2D(64, (3,3), strides=(1,1))(act_1) - act_2 = Activation('relu')(conv_2) + conv_2 = Conv2D(64, (3, 3), strides=(1, 1))(conv_1) + # act_2 = Activation('relu')(conv_2) - conv_3 = Conv2D(64, (3,3), strides=(1,1))(act_2) - act_3 = Activation('relu')(conv_3) + conv_3 = Conv2D(64, (3, 3), strides=(1, 1))(conv_2) + # act_3 = Activation('relu')(conv_3) - pooling_1 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(act_3) + pooling_1 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(conv_3) - conv_4 = Conv2D(128, (3,3), strides=(1,1))(pooling_1) - act_4 = Activation('relu')(conv_4) + conv_4 = Conv2D(128, (3, 3), strides=(1, 1))(pooling_1) + # act_4 = Activation('relu')(conv_4) - pooling_2 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(act_4) + pooling_2 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(conv_4) flat_1 = Flatten()(pooling_2) - fc = Dense(1280)(flat_1) + fc = Dense(1024)(flat_1) fc = Activation('relu')(fc) fc = Dropout(0.5)(fc) - fc = Dense(4)(fc) + fc = Dense(action_total)(fc) outputs = Activation('sigmoid')(fc) model = Model(inputs=inputs, outputs=outputs) - model.compile(loss='categorical_crossentropy', optimizer='adadelta', metrics=['accuracy']) + model.compile(loss='binary_crossentropy', optimizer='adadelta', metrics=['accuracy']) return model + if __name__ == '__main__': - save_model(get_model()) + save_model(get_model(10)) diff --git a/predict.py b/predict.py index 1e2f4a4..ffea7bd 100755 --- a/predict.py +++ b/predict.py @@ -1,8 +1,21 @@ -# Arda Mavi -import numpy as np -from scipy.misc import imresize +""" +This file will make a prediction based on the input data. +Author: Arda Mavi +""" +from PIL import Image +from numpy import size + def predict(model, X): - X = imresize(X, (150, 150, 3)).astype('float32')/255. - Y = model.predict(X.reshape(1,150,150,3)) - return Y + """ + This function will make a prediction based on the input data. + :param model: Which model to use. + :param X: input data. + :return: y_pred: prediction. + """ + # resize it with PIL, because scipy.misc.imresize is deprecated. + x = X(Image.fromarray(X).resize((size[0] * 4, size[1] * 4), + resample=Image.BICUBIC)) + y = model.predict(x.reshape(1, 150, 150, 3)) + y = y.argmax() + return y diff --git a/requirements.txt b/requirements.txt index 508a7f3..c849cbc 100755 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,10 @@ numpy scikit-learn scikit-image pillow -tensorflow -keras +tensorflow==1.9.0 +keras==2.2.0 +scipy==1.2.3 pynput h5py +pyautogui +imageio diff --git a/train.py b/train.py index f97d886..5f0e6d3 100755 --- a/train.py +++ b/train.py @@ -1,31 +1,60 @@ -# Arda Mavi +""" +This will train the model. +Author: Arda Mavi +""" import os -import numpy + +from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard + from get_dataset import get_dataset from get_model import get_model, save_model -from keras.callbacks import ModelCheckpoint, TensorBoard -epochs = 100 -batch_size = 5 +epochs = 15 +batch_size = 32 + -def train_model(model, X, X_test, Y, Y_test): - checkpoints = [] +def train_model(model, x, x_test, y, y_test): + """ + This will train the model. + :param model: Which model to train + :param x: x + :param x_test: x_test + :param y: y + :param y_test: y_test + :return: model + """ if not os.path.exists('Data/Checkpoints/'): os.makedirs('Data/Checkpoints/') - checkpoints.append(ModelCheckpoint('Data/Checkpoints/best_weights.h5', monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=True, mode='auto', period=1)) - checkpoints.append(TensorBoard(log_dir='Data/Checkpoints/./logs', histogram_freq=0, write_graph=True, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None)) + checkpoints = [ModelCheckpoint( + 'Data/Checkpoints/best_weights.h5', + monitor='val_loss', + verbose=0, + save_best_only=True, + save_weights_only=True, + mode='auto', + period=1, + ), TensorBoard(log_dir='Data/Checkpoints/./logs', histogram_freq=0, write_graph=True, write_images=False, + embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None)] - model.fit(X, Y, batch_size=batch_size, epochs=epochs, validation_data=(X_test, Y_test), shuffle=True, callbacks=checkpoints) + model.fit(x, y, batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test), shuffle=True, + callbacks=checkpoints) return model + def main(): - X, X_test, Y, Y_test = get_dataset() - model = get_model() - model = train_model(model, X, X_test, Y, Y_test) + """ + The main function. + :return: model + """ + x, x_test, y, y_test, action_total = get_dataset() + print(action_total) + model = get_model(action_total) + model = train_model(model, x, x_test, x, x_test) save_model(model) return model + if __name__ == '__main__': main()