From 99ff493f18dadc8398cbb7dd8e81141b3422b700 Mon Sep 17 00:00:00 2001 From: rafittu Date: Thu, 17 Oct 2024 09:16:28 -0300 Subject: [PATCH 1/8] build: add new app requirements --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 8fdca55..02d6e7e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,5 @@ Flask==2.2.5 Flask-SQLAlchemy==2.5.1 SQLAlchemy==1.4.22 psycopg2-binary==2.9.10 +numpy==1.24.1 +tensorflow==2.17.0 From c4a65884ee6a3cb5bea64746096b2609ed3634fe Mon Sep 17 00:00:00 2001 From: rafittu Date: Thu, 17 Oct 2024 10:16:42 -0300 Subject: [PATCH 2/8] chore: optimize base image and install build dependencies --- Dockerfile | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index e594259..e300033 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,9 @@ FROM python:3.11-slim WORKDIR /app COPY requirements.txt . -RUN pip install --no-cache-dir -r requirements.txt + +RUN apt-get update && apt-get install -y build-essential \ + && pip install --no-cache-dir -r requirements.txt COPY ./src ./src From 3077c0b13eafbc8bb3665460cea532b35a8edca1 Mon Sep 17 00:00:00 2001 From: rafittu Date: Thu, 17 Oct 2024 13:53:31 -0300 Subject: [PATCH 3/8] feat: initialize lstm model --- src/models/__init__.py | 0 src/models/lstm_color_predictor.py | 13 +++++++++++++ 2 files changed, 13 insertions(+) create mode 100644 src/models/__init__.py create mode 100644 src/models/lstm_color_predictor.py diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/lstm_color_predictor.py b/src/models/lstm_color_predictor.py new file mode 100644 index 0000000..e2ac750 --- /dev/null +++ b/src/models/lstm_color_predictor.py @@ -0,0 +1,13 @@ +from tensorflow.keras.models import Sequential +from tensorflow.keras.layers import LSTM, Dense +from tensorflow.keras.optimizers import Adam + +def create_model(input_shape): + model = Sequential() + model.add(LSTM(50, activation='relu', input_shape=input_shape)) + model.add(Dense(3, activation='sigmoid')) + model.compile(optimizer=Adam(learning_rate=0.001), loss='mse') + return model + + +model = create_model((None, 3)) From 977b2eef5df74e4a4200df8607754609e483af48 Mon Sep 17 00:00:00 2001 From: rafittu Date: Thu, 17 Oct 2024 13:57:11 -0300 Subject: [PATCH 4/8] feat: function to get last N colors --- src/models/lstm_color_predictor.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/models/lstm_color_predictor.py b/src/models/lstm_color_predictor.py index e2ac750..6fbfc7a 100644 --- a/src/models/lstm_color_predictor.py +++ b/src/models/lstm_color_predictor.py @@ -1,6 +1,8 @@ +import numpy as np from tensorflow.keras.models import Sequential from tensorflow.keras.layers import LSTM, Dense from tensorflow.keras.optimizers import Adam +from database.models.color import Color def create_model(input_shape): model = Sequential() @@ -11,3 +13,10 @@ def create_model(input_shape): model = create_model((None, 3)) + + +def get_last_n_colors(n): + colors = Color.query.order_by(Color.timestamp.desc()).limit(n).all() + colors.reverse() + data = np.array([[c.red, c.green, c.blue] for c in colors]) / 255.0 + return data From 41a0ddd51162b72340fc7c1c67d54b9a96ddeafe Mon Sep 17 00:00:00 2001 From: rafittu Date: Thu, 17 Oct 2024 14:09:30 -0300 Subject: [PATCH 5/8] feat: function to train model --- src/models/lstm_color_predictor.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/models/lstm_color_predictor.py b/src/models/lstm_color_predictor.py index 6fbfc7a..e41a715 100644 --- a/src/models/lstm_color_predictor.py +++ b/src/models/lstm_color_predictor.py @@ -20,3 +20,12 @@ def get_last_n_colors(n): colors.reverse() data = np.array([[c.red, c.green, c.blue] for c in colors]) / 255.0 return data + + +def train_model(n): + data = get_last_n_colors(n) + if len(data) > 1: + X, y = data[:-1], data[1:] + X = X.reshape((1, X.shape[0], X.shape[1])) + y = y.reshape((1, y.shape[0], y.shape[1])) + model.fit(X, y, epochs=1, verbose=0) From 271debe588c1f795cfac54869d352bc2f9cc06d5 Mon Sep 17 00:00:00 2001 From: rafittu Date: Thu, 17 Oct 2024 14:30:43 -0300 Subject: [PATCH 6/8] feat: function to predict next colors --- src/models/lstm_color_predictor.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/models/lstm_color_predictor.py b/src/models/lstm_color_predictor.py index e41a715..3fc19cc 100644 --- a/src/models/lstm_color_predictor.py +++ b/src/models/lstm_color_predictor.py @@ -29,3 +29,10 @@ def train_model(n): X = X.reshape((1, X.shape[0], X.shape[1])) y = y.reshape((1, y.shape[0], y.shape[1])) model.fit(X, y, epochs=1, verbose=0) + + +def predict_next_color(): + last_colors = get_last_n_colors(1) + if last_colors.size > 0: + prediction = model.predict(last_colors.reshape((1, 1, 3))) + return (prediction[0][0] * 255).astype(int) From 77a905c8b175d83a40edbd9353a853d9d2ff290c Mon Sep 17 00:00:00 2001 From: rafittu Date: Thu, 17 Oct 2024 15:30:13 -0300 Subject: [PATCH 7/8] feat: train model and predict next color --- src/api/color_api.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/api/color_api.py b/src/api/color_api.py index 42bf6d1..0adefe7 100644 --- a/src/api/color_api.py +++ b/src/api/color_api.py @@ -4,11 +4,14 @@ from flask import Blueprint from database.models.color import Color from database.database import db +from models.lstm_color_predictor import train_model, predict_next_color color_api_bp = Blueprint('color_api', __name__) socketio = SocketIO(cors_allowed_origins="*") logger = logging.getLogger(__name__) +color_buffer = [] + COLOR_MAP = { "rgb(0, 0, 255)": "blue", "rgb(0, 255, 255)": "blue-green", @@ -55,5 +58,18 @@ def handle_receive_color(data): ) db.session.add(new_color) db.session.commit() + + color_buffer.append(new_color) + + if len(color_buffer) == 9: + predicted_color = predict_next_color() + + logger.info(f"Predicted next color (RGB): {predicted_color}") + print(f"Predicted next color (RGB): {predicted_color}") + + train_model(n=10) + + color_buffer.clear() + except ValueError as e: logger.error(f"Error processing color: {e}") From 68cce660dbf1ba0a935b43deb2bffc403b4a72fa Mon Sep 17 00:00:00 2001 From: rafittu Date: Thu, 17 Oct 2024 15:50:40 -0300 Subject: [PATCH 8/8] feat: include color_name in prediction return --- src/api/color_api.py | 16 +++++++++++++--- src/models/lstm_color_predictor.py | 3 ++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/api/color_api.py b/src/api/color_api.py index 0adefe7..d156785 100644 --- a/src/api/color_api.py +++ b/src/api/color_api.py @@ -40,6 +40,10 @@ def get_color_name(color_str): return COLOR_MAP.get(color_str, "unknown") +def map_rgb_to_color_name(rgb_value): + closest_color = min(COLOR_MAP.keys(), key=lambda x: sum(abs(int(c) - v) for c, v in zip(x[4:-1].split(', '), rgb_value))) + return COLOR_MAP[closest_color] + @socketio.on('send_color') def handle_receive_color(data): color_str = data.get('color') @@ -61,15 +65,21 @@ def handle_receive_color(data): color_buffer.append(new_color) + print(f"Buffer len: {len(color_buffer)}") + if len(color_buffer) == 9: - predicted_color = predict_next_color() + predicted_rgb = predict_next_color() + predicted_color_name = map_rgb_to_color_name(predicted_rgb) - logger.info(f"Predicted next color (RGB): {predicted_color}") - print(f"Predicted next color (RGB): {predicted_color}") + logger.info(f"Predicted next color (RGB): {predicted_color_name} {predicted_rgb}") + print(f"Predicted next color (RGB): {predicted_color_name} {predicted_rgb}") train_model(n=10) color_buffer.clear() + if len(color_buffer) == 10: + color_buffer.clear() + except ValueError as e: logger.error(f"Error processing color: {e}") diff --git a/src/models/lstm_color_predictor.py b/src/models/lstm_color_predictor.py index 3fc19cc..bae389a 100644 --- a/src/models/lstm_color_predictor.py +++ b/src/models/lstm_color_predictor.py @@ -35,4 +35,5 @@ def predict_next_color(): last_colors = get_last_n_colors(1) if last_colors.size > 0: prediction = model.predict(last_colors.reshape((1, 1, 3))) - return (prediction[0][0] * 255).astype(int) + rgb_values = (prediction[0] * 255).astype(int).flatten() + return tuple(rgb_values)