diff --git a/src/include/sqlite_stmt.hpp b/src/include/sqlite_stmt.hpp index b737e49..c69a245 100644 --- a/src/include/sqlite_stmt.hpp +++ b/src/include/sqlite_stmt.hpp @@ -55,6 +55,7 @@ class SQLiteStatement { int expected_type, idx_t col_idx); void CheckTypeIsFloatOrInteger(sqlite3_value *val, int sqlite_column_type, idx_t col_idx); void Reset(); + void ClearBindings(); }; template <> diff --git a/src/sqlite_scanner.cpp b/src/sqlite_scanner.cpp index 568595a..91b7180 100644 --- a/src/sqlite_scanner.cpp +++ b/src/sqlite_scanner.cpp @@ -21,6 +21,7 @@ struct SqliteLocalState : public LocalTableFunctionState { SQLiteDB *db; SQLiteDB owned_db; SQLiteStatement stmt; + string stmt_sql; bool done = false; vector column_ids; //! The amount of rows we scanned as part of this row group @@ -85,48 +86,56 @@ static unique_ptr SqliteBind(ClientContext &context, TableFunction return std::move(result); } +static string SqliteGetScanSQL(const SqliteBindData &bind_data, const vector &column_ids) { + if (!bind_data.sql.empty()) { + return bind_data.sql; + } + auto col_names = StringUtil::Join(column_ids.data(), column_ids.size(), ", ", [&](const idx_t column_id) { + return column_id == (column_t)-1 ? "ROWID" + : '"' + SQLiteUtils::SanitizeIdentifier(bind_data.names[column_id]) + '"'; + }); + + auto sql = StringUtil::Format("SELECT %s FROM \"%s\"", col_names, SQLiteUtils::SanitizeIdentifier(bind_data.table_name)); + if (bind_data.rows_per_group.IsValid()) { + sql += " WHERE ROWID BETWEEN ? AND ?"; + } + return sql; +} + +static void SqlitePrepareStatement(SqliteLocalState &local_state, const string &sql) { + if (!local_state.stmt.IsOpen() || local_state.stmt_sql != sql) { + local_state.stmt.Close(); + local_state.stmt = local_state.db->Prepare(sql.c_str()); + local_state.stmt_sql = sql; + return; + } + local_state.stmt.Reset(); + local_state.stmt.ClearBindings(); +} + static void SqliteInitInternal(ClientContext &context, const SqliteBindData &bind_data, SqliteLocalState &local_state, idx_t rowid_min, idx_t rowid_max) { D_ASSERT(rowid_min <= rowid_max); local_state.done = false; - // we may have leftover statements or connections from a previous call to this - // function - local_state.stmt.Close(); if (!local_state.db) { SQLiteOpenOptions options; options.access_mode = AccessMode::READ_ONLY; local_state.owned_db = SQLiteDB::Open(bind_data.file_name.c_str(), options); local_state.db = &local_state.owned_db; } - string sql; - if (bind_data.sql.empty()) { - auto col_names = StringUtil::Join( - local_state.column_ids.data(), local_state.column_ids.size(), ", ", [&](const idx_t column_id) { - return column_id == (column_t)-1 - ? "ROWID" - : '"' + SQLiteUtils::SanitizeIdentifier(bind_data.names[column_id]) + '"'; - }); - - sql = StringUtil::Format("SELECT %s FROM \"%s\"", col_names, - SQLiteUtils::SanitizeIdentifier(bind_data.table_name)); - if (bind_data.rows_per_group.IsValid()) { - // we are scanning a subset of the rows - generate a WHERE clause based on - // the rowid - auto where_clause = StringUtil::Format(" WHERE ROWID BETWEEN %d AND %d", rowid_min, rowid_max); - sql += where_clause; - } else { - // we are scanning the entire table - no need for a WHERE clause - D_ASSERT(rowid_min == 0); - } - } else { - sql = bind_data.sql; + string sql = SqliteGetScanSQL(bind_data, local_state.column_ids); + SqlitePrepareStatement(local_state, sql); + + idx_t param_idx = 0; + for (; param_idx < bind_data.params.size(); param_idx++) { + const Value ¶m = bind_data.params[param_idx]; + local_state.stmt.BindParameter(param, param_idx); } - local_state.stmt = local_state.db->Prepare(sql.c_str()); - for (idx_t i = 0; i < bind_data.params.size(); i++) { - const Value ¶m = bind_data.params[i]; - local_state.stmt.BindParameter(param, i); + if (bind_data.rows_per_group.IsValid()) { + local_state.stmt.Bind(param_idx++, UnsafeNumericCast(rowid_min)); + local_state.stmt.Bind(param_idx++, UnsafeNumericCast(rowid_max)); } } @@ -153,37 +162,45 @@ static idx_t SqliteMaxThreads(ClientContext &context, const FunctionData *bind_d return row_count / bind_data.rows_per_group.GetIndex(); } -static bool SqliteParallelStateNext(ClientContext &context, const SqliteBindData &bind_data, SqliteLocalState &lstate, - SqliteGlobalState &gstate) { +static bool SqliteClaimNextSlice(const SqliteBindData &bind_data, SqliteLocalState &lstate, SqliteGlobalState &gstate, + idx_t &rowid_min, idx_t &rowid_max) { lock_guard parallel_lock(gstate.lock); if (!bind_data.rows_per_group.IsValid()) { // not doing a parallel scan - scan everything at once if (gstate.position > 0) { - // already scanned return false; } - SqliteInitInternal(context, bind_data, lstate, 0, 0); gstate.position = static_cast(-1); lstate.scan_count = 0; return true; } auto max_row_id = bind_data.row_id_info.max_rowid.GetIndex(); - if (gstate.position < max_row_id) { - if (lstate.scan_count == 0 && gstate.rows_per_group < max_row_id) { - // we scanned no rows in our previous slice - double the rows per group - gstate.rows_per_group *= 2; - } - if (gstate.rows_per_group == 0) { - throw InternalException("SqliteParallelStateNext - gstate.rows_per_group not set"); - } - auto start = gstate.position; - auto end = MinValue(max_row_id, start + gstate.rows_per_group - 1); - SqliteInitInternal(context, bind_data, lstate, start, end); - gstate.position = end + 1; - lstate.scan_count = 0; - return true; + if (gstate.position >= max_row_id) { + return false; + } + if (lstate.scan_count == 0 && gstate.rows_per_group < max_row_id) { + // we scanned no rows in our previous slice - double the rows per group + gstate.rows_per_group *= 2; + } + if (gstate.rows_per_group == 0) { + throw InternalException("SqliteParallelStateNext - gstate.rows_per_group not set"); + } + rowid_min = gstate.position; + rowid_max = MinValue(max_row_id, rowid_min + gstate.rows_per_group - 1); + gstate.position = rowid_max + 1; + lstate.scan_count = 0; + return true; +} + +static bool SqliteParallelStateNext(ClientContext &context, const SqliteBindData &bind_data, SqliteLocalState &lstate, + SqliteGlobalState &gstate) { + idx_t rowid_min = 0; + idx_t rowid_max = 0; + if (!SqliteClaimNextSlice(bind_data, lstate, gstate, rowid_min, rowid_max)) { + return false; } - return false; + SqliteInitInternal(context, bind_data, lstate, rowid_min, rowid_max); + return true; } static unique_ptr diff --git a/src/sqlite_stmt.cpp b/src/sqlite_stmt.cpp index 8237295..bbd3c5d 100644 --- a/src/sqlite_stmt.cpp +++ b/src/sqlite_stmt.cpp @@ -101,6 +101,10 @@ void SQLiteStatement::Reset() { SQLiteUtils::Check(sqlite3_reset(stmt), db); } +void SQLiteStatement::ClearBindings() { + SQLiteUtils::Check(sqlite3_clear_bindings(stmt), db); +} + template <> string SQLiteStatement::GetValue(idx_t col) { D_ASSERT(stmt);