diff --git a/Dockerfile b/Dockerfile index e594259..e300033 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,9 @@ FROM python:3.11-slim WORKDIR /app COPY requirements.txt . -RUN pip install --no-cache-dir -r requirements.txt + +RUN apt-get update && apt-get install -y build-essential \ + && pip install --no-cache-dir -r requirements.txt COPY ./src ./src diff --git a/requirements.txt b/requirements.txt index 8fdca55..02d6e7e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,5 @@ Flask==2.2.5 Flask-SQLAlchemy==2.5.1 SQLAlchemy==1.4.22 psycopg2-binary==2.9.10 +numpy==1.24.1 +tensorflow==2.17.0 diff --git a/src/api/color_api.py b/src/api/color_api.py index 42bf6d1..d156785 100644 --- a/src/api/color_api.py +++ b/src/api/color_api.py @@ -4,11 +4,14 @@ from flask import Blueprint from database.models.color import Color from database.database import db +from models.lstm_color_predictor import train_model, predict_next_color color_api_bp = Blueprint('color_api', __name__) socketio = SocketIO(cors_allowed_origins="*") logger = logging.getLogger(__name__) +color_buffer = [] + COLOR_MAP = { "rgb(0, 0, 255)": "blue", "rgb(0, 255, 255)": "blue-green", @@ -37,6 +40,10 @@ def get_color_name(color_str): return COLOR_MAP.get(color_str, "unknown") +def map_rgb_to_color_name(rgb_value): + closest_color = min(COLOR_MAP.keys(), key=lambda x: sum(abs(int(c) - v) for c, v in zip(x[4:-1].split(', '), rgb_value))) + return COLOR_MAP[closest_color] + @socketio.on('send_color') def handle_receive_color(data): color_str = data.get('color') @@ -55,5 +62,24 @@ def handle_receive_color(data): ) db.session.add(new_color) db.session.commit() + + color_buffer.append(new_color) + + print(f"Buffer len: {len(color_buffer)}") + + if len(color_buffer) == 9: + predicted_rgb = predict_next_color() + predicted_color_name = map_rgb_to_color_name(predicted_rgb) + + logger.info(f"Predicted next color (RGB): {predicted_color_name} {predicted_rgb}") + print(f"Predicted next color (RGB): {predicted_color_name} {predicted_rgb}") + + train_model(n=10) + + color_buffer.clear() + + if len(color_buffer) == 10: + color_buffer.clear() + except ValueError as e: logger.error(f"Error processing color: {e}") diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/lstm_color_predictor.py b/src/models/lstm_color_predictor.py new file mode 100644 index 0000000..bae389a --- /dev/null +++ b/src/models/lstm_color_predictor.py @@ -0,0 +1,39 @@ +import numpy as np +from tensorflow.keras.models import Sequential +from tensorflow.keras.layers import LSTM, Dense +from tensorflow.keras.optimizers import Adam +from database.models.color import Color + +def create_model(input_shape): + model = Sequential() + model.add(LSTM(50, activation='relu', input_shape=input_shape)) + model.add(Dense(3, activation='sigmoid')) + model.compile(optimizer=Adam(learning_rate=0.001), loss='mse') + return model + + +model = create_model((None, 3)) + + +def get_last_n_colors(n): + colors = Color.query.order_by(Color.timestamp.desc()).limit(n).all() + colors.reverse() + data = np.array([[c.red, c.green, c.blue] for c in colors]) / 255.0 + return data + + +def train_model(n): + data = get_last_n_colors(n) + if len(data) > 1: + X, y = data[:-1], data[1:] + X = X.reshape((1, X.shape[0], X.shape[1])) + y = y.reshape((1, y.shape[0], y.shape[1])) + model.fit(X, y, epochs=1, verbose=0) + + +def predict_next_color(): + last_colors = get_last_n_colors(1) + if last_colors.size > 0: + prediction = model.predict(last_colors.reshape((1, 1, 3))) + rgb_values = (prediction[0] * 255).astype(int).flatten() + return tuple(rgb_values)