Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/include/sqlite_stmt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class SQLiteStatement {
int expected_type, idx_t col_idx);
void CheckTypeIsFloatOrInteger(sqlite3_value *val, int sqlite_column_type, idx_t col_idx);
void Reset();
void ClearBindings();
};

template <>
Expand Down
113 changes: 65 additions & 48 deletions src/sqlite_scanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct SqliteLocalState : public LocalTableFunctionState {
SQLiteDB *db;
SQLiteDB owned_db;
SQLiteStatement stmt;
string stmt_sql;
bool done = false;
vector<column_t> column_ids;
//! The amount of rows we scanned as part of this row group
Expand Down Expand Up @@ -85,48 +86,56 @@ static unique_ptr<FunctionData> SqliteBind(ClientContext &context, TableFunction
return std::move(result);
}

static string SqliteGetScanSQL(const SqliteBindData &bind_data, const vector<column_t> &column_ids) {
if (!bind_data.sql.empty()) {
return bind_data.sql;
}
auto col_names = StringUtil::Join(column_ids.data(), column_ids.size(), ", ", [&](const idx_t column_id) {
return column_id == (column_t)-1 ? "ROWID"
: '"' + SQLiteUtils::SanitizeIdentifier(bind_data.names[column_id]) + '"';
});

auto sql = StringUtil::Format("SELECT %s FROM \"%s\"", col_names, SQLiteUtils::SanitizeIdentifier(bind_data.table_name));
if (bind_data.rows_per_group.IsValid()) {
sql += " WHERE ROWID BETWEEN ? AND ?";
}
return sql;
}

static void SqlitePrepareStatement(SqliteLocalState &local_state, const string &sql) {
if (!local_state.stmt.IsOpen() || local_state.stmt_sql != sql) {
local_state.stmt.Close();
local_state.stmt = local_state.db->Prepare(sql.c_str());
local_state.stmt_sql = sql;
return;
}
local_state.stmt.Reset();
local_state.stmt.ClearBindings();
}

static void SqliteInitInternal(ClientContext &context, const SqliteBindData &bind_data, SqliteLocalState &local_state,
idx_t rowid_min, idx_t rowid_max) {
D_ASSERT(rowid_min <= rowid_max);

local_state.done = false;
// we may have leftover statements or connections from a previous call to this
// function
local_state.stmt.Close();
if (!local_state.db) {
SQLiteOpenOptions options;
options.access_mode = AccessMode::READ_ONLY;
local_state.owned_db = SQLiteDB::Open(bind_data.file_name.c_str(), options);
local_state.db = &local_state.owned_db;
}
string sql;
if (bind_data.sql.empty()) {
auto col_names = StringUtil::Join(
local_state.column_ids.data(), local_state.column_ids.size(), ", ", [&](const idx_t column_id) {
return column_id == (column_t)-1
? "ROWID"
: '"' + SQLiteUtils::SanitizeIdentifier(bind_data.names[column_id]) + '"';
});

sql = StringUtil::Format("SELECT %s FROM \"%s\"", col_names,
SQLiteUtils::SanitizeIdentifier(bind_data.table_name));
if (bind_data.rows_per_group.IsValid()) {
// we are scanning a subset of the rows - generate a WHERE clause based on
// the rowid
auto where_clause = StringUtil::Format(" WHERE ROWID BETWEEN %d AND %d", rowid_min, rowid_max);
sql += where_clause;
} else {
// we are scanning the entire table - no need for a WHERE clause
D_ASSERT(rowid_min == 0);
}
} else {
sql = bind_data.sql;
string sql = SqliteGetScanSQL(bind_data, local_state.column_ids);
SqlitePrepareStatement(local_state, sql);

idx_t param_idx = 0;
for (; param_idx < bind_data.params.size(); param_idx++) {
const Value &param = bind_data.params[param_idx];
local_state.stmt.BindParameter(param, param_idx);
}
local_state.stmt = local_state.db->Prepare(sql.c_str());

for (idx_t i = 0; i < bind_data.params.size(); i++) {
const Value &param = bind_data.params[i];
local_state.stmt.BindParameter(param, i);
if (bind_data.rows_per_group.IsValid()) {
local_state.stmt.Bind<int64_t>(param_idx++, UnsafeNumericCast<int64_t>(rowid_min));
local_state.stmt.Bind<int64_t>(param_idx++, UnsafeNumericCast<int64_t>(rowid_max));
Comment thread
staticlibs marked this conversation as resolved.
}
}

Expand All @@ -153,37 +162,45 @@ static idx_t SqliteMaxThreads(ClientContext &context, const FunctionData *bind_d
return row_count / bind_data.rows_per_group.GetIndex();
}

static bool SqliteParallelStateNext(ClientContext &context, const SqliteBindData &bind_data, SqliteLocalState &lstate,
SqliteGlobalState &gstate) {
static bool SqliteClaimNextSlice(const SqliteBindData &bind_data, SqliteLocalState &lstate, SqliteGlobalState &gstate,
idx_t &rowid_min, idx_t &rowid_max) {
lock_guard<mutex> parallel_lock(gstate.lock);
if (!bind_data.rows_per_group.IsValid()) {
// not doing a parallel scan - scan everything at once
if (gstate.position > 0) {
// already scanned
return false;
}
SqliteInitInternal(context, bind_data, lstate, 0, 0);
gstate.position = static_cast<idx_t>(-1);
lstate.scan_count = 0;
return true;
}
auto max_row_id = bind_data.row_id_info.max_rowid.GetIndex();
if (gstate.position < max_row_id) {
if (lstate.scan_count == 0 && gstate.rows_per_group < max_row_id) {
// we scanned no rows in our previous slice - double the rows per group
gstate.rows_per_group *= 2;
}
if (gstate.rows_per_group == 0) {
throw InternalException("SqliteParallelStateNext - gstate.rows_per_group not set");
}
auto start = gstate.position;
auto end = MinValue<idx_t>(max_row_id, start + gstate.rows_per_group - 1);
SqliteInitInternal(context, bind_data, lstate, start, end);
gstate.position = end + 1;
lstate.scan_count = 0;
return true;
if (gstate.position >= max_row_id) {
return false;
}
if (lstate.scan_count == 0 && gstate.rows_per_group < max_row_id) {
// we scanned no rows in our previous slice - double the rows per group
gstate.rows_per_group *= 2;
}
if (gstate.rows_per_group == 0) {
throw InternalException("SqliteParallelStateNext - gstate.rows_per_group not set");
}
rowid_min = gstate.position;
rowid_max = MinValue<idx_t>(max_row_id, rowid_min + gstate.rows_per_group - 1);
gstate.position = rowid_max + 1;
lstate.scan_count = 0;
return true;
}

static bool SqliteParallelStateNext(ClientContext &context, const SqliteBindData &bind_data, SqliteLocalState &lstate,
SqliteGlobalState &gstate) {
idx_t rowid_min = 0;
idx_t rowid_max = 0;
if (!SqliteClaimNextSlice(bind_data, lstate, gstate, rowid_min, rowid_max)) {
return false;
}
return false;
SqliteInitInternal(context, bind_data, lstate, rowid_min, rowid_max);
return true;
}

static unique_ptr<LocalTableFunctionState>
Expand Down
4 changes: 4 additions & 0 deletions src/sqlite_stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ void SQLiteStatement::Reset() {
SQLiteUtils::Check(sqlite3_reset(stmt), db);
}

void SQLiteStatement::ClearBindings() {
SQLiteUtils::Check(sqlite3_clear_bindings(stmt), db);
}

template <>
string SQLiteStatement::GetValue(idx_t col) {
D_ASSERT(stmt);
Expand Down
Loading