diff --git a/backend/Cargo.lock b/backend/Cargo.lock index de35fc5..e6df04c 100644 --- a/backend/Cargo.lock +++ b/backend/Cargo.lock @@ -117,6 +117,7 @@ dependencies = [ "argon2 0.4.1", "axum", "chrono", + "csv", "dotenvy", "env_logger", "log", @@ -430,6 +431,27 @@ dependencies = [ "typenum", ] +[[package]] +name = "csv" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52cd9d68cf7efc6ddfaaee42e7288d3a99d613d4b50f76ce9827ae0c6e14f938" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde_core", +] + +[[package]] +name = "csv-core" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "704a3c26996a80471189265814dbc2c257598b96b8a7feae2d31ace646bb9782" +dependencies = [ + "memchr", +] + [[package]] name = "data-encoding" version = "2.9.0" @@ -1850,6 +1872,7 @@ version = "0.1.0" dependencies = [ "argon2 0.5.3", "chrono", + "csv", "dotenvy", "futures-util", "log", @@ -2667,6 +2690,7 @@ dependencies = [ name = "websocket-server" version = "0.1.0" dependencies = [ + "csv", "dotenvy", "env_logger", "futures-util", diff --git a/backend/api-server/Cargo.toml b/backend/api-server/Cargo.toml index 0bdef5e..f505957 100644 --- a/backend/api-server/Cargo.toml +++ b/backend/api-server/Cargo.toml @@ -33,4 +33,7 @@ pyo3 = { version = "0.18.0", features = ["auto-initialize"] } env_logger = "0.11" # Shared logic crate -shared-logic = { path = "../shared-logic" } \ No newline at end of file +shared-logic = { path = "../shared-logic" } + +# CSV serialization/deserialization +csv = "1.4" \ No newline at end of file diff --git a/backend/api-server/src/main.rs b/backend/api-server/src/main.rs index 8e98605..d2a9d0b 100644 --- a/backend/api-server/src/main.rs +++ b/backend/api-server/src/main.rs @@ -5,6 +5,7 @@ use axum::{ routing::{get, post}, Json, Router, + body::Bytes, }; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -17,10 +18,13 @@ use pyo3::Python; use pyo3::types::{PyList, PyModule, PyTuple}; use pyo3::PyResult; use pyo3::{IntoPy, ToPyObject}; +use chrono::{DateTime, Utc}; +use axum::http::{HeaderMap, HeaderValue, header}; +use axum::response::IntoResponse; use rand_core::OsRng; // shared logic library -use shared_logic::db::{initialize_connection, DbClient}; +use shared_logic::db::{DbClient, get_eeg_time_range, initialize_connection, export_eeg_data_as_csv}; use shared_logic::models::{User, NewUser, UpdateUser, Session, FrontendState}; // Argon2 imports @@ -36,6 +40,21 @@ struct AppState { db_client: DbClient, } +// define request struct for exporting EEG data +#[derive(Deserialize)] +struct ExportEEGRequest { + filename: String, + options: ExportOptions +} + +#[derive(Deserialize)] +struct ExportOptions { + format: String, + includeHeader: bool, + start_time: Option>, + end_time: Option>, +} + #[derive(Debug, Clone, Deserialize)] pub struct LoginRequest { @@ -247,6 +266,68 @@ async fn get_frontend_state( } } +// Handler for POST /api/sessions/{session_id}/eeg_data/export +async fn export_eeg_data( + State(app_state): State, + Path(session_id): Path, + Json(request): Json, +) -> Result { + info!("Received request to export EEG data for session {}", session_id); + + // right now the only export format supported is CSV, so we just check for that + if request.options.format.to_lowercase() != "csv" { + return Err((StatusCode::BAD_REQUEST, format!("Unsupported export format: {}", request.options.format))); + } + + let (start_time, end_time) = get_eeg_time_range(&app_state.db_client, session_id, &request.options) + .await.map_err(|e| { + error!("Failed to get EEG time range: {}", e); + (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to get EEG time range: {}", e)) + })?; + + let header_included = request.options.includeHeader; + + // finally call the export function in db.rs + let return_csv = match export_eeg_data_as_csv(&app_state.db_client, session_id, start_time, end_time, header_included).await { + Ok(csv_data) => csv_data, + Err(e) => { + error!("Failed to export EEG data: {}", e); + return Err((StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to export EEG data: {}", e))); + } + }; + + // small safety: avoid quotes breaking header + let filename = request.filename.replace('"', ""); + + let mut headers = HeaderMap::new(); + headers.insert(header::CONTENT_TYPE, HeaderValue::from_static("text/csv; charset=utf-8")); + + let content_disp = format!("attachment; filename=\"{}\"", filename); + headers.insert( + header::CONTENT_DISPOSITION, + HeaderValue::from_str(&content_disp).map_err(|e| { + (StatusCode::BAD_REQUEST, format!("Invalid filename for header: {}", e)) + })?, + ); + + // return CSV directly as the body + Ok((headers, return_csv)) + +} + +// Handler for POST /api/sessions/{session_id}/eeg_data/import +async fn import_eeg_data( + State(app_state): State, + Path(session_id): Path, + // we expect the CSV data to be sent as raw text in the body of the request + body: Bytes, +) -> Result, (StatusCode, String)> { + shared_logic::db::import_eeg_data_from_csv(&app_state.db_client, session_id, &body) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to import EEG data: {}", e)))?; + + Ok(Json(json!({"status": "success"}))) +} @@ -353,6 +434,9 @@ async fn main() { .route("/api/sessions/:session_id/frontend-state", post(set_frontend_state)) .route("/api/sessions/:session_id/frontend-state", get(get_frontend_state)) + .route("/api/sessions/:session_id/eeg_data/export", post(export_eeg_data)) + .route("/api/sessions/:session_id/eeg_data/import", post(import_eeg_data)) + // Share application state with all handlers .with_state(app_state); diff --git a/backend/migrations/20263101120000_sessions_on_eeg_data.sql b/backend/migrations/20263101120000_sessions_on_eeg_data.sql new file mode 100644 index 0000000..95083df --- /dev/null +++ b/backend/migrations/20263101120000_sessions_on_eeg_data.sql @@ -0,0 +1,8 @@ +-- note that this assumes that eeg_data has no rows (since as of this migration there should be no real data yet) +-- to do so just run TRUNCATE TABLE eeg_data before applying this migration +ALTER TABLE eeg_data +ADD COLUMN session_id INTEGER NOT NULL, +ADD CONSTRAINT fk_session FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE; + +-- we can create an index on session_id and time, since the bulk of our queries will be filtering based on these +CREATE INDEX eeg_data_session_time_idx ON eeg_data (session_id, time DESC); -- using DESC since i'm expecting recent data to be more relevant \ No newline at end of file diff --git a/backend/shared-logic/Cargo.toml b/backend/shared-logic/Cargo.toml index 6003901..4acf866 100644 --- a/backend/shared-logic/Cargo.toml +++ b/backend/shared-logic/Cargo.toml @@ -48,4 +48,7 @@ rand_core = "0.6" # working with python pyo3 = { version = "0.18.0", features = ["auto-initialize"] } -numpy = "0.18" \ No newline at end of file +numpy = "0.18" + +# CSV serialization/deserialization +csv = "1.4" \ No newline at end of file diff --git a/backend/shared-logic/src/bc.rs b/backend/shared-logic/src/bc.rs index 6b2baa2..fbc4273 100644 --- a/backend/shared-logic/src/bc.rs +++ b/backend/shared-logic/src/bc.rs @@ -24,6 +24,7 @@ pub async fn start_broadcast( write: Arc, Message>>>, cancel_token: CancellationToken, processing_config: ProcessingConfig, // takes in signal processing configuration from frontend + session_id: i32, // takes in session id to tag incoming data with the correct session windowing_rx: watch::Receiver // takes in windowing configuration from frontend ) { let (tx, _rx) = broadcast::channel::>(1000); // size of the broadcast buffer, not recommand below 500, websocket will miss messages @@ -55,7 +56,7 @@ pub async fn start_broadcast( // Subscribe for database Receiver tokio::spawn(async move { - db_receiver( rx_db).await; + db_receiver( rx_db, session_id).await; }); //waits for sender to complete. @@ -111,7 +112,10 @@ pub async fn ws_receiver(write: &Arc, //db_broadcast_receiver takes EEGDataPacket struct from the broadcast sender and inserts it into the database // it inserts as a batch of 100. -pub async fn db_receiver(mut rx_db: Receiver>){ +pub async fn db_receiver( + mut rx_db: Receiver>, + session_id: i32, +){ let db_client = get_db_client(); let mut packet_count = 0; // for debug purposes @@ -131,7 +135,7 @@ pub async fn db_receiver(mut rx_db: Receiver>){ // Insert the packet directly tokio::spawn(async move { let now = Instant::now(); // for debug purposes - if let Err(e) = insert_batch_eeg(&db_client_clone, &eeg_packet).await { + if let Err(e) = insert_batch_eeg(&db_client_clone, session_id, &eeg_packet).await { error!("Packet insert failed: {:?}", e); } info!("Packet insert took {:?}", now.elapsed()); // for debug purposes diff --git a/backend/shared-logic/src/db.rs b/backend/shared-logic/src/db.rs index 8ffafa7..1ff3825 100644 --- a/backend/shared-logic/src/db.rs +++ b/backend/shared-logic/src/db.rs @@ -5,12 +5,13 @@ use sqlx::{ }; use tokio::time::{self, Duration}; use log::{info, error, warn}; -use chrono::{DateTime, Utc}; +use chrono::{Date, DateTime, Utc}; use dotenvy::dotenv; use super::models::{User, NewUser, TimeSeriesData, UpdateUser, Session, FrontendState}; use crate::{lsl::EEGDataPacket}; use once_cell::sync::OnceCell; use std::sync::Arc; +use serde::{Serialize, Deserialize}; use argon2::password_hash::SaltString; use rand_core::OsRng; use argon2::{Argon2, password_hash::{PasswordHasher, PasswordHash, PasswordVerifier}}; @@ -21,6 +22,16 @@ pub static DB_POOL: OnceCell> = OnceCell::new(); pub type DbClient = Arc; +// struct for EEG rows to convert to CSV +#[derive(serde::Serialize)] +struct EEGCsvRow { + time: String, + channel1: i32, + channel2: i32, + channel3: i32, + channel4: i32, +} + pub async fn initialize_connection() -> Result { dotenv().ok(); let database_url = std::env::var("DATABASE_URL") @@ -135,7 +146,7 @@ pub async fn get_testtime_series_data(client: &DbClient) -> Result Result<(), sqlx::Error> { +pub async fn insert_batch_eeg(client: &DbClient, session_id: i32, packet: &EEGDataPacket) -> Result<(), sqlx::Error> { let n_samples = packet.timestamps.len(); @@ -147,7 +158,7 @@ pub async fn insert_batch_eeg(client: &DbClient, packet: &EEGDataPacket) -> Resu // Construct a single SQL insert statement let mut query_builder = sqlx::QueryBuilder::new( - "INSERT INTO eeg_data (time, channel1, channel2, channel3, channel4) " + "INSERT INTO eeg_data (session_id, time, channel1, channel2, channel3, channel4) " ); // Iterate through all data in the packet, pairing timestamp to the signal, and insert them @@ -155,6 +166,7 @@ pub async fn insert_batch_eeg(client: &DbClient, packet: &EEGDataPacket) -> Resu query_builder.push_values( (0..n_samples).map(|sample_idx| { ( + session_id, &packet.timestamps[sample_idx], packet.signals[0][sample_idx], // Channel 0 packet.signals[1][sample_idx], // Channel 1 @@ -162,8 +174,9 @@ pub async fn insert_batch_eeg(client: &DbClient, packet: &EEGDataPacket) -> Resu packet.signals[3][sample_idx], // Channel 3 ) }), - |mut b, (timestamp, ch0, ch1, ch2, ch3)| { - b.push_bind(timestamp) + |mut b, (session_id, timestamp, ch0, ch1, ch2, ch3)| { + b.push_bind(session_id) + .push_bind(timestamp) .push_bind(ch0) .push_bind(ch1) .push_bind(ch2) @@ -171,6 +184,7 @@ pub async fn insert_batch_eeg(client: &DbClient, packet: &EEGDataPacket) -> Resu } ); + query_builder.push(" ON CONFLICT (session_id, time) DO NOTHING"); query_builder.build().execute(&**client).await?; info!("EEG packet inserted successfully - {} data", packet.timestamps.len()); Ok(()) @@ -404,4 +418,170 @@ pub async fn get_frontend_state(client: &DbClient, session_id: i32) -> Result, end_time: DateTime, include_header: bool) -> Result { + info!("Exporting EEG data for session id {} from {} to {}", session_id, start_time, end_time); + + // get the data from the database + let data = sqlx::query!( + "SELECT time, channel1, channel2, channel3, channel4 FROM eeg_data + WHERE session_id = $1 AND time >= $2 AND time <= $3 + ORDER BY time ASC", + session_id, + start_time, + end_time + ) + .fetch_all(&**client) + .await?; + + // build the CSV using the csv crate + + let mut writer = csv::WriterBuilder::new() + .has_headers(false) + .from_writer(vec![]); + + // write the header based on include_header flag + if include_header { + writer.write_record(&["time", "channel1", "channel2", "channel3", "channel4"]) + .map_err(|e| Error::Protocol(e.to_string()))?; + } + + // now, iterate through the data and write each row + for row in data { + writer.serialize(EEGCsvRow { + time: row.time.to_rfc3339(), + channel1: row.channel1, + channel2: row.channel2, + channel3: row.channel3, + channel4: row.channel4, + }) + .map_err(|e| Error::Protocol(e.to_string()))?; + } + + let byte_stream = writer.into_inner() + .map_err(|e| Error::Protocol(e.to_string()))?; + + // now, we convert the CSV data to a string and return it + let csv_data = String::from_utf8(byte_stream) + .map_err(|e| Error::Protocol(e.to_string()))?; + + Ok(csv_data) +} + +/// Import EEG data from a CSV byte stream for a given session ID. The CSV is expected +/// to have columns: "time", "channel1", "channel2", "channel3", "channel4". +/// +/// Returns Ok(()) on success. +pub async fn import_eeg_data_from_csv(client: &DbClient, session_id: i32, csv_bytes: &[u8]) -> Result<(), Error> { + info!("Importing EEG data for session id {} from CSV", session_id); + + // we use the csv crate to read the CSV data, converting them to the struct we made for CSV rows + let mut reader = csv::ReaderBuilder::new() + .has_headers(true) // we expect the CSV to have headers, should probably make this clear somewhere + .from_reader(csv_bytes); + + // set up our vectors to hold the parsed EEG data rows, so we can batch insert them later + let mut timestamps: Vec> = Vec::new(); + let mut channel1_data: Vec = Vec::new(); + let mut channel2_data: Vec = Vec::new(); + let mut channel3_data: Vec = Vec::new(); + let mut channel4_data: Vec = Vec::new(); + + + // we iterate through the CSV records, parsing each row and converting it to the format we need for insertion + for result in reader.records() { + // unwrap the record, if there's an error we return it + let record = result.map_err(|e| Error::Protocol(e.to_string()))?; + + // now we parse the fields, converting time to DateTime and channels to i32 + let time_str = record.get(0).ok_or_else(|| Error::Protocol("Missing time field".to_string()))?; + + // we assume the time is in RFC3339 format + let time = DateTime::parse_from_rfc3339(time_str) + .map_err(|e| Error::Protocol(format!("Invalid time format: {}", e)))? + .with_timezone(&Utc); + + let channel1 = record.get(1) + .ok_or_else(|| Error::Protocol("Missing channel1 field".to_string()))? + .parse::() + .map_err(|e| Error::Protocol(format!("Invalid channel1 value: {}", e)))?; + let channel2 = record.get(2) + .ok_or_else(|| Error::Protocol("Missing channel2 field".to_string()))? + .parse::() + .map_err(|e| Error::Protocol(format!("Invalid channel2 value: {}", e)))?; + let channel3 = record.get(3) + .ok_or_else(|| Error::Protocol("Missing channel3 field".to_string()))? + .parse::() + .map_err(|e| Error::Protocol(format!("Invalid channel3 value: {}", e)))?; + let channel4 = record.get(4) + .ok_or_else(|| Error::Protocol("Missing channel4 field".to_string()))? + .parse::() + .map_err(|e| Error::Protocol(format!("Invalid channel4 value: {}", e)))?; + + // now we construct the tuple and add it to our vectors + timestamps.push(time); + channel1_data.push(channel1); + channel2_data.push(channel2); + channel3_data.push(channel3); + channel4_data.push(channel4); + } + + // now we use our existing batch insert function to insert the data into the database + let eeg_rows = EEGDataPacket { + timestamps, + signals: vec![channel1_data, channel2_data, channel3_data, channel4_data], + }; + + insert_batch_eeg(client, session_id, &eeg_rows).await?; + + Ok(()) +} + +/// Helper function for eeg data to find the earliest timestamp for a given session +/// +/// Returns the earliest timestamp on success. +pub async fn get_earliest_eeg_timestamp(client: &DbClient, session_id: i32) -> Result>, Error> { + let row = sqlx::query!( + "SELECT MIN(time) as earliest_time FROM eeg_data WHERE session_id = $1", + session_id + ) + .fetch_one(&**client) + .await?; + + Ok(row.earliest_time) +} + +/// Helper function for eeg data to get the start and end timestamps for a given session +/// +/// Returns the start and end timestamps on success. +pub async fn get_eeg_time_range(client: &DbClient, session_id: i32, options: ExportOptions) -> Result<(DateTime, DateTime), Error> { + // check for time range, else use defaults + // for end time, we default to the current time + // for start time, we default to the earliest timestamp for the session + let end_time = match options.end_time { + Some(t) => t, + None => Utc::now(), + }; + + let start_time = match options.start_time { + Some(t) => t, + None => { + // we call the helper function above to get the earliest timestamp + match get_earliest_eeg_timestamp(&app_state.db_client, session_id).await { + Ok(Some(t)) => t, + Ok(None) => return Err((StatusCode::NOT_FOUND, format!("No EEG data found for session {}", session_id))), + Err(e) => return Err((StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to get earliest EEG timestamp: {}", e))), + } + } + }; + + if start_time > end_time { + return Err((StatusCode::BAD_REQUEST, "start_time cannot be after end_time".to_string())); + } + + Ok((start_time, end_time)) } \ No newline at end of file diff --git a/backend/websocket-server/Cargo.toml b/backend/websocket-server/Cargo.toml index bad7c5e..f6452bd 100644 --- a/backend/websocket-server/Cargo.toml +++ b/backend/websocket-server/Cargo.toml @@ -14,5 +14,6 @@ tokio-util = "0.7.15" # Shared logic crate shared-logic = { path = "../shared-logic" } -serde = "1.0.228" -serde_json = "1" \ No newline at end of file +serde = { version = "1", features = ["derive"] } +serde_json = "1" +csv = "1.4" \ No newline at end of file diff --git a/backend/websocket-server/src/main.rs b/backend/websocket-server/src/main.rs index 76042ec..4f846d8 100644 --- a/backend/websocket-server/src/main.rs +++ b/backend/websocket-server/src/main.rs @@ -16,10 +16,15 @@ use shared_logic::db::{initialize_connection}; use shared_logic::lsl::{ProcessingConfig, WindowingConfig}; // get ProcessingConfig from lsl.rs use dotenvy::dotenv; use log::{info, error}; -use serde_json; // used to parse ProcessingConfig from JSON sent by frontend -use serde_json::Value; // used to parse ProcessingConfig from JSON sent by frontend +use serde_json::{Deserialize, Value}; // used to parse ProcessingConfig from JSON sent by frontend +#[derive(Deserialize)] +struct WebSocketInitMessage{ + session_id: i32, + processing_config: ProcessingConfig, +} + #[tokio::main] async fn main() { env_logger::init(); @@ -87,6 +92,42 @@ async fn handle_connection(ws_stream: WebSocketStream) { let mut processing_config = ProcessingConfig::default(); let mut initial_windowing = WindowingConfig::default(); + // we have the WebSocketInitMessage struct, with a session id and processing config + // check if we received a message (some unwrapping needed) + let init_message: WebSocketInitMessage = match signal_config { + Some(Ok(msg)) => { + let text = match msg.to_text() { + Ok(t) => t, + Err(e) => { + error!("Failed to convert init message to text: {}", e); + return; + } + }; + + match serde_json::from_str::(text) { + Ok(init_msg) => init_msg, + Err(e) => { + error!("Failed to parse init message JSON: {}", e); + return; + } + } + } + + Some(Err(e)) => { + error!("Error receiving initialization message: {}", e); + return; + } + + None => { + error!("No initialization message received from client. Closing connection."); + return; + } + }; + + + let session_id = init_message.session_id; + processing_config = init_message.processing_config. + // Give the frontend a short window to send configs before we start // Use a timeout so we don't block forever if only one config arrives let config_timeout = tokio::time::Duration::from_millis(500); @@ -132,14 +173,18 @@ async fn handle_connection(ws_stream: WebSocketStream) { } } + info!("Starting broadcast with chunk={}, overlap={}", initial_windowing.chunk_size, initial_windowing.overlap_size); let (windowing_tx, windowing_rx) = watch::channel(initial_windowing); + // spawns the broadcast task let mut broadcast = Some(tokio::spawn(async move { - start_broadcast(write_clone, cancel_clone, processing_config, windowing_rx).await; + // pass ProcessingConfig, WindowingConfig receiver, and session_id into broadcast + start_broadcast(write_clone, cancel_clone, processing_config, session_id, windowing_rx).await; })); + while let Some(msg) = read.next().await { match msg { Ok(msg) if msg.is_text() => {