diff --git a/http_framework/rgl_framework.lua b/http_framework/rgl_framework.lua index 4be6434..2c1dc46 100644 --- a/http_framework/rgl_framework.lua +++ b/http_framework/rgl_framework.lua @@ -27,6 +27,7 @@ ffi.cdef[[ void rgl_auth_init(const char* secret_dir); void rgl_auth_set(const char* login, const char* password); void rgl_auth_clear(); + void rgl_auth_reset(); int rgl_auth_enabled(); ]] @@ -40,6 +41,7 @@ local command_handlers = {} local framework = {} local _current_module = nil local admin_visible = false +local admin_state = {} local recent_logs = {} local MAX_LOGS = 100 local notifications = {} -- {text, level, time, start} @@ -767,12 +769,18 @@ function admin_render(ui, state) end if ui.tab_item("Auth") then - if rust.rgl_auth_enabled() == 1 then + local auth_on = rust.rgl_auth_enabled() == 1 + if auth_on then ui.text_colored(0.3, 0.8, 0.3, 1, "Auth: ON") ui.spacing() - if ui.button("Disable Auth") then + if ui.button("Disable") then rust.rgl_auth_clear() - log("INFO", "AUTH", "Credentials cleared from admin UI") + log("INFO", "AUTH", "Auth disabled") + end + ui.sameline() + if ui.button("Reset") then + rust.rgl_auth_reset() + log("INFO", "AUTH", "Auth fully reset") end else ui.text_colored(0.5, 0.5, 0.5, 1, "Auth: OFF") @@ -781,10 +789,10 @@ function admin_render(ui, state) ui.text("Set credentials:") state.auth_login = ui.input("Login", state.auth_login or "") state.auth_pass = ui.input("Password", state.auth_pass or "") - if ui.button("Save Credentials") then + if ui.button("Save") then if #(state.auth_login or "") > 0 and #(state.auth_pass or "") > 0 then rust.rgl_auth_set(state.auth_login, state.auth_pass) - log("INFO", "AUTH", "Credentials set from admin UI") + log("INFO", "AUTH", "Credentials saved") state.auth_login = "" state.auth_pass = "" end @@ -803,7 +811,7 @@ function register_admin() pcall(function() interactions = framework.json_decode(body) end) end local ui = create_ui_builder(interactions) - admin_render(ui, {}) + admin_render(ui, admin_state) return framework.json_encode({widgets = ui._get_widgets()}) end, owner = "__admin"} @@ -1023,7 +1031,7 @@ if imgui_loaded and imgui then local ui = create_ui_imgui() if ui then - local rok, rerr = pcall(admin_render, ui, {}) + local rok, rerr = pcall(admin_render, ui, admin_state) if not rok then imgui.TextColored(imgui.ImVec4(1, 0.3, 0.3, 1), "Render error:") imgui.TextWrapped(tostring(rerr)) diff --git a/rust_core/src/auth.rs b/rust_core/src/auth.rs index 8fefb8b..1f1d6be 100644 --- a/rust_core/src/auth.rs +++ b/rust_core/src/auth.rs @@ -4,17 +4,30 @@ //! Credentials (login/password) are XOR-encrypted with the secret and stored in //! a separate `auth` DB table that modules cannot access through the kv API. -use std::sync::OnceLock; +use std::sync::{Mutex, OnceLock}; use rand::Rng; use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64}; use crate::logging; -static SECRET: OnceLock = OnceLock::new(); -static CREDENTIALS: OnceLock>> = OnceLock::new(); +struct AuthState { + secret: String, + credentials: Option<(String, String)>, + secret_paths: Vec, +} + +static STATE: OnceLock> = OnceLock::new(); + +fn state() -> &'static Mutex { + STATE.get_or_init(|| Mutex::new(AuthState { + secret: String::new(), + credentials: None, + secret_paths: Vec::new(), + })) +} -fn credentials() -> &'static std::sync::Mutex> { - CREDENTIALS.get_or_init(|| std::sync::Mutex::new(None)) +fn lock_state() -> std::sync::MutexGuard<'static, AuthState> { + state().lock().unwrap_or_else(|e| e.into_inner()) } /// Generate a 32-byte hex secret. @@ -43,8 +56,6 @@ fn decrypt_with_secret(encoded: &str, secret: &str) -> Option { } /// Initialize auth: load or generate secret, load credentials from DB. -/// Must be called after DB is initialized. -/// `secret_paths`: list of paths to try for secret file (first writable wins). pub async fn init(secret_paths: &[String], db_conn: &tokio_rusqlite::Connection) { // Create auth table if let Err(e) = db_conn.call(|conn| { @@ -72,23 +83,25 @@ pub async fn init(secret_paths: &[String], db_conn: &tokio_rusqlite::Connection) Ok::<_, tokio_rusqlite::rusqlite::Error>((login, password)) }).await.unwrap_or((None, None)); + let mut s = lock_state(); + s.secret_paths = secret_paths.to_vec(); + if let (Some(enc_login), Some(enc_pass)) = creds { if let (Some(login), Some(password)) = ( decrypt_with_secret(&enc_login, &secret), decrypt_with_secret(&enc_pass, &secret), ) { - *credentials().lock().unwrap_or_else(|e| e.into_inner()) = Some((login.clone(), password)); logging::log("INFO", "AUTH", &format!("credentials loaded for user '{login}'")); + s.credentials = Some((login, password)); } else { logging::log("WARN", "AUTH", "credentials in DB couldn't be decrypted (secret changed?), auth disabled"); } } - SECRET.set(secret).ok(); + s.secret = secret; } async fn load_or_generate_secret(paths: &[String]) -> String { - // Try to read existing secret from any path for path in paths { if let Ok(content) = tokio::fs::read_to_string(path).await { let trimmed = content.trim().to_string(); @@ -99,7 +112,6 @@ async fn load_or_generate_secret(paths: &[String]) -> String { } } - // Generate new secret and try to write it let secret = generate_secret(); for path in paths { if let Some(parent) = std::path::Path::new(path).parent() { @@ -115,31 +127,23 @@ async fn load_or_generate_secret(paths: &[String]) -> String { secret } -/// Set credentials (called from Lua FFI). Encrypts and stores in DB. +/// Set credentials. Encrypts and stores in DB. pub fn set_credentials(login: &str, password: &str) { - let Some(secret) = SECRET.get() else { return }; - - let enc_login = xor_with_secret(login, secret); - let enc_pass = xor_with_secret(password, secret); - - *credentials().lock().unwrap_or_else(|e| e.into_inner()) = - Some((login.to_string(), password.to_string())); + let (enc_login, enc_pass) = { + let mut s = lock_state(); + if s.secret.is_empty() { return; } + let enc_login = xor_with_secret(login, &s.secret); + let enc_pass = xor_with_secret(password, &s.secret); + s.credentials = Some((login.to_string(), password.to_string())); + (enc_login, enc_pass) + }; - // Store in DB async - let enc_login = enc_login.clone(); - let enc_pass = enc_pass.clone(); if let Some(handle) = crate::server::runtime_handle() { handle.spawn(async move { if let Some(conn) = crate::db::get_connection() { let _ = conn.call(move |conn| { - conn.execute( - "INSERT OR REPLACE INTO auth (key, value) VALUES ('login', ?1)", - [&enc_login], - )?; - conn.execute( - "INSERT OR REPLACE INTO auth (key, value) VALUES ('password', ?1)", - [&enc_pass], - )?; + conn.execute("INSERT OR REPLACE INTO auth (key, value) VALUES ('login', ?1)", [&enc_login])?; + conn.execute("INSERT OR REPLACE INTO auth (key, value) VALUES ('password', ?1)", [&enc_pass])?; Ok::<_, tokio_rusqlite::rusqlite::Error>(()) }).await; logging::log("INFO", "AUTH", "credentials saved"); @@ -148,9 +152,9 @@ pub fn set_credentials(login: &str, password: &str) { } } -/// Clear credentials. +/// Clear credentials (disable auth, keep secret). pub fn clear_credentials() { - *credentials().lock().unwrap_or_else(|e| e.into_inner()) = None; + lock_state().credentials = None; if let Some(handle) = crate::server::runtime_handle() { handle.spawn(async move { @@ -165,18 +169,55 @@ pub fn clear_credentials() { } } +/// Full reset: clear credentials + regenerate secret. +pub fn reset() { + let paths = { + let mut s = lock_state(); + s.credentials = None; + s.secret = generate_secret(); + let new_secret = s.secret.clone(); + let paths = s.secret_paths.clone(); + drop(s); + + // Write new secret to file and clear DB + if let Some(handle) = crate::server::runtime_handle() { + let paths_clone = paths.clone(); + handle.spawn(async move { + // Write new secret + for path in &paths_clone { + if let Some(parent) = std::path::Path::new(path).parent() { + let _ = tokio::fs::create_dir_all(parent).await; + } + if tokio::fs::write(path, &new_secret).await.is_ok() { + logging::log("INFO", "AUTH", &format!("new secret written to {path}")); + break; + } + } + // Clear credentials from DB + if let Some(conn) = crate::db::get_connection() { + let _ = conn.call(|conn| { + conn.execute("DELETE FROM auth WHERE key IN ('login', 'password')", [])?; + Ok::<_, tokio_rusqlite::rusqlite::Error>(()) + }).await; + } + logging::log("INFO", "AUTH", "auth fully reset with new secret"); + }); + } + paths + }; + let _ = paths; // suppress unused warning +} + /// Check if auth is enabled (credentials are set). pub fn has_auth() -> bool { - credentials().lock().unwrap_or_else(|e| e.into_inner()).is_some() + lock_state().credentials.is_some() } /// Check an HTTP request's authorization. -/// Returns true if authorized (no auth configured, or valid credentials/token). pub fn check_auth(auth_header: Option<&str>) -> bool { - // No auth configured → allow all - let creds = credentials().lock().unwrap_or_else(|e| e.into_inner()); - let Some((ref login, ref password)) = *creds else { - return true; + let s = lock_state(); + let Some((ref login, ref password)) = s.credentials else { + return true; // no auth configured }; let Some(header) = auth_header else { @@ -185,9 +226,7 @@ pub fn check_auth(auth_header: Option<&str>) -> bool { // Bearer token (secret) if let Some(token) = header.strip_prefix("Bearer ") { - if let Some(secret) = SECRET.get() { - return token == secret; - } + return token == s.secret; } // Basic auth @@ -204,9 +243,9 @@ pub fn check_auth(auth_header: Option<&str>) -> bool { false } -/// Get the secret token (for WebSocket query param auth). -pub fn get_secret() -> Option<&'static String> { - SECRET.get() +/// Get the secret token (for external integrations). +pub fn get_secret() -> String { + lock_state().secret.clone() } #[cfg(test)] @@ -222,27 +261,21 @@ mod tests { assert_eq!(decrypted, data); } - #[test] - fn test_check_auth_no_config() { - // When no credentials set, everything passes - // (credentials() defaults to None) - // Note: in test environment, CREDENTIALS may have state from other tests - // This test verifies the logic path - assert!(check_auth(None) || has_auth()); - } - #[test] fn test_generate_secret_length() { let secret = generate_secret(); - assert_eq!(secret.len(), 64); // 32 bytes * 2 hex chars + assert_eq!(secret.len(), 64); assert!(secret.chars().all(|c| c.is_ascii_hexdigit())); } #[test] - fn test_basic_auth_parse() { - // Set up test state - let _ = SECRET.set("a".repeat(64)); - *credentials().lock().unwrap() = Some(("admin".to_string(), "pass123".to_string())); + fn test_auth_flow() { + // Setup + { + let mut s = lock_state(); + s.secret = "a".repeat(64); + s.credentials = Some(("admin".to_string(), "pass123".to_string())); + } // Valid basic auth let encoded = BASE64.encode("admin:pass123"); @@ -256,10 +289,17 @@ mod tests { assert!(!check_auth(None)); // Bearer with secret - let secret = SECRET.get().unwrap(); + let secret = get_secret(); assert!(check_auth(Some(&format!("Bearer {secret}")))); // Clean up - *credentials().lock().unwrap() = None; + lock_state().credentials = None; + } + + #[test] + fn test_no_auth_allows_all() { + lock_state().credentials = None; + assert!(check_auth(None)); + assert!(check_auth(Some("garbage"))); } } diff --git a/rust_core/src/lib.rs b/rust_core/src/lib.rs index ca472ad..dd7db6d 100644 --- a/rust_core/src/lib.rs +++ b/rust_core/src/lib.rs @@ -198,6 +198,11 @@ pub extern "C" fn rgl_auth_clear() { auth::clear_credentials(); } +#[unsafe(no_mangle)] +pub extern "C" fn rgl_auth_reset() { + auth::reset(); +} + #[unsafe(no_mangle)] pub extern "C" fn rgl_auth_enabled() -> c_int { if auth::has_auth() { 1 } else { 0 } diff --git a/rust_core/src/server.rs b/rust_core/src/server.rs index 4ae9468..4c61170 100644 --- a/rust_core/src/server.rs +++ b/rust_core/src/server.rs @@ -1,8 +1,9 @@ //! Axum HTTP/WS server — admin UI is built-in, modules are Lua-side. use axum::{ - extract::{Path, ws::{Message, WebSocket, WebSocketUpgrade}}, + extract::{Path, Request, ws::{Message, WebSocket, WebSocketUpgrade}}, http::{header, StatusCode}, + middleware::{self, Next}, response::{Html, IntoResponse, Response}, routing::{get, post}, Router, @@ -81,6 +82,48 @@ pub fn get_commands_json() -> String { format!("[{}]", items.join(",")) } +// --- Auth middleware --- + +fn unauthorized_response() -> Response { + Response::builder() + .status(StatusCode::UNAUTHORIZED) + .header("WWW-Authenticate", "Basic realm=\"ARZ Web Helper\"") + .body(axum::body::Body::from("Unauthorized")) + .unwrap_or_else(|_| (StatusCode::UNAUTHORIZED, "Unauthorized").into_response()) +} + +async fn auth_middleware(request: Request, next: Next) -> Response { + if !auth::has_auth() { + return next.run(request).await; + } + + // WebSocket upgrade: check ?token= query param + let uri = request.uri().clone(); + if uri.path() == "/ws" { + let query = uri.query().unwrap_or(""); + let token = query.split('&') + .find_map(|p| p.strip_prefix("token=")); + if let Some(t) = token { + if auth::check_auth(Some(&format!("Bearer {t}"))) { + return next.run(request).await; + } + } + } + + // Check Authorization header + let auth_header = request.headers() + .get(header::AUTHORIZATION) + .and_then(|v| v.to_str().ok()); + + if auth::check_auth(auth_header) { + next.run(request).await + } else { + unauthorized_response() + } +} + +// --- Server start/stop --- + pub fn start(port: u16) -> Result<(), String> { let _initial_rx = bridge::init_event_channel(); @@ -99,7 +142,7 @@ pub fn start(port: u16) -> Result<(), String> { }; RT_HANDLE.set(rt.handle().clone()).ok(); - rt_tx.send(true).ok(); // Signal that runtime is ready + rt_tx.send(true).ok(); rt.block_on(async move { let app = Router::new() @@ -109,7 +152,8 @@ pub fn start(port: u16) -> Result<(), String> { .route("/api/modules", get(modules_list_handler)) .route("/api/commands", get(commands_list_handler)) .route("/api/{module}/{action}", post(api_handler)) - .fallback(static_file_handler); + .fallback(static_file_handler) + .layer(middleware::from_fn(auth_middleware)); let addr = format!("0.0.0.0:{port}"); let socket = match tokio::net::TcpSocket::new_v4() { @@ -139,7 +183,6 @@ pub fn start(port: u16) -> Result<(), String> { }) .map_err(|e| format!("thread spawn error: {e}"))?; - // Wait for tokio runtime to be ready before returning match rt_rx.recv() { Ok(true) => Ok(()), _ => Err("runtime init failed".to_string()), @@ -180,22 +223,14 @@ async fn modules_list_handler() -> impl IntoResponse { } async fn api_handler( - headers: axum::http::HeaderMap, Path((module, action)): Path<(String, String)>, body: String, ) -> impl IntoResponse { - let auth_header = headers.get(header::AUTHORIZATION) - .and_then(|v| v.to_str().ok()); - if !auth::check_auth(auth_header) { - return (StatusCode::UNAUTHORIZED, r#"{"error":"unauthorized"}"#).into_response(); - } - let code = format!( "return __arz_handle_api([=[{module}]=], [=[{action}]=], [=[{body}]=])" ); let id = bridge::request_lua_exec(code); - // Poll for result with async timeout — never blocks tokio workers let result = tokio::time::timeout( std::time::Duration::from_secs(2), async { @@ -234,7 +269,6 @@ async fn static_file_handler(uri: axum::http::Uri) -> Response { let dirs = lock_or_recover(module_dirs()); match dirs.get(module) { Some(d) if d == "__render__" => { - // Module uses render() API — serve auto-UI page let html = include_str!("../static/ui_page.html") .replace("{{MODULE}}", module) .replace("{{TITLE}}", module); @@ -277,21 +311,10 @@ async fn static_file_handler(uri: axum::http::Uri) -> Response { // --- WebSocket --- async fn ws_handler( - headers: axum::http::HeaderMap, - query: axum::extract::Query>, ws: WebSocketUpgrade, ) -> impl IntoResponse { - // Check auth via Authorization header or ?token= query param - let auth_header = headers.get(header::AUTHORIZATION) - .and_then(|v| v.to_str().ok()); - let token_param = query.get("token").map(|s| format!("Bearer {s}")); - let effective_auth = auth_header.map(|s| s.to_string()).or(token_param); - - if !auth::check_auth(effective_auth.as_deref()) { - return (StatusCode::UNAUTHORIZED, "unauthorized").into_response(); - } - - ws.on_upgrade(handle_ws).into_response() + // Auth already checked by middleware (including ?token= for WS) + ws.on_upgrade(handle_ws) } async fn handle_ws(mut socket: WebSocket) { @@ -321,7 +344,7 @@ async fn handle_ws(mut socket: WebSocket) { logging::log("WARN", "WS", &format!("client lagged, {n} events dropped")); continue; } - Err(_) => break, // channel closed + Err(_) => break, } } msg = socket.recv() => { diff --git a/rust_core/static/ui_page.html b/rust_core/static/ui_page.html index 5cf51bd..9793c59 100644 --- a/rust_core/static/ui_page.html +++ b/rust_core/static/ui_page.html @@ -74,6 +74,10 @@ let interactions = {}; async function render() { try { + // Collect current input values before sending + uiEl.querySelectorAll('input[data-wid]').forEach(inp => { + if (inp.type === 'text') interactions[inp.dataset.wid] = inp.value; + }); const res = await fetch('/api/' + MODULE + '/__render', { method: 'POST', headers: {'Content-Type': 'application/json'}, @@ -162,7 +166,8 @@ function createWidget(w) { const inp = document.createElement('input'); inp.type = 'text'; inp.value = w.value || ''; - inp.onchange = () => { interactions[w.id] = inp.value; render(); }; + inp.dataset.wid = w.id; + inp.oninput = () => { interactions[w.id] = inp.value; }; d.appendChild(lbl); d.appendChild(inp); return d;