diff --git a/server/Cargo.toml b/server/Cargo.toml index 26ee5d5..f119ad1 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2024" [dependencies] -axum = { version = "0.7", features = ["ws"] } +axum = { version = "0.8.8", features = ["ws"] } tokio = { version = "1", features = ["full"] } bollard = "0.17" # The Docker Client futures = "0.3" @@ -15,11 +15,11 @@ tracing = "0.1" tracing-subscriber = "0.3" sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "postgres", "uuid", "time"] } uuid = { version = "1", features = ["serde", "v4"] } -oauth2 = "4.4" -tower-sessions = "0.10" -axum-extra = { version = "0.9", features = ["cookie"] } +oauth2 = "5.0" +tower-sessions = "0.15" +axum-extra = { version = "0.12", features = ["cookie"] } dotenvy = "0.15" # To load client secrets -reqwest = { version = "0.11", features = ["json", "rustls-tls"] } +reqwest = { version = "0.12", features = ["json", "rustls-tls"] } time = "0.3.46" openssl = { version = "0.10", features = ["vendored"] } openssl-sys = { version = "0.9", features = ["vendored"] } diff --git a/server/src/auth.rs b/server/src/auth.rs deleted file mode 100644 index 40bd7c7..0000000 --- a/server/src/auth.rs +++ /dev/null @@ -1,118 +0,0 @@ -use axum::{ - extract::{Query, State}, - response::{Redirect, IntoResponse}, - http::StatusCode, - routing::get, - Router, -}; -use oauth2::{ - basic::BasicClient, AuthUrl, ClientId, ClientSecret, CsrfToken, - RedirectUrl, Scope, TokenUrl, TokenResponse, -}; -use serde::Deserialize; -use tower_sessions::Session; -use crate::state::AppState; -use crate::models::User; - -pub const AUTH_URL: &str = "https://github.com/login/oauth/authorize"; -pub const TOKEN_URL: &str = "https://github.com/login/oauth/access_token"; - -// Routes for Auth -pub fn routes() -> Router { - Router::new() - .route("/auth/github", get(github_login)) - .route("/auth/callback", get(github_callback)) - .route("/auth/logout", get(logout)) - .route("/api/me", get(get_me)) -} - -// 1. Redirect user to GitHub -async fn github_login(State(state): State) -> Result { - // FIX: Handle config errors gracefully - let client = make_client(&state).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?; - - let (auth_url, _csrf_token) = client - .authorize_url(CsrfToken::new_random) - .add_scope(Scope::new("read:user".to_string())) - .url(); - Ok(Redirect::to(auth_url.as_str())) -} - -// 2. Handle callback from GitHub -#[derive(Deserialize)] -struct AuthRequest { code: String } - -async fn github_callback( - Query(query): Query, - State(state): State, - session: Session, -) -> Result { - // FIX: Handle config errors gracefully - let client = make_client(&state).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?; - - // 1. Exchange Code - let token = client - .exchange_code(oauth2::AuthorizationCode::new(query.code)) - .request_async(oauth2::reqwest::async_http_client) - .await - .map_err(|e| { - (StatusCode::INTERNAL_SERVER_ERROR, format!("Token Error: {}", e)) - })?; - - // 2. Fetch Profile - let http_client = reqwest::Client::new(); - let user_data: User = http_client - .get("https://api.github.com/user") - .header("User-Agent", "TryCli Studio") - .bearer_auth(token.access_token().secret()) - .send() - .await - .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Reqwest Error".into()))? - .json() - .await - .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "JSON Error".into()))?; - - // 3. Save Session - session.insert("user", &user_data) - .await - .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Session Insert Error".into()))?; - - // 4. Redirect - Ok(Redirect::to("http://localhost:8080/dashboard")) -} - -// 3. Helper to check session -async fn get_me(session: Session) -> Result { - let user: Option = session.get("user") - .await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Session Read Error: {}", e)))?; - - Ok(axum::Json(user)) -} - -// 4. Helper to create OAuth client (Now returns Result) -fn make_client(state: &AppState) -> Result { - let auth_url = AuthUrl::new(AUTH_URL.to_string()) - .map_err(|e| format!("Invalid Auth URL: {}", e))?; - - let token_url = TokenUrl::new(TOKEN_URL.to_string()) - .map_err(|e| format!("Invalid Token URL: {}", e))?; - - let api_url = std::env::var("API_URL").unwrap_or_else(|_| "http://localhost:3000".to_string()); - - let redirect_url = RedirectUrl::new(format!("{}/auth/callback", api_url)) - .map_err(|e| format!("Invalid Redirect URL: {}", e))?; - - Ok(BasicClient::new( - ClientId::new(state.github_id.clone()), - Some(ClientSecret::new(state.github_secret.clone())), - auth_url, - Some(token_url), - ) - .set_redirect_uri(redirect_url)) -} - -async fn logout(session: Session) -> Result { - session.delete().await.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; - Ok(Redirect::to("http://localhost:8080/")) -} \ No newline at end of file diff --git a/server/src/handlers/admin.rs b/server/src/handlers/admin.rs index c804c00..b3a3c6d 100644 --- a/server/src/handlers/admin.rs +++ b/server/src/handlers/admin.rs @@ -19,8 +19,8 @@ pub fn routes() -> Router { Router::new() .route("/api/admin/stats", get(get_system_stats)) .route("/api/admin/projects", get(get_all_projects)) - .route("/api/admin/container/:id", delete(kill_container)) - .route("/api/admin/project/:slug", delete(delete_project_admin)) + .route("/api/admin/container/{id}", delete(kill_container)) + .route("/api/admin/project/{slug}", delete(delete_project_admin)) } // Updated middleware to check the list diff --git a/server/src/handlers/auth.rs b/server/src/handlers/auth.rs index 998051b..8cc2dff 100644 --- a/server/src/handlers/auth.rs +++ b/server/src/handlers/auth.rs @@ -7,7 +7,7 @@ use axum::{ }; use oauth2::{ basic::BasicClient, AuthUrl, ClientId, ClientSecret, CsrfToken, - RedirectUrl, Scope, TokenUrl, TokenResponse, + RedirectUrl, Scope, TokenUrl, TokenResponse, EndpointSet, EndpointNotSet, }; use serde::Deserialize; use tower_sessions::Session; @@ -17,6 +17,9 @@ use crate::models::User; pub const AUTH_URL: &str = "https://github.com/login/oauth/authorize"; pub const TOKEN_URL: &str = "https://github.com/login/oauth/access_token"; +// Type alias for a fully configured OAuth client with both auth and token endpoints set +type ConfiguredClient = BasicClient; + // Routes for Auth pub fn routes() -> Router { Router::new() @@ -50,17 +53,22 @@ async fn github_callback( // FIX: Handle config errors gracefully let client = make_client(&state).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?; + // Create a stateful HTTP client with no redirects (for SSRF protection) + let http_client = reqwest::Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .build() + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("HTTP Client Error: {}", e)))?; + // 1. Exchange Code let token = client .exchange_code(oauth2::AuthorizationCode::new(query.code)) - .request_async(oauth2::reqwest::async_http_client) + .request_async(&http_client) .await .map_err(|e| { (StatusCode::INTERNAL_SERVER_ERROR, format!("Token Error: {}", e)) })?; // 2. Fetch Profile - let http_client = reqwest::Client::new(); let user_data: User = http_client .get("https://api.github.com/user") .header("User-Agent", "TryCli Studio") @@ -99,7 +107,7 @@ async fn get_me(session: Session) -> Result Result { +fn make_client(state: &AppState) -> Result { let auth_url = AuthUrl::new(AUTH_URL.to_string()) .map_err(|e| format!("Invalid Auth URL: {}", e))?; @@ -112,13 +120,11 @@ fn make_client(state: &AppState) -> Result { let redirect_url = RedirectUrl::new(format!("{}/auth/callback", api_url)) .map_err(|e| format!("Invalid Redirect URL: {}", e))?; - Ok(BasicClient::new( - ClientId::new(state.github_id.clone()), - Some(ClientSecret::new(state.github_secret.clone())), - auth_url, - Some(token_url), - ) - .set_redirect_uri(redirect_url)) + Ok(BasicClient::new(ClientId::new(state.github_id.clone())) + .set_client_secret(ClientSecret::new(state.github_secret.clone())) + .set_auth_uri(auth_url) + .set_token_uri(token_url) + .set_redirect_uri(redirect_url)) } async fn logout(session: Session) -> Result { diff --git a/server/src/handlers/oembed.rs b/server/src/handlers/oembed.rs index 4ccc569..77d269c 100644 --- a/server/src/handlers/oembed.rs +++ b/server/src/handlers/oembed.rs @@ -27,24 +27,23 @@ pub async fn oembed_handler( if parts.len() >= 2 && parts[0] == "e" { let token = parts[1]; - let project = sqlx::query!( - "SELECT slug, owner_username FROM projects WHERE embed_token = $1", - token + let project: Option<(String, String)> = sqlx::query_as( + "SELECT slug, owner_username FROM projects WHERE embed_token = $1", ) + .bind(token) .fetch_optional(&state.db) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; - if let Some(p) = project { + if let Some((slug, owner_username)) = project { let origin = std::env::var("FRONTEND_URL").unwrap_or("https://trycli.com".to_string()); - let embed_src = format!("{}/embed/{}/{}", origin, p.owner_username, p.slug); + let embed_src = format!("{}/embed/{}/{}", origin, owner_username, slug); return Ok(Json(OEmbedResponse::Rich { version: "1.0".to_string(), - title: format!("Interactive Demo: {}", p.slug), - // FIX: Clone the username here so it doesn't get moved - author_name: p.owner_username.clone(), - author_url: format!("{}/{}", origin, p.owner_username), + title: format!("Interactive Demo: {}", slug), + author_name: owner_username.clone(), + author_url: format!("{}/{}", origin, owner_username), provider_name: "TryCLI Studio".to_string(), provider_url: origin.clone(), html: format!( @@ -62,13 +61,14 @@ pub async fn oembed_handler( let username = parts[0]; let slug = parts[1]; - let exists = sqlx::query!( - "SELECT 1 as exists FROM projects WHERE owner_username = $1 AND slug = $2", - username, slug + let exists: Option<(i32,)> = sqlx::query_as( + "SELECT 1 FROM projects WHERE owner_username = $1 AND slug = $2", ) + .bind(username) + .bind(slug) .fetch_optional(&state.db) .await - .unwrap_or(None); + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; if exists.is_some() { let origin = std::env::var("FRONTEND_URL").unwrap_or("https://trycli.com".to_string()); diff --git a/server/src/handlers/project.rs b/server/src/handlers/project.rs index 508c762..581fdf8 100644 --- a/server/src/handlers/project.rs +++ b/server/src/handlers/project.rs @@ -7,7 +7,8 @@ use axum::{ Json, Router, }; use bollard::exec::{CreateExecOptions, StartExecResults}; -use bollard::image::{CreateImageOptions, RemoveImageOptions}; +use bollard::image::{CreateImageOptions, RemoveImageOptions}; +use bollard::container::ListContainersOptions; use tower_sessions::Session; use uuid::Uuid; use serde::Deserialize; @@ -147,16 +148,16 @@ fn validate_csrf_protection(headers: &HeaderMap) -> Result<(), (StatusCode, Stri pub fn routes() -> Router { Router::new() .route("/api/my-projects", get(list_user_projects)) - .route("/api/project/:username/:slug", get(get_project)) - .route("/api/project/:slug", delete(delete_project)) - .route("/api/project/:slug/embed-key", get(get_embed_key)) + .route("/api/project/{username}/{slug}", get(get_project)) + .route("/api/project/{slug}", delete(delete_project)) + .route("/api/project/{slug}/embed-key", get(get_embed_key)) .route( - "/api/project/:slug/whitelist", + "/api/project/{slug}/whitelist", get(get_whitelist).post(add_to_whitelist).delete(remove_from_whitelist), ) .route("/api/search-projects", get(search_projects)) .route("/api/publish", post(publish_handler)) - .route("/e/:token", get(resolve_secret_embed)) // Secret Embed Route + .route("/e/{token}", get(resolve_secret_embed)) // Secret Embed Route } pub async fn list_user_projects( @@ -335,6 +336,53 @@ pub async fn publish_handler( Ok(Json("Published!".to_string())) } +/// Helper function to find an existing viewer container for a project +/// This is used to reuse containers for old projects that don't have session_id labels +async fn find_existing_viewer_container(state: &AppState, project_slug: &str, owner_id: i64) -> Option { + // First, check in-memory sessions for a match + { + let sessions = state.lock_sessions(); + for (session_id, ctx) in sessions.iter() { + if ctx.project_owner_id == Some(owner_id) + && ctx.project_slug.as_deref() == Some(project_slug) + && !ctx.container_name.is_empty() + && ctx.container_name != "INITIALIZING" { + return Some(session_id.clone()); + } + } + } + + // If not in memory, check Docker containers + // First try with project_slug label (new containers) + let filters_with_slug = HashMap::from([ + ("label".to_string(), vec![ + "managed_by=TryCli Studio".to_string(), + "container_type=viewer".to_string(), + format!("project_owner_id={}", owner_id), + format!("project_slug={}", project_slug), + ]) + ]); + + let opts = ListContainersOptions { + all: false, + filters: filters_with_slug, + ..Default::default() + }; + + if let Ok(containers) = state.docker.list_containers(Some(opts)).await { + if let Some(container) = containers.first() { + // Extract session_id from labels + if let Some(labels) = &container.labels { + if let Some(session_id) = labels.get("session_id") { + return Some(session_id.clone()); + } + } + } + } + + None +} + pub async fn get_project( Path((username, slug)): Path<(String, String)>, State(state): State, @@ -477,23 +525,32 @@ pub async fn get_project( } // 7. Construct JSON response - // Generate the session ID here, but DO NOT spawn Docker yet. - let session_id = Uuid::new_v4().to_string(); - - { - let mut map = state.lock_sessions(); - map.insert(session_id.clone(), SessionContext { - container_name: String::new(), // Empty indicates "Not Started" - pending_image_tag: Some(image_tag), // Store tag for WS handler - shell, - owner_id: None, - project_owner_id: Some(owner_id), - is_publishing: false, - project_slug: Some(slug), - created_at: std::time::Instant::now(), - is_ws_connected: false, - }); - } + // Check if there's an existing viewer container for this project + let session_id = match find_existing_viewer_container(&state, &slug, owner_id).await { + Some(existing_session_id) => { + tracing::info!("Reusing existing session {} for project {}/{}", existing_session_id, username, slug); + existing_session_id + } + None => { + // Generate a new session ID and prepare for lazy container spawn + let new_session_id = Uuid::new_v4().to_string(); + + let mut map = state.lock_sessions(); + map.insert(new_session_id.clone(), SessionContext { + container_name: String::new(), // Empty indicates "Not Started" + pending_image_tag: Some(image_tag.clone()), // Store tag for WS handler + shell: shell.clone(), + owner_id: None, + project_owner_id: Some(owner_id), + is_publishing: false, + project_slug: Some(slug.clone()), + created_at: std::time::Instant::now(), + is_ws_connected: false, + }); + + new_session_id + } + }; let mut response_json = serde_json::json!({ "markdown": markdown, diff --git a/server/src/main.rs b/server/src/main.rs index a596739..07439a5 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -16,6 +16,7 @@ mod handlers { } use services::docker::start_background_reaper; +use services::websocket::restore_sessions_from_containers; #[tokio::main] async fn main() -> Result<(), Box> { @@ -24,6 +25,10 @@ async fn main() -> Result<(), Box> { // Setup database and Docker let state = config::setup_database_and_docker().await?; + // Restore sessions from existing Docker containers (critical for reconnecting to pre-existing containers) + tracing::info!("Restoring sessions from existing containers..."); + restore_sessions_from_containers(&state).await; + // Spawn background reaper let docker_reaper = state.docker.clone(); let sessions_reaper = state.sessions.clone(); diff --git a/server/src/router.rs b/server/src/router.rs index 95052b3..1731346 100644 --- a/server/src/router.rs +++ b/server/src/router.rs @@ -25,7 +25,7 @@ pub fn create_router(state: AppState) -> Result, + project_owner_id: Option, + project_slug: Option<&str>, + shell: &str, + container_type: &str, // "builder" or "viewer" +) -> HashMap { + let mut labels = HashMap::from([ + ("managed_by".to_string(), "TryCli Studio".to_string()), + ("session_id".to_string(), session_id.to_string()), + ("shell".to_string(), shell.to_string()), + ("container_type".to_string(), container_type.to_string()), + ]); + + if let Some(id) = owner_id { + labels.insert("owner_id".to_string(), id.to_string()); + } + + if let Some(id) = project_owner_id { + labels.insert("project_owner_id".to_string(), id.to_string()); + } + + if let Some(slug) = project_slug { + labels.insert("project_slug".to_string(), slug.to_string()); + } + + labels +} + +/// Calculate an approximate creation time based on container's created timestamp +/// This converts a Unix timestamp to an Instant for session tracking +fn calculate_session_created_at(created_ts: Option) -> std::time::Instant { + if let Some(created_ts) = created_ts { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + let age_secs = now - created_ts; + + std::time::Instant::now() + .checked_sub(std::time::Duration::from_secs(age_secs.max(0) as u64)) + .unwrap_or_else(|| std::time::Instant::now()) + } else { + std::time::Instant::now() + } +} + +/// Restore sessions from existing Docker containers on server startup +/// This allows pre-existing containers to be reconnected after server restart +pub async fn restore_sessions_from_containers(state: &AppState) { + let filters = HashMap::from([ + ("label".to_string(), vec!["managed_by=TryCli Studio".to_string()]) + ]); + + let opts = ListContainersOptions { + all: false, // Only running containers + filters, + ..Default::default() + }; + + match state.docker.list_containers(Some(opts)).await { + Ok(containers) => { + let mut restored = 0; + for container in containers { + if let Some(labels) = container.labels { + // Extract session metadata from labels + let session_id = labels.get("session_id").map(|s| s.clone()); + let shell = labels.get("shell").map(|s| s.clone()).unwrap_or_else(|| "/bin/bash".to_string()); + let owner_id = labels.get("owner_id").and_then(|s| s.parse::().ok()); + let project_owner_id = labels.get("project_owner_id").and_then(|s| s.parse::().ok()); + let project_slug = labels.get("project_slug").map(|s| s.clone()); + + let container_name = container.names + .as_ref() + .and_then(|names| names.first()) + .map(|n| n.trim_start_matches('/').to_string()) + .unwrap_or_default(); + + if container_name.is_empty() { + continue; + } + + // Handle new containers (with session_id label) + if let Some(session_id) = session_id { + let mut map = state.lock_sessions(); + + if !map.contains_key(&session_id) { + let created_at = calculate_session_created_at(container.created); + + map.insert(session_id.clone(), SessionContext { + container_name: container_name.clone(), + shell, + pending_image_tag: None, + owner_id, + project_owner_id, + is_publishing: false, + project_slug, + created_at, + is_ws_connected: false, + }); + restored += 1; + tracing::info!("Restored session {} with container {}", session_id, container_name); + } + } else { + // Handle legacy containers (without session_id label) + // Extract UUID from container name for use as session_id + let legacy_session_id = if container_name.starts_with("trycli-studio-viewer-") { + container_name.strip_prefix("trycli-studio-viewer-").map(|s| s.to_string()) + } else if container_name.starts_with("trycli-studio-session-") { + container_name.strip_prefix("trycli-studio-session-").map(|s| s.to_string()) + } else { + None + }; + + if let Some(legacy_session_id) = legacy_session_id { + let mut map = state.lock_sessions(); + + if !map.contains_key(&legacy_session_id) { + let created_at = calculate_session_created_at(container.created); + + // For legacy containers, we don't know the exact metadata + // Set reasonable defaults based on container type + + map.insert(legacy_session_id.clone(), SessionContext { + container_name: container_name.clone(), + shell, + pending_image_tag: None, + owner_id: None, // Unknown for legacy containers + project_owner_id: None, // Unknown for legacy containers + is_publishing: false, + project_slug: None, // Unknown for legacy containers + created_at, + is_ws_connected: false, + }); + restored += 1; + tracing::info!("Restored legacy session {} from container {}", legacy_session_id, container_name); + } + } + } + } + } + tracing::info!("Session restoration complete: {} sessions restored", restored); + } + Err(e) => { + tracing::error!("Failed to restore sessions from containers: {}", e); + } + } +} + pub async fn ws_handler( ws: WebSocketUpgrade, Path(session_id): Path, @@ -25,6 +176,17 @@ pub async fn ws_handler( let user: Option = session.get("user").await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; let user_id = user.map(|u| u.id); + // Check if session exists in memory, if not, try to restore from Docker + let session_exists = { + let map = state.lock_sessions(); + map.contains_key(&session_id) + }; + + if !session_exists { + // Try to find and restore this specific session from Docker containers + restore_specific_session(&state, &session_id).await; + } + { let map = state.lock_sessions(); @@ -47,6 +209,113 @@ pub async fn ws_handler( Ok(ws.on_upgrade(move |socket| handle_socket(socket, state, session_id, user_id))) } +/// Attempt to restore a specific session from Docker containers +/// This is called when a client tries to connect to a session that isn't in memory +async fn restore_specific_session(state: &AppState, session_id: &str) { + // First, try to find a container with the session_id label (new containers) + let filters_with_label = HashMap::from([ + ("label".to_string(), vec![ + "managed_by=TryCli Studio".to_string(), + format!("session_id={}", session_id) + ]) + ]); + + let opts = ListContainersOptions { + all: false, // Only running containers + filters: filters_with_label, + ..Default::default() + }; + + if let Ok(containers) = state.docker.list_containers(Some(opts)).await { + for container in containers { + if let Some(labels) = container.labels { + let shell = labels.get("shell").map(|s| s.clone()).unwrap_or_else(|| "/bin/bash".to_string()); + let owner_id = labels.get("owner_id").and_then(|s| s.parse::().ok()); + let project_owner_id = labels.get("project_owner_id").and_then(|s| s.parse::().ok()); + let project_slug = labels.get("project_slug").map(|s| s.clone()); + + if let Some(names) = container.names { + let container_name = names.first() + .map(|n| n.trim_start_matches('/').to_string()) + .unwrap_or_default(); + + if !container_name.is_empty() { + let created_at = calculate_session_created_at(container.created); + + let mut map = state.lock_sessions(); + map.insert(session_id.to_string(), SessionContext { + container_name: container_name.clone(), + shell, + pending_image_tag: None, + owner_id, + project_owner_id, + is_publishing: false, + project_slug, + created_at, + is_ws_connected: false, + }); + tracing::info!("Restored session {} from container {}", session_id, container_name); + return; + } + } + } + } + } + + // If not found, try to find a legacy container by matching container name + // Legacy containers have UUID in name: trycli-studio-viewer-{uuid} or trycli-studio-session-{uuid} + let legacy_container_names = vec![ + format!("trycli-studio-viewer-{}", session_id), + format!("trycli-studio-session-{}", session_id), + ]; + + for legacy_name in legacy_container_names { + // Try to find container by exact name match + let filters_by_name = HashMap::from([ + ("label".to_string(), vec!["managed_by=TryCli Studio".to_string()]), + ("name".to_string(), vec![legacy_name.clone()]) + ]); + + let opts = ListContainersOptions { + all: false, + filters: filters_by_name, + ..Default::default() + }; + + if let Ok(containers) = state.docker.list_containers(Some(opts)).await { + if let Some(container) = containers.first() { + let labels = container.labels.as_ref(); + let shell = labels + .and_then(|l| l.get("shell")) + .map(|s| s.clone()) + .unwrap_or_else(|| "/bin/bash".to_string()); + + let container_name = container.names.as_ref() + .and_then(|names| names.first()) + .map(|n| n.trim_start_matches('/').to_string()) + .unwrap_or(legacy_name.clone()); + + let created_at = calculate_session_created_at(container.created); + + let mut map = state.lock_sessions(); + map.insert(session_id.to_string(), SessionContext { + container_name: container_name.clone(), + shell, + pending_image_tag: None, + owner_id: None, // Unknown for legacy containers + project_owner_id: None, + is_publishing: false, + project_slug: None, + created_at, + is_ws_connected: false, + }); + tracing::info!("Restored legacy session {} from container {}", session_id, container_name); + return; + } + } + } +} + async fn handle_socket(mut socket: WebSocket, state: AppState, session_id: String, user_id: Option) { // Track if this is a first-time connection for view counting @@ -103,7 +372,13 @@ async fn handle_socket(mut socket: WebSocket, state: AppState, session_id: Strin if let Some(ctx) = map.get(&session_id) { // If we have an image tag but no container name, it's a viewer waiting to start if ctx.container_name.is_empty() && ctx.pending_image_tag.is_some() { - Some((ctx.pending_image_tag.clone().unwrap(), ctx.shell.clone())) + Some(( + ctx.pending_image_tag.clone().unwrap(), + ctx.shell.clone(), + ctx.owner_id, + ctx.project_owner_id, + ctx.project_slug.clone(), + )) } else { None } @@ -112,15 +387,22 @@ async fn handle_socket(mut socket: WebSocket, state: AppState, session_id: Strin } }; - if let Some((image_tag, shell)) = pending_spawn { + if let Some((image_tag, shell, owner_id, project_owner_id, project_slug)) = pending_spawn { // Perform the spawn that used to be in get_project let container_name = format!("trycli-studio-viewer-{}", Uuid::new_v4()); + let labels = create_container_labels( + &session_id, + owner_id, + project_owner_id, + project_slug.as_deref(), + &shell, + "viewer", + ); + let config = Config { image: Some(image_tag), - labels: Some(HashMap::from([ - ("managed_by".to_string(), "TryCli Studio".to_string()) - ])), + labels: Some(labels), tty: Some(true), user: Some("root".to_string()), // FIX: Run sleep infinity as PID 1. This uses almost 0 CPU/RAM. @@ -174,27 +456,45 @@ async fn handle_socket(mut socket: WebSocket, state: AppState, session_id: Strin ..Default::default() }; - // Create & Start - let create_res = state.docker.create_container( - Some(CreateContainerOptions { name: container_name.clone(), platform: None }), - config - ).await; + // Create & Start (viewer container) + match state + .docker + .create_container( + Some(CreateContainerOptions { + name: container_name.clone(), + platform: None, + }), + config, + ) + .await + { + Ok(_) => { + if let Err(e) = + state.docker.start_container::(&container_name, None).await + { + // Log detailed error server-side only + tracing::error!("Viewer start error for session {}: {}", session_id, e); + // Send generic error message to client + let msg = "\r\n\x1b[31m[!] Failed to start viewer container. Please try again later.\x1b[0m\r\n"; + let _ = socket.send(Message::Text(msg.into())).await; + return; + } - if create_res.is_ok() { - if state.docker.start_container::(&container_name, None).await.is_ok() { // Update SessionContext with the real container name let mut map = state.lock_sessions(); if let Some(ctx) = map.get_mut(&session_id) { ctx.container_name = container_name.clone(); ctx.pending_image_tag = None; // clear pending } - } else { - let _ = socket.send(Message::Text("\r\n\x1b[31m[!] Failed to start container.\x1b[0m\r\n".to_string())).await; + } + Err(e) => { + // Log detailed error server-side only + tracing::error!("Viewer create error for session {}: {}", session_id, e); + // Send generic error message to client + let msg = "\r\n\x1b[31m[!] Failed to create viewer container. Please try again later.\x1b[0m\r\n"; + let _ = socket.send(Message::Text(msg.into())).await; return; } - } else { - let _ = socket.send(Message::Text("\r\n\x1b[31m[!] Failed to create container.\x1b[0m\r\n".to_string())).await; - return; } } @@ -228,7 +528,7 @@ async fn handle_socket(mut socket: WebSocket, state: AppState, session_id: Strin if let Some(ctx) = existing_session { if ctx.container_name == "INITIALIZING" { - let _ = socket.close().await; + let _ = socket.send(axum::extract::ws::Message::Close(None)).await; return; } attach_to_container(socket, state, session_id, ctx.container_name, ctx.shell, None).await; @@ -238,7 +538,7 @@ async fn handle_socket(mut socket: WebSocket, state: AppState, session_id: Strin async fn run_setup_wizard(mut socket: WebSocket, state: AppState, session_id: String, _user_id: Option) { async fn send_txt(ws: &mut WebSocket, txt: &str) { - let _ = ws.send(Message::Text(txt.to_string())).await; + let _ = ws.send(Message::Text(txt.to_string().into())).await; } let green = "\x1b[32m"; @@ -258,7 +558,7 @@ async fn run_setup_wizard(mut socket: WebSocket, state: AppState, session_id: St let mut distro_choice = 0; while let Some(Ok(Message::Text(txt))) = socket.recv().await { - let input = txt.trim(); + let input = txt.as_str().trim(); if input == "1" { distro_choice = 1; break; } if input == "2" { distro_choice = 2; break; } if input == "3" { distro_choice = 3; break; } @@ -273,7 +573,7 @@ async fn run_setup_wizard(mut socket: WebSocket, state: AppState, session_id: St let mut shell_choice = 0; while let Some(Ok(Message::Text(txt))) = socket.recv().await { - let input = txt.trim(); + let input = txt.as_str().trim(); if input == "1" { shell_choice = 1; break; } if input == "2" { shell_choice = 2; break; } if input == "3" { shell_choice = 3; break; } @@ -311,6 +611,16 @@ async fn run_setup_wizard(mut socket: WebSocket, state: AppState, session_id: St } let container_name = format!("trycli-studio-session-{}", Uuid::new_v4()); + + let labels = create_container_labels( + &session_id, + _user_id, + None, // Builder sessions don't have project context yet + None, + final_shell, + "builder", + ); + let config = Config { image: Some(image.to_string()), tty: Some(true), @@ -322,9 +632,7 @@ async fn run_setup_wizard(mut socket: WebSocket, state: AppState, session_id: St "LC_ALL=C.UTF-8".to_string(), "TERM=xterm-256color".to_string() ]), - labels: Some(HashMap::from([ - ("managed_by".to_string(), "TryCli Studio".to_string()) - ])), + labels: Some(labels), host_config: Some(HostConfig { runtime: Some("runsc".to_string()), memory: Some(512 * 1024 * 1024), @@ -509,12 +817,12 @@ async fn attach_to_container( Ok(Some(Ok(msg))) => { match msg { Message::Text(text) => { - if input.write_all(text.as_bytes()).await.is_err() { + if input.write_all(text.as_str().as_bytes()).await.is_err() { break; // Container stdin closed } }, Message::Binary(bin) => { - if input.write_all(&bin).await.is_err() { + if input.write_all(bin.as_ref()).await.is_err() { break; } },