Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 51 additions & 13 deletions src/common/js/scl-app.js
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ const ENABLE_SQL_AUTOCOMPLETE = true;
const TEXT_TO_SQL_PROVIDER_DEFAULT = 'chatgpt';
const TEXT_TO_SQL_MODEL_DEFAULT = 'gpt-4o-mini';
const TEXT_TO_SQL_CUSTOM_ENDPOINT_DEFAULT = '';
const TEXT_TO_SQL_CUSTOM_AUTH_TYPE_DEFAULT = 'Bearer';
const SQL_KEYWORDS = [
'SELECT',
'FROM',
Expand Down Expand Up @@ -298,6 +299,7 @@ const textToSqlProviderInput = $('#text-to-sql-provider-input');
const textToSqlModelInput = $('#text-to-sql-model-input');
const textToSqlApiKeyInput = $('#text-to-sql-api-key-input');
const textToSqlCustomEndpointInput = $('#text-to-sql-custom-endpoint-input');
const textToSqlCustomAuthTypeInput = $('#text-to-sql-custom-auth-type-input');
const textToSqlCustomEndpointGroup = $('#text-to-sql-custom-endpoint-group');
const saveSettingsBtn = $('#save-settings-btn');
const settingsModal = $('#settingsModal');
Expand Down Expand Up @@ -515,6 +517,8 @@ async function buildSchemaContext() {
function extractGeneratedSql(response) {
if (!response) return '';
if (typeof response === 'string') return response.trim();
const chatContent = response?.choices?.[0]?.message?.content;
if (typeof chatContent === 'string') return chatContent.trim();
if (typeof response.sql === 'string') return response.sql.trim();
if (typeof response.query === 'string') return response.query.trim();
if (response.data && typeof response.data.sql === 'string') return response.data.sql.trim();
Expand Down Expand Up @@ -565,6 +569,14 @@ function resolveCustomTextToSqlEndpoint(settings) {
return rawEndpoint;
}

function resolveCustomTextToSqlAuthType(settings) {
const rawAuthType = (
settings?.textToSqlCustomAuthType || TEXT_TO_SQL_CUSTOM_AUTH_TYPE_DEFAULT
).trim();
if (!rawAuthType) return TEXT_TO_SQL_CUSTOM_AUTH_TYPE_DEFAULT;
return rawAuthType.toLowerCase() === 'bearer' ? 'Bearer' : rawAuthType;
}

function handleTextToSqlHttpError(response, bodyText) {
const msg = bodyText || `${response.status} ${response.statusText}`;
throw new Error(`Text-to-SQL request failed: ${msg}`);
Expand All @@ -583,7 +595,8 @@ async function requestChatGptSql({ apiKey, model, finalPrompt }) {
messages: [
{
role: 'system',
content: 'You are a SQLite SQL generator. Return only SQL.',
content:
'You are a SQLite SQL generator. Use "," for cross joins. Do not use aliases unless necessary. Return only SQL.',
},
{
role: 'user',
Expand Down Expand Up @@ -612,7 +625,8 @@ async function requestClaudeSql({ apiKey, model, finalPrompt }) {
model,
max_tokens: 1024,
temperature: 0,
system: 'You are a SQLite SQL generator. Return only SQL.',
system:
'You are a SQLite SQL generator. Use "," for cross joins. Do not use aliases unless necessary. Return only SQL.',
messages: [{ role: 'user', content: finalPrompt }],
}),
});
Expand All @@ -627,12 +641,17 @@ async function requestClaudeSql({ apiKey, model, finalPrompt }) {

async function requestGeminiSql({ apiKey, model, finalPrompt }) {
const endpoint = `https://generativelanguage.googleapis.com/v1beta/models/${encodeURIComponent(model)}:generateContent?key=${encodeURIComponent(apiKey)}`;
const systemPrompt =
'You are a SQLite SQL generator. Use "," for cross joins. Do not use aliases unless necessary. Return only SQL.';
const response = await fetch(endpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
systemInstruction: {
parts: [{ text: systemPrompt }],
},
generationConfig: {
temperature: 0,
},
Expand All @@ -652,19 +671,32 @@ async function requestGeminiSql({ apiKey, model, finalPrompt }) {
return data?.candidates?.[0]?.content?.parts?.[0]?.text || '';
}

async function requestCustomSql({ endpoint, apiKey, model, promptText, schema }) {
async function requestCustomSql({ endpoint, authType, apiKey, model, finalPrompt }) {
const headers = {
'Content-Type': 'application/json',
};
if (authType === 'Bearer') {
headers['Authorization'] = `Bearer ${apiKey}`;
} else if (authType) {
headers[authType] = `${apiKey}`;
}
const response = await fetch(endpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
'x-api-key': apiKey,
},
headers,
body: JSON.stringify({
prompt: promptText,
dialect: 'sqlite',
schema,
model,
temperature: 0,
messages: [
{
role: 'system',
content:
'You are a SQLite SQL generator. Use "," for cross joins. Do not use aliases unless necessary. Return only SQL.',
},
{
role: 'user',
content: finalPrompt,
},
],
}),
});

Expand Down Expand Up @@ -705,12 +737,13 @@ async function generateSqlFromPrompt(promptText) {
generatedRaw = await requestGeminiSql({ apiKey, model, finalPrompt });
} else if (provider === 'custom') {
const customEndpoint = resolveCustomTextToSqlEndpoint(settings);
const customAuthType = resolveCustomTextToSqlAuthType(settings);
const customResponse = await requestCustomSql({
endpoint: customEndpoint,
authType: customAuthType,
apiKey,
model,
promptText,
schema,
finalPrompt,
});
generatedRaw = extractGeneratedSql(customResponse);
} else {
Expand Down Expand Up @@ -1491,6 +1524,8 @@ function bindEvents() {
textToSqlModelInput.value = settings.textToSqlModel || TEXT_TO_SQL_MODEL_DEFAULT;
textToSqlApiKeyInput.value = settings.textToSqlApiKey || '';
textToSqlCustomEndpointInput.value = settings.textToSqlCustomEndpoint || '';
textToSqlCustomAuthTypeInput.value =
settings.textToSqlCustomAuthType || TEXT_TO_SQL_CUSTOM_AUTH_TYPE_DEFAULT;
toggleCustomEndpointField();
});
}
Expand All @@ -1516,6 +1551,9 @@ function bindEvents() {
textToSqlModel: textToSqlModelInput.value.trim(),
textToSqlApiKey: textToSqlApiKeyInput.value.trim(),
textToSqlCustomEndpoint: customEndpoint,
textToSqlCustomAuthType:
(textToSqlCustomAuthTypeInput?.value || TEXT_TO_SQL_CUSTOM_AUTH_TYPE_DEFAULT).trim() ||
TEXT_TO_SQL_CUSTOM_AUTH_TYPE_DEFAULT,
});
// Close modal via Bootstrap
const modal = window.bootstrap.Modal.getInstance(settingsModal);
Expand Down
9 changes: 9 additions & 0 deletions src/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,15 @@ <h5 class="modal-title" id="settingsModalLabel">Settings</h5>
id="text-to-sql-custom-endpoint-input"
placeholder="https://openrouter.ai/api/v1/chat/completions"
/>
<label for="text-to-sql-custom-auth-type-input" class="form-label fw-bold mt-3">
Auth Type
</label>
<input
type="text"
class="form-control"
id="text-to-sql-custom-auth-type-input"
placeholder="Bearer"
/>
</div>
</div>
<div class="modal-footer">
Expand Down
2 changes: 1 addition & 1 deletion src/public/js/sqlite-worker.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* It imports sqlite-wasm from /sqlite-wasm/ which is copied from node_modules.
*/

import sqlite3InitModule from './sqlite-wasm/index.mjs';
import sqlite3InitModule from '../sqlite-wasm/index.mjs';

let sqlite3 = null;
let db = null;
Expand Down
Loading