diff --git a/jdbc-bridge/src/main/java/sqlkit/bridge/BridgeMain.java b/jdbc-bridge/src/main/java/sqlkit/bridge/BridgeMain.java index 3aeda772..20037711 100644 --- a/jdbc-bridge/src/main/java/sqlkit/bridge/BridgeMain.java +++ b/jdbc-bridge/src/main/java/sqlkit/bridge/BridgeMain.java @@ -15,7 +15,10 @@ public class BridgeMain { private static final ObjectMapper MAPPER = new ObjectMapper(); - private static final ProtocolHandler HANDLER = new ProtocolHandler(new ConnectionManager()); + private static final ProtocolHandler HANDLER = new ProtocolHandler( + new ConnectionManager(), + progressJson -> System.out.println(progressJson) + ); public static void main(String[] args) throws Exception { if (args.length > 0 && "--version".equals(args[0])) { diff --git a/jdbc-bridge/src/main/java/sqlkit/bridge/DriverResolver.java b/jdbc-bridge/src/main/java/sqlkit/bridge/DriverResolver.java index dacd9e08..2cd3360d 100644 --- a/jdbc-bridge/src/main/java/sqlkit/bridge/DriverResolver.java +++ b/jdbc-bridge/src/main/java/sqlkit/bridge/DriverResolver.java @@ -41,18 +41,20 @@ private static String getDriversCacheDir() { * @param versionCap Optional max version to cap against. Null means resolve LATEST. * @param classifier Optional Maven classifier (e.g. "standalone"). Null means no classifier. * @param downloadUrl Direct download URL for drivers NOT on Maven Central. Null means use Maven. + * @param progressCb Callback invoked with (downloaded, total) during JAR download, or null. * @return DriverResult with path to the cached JAR and resolved version */ public static DriverResult resolve(String mavenGroup, String mavenArtifact, String versionCap, String classifier, - String downloadUrl) throws Exception { + String downloadUrl, ProgressCallback progressCb) throws Exception { if (downloadUrl != null && !downloadUrl.isEmpty()) { - return resolveDirect(mavenArtifact, versionCap, downloadUrl); + return resolveDirect(mavenArtifact, versionCap, downloadUrl, progressCb); } - return resolveFromMaven(mavenGroup, mavenArtifact, versionCap, classifier); + return resolveFromMaven(mavenGroup, mavenArtifact, versionCap, classifier, progressCb); } - private static DriverResult resolveDirect(String mavenArtifact, String versionCap, String downloadUrl) throws Exception { + private static DriverResult resolveDirect(String mavenArtifact, String versionCap, + String downloadUrl, ProgressCallback progressCb) throws Exception { String version = (versionCap != null && !versionCap.isEmpty()) ? versionCap : "1.0.0"; String jarFilename = mavenArtifact + "-" + version + ".jar"; Path destPath = Paths.get(DRIVERS_CACHE, mavenArtifact, jarFilename); @@ -67,16 +69,30 @@ private static DriverResult resolveDirect(String mavenArtifact, String versionCa if (!response.isSuccessful()) { throw new Exception("Failed to download JAR: HTTP " + response.code() + " for " + downloadUrl); } - byte[] jarBytes = response.body() != null ? response.body().bytes() : new byte[0]; - response.close(); - Files.write(destPath, jarBytes); + if (response.body() != null) { + long contentLength = response.body().contentLength(); + long totalBytes = contentLength > 0 ? contentLength : 5_000_000L; + long downloadedBytes = 0L; + try (InputStream in = new BufferedInputStream(response.body().byteStream()); + OutputStream out = Files.newOutputStream(destPath)) { + byte[] buffer = new byte[4096]; + int bytesRead; + while ((bytesRead = in.read(buffer)) != -1) { + out.write(buffer, 0, bytesRead); + downloadedBytes += bytesRead; + if (progressCb != null) + progressCb.onProgress(downloadedBytes, totalBytes); + } + } + } } return new DriverResult(destPath.toAbsolutePath().toString(), version); } private static DriverResult resolveFromMaven(String mavenGroup, String mavenArtifact, - String versionCap, String classifier) throws Exception { + String versionCap, String classifier, + ProgressCallback progressCb) throws Exception { // 1. Fetch maven-metadata.xml String metadataUrl = String.format("%s/%s/%s/maven-metadata.xml", MAVEN_CENTRAL, mavenGroup.replace('.', '/'), mavenArtifact); @@ -135,11 +151,26 @@ private static DriverResult resolveFromMaven(String mavenGroup, String mavenArti if (!response.isSuccessful()) { throw new Exception("Failed to download JAR: HTTP " + response.code() + " for " + downloadUrl); } + if (response.body() == null) { + throw new Exception("Empty response body when downloading JAR from " + downloadUrl); + } - byte[] jarBytes = response.body() != null ? response.body().bytes() : new byte[0]; - response.close(); + long contentLength = response.body().contentLength(); + long totalBytes = contentLength > 0 ? contentLength : 5_000_000L; + long downloadedBytes = 0L; - Files.write(destPath, jarBytes); + try (InputStream in = new BufferedInputStream(response.body().byteStream()); + OutputStream out = Files.newOutputStream(destPath)) { + byte[] buffer = new byte[4096]; + int bytesRead; + while ((bytesRead = in.read(buffer)) != -1) { + out.write(buffer, 0, bytesRead); + downloadedBytes += bytesRead; + if (progressCb != null) { + progressCb.onProgress(downloadedBytes, totalBytes); + } + } + } } return new DriverResult(destPath.toAbsolutePath().toString(), latestVersion); @@ -194,6 +225,13 @@ private static int tryParseInt(String s) { } } + /** + * Callback interface for tracking JAR download progress. + */ + public interface ProgressCallback { + void onProgress(long downloaded, long total); + } + /** * Result of a driver resolution. */ diff --git a/jdbc-bridge/src/main/java/sqlkit/bridge/ProtocolHandler.java b/jdbc-bridge/src/main/java/sqlkit/bridge/ProtocolHandler.java index b5ae6301..85447e8f 100644 --- a/jdbc-bridge/src/main/java/sqlkit/bridge/ProtocolHandler.java +++ b/jdbc-bridge/src/main/java/sqlkit/bridge/ProtocolHandler.java @@ -10,6 +10,7 @@ import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.function.Consumer; /** * Dispatches JSON-RPC requests to the appropriate handler. @@ -19,9 +20,11 @@ public class ProtocolHandler { private static final ObjectMapper MAPPER = new ObjectMapper(); private final ConnectionManager connectionManager; + private final Consumer progressConsumer; - public ProtocolHandler(ConnectionManager connectionManager) { + public ProtocolHandler(ConnectionManager connectionManager, Consumer progressConsumer) { this.connectionManager = connectionManager; + this.progressConsumer = progressConsumer; } /** @@ -222,7 +225,20 @@ private void handleResolveDriver(JsonNode params, ObjectNode response) throws Ex String downloadUrl = params.has("download_url") && !params.get("download_url").isNull() ? params.get("download_url").asText() : null; - DriverResolver.DriverResult result = DriverResolver.resolve(mavenGroup, mavenArtifact, versionCap, classifier, downloadUrl); + DriverResolver.ProgressCallback progressCb = (downloaded, total) -> { + try { + ObjectNode progressNode = MAPPER.createObjectNode(); + progressNode.put("phase", "progress"); + progressNode.put("downloaded", downloaded); + progressNode.put("total", total); + progressConsumer.accept(MAPPER.writeValueAsString(progressNode)); + } catch (Exception e) { + // Ignore progress reporting errors + } + }; + + DriverResolver.DriverResult result = DriverResolver.resolve( + mavenGroup, mavenArtifact, versionCap, classifier, downloadUrl, progressCb); ObjectNode resultNode = MAPPER.createObjectNode(); resultNode.put("jar_path", result.getJarPath()); diff --git a/package-lock.json b/package-lock.json index 7ae73532..36fe2e32 100644 --- a/package-lock.json +++ b/package-lock.json @@ -227,7 +227,6 @@ "integrity": "sha512-H3mcG6ZDLTlYfaSNi0iOKkigqMFvkTKlGUYlD8GW7nNOYRrevuA46iTypPyv+06V3fEmvvazfntkBU34L0azAw==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@babel/code-frame": "^7.28.6", "@babel/generator": "^7.28.6", @@ -2710,7 +2709,6 @@ "integrity": "sha512-PsSugIf9ip1H/mWKj4bi/BlEoerxXAda9ByRFsYuwsmr6af9NxJL0AaiNXs8Le7R21QR5KMiD/KdxZZ71LjAxQ==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@eslint-community/eslint-utils": "^4.9.1", "@typescript-eslint/types": "^8.52.0", @@ -3277,7 +3275,6 @@ "integrity": "sha512-eEXsVvLPu8Z4PkFibtuFJLJOTAV/nPdgtSjkGoPpddpFk3/ym2oy97jynY6ic2m6+nc5M8SE1e9v/mHKsulcJg==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@eslint-community/regexpp": "^4.12.2", "@typescript-eslint/scope-manager": "8.53.0", @@ -3307,7 +3304,6 @@ "integrity": "sha512-npiaib8XzbjtzS2N4HlqPvlpxpmZ14FjSJrteZpPxGUaYPlvhzlzUZ4mZyABo0EFrOWnvyd0Xxroq//hKhtAWg==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@typescript-eslint/scope-manager": "8.53.0", "@typescript-eslint/types": "8.53.0", @@ -3598,7 +3594,6 @@ "integrity": "sha512-6i+kNVVGb5I+XOCUgLDHol6sxb28ahbJrmLL/eiAWflT850MUqtAlFeyHDYHzScnOFcPy1AcmeY0yM77GgERmg==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@typescript-eslint/utils": "^8.53.0", "@unocss/config": "66.6.0", @@ -4301,7 +4296,6 @@ "resolved": "https://registry.npmjs.org/@vue/compiler-sfc/-/compiler-sfc-3.5.26.tgz", "integrity": "sha512-egp69qDTSEZcf4bGOSsprUr4xI73wfrY5oRs6GSgXFTiHrWj4Y3X5Ydtip9QMqiCMCPVwLglB9GBxXtTadJ3mA==", "license": "MIT", - "peer": true, "dependencies": { "@babel/parser": "^7.28.5", "@vue/compiler-core": "3.5.26", @@ -4492,7 +4486,6 @@ "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "license": "MIT", - "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -4819,7 +4812,6 @@ } ], "license": "MIT", - "peer": true, "dependencies": { "baseline-browser-mapping": "^2.9.0", "caniuse-lite": "^1.0.30001759", @@ -5617,7 +5609,6 @@ "integrity": "sha512-LEyamqS7W5HB3ujJyvi0HQK/dtVINZvd5mAAp9eT5S/ujByGjiZLCzPcHVzuXbpJDJF/cxwHlfceVUDZ2lnSTw==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@eslint-community/eslint-utils": "^4.8.0", "@eslint-community/regexpp": "^4.12.1", @@ -5848,7 +5839,6 @@ "integrity": "sha512-nK96Gnt6/9wj8KhTFg+D80Mc01cffrcB15NO6pkTJmPpO0vHV+9yxegr+wVry4O3uGbu83HN86inCO3IsML9Rw==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@dprint/formatter": "^0.5.1", "@dprint/markdown": "^0.20.0", @@ -7304,7 +7294,6 @@ "integrity": "sha512-F26gjC0yWN8uAA5m5Ss8ZQf5nDHWGlN/xWZIh8S5SRbsEKBovwZhxGd6LJlbZYxBgCYOtreSUyb8hpXyGC5O4A==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@jest/core": "30.2.0", "@jest/types": "30.2.0", @@ -7884,7 +7873,6 @@ "resolved": "https://registry.npmjs.org/jiti/-/jiti-2.6.1.tgz", "integrity": "sha512-ekilCSN1jwRvIbgeg/57YFh8qQDNbwDb9xT/qu2DAHbFFZUicIl4ygVaAvzveMhMVr3LnpSKTNnwt8PoOfmKhQ==", "license": "MIT", - "peer": true, "bin": { "jiti": "lib/jiti-cli.mjs" } @@ -7982,7 +7970,6 @@ "integrity": "sha512-1e4qoRgnn448pRuMvKGsFFymUCquZV0mpGgOyIKNgD3JVDTsVJyRBGH/Fm0tBb8WsWGgmB1mDe6/yJMQM37DUA==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "acorn": "^8.5.0", "eslint-visitor-keys": "^3.0.0", @@ -9739,7 +9726,6 @@ "resolved": "https://registry.npmjs.org/pinia/-/pinia-3.0.4.tgz", "integrity": "sha512-l7pqLUFTI/+ESXn6k3nu30ZIzW5E2WZF/LaHJEpoq6ElcLD+wduZoB2kBN19du6K/4FDpPMazY2wJr+IndBtQw==", "license": "MIT", - "peer": true, "dependencies": { "@vue/devtools-api": "^7.7.7" }, @@ -11158,7 +11144,6 @@ "integrity": "sha512-hjcS1mhfuyi4WW8IWtjP7brDrG2cuDZukyrYrSauoXGNgx0S7zceP07adYkJycEr56BOUTNPzbInooiN3fn1qw==", "devOptional": true, "license": "Apache-2.0", - "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -11305,7 +11290,6 @@ "resolved": "https://registry.npmjs.org/unocss/-/unocss-66.6.0.tgz", "integrity": "sha512-B5QsMJzFKeTHPzF5Ehr8tSMuhxzbCR9n+XP0GyhK9/2jTcBdI0/T+rCDDr9m6vUz+lku/coCVz7VAQ2BRAbZJw==", "license": "MIT", - "peer": true, "dependencies": { "@unocss/astro": "66.6.0", "@unocss/cli": "66.6.0", @@ -11351,7 +11335,6 @@ "resolved": "https://registry.npmjs.org/unocss-preset-animations/-/unocss-preset-animations-1.3.0.tgz", "integrity": "sha512-NLsBzPB98Jc8b6+t8nwLItI12ZE/48IZVccZ2uA2owLoeAvhRaDoRvs6eLcJfIUZfRycL25xxOeQoOBMo1p/aw==", "license": "MIT", - "peer": true, "peerDependencies": { "@unocss/preset-wind3": ">=0.56.0 < 101", "unocss": ">=0.56.0 < 101" @@ -11492,7 +11475,6 @@ "resolved": "https://registry.npmjs.org/vite/-/vite-6.4.3.tgz", "integrity": "sha512-NTKlcQjlAK7MlQoyb6LgaqHc8sso/pVyUJYWMws3jg21uTJw/LddqIFPcPqP6PzpgbIcZyKI85sFE4HBrQDA8A==", "license": "MIT", - "peer": true, "dependencies": { "esbuild": "^0.25.0", "fdir": "^6.4.4", @@ -11574,7 +11556,6 @@ "resolved": "https://registry.npmjs.org/vue/-/vue-3.5.26.tgz", "integrity": "sha512-SJ/NTccVyAoNUJmkM9KUqPcYlY+u8OVL1X5EW9RIs3ch5H2uERxyyIUI4MRxVCSOiEcupX9xNGde1tL9ZKpimA==", "license": "MIT", - "peer": true, "dependencies": { "@vue/compiler-dom": "3.5.26", "@vue/compiler-sfc": "3.5.26", @@ -11597,7 +11578,6 @@ "integrity": "sha512-CydUvFOQKD928UzZhTp4pr2vWz1L+H99t7Pkln2QSPdvmURT0MoC4wUccfCnuEaihNsu9aYYyk+bep8rlfkUXw==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "debug": "^4.4.0", "eslint-scope": "^8.2.0", diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 7767e347..0cb1fdd4 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -88,7 +88,7 @@ http = "1" log = "0.4" futures = "0.3" rand = "0.8" -data-studio-agent = { path = "/Users/blankll/Documents/devs/geekfun/data-studio-agent" } +data-studio-agent = { path = "../../data-studio-agent" } # Archive extraction (JRE downloads) flate2 = "1.0" diff --git a/src-tauri/src/agent_adapters.rs b/src-tauri/src/agent_adapters.rs index 5c8dcb2e..197924ae 100644 --- a/src-tauri/src/agent_adapters.rs +++ b/src-tauri/src/agent_adapters.rs @@ -8,8 +8,8 @@ use std::collections::HashMap; use std::sync::Arc; use data_studio_agent as lib; +use data_studio_agent::storage; use data_studio_agent::traits::{CancelMap, ConfirmMap, EventEmitter}; -use data_studio_agent::storage as storage; use serde_json::Value; use tauri::{AppHandle, Emitter, Manager, State}; @@ -44,7 +44,8 @@ pub async fn run_agent_loop( let confirm_map: ConfirmMap = confirm_state.inner().clone(); let cancel_state: State = app.state::(); let cancel_map: CancelMap = cancel_state.inner().clone(); - let executor_state: State> = app.state::>(); + let executor_state: State> = + app.state::>(); let executor: Arc = executor_state.inner().clone(); let connections: HashMap = settings @@ -158,8 +159,7 @@ pub async fn run_agent_step( base_url: Option, ) -> Result { let result = lib::harness::run_agent_step( - provider, model, messages, tools, - http_proxy, proxy_mode, api_key, base_url, + provider, model, messages, tools, http_proxy, proxy_mode, api_key, base_url, ) .await?; @@ -178,7 +178,8 @@ pub async fn validate_llm_config( proxy_mode: Option, base_url: Option, ) -> Result { - lib::harness::validate_llm_config(provider, api_key, model, http_proxy, proxy_mode, base_url).await + lib::harness::validate_llm_config(provider, api_key, model, http_proxy, proxy_mode, base_url) + .await } #[tauri::command] diff --git a/src-tauri/src/capabilities/sql.rs b/src-tauri/src/capabilities/sql.rs index c159a686..96b70156 100644 --- a/src-tauri/src/capabilities/sql.rs +++ b/src-tauri/src/capabilities/sql.rs @@ -41,22 +41,22 @@ async fn resolve_adapter(connection_id: &str) -> Result { - let qualified = - build_qualified_table(query.schema.as_deref(), &query.table, db_type); + let qualified = build_qualified_table(query.schema.as_deref(), &query.table, db_type); let sql = build_paginated_select(&qualified, filter_ref, limit_val, offset_val, db_type); if let Some(ref db) = query.database { @@ -385,8 +384,7 @@ pub async fn get_table_data( connection.execute_query(&sql).await } ActiveConnection::SQLServer(adapter) => { - let qualified = - build_qualified_table(query.schema.as_deref(), &query.table, db_type); + let qualified = build_qualified_table(query.schema.as_deref(), &query.table, db_type); let sql = build_paginated_select(&qualified, filter_ref, limit_val, offset_val, db_type); if let Some(ref db) = query.database { @@ -408,8 +406,7 @@ pub async fn get_table_data( connection.execute_query(&sql).await } _ => { - let qualified = - build_qualified_table(query.schema.as_deref(), &query.table, db_type); + let qualified = build_qualified_table(query.schema.as_deref(), &query.table, db_type); let sql = build_paginated_select(&qualified, filter_ref, limit_val, offset_val, db_type); connection.execute_query(&sql).await @@ -967,7 +964,12 @@ pub async fn get_object_ddl( }; connection - .get_object_ddl(Some(&database), schema.as_deref(), &object_name, &object_type) + .get_object_ddl( + Some(&database), + schema.as_deref(), + &object_name, + &object_type, + ) .await .map_err(|e| format!("Failed to get object DDL: {}", e)) } @@ -991,7 +993,12 @@ pub async fn drop_object( }; connection - .drop_object(Some(&database), schema.as_deref(), &object_name, &object_type) + .drop_object( + Some(&database), + schema.as_deref(), + &object_name, + &object_type, + ) .await .map_err(|e| format!("Failed to drop object: {}", e)) } @@ -1016,7 +1023,13 @@ pub async fn rename_object( }; connection - .rename_object(Some(&database), schema.as_deref(), &object_name, &object_type, &new_name) + .rename_object( + Some(&database), + schema.as_deref(), + &object_name, + &object_type, + &new_name, + ) .await .map_err(|e| format!("Failed to rename object: {}", e)) } diff --git a/src-tauri/src/commands/connection.rs b/src-tauri/src/commands/connection.rs index e14d0fcf..64da7b50 100644 --- a/src-tauri/src/commands/connection.rs +++ b/src-tauri/src/commands/connection.rs @@ -82,7 +82,9 @@ pub async fn get_connection_status( /// Returns quality data including latency, error count, and a composite score (0-100). /// Returns an error if no health data exists for the given connection. #[tauri::command] -pub async fn get_connection_quality(connection_id: String) -> Result { +pub async fn get_connection_quality( + connection_id: String, +) -> Result { let guardian = crate::GUARDIAN .get() .ok_or_else(|| "Guardian not initialized".to_string())?; diff --git a/src-tauri/src/commands/helpers.rs b/src-tauri/src/commands/helpers.rs index 6a59646d..85de96c1 100644 --- a/src-tauri/src/commands/helpers.rs +++ b/src-tauri/src/commands/helpers.rs @@ -10,8 +10,8 @@ use crate::database::{ mysql::MySQLAdapter, postgres::PostgresAdapter, sqlite::SQLiteAdapter, sqlserver::SqlServerAdapter, }; -use crate::ssh::TunnelManager; use crate::ssh::start_transport_layers; +use crate::ssh::TunnelManager; use crate::state::ActiveConnection; use std::sync::Arc; use tokio::sync::Mutex; @@ -26,36 +26,34 @@ pub async fn create_and_connect_adapter( let strategy = resolve_effective_type(db_type_enum); match strategy { - ConnectionStrategy::Native(core) => { - match core { - CoreDatabaseType::PostgreSQL => { - let mut adapter = PostgresAdapter::new(conn_config); - adapter.connect().await.map_err(|e| e.to_string())?; - Ok(ActiveConnection::Postgres(Arc::new(Mutex::new(adapter)))) - } - CoreDatabaseType::MySQL => { - let mut adapter = MySQLAdapter::new(conn_config); - adapter.connect().await.map_err(|e| e.to_string())?; - Ok(ActiveConnection::MySQL(Arc::new(Mutex::new(adapter)))) - } - CoreDatabaseType::SqlServer => { - let mut adapter = SqlServerAdapter::new(conn_config); - adapter.connect().await.map_err(|e| e.to_string())?; - Ok(ActiveConnection::SQLServer(Arc::new(Mutex::new(adapter)))) - } - CoreDatabaseType::SQLite => { - let mut adapter = SQLiteAdapter::new(conn_config); - adapter.connect().await.map_err(|e| e.to_string())?; - Ok(ActiveConnection::SQLite(Arc::new(Mutex::new(adapter)))) - } - CoreDatabaseType::ClickHouse => { - let mut adapter = ClickHouseAdapter::new(conn_config); - adapter.connect().await.map_err(|e| e.to_string())?; - Ok(ActiveConnection::ClickHouse(Arc::new(Mutex::new(adapter)))) - } - _ => Err(format!("Native adapter not yet implemented for {:?}", core)), + ConnectionStrategy::Native(core) => match core { + CoreDatabaseType::PostgreSQL => { + let mut adapter = PostgresAdapter::new(conn_config); + adapter.connect().await.map_err(|e| e.to_string())?; + Ok(ActiveConnection::Postgres(Arc::new(Mutex::new(adapter)))) } - } + CoreDatabaseType::MySQL => { + let mut adapter = MySQLAdapter::new(conn_config); + adapter.connect().await.map_err(|e| e.to_string())?; + Ok(ActiveConnection::MySQL(Arc::new(Mutex::new(adapter)))) + } + CoreDatabaseType::SqlServer => { + let mut adapter = SqlServerAdapter::new(conn_config); + adapter.connect().await.map_err(|e| e.to_string())?; + Ok(ActiveConnection::SQLServer(Arc::new(Mutex::new(adapter)))) + } + CoreDatabaseType::SQLite => { + let mut adapter = SQLiteAdapter::new(conn_config); + adapter.connect().await.map_err(|e| e.to_string())?; + Ok(ActiveConnection::SQLite(Arc::new(Mutex::new(adapter)))) + } + CoreDatabaseType::ClickHouse => { + let mut adapter = ClickHouseAdapter::new(conn_config); + adapter.connect().await.map_err(|e| e.to_string())?; + Ok(ActiveConnection::ClickHouse(Arc::new(Mutex::new(adapter)))) + } + _ => Err(format!("Native adapter not yet implemented for {:?}", core)), + }, ConnectionStrategy::JdbcBridge => { let mut adapter = JdbcBridgeAdapter::new(conn_config); adapter.connect().await.map_err(|e| e.to_string())?; @@ -250,7 +248,8 @@ pub async fn connection_host_port( return Ok((config.host.clone(), config.port)); } - match start_transport_layers(connection_id, &layers, &config.host, config.port, tunnels).await? { + match start_transport_layers(connection_id, &layers, &config.host, config.port, tunnels).await? + { Some(local_port) => Ok(("127.0.0.1".to_string(), local_port)), None => Ok((config.host.clone(), config.port)), } diff --git a/src-tauri/src/commands/jdbc.rs b/src-tauri/src/commands/jdbc.rs index 966da013..3d294600 100644 --- a/src-tauri/src/commands/jdbc.rs +++ b/src-tauri/src/commands/jdbc.rs @@ -1,5 +1,11 @@ use crate::database::config::DatabaseType; -use crate::database::jdbc_bridge::{download, jre, launcher::JdbcBridgeLauncher, protocol::{JdbcMethod, JdbcRequest}, registry::DriverRegistry, tns_parser}; +use crate::database::jdbc_bridge::{ + download, jre, + launcher::JdbcBridgeLauncher, + protocol::{JdbcMethod, JdbcRequest}, + registry::DriverRegistry, + tns_parser, +}; use serde::Serialize; use std::path::PathBuf; @@ -33,8 +39,7 @@ pub async fn check_jre_status() -> Result { source: "managed".to_string(), }) } else if let Some(system_java) = jre::JreDetector::detect_system_java() { - let version = jre::system_java_version(&system_java) - .map(|v| format!("{}.x", v)); + let version = jre::system_java_version(&system_java).map(|v| format!("{}.x", v)); Ok(JreStatus { installed: true, version, @@ -53,7 +58,18 @@ pub async fn check_jre_status() -> Result { #[tauri::command] pub async fn download_jre() -> Result<(), String> { - jre::download_managed_jre().await.map_err(|e| e.to_string()) + let result = jre::download_managed_jre().await; + match result { + Ok(()) => { + crate::download::emit_complete("jre", crate::download::DownloadKind::Jre); + Ok(()) + } + Err(e) => { + let err = e.to_string(); + crate::download::emit_error("jre", crate::download::DownloadKind::Jre, &err); + Err(err) + } + } } #[tauri::command] @@ -74,9 +90,13 @@ pub async fn list_drivers() -> Result, String> { version_cap: config.version_cap.clone(), filename: cached.as_ref().map(|c| c.0.clone()), file_size: cached.as_ref().map(|c| c.1), - resolved_version: cached - .as_ref() - .and_then(|c| parse_jar_version(&config.maven_artifact, config.maven_classifier.as_deref(), &c.0)), + resolved_version: cached.as_ref().and_then(|c| { + parse_jar_version( + &config.maven_artifact, + config.maven_classifier.as_deref(), + &c.0, + ) + }), }); } Ok(result) @@ -90,10 +110,16 @@ pub async fn download_driver(db_type: String) -> Result<(), String> { .get_config(dt) .ok_or_else(|| format!("No registry entry for {}", db_type))?; + crate::download::emit_progress(&db_type, crate::download::DownloadKind::Driver, 0, 1); + // Start a temporary Java bridge process to resolve the driver let bridge_jar = download::bridge_jar_path(); let mut launcher = JdbcBridgeLauncher::new(bridge_jar); - launcher.start(&[]).map_err(|e| e.to_string())?; + if let Err(e) = launcher.start(&[]) { + let err = e.to_string(); + crate::download::emit_error(&db_type, crate::download::DownloadKind::Driver, &err); + return Err(err); + } let req = JdbcRequest::new( JdbcMethod::ResolveDriver, @@ -104,14 +130,29 @@ pub async fn download_driver(db_type: String) -> Result<(), String> { "maven_classifier": config.maven_classifier, }), ); - let resp = launcher.send_request(&req).map_err(|e| e.to_string())?; + let resp = launcher + .send_request_with_progress(&req, |downloaded, total| { + crate::download::emit_progress( + &db_type, + crate::download::DownloadKind::Driver, + downloaded, + total, + ); + }) + .map_err(|e| { + let err = e.to_string(); + crate::download::emit_error(&db_type, crate::download::DownloadKind::Driver, &err); + err + })?; if let Some(err) = resp.error { launcher.shutdown(); + crate::download::emit_error(&db_type, crate::download::DownloadKind::Driver, &err); return Err(err); } // Driver is now cached on disk by the Java bridge launcher.shutdown(); + crate::download::emit_complete(&db_type, crate::download::DownloadKind::Driver); Ok(()) } @@ -154,6 +195,10 @@ fn parse_db_type(s: &str) -> Result { "cassandra" => Ok(DatabaseType::Cassandra), "iris" => Ok(DatabaseType::Iris), "access" => Ok(DatabaseType::Access), + "tdengine" | "td" => Ok(DatabaseType::TDengine), + "yashandb" => Ok(DatabaseType::YashanDB), + "kingbasees" | "kingbase" => Ok(DatabaseType::KingbaseES), + "oceanbase_oracle" | "oceanbase-oracle" => Ok(DatabaseType::OceanbaseOracle), _ => Err(format!("Unknown JDBC database type: {}", s)), } } @@ -218,9 +263,18 @@ pub async fn list_tns_aliases(tns_admin_dir: String) -> Result, Stri /// Does NOT require Java — purely HTTP download, parallel-safe. #[tauri::command] pub async fn download_jdbc_driver_direct(db_type: String) -> Result<(), String> { - download::download_jdbc_driver_direct(&db_type) - .await - .map_err(|e| e.to_string()) + let result = download::download_jdbc_driver_direct(&db_type).await; + match result { + Ok(()) => { + crate::download::emit_complete(&db_type, crate::download::DownloadKind::Driver); + Ok(()) + } + Err(e) => { + let err = e.to_string(); + crate::download::emit_error(&db_type, crate::download::DownloadKind::Driver, &err); + Err(err) + } + } } #[tauri::command] @@ -245,9 +299,18 @@ pub async fn check_bridge_status() -> Result { #[tauri::command] pub async fn download_bridge_jar() -> Result<(), String> { - download::download_bridge_plugin() - .await - .map_err(|e| e.to_string()) + let result = download::download_bridge_plugin().await; + match result { + Ok(()) => { + crate::download::emit_complete("bridge", crate::download::DownloadKind::Bridge); + Ok(()) + } + Err(e) => { + let err = e.to_string(); + crate::download::emit_error("bridge", crate::download::DownloadKind::Bridge, &err); + Err(err) + } + } } #[tauri::command] @@ -267,6 +330,102 @@ pub struct JreUpdateStatus { pub update_available: bool, } +#[derive(Serialize)] +pub struct DriverUpdateStatus { + pub current_version: Option, + pub latest_version: Option, + pub update_available: bool, +} + +#[tauri::command] +pub async fn check_driver_update(db_type: String) -> Result { + let registry = DriverRegistry::load(); + let config = registry + .get_config_by_name(&db_type) + .ok_or_else(|| format!("Unknown driver type: {}", db_type))?; + + let current_version = driver_cache_info(&config.maven_artifact).and_then(|(filename, _)| { + parse_jar_version( + &config.maven_artifact, + config.maven_classifier.as_deref(), + &filename, + ) + }); + + // Drivers served from a direct URL (not on Maven Central, e.g. GBase 8a). + // Check by HEAD'ing the URL and comparing the Content-Length against the + // locally cached JAR size — a different length signals a new release. + if let Some(download_url) = &config.download_url { + let local_size = driver_cache_info(&config.maven_artifact).map(|(_, size)| size); + match reqwest::Client::new().head(download_url).send().await { + Ok(resp) if resp.status().is_success() => { + let remote_size = resp + .headers() + .get(reqwest::header::CONTENT_LENGTH) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()); + let update_available = match (remote_size, local_size) { + (Some(r), Some(l)) => r != l, + (Some(_), None) => true, + _ => false, + }; + let latest_version = remote_size.map(|s| format!("{} bytes", s)); + return Ok(DriverUpdateStatus { + current_version: local_size.map(|s| format!("{} bytes", s)), + latest_version, + update_available, + }); + } + _ => { + return Ok(DriverUpdateStatus { + current_version: local_size.map(|s| format!("{} bytes", s)), + latest_version: None, + update_available: false, + }); + } + } + } + + // If version_cap is set, treat it as the pinned latest version + // (avoids spurious "update available" for intentionally pinned drivers) + let latest_version = if let Some(cap) = &config.version_cap { + Some(cap.clone()) + } else { + let group_path = config.maven_group.replace('.', "/"); + let url = format!( + "https://repo1.maven.org/maven2/{}/{}/maven-metadata.xml", + group_path, config.maven_artifact + ); + match reqwest::get(&url).await { + Ok(resp) if resp.status().is_success() => match resp.text().await { + Ok(xml) => parse_maven_latest_version(&xml), + Err(_) => None, + }, + _ => None, + } + }; + + let update_available = match (¤t_version, &latest_version) { + (Some(c), Some(l)) => c != l, + (None, Some(_)) => true, + _ => false, + }; + + Ok(DriverUpdateStatus { + current_version, + latest_version, + update_available, + }) +} + +/// Extract the `` version tag from a Maven metadata XML string. +fn parse_maven_latest_version(xml: &str) -> Option { + let start = xml.find("")?; + let value_start = start + "".len(); + let end = xml[value_start..].find("")?; + Some(xml[value_start..value_start + end].to_string()) +} + #[tauri::command] pub async fn check_jre_update() -> Result { let current = jre::read_jre_version(); diff --git a/src-tauri/src/commands/query.rs b/src-tauri/src/commands/query.rs index 126ce971..50c63f68 100644 --- a/src-tauri/src/commands/query.rs +++ b/src-tauri/src/commands/query.rs @@ -7,9 +7,9 @@ use crate::api_response::{db_error_to_api_error, ApiResponse}; use crate::connection::guardian::HealthState; use crate::connection::handle::ConnectionHandle; use crate::database::{ - ClickHouseAdapter, ConnectionConfig, DatabaseAdapter, DbError, ExplainResult, - HttpSqlAdapter, JdbcBridgeAdapter, MySQLAdapter, PostgresAdapter, QueryResult, RqliteAdapter, - SqlServerAdapter, TursoAdapter, + ClickHouseAdapter, ConnectionConfig, DatabaseAdapter, DbError, ExplainResult, HttpSqlAdapter, + JdbcBridgeAdapter, MySQLAdapter, PostgresAdapter, QueryResult, RqliteAdapter, SqlServerAdapter, + TursoAdapter, }; use crate::state::{ActiveConnection, AppState}; use std::sync::Arc; @@ -259,10 +259,14 @@ pub async fn execute_query( }; // Cache this handle for future cross-database lookups - state.cache.get_or_create( - &crate::connection::cache::PoolKey::new(&connection_id, database.as_deref()), - state.inner(), - ).await.ok(); + state + .cache + .get_or_create( + &crate::connection::cache::PoolKey::new(&connection_id, database.as_deref()), + state.inner(), + ) + .await + .ok(); // Guardian health check if let Some(guardian) = crate::GUARDIAN.get() { @@ -296,7 +300,9 @@ pub async fn execute_query( match result { Ok(data) => { if let Some(guardian) = crate::GUARDIAN.get() { - guardian.mark_healthy(&connection_id, Some(elapsed_ms)).await; + guardian + .mark_healthy(&connection_id, Some(elapsed_ms)) + .await; } Ok(ApiResponse::success(data)) } @@ -479,9 +485,9 @@ pub async fn explain_query( return match kind { TempExplainKind::Postgres(cfg) => { let mut temp = PostgresAdapter::new(cfg); - temp.connect().await.map_err(|e| { - format!("Failed to connect to database for EXPLAIN: {}", e) - })?; + temp.connect() + .await + .map_err(|e| format!("Failed to connect to database for EXPLAIN: {}", e))?; let connection = ActiveConnection::Postgres(Arc::new(Mutex::new(temp))); let database_type = "postgresql"; let explain_sql = if analyze { @@ -504,9 +510,9 @@ pub async fn explain_query( } TempExplainKind::MySQL(cfg) => { let mut temp = MySQLAdapter::new(cfg); - temp.connect().await.map_err(|e| { - format!("Failed to connect to database for EXPLAIN: {}", e) - })?; + temp.connect() + .await + .map_err(|e| format!("Failed to connect to database for EXPLAIN: {}", e))?; let connection = ActiveConnection::MySQL(Arc::new(Mutex::new(temp))); let database_type = "mysql"; let (explain_sql, plan_format) = if analyze { @@ -529,9 +535,9 @@ pub async fn explain_query( } TempExplainKind::SQLServer(cfg) => { let mut temp = SqlServerAdapter::new(cfg); - temp.connect().await.map_err(|e| { - format!("Failed to connect to database for EXPLAIN: {}", e) - })?; + temp.connect() + .await + .map_err(|e| format!("Failed to connect to database for EXPLAIN: {}", e))?; let connection = ActiveConnection::SQLServer(Arc::new(Mutex::new(temp))); let database_type = "sqlserver"; let settings = if analyze { @@ -560,9 +566,9 @@ pub async fn explain_query( } TempExplainKind::ClickHouse(cfg) => { let mut temp = ClickHouseAdapter::new(cfg); - temp.connect().await.map_err(|e| { - format!("Failed to connect to database for EXPLAIN: {}", e) - })?; + temp.connect() + .await + .map_err(|e| format!("Failed to connect to database for EXPLAIN: {}", e))?; let connection = ActiveConnection::ClickHouse(Arc::new(Mutex::new(temp))); let database_type = "clickhouse"; let explain_sql = format!("EXPLAIN {}", sql); @@ -581,9 +587,9 @@ pub async fn explain_query( } TempExplainKind::JdbcBridge(cfg) => { let mut temp = JdbcBridgeAdapter::new(cfg); - temp.connect().await.map_err(|e| { - format!("Failed to connect to database for EXPLAIN: {}", e) - })?; + temp.connect() + .await + .map_err(|e| format!("Failed to connect to database for EXPLAIN: {}", e))?; let connection = ActiveConnection::JdbcBridge(Arc::new(Mutex::new(temp))); let database_type = "generic"; let explain_sql = format!("EXPLAIN {}", sql); @@ -602,9 +608,9 @@ pub async fn explain_query( } TempExplainKind::HttpSql(cfg) => { let mut temp = HttpSqlAdapter::new(cfg); - temp.connect().await.map_err(|e| { - format!("Failed to connect to database for EXPLAIN: {}", e) - })?; + temp.connect() + .await + .map_err(|e| format!("Failed to connect to database for EXPLAIN: {}", e))?; let connection = ActiveConnection::HttpSql(Arc::new(Mutex::new(temp))); let database_type = "generic"; let explain_sql = format!("EXPLAIN {}", sql); @@ -623,9 +629,9 @@ pub async fn explain_query( } TempExplainKind::Rqlite(cfg) => { let mut temp = RqliteAdapter::new(cfg); - temp.connect().await.map_err(|e| { - format!("Failed to connect to database for EXPLAIN: {}", e) - })?; + temp.connect() + .await + .map_err(|e| format!("Failed to connect to database for EXPLAIN: {}", e))?; let connection = ActiveConnection::Rqlite(Arc::new(Mutex::new(temp))); let database_type = "rqlite"; let explain_sql = format!("EXPLAIN {}", sql); @@ -644,9 +650,9 @@ pub async fn explain_query( } TempExplainKind::Turso(cfg) => { let mut temp = TursoAdapter::new(cfg); - temp.connect().await.map_err(|e| { - format!("Failed to connect to database for EXPLAIN: {}", e) - })?; + temp.connect() + .await + .map_err(|e| format!("Failed to connect to database for EXPLAIN: {}", e))?; let connection = ActiveConnection::Turso(Arc::new(Mutex::new(temp))); let database_type = "turso"; let explain_sql = format!("EXPLAIN {}", sql); diff --git a/src-tauri/src/commands/server.rs b/src-tauri/src/commands/server.rs index def130a8..8b660b11 100644 --- a/src-tauri/src/commands/server.rs +++ b/src-tauri/src/commands/server.rs @@ -162,12 +162,8 @@ pub async fn test_connection(config: ServerConfig) -> Result { conn_config.host = host; diff --git a/src-tauri/src/connection/guardian.rs b/src-tauri/src/connection/guardian.rs index cf7b940b..5edc4f4d 100644 --- a/src-tauri/src/connection/guardian.rs +++ b/src-tauri/src/connection/guardian.rs @@ -14,10 +14,10 @@ use crate::database::adapter::DatabaseAdapter; use crate::state::{ActiveConnection, AppState}; use crate::APP_HANDLE; use serde::Serialize; -use tauri::{Emitter, Manager}; use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, Instant}; +use tauri::{Emitter, Manager}; use tokio::sync::RwLock; use tokio::time::MissedTickBehavior; @@ -314,7 +314,10 @@ impl ConnectionGuardian { continue; } // Gracefully disconnect idle connection - log::info!("Connection '{conn_id}' idle for {}s, evicting", self.idle_eviction_secs); + log::info!( + "Connection '{conn_id}' idle for {}s, evicting", + self.idle_eviction_secs + ); let conns = state.connections.write().await; if let Some(connection) = conns.get(&conn_id) { self.disconnect_connection(connection).await; @@ -374,7 +377,8 @@ impl ConnectionGuardian { self.mark_healthy(connection_id, None).await; } else { // Exponential backoff - let delay = (RECONNECT_BASE_DELAY_SECS * 2u64.pow(attempt_num)).min(RECONNECT_MAX_DELAY_SECS); + let delay = + (RECONNECT_BASE_DELAY_SECS * 2u64.pow(attempt_num)).min(RECONNECT_MAX_DELAY_SECS); let next = Instant::now() + Duration::from_secs(delay); let mut health_map = self.health.write().await; if let Some(h) = health_map.get_mut(connection_id) { @@ -384,7 +388,10 @@ impl ConnectionGuardian { self.emit_state_change( connection_id, HealthState::Dead, - Some(&format!("Reconnect attempt {} failed, retrying in {}s", attempt_num, delay)), + Some(&format!( + "Reconnect attempt {} failed, retrying in {}s", + attempt_num, delay + )), ); } } @@ -405,39 +412,22 @@ impl ConnectionGuardian { if success { self.mark_healthy(connection_id, None).await; } else { - self.mark_error(connection_id, "Health check ping failed", None).await; + self.mark_error(connection_id, "Health check ping failed", None) + .await; } } async fn ping_connection(&self, connection: &ActiveConnection) -> bool { let result = match connection { - ActiveConnection::Postgres(adapter) => { - adapter.lock().await.test_connection().await - } - ActiveConnection::MySQL(adapter) => { - adapter.lock().await.test_connection().await - } - ActiveConnection::SQLServer(adapter) => { - adapter.lock().await.test_connection().await - } - ActiveConnection::SQLite(adapter) => { - adapter.lock().await.test_connection().await - } - ActiveConnection::ClickHouse(adapter) => { - adapter.lock().await.test_connection().await - } - ActiveConnection::JdbcBridge(adapter) => { - adapter.lock().await.test_connection().await - } - ActiveConnection::HttpSql(adapter) => { - adapter.lock().await.test_connection().await - } - ActiveConnection::Rqlite(adapter) => { - adapter.lock().await.test_connection().await - } - ActiveConnection::Turso(adapter) => { - adapter.lock().await.test_connection().await - } + ActiveConnection::Postgres(adapter) => adapter.lock().await.test_connection().await, + ActiveConnection::MySQL(adapter) => adapter.lock().await.test_connection().await, + ActiveConnection::SQLServer(adapter) => adapter.lock().await.test_connection().await, + ActiveConnection::SQLite(adapter) => adapter.lock().await.test_connection().await, + ActiveConnection::ClickHouse(adapter) => adapter.lock().await.test_connection().await, + ActiveConnection::JdbcBridge(adapter) => adapter.lock().await.test_connection().await, + ActiveConnection::HttpSql(adapter) => adapter.lock().await.test_connection().await, + ActiveConnection::Rqlite(adapter) => adapter.lock().await.test_connection().await, + ActiveConnection::Turso(adapter) => adapter.lock().await.test_connection().await, }; result.is_ok() } diff --git a/src-tauri/src/connection/handle.rs b/src-tauri/src/connection/handle.rs index 72ccc240..d69275fd 100644 --- a/src-tauri/src/connection/handle.rs +++ b/src-tauri/src/connection/handle.rs @@ -8,7 +8,10 @@ use crate::database::{ error::DbResult, - types::{ColumnInfo, ConnectionStatus, DatabaseSchema, ForeignKeyInfo, IndexInfo, ObjectInfo, QueryResult, TableInfo, TriggerInfo}, + types::{ + ColumnInfo, ConnectionStatus, DatabaseSchema, ForeignKeyInfo, IndexInfo, ObjectInfo, + QueryResult, TableInfo, TriggerInfo, + }, }; use async_trait::async_trait; @@ -22,19 +25,83 @@ pub trait ConnectionHandle: Send + Sync { async fn test_connection(&self) -> DbResult; async fn list_databases(&self) -> DbResult>; async fn list_schemas(&self, database: Option<&str>) -> DbResult>; - async fn list_tables(&self, database: Option<&str>, schema: Option<&str>) -> DbResult>; - async fn list_columns(&self, database: Option<&str>, schema: Option<&str>, table: &str) -> DbResult>; - async fn get_table_info(&self, database: Option<&str>, schema: Option<&str>, table: &str) -> DbResult; - async fn get_foreign_keys(&self, database: Option<&str>, schema: Option<&str>) -> DbResult>; - async fn list_views(&self, database: Option<&str>, schema: Option<&str>) -> DbResult>; - async fn list_procedures(&self, database: Option<&str>, schema: Option<&str>) -> DbResult>; - async fn list_functions(&self, database: Option<&str>, schema: Option<&str>) -> DbResult>; - async fn list_triggers(&self, database: Option<&str>, schema: Option<&str>, table: &str) -> DbResult>; - async fn list_indexes(&self, database: Option<&str>, schema: Option<&str>, table: &str) -> DbResult>; - async fn list_foreign_keys_for_table(&self, database: Option<&str>, schema: Option<&str>, table: &str) -> DbResult>; - async fn get_object_ddl(&self, database: Option<&str>, schema: Option<&str>, object_name: &str, object_type: &str) -> DbResult; - async fn drop_object(&self, database: Option<&str>, schema: Option<&str>, object_name: &str, object_type: &str) -> DbResult<()>; - async fn rename_object(&self, database: Option<&str>, schema: Option<&str>, object_name: &str, object_type: &str, new_name: &str) -> DbResult<()>; + async fn list_tables( + &self, + database: Option<&str>, + schema: Option<&str>, + ) -> DbResult>; + async fn list_columns( + &self, + database: Option<&str>, + schema: Option<&str>, + table: &str, + ) -> DbResult>; + async fn get_table_info( + &self, + database: Option<&str>, + schema: Option<&str>, + table: &str, + ) -> DbResult; + async fn get_foreign_keys( + &self, + database: Option<&str>, + schema: Option<&str>, + ) -> DbResult>; + async fn list_views( + &self, + database: Option<&str>, + schema: Option<&str>, + ) -> DbResult>; + async fn list_procedures( + &self, + database: Option<&str>, + schema: Option<&str>, + ) -> DbResult>; + async fn list_functions( + &self, + database: Option<&str>, + schema: Option<&str>, + ) -> DbResult>; + async fn list_triggers( + &self, + database: Option<&str>, + schema: Option<&str>, + table: &str, + ) -> DbResult>; + async fn list_indexes( + &self, + database: Option<&str>, + schema: Option<&str>, + table: &str, + ) -> DbResult>; + async fn list_foreign_keys_for_table( + &self, + database: Option<&str>, + schema: Option<&str>, + table: &str, + ) -> DbResult>; + async fn get_object_ddl( + &self, + database: Option<&str>, + schema: Option<&str>, + object_name: &str, + object_type: &str, + ) -> DbResult; + async fn drop_object( + &self, + database: Option<&str>, + schema: Option<&str>, + object_name: &str, + object_type: &str, + ) -> DbResult<()>; + async fn rename_object( + &self, + database: Option<&str>, + schema: Option<&str>, + object_name: &str, + object_type: &str, + new_name: &str, + ) -> DbResult<()>; async fn disconnect(&self) -> DbResult<()>; async fn query_timeout_secs(&self) -> u64; } diff --git a/src-tauri/src/connection/handle_impl.rs b/src-tauri/src/connection/handle_impl.rs index 618476f1..7a61ed8b 100644 --- a/src-tauri/src/connection/handle_impl.rs +++ b/src-tauri/src/connection/handle_impl.rs @@ -7,8 +7,8 @@ use crate::database::{ adapter::DatabaseAdapter, error::DbResult, types::{ - ColumnInfo, ConnectionStatus, DatabaseSchema, ForeignKeyInfo, IndexInfo, ObjectInfo, QueryResult, - TableInfo, TriggerInfo, + ColumnInfo, ConnectionStatus, DatabaseSchema, ForeignKeyInfo, IndexInfo, ObjectInfo, + QueryResult, TableInfo, TriggerInfo, }, }; use crate::state::ActiveConnection; @@ -142,7 +142,14 @@ impl ConnectionHandle for ActiveConnection { object_name: &str, object_type: &str, ) -> DbResult { - delegate!(self, get_object_ddl, database, schema, object_name, object_type) + delegate!( + self, + get_object_ddl, + database, + schema, + object_name, + object_type + ) } async fn drop_object( @@ -152,7 +159,14 @@ impl ConnectionHandle for ActiveConnection { object_name: &str, object_type: &str, ) -> DbResult<()> { - delegate!(self, drop_object, database, schema, object_name, object_type) + delegate!( + self, + drop_object, + database, + schema, + object_name, + object_type + ) } async fn rename_object( @@ -163,7 +177,15 @@ impl ConnectionHandle for ActiveConnection { object_type: &str, new_name: &str, ) -> DbResult<()> { - delegate!(self, rename_object, database, schema, object_name, object_type, new_name) + delegate!( + self, + rename_object, + database, + schema, + object_name, + object_type, + new_name + ) } async fn disconnect(&self) -> DbResult<()> { diff --git a/src-tauri/src/database/clickhouse.rs b/src-tauri/src/database/clickhouse.rs index 73c1052a..48fe747d 100644 --- a/src-tauri/src/database/clickhouse.rs +++ b/src-tauri/src/database/clickhouse.rs @@ -133,7 +133,11 @@ impl ClickHouseAdapter { /// Build the base URL from the configuration. fn build_base_url(&self) -> String { - let scheme = if self.config.ssl_mode == SslMode::Disable { "http" } else { "https" }; + let scheme = if self.config.ssl_mode == SslMode::Disable { + "http" + } else { + "https" + }; format!("{}://{}:{}", scheme, self.config.host, self.config.port) } @@ -145,11 +149,15 @@ impl ClickHouseAdapter { builder = self.apply_ssl_to_builder(builder)?; - builder.build() + builder + .build() .map_err(|e| DbError::Connection(format!("Failed to create HTTP client: {}", e))) } - fn apply_ssl_to_builder(&self, mut builder: reqwest::ClientBuilder) -> DbResult { + fn apply_ssl_to_builder( + &self, + mut builder: reqwest::ClientBuilder, + ) -> DbResult { match self.config.ssl_mode { SslMode::Disable => {} SslMode::Prefer | SslMode::Require => { @@ -157,10 +165,12 @@ impl ClickHouseAdapter { } SslMode::VerifyCA | SslMode::VerifyFull => { if let Some(ref ca_cert) = self.config.ssl_ca_cert { - let pem = std::fs::read(ca_cert) - .map_err(|e| DbError::Connection(format!("Failed to read CA certificate: {}", e)))?; - let cert = reqwest::Certificate::from_pem(&pem) - .map_err(|e| DbError::Connection(format!("Failed to parse CA certificate: {}", e)))?; + let pem = std::fs::read(ca_cert).map_err(|e| { + DbError::Connection(format!("Failed to read CA certificate: {}", e)) + })?; + let cert = reqwest::Certificate::from_pem(&pem).map_err(|e| { + DbError::Connection(format!("Failed to parse CA certificate: {}", e)) + })?; builder = builder.add_root_certificate(cert); } } @@ -168,14 +178,16 @@ impl ClickHouseAdapter { if let (Some(ref cert_path), Some(ref key_path)) = (&self.config.ssl_client_cert, &self.config.ssl_client_key) { - let cert_pem = std::fs::read(cert_path) - .map_err(|e| DbError::Connection(format!("Failed to read client certificate: {}", e)))?; + let cert_pem = std::fs::read(cert_path).map_err(|e| { + DbError::Connection(format!("Failed to read client certificate: {}", e)) + })?; let key_pem = std::fs::read(key_path) .map_err(|e| DbError::Connection(format!("Failed to read client key: {}", e)))?; let mut combined = cert_pem; combined.extend_from_slice(&key_pem); - let identity = reqwest::Identity::from_pem(&combined) - .map_err(|e| DbError::Connection(format!("Failed to parse client identity: {}", e)))?; + let identity = reqwest::Identity::from_pem(&combined).map_err(|e| { + DbError::Connection(format!("Failed to parse client identity: {}", e)) + })?; builder = builder.identity(identity); } Ok(builder) @@ -718,16 +730,18 @@ mod tests { #[test] fn test_build_base_url_https() { - let config = ConnectionConfig::new(DatabaseType::ClickHouse, "ch.example.com", 8123, "default") - .with_ssl_mode(SslMode::Prefer); + let config = + ConnectionConfig::new(DatabaseType::ClickHouse, "ch.example.com", 8123, "default") + .with_ssl_mode(SslMode::Prefer); let adapter = ClickHouseAdapter::new(config); assert_eq!(adapter.build_base_url(), "https://ch.example.com:8123"); } #[test] fn test_build_base_url_https_require() { - let config = ConnectionConfig::new(DatabaseType::ClickHouse, "ch.example.com", 8123, "default") - .with_ssl_mode(SslMode::Require); + let config = + ConnectionConfig::new(DatabaseType::ClickHouse, "ch.example.com", 8123, "default") + .with_ssl_mode(SslMode::Require); let adapter = ClickHouseAdapter::new(config); assert_eq!(adapter.build_base_url(), "https://ch.example.com:8123"); } diff --git a/src-tauri/src/database/config.rs b/src-tauri/src/database/config.rs index 1e4c111f..c0b7c164 100644 --- a/src-tauri/src/database/config.rs +++ b/src-tauri/src/database/config.rs @@ -265,8 +265,12 @@ pub struct OracleConnectionOptions { pub service_level: Option, } -fn default_connect_timeout() -> u64 { 10 } -fn default_query_timeout() -> u64 { 30 } +fn default_connect_timeout() -> u64 { + 10 +} +fn default_query_timeout() -> u64 { + 30 +} impl ConnectionConfig { /// Create a new connection configuration. diff --git a/src-tauri/src/database/http_sql.rs b/src-tauri/src/database/http_sql.rs index 9ace6de9..636cf160 100644 --- a/src-tauri/src/database/http_sql.rs +++ b/src-tauri/src/database/http_sql.rs @@ -79,7 +79,11 @@ impl HttpSqlAdapter { } fn base_url(&self) -> String { - let scheme = if self.config.ssl_mode == SslMode::Disable { "http" } else { "https" }; + let scheme = if self.config.ssl_mode == SslMode::Disable { + "http" + } else { + "https" + }; format!("{}://{}:{}", scheme, self.config.host, self.config.port) } } @@ -89,18 +93,19 @@ impl DatabaseAdapter for HttpSqlAdapter { type Pool = HttpSqlPool; async fn connect(&mut self) -> DbResult<()> { - let mut builder = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(30)); + let mut builder = reqwest::Client::builder().timeout(std::time::Duration::from_secs(30)); builder = match self.config.ssl_mode { SslMode::Disable => builder, SslMode::Prefer | SslMode::Require => builder.danger_accept_invalid_certs(true), SslMode::VerifyCA | SslMode::VerifyFull => { if let Some(ref ca_cert) = self.config.ssl_ca_cert { - let pem = std::fs::read(ca_cert) - .map_err(|e| DbError::Connection(format!("Failed to read CA certificate: {}", e)))?; - let cert = reqwest::Certificate::from_pem(&pem) - .map_err(|e| DbError::Connection(format!("Failed to parse CA certificate: {}", e)))?; + let pem = std::fs::read(ca_cert).map_err(|e| { + DbError::Connection(format!("Failed to read CA certificate: {}", e)) + })?; + let cert = reqwest::Certificate::from_pem(&pem).map_err(|e| { + DbError::Connection(format!("Failed to parse CA certificate: {}", e)) + })?; builder = builder.add_root_certificate(cert); } builder @@ -110,18 +115,21 @@ impl DatabaseAdapter for HttpSqlAdapter { if let (Some(ref cert_path), Some(ref key_path)) = (&self.config.ssl_client_cert, &self.config.ssl_client_key) { - let cert_pem = std::fs::read(cert_path) - .map_err(|e| DbError::Connection(format!("Failed to read client certificate: {}", e)))?; + let cert_pem = std::fs::read(cert_path).map_err(|e| { + DbError::Connection(format!("Failed to read client certificate: {}", e)) + })?; let key_pem = std::fs::read(key_path) .map_err(|e| DbError::Connection(format!("Failed to read client key: {}", e)))?; let mut combined = cert_pem; combined.extend_from_slice(&key_pem); - let identity = reqwest::Identity::from_pem(&combined) - .map_err(|e| DbError::Connection(format!("Failed to parse client identity: {}", e)))?; + let identity = reqwest::Identity::from_pem(&combined).map_err(|e| { + DbError::Connection(format!("Failed to parse client identity: {}", e)) + })?; builder = builder.identity(identity); } - let client = builder.build() + let client = builder + .build() .map_err(|e| DbError::Connection(e.to_string()))?; let resp = client diff --git a/src-tauri/src/database/jdbc_bridge/adapter.rs b/src-tauri/src/database/jdbc_bridge/adapter.rs index c01eab6b..cd968dab 100644 --- a/src-tauri/src/database/jdbc_bridge/adapter.rs +++ b/src-tauri/src/database/jdbc_bridge/adapter.rs @@ -15,9 +15,7 @@ use tokio::sync::Mutex; use super::launcher::JdbcBridgeLauncher; use super::pool::JdbcBridgePool; -use super::protocol::{ - ConnectionStatusData, JdbcMethod, JdbcRequest, QueryResultData, -}; +use super::protocol::{ConnectionStatusData, JdbcMethod, JdbcRequest, QueryResultData}; /// JDBC bridge adapter. /// @@ -531,15 +529,11 @@ mod tests { #[test] fn test_split_comment_statement() { - let stmts = split_sql_statements( - "CREATE TABLE t (id INT);\nCOMMENT ON TABLE t IS 'hello';", - ); + let stmts = + split_sql_statements("CREATE TABLE t (id INT);\nCOMMENT ON TABLE t IS 'hello';"); assert_eq!( stmts, - vec![ - "CREATE TABLE t (id INT)", - "COMMENT ON TABLE t IS 'hello'", - ] + vec!["CREATE TABLE t (id INT)", "COMMENT ON TABLE t IS 'hello'",] ); } diff --git a/src-tauri/src/database/jdbc_bridge/download.rs b/src-tauri/src/database/jdbc_bridge/download.rs index bfabb947..5c9efd62 100644 --- a/src-tauri/src/database/jdbc_bridge/download.rs +++ b/src-tauri/src/database/jdbc_bridge/download.rs @@ -7,11 +7,10 @@ //! (fallback if the Java bridge resolution is unavailable). use crate::database::error::{DbError, DbResult}; -use crate::APP_HANDLE; +use crate::download::DownloadKind; use futures::StreamExt; use std::path::{Path, PathBuf}; use std::process::Command; -use tauri::Emitter; const APP_VERSION: &str = env!("APP_VERSION"); @@ -31,7 +30,10 @@ fn bridge_dir() -> PathBuf { /// Get the path to the current version's bridge JAR (`~/.sqlkit/jdbc-bridge/jdbc-bridge-{ver}.jar`). pub fn bridge_jar_path() -> PathBuf { - bridge_dir().join(format!("{}{}{}", BRIDGE_JAR_PREFIX, APP_VERSION, BRIDGE_JAR_SUFFIX)) + bridge_dir().join(format!( + "{}{}{}", + BRIDGE_JAR_PREFIX, APP_VERSION, BRIDGE_JAR_SUFFIX + )) } /// Check if the current version's bridge JAR is installed. @@ -41,7 +43,13 @@ pub fn is_bridge_installed() -> bool { /// Download a file from URL to a temporary path, then atomically rename to final. /// Emits Tauri progress events if the global APP_HANDLE is set. -pub async fn download_to_path(url: &str, dest: &Path, event_label: &str, expected_size_hint: u64) -> DbResult<()> { +pub async fn download_to_path( + url: &str, + dest: &Path, + id: &str, + kind: DownloadKind, + expected_size_hint: u64, +) -> DbResult<()> { let tmp_path = dest.with_extension("tmp"); let response = reqwest::get(url) .await @@ -54,7 +62,10 @@ pub async fn download_to_path(url: &str, dest: &Path, event_label: &str, expecte ))); } - let total = response.content_length().unwrap_or(expected_size_hint).max(1); + let total = response + .content_length() + .unwrap_or(expected_size_hint) + .max(1); let mut downloaded: u64 = 0; if let Some(parent) = dest.parent() { @@ -70,24 +81,15 @@ pub async fn download_to_path(url: &str, dest: &Path, event_label: &str, expecte let mut stream = response.bytes_stream(); while let Some(chunk) = stream.next().await { - let chunk = chunk.map_err(|e| DbError::Connection(format!("Download stream error: {}", e)))?; + let chunk = + chunk.map_err(|e| DbError::Connection(format!("Download stream error: {}", e)))?; downloaded += chunk.len() as u64; use tokio::io::AsyncWriteExt; file.write_all(&chunk) .await .map_err(|e| DbError::Connection(format!("Failed to write chunk: {}", e)))?; - // Emit progress event - if let Some(handle) = crate::APP_HANDLE.get() { - let _ = handle.emit( - "connection-progress", - serde_json::json!({ - "step": event_label, - "downloaded": downloaded, - "total": total, - }), - ); - } + crate::download::emit_progress(id, kind.clone(), downloaded, total); } tokio::fs::rename(&tmp_path, dest) @@ -121,27 +123,17 @@ pub async fn download_bridge_plugin() -> DbResult<()> { let mut last_err = None::; for attempt in 0..2 { if attempt > 0 { - // Emit retry event - if let Some(handle) = APP_HANDLE.get() { - let _ = handle.emit( - "connection-progress", - serde_json::json!({ - "step": "retry", - "message": format!("Download failed, retrying... ({})", last_err.as_deref().unwrap_or("unknown error")), - "downloaded": 0, - "total": 1, - }), - ); - } + crate::download::emit_progress("bridge", DownloadKind::Bridge, 0, 1); } - if let Err(e) = download_to_path(&url, &jar_path, "bridge_jar", 10_000_000).await { + if let Err(e) = + download_to_path(&url, &jar_path, "bridge", DownloadKind::Bridge, 10_000_000).await + { last_err = Some(e.to_string()); continue; } - let meta = std::fs::metadata(&jar_path).map_err(|e| { - DbError::Connection(format!("Failed to check JAR file size: {}", e)) - })?; + let meta = std::fs::metadata(&jar_path) + .map_err(|e| DbError::Connection(format!("Failed to check JAR file size: {}", e)))?; if meta.len() < 1_000_000 { let _ = std::fs::remove_file(&jar_path); last_err = Some(format!( @@ -207,37 +199,88 @@ pub async fn download_bridge_plugin() -> DbResult<()> { Ok(()) } -/// Download a JDBC driver JAR directly from Maven Central via HTTP. -/// Uses the driver registry config (version_cap, maven coordinates) to -/// construct the download URL. Stores the JAR in the driver cache directory. +/// Download a JDBC driver JAR directly from Maven Central via HTTP (or from +/// a direct URL if `download_url` is set in the registry). +/// +/// Version resolution: +/// - `download_url` → download directly (no version resolution needed) +/// - `version_cap` → pinned version (skip network) +/// - Neither → fetch `maven-metadata.xml` to resolve the latest version +/// /// Does NOT require Java — purely HTTP. pub async fn download_jdbc_driver_direct(db_type: &str) -> DbResult<()> { use super::registry::{resolve_maven_url, DriverRegistry}; let registry = DriverRegistry::load(); - let config = registry - .get_config_by_name(db_type) - .ok_or_else(|| { - DbError::Connection(format!("No driver registry entry for '{}'", db_type)) - })?; - - let version = config.version_cap.as_deref().unwrap_or("latest"); - let classifier = config.maven_classifier.as_deref(); - let url = resolve_maven_url(&config.maven_group, &config.maven_artifact, version, classifier); + let config = registry.get_config_by_name(db_type).ok_or_else(|| { + DbError::Connection(format!("No driver registry entry for '{}'", db_type)) + })?; let dest_dir = super::jre::home_dir() .join(".sqlkit") .join("jdbc-bridge") .join("drivers") .join(&config.maven_artifact); + + // Drivers with a direct download URL (non-Maven, e.g. GBase 8a) + if let Some(download_url) = &config.download_url { + let jar_name = config.maven_artifact.clone() + ".jar"; + let dest = dest_dir.join(&jar_name); + if dest.exists() { + return Ok(()); + } + return download_to_path(download_url, &dest, db_type, DownloadKind::Driver, 5_000_000).await; + } + + // Resolve the version to download + let version: String = if let Some(cap) = &config.version_cap { + cap.clone() + } else { + let group_path = config.maven_group.replace('.', "/"); + let metadata_url = format!( + "https://repo1.maven.org/maven2/{}/{}/maven-metadata.xml", + group_path, config.maven_artifact + ); + let resp = reqwest::get(&metadata_url) + .await + .map_err(|e| DbError::Connection(format!("Failed to fetch Maven metadata: {}", e)))?; + if !resp.status().is_success() { + return Err(DbError::Connection(format!( + "Maven metadata not found (HTTP {}) from {}", + resp.status(), + metadata_url + ))); + } + let xml = resp + .text() + .await + .map_err(|e| DbError::Connection(format!("Failed to read Maven metadata: {}", e)))?; + let start = xml.find("").ok_or_else(|| { + DbError::Connection("No tag found in Maven metadata".to_string()) + })?; + let value_start = start + "".len(); + let end = xml[value_start..].find("").ok_or_else(|| { + DbError::Connection("Malformed tag in Maven metadata".to_string()) + })?; + xml[value_start..value_start + end].to_string() + }; + + let classifier = config.maven_classifier.as_deref(); + let url = resolve_maven_url( + &config.maven_group, + &config.maven_artifact, + &version, + classifier, + ); + let jar_name = format!("{}-{}.jar", config.maven_artifact, version); let dest = dest_dir.join(&jar_name); if dest.exists() { - return Ok(()); // Already cached + return Ok(()); } - download_to_path(&url, &dest, "jdbc_driver", 5_000_000).await + download_to_path(&url, &dest, db_type, DownloadKind::Driver, 5_000_000).await } /// Clean up old bridge JARs and stale version directories. @@ -266,9 +309,9 @@ async fn cleanup_old_bridge_versions() -> DbResult<()> { if path.is_dir() && fname != "drivers" { let old_jar = path.join("jdbc-bridge.jar"); if old_jar.exists() { - tokio::fs::remove_dir_all(&path) - .await - .map_err(|e| DbError::Connection(format!("Failed to remove old bridge dir: {}", e)))?; + tokio::fs::remove_dir_all(&path).await.map_err(|e| { + DbError::Connection(format!("Failed to remove old bridge dir: {}", e)) + })?; } continue; } @@ -277,13 +320,11 @@ async fn cleanup_old_bridge_versions() -> DbResult<()> { if fname.starts_with(BRIDGE_JAR_PREFIX) && fname.ends_with(BRIDGE_JAR_SUFFIX) { let ver = &fname[BRIDGE_JAR_PREFIX.len()..fname.len() - BRIDGE_JAR_SUFFIX.len()]; if ver != APP_VERSION { - tokio::fs::remove_file(&path) - .await - .map_err(|e| DbError::Connection(format!("Failed to remove old bridge JAR: {}", e)))?; + tokio::fs::remove_file(&path).await.map_err(|e| { + DbError::Connection(format!("Failed to remove old bridge JAR: {}", e)) + })?; } } } Ok(()) } - - diff --git a/src-tauri/src/database/jdbc_bridge/fallback.rs b/src-tauri/src/database/jdbc_bridge/fallback.rs index 61ed601d..60c4729d 100644 --- a/src-tauri/src/database/jdbc_bridge/fallback.rs +++ b/src-tauri/src/database/jdbc_bridge/fallback.rs @@ -40,7 +40,12 @@ fn build_oracle_url( if use_service { // Service name format: jdbc:oracle:thin:@//host:port/service_name if let Some(ref service_template) = config.jdbc_url_template_service { - super::registry::build_jdbc_url_from_template(service_template, host, port, database) + super::registry::build_jdbc_url_from_template( + service_template, + host, + port, + database, + ) } else { super::registry::build_jdbc_url(config, host, port, database) } @@ -69,7 +74,10 @@ fn build_jvm_args(oracle_options: Option<&OracleConnectionOptions>) -> Vec(v).ok()) { + match resp + .result + .and_then(|v| serde_json::from_value::(v).ok()) + { Some(result) => result.jar_path, None => { let mut guard = launcher.lock().await; @@ -227,7 +239,11 @@ pub async fn try_driver( } _ => { let stderr = stderr_from_launcher(&launcher); - let detail = if stderr.is_empty() { err.clone() } else { format!("{}. stderr: {}", err, stderr) }; + let detail = if stderr.is_empty() { + err.clone() + } else { + format!("{}. stderr: {}", err, stderr) + }; DriverAttempt::Fatal(DbError::Connection(detail)) } } @@ -241,7 +257,11 @@ pub async fn try_driver( } Ok(Err(e)) => { let stderr = stderr_from_launcher(&launcher); - let msg = if stderr.is_empty() { e.to_string() } else { format!("{}. stderr: {}", e, stderr) }; + let msg = if stderr.is_empty() { + e.to_string() + } else { + format!("{}. stderr: {}", e, stderr) + }; let category = classify_connection_error( db_type_from_config(config), &msg, @@ -331,14 +351,46 @@ pub async fn run_fallback_chain( // Two-phase driver resolution: // 1. Try LATEST (no version_cap) - match try_driver(config, host, port, database, username, password, false, oracle_options, ssl_mode, ssl_ca_cert, ssl_client_cert, ssl_client_key, trust_server_certificate).await { + match try_driver( + config, + host, + port, + database, + username, + password, + false, + oracle_options, + ssl_mode, + ssl_ca_cert, + ssl_client_cert, + ssl_client_key, + trust_server_certificate, + ) + .await + { DriverAttempt::Connected(conn_id, launcher) => { return Ok(("resolved".to_string(), conn_id, launcher)); } DriverAttempt::VersionMismatch(_) => { // 2. If LATEST fails with version_incompatible and a cap exists, retry with cap if config.version_cap.is_some() { - match try_driver(config, host, port, database, username, password, true, oracle_options, ssl_mode, ssl_ca_cert, ssl_client_cert, ssl_client_key, trust_server_certificate).await { + match try_driver( + config, + host, + port, + database, + username, + password, + true, + oracle_options, + ssl_mode, + ssl_ca_cert, + ssl_client_cert, + ssl_client_key, + trust_server_certificate, + ) + .await + { DriverAttempt::Connected(conn_id, launcher) => { return Ok(("capped".to_string(), conn_id, launcher)); } diff --git a/src-tauri/src/database/jdbc_bridge/jre.rs b/src-tauri/src/database/jdbc_bridge/jre.rs index 4f8a6366..8bd6dac6 100644 --- a/src-tauri/src/database/jdbc_bridge/jre.rs +++ b/src-tauri/src/database/jdbc_bridge/jre.rs @@ -6,12 +6,10 @@ //! `release` file, and cleaning it up. use crate::database::error::{DbError, DbResult}; -use crate::APP_HANDLE; use futures::StreamExt; use std::path::Path; use std::path::PathBuf; use std::sync::OnceLock; -use tauri::Emitter; use tokio::sync::Mutex; /// Subdirectory under user home for the managed JRE. @@ -75,18 +73,36 @@ pub fn is_managed_jre_installed() -> bool { /// Determine the Adoptium OS and arch strings for the current platform. fn adoptium_os_arch() -> Option<(&'static str, &'static str)> { - #[cfg(all(target_os = "macos", target_arch = "aarch64"))] { Some(("mac", "aarch64")) } - #[cfg(all(target_os = "macos", target_arch = "x86_64"))] { Some(("mac", "x64")) } - #[cfg(all(target_os = "linux", target_arch = "x86_64"))] { Some(("linux", "x64")) } - #[cfg(all(target_os = "linux", target_arch = "aarch64"))] { Some(("linux", "aarch64")) } - #[cfg(all(target_os = "windows", target_arch = "x86_64"))] { Some(("windows", "x64")) } + #[cfg(all(target_os = "macos", target_arch = "aarch64"))] + { + Some(("mac", "aarch64")) + } + #[cfg(all(target_os = "macos", target_arch = "x86_64"))] + { + Some(("mac", "x64")) + } + #[cfg(all(target_os = "linux", target_arch = "x86_64"))] + { + Some(("linux", "x64")) + } + #[cfg(all(target_os = "linux", target_arch = "aarch64"))] + { + Some(("linux", "aarch64")) + } + #[cfg(all(target_os = "windows", target_arch = "x86_64"))] + { + Some(("windows", "x64")) + } #[cfg(not(any( all(target_os = "macos", target_arch = "aarch64"), all(target_os = "macos", target_arch = "x86_64"), all(target_os = "linux", target_arch = "x86_64"), all(target_os = "linux", target_arch = "aarch64"), all(target_os = "windows", target_arch = "x86_64"), - )))] { None } + )))] + { + None + } } // ── detection ───────────────────────────────────────────── @@ -179,7 +195,10 @@ impl JreDetector { /// there). Returns `None` if the path doesn't exist, isn't a Java binary, /// or the version string can't be parsed. pub fn system_java_version(java: &PathBuf) -> Option { - let output = std::process::Command::new(java).arg("-version").output().ok()?; + let output = std::process::Command::new(java) + .arg("-version") + .output() + .ok()?; let stderr = String::from_utf8_lossy(&output.stderr); let stdout = String::from_utf8_lossy(&output.stdout); let version_str = stderr @@ -265,7 +284,11 @@ pub fn parse_adoptium_build_version(url: &str) -> Option { .chars() .take_while(|c| c.is_ascii_digit() || *c == '.') .collect(); - if version.is_empty() { None } else { Some(version) } + if version.is_empty() { + None + } else { + Some(version) + } } // ── download / remove ───────────────────────────────────── @@ -304,21 +327,18 @@ async fn download_jre_stream( let mut stream = response.bytes_stream(); let mut downloaded: u64 = 0; while let Some(chunk) = stream.next().await { - let chunk = chunk.map_err(|e| DbError::Connection(format!("Download stream error: {}", e)))?; + let chunk = + chunk.map_err(|e| DbError::Connection(format!("Download stream error: {}", e)))?; downloaded += chunk.len() as u64; file.write_all(&chunk) .await .map_err(|e| DbError::Connection(format!("Failed to write chunk: {}", e)))?; - if let Some(handle) = APP_HANDLE.get() { - let _ = handle.emit( - "connection-progress", - serde_json::json!({ - "step": "jre_download", - "downloaded": downloaded, - "total": total, - }), - ); - } + crate::download::emit_progress( + "jre", + crate::download::DownloadKind::Jre, + downloaded, + total, + ); } file.flush().await.ok(); Ok(()) @@ -330,17 +350,14 @@ async fn download_jre_stream( /// extracts the archive, and renames the extracted directory to `jre`. /// Uses atomic operations: download to temp → validate → extract to temp dir → replace. pub async fn download_managed_jre() -> DbResult<()> { - let _guard = JRE_INSTALL_LOCK - .get_or_init(|| Mutex::new(())) - .lock() - .await; + let _guard = JRE_INSTALL_LOCK.get_or_init(|| Mutex::new(())).lock().await; - let (os, arch) = adoptium_os_arch().ok_or_else(|| { - DbError::Connection("No JRE available for this platform".to_string()) - })?; + let (os, arch) = adoptium_os_arch() + .ok_or_else(|| DbError::Connection("No JRE available for this platform".to_string()))?; let base_dir = jre_base_dir(); // ~/.sqlkit/jre - let parent = base_dir.parent() + let parent = base_dir + .parent() .expect("jre_base_dir has a parent") .to_path_buf(); // ~/.sqlkit tokio::fs::create_dir_all(&parent) @@ -373,7 +390,8 @@ pub async fn download_managed_jre() -> DbResult<()> { if meta.len() < 10_000_000 { let _ = tokio::fs::remove_file(&tmp_archive).await; return Err(DbError::Connection(format!( - "JRE archive too small: {} bytes (expected ≥ 10MB)", meta.len() + "JRE archive too small: {} bytes (expected ≥ 10MB)", + meta.len() ))); } // Validate gzip magic bytes (1f 8b) if not a zip file @@ -383,7 +401,8 @@ pub async fn download_managed_jre() -> DbResult<()> { if magic.len() < 2 || magic[0] != 0x1f || magic[1] != 0x8b { let _ = tokio::fs::remove_file(&tmp_archive).await; return Err(DbError::Connection( - "Downloaded JRE archive has invalid gzip magic bytes — corrupt download".to_string() + "Downloaded JRE archive has invalid gzip magic bytes — corrupt download" + .to_string(), )); } } @@ -451,8 +470,8 @@ pub async fn download_managed_jre() -> DbResult<()> { .await .map_err(|e| DbError::Connection(format!("JRE extraction panicked: {}", e)))?; - let extracted_dir = extract_result - .map_err(|e| DbError::Connection(format!("JRE extraction failed: {}", e)))?; + let extracted_dir = + extract_result.map_err(|e| DbError::Connection(format!("JRE extraction failed: {}", e)))?; // Step 4: Atomic swap — rename temp to final, with rollback let _ = tokio::fs::remove_file(&tmp_archive).await; @@ -471,7 +490,9 @@ pub async fn download_managed_jre() -> DbResult<()> { // Rollback: restore backup let _ = tokio::fs::rename(&backup, &base_dir).await; let _ = tokio::fs::remove_dir_all(&extracted_dir).await; - return Err(DbError::Connection("Failed to install JRE — restored previous version".to_string())); + return Err(DbError::Connection( + "Failed to install JRE — restored previous version".to_string(), + )); } } } else { diff --git a/src-tauri/src/database/jdbc_bridge/launcher.rs b/src-tauri/src/database/jdbc_bridge/launcher.rs index ec208d4b..16e8c610 100644 --- a/src-tauri/src/database/jdbc_bridge/launcher.rs +++ b/src-tauri/src/database/jdbc_bridge/launcher.rs @@ -39,9 +39,7 @@ impl JdbcBridgeLauncher { } fn read_stderr_buffer(buf: &Arc>>) -> String { - buf.lock() - .unwrap_or_else(|e| e.into_inner()) - .join("\n") + buf.lock().unwrap_or_else(|e| e.into_inner()).join("\n") } fn drain_stderr(&self) -> String { @@ -89,8 +87,7 @@ impl JdbcBridgeLauncher { for arg in jvm_args { cmd.arg(arg); } - cmd.arg("-jar") - .arg(self.jar_path.to_str().unwrap_or("")); + cmd.arg("-jar").arg(self.jar_path.to_str().unwrap_or("")); let mut child = cmd .stdin(Stdio::piped()) .stdout(Stdio::piped()) @@ -200,7 +197,9 @@ impl JdbcBridgeLauncher { // Check if the process is still alive before trying to communicate if let Ok(Some(status)) = process.try_wait() { - let stderr = self.stderr_buffer.as_ref() + let stderr = self + .stderr_buffer + .as_ref() .map(Self::read_stderr_buffer) .unwrap_or_default(); return if stderr.is_empty() { @@ -230,29 +229,27 @@ impl JdbcBridgeLauncher { .map_err(|e| DbError::Connection(format!("Failed to serialize request: {}", e)))?; writeln!(stdin, "{}", json).map_err(|e| { - let stderr = self.stderr_buffer.as_ref() + let stderr = self + .stderr_buffer + .as_ref() .map(Self::read_stderr_buffer) .unwrap_or_default(); if stderr.is_empty() { DbError::Connection(format!("Failed to write to bridge stdin: {}", e)) } else { - DbError::Connection(format!( - "Bridge write error: {}. stderr: {}", - e, stderr - )) + DbError::Connection(format!("Bridge write error: {}. stderr: {}", e, stderr)) } })?; stdin.flush().map_err(|e| { - let stderr = self.stderr_buffer.as_ref() + let stderr = self + .stderr_buffer + .as_ref() .map(Self::read_stderr_buffer) .unwrap_or_default(); if stderr.is_empty() { DbError::Connection(format!("Failed to flush bridge stdin: {}", e)) } else { - DbError::Connection(format!( - "Bridge write error: {}. stderr: {}", - e, stderr - )) + DbError::Connection(format!("Bridge write error: {}. stderr: {}", e, stderr)) } })?; @@ -265,16 +262,15 @@ impl JdbcBridgeLauncher { loop { line.clear(); reader.read_line(&mut line).map_err(|e| { - let stderr = self.stderr_buffer.as_ref() + let stderr = self + .stderr_buffer + .as_ref() .map(Self::read_stderr_buffer) .unwrap_or_default(); if stderr.is_empty() { DbError::Connection(format!("Failed to read bridge response: {}", e)) } else { - DbError::Connection(format!( - "Bridge read error: {}. stderr: {}", - e, stderr - )) + DbError::Connection(format!("Bridge read error: {}. stderr: {}", e, stderr)) } })?; @@ -286,7 +282,9 @@ impl JdbcBridgeLauncher { std::thread::sleep(std::time::Duration::from_millis(1000)); continue; } - let stderr = self.stderr_buffer.as_ref() + let stderr = self + .stderr_buffer + .as_ref() .map(Self::read_stderr_buffer) .unwrap_or_default(); return if stderr.is_empty() { @@ -322,6 +320,160 @@ impl JdbcBridgeLauncher { Ok(resp) } + /// Send a request and receive a response, emitting progress events during the read phase. + /// + /// Same as [`send_request`] but intercepts intermediate JSON lines containing + /// `"phase":"progress"` emitted by the Java bridge during long operations + /// (e.g. `resolve_driver`). Such lines are parsed and the `downloaded`/`total` + /// fields are forwarded to `progress_cb`. The read loop continues until the + /// actual JSON-RPC response arrives. + pub fn send_request_with_progress( + &mut self, + req: &JdbcRequest, + mut progress_cb: impl FnMut(u64, u64), + ) -> DbResult { + let process = self + .process + .as_mut() + .ok_or_else(|| DbError::Connection("JDBC bridge not started".to_string()))?; + + // Check if the process is still alive before trying to communicate + if let Ok(Some(status)) = process.try_wait() { + let stderr = self + .stderr_buffer + .as_ref() + .map(Self::read_stderr_buffer) + .unwrap_or_default(); + return if stderr.is_empty() { + Err(DbError::Connection(format!( + "JDBC bridge exited before request (code: {}). No stderr output.", + status + ))) + } else { + Err(DbError::Connection(format!( + "JDBC bridge exited before request (code: {}). stderr: {}", + status, stderr + ))) + }; + } + + let stdout = process + .stdout + .as_mut() + .ok_or_else(|| DbError::Connection("JDBC bridge stdout not available".to_string()))?; + + let stdin = self + .stdin + .as_mut() + .ok_or_else(|| DbError::Connection("JDBC bridge stdin not available".to_string()))?; + + let json = serde_json::to_string(req) + .map_err(|e| DbError::Connection(format!("Failed to serialize request: {}", e)))?; + + writeln!(stdin, "{}", json).map_err(|e| { + let stderr = self + .stderr_buffer + .as_ref() + .map(Self::read_stderr_buffer) + .unwrap_or_default(); + if stderr.is_empty() { + DbError::Connection(format!("Failed to write to bridge stdin: {}", e)) + } else { + DbError::Connection(format!("Bridge write error: {}. stderr: {}", e, stderr)) + } + })?; + stdin.flush().map_err(|e| { + let stderr = self + .stderr_buffer + .as_ref() + .map(Self::read_stderr_buffer) + .unwrap_or_default(); + if stderr.is_empty() { + DbError::Connection(format!("Failed to flush bridge stdin: {}", e)) + } else { + DbError::Connection(format!("Bridge write error: {}. stderr: {}", e, stderr)) + } + })?; + + let mut reader = BufReader::new(stdout); + let mut line = String::new(); + let mut read_attempts = 0; + + // Skip any non-JSON lines (e.g. JVM prints version info to stdout). + // Intercept intermediate progress events from the bridge. + // Retry once if first read is empty (JVM may be slow to start). + loop { + line.clear(); + reader.read_line(&mut line).map_err(|e| { + let stderr = self + .stderr_buffer + .as_ref() + .map(Self::read_stderr_buffer) + .unwrap_or_default(); + if stderr.is_empty() { + DbError::Connection(format!("Failed to read bridge response: {}", e)) + } else { + DbError::Connection(format!("Bridge read error: {}. stderr: {}", e, stderr)) + } + })?; + + let trimmed = line.trim(); + if trimmed.is_empty() { + if read_attempts == 0 { + // JVM may be slow to start — wait and retry once + read_attempts += 1; + std::thread::sleep(std::time::Duration::from_millis(1000)); + continue; + } + let stderr = self + .stderr_buffer + .as_ref() + .map(Self::read_stderr_buffer) + .unwrap_or_default(); + return if stderr.is_empty() { + Err(DbError::Connection( + "Empty response from JDBC bridge".to_string(), + )) + } else { + Err(DbError::Connection(format!( + "Bridge read error. stderr: {}", + stderr + ))) + }; + } + // Skip lines that don't start with '{' (non-JSON noise from JVM) + if !trimmed.starts_with('{') { + continue; + } + // Intercept progress events: {"phase":"progress","downloaded":N,"total":M} + if trimmed.contains("\"phase\":\"progress\"") { + if let Ok(val) = serde_json::from_str::(trimmed) { + let downloaded = val.get("downloaded").and_then(|v| v.as_u64()).unwrap_or(0); + let total = val.get("total").and_then(|v| v.as_u64()).unwrap_or(0); + progress_cb(downloaded, total); + } + continue; + } + // This is the actual JSON-RPC response + break; + } + + let resp: JdbcResponse = serde_json::from_str(line.trim()) + .map_err(|e| DbError::Connection(format!("Failed to parse bridge response: {}", e)))?; + + if let Some(ref err) = resp.error { + let error_type = resp.error_type.as_deref().unwrap_or("unknown"); + return Err(match error_type { + "version_incompatible" => DbError::DriverVersionIncompatible(err.clone()), + "authentication_failed" => DbError::Authentication(err.clone()), + "network_error" | "timeout" => DbError::Connection(err.clone()), + _ => DbError::Connection(format!("JDBC bridge error: {}", err)), + }); + } + + Ok(resp) + } + /// Check if the bridge process is still alive. pub fn is_alive(&mut self) -> bool { match self.process.as_mut() { diff --git a/src-tauri/src/database/jdbc_bridge/mod.rs b/src-tauri/src/database/jdbc_bridge/mod.rs index 8ee1521d..4280c8b9 100644 --- a/src-tauri/src/database/jdbc_bridge/mod.rs +++ b/src-tauri/src/database/jdbc_bridge/mod.rs @@ -30,4 +30,6 @@ pub mod tns_parser; pub use adapter::JdbcBridgeAdapter; pub use launcher::JdbcBridgeLauncher; pub use pool::{JdbcBridgeConnection, JdbcBridgePool}; -pub use protocol::{JdbcMethod, JdbcRequest, JdbcResponse, ResolveDriverParams, ResolveDriverResult}; +pub use protocol::{ + JdbcMethod, JdbcRequest, JdbcResponse, ResolveDriverParams, ResolveDriverResult, +}; diff --git a/src-tauri/src/database/jdbc_bridge/registry.rs b/src-tauri/src/database/jdbc_bridge/registry.rs index 1a694051..59fb3e8a 100644 --- a/src-tauri/src/database/jdbc_bridge/registry.rs +++ b/src-tauri/src/database/jdbc_bridge/registry.rs @@ -117,11 +117,14 @@ impl DriverRegistry { /// /// When `classifier` is `Some`, the JAR filename becomes /// `{artifact}-{version}-{classifier}.jar` (e.g. `hive-jdbc-3.1.3-standalone.jar`). -pub fn resolve_maven_url(group: &str, artifact: &str, version: &str, classifier: Option<&str>) -> String { +pub fn resolve_maven_url( + group: &str, + artifact: &str, + version: &str, + classifier: Option<&str>, +) -> String { let group_path = group.replace('.', "/"); - let classifier_suffix = classifier - .map(|c| format!("-{c}")) - .unwrap_or_default(); + let classifier_suffix = classifier.map(|c| format!("-{c}")).unwrap_or_default(); format!( "https://repo1.maven.org/maven2/{group_path}/{artifact}/{version}/{artifact}-{version}{classifier_suffix}.jar" ) @@ -201,6 +204,7 @@ fn db_type_to_registry_key(db: DatabaseType) -> Option<&'static str> { DatabaseType::Access => Some("access"), DatabaseType::YashanDB => Some("yashandb"), DatabaseType::KingbaseES => Some("kingbase"), + DatabaseType::OceanbaseOracle => Some("oceanbase_oracle"), _ => None, } } diff --git a/src-tauri/src/database/jdbc_bridge/tns_parser.rs b/src-tauri/src/database/jdbc_bridge/tns_parser.rs index c7ca4f6b..7534593e 100644 --- a/src-tauri/src/database/jdbc_bridge/tns_parser.rs +++ b/src-tauri/src/database/jdbc_bridge/tns_parser.rs @@ -8,10 +8,11 @@ use std::path::Path; /// Tries common filename variants: `tnsnames.ora`, `TNSNAMES.ORA`. pub fn parse_tns_aliases(tns_admin_dir: &str) -> Vec { let dir = Path::new(tns_admin_dir); - + // Try common filename variants let filenames = ["tnsnames.ora", "TNSNAMES.ORA", "Tnsnames.ora"]; - let content = filenames.iter() + let content = filenames + .iter() .find_map(|name| fs::read_to_string(dir.join(name)).ok()); let content = match content { diff --git a/src-tauri/src/database/mysql.rs b/src-tauri/src/database/mysql.rs index f1be1150..0e58fa7b 100644 --- a/src-tauri/src/database/mysql.rs +++ b/src-tauri/src/database/mysql.rs @@ -14,11 +14,11 @@ use crate::database::{ }, }; use async_trait::async_trait; +use log; use mysql_async::{ - prelude::*, ClientIdentity, Conn, OptsBuilder, Pool, PoolConstraints, PoolOpts, Row, - SslOpts, Value, + prelude::*, ClientIdentity, Conn, OptsBuilder, Pool, PoolConstraints, PoolOpts, Row, SslOpts, + Value, }; -use log; use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -166,8 +166,7 @@ impl MySQLAdapter { opts = opts.ssl_opts(None); } SslMode::Prefer | SslMode::Require => { - let mut ssl_opts = SslOpts::default() - .with_danger_accept_invalid_certs(true); + let mut ssl_opts = SslOpts::default().with_danger_accept_invalid_certs(true); if let (Some(ref cert), Some(ref key)) = (&self.config.ssl_client_cert, &self.config.ssl_client_key) @@ -181,8 +180,7 @@ impl MySQLAdapter { opts = opts.ssl_opts(Some(ssl_opts)); } SslMode::VerifyCA => { - let mut ssl_opts = SslOpts::default() - .with_danger_skip_domain_validation(true); + let mut ssl_opts = SslOpts::default().with_danger_skip_domain_validation(true); if let Some(ref ca_cert) = self.config.ssl_ca_cert { let ca_path: std::path::PathBuf = ca_cert.into(); @@ -323,7 +321,10 @@ impl MySQLAdapter { /// Check if a mysql_async error is SSL/TLS related. fn is_ssl_error(e: &mysql_async::Error) -> bool { let msg = e.to_string().to_lowercase(); - msg.contains("ssl") || msg.contains("tls") || msg.contains("certificate") || msg.contains("handshake") + msg.contains("ssl") + || msg.contains("tls") + || msg.contains("certificate") + || msg.contains("handshake") } #[async_trait] @@ -337,22 +338,22 @@ impl DatabaseAdapter for MySQLAdapter { let mut conn = match pool.get_conn().await { Ok(conn) => conn, Err(e) if self.config.ssl_mode == SslMode::Prefer && is_ssl_error(&e) => { - log::warn!("SSL handshake failed with Prefer mode, retrying without SSL: {}", e); + log::warn!( + "SSL handshake failed with Prefer mode, retrying without SSL: {}", + e + ); self.config.ssl_mode = SslMode::Disable; let fallback_opts = self.build_connection_opts()?; self.config.ssl_mode = SslMode::Prefer; let fallback_pool = Pool::new(fallback_opts); - let mut conn = fallback_pool - .get_conn() - .await - .map_err(|retry_err| { - DbError::Connection(format!( - "Connection failed even without SSL: {}", - retry_err - )) - })?; + let mut conn = fallback_pool.get_conn().await.map_err(|retry_err| { + DbError::Connection(format!( + "Connection failed even without SSL: {}", + retry_err + )) + })?; conn.query_drop("SELECT 1") .await @@ -361,7 +362,9 @@ impl DatabaseAdapter for MySQLAdapter { drop(conn); self.raw_pool = Some(fallback_pool.clone()); - self.pool = Some(Arc::new(MySQLPool { pool: fallback_pool })); + self.pool = Some(Arc::new(MySQLPool { + pool: fallback_pool, + })); return Ok(()); } diff --git a/src-tauri/src/database/postgres.rs b/src-tauri/src/database/postgres.rs index 1581d161..74ba15cf 100644 --- a/src-tauri/src/database/postgres.rs +++ b/src-tauri/src/database/postgres.rs @@ -281,7 +281,10 @@ impl PostgresAdapter { parts.push(format!("sslmode={}", ssl_mode)); // Add connection timeout - parts.push(format!("connect_timeout={}", self.config.connect_timeout_secs)); + parts.push(format!( + "connect_timeout={}", + self.config.connect_timeout_secs + )); // Add additional options for (key, value) in &self.config.options { @@ -296,11 +299,13 @@ impl PostgresAdapter { skip_verification: bool, verify_hostname: bool, ) -> DbResult { - use rustls::RootCertStore; - use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; + use rustls::client::danger::{ + HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier, + }; use rustls::pki_types::UnixTime; - use rustls::{DigitallySignedStruct, SignatureScheme}; use rustls::server::ParsedCertificate; + use rustls::RootCertStore; + use rustls::{DigitallySignedStruct, SignatureScheme}; /// Certificate verifier that skips all verification (accepts any cert). #[derive(Debug)] @@ -386,8 +391,7 @@ impl PostgresAdapter { cert: &CertificateDer<'_>, dss: &DigitallySignedStruct, ) -> Result { - let spki = ParsedCertificate::try_from(cert)? - .subject_public_key_info(); + let spki = ParsedCertificate::try_from(cert)?.subject_public_key_info(); let provider = rustls::crypto::aws_lc_rs::default_provider(); let supported = &provider.signature_verification_algorithms; @@ -423,8 +427,7 @@ impl PostgresAdapter { cert: &CertificateDer<'_>, dss: &DigitallySignedStruct, ) -> Result { - let spki = ParsedCertificate::try_from(cert)? - .subject_public_key_info(); + let spki = ParsedCertificate::try_from(cert)?.subject_public_key_info(); let provider = rustls::crypto::aws_lc_rs::default_provider(); let supported = &provider.signature_verification_algorithms; @@ -481,22 +484,19 @@ impl PostgresAdapter { rustls_pemfile::certs(&mut cert_data.as_slice()) .collect::, _>>() .map_err(|e| { - DbError::Connection( - format!("Failed to parse client certificate: {}", e), - ) + DbError::Connection(format!( + "Failed to parse client certificate: {}", + e + )) })?; let key_data = std::fs::read(key_path).map_err(|e| { DbError::Connection(format!("Failed to read client key: {}", e)) })?; let key = rustls_pemfile::private_key(&mut key_data.as_slice()) - .map_err(|e| { - DbError::Connection(format!("Failed to parse client key: {}", e)) - })? + .map_err(|e| DbError::Connection(format!("Failed to parse client key: {}", e)))? .ok_or_else(|| { - DbError::Connection( - "No private key found in client key file".to_string(), - ) + DbError::Connection("No private key found in client key file".to_string()) })?; Some((certs, key)) @@ -558,9 +558,7 @@ impl PostgresAdapter { root_store: Arc::new(root_store), })) .with_client_auth_cert(certs, key) - .map_err(|e| { - DbError::Connection(format!("Failed to set client auth: {}", e)) - })? + .map_err(|e| DbError::Connection(format!("Failed to set client auth: {}", e)))? } else { ClientConfig::builder() .dangerous() @@ -2285,9 +2283,8 @@ mod tests { #[test] fn test_pg_compat_without_database_falls_back_to_username() { - let config = - ConnectionConfig::new(DatabaseType::OpenGauss, "10.84.1.213", 5432, "SYSTEM") - .with_password("kingbase@123"); + let config = ConnectionConfig::new(DatabaseType::OpenGauss, "10.84.1.213", 5432, "SYSTEM") + .with_password("kingbase@123"); let adapter = PostgresAdapter::new(config); let conn_str = adapter.build_connection_string(); diff --git a/src-tauri/src/database/rqlite.rs b/src-tauri/src/database/rqlite.rs index 048c42ea..f33c8f58 100644 --- a/src-tauri/src/database/rqlite.rs +++ b/src-tauri/src/database/rqlite.rs @@ -178,7 +178,11 @@ impl RqliteAdapter { /// Build the base URL from the configuration. fn build_base_url(&self) -> String { - let scheme = if self.config.ssl_mode == SslMode::Disable { "http" } else { "https" }; + let scheme = if self.config.ssl_mode == SslMode::Disable { + "http" + } else { + "https" + }; format!("{}://{}:{}", scheme, self.config.host, self.config.port) } @@ -190,11 +194,15 @@ impl RqliteAdapter { builder = self.apply_ssl_to_builder(builder)?; - builder.build() + builder + .build() .map_err(|e| DbError::Connection(format!("Failed to create HTTP client: {}", e))) } - fn apply_ssl_to_builder(&self, mut builder: reqwest::ClientBuilder) -> DbResult { + fn apply_ssl_to_builder( + &self, + mut builder: reqwest::ClientBuilder, + ) -> DbResult { match self.config.ssl_mode { SslMode::Disable => {} SslMode::Prefer | SslMode::Require => { @@ -202,10 +210,12 @@ impl RqliteAdapter { } SslMode::VerifyCA | SslMode::VerifyFull => { if let Some(ref ca_cert) = self.config.ssl_ca_cert { - let pem = std::fs::read(ca_cert) - .map_err(|e| DbError::Connection(format!("Failed to read CA certificate: {}", e)))?; - let cert = reqwest::Certificate::from_pem(&pem) - .map_err(|e| DbError::Connection(format!("Failed to parse CA certificate: {}", e)))?; + let pem = std::fs::read(ca_cert).map_err(|e| { + DbError::Connection(format!("Failed to read CA certificate: {}", e)) + })?; + let cert = reqwest::Certificate::from_pem(&pem).map_err(|e| { + DbError::Connection(format!("Failed to parse CA certificate: {}", e)) + })?; builder = builder.add_root_certificate(cert); } } @@ -213,14 +223,16 @@ impl RqliteAdapter { if let (Some(ref cert_path), Some(ref key_path)) = (&self.config.ssl_client_cert, &self.config.ssl_client_key) { - let cert_pem = std::fs::read(cert_path) - .map_err(|e| DbError::Connection(format!("Failed to read client certificate: {}", e)))?; + let cert_pem = std::fs::read(cert_path).map_err(|e| { + DbError::Connection(format!("Failed to read client certificate: {}", e)) + })?; let key_pem = std::fs::read(key_path) .map_err(|e| DbError::Connection(format!("Failed to read client key: {}", e)))?; let mut combined = cert_pem; combined.extend_from_slice(&key_pem); - let identity = reqwest::Identity::from_pem(&combined) - .map_err(|e| DbError::Connection(format!("Failed to parse client identity: {}", e)))?; + let identity = reqwest::Identity::from_pem(&combined).map_err(|e| { + DbError::Connection(format!("Failed to parse client identity: {}", e)) + })?; builder = builder.identity(identity); } Ok(builder) @@ -754,16 +766,18 @@ mod tests { #[test] fn test_build_base_url_https() { - let config = ConnectionConfig::new(DatabaseType::RQLite, "rqlite.example.com", 4001, "default") - .with_ssl_mode(SslMode::Prefer); + let config = + ConnectionConfig::new(DatabaseType::RQLite, "rqlite.example.com", 4001, "default") + .with_ssl_mode(SslMode::Prefer); let adapter = RqliteAdapter::new(config); assert_eq!(adapter.build_base_url(), "https://rqlite.example.com:4001"); } #[test] fn test_build_base_url_https_require() { - let config = ConnectionConfig::new(DatabaseType::RQLite, "rqlite.example.com", 4001, "default") - .with_ssl_mode(SslMode::Require); + let config = + ConnectionConfig::new(DatabaseType::RQLite, "rqlite.example.com", 4001, "default") + .with_ssl_mode(SslMode::Require); let adapter = RqliteAdapter::new(config); assert_eq!(adapter.build_base_url(), "https://rqlite.example.com:4001"); } diff --git a/src-tauri/src/database/sqlite.rs b/src-tauri/src/database/sqlite.rs index cac8c770..a4668e87 100644 --- a/src-tauri/src/database/sqlite.rs +++ b/src-tauri/src/database/sqlite.rs @@ -358,9 +358,7 @@ impl SQLiteAdapter { let columns: Vec = (0..column_count) .map(|i| stmt.column_name(i).unwrap_or("unknown").to_string()) .collect(); - let column_types: Vec = (0..column_count) - .map(|_| String::new()) - .collect(); + let column_types: Vec = (0..column_count).map(|_| String::new()).collect(); let rows_iter = stmt .query_map([], |row| Ok(Self::row_to_query_row(row))) diff --git a/src-tauri/src/database/strategy.rs b/src-tauri/src/database/strategy.rs index 90582a96..ea7c28f3 100644 --- a/src-tauri/src/database/strategy.rs +++ b/src-tauri/src/database/strategy.rs @@ -43,9 +43,8 @@ pub fn resolve_effective_type(db: DatabaseType) -> ConnectionStrategy { // Native PG adapter PostgreSQL => ConnectionStrategy::Native(CoreDatabaseType::PostgreSQL), // PG wire protocol compat - CockroachDB | Redshift | YugabyteDB | TimescaleDB | GaussDB | HighGo - | UXDB | OpenGauss | GBase8c | QuestDB | Vastbase - | Greenplum | EnterpriseDB | CrateDB | Materialize + CockroachDB | Redshift | YugabyteDB | TimescaleDB | GaussDB | HighGo | UXDB | OpenGauss + | GBase8c | QuestDB | Vastbase | Greenplum | EnterpriseDB | CrateDB | Materialize | AlloyDB | CloudSQLPG | FujitsuPG => { ConnectionStrategy::Native(CoreDatabaseType::PostgreSQL) } @@ -53,9 +52,8 @@ pub fn resolve_effective_type(db: DatabaseType) -> ConnectionStrategy { // Native MySQL adapter MySQL => ConnectionStrategy::Native(CoreDatabaseType::MySQL), // MySQL wire protocol compat - MariaDB | TiDB | OceanBase | TDSQL | PolarDB | Doris | SelectDB | StarRocks - | Databend | GoldenDB | ManticoreSearch - | SingleStoreMemSQL | CloudSQLMySQL => { + MariaDB | TiDB | OceanBase | TDSQL | PolarDB | Doris | SelectDB | StarRocks | Databend + | GoldenDB | ManticoreSearch | SingleStoreMemSQL | CloudSQLMySQL => { ConnectionStrategy::Native(CoreDatabaseType::MySQL) } @@ -120,16 +118,15 @@ pub fn is_pg_family(db: DatabaseType) -> bool { pub fn default_port(db: DatabaseType) -> Option { use DatabaseType::*; match db { - PostgreSQL | CockroachDB | Redshift | YugabyteDB | TimescaleDB | GaussDB - | HighGo | UXDB | OpenGauss | GBase8c | Vastbase - | Greenplum | EnterpriseDB | CrateDB | Materialize - | AlloyDB | CloudSQLPG | FujitsuPG => Some(5432), + PostgreSQL | CockroachDB | Redshift | YugabyteDB | TimescaleDB | GaussDB | HighGo + | UXDB | OpenGauss | GBase8c | Vastbase | Greenplum | EnterpriseDB | CrateDB + | Materialize | AlloyDB | CloudSQLPG | FujitsuPG => Some(5432), QuestDB => Some(8812), YashanDB => Some(1688), KingbaseES => Some(54321), OceanbaseOracle => Some(2881), - MySQL | MariaDB | TiDB | OceanBase | TDSQL | PolarDB | GoldenDB - | SingleStoreMemSQL | CloudSQLMySQL => Some(3306), + MySQL | MariaDB | TiDB | OceanBase | TDSQL | PolarDB | GoldenDB | SingleStoreMemSQL + | CloudSQLMySQL => Some(3306), Doris | SelectDB | StarRocks => Some(9030), Databend => Some(3307), ManticoreSearch => Some(9306), diff --git a/src-tauri/src/database/turso.rs b/src-tauri/src/database/turso.rs index 8794d79f..74a0801b 100644 --- a/src-tauri/src/database/turso.rs +++ b/src-tauri/src/database/turso.rs @@ -183,28 +183,33 @@ impl TursoAdapter { .user_agent("sqlkit-turso-adapter/0.1"); if let Some(ref ca_cert) = self.config.ssl_ca_cert { - let pem = std::fs::read(ca_cert) - .map_err(|e| DbError::Connection(format!("Failed to read CA certificate: {}", e)))?; - let cert = reqwest::Certificate::from_pem(&pem) - .map_err(|e| DbError::Connection(format!("Failed to parse CA certificate: {}", e)))?; + let pem = std::fs::read(ca_cert).map_err(|e| { + DbError::Connection(format!("Failed to read CA certificate: {}", e)) + })?; + let cert = reqwest::Certificate::from_pem(&pem).map_err(|e| { + DbError::Connection(format!("Failed to parse CA certificate: {}", e)) + })?; builder = builder.add_root_certificate(cert); } if let (Some(ref cert_path), Some(ref key_path)) = (&self.config.ssl_client_cert, &self.config.ssl_client_key) { - let cert_pem = std::fs::read(cert_path) - .map_err(|e| DbError::Connection(format!("Failed to read client certificate: {}", e)))?; + let cert_pem = std::fs::read(cert_path).map_err(|e| { + DbError::Connection(format!("Failed to read client certificate: {}", e)) + })?; let key_pem = std::fs::read(key_path) .map_err(|e| DbError::Connection(format!("Failed to read client key: {}", e)))?; let mut combined = cert_pem; combined.extend_from_slice(&key_pem); - let identity = reqwest::Identity::from_pem(&combined) - .map_err(|e| DbError::Connection(format!("Failed to parse client identity: {}", e)))?; + let identity = reqwest::Identity::from_pem(&combined).map_err(|e| { + DbError::Connection(format!("Failed to parse client identity: {}", e)) + })?; builder = builder.identity(identity); } - builder.build() + builder + .build() .map_err(|e| DbError::Connection(format!("Failed to create HTTP client: {}", e))) } diff --git a/src-tauri/src/download/mod.rs b/src-tauri/src/download/mod.rs new file mode 100644 index 00000000..e7da8dc9 --- /dev/null +++ b/src-tauri/src/download/mod.rs @@ -0,0 +1,99 @@ +//! Unified download event protocol. +//! +//! Provides typed download progress events for JRE, Bridge JAR, and JDBC driver +//! downloads. Helpers emit Tauri events via the global [`crate::APP_HANDLE`]. +//! +//! # Event shape +//! +//! All events use the `"download-progress"` event name with a JSON payload +//! tagged by `"phase"`: +//! +//! - `{"phase":"progress","id":"...","kind":"jre","downloaded":N,"total":M}` +//! - `{"phase":"complete","id":"...","kind":"bridge"}` +//! - `{"phase":"error","id":"...","kind":"driver","error":"..."}` + +use serde::Serialize; +use tauri::Emitter; + +/// Named Tauri event constant for all download progress events. +pub const DOWNLOAD_EVENT: &str = "download-progress"; + +/// The kind of download operation. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum DownloadKind { + /// Managed JRE download (Eclipse Temurin). + Jre, + /// JDBC bridge fat JAR download. + Bridge, + /// JDBC driver JAR download from Maven Central. + Driver, +} + +/// A typed download event tagged by phase. +/// +/// Serialized as a flat JSON object with a `"phase"` discriminator field. +#[derive(Debug, Clone, Serialize)] +#[serde(tag = "phase", rename_all = "snake_case")] +pub enum DownloadEvent { + /// Download is in progress with byte-level progress. + Progress { + id: String, + kind: DownloadKind, + downloaded: u64, + total: u64, + }, + /// Download completed successfully. + Complete { id: String, kind: DownloadKind }, + /// Download failed with an error message. + Error { + id: String, + kind: DownloadKind, + error: String, + }, +} + +/// Emit a [`DownloadEvent::Progress`] via the global app handle. +/// +/// Returns immediately and silently if [`crate::APP_HANDLE`] has not been set yet. +pub fn emit_progress(id: &str, kind: DownloadKind, downloaded: u64, total: u64) { + let event = DownloadEvent::Progress { + id: id.to_string(), + kind, + downloaded, + total, + }; + if let Some(handle) = crate::APP_HANDLE.get() { + let _ = handle.emit(DOWNLOAD_EVENT, &event); + } +} + +/// Emit a [`DownloadEvent::Complete`] via the global app handle. +/// +/// Returns immediately and silently if [`crate::APP_HANDLE`] has not been set yet. +pub fn emit_complete(id: &str, kind: DownloadKind) { + let event = DownloadEvent::Complete { + id: id.to_string(), + kind, + }; + if let Some(handle) = crate::APP_HANDLE.get() { + let _ = handle.emit(DOWNLOAD_EVENT, &event); + } +} + +/// Emit a [`DownloadEvent::Error`] via the global app handle. +/// +/// Accepts any type that implements `Into` for the error message +/// (e.g., `&str`, `String`, `Box`). +/// +/// Returns immediately and silently if [`crate::APP_HANDLE`] has not been set yet. +pub fn emit_error(id: &str, kind: DownloadKind, error: impl Into) { + let event = DownloadEvent::Error { + id: id.to_string(), + kind, + error: error.into(), + }; + if let Some(handle) = crate::APP_HANDLE.get() { + let _ = handle.emit(DOWNLOAD_EVENT, &event); + } +} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 0f67eeb7..e0d2c5d8 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -1,5 +1,6 @@ pub mod api_response; pub mod database; +pub mod download; pub mod ssh; pub mod transfer; @@ -32,7 +33,8 @@ pub static APP_HANDLE: OnceLock = OnceLock::new(); /// Global ConnectionGuardian, set once during app setup. Used by query execution /// for health checks and by capability handlers for connection quality warnings. -pub static GUARDIAN: OnceLock> = OnceLock::new(); +pub static GUARDIAN: OnceLock> = + OnceLock::new(); use agent::executor::SqlKitToolExecutor; use agent::query_history::{add_query_history_entry, load_query_history}; @@ -50,7 +52,7 @@ use agent_adapters::{ }; use capabilities::commands::{get_available_tools, invoke_capability}; use data_studio_agent as lib; -use data_studio_agent::storage as storage; +use data_studio_agent::storage; #[derive(Clone, serde::Serialize)] struct AuthPayload { @@ -220,6 +222,7 @@ pub fn run() { commands::download_jre, commands::remove_jre, commands::check_jre_update, + commands::check_driver_update, commands::check_bridge_status, commands::download_bridge_jar, commands::remove_bridge_jar, diff --git a/src-tauri/src/ssh/transport.rs b/src-tauri/src/ssh/transport.rs index 46179b33..1cb89ef7 100644 --- a/src-tauri/src/ssh/transport.rs +++ b/src-tauri/src/ssh/transport.rs @@ -36,10 +36,7 @@ pub async fn start_transport_layers( } } -pub async fn stop_transport_layers( - connection_id: &str, - tunnels: &TunnelManager, -) { +pub async fn stop_transport_layers(connection_id: &str, tunnels: &TunnelManager) { tunnels.stop_tunnel(connection_id).await; } @@ -74,7 +71,8 @@ mod tests { config.enabled = false; let layers = vec![TransportLayerConfig::Ssh(config)]; let tunnels = TunnelManager::new(); - let result = start_transport_layers("test", &layers, "db.example.com", 5432, &tunnels).await; + let result = + start_transport_layers("test", &layers, "db.example.com", 5432, &tunnels).await; assert_eq!(result.unwrap(), None); } @@ -85,7 +83,8 @@ mod tests { TransportLayerConfig::Ssh(ssh_config()), ]; let tunnels = TunnelManager::new(); - let result = start_transport_layers("test", &layers, "db.example.com", 5432, &tunnels).await; + let result = + start_transport_layers("test", &layers, "db.example.com", 5432, &tunnels).await; assert!(result.is_err()); assert!(result.unwrap_err().contains("Multi-hop")); } diff --git a/src-tauri/src/ssh/tunnel.rs b/src-tauri/src/ssh/tunnel.rs index a572ae06..16d46918 100644 --- a/src-tauri/src/ssh/tunnel.rs +++ b/src-tauri/src/ssh/tunnel.rs @@ -66,7 +66,9 @@ fn ssh_client_config() -> client::Config { use russh::keys::agent::AgentIdentity; async fn authenticate_with_agent_inner( - mut agent: russh::keys::agent::client::AgentClient, + mut agent: russh::keys::agent::client::AgentClient< + impl tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, + >, session: &mut Handle, username: &str, timeout: &Duration, @@ -80,16 +82,30 @@ async fn authenticate_with_agent_inner( return Err("SSH agent has no identities".to_string()); } - let hash_alg = session.best_supported_rsa_hash().await.ok().flatten().flatten(); + let hash_alg = session + .best_supported_rsa_hash() + .await + .ok() + .flatten() + .flatten(); let auth_result = tokio::time::timeout(*timeout, async { for identity in &identities { let result = match identity { AgentIdentity::PublicKey { key, .. } => { - session.authenticate_publickey_with(username, key.clone(), hash_alg, &mut agent).await + session + .authenticate_publickey_with(username, key.clone(), hash_alg, &mut agent) + .await } AgentIdentity::Certificate { certificate, .. } => { - session.authenticate_certificate_with(username, certificate.clone(), hash_alg, &mut agent).await + session + .authenticate_certificate_with( + username, + certificate.clone(), + hash_alg, + &mut agent, + ) + .await } }; @@ -137,8 +153,12 @@ async fn authenticate_with_agent( authenticate_with_agent_inner(agent, session, username, timeout).await } -fn load_ssh_private_key(path: &str, passphrase: Option<&str>) -> Result { - let secret = std::fs::read_to_string(path).map_err(|e| format!("Cannot read SSH key file: {}", e))?; +fn load_ssh_private_key( + path: &str, + passphrase: Option<&str>, +) -> Result { + let secret = + std::fs::read_to_string(path).map_err(|e| format!("Cannot read SSH key file: {}", e))?; match russh::keys::decode_secret_key(&secret, passphrase) { Ok(key) => Ok(key), @@ -408,24 +428,42 @@ async fn tunnel_reconnect_loop( match client::connect( Arc::new(ssh_client_config()), (&*connect_host, connect_port), - SshClient { verify_host_key: current_config.verify_host_key }, + SshClient { + verify_host_key: current_config.verify_host_key, + }, ) .await { Ok(mut raw_session) => { - match authenticate_session(&mut raw_session, ¤t_config, connect_timeout_secs).await { + match authenticate_session(&mut raw_session, ¤t_config, connect_timeout_secs) + .await + { Ok(()) => { let ka = Duration::from_secs(current_config.keepalive_interval_secs); forward_loop(&raw_session, &listener, &remote_host, remote_port, ka).await; - log::warn!("SSH tunnel lost ({}:{}), reconnecting...", connect_host, connect_port); + log::warn!( + "SSH tunnel lost ({}:{}), reconnecting...", + connect_host, + connect_port + ); } Err(e) => { - log::error!("SSH tunnel auth failed ({}:{}): {}", connect_host, connect_port, e); + log::error!( + "SSH tunnel auth failed ({}:{}): {}", + connect_host, + connect_port, + e + ); } } } Err(e) => { - log::error!("SSH tunnel connect failed ({}:{}): {}", connect_host, connect_port, e); + log::error!( + "SSH tunnel connect failed ({}:{}): {}", + connect_host, + connect_port, + e + ); } } @@ -448,12 +486,20 @@ async fn tunnel_reconnect_loop( match client::connect( Arc::new(ssh_client_config()), (&*connect_host, connect_port), - SshClient { verify_host_key: current_config.verify_host_key }, + SshClient { + verify_host_key: current_config.verify_host_key, + }, ) .await { Ok(mut raw_session) => { - match authenticate_session(&mut raw_session, ¤t_config, connect_timeout_secs).await { + match authenticate_session( + &mut raw_session, + ¤t_config, + connect_timeout_secs, + ) + .await + { Ok(()) => { current_config = initial_config.clone(); log::info!( @@ -463,7 +509,8 @@ async fn tunnel_reconnect_loop( attempts + 1 ); let ka = Duration::from_secs(current_config.keepalive_interval_secs); - forward_loop(&raw_session, &listener, &remote_host, remote_port, ka).await; + forward_loop(&raw_session, &listener, &remote_host, remote_port, ka) + .await; break; } Err(e) => { @@ -504,10 +551,13 @@ async fn authenticate_session( match &config.auth_method { SshAuthMethod::Password { password } => { - let auth_res = tokio::time::timeout(timeout, session.authenticate_password(&config.username, password)) - .await - .map_err(|_| format!("Auth timed out ({}s)", connect_timeout_secs))? - .map_err(|e| format!("Auth failed: {}", e))?; + let auth_res = tokio::time::timeout( + timeout, + session.authenticate_password(&config.username, password), + ) + .await + .map_err(|_| format!("Auth timed out ({}s)", connect_timeout_secs))? + .map_err(|e| format!("Auth failed: {}", e))?; if !auth_res.success() { return Err("Password authentication failed".to_string()); } @@ -518,7 +568,12 @@ async fn authenticate_session( } => { let key_pair = load_ssh_private_key(private_key_path, passphrase.as_deref()) .map_err(|e| format!("Failed to load key: {}", e))?; - let hash_alg = session.best_supported_rsa_hash().await.ok().flatten().flatten(); + let hash_alg = session + .best_supported_rsa_hash() + .await + .ok() + .flatten() + .flatten(); let auth_res = tokio::time::timeout( timeout, session.authenticate_publickey( @@ -587,10 +642,7 @@ impl TunnelManager { tunnels.insert( connection_id.to_string(), - TunnelEntry { - handle, - local_port, - }, + TunnelEntry { handle, local_port }, ); Ok(local_port) } @@ -614,10 +666,7 @@ impl TunnelManager { } } -fn get_active_port( - tunnels: &mut HashMap, - connection_id: &str, -) -> Option { +fn get_active_port(tunnels: &mut HashMap, connection_id: &str) -> Option { let entry = tunnels.get(connection_id)?; if entry.handle.is_finished() { tunnels.remove(connection_id); @@ -649,7 +698,13 @@ async fn spawn_tunnel_task( let timeout_dur = Duration::from_secs(timeout); let mut init_session = tokio::time::timeout( timeout_dur, - client::connect(ssh_config_init, (&*config.host, config.port), SshClient { verify_host_key: config.verify_host_key }), + client::connect( + ssh_config_init, + (&*config.host, config.port), + SshClient { + verify_host_key: config.verify_host_key, + }, + ), ) .await .map_err(|_| format!("SSH connection timed out ({}s)", timeout))? @@ -701,7 +756,11 @@ mod tests { ); let result = sanitize_openssh_key_comment(&pem); - assert!(result.is_ok(), "sanitize should succeed: {:?}", result.err()); + assert!( + result.is_ok(), + "sanitize should succeed: {:?}", + result.err() + ); let sanitized = result.unwrap(); assert!(sanitized.starts_with("-----BEGIN OPENSSH PRIVATE KEY-----")); assert!(sanitized.ends_with("-----\n")); @@ -709,7 +768,8 @@ mod tests { #[test] fn test_sanitize_rejects_non_openssh() { - let pkcs1 = "-----BEGIN RSA PRIVATE KEY-----\nMIIEpAIBAAKCAQEA\n-----END RSA PRIVATE KEY-----"; + let pkcs1 = + "-----BEGIN RSA PRIVATE KEY-----\nMIIEpAIBAAKCAQEA\n-----END RSA PRIVATE KEY-----"; let result = sanitize_openssh_key_comment(pkcs1); assert!(result.is_err()); assert!(result.unwrap_err().contains("not an OpenSSH format")); @@ -727,7 +787,10 @@ mod tests { assert_eq!(find_padding_len(&data), Ok(8)); let data2 = vec![1, 2, 4, 8]; - assert_eq!(find_padding_len(&data2), Err("Invalid private key padding".to_string())); + assert_eq!( + find_padding_len(&data2), + Err("Invalid private key padding".to_string()) + ); } #[test] diff --git a/src-tauri/src/state.rs b/src-tauri/src/state.rs index a9776270..0b2af6fe 100644 --- a/src-tauri/src/state.rs +++ b/src-tauri/src/state.rs @@ -16,9 +16,9 @@ use crate::database::rqlite::RqliteAdapter; use crate::database::turso::TursoAdapter; /// Core adapter types used in dispatch logic. use crate::database::{ - clickhouse::ClickHouseAdapter, http_sql::HttpSqlAdapter, - jdbc_bridge::JdbcBridgeAdapter, mysql::MySQLAdapter, postgres::PostgresAdapter, - sqlite::SQLiteAdapter, sqlserver::SqlServerAdapter, + clickhouse::ClickHouseAdapter, http_sql::HttpSqlAdapter, jdbc_bridge::JdbcBridgeAdapter, + mysql::MySQLAdapter, postgres::PostgresAdapter, sqlite::SQLiteAdapter, + sqlserver::SqlServerAdapter, }; /// Server configuration with connection details. @@ -85,8 +85,12 @@ pub struct ServerConfig { pub transport_layers: Option>, } -fn default_timeout_10() -> u64 { 10 } -fn default_timeout_30() -> u64 { 30 } +fn default_timeout_10() -> u64 { + 10 +} +fn default_timeout_30() -> u64 { + 30 +} impl ServerConfig { /// Create a new server configuration. diff --git a/src/components/connections/ServerFormDialog.vue b/src/components/connections/ServerFormDialog.vue index 064d144b..f7011a96 100644 --- a/src/components/connections/ServerFormDialog.vue +++ b/src/components/connections/ServerFormDialog.vue @@ -1,9 +1,8 @@ @@ -278,31 +285,32 @@ async function handleCheckUpdates() { diff --git a/src/views/setting/jre-driver-section.vue b/src/views/setting/jre-driver-section.vue index 3fda2732..f578fc3f 100644 --- a/src/views/setting/jre-driver-section.vue +++ b/src/views/setting/jre-driver-section.vue @@ -7,12 +7,16 @@ import { Button } from '@/components/ui/button' import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/tooltip' import { useDatabaseIcon } from '@/composables/useDatabaseIcon' +import { useDownloadEvents } from '@/composables/useDownloadEvents' +import { toast } from '@/composables/useNotifications' import { jdbcApi } from '@/datasources' import { dbTypeFromBackend } from '@/store' const { t } = useI18n() const { getDatabaseIcon } = useDatabaseIcon() +const dl = useDownloadEvents() + function getDriverIcon(dbType: string): string { const dbTypeEnum = dbTypeFromBackend[dbType] return dbTypeEnum ? getDatabaseIcon(dbTypeEnum) : '' @@ -23,27 +27,29 @@ const jreStatus = ref(null) const jreUpdate = ref(null) const jreLoading = ref(false) +const jreChecking = ref(false) + +const jreDownloading = computed(() => dl.isDownloading('jre')) +const jreDownloadProgress = computed(() => dl.getProgress('jre')) -const jreWarning = computed(() => { +const systemJreValid = computed(() => { const status = jreStatus.value - if (!status || status.source !== 'system') - return '' - const ver = status.version - if (!ver || ver === 'system') - return t('pages.settings.jre.jreCard.status.systemWarning') - const major = Number.parseInt(ver, 10) - if (!Number.isNaN(major) && major < 25) - return t('pages.settings.jre.jreCard.status.systemWarning') - return '' + if (!status || status.source !== 'system' || !status.version) + return false + const major = Number.parseInt(status.version, 10) + return !Number.isNaN(major) && major >= 25 }) const bridgeStatus = ref(null) const bridgeLoading = ref(false) +const bridgeDownloading = computed(() => dl.isDownloading('bridge')) +const bridgeDownloadProgress = computed(() => dl.getProgress('bridge')) + const drivers = ref([]) const driversLoading = ref(false) -const downloadingDriver = ref(null) const removingDriver = ref(null) +const checkingDriver = ref(null) // --- Computed sort: installed first, then alphabetical by name --- const sortedDrivers = ref([]) @@ -93,24 +99,47 @@ async function loadJreStatus() { } async function handleCheckJreUpdates() { + jreChecking.value = true jreLoading.value = true try { - jreUpdate.value = await jdbcApi.checkJreUpdate() + const update = await jdbcApi.checkJreUpdate() + jreUpdate.value = update + + if (update.update_available) { + toast.info(t('pages.settings.jre.jreCard.notifications.updateAvailable'), { + description: t('pages.settings.jre.jreCard.notifications.newVersion', { version: update.latest_version ?? '' }), + }) + // Auto-start download with progress (like uninstall → install flow) + await handleDownloadJre() + } + else { + toast.success(t('pages.settings.jre.jreCard.notifications.upToDate')) + } + } + catch (error) { + toast.error(t('pages.settings.jre.jreCard.notifications.checkUpdatesFailed'), { + description: error instanceof Error ? error.message : String(error), + }) } finally { + jreChecking.value = false jreLoading.value = false } } async function handleDownloadJre() { jreLoading.value = true - try { - await jdbcApi.downloadJre() - await loadJreStatus() + dl.reset('jre') + const ok = await dl.startDownload('jre', 'jre', () => jdbcApi.downloadJre()) + if (!ok) { + toast.error(t('pages.settings.jre.jreCard.notifications.downloadFailed'), { description: dl.getError('jre') ?? '' }) } - finally { - jreLoading.value = false + else { + toast.success(t('pages.settings.jre.jreCard.notifications.downloadSuccess')) + await loadJreStatus() } + jreLoading.value = false + dl.reset('jre') } async function handleRemoveJre() { @@ -121,6 +150,11 @@ async function handleRemoveJre() { jreUpdate.value = null await loadJreStatus() } + catch (error) { + toast.error(t('pages.settings.jre.jreCard.notifications.removeFailed'), { + description: error instanceof Error ? error.message : String(error), + }) + } finally { jreLoading.value = false } @@ -142,13 +176,17 @@ async function loadBridgeStatus() { async function handleDownloadBridge() { bridgeLoading.value = true - try { - await jdbcApi.downloadBridgeJar() - await loadBridgeStatus() + dl.reset('bridge') + const ok = await dl.startDownload('bridge', 'bridge', () => jdbcApi.downloadBridgeJar()) + if (!ok) { + toast.error(t('pages.settings.jre.bridgeCard.notifications.downloadFailed'), { description: dl.getError('bridge') ?? '' }) } - finally { - bridgeLoading.value = false + else { + toast.success(t('pages.settings.jre.bridgeCard.notifications.downloadSuccess')) + await loadBridgeStatus() } + bridgeLoading.value = false + dl.reset('bridge') } async function handleRemoveBridge() { @@ -158,6 +196,11 @@ async function handleRemoveBridge() { bridgeStatus.value = null await loadBridgeStatus() } + catch (error) { + toast.error(t('pages.settings.jre.bridgeCard.notifications.removeFailed'), { + description: error instanceof Error ? error.message : String(error), + }) + } finally { bridgeLoading.value = false } @@ -179,14 +222,20 @@ async function loadDrivers() { } async function handleDownloadDriver(dbType: string) { - downloadingDriver.value = dbType - try { - await jdbcApi.downloadDriver(dbType) - await loadDrivers() + if (!bridgeStatus.value?.installed) { + toast.error(t('pages.settings.jre.driversCard.notifications.installBridgeFirst')) + return } - finally { - downloadingDriver.value = null + dl.reset(dbType) + const ok = await dl.startDownload('driver', dbType, () => jdbcApi.downloadDriver(dbType)) + if (!ok) { + toast.error(t('pages.settings.jre.driversCard.notifications.downloadFailed'), { description: dl.getError(dbType) ?? '' }) + } + else { + toast.success(t('pages.settings.jre.driversCard.notifications.downloadSuccess')) + await loadDrivers() } + dl.reset(dbType) } async function handleRemoveDriver(dbType: string) { @@ -195,11 +244,40 @@ async function handleRemoveDriver(dbType: string) { await jdbcApi.removeDriver(dbType) await loadDrivers() } + catch (error) { + toast.error(t('pages.settings.jre.driversCard.notifications.removeFailed'), { + description: error instanceof Error ? error.message : String(error), + }) + } finally { removingDriver.value = null } } +async function handleCheckDriverUpdates(dbType: string) { + checkingDriver.value = dbType + try { + const update = await jdbcApi.checkDriverUpdate(dbType) + if (update.update_available) { + toast.info(t('pages.settings.jre.driversCard.notifications.updateAvailable'), { + description: t('pages.settings.jre.driversCard.notifications.newVersion', { version: update.latest_version ?? '' }), + }) + await handleDownloadDriver(dbType) + } + else { + toast.success(t('pages.settings.jre.driversCard.notifications.upToDate')) + } + } + catch (error) { + toast.error(t('pages.settings.jre.driversCard.notifications.checkUpdatesFailed'), { + description: error instanceof Error ? error.message : String(error), + }) + } + finally { + checkingDriver.value = null + } +} + // --- Refresh --- const allLoading = ref(false) @@ -254,6 +332,21 @@ onMounted(() => { {{ t('pages.settings.jre.jreCard.status.notInstalled') }}

+ +
+ + + + +
{{ jreStatus?.source === 'managed' ? t('pages.settings.jre.jreCard.status.installed') : t('pages.settings.jre.jreCard.status.notInstalled') }} @@ -264,7 +357,7 @@ onMounted(() => {

{{ t('pages.settings.jre.jreCard.actions.checkUpdates') }}

@@ -273,8 +366,9 @@ onMounted(() => { -

{{ jreStatus?.source !== 'managed' ? t('pages.settings.jre.jreCard.actions.download') : t('pages.settings.jre.jreCard.actions.redownload') }}

@@ -294,34 +388,43 @@ onMounted(() => {
-
+
-
-

- {{ t('pages.settings.jre.jreCard.status.system') }} -

-

- {{ jreStatus.path || t('pages.settings.jre.jreCard.status.notInstalled') }} - - -

-
-
- - {{ jreStatus.installed ? t('pages.settings.jre.jreCard.status.installed') : t('pages.settings.jre.jreCard.status.notInstalled') }} - -
+ +
- -
- - {{ jreWarning }} -
@@ -337,13 +440,30 @@ onMounted(() => {
- - + + {{ t('pages.settings.jre.bridgeCard.status.checking') }}

+ +
+ + + + +
{
+ + + + + + +

{{ t('pages.settings.jre.driversCard.actions.checkUpdates') }}

+
+
+
@@ -509,10 +668,11 @@ onMounted(() => { size="icon" variant="ghost" class="h-7 w-7" - :disabled="downloadingDriver === driver.db_type" + :disabled="dl.isDownloading(driver.db_type)" @click="handleDownloadDriver(driver.db_type)" > - + +