From a64dbdf63123a35d0afb650bdf03d1fd32ba0100 Mon Sep 17 00:00:00 2001 From: Ilya Danilov Date: Wed, 13 May 2026 19:06:47 +0300 Subject: [PATCH] fix: align gorm v1.31.1 and postgres driver v1.6.0 across services --- auth-center/go.mod | 2 +- auth-center/go.sum | 4 +- auth-center/vendor/gorm.io/gorm/README.md | 2 +- .../vendor/gorm.io/gorm/association.go | 86 ++- auth-center/vendor/gorm.io/gorm/callbacks.go | 8 +- .../vendor/gorm.io/gorm/callbacks/create.go | 31 +- .../vendor/gorm.io/gorm/chainable_api.go | 4 +- .../vendor/gorm.io/gorm/clause/association.go | 35 + auth-center/vendor/gorm.io/gorm/clause/set.go | 11 + .../vendor/gorm.io/gorm/finisher_api.go | 10 +- auth-center/vendor/gorm.io/gorm/generics.go | 325 +++++++- auth-center/vendor/gorm.io/gorm/gorm.go | 9 + .../vendor/gorm.io/gorm/logger/slog.go | 116 +++ .../vendor/gorm.io/gorm/migrator/migrator.go | 18 +- .../vendor/gorm.io/gorm/schema/field.go | 16 +- .../gorm.io/gorm/schema/relationship.go | 15 +- .../vendor/gorm.io/gorm/schema/schema.go | 229 +++--- .../vendor/gorm.io/gorm/schema/serializer.go | 28 +- .../vendor/gorm.io/gorm/schema/utils.go | 37 +- auth-center/vendor/gorm.io/gorm/statement.go | 29 +- .../vendor/gorm.io/gorm/utils/utils.go | 22 +- auth-center/vendor/modules.txt | 2 +- cluster-manager/go.mod | 9 +- cluster-manager/go.sum | 18 +- .../jackc/pgservicefile/.travis.yml | 9 - .../github.com/jackc/pgservicefile/README.md | 5 +- .../jackc/pgservicefile/pgservicefile.go | 4 +- .../github.com/jackc/pgx/v5/CHANGELOG.md | 89 +++ .../github.com/jackc/pgx/v5/CONTRIBUTING.md | 19 +- .../vendor/github.com/jackc/pgx/v5/README.md | 7 +- .../vendor/github.com/jackc/pgx/v5/batch.go | 52 +- .../vendor/github.com/jackc/pgx/v5/conn.go | 183 +++-- .../github.com/jackc/pgx/v5/copy_from.go | 27 + .../vendor/github.com/jackc/pgx/v5/doc.go | 13 +- .../jackc/pgx/v5/extended_query_builder.go | 90 +-- .../jackc/pgx/v5/internal/anynil/anynil.go | 36 - .../pgx/v5/internal/sanitize/sanitize.go | 9 + .../pgx/v5/internal/stmtcache/lru_cache.go | 22 +- .../pgx/v5/internal/stmtcache/stmtcache.go | 40 +- .../v5/internal/stmtcache/unlimited_cache.go | 12 +- .../github.com/jackc/pgx/v5/large_objects.go | 76 +- .../github.com/jackc/pgx/v5/named_args.go | 63 +- .../jackc/pgx/v5/pgconn/auth_scram.go | 4 +- .../github.com/jackc/pgx/v5/pgconn/config.go | 83 +- .../ctxwatch/context_watcher.go | 25 +- .../github.com/jackc/pgx/v5/pgconn/doc.go | 16 +- .../github.com/jackc/pgx/v5/pgconn/errors.go | 104 ++- .../github.com/jackc/pgx/v5/pgconn/pgconn.go | 590 +++++++++----- .../jackc/pgx/v5/pgproto3/README.md | 2 +- .../authentication_cleartext_password.go | 7 +- .../pgx/v5/pgproto3/authentication_gss.go | 7 +- .../pgproto3/authentication_gss_continue.go | 7 +- .../pgproto3/authentication_md5_password.go | 7 +- .../pgx/v5/pgproto3/authentication_ok.go | 7 +- .../pgx/v5/pgproto3/authentication_sasl.go | 10 +- .../pgproto3/authentication_sasl_continue.go | 12 +- .../v5/pgproto3/authentication_sasl_final.go | 12 +- .../jackc/pgx/v5/pgproto3/backend.go | 38 +- .../jackc/pgx/v5/pgproto3/backend_key_data.go | 7 +- .../github.com/jackc/pgx/v5/pgproto3/bind.go | 21 +- .../jackc/pgx/v5/pgproto3/bind_complete.go | 4 +- .../jackc/pgx/v5/pgproto3/cancel_request.go | 4 +- .../github.com/jackc/pgx/v5/pgproto3/close.go | 14 +- .../jackc/pgx/v5/pgproto3/close_complete.go | 4 +- .../jackc/pgx/v5/pgproto3/command_complete.go | 14 +- .../pgx/v5/pgproto3/copy_both_response.go | 14 +- .../jackc/pgx/v5/pgproto3/copy_data.go | 9 +- .../jackc/pgx/v5/pgproto3/copy_done.go | 4 +- .../jackc/pgx/v5/pgproto3/copy_fail.go | 14 +- .../jackc/pgx/v5/pgproto3/copy_in_response.go | 14 +- .../pgx/v5/pgproto3/copy_out_response.go | 14 +- .../jackc/pgx/v5/pgproto3/data_row.go | 15 +- .../jackc/pgx/v5/pgproto3/describe.go | 14 +- .../github.com/jackc/pgx/v5/pgproto3/doc.go | 4 +- .../pgx/v5/pgproto3/empty_query_response.go | 4 +- .../jackc/pgx/v5/pgproto3/error_response.go | 135 ++-- .../jackc/pgx/v5/pgproto3/execute.go | 13 +- .../github.com/jackc/pgx/v5/pgproto3/flush.go | 4 +- .../jackc/pgx/v5/pgproto3/frontend.go | 137 +++- .../jackc/pgx/v5/pgproto3/function_call.go | 19 +- .../pgx/v5/pgproto3/function_call_response.go | 10 +- .../jackc/pgx/v5/pgproto3/gss_enc_request.go | 4 +- .../jackc/pgx/v5/pgproto3/gss_response.go | 9 +- .../jackc/pgx/v5/pgproto3/no_data.go | 4 +- .../jackc/pgx/v5/pgproto3/notice_response.go | 6 +- .../pgx/v5/pgproto3/notification_response.go | 12 +- .../pgx/v5/pgproto3/parameter_description.go | 15 +- .../jackc/pgx/v5/pgproto3/parameter_status.go | 14 +- .../github.com/jackc/pgx/v5/pgproto3/parse.go | 15 +- .../jackc/pgx/v5/pgproto3/parse_complete.go | 4 +- .../jackc/pgx/v5/pgproto3/password_message.go | 11 +- .../jackc/pgx/v5/pgproto3/pgproto3.go | 37 +- .../jackc/pgx/v5/pgproto3/portal_suspended.go | 4 +- .../github.com/jackc/pgx/v5/pgproto3/query.go | 11 +- .../jackc/pgx/v5/pgproto3/ready_for_query.go | 4 +- .../jackc/pgx/v5/pgproto3/row_description.go | 15 +- .../pgx/v5/pgproto3/sasl_initial_response.go | 10 +- .../jackc/pgx/v5/pgproto3/sasl_response.go | 11 +- .../jackc/pgx/v5/pgproto3/ssl_request.go | 4 +- .../jackc/pgx/v5/pgproto3/startup_message.go | 10 +- .../github.com/jackc/pgx/v5/pgproto3/sync.go | 4 +- .../jackc/pgx/v5/pgproto3/terminate.go | 4 +- .../github.com/jackc/pgx/v5/pgtype/array.go | 22 +- .../jackc/pgx/v5/pgtype/array_codec.go | 3 +- .../github.com/jackc/pgx/v5/pgtype/bits.go | 4 +- .../jackc/pgx/v5/pgtype/builtin_wrappers.go | 4 +- .../github.com/jackc/pgx/v5/pgtype/date.go | 6 +- .../github.com/jackc/pgx/v5/pgtype/doc.go | 14 +- .../github.com/jackc/pgx/v5/pgtype/float4.go | 28 +- .../github.com/jackc/pgx/v5/pgtype/float8.go | 30 +- .../github.com/jackc/pgx/v5/pgtype/inet.go | 2 +- .../jackc/pgx/v5/pgtype/interval.go | 49 +- .../github.com/jackc/pgx/v5/pgtype/json.go | 77 +- .../github.com/jackc/pgx/v5/pgtype/jsonb.go | 28 +- .../github.com/jackc/pgx/v5/pgtype/ltree.go | 122 +++ .../jackc/pgx/v5/pgtype/multirange.go | 8 +- .../github.com/jackc/pgx/v5/pgtype/numeric.go | 20 + .../github.com/jackc/pgx/v5/pgtype/pgtype.go | 87 ++- .../jackc/pgx/v5/pgtype/pgtype_default.go | 11 +- .../github.com/jackc/pgx/v5/pgtype/point.go | 16 +- .../github.com/jackc/pgx/v5/pgtype/range.go | 14 +- .../jackc/pgx/v5/pgtype/range_codec.go | 16 +- .../github.com/jackc/pgx/v5/pgtype/tid.go | 8 +- .../github.com/jackc/pgx/v5/pgtype/time.go | 43 +- .../jackc/pgx/v5/pgtype/timestamp.go | 39 +- .../jackc/pgx/v5/pgtype/timestamptz.go | 39 +- .../github.com/jackc/pgx/v5/pgtype/uuid.go | 14 +- .../jackc/pgx/v5/pgxpool/batch_results.go | 52 ++ .../github.com/jackc/pgx/v5/pgxpool/conn.go | 134 ++++ .../github.com/jackc/pgx/v5/pgxpool/doc.go | 27 + .../github.com/jackc/pgx/v5/pgxpool/pool.go | 717 ++++++++++++++++++ .../github.com/jackc/pgx/v5/pgxpool/rows.go | 116 +++ .../github.com/jackc/pgx/v5/pgxpool/stat.go | 84 ++ .../github.com/jackc/pgx/v5/pgxpool/tracer.go | 33 + .../github.com/jackc/pgx/v5/pgxpool/tx.go | 82 ++ .../vendor/github.com/jackc/pgx/v5/rows.go | 348 ++++++--- .../github.com/jackc/pgx/v5/stdlib/sql.go | 140 +++- .../vendor/github.com/jackc/pgx/v5/values.go | 15 +- .../github.com/jackc/puddle/v2/CHANGELOG.md | 79 ++ .../vendor/github.com/jackc/puddle/v2/LICENSE | 22 + .../github.com/jackc/puddle/v2/README.md | 80 ++ .../github.com/jackc/puddle/v2/context.go | 24 + .../vendor/github.com/jackc/puddle/v2/doc.go | 11 + .../puddle/v2/internal/genstack/gen_stack.go | 85 +++ .../puddle/v2/internal/genstack/stack.go | 39 + .../vendor/github.com/jackc/puddle/v2/log.go | 32 + .../github.com/jackc/puddle/v2/nanotime.go | 16 + .../vendor/github.com/jackc/puddle/v2/pool.go | 710 +++++++++++++++++ .../jackc/puddle/v2/resource_list.go | 28 + .../driver/postgres/error_translator.go | 2 + .../gorm.io/driver/postgres/migrator.go | 63 +- .../gorm.io/driver/postgres/postgres.go | 22 +- cluster-manager/vendor/gorm.io/gorm/README.md | 2 +- .../vendor/gorm.io/gorm/association.go | 86 ++- .../vendor/gorm.io/gorm/callbacks.go | 8 +- .../vendor/gorm.io/gorm/callbacks/create.go | 31 +- .../vendor/gorm.io/gorm/chainable_api.go | 4 +- .../vendor/gorm.io/gorm/clause/association.go | 35 + .../vendor/gorm.io/gorm/clause/set.go | 11 + .../vendor/gorm.io/gorm/finisher_api.go | 10 +- .../vendor/gorm.io/gorm/generics.go | 325 +++++++- cluster-manager/vendor/gorm.io/gorm/gorm.go | 9 + .../vendor/gorm.io/gorm/logger/slog.go | 116 +++ .../vendor/gorm.io/gorm/migrator/migrator.go | 18 +- .../vendor/gorm.io/gorm/schema/field.go | 16 +- .../gorm.io/gorm/schema/relationship.go | 15 +- .../vendor/gorm.io/gorm/schema/schema.go | 229 +++--- .../vendor/gorm.io/gorm/schema/serializer.go | 28 +- .../vendor/gorm.io/gorm/schema/utils.go | 37 +- .../vendor/gorm.io/gorm/statement.go | 29 +- .../vendor/gorm.io/gorm/utils/utils.go | 22 +- cluster-manager/vendor/modules.txt | 20 +- cs-manager/go.mod | 4 +- cs-manager/go.sum | 8 +- .../driver/postgres/error_translator.go | 2 + .../gorm.io/driver/postgres/migrator.go | 63 +- .../gorm.io/driver/postgres/postgres.go | 22 +- cs-manager/vendor/gorm.io/gorm/README.md | 2 +- cs-manager/vendor/gorm.io/gorm/association.go | 86 ++- cs-manager/vendor/gorm.io/gorm/callbacks.go | 8 +- .../vendor/gorm.io/gorm/callbacks/create.go | 31 +- .../vendor/gorm.io/gorm/chainable_api.go | 4 +- .../vendor/gorm.io/gorm/clause/association.go | 35 + cs-manager/vendor/gorm.io/gorm/clause/set.go | 11 + .../vendor/gorm.io/gorm/finisher_api.go | 10 +- cs-manager/vendor/gorm.io/gorm/generics.go | 325 +++++++- cs-manager/vendor/gorm.io/gorm/gorm.go | 9 + cs-manager/vendor/gorm.io/gorm/logger/slog.go | 116 +++ .../vendor/gorm.io/gorm/migrator/migrator.go | 18 +- .../vendor/gorm.io/gorm/schema/field.go | 16 +- .../gorm.io/gorm/schema/relationship.go | 15 +- .../vendor/gorm.io/gorm/schema/schema.go | 229 +++--- .../vendor/gorm.io/gorm/schema/serializer.go | 28 +- .../vendor/gorm.io/gorm/schema/utils.go | 37 +- cs-manager/vendor/gorm.io/gorm/statement.go | 29 +- cs-manager/vendor/gorm.io/gorm/utils/utils.go | 22 +- cs-manager/vendor/modules.txt | 6 +- event-processor/go.mod | 2 +- event-processor/go.sum | 4 +- event-processor/vendor/gorm.io/gorm/README.md | 2 +- .../vendor/gorm.io/gorm/association.go | 4 +- .../vendor/gorm.io/gorm/chainable_api.go | 4 +- .../vendor/gorm.io/gorm/finisher_api.go | 4 +- .../vendor/gorm.io/gorm/generics.go | 18 +- .../vendor/gorm.io/gorm/logger/slog.go | 28 +- .../vendor/gorm.io/gorm/migrator/migrator.go | 4 + .../vendor/gorm.io/gorm/schema/schema.go | 10 + .../vendor/gorm.io/gorm/schema/serializer.go | 28 +- .../vendor/gorm.io/gorm/schema/utils.go | 26 +- .../vendor/gorm.io/gorm/utils/utils.go | 22 +- event-processor/vendor/modules.txt | 2 +- history-api/go.mod | 2 +- history-api/go.sum | 4 +- history-api/vendor/gorm.io/gorm/README.md | 2 +- .../vendor/gorm.io/gorm/association.go | 4 +- .../vendor/gorm.io/gorm/chainable_api.go | 4 +- .../vendor/gorm.io/gorm/finisher_api.go | 4 +- history-api/vendor/gorm.io/gorm/generics.go | 18 +- .../vendor/gorm.io/gorm/logger/slog.go | 28 +- .../vendor/gorm.io/gorm/migrator/migrator.go | 4 + .../vendor/gorm.io/gorm/schema/schema.go | 10 + .../vendor/gorm.io/gorm/schema/serializer.go | 28 +- .../vendor/gorm.io/gorm/schema/utils.go | 26 +- .../vendor/gorm.io/gorm/utils/utils.go | 22 +- history-api/vendor/modules.txt | 2 +- lib/go.mod | 2 +- lib/go.sum | 4 +- lib/vendor/gorm.io/gorm/README.md | 2 +- lib/vendor/gorm.io/gorm/association.go | 86 ++- lib/vendor/gorm.io/gorm/callbacks.go | 8 +- lib/vendor/gorm.io/gorm/chainable_api.go | 4 +- lib/vendor/gorm.io/gorm/clause/association.go | 35 + lib/vendor/gorm.io/gorm/clause/set.go | 11 + lib/vendor/gorm.io/gorm/finisher_api.go | 10 +- lib/vendor/gorm.io/gorm/generics.go | 325 +++++++- lib/vendor/gorm.io/gorm/gorm.go | 9 + lib/vendor/gorm.io/gorm/logger/slog.go | 116 +++ lib/vendor/gorm.io/gorm/schema/field.go | 16 +- .../gorm.io/gorm/schema/relationship.go | 15 +- lib/vendor/gorm.io/gorm/schema/schema.go | 229 +++--- lib/vendor/gorm.io/gorm/schema/serializer.go | 28 +- lib/vendor/gorm.io/gorm/schema/utils.go | 37 +- lib/vendor/gorm.io/gorm/statement.go | 29 +- lib/vendor/gorm.io/gorm/utils/utils.go | 22 +- lib/vendor/modules.txt | 2 +- notifier/go.mod | 2 +- notifier/go.sum | 4 +- notifier/vendor/gorm.io/gorm/README.md | 2 +- notifier/vendor/gorm.io/gorm/association.go | 4 +- notifier/vendor/gorm.io/gorm/chainable_api.go | 4 +- notifier/vendor/gorm.io/gorm/finisher_api.go | 4 +- notifier/vendor/gorm.io/gorm/generics.go | 18 +- notifier/vendor/gorm.io/gorm/logger/slog.go | 28 +- .../vendor/gorm.io/gorm/migrator/migrator.go | 4 + notifier/vendor/gorm.io/gorm/schema/schema.go | 10 + .../vendor/gorm.io/gorm/schema/serializer.go | 28 +- notifier/vendor/gorm.io/gorm/schema/utils.go | 26 +- notifier/vendor/gorm.io/gorm/utils/utils.go | 22 +- notifier/vendor/modules.txt | 2 +- policy-enforcer/go.mod | 2 +- policy-enforcer/go.sum | 4 +- policy-enforcer/vendor/gorm.io/gorm/README.md | 2 +- .../vendor/gorm.io/gorm/association.go | 86 ++- .../vendor/gorm.io/gorm/callbacks.go | 8 +- .../vendor/gorm.io/gorm/callbacks/create.go | 31 +- .../vendor/gorm.io/gorm/chainable_api.go | 4 +- .../vendor/gorm.io/gorm/clause/association.go | 35 + .../vendor/gorm.io/gorm/clause/set.go | 11 + .../vendor/gorm.io/gorm/finisher_api.go | 10 +- .../vendor/gorm.io/gorm/generics.go | 325 +++++++- policy-enforcer/vendor/gorm.io/gorm/gorm.go | 9 + .../vendor/gorm.io/gorm/logger/slog.go | 116 +++ .../vendor/gorm.io/gorm/migrator/migrator.go | 18 +- .../vendor/gorm.io/gorm/schema/field.go | 16 +- .../gorm.io/gorm/schema/relationship.go | 15 +- .../vendor/gorm.io/gorm/schema/schema.go | 229 +++--- .../vendor/gorm.io/gorm/schema/serializer.go | 28 +- .../vendor/gorm.io/gorm/schema/utils.go | 37 +- .../vendor/gorm.io/gorm/statement.go | 29 +- .../vendor/gorm.io/gorm/utils/utils.go | 22 +- policy-enforcer/vendor/modules.txt | 2 +- public-api/go.mod | 2 +- public-api/go.sum | 4 +- public-api/vendor/gorm.io/gorm/README.md | 2 +- public-api/vendor/gorm.io/gorm/association.go | 4 +- .../vendor/gorm.io/gorm/chainable_api.go | 4 +- .../vendor/gorm.io/gorm/finisher_api.go | 4 +- public-api/vendor/gorm.io/gorm/generics.go | 18 +- public-api/vendor/gorm.io/gorm/logger/slog.go | 28 +- .../vendor/gorm.io/gorm/migrator/migrator.go | 4 + .../vendor/gorm.io/gorm/schema/schema.go | 10 + .../vendor/gorm.io/gorm/schema/serializer.go | 28 +- .../vendor/gorm.io/gorm/schema/utils.go | 26 +- public-api/vendor/gorm.io/gorm/utils/utils.go | 22 +- public-api/vendor/modules.txt | 2 +- runtime-monitor/go.mod | 6 +- runtime-monitor/go.sum | 12 +- .../github.com/jackc/puddle/v2/CHANGELOG.md | 5 + .../github.com/jackc/puddle/v2/README.md | 2 +- .../github.com/jackc/puddle/v2/nanotime.go | 16 + .../jackc/puddle/v2/nanotime_time.go | 13 - .../jackc/puddle/v2/nanotime_unsafe.go | 12 - .../vendor/github.com/jackc/puddle/v2/pool.go | 20 +- .../gorm.io/driver/postgres/migrator.go | 71 +- .../gorm.io/driver/postgres/postgres.go | 22 +- runtime-monitor/vendor/gorm.io/gorm/README.md | 2 +- .../vendor/gorm.io/gorm/association.go | 86 ++- .../vendor/gorm.io/gorm/callbacks.go | 8 +- .../vendor/gorm.io/gorm/callbacks/create.go | 31 +- .../vendor/gorm.io/gorm/chainable_api.go | 4 +- .../vendor/gorm.io/gorm/clause/association.go | 35 + .../vendor/gorm.io/gorm/clause/set.go | 11 + .../vendor/gorm.io/gorm/finisher_api.go | 10 +- .../vendor/gorm.io/gorm/generics.go | 325 +++++++- runtime-monitor/vendor/gorm.io/gorm/gorm.go | 9 + .../vendor/gorm.io/gorm/logger/slog.go | 116 +++ .../vendor/gorm.io/gorm/migrator/migrator.go | 18 +- .../vendor/gorm.io/gorm/schema/field.go | 16 +- .../gorm.io/gorm/schema/relationship.go | 15 +- .../vendor/gorm.io/gorm/schema/schema.go | 229 +++--- .../vendor/gorm.io/gorm/schema/serializer.go | 28 +- .../vendor/gorm.io/gorm/schema/utils.go | 37 +- .../vendor/gorm.io/gorm/statement.go | 29 +- .../vendor/gorm.io/gorm/utils/utils.go | 22 +- runtime-monitor/vendor/modules.txt | 8 +- 325 files changed, 10422 insertions(+), 2762 deletions(-) create mode 100644 auth-center/vendor/gorm.io/gorm/clause/association.go create mode 100644 auth-center/vendor/gorm.io/gorm/logger/slog.go delete mode 100644 cluster-manager/vendor/github.com/jackc/pgservicefile/.travis.yml delete mode 100644 cluster-manager/vendor/github.com/jackc/pgx/v5/internal/anynil/anynil.go rename cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/{internal => }/ctxwatch/context_watcher.go (71%) create mode 100644 cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/ltree.go create mode 100644 cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/batch_results.go create mode 100644 cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/conn.go create mode 100644 cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/doc.go create mode 100644 cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/pool.go create mode 100644 cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/rows.go create mode 100644 cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/stat.go create mode 100644 cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/tracer.go create mode 100644 cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/tx.go create mode 100644 cluster-manager/vendor/github.com/jackc/puddle/v2/CHANGELOG.md create mode 100644 cluster-manager/vendor/github.com/jackc/puddle/v2/LICENSE create mode 100644 cluster-manager/vendor/github.com/jackc/puddle/v2/README.md create mode 100644 cluster-manager/vendor/github.com/jackc/puddle/v2/context.go create mode 100644 cluster-manager/vendor/github.com/jackc/puddle/v2/doc.go create mode 100644 cluster-manager/vendor/github.com/jackc/puddle/v2/internal/genstack/gen_stack.go create mode 100644 cluster-manager/vendor/github.com/jackc/puddle/v2/internal/genstack/stack.go create mode 100644 cluster-manager/vendor/github.com/jackc/puddle/v2/log.go create mode 100644 cluster-manager/vendor/github.com/jackc/puddle/v2/nanotime.go create mode 100644 cluster-manager/vendor/github.com/jackc/puddle/v2/pool.go create mode 100644 cluster-manager/vendor/github.com/jackc/puddle/v2/resource_list.go create mode 100644 cluster-manager/vendor/gorm.io/gorm/clause/association.go create mode 100644 cluster-manager/vendor/gorm.io/gorm/logger/slog.go create mode 100644 cs-manager/vendor/gorm.io/gorm/clause/association.go create mode 100644 cs-manager/vendor/gorm.io/gorm/logger/slog.go create mode 100644 lib/vendor/gorm.io/gorm/clause/association.go create mode 100644 lib/vendor/gorm.io/gorm/logger/slog.go create mode 100644 policy-enforcer/vendor/gorm.io/gorm/clause/association.go create mode 100644 policy-enforcer/vendor/gorm.io/gorm/logger/slog.go create mode 100644 runtime-monitor/vendor/github.com/jackc/puddle/v2/nanotime.go delete mode 100644 runtime-monitor/vendor/github.com/jackc/puddle/v2/nanotime_time.go delete mode 100644 runtime-monitor/vendor/github.com/jackc/puddle/v2/nanotime_unsafe.go create mode 100644 runtime-monitor/vendor/gorm.io/gorm/clause/association.go create mode 100644 runtime-monitor/vendor/gorm.io/gorm/logger/slog.go diff --git a/auth-center/go.mod b/auth-center/go.mod index 192dd4b7..76bea461 100644 --- a/auth-center/go.mod +++ b/auth-center/go.mod @@ -22,7 +22,7 @@ require ( google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1 google.golang.org/protobuf v1.36.6 gorm.io/driver/postgres v1.6.0 - gorm.io/gorm v1.30.0 + gorm.io/gorm v1.31.1 ) require ( diff --git a/auth-center/go.sum b/auth-center/go.sum index 11024fcd..6cac529e 100644 --- a/auth-center/go.sum +++ b/auth-center/go.sum @@ -673,8 +673,8 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= -gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs= -gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= honnef.co/go/tools v0.6.1 h1:R094WgE8K4JirYjBaOpz/AvTyUu/3wbmAoskKN/pxTI= honnef.co/go/tools v0.6.1/go.mod h1:3puzxxljPCe8RGJX7BIy1plGbxEOZni5mR2aXe3/uk4= mvdan.cc/gofumpt v0.7.0 h1:bg91ttqXmi9y2xawvkuMXyvAA/1ZGJqYAEGjXuP0JXU= diff --git a/auth-center/vendor/gorm.io/gorm/README.md b/auth-center/vendor/gorm.io/gorm/README.md index 745dad60..24eb84c9 100644 --- a/auth-center/vendor/gorm.io/gorm/README.md +++ b/auth-center/vendor/gorm.io/gorm/README.md @@ -3,7 +3,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) -[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) +[![test status](https://github.com/go-gorm/gorm/actions/workflows/tests.yml/badge.svg)](https://github.com/go-gorm/gorm/actions) [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) [![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) diff --git a/auth-center/vendor/gorm.io/gorm/association.go b/auth-center/vendor/gorm.io/gorm/association.go index e3f51d17..3a4e0e25 100644 --- a/auth-center/vendor/gorm.io/gorm/association.go +++ b/auth-center/vendor/gorm.io/gorm/association.go @@ -19,10 +19,10 @@ type Association struct { } func (db *DB) Association(column string) *Association { - association := &Association{DB: db} + association := &Association{DB: db, Unscope: db.Statement.Unscoped} table := db.Statement.Table - if err := db.Statement.Parse(db.Statement.Model); err == nil { + if association.Error = db.Statement.Parse(db.Statement.Model); association.Error == nil { db.Statement.Table = table association.Relationship = db.Statement.Schema.Relationships.Relations[column] @@ -34,8 +34,6 @@ func (db *DB) Association(column string) *Association { for db.Statement.ReflectValue.Kind() == reflect.Ptr { db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() } - } else { - association.Error = err } return association @@ -58,6 +56,8 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro } func (association *Association) Append(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { switch association.Relationship.Type { case schema.HasOne, schema.BelongsTo: @@ -73,6 +73,8 @@ func (association *Association) Append(values ...interface{}) error { } func (association *Association) Replace(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { reflectValue := association.DB.Statement.ReflectValue rel := association.Relationship @@ -97,7 +99,7 @@ func (association *Association) Replace(values ...interface{}) error { return association.Error } - // set old associations's foreign key to null + // set old association's foreign key to null switch rel.Type { case schema.BelongsTo: if len(values) == 0 { @@ -195,6 +197,8 @@ func (association *Association) Replace(values ...interface{}) error { } func (association *Association) Delete(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { var ( reflectValue = association.DB.Statement.ReflectValue @@ -300,7 +304,7 @@ func (association *Association) Delete(values ...interface{}) error { } if association.Error == nil { - // clean up deleted values's foreign key + // clean up deleted values' foreign key relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { @@ -431,10 +435,49 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } } + processMap := func(mapv reflect.Value) { + child := reflect.New(association.Relationship.FieldSchema.ModelType) + + switch association.Relationship.Type { + case schema.HasMany: + for _, ref := range association.Relationship.References { + key := reflect.ValueOf(ref.ForeignKey.DBName) + if ref.OwnPrimaryKey { + v := ref.PrimaryKey.ReflectValueOf(association.DB.Statement.Context, source) + mapv.SetMapIndex(key, v) + } else if ref.PrimaryValue != "" { + mapv.SetMapIndex(key, reflect.ValueOf(ref.PrimaryValue)) + } + } + association.Error = association.DB.Session(&Session{ + NewDB: true, + }).Model(child.Interface()).Create(mapv.Interface()).Error + case schema.Many2Many: + association.Error = association.DB.Session(&Session{ + NewDB: true, + }).Model(child.Interface()).Create(mapv.Interface()).Error + + for _, key := range mapv.MapKeys() { + k := strings.ToLower(key.String()) + if f, ok := association.Relationship.FieldSchema.FieldsByDBName[k]; ok { + _ = f.Set(association.DB.Statement.Context, child, mapv.MapIndex(key).Interface()) + } + } + appendToFieldValues(child) + } + } + switch rv.Kind() { + case reflect.Map: + processMap(rv) case reflect.Slice, reflect.Array: for i := 0; i < rv.Len(); i++ { - appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr()) + elem := reflect.Indirect(rv.Index(i)) + if elem.Kind() == reflect.Map { + processMap(elem) + continue + } + appendToFieldValues(elem.Addr()) } case reflect.Struct: if !rv.CanAddr() { @@ -591,3 +634,32 @@ func (association *Association) buildCondition() *DB { return tx } + +func expandValues(values ...any) (results []any) { + appendToResult := func(rv reflect.Value) { + // unwrap interface + if rv.IsValid() && rv.Kind() == reflect.Interface { + rv = rv.Elem() + } + if rv.IsValid() && rv.Kind() == reflect.Struct { + p := reflect.New(rv.Type()) + p.Elem().Set(rv) + results = append(results, p.Interface()) + } else if rv.IsValid() { + results = append(results, rv.Interface()) + } + } + + // Process each argument; if an argument is a slice/array, expand its elements + for _, value := range values { + rv := reflect.ValueOf(value) + if rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array { + for i := 0; i < rv.Len(); i++ { + appendToResult(rv.Index(i)) + } + } else { + appendToResult(rv) + } + } + return +} diff --git a/auth-center/vendor/gorm.io/gorm/callbacks.go b/auth-center/vendor/gorm.io/gorm/callbacks.go index 50b5b0e9..bd97f040 100644 --- a/auth-center/vendor/gorm.io/gorm/callbacks.go +++ b/auth-center/vendor/gorm.io/gorm/callbacks.go @@ -89,10 +89,16 @@ func (p *processor) Execute(db *DB) *DB { resetBuildClauses = true } - if optimizer, ok := db.Statement.Dest.(StatementModifier); ok { + if optimizer, ok := stmt.Dest.(StatementModifier); ok { optimizer.ModifyStatement(stmt) } + if db.DefaultContextTimeout > 0 { + if _, ok := stmt.Context.Deadline(); !ok { + stmt.Context, _ = context.WithTimeout(stmt.Context, db.DefaultContextTimeout) + } + } + // assign model values if stmt.Model == nil { stmt.Model = stmt.Dest diff --git a/auth-center/vendor/gorm.io/gorm/callbacks/create.go b/auth-center/vendor/gorm.io/gorm/callbacks/create.go index d8701f51..e5929adb 100644 --- a/auth-center/vendor/gorm.io/gorm/callbacks/create.go +++ b/auth-center/vendor/gorm.io/gorm/callbacks/create.go @@ -53,9 +53,13 @@ func Create(config *Config) func(db *gorm.DB) { if _, ok := db.Statement.Clauses["RETURNING"]; !ok { fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue)) for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { - fromColumns = append(fromColumns, clause.Column{Name: field.DBName}) + if field.Readable { + fromColumns = append(fromColumns, clause.Column{Name: field.DBName}) + } + } + if len(fromColumns) > 0 { + db.Statement.AddClause(clause.Returning{Columns: fromColumns}) } - db.Statement.AddClause(clause.Returning{Columns: fromColumns}) } } } @@ -76,8 +80,11 @@ func Create(config *Config) func(db *gorm.DB) { ok, mode := hasReturning(db, supportReturning) if ok { if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { - if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { + onConflict, _ := c.Expression.(clause.OnConflict) + if onConflict.DoNothing { mode |= gorm.ScanOnConflictDoNothing + } else if len(onConflict.DoUpdates) > 0 || onConflict.UpdateAll { + mode |= gorm.ScanUpdate } } @@ -122,6 +129,16 @@ func Create(config *Config) func(db *gorm.DB) { pkFieldName = "@id" ) + if db.Statement.Schema != nil { + if db.Statement.Schema.PrioritizedPrimaryField == nil || + !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue || + !db.Statement.Schema.PrioritizedPrimaryField.Readable { + return + } + pkField = db.Statement.Schema.PrioritizedPrimaryField + pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName + } + insertID, err := result.LastInsertId() insertOk := err == nil && insertID > 0 @@ -132,14 +149,6 @@ func Create(config *Config) func(db *gorm.DB) { return } - if db.Statement.Schema != nil { - if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - return - } - pkField = db.Statement.Schema.PrioritizedPrimaryField - pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName - } - // append @id column with value for auto-increment primary key // the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1 switch values := db.Statement.Dest.(type) { diff --git a/auth-center/vendor/gorm.io/gorm/chainable_api.go b/auth-center/vendor/gorm.io/gorm/chainable_api.go index 8a6aea34..8f6113cc 100644 --- a/auth-center/vendor/gorm.io/gorm/chainable_api.go +++ b/auth-center/vendor/gorm.io/gorm/chainable_api.go @@ -178,7 +178,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { - tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsInvalidDBNameChar) } else { tx.Statement.Omits = columns } @@ -283,7 +283,7 @@ func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() - fields := strings.FieldsFunc(name, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(name, utils.IsInvalidDBNameChar) tx.Statement.AddClause(clause.GroupBy{ Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, }) diff --git a/auth-center/vendor/gorm.io/gorm/clause/association.go b/auth-center/vendor/gorm.io/gorm/clause/association.go new file mode 100644 index 00000000..a9bf7eb0 --- /dev/null +++ b/auth-center/vendor/gorm.io/gorm/clause/association.go @@ -0,0 +1,35 @@ +package clause + +// AssociationOpType represents association operation types +type AssociationOpType int + +const ( + OpUnlink AssociationOpType = iota // Unlink association + OpDelete // Delete association records + OpUpdate // Update association records + OpCreate // Create association records with assignments +) + +// Association represents an association operation +type Association struct { + Association string // Association name + Type AssociationOpType // Operation type + Conditions []Expression // Filter conditions + Set []Assignment // Assignment operations (for Update and Create) + Values []interface{} // Values for Create operation +} + +// AssociationAssigner is an interface for association operation providers +type AssociationAssigner interface { + AssociationAssignments() []Association +} + +// Assignments implements the Assigner interface so that AssociationOperation can be used as a Set method parameter +func (ao Association) Assignments() []Assignment { + return []Assignment{} +} + +// AssociationAssignments implements the AssociationAssigner interface +func (ao Association) AssociationAssignments() []Association { + return []Association{ao} +} diff --git a/auth-center/vendor/gorm.io/gorm/clause/set.go b/auth-center/vendor/gorm.io/gorm/clause/set.go index 75eb6bdd..cb5f36a0 100644 --- a/auth-center/vendor/gorm.io/gorm/clause/set.go +++ b/auth-center/vendor/gorm.io/gorm/clause/set.go @@ -9,6 +9,11 @@ type Assignment struct { Value interface{} } +// Assigner assignments provider interface +type Assigner interface { + Assignments() []Assignment +} + func (set Set) Name() string { return "SET" } @@ -37,6 +42,9 @@ func (set Set) MergeClause(clause *Clause) { clause.Expression = Set(copiedAssignments) } +// Assignments implements Assigner for Set. +func (set Set) Assignments() []Assignment { return []Assignment(set) } + func Assignments(values map[string]interface{}) Set { keys := make([]string, 0, len(values)) for key := range values { @@ -58,3 +66,6 @@ func AssignmentColumns(values []string) Set { } return assignments } + +// Assignments implements Assigner for a single Assignment. +func (a Assignment) Assignments() []Assignment { return []Assignment{a} } diff --git a/auth-center/vendor/gorm.io/gorm/finisher_api.go b/auth-center/vendor/gorm.io/gorm/finisher_api.go index 57809d17..e9e35f1b 100644 --- a/auth-center/vendor/gorm.io/gorm/finisher_api.go +++ b/auth-center/vendor/gorm.io/gorm/finisher_api.go @@ -465,7 +465,7 @@ func (db *DB) Count(count *int64) (tx *DB) { if len(tx.Statement.Selects) == 1 { dbName := tx.Statement.Selects[0] - fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(dbName, utils.IsInvalidDBNameChar) if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) { if tx.Statement.Parse(tx.Statement.Model) == nil { if f := tx.Statement.Schema.LookUpField(dbName); f != nil { @@ -564,7 +564,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { } if len(tx.Statement.Selects) != 1 { - fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(column, utils.IsInvalidDBNameChar) tx.Statement.AddClauseIfNotExists(clause.Select{ Distinct: tx.Statement.Distinct, Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, @@ -675,9 +675,9 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { } ctx := tx.Statement.Context - if _, ok := ctx.Deadline(); !ok { - if db.Config.DefaultTransactionTimeout > 0 { - ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout) + if db.DefaultTransactionTimeout > 0 { + if _, ok := ctx.Deadline(); !ok { + ctx, _ = context.WithTimeout(ctx, db.DefaultTransactionTimeout) } } diff --git a/auth-center/vendor/gorm.io/gorm/generics.go b/auth-center/vendor/gorm.io/gorm/generics.go index ad2d063f..166d1520 100644 --- a/auth-center/vendor/gorm.io/gorm/generics.go +++ b/auth-center/vendor/gorm.io/gorm/generics.go @@ -3,12 +3,15 @@ package gorm import ( "context" "database/sql" + "errors" "fmt" + "reflect" "sort" "strings" "gorm.io/gorm/clause" "gorm.io/gorm/logger" + "gorm.io/gorm/schema" ) type result struct { @@ -35,10 +38,34 @@ type Interface[T any] interface { } type CreateInterface[T any] interface { - ChainInterface[T] + ExecInterface[T] + // chain methods available at start; Select/Omit keep CreateInterface to allow Create chaining + Scopes(scopes ...func(db *Statement)) ChainInterface[T] + Where(query interface{}, args ...interface{}) ChainInterface[T] + Not(query interface{}, args ...interface{}) ChainInterface[T] + Or(query interface{}, args ...interface{}) ChainInterface[T] + Limit(offset int) ChainInterface[T] + Offset(offset int) ChainInterface[T] + Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] + Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] + Select(query string, args ...interface{}) CreateInterface[T] + Omit(columns ...string) CreateInterface[T] + MapColumns(m map[string]string) ChainInterface[T] + Distinct(args ...interface{}) ChainInterface[T] + Group(name string) ChainInterface[T] + Having(query interface{}, args ...interface{}) ChainInterface[T] + Order(value interface{}) ChainInterface[T] + Build(builder clause.Builder) + + Delete(ctx context.Context) (rowsAffected int, err error) + Update(ctx context.Context, name string, value any) (rowsAffected int, err error) + Updates(ctx context.Context, t T) (rowsAffected int, err error) + Count(ctx context.Context, column string) (result int64, err error) + Table(name string, args ...interface{}) CreateInterface[T] Create(ctx context.Context, r *T) error CreateInBatches(ctx context.Context, r *[]T, batchSize int) error + Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T] } type ChainInterface[T any] interface { @@ -58,15 +85,28 @@ type ChainInterface[T any] interface { Group(name string) ChainInterface[T] Having(query interface{}, args ...interface{}) ChainInterface[T] Order(value interface{}) ChainInterface[T] + Set(assignments ...clause.Assigner) SetUpdateOnlyInterface[T] Build(builder clause.Builder) + Table(name string, args ...interface{}) ChainInterface[T] Delete(ctx context.Context) (rowsAffected int, err error) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) Updates(ctx context.Context, t T) (rowsAffected int, err error) Count(ctx context.Context, column string) (result int64, err error) } +// SetUpdateOnlyInterface is returned by Set after chaining; only Update is allowed +type SetUpdateOnlyInterface[T any] interface { + Update(ctx context.Context) (rowsAffected int, err error) +} + +// SetCreateOrUpdateInterface is returned by Set at start; Create or Update are allowed +type SetCreateOrUpdateInterface[T any] interface { + Create(ctx context.Context) error + Update(ctx context.Context) (rowsAffected int, err error) +} + type ExecInterface[T any] interface { Scan(ctx context.Context, r interface{}) error First(context.Context) (T, error) @@ -142,13 +182,15 @@ func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] { return execG[T]{g: &g[T]{ db: c.db, ops: append(c.ops, func(db *DB) *DB { - return db.Raw(sql, values...) + var r T + return db.Model(r).Raw(sql, values...) }), }} } func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error { - return c.apply(ctx).Exec(sql, values...).Error + var r T + return c.apply(ctx).Model(r).Exec(sql, values...).Error } type createG[T any] struct { @@ -161,6 +203,22 @@ func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { })} } +func (c createG[T]) Select(query string, args ...interface{}) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Select(query, args...) + })} +} + +func (c createG[T]) Omit(columns ...string) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Omit(columns...) + })} +} + +func (c createG[T]) Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T] { + return c.processSet(assignments...) +} + func (c createG[T]) Create(ctx context.Context, r *T) error { return c.g.apply(ctx).Create(r).Error } @@ -187,6 +245,12 @@ func (c chainG[T]) with(v op) chainG[T] { } } +func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Table(name, args...) + }) +} + func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { return c.with(func(db *DB) *DB { for _, fc := range scopes { @@ -196,12 +260,6 @@ func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { }) } -func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] { - return c.with(func(db *DB) *DB { - return db.Table(name, args...) - }) -} - func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Where(query, args...) @@ -388,6 +446,10 @@ func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] { }) } +func (c chainG[T]) Set(assignments ...clause.Assigner) SetUpdateOnlyInterface[T] { + return c.processSet(assignments...) +} + func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Distinct(args...) @@ -425,12 +487,12 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err relation, ok := db.Statement.Schema.Relationships.Relations[association] if !ok { if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 { - relationships := db.Statement.Schema.Relationships + relationships := &db.Statement.Schema.Relationships for _, field := range preloadFields { var ok bool relation, ok = relationships.Relations[field] if ok { - relationships = relation.FieldSchema.Relationships + relationships = &relation.FieldSchema.Relationships } else { db.AddError(fmt.Errorf("relation %s not found", association)) return nil @@ -567,7 +629,7 @@ func (g execG[T]) First(ctx context.Context) (T, error) { func (g execG[T]) Scan(ctx context.Context, result interface{}) error { var r T - err := g.g.apply(ctx).Model(r).Find(&result).Error + err := g.g.apply(ctx).Model(r).Find(result).Error return err } @@ -603,3 +665,242 @@ func (g execG[T]) Row(ctx context.Context) *sql.Row { func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) { return g.g.apply(ctx).Rows() } + +func (c chainG[T]) processSet(items ...clause.Assigner) setCreateOrUpdateG[T] { + var ( + assigns []clause.Assignment + assocOps []clause.Association + ) + + for _, item := range items { + // Check if it's an AssociationAssigner + if assocAssigner, ok := item.(clause.AssociationAssigner); ok { + assocOps = append(assocOps, assocAssigner.AssociationAssignments()...) + } else { + assigns = append(assigns, item.Assignments()...) + } + } + + return setCreateOrUpdateG[T]{ + c: c, + assigns: assigns, + assocOps: assocOps, + } +} + +// setCreateOrUpdateG[T] is a struct that holds operations to be executed in a batch. +// It supports regular assignments and association operations. +type setCreateOrUpdateG[T any] struct { + c chainG[T] + assigns []clause.Assignment + assocOps []clause.Association +} + +func (s setCreateOrUpdateG[T]) Update(ctx context.Context) (rowsAffected int, err error) { + // Execute association operations + for _, assocOp := range s.assocOps { + if err := s.executeAssociationOperation(ctx, assocOp); err != nil { + return 0, err + } + } + + // Execute assignment operations + if len(s.assigns) > 0 { + var r T + res := s.c.g.apply(ctx).Model(r).Clauses(clause.Set(s.assigns)).Updates(map[string]interface{}{}) + return int(res.RowsAffected), res.Error + } + + return 0, nil +} + +func (s setCreateOrUpdateG[T]) Create(ctx context.Context) error { + // Execute association operations + for _, assocOp := range s.assocOps { + if err := s.executeAssociationOperation(ctx, assocOp); err != nil { + return err + } + } + + // Execute assignment operations + if len(s.assigns) > 0 { + data := make(map[string]interface{}, len(s.assigns)) + for _, a := range s.assigns { + data[a.Column.Name] = a.Value + } + var r T + return s.c.g.apply(ctx).Model(r).Create(data).Error + } + + return nil +} + +// executeAssociationOperation executes an association operation +func (s setCreateOrUpdateG[T]) executeAssociationOperation(ctx context.Context, op clause.Association) error { + var r T + base := s.c.g.apply(ctx).Model(r) + + switch op.Type { + case clause.OpCreate: + return s.handleAssociationCreate(ctx, base, op) + case clause.OpUnlink, clause.OpDelete, clause.OpUpdate: + return s.handleAssociation(ctx, base, op) + default: + return fmt.Errorf("unknown association operation type: %v", op.Type) + } +} + +func (s setCreateOrUpdateG[T]) handleAssociationCreate(ctx context.Context, base *DB, op clause.Association) error { + if len(op.Set) > 0 { + return s.handleAssociationForOwners(base, ctx, func(owner T, assoc *Association) error { + data := make(map[string]interface{}, len(op.Set)) + for _, a := range op.Set { + data[a.Column.Name] = a.Value + } + return assoc.Append(data) + }, op.Association) + } + + return s.handleAssociationForOwners(base, ctx, func(owner T, assoc *Association) error { + return assoc.Append(op.Values...) + }, op.Association) +} + +// handleAssociationForOwners is a helper function that handles associations for all owners +func (s setCreateOrUpdateG[T]) handleAssociationForOwners(base *DB, ctx context.Context, handler func(owner T, association *Association) error, associationName string) error { + var owners []T + if err := base.Find(&owners).Error; err != nil { + return err + } + + for _, owner := range owners { + assoc := base.Session(&Session{NewDB: true, Context: ctx}).Model(&owner).Association(associationName) + if assoc.Error != nil { + return assoc.Error + } + + if err := handler(owner, assoc); err != nil { + return err + } + } + return nil +} + +func (s setCreateOrUpdateG[T]) handleAssociation(ctx context.Context, base *DB, op clause.Association) error { + assoc := base.Association(op.Association) + if assoc.Error != nil { + return assoc.Error + } + + var ( + rel = assoc.Relationship + assocModel = reflect.New(rel.FieldSchema.ModelType).Interface() + fkNil = map[string]any{} + setMap = make(map[string]any, len(op.Set)) + ownerPKNames []string + ownerFKNames []string + primaryColumns []any + foreignColumns []any + ) + + for _, a := range op.Set { + setMap[a.Column.Name] = a.Value + } + + for _, ref := range rel.References { + fkNil[ref.ForeignKey.DBName] = nil + + if ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownerPKNames = append(ownerPKNames, ref.PrimaryKey.DBName) + primaryColumns = append(primaryColumns, clause.Column{Name: ref.PrimaryKey.DBName}) + foreignColumns = append(foreignColumns, clause.Column{Name: ref.ForeignKey.DBName}) + } else if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownerFKNames = append(ownerFKNames, ref.ForeignKey.DBName) + primaryColumns = append(primaryColumns, clause.Column{Name: ref.PrimaryKey.DBName}) + } + } + + assocDB := s.c.g.db.Session(&Session{NewDB: true, Context: ctx}).Model(assocModel).Where(op.Conditions) + + switch rel.Type { + case schema.HasOne, schema.HasMany: + assocDB = assocDB.Where("? IN (?)", foreignColumns, base.Select(ownerPKNames)) + switch op.Type { + case clause.OpUnlink: + return assocDB.Updates(fkNil).Error + case clause.OpDelete: + return assocDB.Delete(assocModel).Error + case clause.OpUpdate: + return assocDB.Updates(setMap).Error + } + case schema.BelongsTo: + switch op.Type { + case clause.OpDelete: + return base.Transaction(func(tx *DB) error { + assocDB.Statement.ConnPool = tx.Statement.ConnPool + base.Statement.ConnPool = tx.Statement.ConnPool + + if err := assocDB.Where("? IN (?)", primaryColumns, base.Select(ownerFKNames)).Delete(assocModel).Error; err != nil { + return err + } + return base.Updates(fkNil).Error + }) + case clause.OpUnlink: + return base.Updates(fkNil).Error + case clause.OpUpdate: + return assocDB.Where("? IN (?)", primaryColumns, base.Select(ownerFKNames)).Updates(setMap).Error + } + case schema.Many2Many: + joinModel := reflect.New(rel.JoinTable.ModelType).Interface() + joinDB := base.Session(&Session{NewDB: true, Context: ctx}).Model(joinModel) + + // EXISTS owners: owners.pk = join.owner_fk for all owner refs + ownersExists := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.Schema.Table).Select("1") + for _, ref := range rel.References { + if ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownersExists = ownersExists.Where(clause.Eq{ + Column: clause.Column{Table: rel.Schema.Table, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + }) + } + } + + // EXISTS related: related.pk = join.rel_fk for all related refs, plus optional conditions + relatedExists := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.FieldSchema.Table).Select("1") + for _, ref := range rel.References { + if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + relatedExists = relatedExists.Where(clause.Eq{ + Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + }) + } + } + relatedExists = relatedExists.Where(op.Conditions) + + switch op.Type { + case clause.OpUnlink, clause.OpDelete: + joinDB = joinDB.Where("EXISTS (?)", ownersExists) + if len(op.Conditions) > 0 { + joinDB = joinDB.Where("EXISTS (?)", relatedExists) + } + return joinDB.Delete(nil).Error + case clause.OpUpdate: + // Update related table rows that have join rows matching owners + relatedDB := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.FieldSchema.Table).Where(op.Conditions) + + // correlated join subquery: join.rel_fk = related.pk AND EXISTS owners + joinSub := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.JoinTable.Table).Select("1") + for _, ref := range rel.References { + if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + joinSub = joinSub.Where(clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, + }) + } + } + joinSub = joinSub.Where("EXISTS (?)", ownersExists) + return relatedDB.Where("EXISTS (?)", joinSub).Updates(setMap).Error + } + } + return errors.New("unsupported relationship") +} diff --git a/auth-center/vendor/gorm.io/gorm/gorm.go b/auth-center/vendor/gorm.io/gorm/gorm.go index 67889262..a209bb09 100644 --- a/auth-center/vendor/gorm.io/gorm/gorm.go +++ b/auth-center/vendor/gorm.io/gorm/gorm.go @@ -23,6 +23,7 @@ type Config struct { // You can disable it by setting `SkipDefaultTransaction` to true SkipDefaultTransaction bool DefaultTransactionTimeout time.Duration + DefaultContextTimeout time.Duration // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer @@ -137,6 +138,14 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { return isConfig && !isConfig2 }) + if len(opts) > 0 { + if c, ok := opts[0].(*Config); ok { + config = c + } else { + opts = append([]Option{config}, opts...) + } + } + var skipAfterInitialize bool for _, opt := range opts { if opt != nil { diff --git a/auth-center/vendor/gorm.io/gorm/logger/slog.go b/auth-center/vendor/gorm.io/gorm/logger/slog.go new file mode 100644 index 00000000..27c74683 --- /dev/null +++ b/auth-center/vendor/gorm.io/gorm/logger/slog.go @@ -0,0 +1,116 @@ +//go:build go1.21 + +package logger + +import ( + "context" + "errors" + "fmt" + "log/slog" + "time" + + "gorm.io/gorm/utils" +) + +type slogLogger struct { + Logger *slog.Logger + LogLevel LogLevel + SlowThreshold time.Duration + Parameterized bool + Colorful bool // Ignored in slog + IgnoreRecordNotFoundError bool +} + +func NewSlogLogger(logger *slog.Logger, config Config) Interface { + return &slogLogger{ + Logger: logger, + LogLevel: config.LogLevel, + SlowThreshold: config.SlowThreshold, + Parameterized: config.ParameterizedQueries, + IgnoreRecordNotFoundError: config.IgnoreRecordNotFoundError, + } +} + +func (l *slogLogger) LogMode(level LogLevel) Interface { + newLogger := *l + newLogger.LogLevel = level + return &newLogger +} + +func (l *slogLogger) Info(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Info { + l.log(ctx, slog.LevelInfo, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Warn(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Warn { + l.log(ctx, slog.LevelWarn, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Error(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Error { + l.log(ctx, slog.LevelError, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + if l.LogLevel <= Silent { + return + } + + elapsed := time.Since(begin) + sql, rows := fc() + fields := []slog.Attr{ + slog.String("duration", fmt.Sprintf("%.3fms", float64(elapsed.Nanoseconds())/1e6)), + slog.String("sql", sql), + } + + if rows != -1 { + fields = append(fields, slog.Int64("rows", rows)) + } + + switch { + case err != nil && (!l.IgnoreRecordNotFoundError || !errors.Is(err, ErrRecordNotFound)): + fields = append(fields, slog.String("error", err.Error())) + l.log(ctx, slog.LevelError, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + + case l.SlowThreshold != 0 && elapsed > l.SlowThreshold: + l.log(ctx, slog.LevelWarn, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + + case l.LogLevel >= Info: + l.log(ctx, slog.LevelInfo, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + } +} + +func (l *slogLogger) log(ctx context.Context, level slog.Level, msg string, args ...any) { + if ctx == nil { + ctx = context.Background() + } + + if !l.Logger.Enabled(ctx, level) { + return + } + + r := slog.NewRecord(time.Now(), level, msg, utils.CallerFrame().PC) + r.Add(args...) + _ = l.Logger.Handler().Handle(ctx, r) +} + +// ParamsFilter filter params +func (l *slogLogger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { + if l.Parameterized { + return sql, nil + } + return sql, params +} diff --git a/auth-center/vendor/gorm.io/gorm/migrator/migrator.go b/auth-center/vendor/gorm.io/gorm/migrator/migrator.go index cec4e30f..35107d57 100644 --- a/auth-center/vendor/gorm.io/gorm/migrator/migrator.go +++ b/auth-center/vendor/gorm.io/gorm/migrator/migrator.go @@ -474,7 +474,6 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy // found, smart migrate fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) realDataType := strings.ToLower(columnType.DatabaseTypeName()) - var ( alterColumn bool isSameType = fullDataType == realDataType @@ -513,8 +512,19 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } } + } - // check precision + // check precision + if realDataType == "decimal" || realDataType == "numeric" && + regexp.MustCompile(realDataType+`\(.*\)`).FindString(fullDataType) != "" { // if realDataType has no precision,ignore + precision, scale, ok := columnType.DecimalSize() + if ok { + if !strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d,%d)", realDataType, precision, scale)) && + !strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d)", realDataType, precision)) { + alterColumn = true + } + } + } else { if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { alterColumn = true @@ -550,6 +560,10 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy v1, _ := strconv.ParseBool(dv) v2, _ := strconv.ParseBool(field.DefaultValue) alterColumn = v1 != v2 + case schema.String: + if dv != field.DefaultValue && dv != strings.Trim(field.DefaultValue, "'\"") { + alterColumn = true + } default: alterColumn = dv != field.DefaultValue } diff --git a/auth-center/vendor/gorm.io/gorm/schema/field.go b/auth-center/vendor/gorm.io/gorm/schema/field.go index a6ff1a72..de797402 100644 --- a/auth-center/vendor/gorm.io/gorm/schema/field.go +++ b/auth-center/vendor/gorm.io/gorm/schema/field.go @@ -448,16 +448,17 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } // create valuer, setter when parse struct -func (field *Field) setupValuerAndSetter() { +func (field *Field) setupValuerAndSetter(modelType reflect.Type) { // Setup NewValuePool field.setupNewValuePool() // ValueOf returns field's value and if it is zero fieldIndex := field.StructField.Index[0] switch { - case len(field.StructField.Index) == 1 && fieldIndex > 0: - field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) { - fieldValue := reflect.Indirect(value).Field(fieldIndex) + case len(field.StructField.Index) == 1 && fieldIndex >= 0: + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + v = reflect.Indirect(v) + fieldValue := v.Field(fieldIndex) return fieldValue.Interface(), fieldValue.IsZero() } default: @@ -504,9 +505,10 @@ func (field *Field) setupValuerAndSetter() { // ReflectValueOf returns field's reflect value switch { - case len(field.StructField.Index) == 1 && fieldIndex > 0: - field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value { - return reflect.Indirect(value).Field(fieldIndex) + case len(field.StructField.Index) == 1 && fieldIndex >= 0: + field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { + v = reflect.Indirect(v) + return v.Field(fieldIndex) } default: field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { diff --git a/auth-center/vendor/gorm.io/gorm/schema/relationship.go b/auth-center/vendor/gorm.io/gorm/schema/relationship.go index f1ace924..0535bba4 100644 --- a/auth-center/vendor/gorm.io/gorm/schema/relationship.go +++ b/auth-center/vendor/gorm.io/gorm/schema/relationship.go @@ -75,9 +75,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { } ) - cacheStore := schema.cacheStore - - if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { + if relation.FieldSchema, err = getOrParse(fieldValue, schema.cacheStore, schema.namer); err != nil { schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err) return nil } @@ -147,6 +145,9 @@ func hasPolymorphicRelation(tagSettings map[string]string) bool { } func (schema *Schema) setRelation(relation *Relationship) { + schema.Relationships.Mux.Lock() + defer schema.Relationships.Mux.Unlock() + // set non-embedded relation if rel := schema.Relationships.Relations[relation.Name]; rel != nil { if len(rel.Field.BindNames) > 1 { @@ -590,6 +591,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu // build references for idx, foreignField := range foreignFields { // use same data type for foreign keys + schema.Relationships.Mux.Lock() + if schema != foreignField.Schema { + foreignField.Schema.Relationships.Mux.Lock() + } if copyableDataType(primaryFields[idx].DataType) { foreignField.DataType = primaryFields[idx].DataType } @@ -597,6 +602,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu if foreignField.Size == 0 { foreignField.Size = primaryFields[idx].Size } + schema.Relationships.Mux.Unlock() + if schema != foreignField.Schema { + foreignField.Schema.Relationships.Mux.Unlock() + } relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], diff --git a/auth-center/vendor/gorm.io/gorm/schema/schema.go b/auth-center/vendor/gorm.io/gorm/schema/schema.go index db236797..09697a7a 100644 --- a/auth-center/vendor/gorm.io/gorm/schema/schema.go +++ b/auth-center/vendor/gorm.io/gorm/schema/schema.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "go/ast" + "path" "reflect" "strings" "sync" @@ -59,14 +60,14 @@ type Schema struct { cacheStore *sync.Map } -func (schema Schema) String() string { +func (schema *Schema) String() string { if schema.ModelType.Name() == "" { return fmt.Sprintf("%s(%s)", schema.Name, schema.Table) } return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name()) } -func (schema Schema) MakeSlice() reflect.Value { +func (schema *Schema) MakeSlice() reflect.Value { slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20) results := reflect.New(slice.Type()) results.Elem().Set(slice) @@ -74,13 +75,23 @@ func (schema Schema) MakeSlice() reflect.Value { return results } -func (schema Schema) LookUpField(name string) *Field { +func (schema *Schema) LookUpField(name string) *Field { if field, ok := schema.FieldsByDBName[name]; ok { return field } if field, ok := schema.FieldsByName[name]; ok { return field } + + // Lookup field using namer-driven ColumnName + if schema.namer == nil { + return nil + } + namerColumnName := schema.namer.ColumnName(schema.Table, name) + if field, ok := schema.FieldsByDBName[namerColumnName]; ok { + return field + } + return nil } @@ -92,10 +103,7 @@ func (schema Schema) LookUpField(name string) *Field { // } // ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID") // } -func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { - if len(bindNames) == 0 { - return nil - } +func (schema *Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { for i := len(bindNames) - 1; i >= 0; i-- { find := strings.Join(bindNames[:i], ".") + "." + name if field, ok := schema.FieldsByBindName[find]; ok { @@ -113,6 +121,14 @@ type TablerWithNamer interface { TableName(Namer) string } +var callbackTypes = []callbackType{ + callbackTypeBeforeCreate, callbackTypeAfterCreate, + callbackTypeBeforeUpdate, callbackTypeAfterUpdate, + callbackTypeBeforeSave, callbackTypeAfterSave, + callbackTypeBeforeDelete, callbackTypeAfterDelete, + callbackTypeAfterFind, +} + // Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { return ParseWithSpecialTableName(dest, cacheStore, namer, "") @@ -124,34 +140,33 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - value := reflect.ValueOf(dest) - if value.Kind() == reflect.Ptr && value.IsNil() { - value = reflect.New(value.Type().Elem()) - } - modelType := reflect.Indirect(value).Type() - - if modelType.Kind() == reflect.Interface { - modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() - } - - for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType := reflect.ValueOf(dest).Type() + if modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } if modelType.Kind() != reflect.Struct { - if modelType.PkgPath() == "" { - return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + if modelType.Kind() == reflect.Interface { + modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() + } + + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } - return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } // Cache the Schema for performance, // Use the modelType or modelType + schemaTable (if it present) as cache key. - var schemaCacheKey interface{} + var schemaCacheKey interface{} = modelType if specialTableName != "" { schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName) - } else { - schemaCacheKey = modelType } // Load exist schema cache, return if exists @@ -162,28 +177,29 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return s, s.err } + var tableName string modelValue := reflect.New(modelType) - tableName := namer.TableName(modelType.Name()) - if tabler, ok := modelValue.Interface().(Tabler); ok { + if specialTableName != "" { + tableName = specialTableName + } else if en, ok := namer.(embeddedNamer); ok { + tableName = en.Table + } else if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() - } - if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { + } else if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { tableName = tabler.TableName(namer) - } - if en, ok := namer.(embeddedNamer); ok { - tableName = en.Table - } - if specialTableName != "" && specialTableName != tableName { - tableName = specialTableName + } else { + tableName = namer.TableName(modelType.Name()) } schema := &Schema{ Name: modelType.Name(), ModelType: modelType, Table: tableName, - FieldsByName: map[string]*Field{}, - FieldsByBindName: map[string]*Field{}, - FieldsByDBName: map[string]*Field{}, + DBNames: make([]string, 0, 10), + Fields: make([]*Field, 0, 10), + FieldsByName: make(map[string]*Field, 10), + FieldsByBindName: make(map[string]*Field, 10), + FieldsByDBName: make(map[string]*Field, 10), Relationships: Relationships{Relations: map[string]*Relationship{}}, cacheStore: cacheStore, namer: namer, @@ -227,8 +243,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.FieldsByBindName[bindName] = field if v != nil && v.PrimaryKey { + // remove the existing primary key field for idx, f := range schema.PrimaryFields { - if f == v { + if f.DBName == v.DBName { schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) } } @@ -247,7 +264,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.FieldsByBindName[bindName] = field } - field.setupValuerAndSetter() + field.setupValuerAndSetter(modelType) } prioritizedPrimaryField := schema.LookUpField("id") @@ -283,10 +300,37 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) } + _, embedded := schema.cacheStore.Load(embeddedCacheKey) + relationshipFields := []*Field{} for _, field := range schema.Fields { if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil { schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } + + if !embedded { + if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) { + relationshipFields = append(relationshipFields, field) + schema.FieldsByName[field.Name] = field + schema.FieldsByBindName[field.BindName()] = field + } + + fieldValue := reflect.New(field.IndirectFieldType).Interface() + if fc, ok := fieldValue.(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } + + if fc, ok := fieldValue.(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } + + if fc, ok := fieldValue.(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } + + if fc, ok := fieldValue.(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + } + } } if field := schema.PrioritizedPrimaryField; field != nil { @@ -303,24 +347,6 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } } - callbackTypes := []callbackType{ - callbackTypeBeforeCreate, callbackTypeAfterCreate, - callbackTypeBeforeUpdate, callbackTypeAfterUpdate, - callbackTypeBeforeSave, callbackTypeAfterSave, - callbackTypeBeforeDelete, callbackTypeAfterDelete, - callbackTypeAfterFind, - } - for _, cbName := range callbackTypes { - if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() { - switch methodValue.Type().String() { - case "func(*gorm.DB) error": // TODO hack - reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) - default: - logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) - } - } - } - // Cache the schema if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded { s := v.(*Schema) @@ -336,84 +362,47 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } }() - if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { - for _, field := range schema.Fields { - if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) { - if schema.parseRelation(field); schema.err != nil { - return schema, schema.err + for _, cbName := range callbackTypes { + if methodValue := modelValue.MethodByName(string(cbName)); methodValue.IsValid() { + switch methodValue.Type().String() { + case "func(*gorm.DB) error": + expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath()) + if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath { + reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) } else { - schema.FieldsByName[field.Name] = field - schema.FieldsByBindName[field.BindName()] = field + logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg) + // PASS } + default: + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) } + } + } - fieldValue := reflect.New(field.IndirectFieldType) - fieldInterface := fieldValue.Interface() - if fc, ok := fieldInterface.(CreateClausesInterface); ok { - field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) - } - - if fc, ok := fieldInterface.(QueryClausesInterface); ok { - field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) - } - - if fc, ok := fieldInterface.(UpdateClausesInterface); ok { - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) - } - - if fc, ok := fieldInterface.(DeleteClausesInterface); ok { - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) - } + // parse relationships + for _, field := range relationshipFields { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err } } return schema, schema.err } -// This unrolling is needed to show to the compiler the exact set of methods -// that can be used on the modelType. -// Prior to go1.22 any use of MethodByName would cause the linker to -// abandon dead code elimination for the entire binary. -// As of go1.22 the compiler supports one special case of a string constant -// being passed to MethodByName. For enterprise customers or those building -// large binaries, this gives a significant reduction in binary size. -// https://github.com/golang/go/issues/62257 -func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value { - switch cbType { - case callbackTypeBeforeCreate: - return modelType.MethodByName(string(callbackTypeBeforeCreate)) - case callbackTypeAfterCreate: - return modelType.MethodByName(string(callbackTypeAfterCreate)) - case callbackTypeBeforeUpdate: - return modelType.MethodByName(string(callbackTypeBeforeUpdate)) - case callbackTypeAfterUpdate: - return modelType.MethodByName(string(callbackTypeAfterUpdate)) - case callbackTypeBeforeSave: - return modelType.MethodByName(string(callbackTypeBeforeSave)) - case callbackTypeAfterSave: - return modelType.MethodByName(string(callbackTypeAfterSave)) - case callbackTypeBeforeDelete: - return modelType.MethodByName(string(callbackTypeBeforeDelete)) - case callbackTypeAfterDelete: - return modelType.MethodByName(string(callbackTypeAfterDelete)) - case callbackTypeAfterFind: - return modelType.MethodByName(string(callbackTypeAfterFind)) - default: - return reflect.ValueOf(nil) - } -} - func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() - for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } if modelType.Kind() != reflect.Struct { - if modelType.PkgPath() == "" { - return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } - return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } if v, ok := cacheStore.Load(modelType); ok { diff --git a/auth-center/vendor/gorm.io/gorm/schema/serializer.go b/auth-center/vendor/gorm.io/gorm/schema/serializer.go index 0fafbcba..d774a8ed 100644 --- a/auth-center/vendor/gorm.io/gorm/schema/serializer.go +++ b/auth-center/vendor/gorm.io/gorm/schema/serializer.go @@ -8,6 +8,7 @@ import ( "encoding/gob" "encoding/json" "fmt" + "math" "reflect" "strings" "sync" @@ -127,16 +128,31 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect. // Value implements serializer interface func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { rv := reflect.ValueOf(fieldValue) - switch v := fieldValue.(type) { - case int64, int, uint, uint64, int32, uint32, int16, uint16: - result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() - case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: + switch fieldValue.(type) { + case int, int8, int16, int32, int64: + result = time.Unix(rv.Int(), 0).UTC() + case uint, uint8, uint16, uint32, uint64: + if uv := rv.Uint(); uv > math.MaxInt64 { + err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) + } else { + result = time.Unix(int64(uv), 0).UTC() //nolint:gosec + } + case *int, *int8, *int16, *int32, *int64: if rv.IsZero() { return nil, nil } - result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() + result = time.Unix(rv.Elem().Int(), 0).UTC() + case *uint, *uint8, *uint16, *uint32, *uint64: + if rv.IsZero() { + return nil, nil + } + if uv := rv.Elem().Uint(); uv > math.MaxInt64 { + err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) + } else { + result = time.Unix(int64(uv), 0).UTC() //nolint:gosec + } default: - err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) + err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", fieldValue) } return } diff --git a/auth-center/vendor/gorm.io/gorm/schema/utils.go b/auth-center/vendor/gorm.io/gorm/schema/utils.go index fa1c65d4..86305d7b 100644 --- a/auth-center/vendor/gorm.io/gorm/schema/utils.go +++ b/auth-center/vendor/gorm.io/gorm/schema/utils.go @@ -17,25 +17,23 @@ func ParseTagSetting(str string, sep string) map[string]string { settings := map[string]string{} names := strings.Split(str, sep) + var parsedNames []string for i := 0; i < len(names); i++ { - j := i - if len(names[j]) > 0 { - for { - if names[j][len(names[j])-1] == '\\' { - i++ - names[j] = names[j][0:len(names[j])-1] + sep + names[i] - names[i] = "" - } else { - break - } - } + s := names[i] + for strings.HasSuffix(s, "\\") && i+1 < len(names) { + i++ + s = s[:len(s)-1] + sep + names[i] } + parsedNames = append(parsedNames, s) + } - values := strings.Split(names[j], ":") + for _, tag := range parsedNames { + values := strings.Split(tag, ":") k := strings.TrimSpace(strings.ToUpper(values[0])) - if len(values) >= 2 { - settings[k] = strings.Join(values[1:], ":") + val := strings.Join(values[1:], ":") + val = strings.ReplaceAll(val, `\"`, `"`) + settings[k] = val } else if k != "" { settings[k] = k } @@ -121,6 +119,17 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, } switch reflectValue.Kind() { + case reflect.Map: + results = [][]interface{}{make([]interface{}, len(fields))} + for idx, field := range fields { + mapValue := reflectValue.MapIndex(reflect.ValueOf(field.DBName)) + if mapValue.IsZero() { + mapValue = reflectValue.MapIndex(reflect.ValueOf(field.Name)) + } + results[0][idx] = mapValue.Interface() + } + + dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue} case reflect.Struct: results = [][]interface{}{make([]interface{}, len(fields))} diff --git a/auth-center/vendor/gorm.io/gorm/statement.go b/auth-center/vendor/gorm.io/gorm/statement.go index c6183724..736087d7 100644 --- a/auth-center/vendor/gorm.io/gorm/statement.go +++ b/auth-center/vendor/gorm.io/gorm/statement.go @@ -96,7 +96,9 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { if v.Name == clause.CurrentTable { if stmt.TableExpr != nil { stmt.TableExpr.Build(stmt) - } else { + } else if stmt.Table != "" { + write(v.Raw, stmt.Table) + } else if stmt.AddError(stmt.Parse(stmt.Model)) == nil { write(v.Raw, stmt.Table) } } else { @@ -334,6 +336,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] switch v := arg.(type) { case clause.Expression: conds = append(conds, v) + case []clause.Expression: + conds = append(conds, v...) case *DB: v.executeScopes() @@ -341,7 +345,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if where, ok := cs.Expression.(clause.Where); ok { if len(where.Exprs) == 1 { if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { - where.Exprs[0] = clause.AndConditions(orConds) + if len(orConds.Exprs) == 1 { + where.Exprs[0] = clause.AndConditions(orConds) + } } } conds = append(conds, clause.And(where.Exprs...)) @@ -362,6 +368,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, key := range keys { column := clause.Column{Name: key, Table: curTable} + if strings.Contains(key, ".") { + column = clause.Column{Name: key} + } conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } case map[string]interface{}: @@ -374,6 +383,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, key := range keys { reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) column := clause.Column{Name: key, Table: curTable} + if strings.Contains(key, ".") { + column = clause.Column{Name: key} + } switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if _, ok := v[key].(driver.Valuer); ok { @@ -650,12 +662,15 @@ func (stmt *Statement) Changed(fields ...string) bool { for destValue.Kind() == reflect.Ptr { destValue = destValue.Elem() } - - changedValue, zero := field.ValueOf(stmt.Context, destValue) - if v { - return !utils.AssertEqual(changedValue, fieldValue) + if descSchema, err := schema.Parse(stmt.Dest, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + if destField := descSchema.LookUpField(field.DBName); destField != nil { + changedValue, zero := destField.ValueOf(stmt.Context, destValue) + if v { + return !utils.AssertEqual(changedValue, fieldValue) + } + return !zero && !utils.AssertEqual(changedValue, fieldValue) + } } - return !zero && !utils.AssertEqual(changedValue, fieldValue) } } return false diff --git a/auth-center/vendor/gorm.io/gorm/utils/utils.go b/auth-center/vendor/gorm.io/gorm/utils/utils.go index fc615d73..7e59264b 100644 --- a/auth-center/vendor/gorm.io/gorm/utils/utils.go +++ b/auth-center/vendor/gorm.io/gorm/utils/utils.go @@ -30,8 +30,12 @@ func sourceDir(file string) string { return filepath.ToSlash(s) + "/" } -// FileWithLineNum return the file name and line number of the current file -func FileWithLineNum() string { +// CallerFrame retrieves the first relevant stack frame outside of GORM's internal implementation files. +// It skips: +// - GORM's core source files (identified by gormSourceDir prefix) +// - Exclude test files (*_test.go) +// - go-gorm/gen's Generated files (*.gen.go) +func CallerFrame() runtime.Frame { pcs := [13]uintptr{} // the third caller usually from gorm internal len := runtime.Callers(3, pcs[:]) @@ -41,14 +45,24 @@ func FileWithLineNum() string { frame, _ := frames.Next() if (!strings.HasPrefix(frame.File, gormSourceDir) || strings.HasSuffix(frame.File, "_test.go")) && !strings.HasSuffix(frame.File, ".gen.go") { - return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) + return frame } } + return runtime.Frame{} +} + +// FileWithLineNum return the file name and line number of the current file +func FileWithLineNum() string { + frame := CallerFrame() + if frame.PC != 0 { + return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) + } + return "" } -func IsValidDBNameChar(c rune) bool { +func IsInvalidDBNameChar(c rune) bool { return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' } diff --git a/auth-center/vendor/modules.txt b/auth-center/vendor/modules.txt index a345f74a..e3ae0150 100644 --- a/auth-center/vendor/modules.txt +++ b/auth-center/vendor/modules.txt @@ -1283,7 +1283,7 @@ gopkg.in/yaml.v3 # gorm.io/driver/postgres v1.6.0 ## explicit; go 1.20 gorm.io/driver/postgres -# gorm.io/gorm v1.30.0 +# gorm.io/gorm v1.31.1 ## explicit; go 1.18 gorm.io/gorm gorm.io/gorm/callbacks diff --git a/cluster-manager/go.mod b/cluster-manager/go.mod index 1b9ff6ad..de2065c2 100644 --- a/cluster-manager/go.mod +++ b/cluster-manager/go.mod @@ -17,8 +17,8 @@ require ( google.golang.org/grpc v1.73.0 google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.3.0 google.golang.org/protobuf v1.36.6 - gorm.io/driver/postgres v1.5.7 - gorm.io/gorm v1.30.0 + gorm.io/driver/postgres v1.6.0 + gorm.io/gorm v1.31.1 sigs.k8s.io/yaml v1.4.0 ) @@ -107,8 +107,9 @@ require ( github.com/hexops/gotextdiff v1.0.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.4.3 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.6.0 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jgautheron/goconst v1.7.1 // indirect github.com/jingyugao/rowserrcheck v1.1.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect diff --git a/cluster-manager/go.sum b/cluster-manager/go.sum index 8fd0c75d..e73e98dd 100644 --- a/cluster-manager/go.sum +++ b/cluster-manager/go.sum @@ -218,10 +218,12 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= -github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= +github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jgautheron/goconst v1.7.1 h1:VpdAG7Ca7yvvJk5n8dMwQhfEZJh95kl/Hl9S1OI5Jkk= github.com/jgautheron/goconst v1.7.1/go.mod h1:aAosetZ5zaeC/2EfMeRswtxUFBpe2Hr7HzkgX4fanO4= github.com/jingyugao/rowserrcheck v1.1.1 h1:zibz55j/MJtLsjP1OF4bSdgXxwL1b+Vn7Tjzq7gFzUs= @@ -659,10 +661,10 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= -gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= -gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs= -gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= +gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= +gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= honnef.co/go/tools v0.6.1 h1:R094WgE8K4JirYjBaOpz/AvTyUu/3wbmAoskKN/pxTI= honnef.co/go/tools v0.6.1/go.mod h1:3puzxxljPCe8RGJX7BIy1plGbxEOZni5mR2aXe3/uk4= mvdan.cc/gofumpt v0.7.0 h1:bg91ttqXmi9y2xawvkuMXyvAA/1ZGJqYAEGjXuP0JXU= diff --git a/cluster-manager/vendor/github.com/jackc/pgservicefile/.travis.yml b/cluster-manager/vendor/github.com/jackc/pgservicefile/.travis.yml deleted file mode 100644 index e176228e..00000000 --- a/cluster-manager/vendor/github.com/jackc/pgservicefile/.travis.yml +++ /dev/null @@ -1,9 +0,0 @@ -language: go - -go: - - 1.x - - tip - -matrix: - allow_failures: - - go: tip diff --git a/cluster-manager/vendor/github.com/jackc/pgservicefile/README.md b/cluster-manager/vendor/github.com/jackc/pgservicefile/README.md index e50ca126..2fc7e012 100644 --- a/cluster-manager/vendor/github.com/jackc/pgservicefile/README.md +++ b/cluster-manager/vendor/github.com/jackc/pgservicefile/README.md @@ -1,5 +1,6 @@ -[![](https://godoc.org/github.com/jackc/pgservicefile?status.svg)](https://godoc.org/github.com/jackc/pgservicefile) -[![Build Status](https://travis-ci.org/jackc/pgservicefile.svg)](https://travis-ci.org/jackc/pgservicefile) +[![Go Reference](https://pkg.go.dev/badge/github.com/jackc/pgservicefile.svg)](https://pkg.go.dev/github.com/jackc/pgservicefile) +[![Build Status](https://github.com/jackc/pgservicefile/actions/workflows/ci.yml/badge.svg)](https://github.com/jackc/pgservicefile/actions/workflows/ci.yml) + # pgservicefile diff --git a/cluster-manager/vendor/github.com/jackc/pgservicefile/pgservicefile.go b/cluster-manager/vendor/github.com/jackc/pgservicefile/pgservicefile.go index 797bbab9..c62caa7f 100644 --- a/cluster-manager/vendor/github.com/jackc/pgservicefile/pgservicefile.go +++ b/cluster-manager/vendor/github.com/jackc/pgservicefile/pgservicefile.go @@ -57,7 +57,7 @@ func ParseServicefile(r io.Reader) (*Servicefile, error) { } else if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { service = &Service{Name: line[1 : len(line)-1], Settings: make(map[string]string)} servicefile.Services = append(servicefile.Services, service) - } else { + } else if service != nil { parts := strings.SplitN(line, "=", 2) if len(parts) != 2 { return nil, fmt.Errorf("unable to parse line %d", lineNum) @@ -67,6 +67,8 @@ func ParseServicefile(r io.Reader) (*Servicefile, error) { value := strings.TrimSpace(parts[1]) service.Settings[key] = value + } else { + return nil, fmt.Errorf("line %d is not in a section", lineNum) } } diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/CHANGELOG.md b/cluster-manager/vendor/github.com/jackc/pgx/v5/CHANGELOG.md index fb2304a2..61b4695f 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/CHANGELOG.md +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/CHANGELOG.md @@ -1,3 +1,92 @@ +# 5.6.0 (May 25, 2024) + +* Add StrictNamedArgs (Tomas Zahradnicek) +* Add support for macaddr8 type (Carlos Pérez-Aradros Herce) +* Add SeverityUnlocalized field to PgError / Notice +* Performance optimization of RowToStructByPos/Name (Zach Olstein) +* Allow customizing context canceled behavior for pgconn +* Add ScanLocation to pgtype.Timestamp[tz]Codec +* Add custom data to pgconn.PgConn +* Fix ResultReader.Read() to handle nil values +* Do not encode interval microseconds when they are 0 (Carlos Pérez-Aradros Herce) +* pgconn.SafeToRetry checks for wrapped errors (tjasko) +* Failed connection attempts include all errors +* Optimize LargeObject.Read (Mitar) +* Add tracing for connection acquire and release from pool (ngavinsir) +* Fix encode driver.Valuer not called when nil +* Add support for custom JSON marshal and unmarshal (Mitar) +* Use Go default keepalive for TCP connections (Hans-Joachim Kliemeck) + +# 5.5.5 (March 9, 2024) + +Use spaces instead of parentheses for SQL sanitization. + +This still solves the problem of negative numbers creating a line comment, but this avoids breaking edge cases such as +`set foo to $1` where the substitution is taking place in a location where an arbitrary expression is not allowed. + +# 5.5.4 (March 4, 2024) + +Fix CVE-2024-27304 + +SQL injection can occur if an attacker can cause a single query or bind message to exceed 4 GB in size. An integer +overflow in the calculated message size can cause the one large message to be sent as multiple messages under the +attacker's control. + +Thanks to Paul Gerste for reporting this issue. + +* Fix behavior of CollectRows to return empty slice if Rows are empty (Felix) +* Fix simple protocol encoding of json.RawMessage +* Fix *Pipeline.getResults should close pipeline on error +* Fix panic in TryFindUnderlyingTypeScanPlan (David Kurman) +* Fix deallocation of invalidated cached statements in a transaction +* Handle invalid sslkey file +* Fix scan float4 into sql.Scanner +* Fix pgtype.Bits not making copy of data from read buffer. This would cause the data to be corrupted by future reads. + +# 5.5.3 (February 3, 2024) + +* Fix: prepared statement already exists +* Improve CopyFrom auto-conversion of text-ish values +* Add ltree type support (Florent Viel) +* Make some properties of Batch and QueuedQuery public (Pavlo Golub) +* Add AppendRows function (Edoardo Spadolini) +* Optimize convert UUID [16]byte to string (Kirill Malikov) +* Fix: LargeObject Read and Write of more than ~1GB at a time (Mitar) + +# 5.5.2 (January 13, 2024) + +* Allow NamedArgs to start with underscore +* pgproto3: Maximum message body length support (jeremy.spriet) +* Upgrade golang.org/x/crypto to v0.17.0 +* Add snake_case support to RowToStructByName (Tikhon Fedulov) +* Fix: update description cache after exec prepare (James Hartig) +* Fix: pipeline checks if it is closed (James Hartig and Ryan Fowler) +* Fix: normalize timeout / context errors during TLS startup (Samuel Stauffer) +* Add OnPgError for easier centralized error handling (James Hartig) + +# 5.5.1 (December 9, 2023) + +* Add CopyFromFunc helper function. (robford) +* Add PgConn.Deallocate method that uses PostgreSQL protocol Close message. +* pgx uses new PgConn.Deallocate method. This allows deallocating statements to work in a failed transaction. This fixes a case where the prepared statement map could become invalid. +* Fix: Prefer driver.Valuer over json.Marshaler for json fields. (Jacopo) +* Fix: simple protocol SQL sanitizer previously panicked if an invalid $0 placeholder was used. This now returns an error instead. (maksymnevajdev) +* Add pgtype.Numeric.ScanScientific (Eshton Robateau) + +# 5.5.0 (November 4, 2023) + +* Add CollectExactlyOneRow. (Julien GOTTELAND) +* Add OpenDBFromPool to create *database/sql.DB from *pgxpool.Pool. (Lev Zakharov) +* Prepare can automatically choose statement name based on sql. This makes it easier to explicitly manage prepared statements. +* Statement cache now uses deterministic, stable statement names. +* database/sql prepared statement names are deterministically generated. +* Fix: SendBatch wasn't respecting context cancellation. +* Fix: Timeout error from pipeline is now normalized. +* Fix: database/sql encoding json.RawMessage to []byte. +* CancelRequest: Wait for the cancel request to be acknowledged by the server. This should improve PgBouncer compatibility. (Anton Levakin) +* stdlib: Use Ping instead of CheckConn in ResetSession +* Add json.Marshaler and json.Unmarshaler for Float4, Float8 (Kirill Mironov) + # 5.4.3 (August 5, 2023) * Fix: QCharArrayOID was defined with the wrong OID (Christoph Engelbert) diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/CONTRIBUTING.md b/cluster-manager/vendor/github.com/jackc/pgx/v5/CONTRIBUTING.md index 3eb0da5b..c975a937 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/CONTRIBUTING.md +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/CONTRIBUTING.md @@ -29,6 +29,7 @@ Create and setup a test database: export PGDATABASE=pgx_test createdb psql -c 'create extension hstore;' +psql -c 'create extension ltree;' psql -c 'create domain uint64 as numeric(20,0);' ``` @@ -79,20 +80,11 @@ echo "listen_addresses = '127.0.0.1'" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql echo "port = $PGPORT" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf cat testsetup/postgresql_ssl.conf >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf cp testsetup/pg_hba.conf .testdb/$POSTGRESQL_DATA_DIR/pg_hba.conf -cp testsetup/ca.cnf .testdb -cp testsetup/localhost.cnf .testdb -cp testsetup/pgx_sslcert.cnf .testdb cd .testdb -# Generate a CA public / private key pair. -openssl genrsa -out ca.key 4096 -openssl req -x509 -config ca.cnf -new -nodes -key ca.key -sha256 -days 365 -subj '/O=pgx-test-root' -out ca.pem - -# Generate the certificate for localhost (the server). -openssl genrsa -out localhost.key 2048 -openssl req -new -config localhost.cnf -key localhost.key -out localhost.csr -openssl x509 -req -in localhost.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out localhost.crt -days 364 -sha256 -extfile localhost.cnf -extensions v3_req +# Generate CA, server, and encrypted client certificates. +go run ../testsetup/generate_certs.go # Copy certificates to server directory and set permissions. cp ca.pem $POSTGRESQL_DATA_DIR/root.crt @@ -100,11 +92,6 @@ cp localhost.key $POSTGRESQL_DATA_DIR/server.key chmod 600 $POSTGRESQL_DATA_DIR/server.key cp localhost.crt $POSTGRESQL_DATA_DIR/server.crt -# Generate the certificate for client authentication. -openssl genrsa -des3 -out pgx_sslcert.key -passout pass:certpw 2048 -openssl req -new -config pgx_sslcert.cnf -key pgx_sslcert.key -passin pass:certpw -out pgx_sslcert.csr -openssl x509 -req -in pgx_sslcert.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out pgx_sslcert.crt -days 363 -sha256 -extfile pgx_sslcert.cnf -extensions v3_req - cd .. ``` diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/README.md b/cluster-manager/vendor/github.com/jackc/pgx/v5/README.md index 522206f9..0cf2c291 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/README.md +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/README.md @@ -86,9 +86,13 @@ It is also possible to use the `database/sql` interface and convert a connection See CONTRIBUTING.md for setup instructions. +## Architecture + +See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube.com/watch?v=sXMSWhcHCf8) for a description of pgx architecture. + ## Supported Go and PostgreSQL Versions -pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.19 and higher and PostgreSQL 11 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). +pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.21 and higher and PostgreSQL 12 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). ## Version Policy @@ -116,6 +120,7 @@ pgerrcode contains constants for the PostgreSQL error codes. * [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid) * [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal) +* [github.com/twpayne/pgx-geos](https://github.com/twpayne/pgx-geos) ([PostGIS](https://postgis.net/) and [GEOS](https://libgeos.org/) via [go-geos](https://github.com/twpayne/go-geos)) * [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid) diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/batch.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/batch.go index 8f6ea4f0..3540f57f 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/batch.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/batch.go @@ -10,9 +10,9 @@ import ( // QueuedQuery is a query that has been queued for execution via a Batch. type QueuedQuery struct { - query string - arguments []any - fn batchItemFunc + SQL string + Arguments []any + Fn batchItemFunc sd *pgconn.StatementDescription } @@ -20,7 +20,7 @@ type batchItemFunc func(br BatchResults) error // Query sets fn to be called when the response to qq is received. func (qq *QueuedQuery) Query(fn func(rows Rows) error) { - qq.fn = func(br BatchResults) error { + qq.Fn = func(br BatchResults) error { rows, _ := br.Query() defer rows.Close() @@ -36,7 +36,7 @@ func (qq *QueuedQuery) Query(fn func(rows Rows) error) { // Query sets fn to be called when the response to qq is received. func (qq *QueuedQuery) QueryRow(fn func(row Row) error) { - qq.fn = func(br BatchResults) error { + qq.Fn = func(br BatchResults) error { row := br.QueryRow() return fn(row) } @@ -44,7 +44,7 @@ func (qq *QueuedQuery) QueryRow(fn func(row Row) error) { // Exec sets fn to be called when the response to qq is received. func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) { - qq.fn = func(br BatchResults) error { + qq.Fn = func(br BatchResults) error { ct, err := br.Exec() if err != nil { return err @@ -57,22 +57,24 @@ func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) { // Batch queries are a way of bundling multiple queries together to avoid // unnecessary network round trips. A Batch must only be sent once. type Batch struct { - queuedQueries []*QueuedQuery + QueuedQueries []*QueuedQuery } // Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. +// The only pgx option argument that is supported is QueryRewriter. Queries are executed using the +// connection's DefaultQueryExecMode. func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery { qq := &QueuedQuery{ - query: query, - arguments: arguments, + SQL: query, + Arguments: arguments, } - b.queuedQueries = append(b.queuedQueries, qq) + b.QueuedQueries = append(b.QueuedQueries, qq) return qq } // Len returns number of queries that have been queued so far. func (b *Batch) Len() int { - return len(b.queuedQueries) + return len(b.QueuedQueries) } type BatchResults interface { @@ -225,9 +227,9 @@ func (br *batchResults) Close() error { } // Read and run fn for all remaining items - for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { - if br.b.queuedQueries[br.qqIdx].fn != nil { - err := br.b.queuedQueries[br.qqIdx].fn(br) + for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) { + if br.b.QueuedQueries[br.qqIdx].Fn != nil { + err := br.b.QueuedQueries[br.qqIdx].Fn(br) if err != nil { br.err = err } @@ -251,10 +253,10 @@ func (br *batchResults) earlyError() error { } func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) { - if br.b != nil && br.qqIdx < len(br.b.queuedQueries) { - bi := br.b.queuedQueries[br.qqIdx] - query = bi.query - args = bi.arguments + if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) { + bi := br.b.QueuedQueries[br.qqIdx] + query = bi.SQL + args = bi.Arguments ok = true br.qqIdx++ } @@ -394,9 +396,9 @@ func (br *pipelineBatchResults) Close() error { } // Read and run fn for all remaining items - for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { - if br.b.queuedQueries[br.qqIdx].fn != nil { - err := br.b.queuedQueries[br.qqIdx].fn(br) + for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) { + if br.b.QueuedQueries[br.qqIdx].Fn != nil { + err := br.b.QueuedQueries[br.qqIdx].Fn(br) if err != nil { br.err = err } @@ -420,10 +422,10 @@ func (br *pipelineBatchResults) earlyError() error { } func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) { - if br.b != nil && br.qqIdx < len(br.b.queuedQueries) { - bi := br.b.queuedQueries[br.qqIdx] - query = bi.query - args = bi.arguments + if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) { + bi := br.b.QueuedQueries[br.qqIdx] + query = bi.SQL + args = bi.Arguments ok = true br.qqIdx++ } diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/conn.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/conn.go index 7c7081b4..31172145 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/conn.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/conn.go @@ -2,13 +2,14 @@ package pgx import ( "context" + "crypto/sha256" + "encoding/hex" "errors" "fmt" "strconv" "strings" "time" - "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/internal/sanitize" "github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/pgconn" @@ -35,7 +36,7 @@ type ConnConfig struct { // DefaultQueryExecMode controls the default mode for executing queries. By default pgx uses the extended protocol // and automatically prepares and caches prepared statements. However, this may be incompatible with proxies such as - // PGBouncer. In this case it may be preferrable to use QueryExecModeExec or QueryExecModeSimpleProtocol. The same + // PGBouncer. In this case it may be preferable to use QueryExecModeExec or QueryExecModeSimpleProtocol. The same // functionality can be controlled on a per query basis by passing a QueryExecMode as the first query argument. DefaultQueryExecMode QueryExecMode @@ -99,8 +100,12 @@ func (ident Identifier) Sanitize() string { return strings.Join(parts, ".") } -// ErrNoRows occurs when rows are expected but none are returned. -var ErrNoRows = errors.New("no rows in result set") +var ( + // ErrNoRows occurs when rows are expected but none are returned. + ErrNoRows = errors.New("no rows in result set") + // ErrTooManyRows occurs when more rows than expected are returned. + ErrTooManyRows = errors.New("too many rows in result set") +) var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") @@ -269,7 +274,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { return c, nil } -// Close closes a connection. It is safe to call Close on a already closed +// Close closes a connection. It is safe to call Close on an already closed // connection. func (c *Conn) Close(ctx context.Context) error { if c.IsClosed() { @@ -280,12 +285,15 @@ func (c *Conn) Close(ctx context.Context) error { return err } -// Prepare creates a prepared statement with name and sql. sql can contain placeholders -// for bound parameters. These placeholders are referenced positional as $1, $2, etc. +// Prepare creates a prepared statement with name and sql. sql can contain placeholders for bound parameters. These +// placeholders are referenced positionally as $1, $2, etc. name can be used instead of sql with Query, QueryRow, and +// Exec to execute the statement. It can also be used with Batch.Queue. +// +// The underlying PostgreSQL identifier for the prepared statement will be name if name != sql or a digest of sql if +// name == sql. // -// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same -// name and sql arguments. This allows a code path to Prepare and Query/Exec without -// concern for if the statement has already been prepared. +// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same name and sql arguments. This +// allows a code path to Prepare and Query/Exec without concern for if the statement has already been prepared. func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) { if c.prepareTracer != nil { ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql}) @@ -307,23 +315,48 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.Statem }() } - sd, err = c.pgConn.Prepare(ctx, name, sql, nil) + var psName, psKey string + if name == sql { + digest := sha256.Sum256([]byte(sql)) + psName = "stmt_" + hex.EncodeToString(digest[0:24]) + psKey = sql + } else { + psName = name + psKey = name + } + + sd, err = c.pgConn.Prepare(ctx, psName, sql, nil) if err != nil { return nil, err } - if name != "" { - c.preparedStatements[name] = sd + if psKey != "" { + c.preparedStatements[psKey] = sd } return sd, nil } -// Deallocate released a prepared statement +// Deallocate releases a prepared statement. Calling Deallocate on a non-existent prepared statement will succeed. func (c *Conn) Deallocate(ctx context.Context, name string) error { - delete(c.preparedStatements, name) - _, err := c.pgConn.Exec(ctx, "deallocate "+quoteIdentifier(name)).ReadAll() - return err + var psName string + sd := c.preparedStatements[name] + if sd != nil { + psName = sd.Name + } else { + psName = name + } + + err := c.pgConn.Deallocate(ctx, psName) + if err != nil { + return err + } + + if sd != nil { + delete(c.preparedStatements, name) + } + + return nil } // DeallocateAll releases all previously prepared statements from the server and client, where it also resets the statement and description cache. @@ -441,7 +474,7 @@ optionLoop: if queryRewriter != nil { sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments) if err != nil { - return pgconn.CommandTag{}, fmt.Errorf("rewrite query failed: %v", err) + return pgconn.CommandTag{}, fmt.Errorf("rewrite query failed: %w", err) } } @@ -461,7 +494,7 @@ optionLoop: } sd := c.statementCache.Get(sql) if sd == nil { - sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql) + sd, err = c.Prepare(ctx, stmtcache.StatementName(sql), sql) if err != nil { return pgconn.CommandTag{}, err } @@ -479,6 +512,7 @@ optionLoop: if err != nil { return pgconn.CommandTag{}, err } + c.descriptionCache.Put(sd) } return c.execParams(ctx, sd, arguments) @@ -573,32 +607,35 @@ type QueryExecMode int32 const ( _ QueryExecMode = iota - // Automatically prepare and cache statements. This uses the extended protocol. Queries are executed in a single - // round trip after the statement is cached. This is the default. + // Automatically prepare and cache statements. This uses the extended protocol. Queries are executed in a single round + // trip after the statement is cached. This is the default. If the database schema is modified or the search_path is + // changed after a statement is cached then the first execution of a previously cached query may fail. e.g. If the + // number of columns returned by a "SELECT *" changes or the type of a column is changed. QueryExecModeCacheStatement - // Cache statement descriptions (i.e. argument and result types) and assume they do not change. This uses the - // extended protocol. Queries are executed in a single round trip after the description is cached. If the database - // schema is modified or the search_path is changed this may result in undetected result decoding errors. + // Cache statement descriptions (i.e. argument and result types) and assume they do not change. This uses the extended + // protocol. Queries are executed in a single round trip after the description is cached. If the database schema is + // modified or the search_path is changed after a statement is cached then the first execution of a previously cached + // query may fail. e.g. If the number of columns returned by a "SELECT *" changes or the type of a column is changed. QueryExecModeCacheDescribe // Get the statement description on every execution. This uses the extended protocol. Queries require two round trips // to execute. It does not use named prepared statements. But it does use the unnamed prepared statement to get the // statement description on the first round trip and then uses it to execute the query on the second round trip. This // may cause problems with connection poolers that switch the underlying connection between round trips. It is safe - // even when the the database schema is modified concurrently. + // even when the database schema is modified concurrently. QueryExecModeDescribeExec // Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol // with text formatted parameters and results. Queries are executed in a single round trip. Type mappings can be // registered with pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are - // unregistered or ambigious. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know + // unregistered or ambiguous. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know // the PostgreSQL type can use a map[string]string directly as an argument. This mode cannot. QueryExecModeExec // Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. // Queries are executed in a single round trip. Type mappings can be registered with - // pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious. + // pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambiguous. // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use // a map[string]string directly as an argument. This mode cannot. // @@ -705,7 +742,7 @@ optionLoop: sql, args, err = queryRewriter.RewriteQuery(ctx, c, sql, args) if err != nil { rows := c.getRows(ctx, originalSQL, originalArgs) - err = fmt.Errorf("rewrite query failed: %v", err) + err = fmt.Errorf("rewrite query failed: %w", err) rows.fatal(err) return rows, err } @@ -717,7 +754,6 @@ optionLoop: } c.eqb.reset() - anynil.NormalizeSlice(args) rows := c.getRows(ctx, sql, args) var err error @@ -815,7 +851,7 @@ func (c *Conn) getStatementDescription( } sd = c.statementCache.Get(sql) if sd == nil { - sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql) + sd, err = c.Prepare(ctx, stmtcache.StatementName(sql), sql) if err != nil { return nil, err } @@ -865,15 +901,14 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { return &batchResults{ctx: ctx, conn: c, err: err} } - mode := c.config.DefaultQueryExecMode - - for _, bi := range b.queuedQueries { + for _, bi := range b.QueuedQueries { var queryRewriter QueryRewriter - sql := bi.query - arguments := bi.arguments + sql := bi.SQL + arguments := bi.Arguments optionLoop: for len(arguments) > 0 { + // Update Batch.Queue function comment when additional options are implemented switch arg := arguments[0].(type) { case QueryRewriter: queryRewriter = arg @@ -887,21 +922,23 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { var err error sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments) if err != nil { - return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("rewrite query failed: %v", err)} + return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("rewrite query failed: %w", err)} } } - bi.query = sql - bi.arguments = arguments + bi.SQL = sql + bi.Arguments = arguments } + // TODO: changing mode per batch? Update Batch.Queue function comment when implemented + mode := c.config.DefaultQueryExecMode if mode == QueryExecModeSimpleProtocol { return c.sendBatchQueryExecModeSimpleProtocol(ctx, b) } // All other modes use extended protocol and thus can use prepared statements. - for _, bi := range b.queuedQueries { - if sd, ok := c.preparedStatements[bi.query]; ok { + for _, bi := range b.QueuedQueries { + if sd, ok := c.preparedStatements[bi.SQL]; ok { bi.sd = sd } } @@ -922,11 +959,11 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults { var sb strings.Builder - for i, bi := range b.queuedQueries { + for i, bi := range b.QueuedQueries { if i > 0 { sb.WriteByte(';') } - sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...) + sql, err := c.sanitizeForSimpleQuery(bi.SQL, bi.Arguments...) if err != nil { return &batchResults{ctx: ctx, conn: c, err: err} } @@ -945,21 +982,21 @@ func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batc func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults { batch := &pgconn.Batch{} - for _, bi := range b.queuedQueries { + for _, bi := range b.QueuedQueries { sd := bi.sd if sd != nil { - err := c.eqb.Build(c.typeMap, sd, bi.arguments) + err := c.eqb.Build(c.typeMap, sd, bi.Arguments) if err != nil { return &batchResults{ctx: ctx, conn: c, err: err} } batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) } else { - err := c.eqb.Build(c.typeMap, nil, bi.arguments) + err := c.eqb.Build(c.typeMap, nil, bi.Arguments) if err != nil { return &batchResults{ctx: ctx, conn: c, err: err} } - batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats) + batch.ExecParams(bi.SQL, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats) } } @@ -984,18 +1021,18 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc distinctNewQueries := []*pgconn.StatementDescription{} distinctNewQueriesIdxMap := make(map[string]int) - for _, bi := range b.queuedQueries { + for _, bi := range b.QueuedQueries { if bi.sd == nil { - sd := c.statementCache.Get(bi.query) + sd := c.statementCache.Get(bi.SQL) if sd != nil { bi.sd = sd } else { - if idx, present := distinctNewQueriesIdxMap[bi.query]; present { + if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present { bi.sd = distinctNewQueries[idx] } else { sd = &pgconn.StatementDescription{ - Name: stmtcache.NextStatementName(), - SQL: bi.query, + Name: stmtcache.StatementName(bi.SQL), + SQL: bi.SQL, } distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) distinctNewQueries = append(distinctNewQueries, sd) @@ -1016,17 +1053,17 @@ func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch distinctNewQueries := []*pgconn.StatementDescription{} distinctNewQueriesIdxMap := make(map[string]int) - for _, bi := range b.queuedQueries { + for _, bi := range b.QueuedQueries { if bi.sd == nil { - sd := c.descriptionCache.Get(bi.query) + sd := c.descriptionCache.Get(bi.SQL) if sd != nil { bi.sd = sd } else { - if idx, present := distinctNewQueriesIdxMap[bi.query]; present { + if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present { bi.sd = distinctNewQueries[idx] } else { sd = &pgconn.StatementDescription{ - SQL: bi.query, + SQL: bi.SQL, } distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) distinctNewQueries = append(distinctNewQueries, sd) @@ -1043,13 +1080,13 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) distinctNewQueries := []*pgconn.StatementDescription{} distinctNewQueriesIdxMap := make(map[string]int) - for _, bi := range b.queuedQueries { + for _, bi := range b.QueuedQueries { if bi.sd == nil { - if idx, present := distinctNewQueriesIdxMap[bi.query]; present { + if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present { bi.sd = distinctNewQueries[idx] } else { sd := &pgconn.StatementDescription{ - SQL: bi.query, + SQL: bi.SQL, } distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) distinctNewQueries = append(distinctNewQueries, sd) @@ -1062,7 +1099,7 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) } func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) { - pipeline := c.pgConn.StartPipeline(context.Background()) + pipeline := c.pgConn.StartPipeline(ctx) defer func() { if pbr != nil && pbr.err != nil { pipeline.Close() @@ -1115,11 +1152,11 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d } // Queue the queries. - for _, bi := range b.queuedQueries { - err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments) + for _, bi := range b.QueuedQueries { + err := c.eqb.Build(c.typeMap, bi.sd, bi.Arguments) if err != nil { // we wrap the error so we the user can understand which query failed inside the batch - err = fmt.Errorf("error building query %s: %w", bi.query, err) + err = fmt.Errorf("error building query %s: %w", bi.SQL, err) return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } @@ -1164,7 +1201,15 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) { return sanitize.SanitizeSQL(sql, valueArgs...) } -// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration. +// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration. typeName must be +// the name of a type where the underlying type(s) is already understood by pgx. It is for derived types. In particular, +// typeName must be one of the following: +// - An array type name of a type that is already registered. e.g. "_foo" when "foo" is registered. +// - A composite type name where all field types are already registered. +// - A domain type name where the base type is already registered. +// - An enum type name. +// - A range type name where the element type is already registered. +// - A multirange type name where the element type is already registered. func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, error) { var oid uint32 @@ -1307,17 +1352,17 @@ order by attnum`, } func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error { - if c.pgConn.TxStatus() != 'I' { + if txStatus := c.pgConn.TxStatus(); txStatus != 'I' && txStatus != 'T' { return nil } if c.descriptionCache != nil { - c.descriptionCache.HandleInvalidated() + c.descriptionCache.RemoveInvalidated() } var invalidatedStatements []*pgconn.StatementDescription if c.statementCache != nil { - invalidatedStatements = c.statementCache.HandleInvalidated() + invalidatedStatements = c.statementCache.GetInvalidated() } if len(invalidatedStatements) == 0 { @@ -1329,7 +1374,6 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error for _, sd := range invalidatedStatements { pipeline.SendDeallocate(sd.Name) - delete(c.preparedStatements, sd.Name) } err := pipeline.Sync() @@ -1342,5 +1386,10 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error return fmt.Errorf("failed to deallocate cached statement(s): %w", err) } + c.statementCache.RemoveInvalidated() + for _, sd := range invalidatedStatements { + delete(c.preparedStatements, sd.Name) + } + return nil } diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/copy_from.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/copy_from.go index a2c227fd..abcd2239 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/copy_from.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/copy_from.go @@ -64,6 +64,33 @@ func (cts *copyFromSlice) Err() error { return cts.err } +// CopyFromFunc returns a CopyFromSource interface that relies on nxtf for values. +// nxtf returns rows until it either signals an 'end of data' by returning row=nil and err=nil, +// or it returns an error. If nxtf returns an error, the copy is aborted. +func CopyFromFunc(nxtf func() (row []any, err error)) CopyFromSource { + return ©FromFunc{next: nxtf} +} + +type copyFromFunc struct { + next func() ([]any, error) + valueRow []any + err error +} + +func (g *copyFromFunc) Next() bool { + g.valueRow, g.err = g.next() + // only return true if valueRow exists and no error + return g.valueRow != nil && g.err == nil +} + +func (g *copyFromFunc) Values() ([]any, error) { + return g.valueRow, g.err +} + +func (g *copyFromFunc) Err() error { + return g.err +} + // CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data. type CopyFromSource interface { // Next returns true if there is another row and makes the next row data diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/doc.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/doc.go index 7486f42c..bc0391dd 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/doc.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/doc.go @@ -11,9 +11,10 @@ The primary way of establishing a connection is with [pgx.Connect]: conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) -The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified -here. In addition, a config struct can be created by [ParseConfig] and modified before establishing the connection with -[ConnectConfig] to configure settings such as tracing that cannot be configured with a connection string. +The database connection string can be in URL or key/value format. Both PostgreSQL settings and pgx settings can be +specified here. In addition, a config struct can be created by [ParseConfig] and modified before establishing the +connection with [ConnectConfig] to configure settings such as tracing that cannot be configured with a connection +string. Connection Pool @@ -23,8 +24,8 @@ github.com/jackc/pgx/v5/pgxpool for a concurrency safe connection pool. Query Interface pgx implements Query in the familiar database/sql style. However, pgx provides generic functions such as CollectRows and -ForEachRow that are a simpler and safer way of processing rows than manually calling rows.Next(), rows.Scan, and -rows.Err(). +ForEachRow that are a simpler and safer way of processing rows than manually calling defer rows.Close(), rows.Next(), +rows.Scan, and rows.Err(). CollectRows can be used collect all returned rows into a slice. @@ -187,7 +188,7 @@ implemented on top of pgconn. The Conn.PgConn() method can be used to access thi PgBouncer -By default pgx automatically uses prepared statements. Prepared statements are incompaptible with PgBouncer. This can be +By default pgx automatically uses prepared statements. Prepared statements are incompatible with PgBouncer. This can be disabled by setting a different QueryExecMode in ConnConfig.DefaultQueryExecMode. */ package pgx diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/extended_query_builder.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/extended_query_builder.go index 0bbdfbb5..526b0e95 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/extended_query_builder.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/extended_query_builder.go @@ -1,10 +1,8 @@ package pgx import ( - "database/sql/driver" "fmt" - "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" ) @@ -23,10 +21,15 @@ type ExtendedQueryBuilder struct { func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error { eqb.reset() - anynil.NormalizeSlice(args) - if sd == nil { - return eqb.appendParamsForQueryExecModeExec(m, args) + for i := range args { + err := eqb.appendParam(m, 0, pgtype.TextFormatCode, args[i]) + if err != nil { + err = fmt.Errorf("failed to encode args[%d]: %w", i, err) + return err + } + } + return nil } if len(sd.ParamOIDs) != len(args) { @@ -36,7 +39,7 @@ func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescri for i := range args { err := eqb.appendParam(m, sd.ParamOIDs[i], -1, args[i]) if err != nil { - err = fmt.Errorf("failed to encode args[%d]: %v", i, err) + err = fmt.Errorf("failed to encode args[%d]: %w", i, err) return err } } @@ -113,10 +116,6 @@ func (eqb *ExtendedQueryBuilder) reset() { } func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) { - if anynil.Is(arg) { - return nil, nil - } - if eqb.paramValueBytes == nil { eqb.paramValueBytes = make([]byte, 0, 128) } @@ -145,74 +144,3 @@ func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid ui return m.FormatCodeForOID(oid) } - -// appendParamsForQueryExecModeExec appends the args to eqb. -// -// Parameters must be encoded in the text format because of differences in type conversion between timestamps and -// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the -// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both -// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL -// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date. -// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion -// before converting it to date. This means that dates can be shifted by one day. In text format without that double -// type conversion it takes the date directly and ignores time zone (i.e. it works). -// -// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is -// no way to safely use binary or to specify the parameter OIDs. -func (eqb *ExtendedQueryBuilder) appendParamsForQueryExecModeExec(m *pgtype.Map, args []any) error { - for _, arg := range args { - if arg == nil { - err := eqb.appendParam(m, 0, TextFormatCode, arg) - if err != nil { - return err - } - } else { - dt, ok := m.TypeForValue(arg) - if !ok { - var tv pgtype.TextValuer - if tv, ok = arg.(pgtype.TextValuer); ok { - t, err := tv.TextValue() - if err != nil { - return err - } - - dt, ok = m.TypeForOID(pgtype.TextOID) - if ok { - arg = t - } - } - } - if !ok { - var dv driver.Valuer - if dv, ok = arg.(driver.Valuer); ok { - v, err := dv.Value() - if err != nil { - return err - } - dt, ok = m.TypeForValue(v) - if ok { - arg = v - } - } - } - if !ok { - var str fmt.Stringer - if str, ok = arg.(fmt.Stringer); ok { - dt, ok = m.TypeForOID(pgtype.TextOID) - if ok { - arg = str.String() - } - } - } - if !ok { - return &unknownArgumentTypeQueryExecModeExecError{arg: arg} - } - err := eqb.appendParam(m, dt.OID, TextFormatCode, arg) - if err != nil { - return err - } - } - } - - return nil -} diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/internal/anynil/anynil.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/internal/anynil/anynil.go deleted file mode 100644 index 9a48c1a8..00000000 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/internal/anynil/anynil.go +++ /dev/null @@ -1,36 +0,0 @@ -package anynil - -import "reflect" - -// Is returns true if value is any type of nil. e.g. nil or []byte(nil). -func Is(value any) bool { - if value == nil { - return true - } - - refVal := reflect.ValueOf(value) - switch refVal.Kind() { - case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: - return refVal.IsNil() - default: - return false - } -} - -// Normalize converts typed nils (e.g. []byte(nil)) into untyped nil. Other values are returned unmodified. -func Normalize(v any) any { - if Is(v) { - return nil - } - return v -} - -// NormalizeSlice converts all typed nils (e.g. []byte(nil)) in s into untyped nils. Other values are unmodified. s is -// mutated in place. -func NormalizeSlice(s []any) { - for i := range s { - if Is(s[i]) { - s[i] = nil - } - } -} diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go index e9e6d228..df58c448 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go @@ -35,6 +35,11 @@ func (q *Query) Sanitize(args ...any) (string, error) { str = part case int: argIdx := part - 1 + + if argIdx < 0 { + return "", fmt.Errorf("first sql argument must be > 0") + } + if argIdx >= len(args) { return "", fmt.Errorf("insufficient arguments") } @@ -58,6 +63,10 @@ func (q *Query) Sanitize(args ...any) (string, error) { return "", fmt.Errorf("invalid arg type: %T", arg) } argUse[argIdx] = true + + // Prevent SQL injection via Line Comment Creation + // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p + str = " " + str + " " default: return "", fmt.Errorf("invalid Part type: %T", part) } diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go index a25cc8b1..dec83f47 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go @@ -34,7 +34,8 @@ func (c *LRUCache) Get(key string) *pgconn.StatementDescription { } -// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. +// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or +// sd.SQL has been invalidated and HandleInvalidated has not been called yet. func (c *LRUCache) Put(sd *pgconn.StatementDescription) { if sd.SQL == "" { panic("cannot store statement description with empty SQL") @@ -44,6 +45,13 @@ func (c *LRUCache) Put(sd *pgconn.StatementDescription) { return } + // The statement may have been invalidated but not yet handled. Do not readd it to the cache. + for _, invalidSD := range c.invalidStmts { + if invalidSD.SQL == sd.SQL { + return + } + } + if c.l.Len() == c.cap { c.invalidateOldest() } @@ -73,10 +81,16 @@ func (c *LRUCache) InvalidateAll() { c.l = list.New() } -func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription { - invalidStmts := c.invalidStmts +// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. +func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription { + return c.invalidStmts +} + +// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a +// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were +// never seen by the call to GetInvalidated. +func (c *LRUCache) RemoveInvalidated() { c.invalidStmts = nil - return invalidStmts } // Len returns the number of cached prepared statement descriptions. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/internal/stmtcache/stmtcache.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/internal/stmtcache/stmtcache.go index e1bdcba5..d57bdd29 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/internal/stmtcache/stmtcache.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/internal/stmtcache/stmtcache.go @@ -2,18 +2,17 @@ package stmtcache import ( - "strconv" - "sync/atomic" + "crypto/sha256" + "encoding/hex" "github.com/jackc/pgx/v5/pgconn" ) -var stmtCounter int64 - -// NextStatementName returns a statement name that will be unique for the lifetime of the program. -func NextStatementName() string { - n := atomic.AddInt64(&stmtCounter, 1) - return "stmtcache_" + strconv.FormatInt(n, 10) +// StatementName returns a statement name that will be stable for sql across multiple connections and program +// executions. +func StatementName(sql string) string { + digest := sha256.Sum256([]byte(sql)) + return "stmtcache_" + hex.EncodeToString(digest[0:24]) } // Cache caches statement descriptions. @@ -30,8 +29,13 @@ type Cache interface { // InvalidateAll invalidates all statement descriptions. InvalidateAll() - // HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated. - HandleInvalidated() []*pgconn.StatementDescription + // GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. + GetInvalidated() []*pgconn.StatementDescription + + // RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a + // call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were + // never seen by the call to GetInvalidated. + RemoveInvalidated() // Len returns the number of cached prepared statement descriptions. Len() int @@ -39,19 +43,3 @@ type Cache interface { // Cap returns the maximum number of cached prepared statement descriptions. Cap() int } - -func IsStatementInvalid(err error) bool { - pgErr, ok := err.(*pgconn.PgError) - if !ok { - return false - } - - // https://github.com/jackc/pgx/issues/1162 - // - // We used to look for the message "cached plan must not change result type". However, that message can be localized. - // Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to - // tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't - // have so it should be safe. - possibleInvalidCachedPlanError := pgErr.Code == "0A000" - return possibleInvalidCachedPlanError -} diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go index f5f59396..69641329 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go @@ -54,10 +54,16 @@ func (c *UnlimitedCache) InvalidateAll() { c.m = make(map[string]*pgconn.StatementDescription) } -func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription { - invalidStmts := c.invalidStmts +// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. +func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription { + return c.invalidStmts +} + +// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a +// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were +// never seen by the call to GetInvalidated. +func (c *UnlimitedCache) RemoveInvalidated() { c.invalidStmts = nil - return invalidStmts } // Len returns the number of cached prepared statement descriptions. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/large_objects.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/large_objects.go index c238ab9c..9d21afdc 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/large_objects.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/large_objects.go @@ -4,8 +4,15 @@ import ( "context" "errors" "io" + + "github.com/jackc/pgx/v5/pgtype" ) +// The PostgreSQL wire protocol has a limit of 1 GB - 1 per message. See definition of +// PQ_LARGE_MESSAGE_LIMIT in the PostgreSQL source code. To allow for the other data +// in the message,maxLargeObjectMessageLength should be no larger than 1 GB - 1 KB. +var maxLargeObjectMessageLength = 1024*1024*1024 - 1024 + // LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it // was created. // @@ -68,32 +75,65 @@ type LargeObject struct { // Write writes p to the large object and returns the number of bytes written and an error if not all of p was written. func (o *LargeObject) Write(p []byte) (int, error) { - var n int - err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p).Scan(&n) - if err != nil { - return n, err + nTotal := 0 + for { + expected := len(p) - nTotal + if expected == 0 { + break + } else if expected > maxLargeObjectMessageLength { + expected = maxLargeObjectMessageLength + } + + var n int + err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p[nTotal:nTotal+expected]).Scan(&n) + if err != nil { + return nTotal, err + } + + if n < 0 { + return nTotal, errors.New("failed to write to large object") + } + + nTotal += n + + if n < expected { + return nTotal, errors.New("short write to large object") + } else if n > expected { + return nTotal, errors.New("invalid write to large object") + } } - if n < 0 { - return 0, errors.New("failed to write to large object") - } - - return n, nil + return nTotal, nil } // Read reads up to len(p) bytes into p returning the number of bytes read. func (o *LargeObject) Read(p []byte) (int, error) { - var res []byte - err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, len(p)).Scan(&res) - copy(p, res) - if err != nil { - return len(res), err + nTotal := 0 + for { + expected := len(p) - nTotal + if expected == 0 { + break + } else if expected > maxLargeObjectMessageLength { + expected = maxLargeObjectMessageLength + } + + res := pgtype.PreallocBytes(p[nTotal:]) + err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, expected).Scan(&res) + // We compute expected so that it always fits into p, so it should never happen + // that PreallocBytes's ScanBytes had to allocate a new slice. + nTotal += len(res) + if err != nil { + return nTotal, err + } + + if len(res) < expected { + return nTotal, io.EOF + } else if len(res) > expected { + return nTotal, errors.New("invalid read of large object") + } } - if len(res) < len(p) { - err = io.EOF - } - return len(res), err + return nTotal, nil } // Seek moves the current location pointer to the new location specified by offset. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/named_args.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/named_args.go index 1bc32337..c88991ee 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/named_args.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/named_args.go @@ -2,6 +2,7 @@ package pgx import ( "context" + "fmt" "strconv" "strings" "unicode/utf8" @@ -14,10 +15,41 @@ import ( // // conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2}) // conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2) +// +// Named placeholders are case sensitive and must start with a letter or underscore. Subsequent characters can be +// letters, numbers, or underscores. type NamedArgs map[string]any // RewriteQuery implements the QueryRewriter interface. func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { + return rewriteQuery(na, sql, false) +} + +// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all +// named arguments that the sql query uses, and no extra arguments. +type StrictNamedArgs map[string]any + +// RewriteQuery implements the QueryRewriter interface. +func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { + return rewriteQuery(sna, sql, true) +} + +type namedArg string + +type sqlLexer struct { + src string + start int + pos int + nested int // multiline comment nesting level. + stateFn stateFn + parts []any + + nameToOrdinal map[namedArg]int +} + +type stateFn func(*sqlLexer) stateFn + +func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) { l := &sqlLexer{ src: sql, stateFn: rawState, @@ -41,27 +73,24 @@ func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, ar newArgs = make([]any, len(l.nameToOrdinal)) for name, ordinal := range l.nameToOrdinal { - newArgs[ordinal-1] = na[string(name)] + var found bool + newArgs[ordinal-1], found = na[string(name)] + if isStrict && !found { + return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name) + } } - return sb.String(), newArgs, nil -} - -type namedArg string - -type sqlLexer struct { - src string - start int - pos int - nested int // multiline comment nesting level. - stateFn stateFn - parts []any + if isStrict { + for name := range na { + if _, found := l.nameToOrdinal[namedArg(name)]; !found { + return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name) + } + } + } - nameToOrdinal map[namedArg]int + return sb.String(), newArgs, nil } -type stateFn func(*sqlLexer) stateFn - func rawState(l *sqlLexer) stateFn { for { r, width := utf8.DecodeRuneInString(l.src[l.pos:]) @@ -80,7 +109,7 @@ func rawState(l *sqlLexer) stateFn { return doubleQuoteState case '@': nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) - if isLetter(nextRune) { + if isLetter(nextRune) || nextRune == '_' { if l.pos-l.start > 0 { l.parts = append(l.parts, l.src[l.start:l.pos-width]) } diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go index 8c4b2de3..06498361 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go @@ -47,7 +47,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { return err } - // Receive server-first-message payload in a AuthenticationSASLContinue. + // Receive server-first-message payload in an AuthenticationSASLContinue. saslContinue, err := c.rxSASLContinue() if err != nil { return err @@ -67,7 +67,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { return err } - // Receive server-final-message payload in a AuthenticationSASLFinal. + // Receive server-final-message payload in an AuthenticationSASLFinal. saslFinal, err := c.rxSASLFinal() if err != nil { return err diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/config.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/config.go index 1c2c647d..598917f5 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/config.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/config.go @@ -19,6 +19,7 @@ import ( "github.com/jackc/pgpassfile" "github.com/jackc/pgservicefile" + "github.com/jackc/pgx/v5/pgconn/ctxwatch" "github.com/jackc/pgx/v5/pgproto3" ) @@ -39,7 +40,12 @@ type Config struct { DialFunc DialFunc // e.g. net.Dialer.DialContext LookupFunc LookupFunc // e.g. net.Resolver.LookupHost BuildFrontend BuildFrontendFunc - RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + + // BuildContextWatcherHandler is called to create a ContextWatcherHandler for a connection. The handler is called + // when a context passed to a PgConn method is canceled. + BuildContextWatcherHandler func(*PgConn) ctxwatch.Handler + + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) KerberosSrvName string KerberosSpn string @@ -60,12 +66,17 @@ type Config struct { // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. OnNotification NotificationHandler + // OnPgError is a callback function called when a Postgres error is received by the server. The default handler will close + // the connection on any FATAL errors. If you override this handler you should call the previously set handler or ensure + // that you close on FATAL errors by returning false. + OnPgError PgErrorHandler + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } // ParseConfigOptions contains options that control how a config is built such as GetSSLPassword. type ParseConfigOptions struct { - // GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function + // GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the libpq function // PQsetSSLKeyPassHook_OpenSSL. GetSSLPassword GetSSLPasswordFunc } @@ -107,6 +118,14 @@ type FallbackConfig struct { TLSConfig *tls.Config // nil disables TLS } +// connectOneConfig is the configuration for a single attempt to connect to a single host. +type connectOneConfig struct { + network string + address string + originalHostname string // original hostname before resolving + tlsConfig *tls.Config // nil disables TLS +} + // isAbsolutePath checks if the provided value is an absolute path either // beginning with a forward slash (as on Linux-based systems) or with a capital // letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows). @@ -141,11 +160,11 @@ func NetworkAddress(host string, port uint16) (network, address string) { // ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It // uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely -// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). -// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be -// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file. +// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format. See +// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be empty +// to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file. // -// # Example DSN +// # Example Keyword/Value // user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca // // # Example URL @@ -164,7 +183,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { // postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb // // ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed -// via database URL or DSN: +// via database URL or keyword/value: // // PGHOST // PGPORT @@ -228,16 +247,16 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con connStringSettings := make(map[string]string) if connString != "" { var err error - // connString may be a database URL or a DSN + // connString may be a database URL or in PostgreSQL keyword/value format if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { connStringSettings, err = parseURLSettings(connString) if err != nil { - return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} + return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err} } } else { - connStringSettings, err = parseDSNSettings(connString) + connStringSettings, err = parseKeywordValueSettings(connString) if err != nil { - return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err} + return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as keyword/value", err: err} } } } @@ -246,7 +265,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con if service, present := settings["service"]; present { serviceSettings, err := parseServiceSettings(settings["servicefile"], service) if err != nil { - return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err} + return nil, &ParseConfigError{ConnString: connString, msg: "failed to read service", err: err} } settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings) @@ -261,12 +280,22 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend { return pgproto3.NewFrontend(r, w) }, + BuildContextWatcherHandler: func(pgConn *PgConn) ctxwatch.Handler { + return &DeadlineContextWatcherHandler{Conn: pgConn.conn} + }, + OnPgError: func(_ *PgConn, pgErr *PgError) bool { + // we want to automatically close any fatal errors + if strings.EqualFold(pgErr.Severity, "FATAL") { + return false + } + return true + }, } if connectTimeoutSetting, present := settings["connect_timeout"]; present { connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting) if err != nil { - return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err} + return nil, &ParseConfigError{ConnString: connString, msg: "invalid connect_timeout", err: err} } config.ConnectTimeout = connectTimeout config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout) @@ -328,7 +357,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con port, err := parsePort(portStr) if err != nil { - return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err} + return nil, &ParseConfigError{ConnString: connString, msg: "invalid port", err: err} } var tlsConfigs []*tls.Config @@ -340,7 +369,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con var err error tlsConfigs, err = configTLS(settings, host, options) if err != nil { - return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} + return nil, &ParseConfigError{ConnString: connString, msg: "failed to configure TLS", err: err} } } @@ -384,7 +413,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con case "any": // do nothing default: - return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)} + return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)} } return config, nil @@ -505,7 +534,7 @@ func isIPOnly(host string) bool { var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} -func parseDSNSettings(s string) (map[string]string, error) { +func parseKeywordValueSettings(s string) (map[string]string, error) { settings := make(map[string]string) nameMap := map[string]string{ @@ -516,7 +545,7 @@ func parseDSNSettings(s string) (map[string]string, error) { var key, val string eqIdx := strings.IndexRune(s, '=') if eqIdx < 0 { - return nil, errors.New("invalid dsn") + return nil, errors.New("invalid keyword/value") } key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") @@ -568,7 +597,7 @@ func parseDSNSettings(s string) (map[string]string, error) { } if key == "" { - return nil, errors.New("invalid dsn") + return nil, errors.New("invalid keyword/value") } settings[key] = val @@ -709,6 +738,9 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P return nil, fmt.Errorf("unable to read sslkey: %w", err) } block, _ := pem.Decode(buf) + if block == nil { + return nil, errors.New("failed to decode sslkey") + } var pemKey []byte var decryptedKey []byte var decryptedError error @@ -785,7 +817,8 @@ func parsePort(s string) (uint16, error) { } func makeDefaultDialer() *net.Dialer { - return &net.Dialer{KeepAlive: 5 * time.Minute} + // rely on GOLANG KeepAlive settings + return &net.Dialer{} } func makeDefaultResolver() *net.Resolver { @@ -809,7 +842,7 @@ func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc { return d.DialContext } -// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible +// ValidateConnectTargetSessionAttrsReadWrite is a ValidateConnectFunc that implements libpq compatible // target_session_attrs=read-write. func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() @@ -824,7 +857,7 @@ func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgC return nil } -// ValidateConnectTargetSessionAttrsReadOnly is an ValidateConnectFunc that implements libpq compatible +// ValidateConnectTargetSessionAttrsReadOnly is a ValidateConnectFunc that implements libpq compatible // target_session_attrs=read-only. func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error { result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() @@ -839,7 +872,7 @@ func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgCo return nil } -// ValidateConnectTargetSessionAttrsStandby is an ValidateConnectFunc that implements libpq compatible +// ValidateConnectTargetSessionAttrsStandby is a ValidateConnectFunc that implements libpq compatible // target_session_attrs=standby. func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error { result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() @@ -854,7 +887,7 @@ func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgCon return nil } -// ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible +// ValidateConnectTargetSessionAttrsPrimary is a ValidateConnectFunc that implements libpq compatible // target_session_attrs=primary. func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error { result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() @@ -869,7 +902,7 @@ func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgCon return nil } -// ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible +// ValidateConnectTargetSessionAttrsPreferStandby is a ValidateConnectFunc that implements libpq compatible // target_session_attrs=prefer-standby. func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error { result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/internal/ctxwatch/context_watcher.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/ctxwatch/context_watcher.go similarity index 71% rename from cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/internal/ctxwatch/context_watcher.go rename to cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/ctxwatch/context_watcher.go index b39cb3ee..db8884eb 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/internal/ctxwatch/context_watcher.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/ctxwatch/context_watcher.go @@ -8,9 +8,8 @@ import ( // ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a // time. type ContextWatcher struct { - onCancel func() - onUnwatchAfterCancel func() - unwatchChan chan struct{} + handler Handler + unwatchChan chan struct{} lock sync.Mutex watchInProgress bool @@ -20,11 +19,10 @@ type ContextWatcher struct { // NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled. // OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and // onCancel called. -func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher { +func NewContextWatcher(handler Handler) *ContextWatcher { cw := &ContextWatcher{ - onCancel: onCancel, - onUnwatchAfterCancel: onUnwatchAfterCancel, - unwatchChan: make(chan struct{}), + handler: handler, + unwatchChan: make(chan struct{}), } return cw @@ -46,7 +44,7 @@ func (cw *ContextWatcher) Watch(ctx context.Context) { go func() { select { case <-ctx.Done(): - cw.onCancel() + cw.handler.HandleCancel(ctx) cw.onCancelWasCalled = true <-cw.unwatchChan case <-cw.unwatchChan: @@ -66,8 +64,17 @@ func (cw *ContextWatcher) Unwatch() { if cw.watchInProgress { cw.unwatchChan <- struct{}{} if cw.onCancelWasCalled { - cw.onUnwatchAfterCancel() + cw.handler.HandleUnwatchAfterCancel() } cw.watchInProgress = false } } + +type Handler interface { + // HandleCancel is called when the context that a ContextWatcher is currently watching is canceled. canceledCtx is the + // context that was canceled. + HandleCancel(canceledCtx context.Context) + + // HandleUnwatchAfterCancel is called when a ContextWatcher that called HandleCancel on this Handler is unwatched. + HandleUnwatchAfterCancel() +} diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/doc.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/doc.go index e3242cf4..70137501 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/doc.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/doc.go @@ -5,8 +5,8 @@ nearly the same level is the C library libpq. Establishing a Connection -Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for -libpq style environment variables. +Use Connect to establish a connection. It accepts a connection string in URL or keyword/value format and will read the +environment for libpq style environment variables. Executing a Query @@ -20,13 +20,17 @@ result. The ReadAll method reads all query results into memory. Pipeline Mode -Pipeline mode allows sending queries without having read the results of previously sent queries. It allows -control of exactly how many and when network round trips occur. +Pipeline mode allows sending queries without having read the results of previously sent queries. It allows control of +exactly how many and when network round trips occur. Context Support -All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the -method immediately returns. In most circumstances, this will close the underlying connection. +All potentially blocking operations take a context.Context. The default behavior when a context is canceled is for the +method to immediately return. In most circumstances, this will also close the underlying connection. This behavior can +be customized by using BuildContextWatcherHandler on the Config to create a ctxwatch.Handler with different behavior. +This can be especially useful when queries that are frequently canceled and the overhead of creating new connections is +a problem. DeadlineContextWatcherHandler and CancelRequestContextWatcherHandler can be used to introduce a delay before +interrupting the query in such a way as to close the connection. The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the client to abort. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/errors.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/errors.go index 3c54bbec..ec4a6d47 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/errors.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/errors.go @@ -12,13 +12,14 @@ import ( // SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server. func SafeToRetry(err error) bool { - if e, ok := err.(interface{ SafeToRetry() bool }); ok { - return e.SafeToRetry() + var retryableErr interface{ SafeToRetry() bool } + if errors.As(err, &retryableErr) { + return retryableErr.SafeToRetry() } return false } -// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a +// Timeout checks if err was caused by a timeout. To be specific, it is true if err was caused within pgconn by a // context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. func Timeout(err error) bool { var timeoutErr *errTimeout @@ -29,23 +30,24 @@ func Timeout(err error) bool { // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // detailed field description. type PgError struct { - Severity string - Code string - Message string - Detail string - Hint string - Position int32 - InternalPosition int32 - InternalQuery string - Where string - SchemaName string - TableName string - ColumnName string - DataTypeName string - ConstraintName string - File string - Line int32 - Routine string + Severity string + SeverityUnlocalized string + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string } func (pe *PgError) Error() string { @@ -57,22 +59,37 @@ func (pe *PgError) SQLState() string { return pe.Code } -type connectError struct { - config *Config - msg string +// ConnectError is the error returned when a connection attempt fails. +type ConnectError struct { + Config *Config // The configuration that was used in the connection attempt. err error } -func (e *connectError) Error() string { - sb := &strings.Builder{} - fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg) - if e.err != nil { - fmt.Fprintf(sb, " (%s)", e.err.Error()) +func (e *ConnectError) Error() string { + prefix := fmt.Sprintf("failed to connect to `user=%s database=%s`:", e.Config.User, e.Config.Database) + details := e.err.Error() + if strings.Contains(details, "\n") { + return prefix + "\n\t" + strings.ReplaceAll(details, "\n", "\n\t") + } else { + return prefix + " " + details } - return sb.String() } -func (e *connectError) Unwrap() error { +func (e *ConnectError) Unwrap() error { + return e.err +} + +type perDialConnectError struct { + address string + originalHostname string + err error +} + +func (e *perDialConnectError) Error() string { + return fmt.Sprintf("%s (%s): %s", e.address, e.originalHostname, e.err.Error()) +} + +func (e *perDialConnectError) Unwrap() error { return e.err } @@ -88,33 +105,38 @@ func (e *connLockError) Error() string { return e.status } -type parseConfigError struct { - connString string +// ParseConfigError is the error returned when a connection string cannot be parsed. +type ParseConfigError struct { + ConnString string // The connection string that could not be parsed. msg string err error } -func (e *parseConfigError) Error() string { - connString := redactPW(e.connString) +func (e *ParseConfigError) Error() string { + // Now that ParseConfigError is public and ConnString is available to the developer, perhaps it would be better only + // return a static string. That would ensure that the error message cannot leak a password. The ConnString field would + // allow access to the original string if desired and Unwrap would allow access to the underlying error. + connString := redactPW(e.ConnString) if e.err == nil { return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg) } return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error()) } -func (e *parseConfigError) Unwrap() error { +func (e *ParseConfigError) Unwrap() error { return e.err } func normalizeTimeoutError(ctx context.Context, err error) error { - if err, ok := err.(net.Error); ok && err.Timeout() { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { if ctx.Err() == context.Canceled { // Since the timeout was caused by a context cancellation, the actual error is context.Canceled not the timeout error. return context.Canceled } else if ctx.Err() == context.DeadlineExceeded { return &errTimeout{err: ctx.Err()} } else { - return &errTimeout{err: err} + return &errTimeout{err: netErr} } } return err @@ -189,10 +211,10 @@ func redactPW(connString string) string { return redactURL(u) } } - quotedDSN := regexp.MustCompile(`password='[^']*'`) - connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx") - plainDSN := regexp.MustCompile(`password=[^ ]*`) - connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx") + quotedKV := regexp.MustCompile(`password='[^']*'`) + connString = quotedKV.ReplaceAllLiteralString(connString, "password=xxxxx") + plainKV := regexp.MustCompile(`password=[^ ]*`) + connString = plainKV.ReplaceAllLiteralString(connString, "password=xxxxx") brokenURL := regexp.MustCompile(`:[^:@]+?@`) connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@") return connString diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go index 8f602e40..7efb522a 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go @@ -18,8 +18,8 @@ import ( "github.com/jackc/pgx/v5/internal/iobufpool" "github.com/jackc/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/pgconn/ctxwatch" "github.com/jackc/pgx/v5/pgconn/internal/bgreader" - "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" "github.com/jackc/pgx/v5/pgproto3" ) @@ -52,6 +52,12 @@ type LookupFunc func(ctx context.Context, host string) (addrs []string, err erro // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend +// PgErrorHandler is a function that handles errors returned from Postgres. This function must return true to keep +// the connection open. Returning false will cause the connection to be closed immediately. You should return +// false on any FATAL-severity errors. This will not receive network errors. The *PgConn is provided so the handler is +// aware of the origin of the error, but it must not invoke any query method. +type PgErrorHandler func(*PgConn, *PgError) bool + // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin // of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY @@ -74,6 +80,9 @@ type PgConn struct { frontend *pgproto3.Frontend bgReader *bgreader.BGReader slowWriteTimer *time.Timer + bgReaderStarted chan struct{} + + customData map[string]any config *Config @@ -96,8 +105,9 @@ type PgConn struct { cleanupDone chan struct{} } -// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) -// to provide configuration. See documentation for [ParseConfig] for details. ctx can be used to cancel a connect attempt. +// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or keyword/value +// format) to provide configuration. See documentation for [ParseConfig] for details. ctx can be used to cancel a +// connect attempt. func Connect(ctx context.Context, connString string) (*PgConn, error) { config, err := ParseConfig(connString) if err != nil { @@ -107,9 +117,9 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) { return ConnectConfig(ctx, config) } -// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) -// and ParseConfigOptions to provide additional configuration. See documentation for [ParseConfig] for details. ctx can be -// used to cancel a connect attempt. +// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or keyword/value +// format) and ParseConfigOptions to provide additional configuration. See documentation for [ParseConfig] for details. +// ctx can be used to cancel a connect attempt. func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) { config, err := ParseConfigWithOptions(connString, parseConfigOptions) if err != nil { @@ -124,113 +134,77 @@ func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptio // // If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An // authentication error will terminate the chain of attempts (like libpq: -// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, -// if all attempts fail the last error is returned. -func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err error) { +// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. +func ConnectConfig(ctx context.Context, config *Config) (*PgConn, error) { // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from // zero values. if !config.createdByParseConfig { panic("config must be created by ParseConfig") } - // Simplify usage by treating primary config and fallbacks the same. - fallbackConfigs := []*FallbackConfig{ - { - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, - }, - } - fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) - ctx := octx - fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs) - if err != nil { - return nil, &connectError{config: config, msg: "hostname resolving error", err: err} - } + var allErrors []error - if len(fallbackConfigs) == 0 { - return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} - } - - foundBestServer := false - var fallbackConfig *FallbackConfig - for i, fc := range fallbackConfigs { - // ConnectTimeout restricts the whole connection process. - if config.ConnectTimeout != 0 { - // create new context first time or when previous host was different - if i == 0 || (fallbackConfigs[i].Host != fallbackConfigs[i-1].Host) { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout) - defer cancel() - } - } else { - ctx = octx - } - pgConn, err = connect(ctx, config, fc, false) - if err == nil { - foundBestServer = true - break - } else if pgerr, ok := err.(*PgError); ok { - err = &connectError{config: config, msg: "server error", err: pgerr} - const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password - const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings - const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist - const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege - if pgerr.Code == ERRCODE_INVALID_PASSWORD || - pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && fc.TLSConfig != nil || - pgerr.Code == ERRCODE_INVALID_CATALOG_NAME || - pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { - break - } - } else if cerr, ok := err.(*connectError); ok { - if _, ok := cerr.err.(*NotPreferredError); ok { - fallbackConfig = fc - } - } + connectConfigs, errs := buildConnectOneConfigs(ctx, config) + if len(errs) > 0 { + allErrors = append(allErrors, errs...) } - if !foundBestServer && fallbackConfig != nil { - pgConn, err = connect(ctx, config, fallbackConfig, true) - if pgerr, ok := err.(*PgError); ok { - err = &connectError{config: config, msg: "server error", err: pgerr} - } + if len(connectConfigs) == 0 { + return nil, &ConnectError{Config: config, err: fmt.Errorf("hostname resolving error: %w", errors.Join(allErrors...))} } - if err != nil { - return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError + pgConn, errs := connectPreferred(ctx, config, connectConfigs) + if len(errs) > 0 { + allErrors = append(allErrors, errs...) + return nil, &ConnectError{Config: config, err: errors.Join(allErrors...)} } if config.AfterConnect != nil { err := config.AfterConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "AfterConnect error", err: err} + return nil, &ConnectError{Config: config, err: fmt.Errorf("AfterConnect error: %w", err)} } } return pgConn, nil } -func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) { - var configs []*FallbackConfig +// buildConnectOneConfigs resolves hostnames and builds a list of connectOneConfigs to try connecting to. It returns a +// slice of successfully resolved connectOneConfigs and a slice of errors. It is possible for both slices to contain +// values if some hosts were successfully resolved and others were not. +func buildConnectOneConfigs(ctx context.Context, config *Config) ([]*connectOneConfig, []error) { + // Simplify usage by treating primary config and fallbacks the same. + fallbackConfigs := []*FallbackConfig{ + { + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + } + fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) + + var configs []*connectOneConfig - var lookupErrors []error + var allErrors []error - for _, fb := range fallbacks { + for _, fb := range fallbackConfigs { // skip resolve for unix sockets if isAbsolutePath(fb.Host) { - configs = append(configs, &FallbackConfig{ - Host: fb.Host, - Port: fb.Port, - TLSConfig: fb.TLSConfig, + network, address := NetworkAddress(fb.Host, fb.Port) + configs = append(configs, &connectOneConfig{ + network: network, + address: address, + originalHostname: fb.Host, + tlsConfig: fb.TLSConfig, }) continue } - ips, err := lookupFn(ctx, fb.Host) + ips, err := config.LookupFunc(ctx, fb.Host) if err != nil { - lookupErrors = append(lookupErrors, err) + allErrors = append(allErrors, err) continue } @@ -239,70 +213,139 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba if err == nil { port, err := strconv.ParseUint(splitPort, 10, 16) if err != nil { - return nil, fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err) + return nil, []error{fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err)} } - configs = append(configs, &FallbackConfig{ - Host: splitIP, - Port: uint16(port), - TLSConfig: fb.TLSConfig, + network, address := NetworkAddress(splitIP, uint16(port)) + configs = append(configs, &connectOneConfig{ + network: network, + address: address, + originalHostname: fb.Host, + tlsConfig: fb.TLSConfig, }) } else { - configs = append(configs, &FallbackConfig{ - Host: ip, - Port: fb.Port, - TLSConfig: fb.TLSConfig, + network, address := NetworkAddress(ip, fb.Port) + configs = append(configs, &connectOneConfig{ + network: network, + address: address, + originalHostname: fb.Host, + tlsConfig: fb.TLSConfig, }) } } } - // See https://github.com/jackc/pgx/issues/1464. When Go 1.20 can be used in pgx consider using errors.Join so all - // errors are reported. - if len(configs) == 0 && len(lookupErrors) > 0 { - return nil, lookupErrors[0] + return configs, allErrors +} + +// connectPreferred attempts to connect to the preferred host from connectOneConfigs. The connections are attempted in +// order. If a connection is successful it is returned. If no connection is successful then all errors are returned. If +// a connection attempt returns a [NotPreferredError], then that host will be used if no other hosts are successful. +func connectPreferred(ctx context.Context, config *Config, connectOneConfigs []*connectOneConfig) (*PgConn, []error) { + octx := ctx + var allErrors []error + + var fallbackConnectOneConfig *connectOneConfig + for i, c := range connectOneConfigs { + // ConnectTimeout restricts the whole connection process. + if config.ConnectTimeout != 0 { + // create new context first time or when previous host was different + if i == 0 || (connectOneConfigs[i].address != connectOneConfigs[i-1].address) { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout) + defer cancel() + } + } else { + ctx = octx + } + + pgConn, err := connectOne(ctx, config, c, false) + if pgConn != nil { + return pgConn, nil + } + + allErrors = append(allErrors, err) + + var pgErr *PgError + if errors.As(err, &pgErr) { + const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password + const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings + const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist + const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege + if pgErr.Code == ERRCODE_INVALID_PASSWORD || + pgErr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && c.tlsConfig != nil || + pgErr.Code == ERRCODE_INVALID_CATALOG_NAME || + pgErr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { + return nil, allErrors + } + } + + var npErr *NotPreferredError + if errors.As(err, &npErr) { + fallbackConnectOneConfig = c + } + } + + if fallbackConnectOneConfig != nil { + pgConn, err := connectOne(ctx, config, fallbackConnectOneConfig, true) + if err == nil { + return pgConn, nil + } + allErrors = append(allErrors, err) } - return configs, nil + return nil, allErrors } -func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig, +// connectOne makes one connection attempt to a single host. +func connectOne(ctx context.Context, config *Config, connectConfig *connectOneConfig, ignoreNotPreferredErr bool, ) (*PgConn, error) { pgConn := new(PgConn) pgConn.config = config pgConn.cleanupDone = make(chan struct{}) + pgConn.customData = make(map[string]any) var err error - network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) - netConn, err := config.DialFunc(ctx, network, address) - if err != nil { - return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} + + newPerDialConnectError := func(msg string, err error) *perDialConnectError { + err = normalizeTimeoutError(ctx, err) + e := &perDialConnectError{address: connectConfig.address, originalHostname: connectConfig.originalHostname, err: fmt.Errorf("%s: %w", msg, err)} + return e } - pgConn.conn = netConn - pgConn.contextWatcher = newContextWatcher(netConn) - pgConn.contextWatcher.Watch(ctx) + pgConn.conn, err = config.DialFunc(ctx, connectConfig.network, connectConfig.address) + if err != nil { + return nil, newPerDialConnectError("dial error", err) + } - if fallbackConfig.TLSConfig != nil { - nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig) + if connectConfig.tlsConfig != nil { + pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn}) + pgConn.contextWatcher.Watch(ctx) + tlsConn, err := startTLS(pgConn.conn, connectConfig.tlsConfig) pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { - netConn.Close() - return nil, &connectError{config: config, msg: "tls error", err: err} + pgConn.conn.Close() + return nil, newPerDialConnectError("tls error", err) } - pgConn.conn = nbTLSConn - pgConn.contextWatcher = newContextWatcher(nbTLSConn) - pgConn.contextWatcher.Watch(ctx) + pgConn.conn = tlsConn } + pgConn.contextWatcher = ctxwatch.NewContextWatcher(config.BuildContextWatcherHandler(pgConn)) + pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() pgConn.parameterStatuses = make(map[string]string) pgConn.status = connStatusConnecting pgConn.bgReader = bgreader.New(pgConn.conn) - pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start) + pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), + func() { + pgConn.bgReader.Start() + pgConn.bgReaderStarted <- struct{}{} + }, + ) pgConn.slowWriteTimer.Stop() + pgConn.bgReaderStarted = make(chan struct{}) pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn) startupMsg := pgproto3.StartupMessage{ @@ -323,7 +366,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.frontend.Send(&startupMsg) if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)} + return nil, newPerDialConnectError("failed to write startup message", err) } for { @@ -331,9 +374,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if err != nil { pgConn.conn.Close() if err, ok := err.(*PgError); ok { - return nil, err + return nil, newPerDialConnectError("server error", err) } - return nil, &connectError{config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)} + return nil, newPerDialConnectError("failed to receive message", err) } switch msg := msg.(type) { @@ -346,26 +389,26 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig err = pgConn.txPasswordMessage(pgConn.config.Password) if err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed to write password message", err: err} + return nil, newPerDialConnectError("failed to write password message", err) } case *pgproto3.AuthenticationMD5Password: digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) err = pgConn.txPasswordMessage(digestedPassword) if err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed to write password message", err: err} + return nil, newPerDialConnectError("failed to write password message", err) } case *pgproto3.AuthenticationSASL: err = pgConn.scramAuth(msg.AuthMechanisms) if err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed SASL auth", err: err} + return nil, newPerDialConnectError("failed SASL auth", err) } case *pgproto3.AuthenticationGSS: err = pgConn.gssAuth() if err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed GSS auth", err: err} + return nil, newPerDialConnectError("failed GSS auth", err) } case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle @@ -383,7 +426,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return pgConn, nil } pgConn.conn.Close() - return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err} + return nil, newPerDialConnectError("ValidateConnect failed", err) } } return pgConn, nil @@ -391,21 +434,14 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig // handled by ReceiveMessage case *pgproto3.ErrorResponse: pgConn.conn.Close() - return nil, ErrorResponseToPgError(msg) + return nil, newPerDialConnectError("server error", ErrorResponseToPgError(msg)) default: pgConn.conn.Close() - return nil, &connectError{config: config, msg: "received unexpected message", err: err} + return nil, newPerDialConnectError("received unexpected message", err) } } } -func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { - return ctxwatch.NewContextWatcher( - func() { conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, - func() { conn.SetDeadline(time.Time{}) }, - ) -} - func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { @@ -540,11 +576,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { case *pgproto3.ParameterStatus: pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: - if msg.Severity == "FATAL" { + err := ErrorResponseToPgError(msg) + if pgConn.config.OnPgError != nil && !pgConn.config.OnPgError(pgConn, err) { pgConn.status = connStatusClosed pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return. close(pgConn.cleanupDone) - return nil, ErrorResponseToPgError(msg) + return nil, err } case *pgproto3.NoticeResponse: if pgConn.config.OnNotice != nil { @@ -593,7 +630,7 @@ func (pgConn *PgConn) Frontend() *pgproto3.Frontend { return pgConn.frontend } -// Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by +// Close closes a connection. It is safe to call Close on an already closed connection. Close attempts a clean close by // sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The // underlying net.Conn.Close() will always be called regardless of any other errors. func (pgConn *PgConn) Close(ctx context.Context) error { @@ -806,6 +843,9 @@ type StatementDescription struct { // Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This // allows Prepare to also to describe statements without creating a server-side prepared statement. +// +// Prepare does not send a PREPARE statement to the server. It uses the PostgreSQL Parse and Describe protocol messages +// directly. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) { if err := pgConn.lock(); err != nil { return nil, err @@ -862,26 +902,73 @@ readloop: return psd, nil } +// Deallocate deallocates a prepared statement. +// +// Deallocate does not send a DEALLOCATE statement to the server. It uses the PostgreSQL Close protocol message +// directly. This has slightly different behavior than executing DEALLOCATE statement. +// - Deallocate can succeed in an aborted transaction. +// - Deallocating a non-existent prepared statement is not an error. +func (pgConn *PgConn) Deallocate(ctx context.Context, name string) error { + if err := pgConn.lock(); err != nil { + return err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + pgConn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name}) + pgConn.frontend.SendSync(&pgproto3.Sync{}) + err := pgConn.flushWithPotentialWriteReadDeadlock() + if err != nil { + pgConn.asyncClose() + return err + } + + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return normalizeTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + return ErrorResponseToPgError(msg) + case *pgproto3.ReadyForQuery: + return nil + } + } +} + // ErrorResponseToPgError converts a wire protocol error message to a *PgError. func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ - Severity: msg.Severity, - Code: string(msg.Code), - Message: string(msg.Message), - Detail: string(msg.Detail), - Hint: msg.Hint, - Position: msg.Position, - InternalPosition: msg.InternalPosition, - InternalQuery: string(msg.InternalQuery), - Where: string(msg.Where), - SchemaName: string(msg.SchemaName), - TableName: string(msg.TableName), - ColumnName: string(msg.ColumnName), - DataTypeName: string(msg.DataTypeName), - ConstraintName: msg.ConstraintName, - File: string(msg.File), - Line: msg.Line, - Routine: string(msg.Routine), + Severity: msg.Severity, + SeverityUnlocalized: msg.SeverityUnlocalized, + Code: string(msg.Code), + Message: string(msg.Message), + Detail: string(msg.Detail), + Hint: msg.Hint, + Position: msg.Position, + InternalPosition: msg.InternalPosition, + InternalQuery: string(msg.InternalQuery), + Where: string(msg.Where), + SchemaName: string(msg.SchemaName), + TableName: string(msg.TableName), + ColumnName: string(msg.ColumnName), + DataTypeName: string(msg.DataTypeName), + ConstraintName: msg.ConstraintName, + File: string(msg.File), + Line: msg.Line, + Routine: string(msg.Routine), } } @@ -924,10 +1011,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { defer cancelConn.Close() if ctx != context.Background() { - contextWatcher := ctxwatch.NewContextWatcher( - func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, - func() { cancelConn.SetDeadline(time.Time{}) }, - ) + contextWatcher := ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: cancelConn}) contextWatcher.Watch(ctx) defer contextWatcher.Unwatch() } @@ -935,16 +1019,21 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { buf := make([]byte, 16) binary.BigEndian.PutUint32(buf[0:4], 16) binary.BigEndian.PutUint32(buf[4:8], 80877102) - binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid)) - binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) - // Postgres will process the request and close the connection - // so when don't need to read the reply - // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.6.7.10 - _, err = cancelConn.Write(buf) - return err + binary.BigEndian.PutUint32(buf[8:12], pgConn.pid) + binary.BigEndian.PutUint32(buf[12:16], pgConn.secretKey) + + if _, err := cancelConn.Write(buf); err != nil { + return fmt.Errorf("write to connection for cancellation: %w", err) + } + + // Wait for the cancel request to be acknowledged by the server. + // It copies the behavior of the libpq: https://github.com/postgres/postgres/blob/REL_16_0/src/interfaces/libpq/fe-connect.c#L4946-L4960 + _, _ = cancelConn.Read(buf) + + return nil } -// WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not +// WaitForNotification waits for a LISTEN/NOTIFY message to be received. It returns an error if a notification was not // received. func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { if err := pgConn.lock(); err != nil { @@ -1455,8 +1544,10 @@ func (rr *ResultReader) Read() *Result { values := rr.Values() row := make([][]byte, len(values)) for i := range row { - row[i] = make([]byte, len(values[i])) - copy(row[i], values[i]) + if values[i] != nil { + row[i] = make([]byte, len(values[i])) + copy(row[i], values[i]) + } } br.Rows = append(br.Rows, row) } @@ -1606,25 +1697,55 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { // Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. type Batch struct { buf []byte + err error } // ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { - batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) + if batch.err != nil { + return + } batch.ExecPrepared("", paramValues, paramFormats, resultFormats) } // ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { - batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) - batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) - batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf) + if batch.err != nil { + return + } } // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a // transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing // multiple queries in a single round trip than using pipeline mode. func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { + if batch.err != nil { + return &MultiResultReader{ + closed: true, + err: batch.err, + } + } + if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, @@ -1650,7 +1771,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR pgConn.contextWatcher.Watch(ctx) } - batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) + batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf) + if batch.err != nil { + multiResult.closed = true + multiResult.err = batch.err + pgConn.unlock() + return multiResult + } pgConn.enterPotentialWriteReadDeadlock() defer pgConn.exitPotentialWriteReadDeadlock() @@ -1732,10 +1859,16 @@ func (pgConn *PgConn) enterPotentialWriteReadDeadlock() { // exitPotentialWriteReadDeadlock must be called after a call to enterPotentialWriteReadDeadlock. func (pgConn *PgConn) exitPotentialWriteReadDeadlock() { - // The state of the timer is not relevant upon exiting the potential slow write. It may both - // fire (due to a slow write), or not fire (due to a fast write). - _ = pgConn.slowWriteTimer.Stop() - pgConn.bgReader.Stop() + if !pgConn.slowWriteTimer.Stop() { + // The timer starts its function in a separate goroutine. It is necessary to ensure the background reader has + // started before calling Stop. Otherwise, the background reader may not be stopped. That on its own is not a + // serious problem. But what is a serious problem is that the background reader may start at an inopportune time in + // a subsequent query. For example, if a subsequent query was canceled then a deadline may be set on the net.Conn to + // interrupt an in-progress read. After the read is interrupted, but before the deadline is cleared, the background + // reader could start and read a deadline error. Then the next query would receive the an unexpected deadline error. + <-pgConn.bgReaderStarted + pgConn.bgReader.Stop() + } } func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error { @@ -1764,11 +1897,16 @@ func (pgConn *PgConn) SyncConn(ctx context.Context) error { } } - // This should never happen. Only way I can imagine this occuring is if the server is constantly sending data such as + // This should never happen. Only way I can imagine this occurring is if the server is constantly sending data such as // LISTEN/NOTIFY or log notifications such that we never can get an empty buffer. return errors.New("SyncConn: conn never synchronized") } +// CustomData returns a map that can be used to associate custom data with the connection. +func (pgConn *PgConn) CustomData() map[string]any { + return pgConn.customData +} + // HijackedConn is the result of hijacking a connection. // // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning @@ -1781,6 +1919,7 @@ type HijackedConn struct { TxStatus byte Frontend *pgproto3.Frontend Config *Config + CustomData map[string]any } // Hijack extracts the internal connection data. pgConn must be in an idle state. SyncConn should be called immediately @@ -1803,6 +1942,7 @@ func (pgConn *PgConn) Hijack() (*HijackedConn, error) { TxStatus: pgConn.txStatus, Frontend: pgConn.frontend, Config: pgConn.config, + CustomData: pgConn.customData, }, nil } @@ -1822,16 +1962,23 @@ func Construct(hc *HijackedConn) (*PgConn, error) { txStatus: hc.TxStatus, frontend: hc.Frontend, config: hc.Config, + customData: hc.CustomData, status: connStatusIdle, cleanupDone: make(chan struct{}), } - pgConn.contextWatcher = newContextWatcher(pgConn.conn) + pgConn.contextWatcher = ctxwatch.NewContextWatcher(hc.Config.BuildContextWatcherHandler(pgConn)) pgConn.bgReader = bgreader.New(pgConn.conn) - pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start) + pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), + func() { + pgConn.bgReader.Start() + pgConn.bgReaderStarted <- struct{}{} + }, + ) pgConn.slowWriteTimer.Stop() + pgConn.bgReaderStarted = make(chan struct{}) pgConn.frontend = hc.Config.BuildFrontend(pgConn.bgReader, pgConn.conn) return pgConn, nil @@ -1973,6 +2120,13 @@ func (p *Pipeline) Flush() error { // Sync establishes a synchronization point and flushes the queued requests. func (p *Pipeline) Sync() error { + if p.closed { + if p.err != nil { + return p.err + } + return errors.New("pipeline closed") + } + p.conn.frontend.SendSync(&pgproto3.Sync{}) err := p.Flush() if err != nil { @@ -1989,14 +2143,28 @@ func (p *Pipeline) Sync() error { // *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError. If no // results are available, results and err will both be nil. func (p *Pipeline) GetResults() (results any, err error) { + if p.closed { + if p.err != nil { + return nil, p.err + } + return nil, errors.New("pipeline closed") + } + if p.expectedReadyForQueryCount == 0 { return nil, nil } + return p.getResults() +} + +func (p *Pipeline) getResults() (results any, err error) { for { msg, err := p.conn.receiveMessage() if err != nil { - return nil, err + p.closed = true + p.err = err + p.conn.asyncClose() + return nil, normalizeTimeoutError(p.ctx, err) } switch msg := msg.(type) { @@ -2018,7 +2186,8 @@ func (p *Pipeline) GetResults() (results any, err error) { case *pgproto3.ParseComplete: peekedMsg, err := p.conn.peekMessage() if err != nil { - return nil, err + p.conn.asyncClose() + return nil, normalizeTimeoutError(p.ctx, err) } if _, ok := peekedMsg.(*pgproto3.ParameterDescription); ok { return p.getResultsPrepare() @@ -2078,6 +2247,7 @@ func (p *Pipeline) Close() error { if p.closed { return p.err } + p.closed = true if p.pendingSync { @@ -2090,7 +2260,7 @@ func (p *Pipeline) Close() error { } for p.expectedReadyForQueryCount > 0 { - _, err := p.GetResults() + _, err := p.getResults() if err != nil { p.err = err var pgErr *PgError @@ -2106,3 +2276,71 @@ func (p *Pipeline) Close() error { return p.err } + +// DeadlineContextWatcherHandler handles canceled contexts by setting a deadline on a net.Conn. +type DeadlineContextWatcherHandler struct { + Conn net.Conn + + // DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled. + DeadlineDelay time.Duration +} + +func (h *DeadlineContextWatcherHandler) HandleCancel(ctx context.Context) { + h.Conn.SetDeadline(time.Now().Add(h.DeadlineDelay)) +} + +func (h *DeadlineContextWatcherHandler) HandleUnwatchAfterCancel() { + h.Conn.SetDeadline(time.Time{}) +} + +// CancelRequestContextWatcherHandler handles canceled contexts by sending a cancel request to the server. It also sets +// a deadline on a net.Conn as a fallback. +type CancelRequestContextWatcherHandler struct { + Conn *PgConn + + // CancelRequestDelay is the delay before sending the cancel request to the server. + CancelRequestDelay time.Duration + + // DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled. + DeadlineDelay time.Duration + + cancelFinishedChan chan struct{} + handleUnwatchAfterCancelCalled func() +} + +func (h *CancelRequestContextWatcherHandler) HandleCancel(context.Context) { + h.cancelFinishedChan = make(chan struct{}) + var handleUnwatchedAfterCancelCalledCtx context.Context + handleUnwatchedAfterCancelCalledCtx, h.handleUnwatchAfterCancelCalled = context.WithCancel(context.Background()) + + deadline := time.Now().Add(h.DeadlineDelay) + h.Conn.conn.SetDeadline(deadline) + + go func() { + defer close(h.cancelFinishedChan) + + select { + case <-handleUnwatchedAfterCancelCalledCtx.Done(): + return + case <-time.After(h.CancelRequestDelay): + } + + cancelRequestCtx, cancel := context.WithDeadline(handleUnwatchedAfterCancelCalledCtx, deadline) + defer cancel() + h.Conn.CancelRequest(cancelRequestCtx) + + // CancelRequest is inherently racy. Even though the cancel request has been received by the server at this point, + // it hasn't necessarily been delivered to the other connection. If we immediately return and the connection is + // immediately used then it is possible the CancelRequest will actually cancel our next query. The + // TestCancelRequestContextWatcherHandler Stress test can produce this error without the sleep below. The sleep time + // is arbitrary, but should be sufficient to prevent this error case. + time.Sleep(100 * time.Millisecond) + }() +} + +func (h *CancelRequestContextWatcherHandler) HandleUnwatchAfterCancel() { + h.handleUnwatchAfterCancelCalled() + <-h.cancelFinishedChan + + h.Conn.conn.SetDeadline(time.Time{}) +} diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/README.md b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/README.md index 79d3a68b..7a26f1cb 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/README.md +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/README.md @@ -1,6 +1,6 @@ # pgproto3 -Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3. +Package pgproto3 is an encoder and decoder of the PostgreSQL wire protocol version 3. pgproto3 can be used as a foundation for PostgreSQL drivers, proxies, mock servers, load balancers and more. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_cleartext_password.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_cleartext_password.go index d8f98b9a..ac2962e9 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_cleartext_password.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_cleartext_password.go @@ -35,11 +35,10 @@ func (dst *AuthenticationCleartextPassword) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, 8) +func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_gss.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_gss.go index 0d234222..178ef31d 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_gss.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_gss.go @@ -27,11 +27,10 @@ func (a *AuthenticationGSS) Decode(src []byte) error { return nil } -func (a *AuthenticationGSS) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, 4) +func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeGSS) - return dst + return finishMessage(dst, sp) } func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) { diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_gss_continue.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_gss_continue.go index 63789dc1..2ba3f3b3 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_gss_continue.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_gss_continue.go @@ -31,12 +31,11 @@ func (a *AuthenticationGSSContinue) Decode(src []byte) error { return nil } -func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, int32(len(a.Data))+8) +func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeGSSCont) dst = append(dst, a.Data...) - return dst + return finishMessage(dst, sp) } func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) { diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_md5_password.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_md5_password.go index 5671c84c..854c6404 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_md5_password.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_md5_password.go @@ -38,12 +38,11 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationMD5Password) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, 12) +func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeMD5Password) dst = append(dst, src.Salt[:]...) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_ok.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_ok.go index 88d648ae..ec11d39f 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_ok.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_ok.go @@ -35,11 +35,10 @@ func (dst *AuthenticationOk) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationOk) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, 8) +func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeOk) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_sasl.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_sasl.go index 59650d4c..e66580f4 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_sasl.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_sasl.go @@ -47,10 +47,8 @@ func (dst *AuthenticationSASL) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationSASL) Encode(dst []byte) []byte { - dst = append(dst, 'R') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeSASL) for _, s := range src.AuthMechanisms { @@ -59,9 +57,7 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte { } dst = append(dst, 0) - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_sasl_continue.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_sasl_continue.go index 2ce70a47..70fba4a6 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_sasl_continue.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_sasl_continue.go @@ -38,17 +38,11 @@ func (dst *AuthenticationSASLContinue) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte { - dst = append(dst, 'R') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeSASLContinue) - dst = append(dst, src.Data...) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_sasl_final.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_sasl_final.go index a38a8b91..84976c2a 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_sasl_final.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_sasl_final.go @@ -38,17 +38,11 @@ func (dst *AuthenticationSASLFinal) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte { - dst = append(dst, 'R') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeSASLFinal) - dst = append(dst, src.Data...) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Unmarshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/backend.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/backend.go index 6db77e4a..d146c338 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/backend.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/backend.go @@ -16,7 +16,8 @@ type Backend struct { // before it is actually transmitted (i.e. before Flush). tracer *tracer - wbuf []byte + wbuf []byte + encodeError error // Frontend message flyweights bind Bind @@ -38,6 +39,7 @@ type Backend struct { terminate Terminate bodyLen int + maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error. msgType byte partialMsg bool authType uint32 @@ -54,11 +56,21 @@ func NewBackend(r io.Reader, w io.Writer) *Backend { return &Backend{cr: cr, w: w} } -// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is -// called. +// Send sends a message to the frontend (i.e. the client). The message is buffered until Flush is called. Any error +// encountered will be returned from Flush. func (b *Backend) Send(msg BackendMessage) { + if b.encodeError != nil { + return + } + prevLen := len(b.wbuf) - b.wbuf = msg.Encode(b.wbuf) + newBuf, err := msg.Encode(b.wbuf) + if err != nil { + b.encodeError = err + return + } + b.wbuf = newBuf + if b.tracer != nil { b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg) } @@ -66,6 +78,12 @@ func (b *Backend) Send(msg BackendMessage) { // Flush writes any pending messages to the frontend (i.e. the client). func (b *Backend) Flush() error { + if err := b.encodeError; err != nil { + b.encodeError = nil + b.wbuf = b.wbuf[:0] + return &writeError{err: err, safeToRetry: true} + } + n, err := b.w.Write(b.wbuf) const maxLen = 1024 @@ -158,6 +176,9 @@ func (b *Backend) Receive() (FrontendMessage, error) { b.msgType = header[0] b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 + if b.maxBodyLen > 0 && b.bodyLen > b.maxBodyLen { + return nil, &ExceededMaxBodyLenErr{b.maxBodyLen, b.bodyLen} + } b.partialMsg = true } @@ -260,3 +281,12 @@ func (b *Backend) SetAuthType(authType uint32) error { return nil } + +// SetMaxBodyLen sets the maximum length of a message body in octets. If a message body exceeds this length, Receive will return +// an error. This is useful for protecting against malicious clients that send large messages with the intent of +// causing memory exhaustion. +// The default value is 0. +// If maxBodyLen is 0, then no maximum is enforced. +func (b *Backend) SetMaxBodyLen(maxBodyLen int) { + b.maxBodyLen = maxBodyLen +} diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/backend_key_data.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/backend_key_data.go index 12c60817..23f5da67 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/backend_key_data.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/backend_key_data.go @@ -29,12 +29,11 @@ func (dst *BackendKeyData) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *BackendKeyData) Encode(dst []byte) []byte { - dst = append(dst, 'K') - dst = pgio.AppendUint32(dst, 12) +func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'K') dst = pgio.AppendUint32(dst, src.ProcessID) dst = pgio.AppendUint32(dst, src.SecretKey) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go index fdd2d3b8..ad6ac48b 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go @@ -5,7 +5,9 @@ import ( "encoding/binary" "encoding/hex" "encoding/json" + "errors" "fmt" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -108,21 +110,25 @@ func (dst *Bind) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Bind) Encode(dst []byte) []byte { - dst = append(dst, 'B') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *Bind) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'B') dst = append(dst, src.DestinationPortal...) dst = append(dst, 0) dst = append(dst, src.PreparedStatement...) dst = append(dst, 0) + if len(src.ParameterFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many parameter format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes))) for _, fc := range src.ParameterFormatCodes { dst = pgio.AppendInt16(dst, fc) } + if len(src.Parameters) > math.MaxUint16 { + return nil, errors.New("too many parameters") + } dst = pgio.AppendUint16(dst, uint16(len(src.Parameters))) for _, p := range src.Parameters { if p == nil { @@ -134,14 +140,15 @@ func (src *Bind) Encode(dst []byte) []byte { dst = append(dst, p...) } + if len(src.ResultFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many result format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes))) for _, fc := range src.ResultFormatCodes { dst = pgio.AppendInt16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/bind_complete.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/bind_complete.go index 3be256c8..bacf30d8 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/bind_complete.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/bind_complete.go @@ -20,8 +20,8 @@ func (dst *BindComplete) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *BindComplete) Encode(dst []byte) []byte { - return append(dst, '2', 0, 0, 0, 4) +func (src *BindComplete) Encode(dst []byte) ([]byte, error) { + return append(dst, '2', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/cancel_request.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/cancel_request.go index 8fcf8217..6b52dd97 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/cancel_request.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/cancel_request.go @@ -36,12 +36,12 @@ func (dst *CancelRequest) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 4 byte message length. -func (src *CancelRequest) Encode(dst []byte) []byte { +func (src *CancelRequest) Encode(dst []byte) ([]byte, error) { dst = pgio.AppendInt32(dst, 16) dst = pgio.AppendInt32(dst, cancelRequestCode) dst = pgio.AppendUint32(dst, src.ProcessID) dst = pgio.AppendUint32(dst, src.SecretKey) - return dst + return dst, nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/close.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/close.go index f99b5943..0b50f27c 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/close.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/close.go @@ -4,8 +4,6 @@ import ( "bytes" "encoding/json" "errors" - - "github.com/jackc/pgx/v5/internal/pgio" ) type Close struct { @@ -37,18 +35,12 @@ func (dst *Close) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Close) Encode(dst []byte) []byte { - dst = append(dst, 'C') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *Close) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'C') dst = append(dst, src.ObjectType) dst = append(dst, src.Name...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/close_complete.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/close_complete.go index 1d7b8f08..833f7a12 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/close_complete.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/close_complete.go @@ -20,8 +20,8 @@ func (dst *CloseComplete) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CloseComplete) Encode(dst []byte) []byte { - return append(dst, '3', 0, 0, 0, 4) +func (src *CloseComplete) Encode(dst []byte) ([]byte, error) { + return append(dst, '3', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/command_complete.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/command_complete.go index 814027ca..eba70947 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/command_complete.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/command_complete.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type CommandComplete struct { @@ -31,17 +29,11 @@ func (dst *CommandComplete) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CommandComplete) Encode(dst []byte) []byte { - dst = append(dst, 'C') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *CommandComplete) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'C') dst = append(dst, src.CommandTag...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_both_response.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_both_response.go index 8840a89e..99e1afea 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_both_response.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_both_response.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -44,19 +45,18 @@ func (dst *CopyBothResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyBothResponse) Encode(dst []byte) []byte { - dst = append(dst, 'W') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'W') dst = append(dst, src.OverallFormat) + if len(src.ColumnFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many column format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_data.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_data.go index 59e3dd94..89ecdd4d 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_data.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_data.go @@ -3,8 +3,6 @@ package pgproto3 import ( "encoding/hex" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type CopyData struct { @@ -25,11 +23,10 @@ func (dst *CopyData) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyData) Encode(dst []byte) []byte { - dst = append(dst, 'd') - dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) +func (src *CopyData) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'd') dst = append(dst, src.Data...) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_done.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_done.go index 0e13282b..040814db 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_done.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_done.go @@ -24,8 +24,8 @@ func (dst *CopyDone) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyDone) Encode(dst []byte) []byte { - return append(dst, 'c', 0, 0, 0, 4) +func (src *CopyDone) Encode(dst []byte) ([]byte, error) { + return append(dst, 'c', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_fail.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_fail.go index 0041bbb1..72a85fd0 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_fail.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_fail.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type CopyFail struct { @@ -28,17 +26,11 @@ func (dst *CopyFail) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyFail) Encode(dst []byte) []byte { - dst = append(dst, 'f') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *CopyFail) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'f') dst = append(dst, src.Message...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_in_response.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_in_response.go index 4584f7df..06cf99ce 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_in_response.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_in_response.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -44,20 +45,19 @@ func (dst *CopyInResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyInResponse) Encode(dst []byte) []byte { - dst = append(dst, 'G') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'G') dst = append(dst, src.OverallFormat) + if len(src.ColumnFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many column format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_out_response.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_out_response.go index 3175c6a4..549e916c 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_out_response.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/copy_out_response.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -43,21 +44,20 @@ func (dst *CopyOutResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyOutResponse) Encode(dst []byte) []byte { - dst = append(dst, 'H') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'H') dst = append(dst, src.OverallFormat) + if len(src.ColumnFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many column format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/data_row.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/data_row.go index 4de77977..fdfb0f7f 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/data_row.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/data_row.go @@ -4,6 +4,8 @@ import ( "encoding/binary" "encoding/hex" "encoding/json" + "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -63,11 +65,12 @@ func (dst *DataRow) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *DataRow) Encode(dst []byte) []byte { - dst = append(dst, 'D') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *DataRow) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'D') + if len(src.Values) > math.MaxUint16 { + return nil, errors.New("too many values") + } dst = pgio.AppendUint16(dst, uint16(len(src.Values))) for _, v := range src.Values { if v == nil { @@ -79,9 +82,7 @@ func (src *DataRow) Encode(dst []byte) []byte { dst = append(dst, v...) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/describe.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/describe.go index f131d1f4..89feff21 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/describe.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/describe.go @@ -4,8 +4,6 @@ import ( "bytes" "encoding/json" "errors" - - "github.com/jackc/pgx/v5/internal/pgio" ) type Describe struct { @@ -37,18 +35,12 @@ func (dst *Describe) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Describe) Encode(dst []byte) []byte { - dst = append(dst, 'D') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *Describe) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'D') dst = append(dst, src.ObjectType) dst = append(dst, src.Name...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/doc.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/doc.go index e0e1cf87..0afd18e2 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/doc.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/doc.go @@ -1,7 +1,7 @@ -// Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3. +// Package pgproto3 is an encoder and decoder of the PostgreSQL wire protocol version 3. // // The primary interfaces are Frontend and Backend. They correspond to a client and server respectively. Messages are -// sent with Send (or a specialized Send variant). Messages are automatically bufferred to minimize small writes. Call +// sent with Send (or a specialized Send variant). Messages are automatically buffered to minimize small writes. Call // Flush to ensure a message has actually been sent. // // The Trace method of Frontend and Backend can be used to examine the wire-level message traffic. It outputs in a diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/empty_query_response.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/empty_query_response.go index 2b85e744..cb6cca07 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/empty_query_response.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/empty_query_response.go @@ -20,8 +20,8 @@ func (dst *EmptyQueryResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *EmptyQueryResponse) Encode(dst []byte) []byte { - return append(dst, 'I', 0, 0, 0, 4) +func (src *EmptyQueryResponse) Encode(dst []byte) ([]byte, error) { + return append(dst, 'I', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/error_response.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/error_response.go index 45c9a981..6ef9bd06 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/error_response.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/error_response.go @@ -2,7 +2,6 @@ package pgproto3 import ( "bytes" - "encoding/binary" "encoding/json" "strconv" ) @@ -111,119 +110,113 @@ func (dst *ErrorResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ErrorResponse) Encode(dst []byte) []byte { - return append(dst, src.marshalBinary('E')...) +func (src *ErrorResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'E') + dst = src.appendFields(dst) + return finishMessage(dst, sp) } -func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte(typeByte) - buf.Write(bigEndian.Uint32(0)) - +func (src *ErrorResponse) appendFields(dst []byte) []byte { if src.Severity != "" { - buf.WriteByte('S') - buf.WriteString(src.Severity) - buf.WriteByte(0) + dst = append(dst, 'S') + dst = append(dst, src.Severity...) + dst = append(dst, 0) } if src.SeverityUnlocalized != "" { - buf.WriteByte('V') - buf.WriteString(src.SeverityUnlocalized) - buf.WriteByte(0) + dst = append(dst, 'V') + dst = append(dst, src.SeverityUnlocalized...) + dst = append(dst, 0) } if src.Code != "" { - buf.WriteByte('C') - buf.WriteString(src.Code) - buf.WriteByte(0) + dst = append(dst, 'C') + dst = append(dst, src.Code...) + dst = append(dst, 0) } if src.Message != "" { - buf.WriteByte('M') - buf.WriteString(src.Message) - buf.WriteByte(0) + dst = append(dst, 'M') + dst = append(dst, src.Message...) + dst = append(dst, 0) } if src.Detail != "" { - buf.WriteByte('D') - buf.WriteString(src.Detail) - buf.WriteByte(0) + dst = append(dst, 'D') + dst = append(dst, src.Detail...) + dst = append(dst, 0) } if src.Hint != "" { - buf.WriteByte('H') - buf.WriteString(src.Hint) - buf.WriteByte(0) + dst = append(dst, 'H') + dst = append(dst, src.Hint...) + dst = append(dst, 0) } if src.Position != 0 { - buf.WriteByte('P') - buf.WriteString(strconv.Itoa(int(src.Position))) - buf.WriteByte(0) + dst = append(dst, 'P') + dst = append(dst, strconv.Itoa(int(src.Position))...) + dst = append(dst, 0) } if src.InternalPosition != 0 { - buf.WriteByte('p') - buf.WriteString(strconv.Itoa(int(src.InternalPosition))) - buf.WriteByte(0) + dst = append(dst, 'p') + dst = append(dst, strconv.Itoa(int(src.InternalPosition))...) + dst = append(dst, 0) } if src.InternalQuery != "" { - buf.WriteByte('q') - buf.WriteString(src.InternalQuery) - buf.WriteByte(0) + dst = append(dst, 'q') + dst = append(dst, src.InternalQuery...) + dst = append(dst, 0) } if src.Where != "" { - buf.WriteByte('W') - buf.WriteString(src.Where) - buf.WriteByte(0) + dst = append(dst, 'W') + dst = append(dst, src.Where...) + dst = append(dst, 0) } if src.SchemaName != "" { - buf.WriteByte('s') - buf.WriteString(src.SchemaName) - buf.WriteByte(0) + dst = append(dst, 's') + dst = append(dst, src.SchemaName...) + dst = append(dst, 0) } if src.TableName != "" { - buf.WriteByte('t') - buf.WriteString(src.TableName) - buf.WriteByte(0) + dst = append(dst, 't') + dst = append(dst, src.TableName...) + dst = append(dst, 0) } if src.ColumnName != "" { - buf.WriteByte('c') - buf.WriteString(src.ColumnName) - buf.WriteByte(0) + dst = append(dst, 'c') + dst = append(dst, src.ColumnName...) + dst = append(dst, 0) } if src.DataTypeName != "" { - buf.WriteByte('d') - buf.WriteString(src.DataTypeName) - buf.WriteByte(0) + dst = append(dst, 'd') + dst = append(dst, src.DataTypeName...) + dst = append(dst, 0) } if src.ConstraintName != "" { - buf.WriteByte('n') - buf.WriteString(src.ConstraintName) - buf.WriteByte(0) + dst = append(dst, 'n') + dst = append(dst, src.ConstraintName...) + dst = append(dst, 0) } if src.File != "" { - buf.WriteByte('F') - buf.WriteString(src.File) - buf.WriteByte(0) + dst = append(dst, 'F') + dst = append(dst, src.File...) + dst = append(dst, 0) } if src.Line != 0 { - buf.WriteByte('L') - buf.WriteString(strconv.Itoa(int(src.Line))) - buf.WriteByte(0) + dst = append(dst, 'L') + dst = append(dst, strconv.Itoa(int(src.Line))...) + dst = append(dst, 0) } if src.Routine != "" { - buf.WriteByte('R') - buf.WriteString(src.Routine) - buf.WriteByte(0) + dst = append(dst, 'R') + dst = append(dst, src.Routine...) + dst = append(dst, 0) } for k, v := range src.UnknownFields { - buf.WriteByte(k) - buf.WriteString(v) - buf.WriteByte(0) + dst = append(dst, k) + dst = append(dst, v...) + dst = append(dst, 0) } - buf.WriteByte(0) - - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + dst = append(dst, 0) - return buf.Bytes() + return dst } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/execute.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/execute.go index a5fee7cb..31bc714d 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/execute.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/execute.go @@ -36,19 +36,12 @@ func (dst *Execute) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Execute) Encode(dst []byte) []byte { - dst = append(dst, 'E') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *Execute) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'E') dst = append(dst, src.Portal...) dst = append(dst, 0) - dst = pgio.AppendUint32(dst, src.MaxRows) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/flush.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/flush.go index 2725f689..e5dc1fbb 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/flush.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/flush.go @@ -20,8 +20,8 @@ func (dst *Flush) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Flush) Encode(dst []byte) []byte { - return append(dst, 'H', 0, 0, 0, 4) +func (src *Flush) Encode(dst []byte) ([]byte, error) { + return append(dst, 'H', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go index 33c3882a..b41abbe1 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go @@ -18,7 +18,8 @@ type Frontend struct { // idle. Setting and unsetting tracer provides equivalent functionality to PQtrace and PQuntrace in libpq. tracer *tracer - wbuf []byte + wbuf []byte + encodeError error // Backend message flyweights authenticationOk AuthenticationOk @@ -64,16 +65,26 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend { return &Frontend{cr: cr, w: w} } -// Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is -// called. +// Send sends a message to the backend (i.e. the server). The message is buffered until Flush is called. Any error +// encountered will be returned from Flush. // // Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods // such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an // extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden // behind an interface. func (f *Frontend) Send(msg FrontendMessage) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg) } @@ -81,6 +92,12 @@ func (f *Frontend) Send(msg FrontendMessage) { // Flush writes any pending messages to the backend (i.e. the server). func (f *Frontend) Flush() error { + if err := f.encodeError; err != nil { + f.encodeError = nil + f.wbuf = f.wbuf[:0] + return &writeError{err: err, safeToRetry: true} + } + if len(f.wbuf) == 0 { return nil } @@ -116,71 +133,141 @@ func (f *Frontend) Untrace() { f.tracer = nil } -// SendBind sends a Bind message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendBind sends a Bind message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. func (f *Frontend) SendBind(msg *Bind) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendParse sends a Parse message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendParse sends a Parse message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. func (f *Frontend) SendParse(msg *Parse) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendClose sends a Close message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendClose sends a Close message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. func (f *Frontend) SendClose(msg *Close) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendDescribe sends a Describe message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendDescribe sends a Describe message to the backend (i.e. the server). The message is buffered until Flush is +// called. Any error encountered will be returned from Flush. func (f *Frontend) SendDescribe(msg *Describe) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendExecute sends a Execute message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendExecute sends an Execute message to the backend (i.e. the server). The message is buffered until Flush is called. +// Any error encountered will be returned from Flush. func (f *Frontend) SendExecute(msg *Execute) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendSync sends a Sync message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendSync sends a Sync message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. func (f *Frontend) SendSync(msg *Sync) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendQuery sends a Query message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendQuery sends a Query message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. func (f *Frontend) SendQuery(msg *Query) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg) } diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/function_call.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/function_call.go index 2c4f38df..7d83579f 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/function_call.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/function_call.go @@ -2,6 +2,8 @@ package pgproto3 import ( "encoding/binary" + "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -71,15 +73,21 @@ func (dst *FunctionCall) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *FunctionCall) Encode(dst []byte) []byte { - dst = append(dst, 'F') - sp := len(dst) - dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end +func (src *FunctionCall) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'F') dst = pgio.AppendUint32(dst, src.Function) + + if len(src.ArgFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many arg format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes))) for _, argFormatCode := range src.ArgFormatCodes { dst = pgio.AppendUint16(dst, argFormatCode) } + + if len(src.Arguments) > math.MaxUint16 { + return nil, errors.New("too many arguments") + } dst = pgio.AppendUint16(dst, uint16(len(src.Arguments))) for _, argument := range src.Arguments { if argument == nil { @@ -90,6 +98,5 @@ func (src *FunctionCall) Encode(dst []byte) []byte { } } dst = pgio.AppendUint16(dst, src.ResultFormatCode) - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return dst + return finishMessage(dst, sp) } diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/function_call_response.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/function_call_response.go index 3d3606dd..1f273495 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/function_call_response.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/function_call_response.go @@ -39,10 +39,8 @@ func (dst *FunctionCallResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *FunctionCallResponse) Encode(dst []byte) []byte { - dst = append(dst, 'V') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *FunctionCallResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'V') if src.Result == nil { dst = pgio.AppendInt32(dst, -1) @@ -51,9 +49,7 @@ func (src *FunctionCallResponse) Encode(dst []byte) []byte { dst = append(dst, src.Result...) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/gss_enc_request.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/gss_enc_request.go index 30ffc08d..70cb20cd 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/gss_enc_request.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/gss_enc_request.go @@ -31,10 +31,10 @@ func (dst *GSSEncRequest) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 4 byte message length. -func (src *GSSEncRequest) Encode(dst []byte) []byte { +func (src *GSSEncRequest) Encode(dst []byte) ([]byte, error) { dst = pgio.AppendInt32(dst, 8) dst = pgio.AppendInt32(dst, gssEncReqNumber) - return dst + return dst, nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/gss_response.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/gss_response.go index 64bfbd04..10d93775 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/gss_response.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/gss_response.go @@ -2,8 +2,6 @@ package pgproto3 import ( "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type GSSResponse struct { @@ -18,11 +16,10 @@ func (g *GSSResponse) Decode(data []byte) error { return nil } -func (g *GSSResponse) Encode(dst []byte) []byte { - dst = append(dst, 'p') - dst = pgio.AppendInt32(dst, int32(4+len(g.Data))) +func (g *GSSResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') dst = append(dst, g.Data...) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/no_data.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/no_data.go index d8f85d38..cbcaad40 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/no_data.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/no_data.go @@ -20,8 +20,8 @@ func (dst *NoData) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *NoData) Encode(dst []byte) []byte { - return append(dst, 'n', 0, 0, 0, 4) +func (src *NoData) Encode(dst []byte) ([]byte, error) { + return append(dst, 'n', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/notice_response.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/notice_response.go index 4ac28a79..497aba6d 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/notice_response.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/notice_response.go @@ -12,6 +12,8 @@ func (dst *NoticeResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *NoticeResponse) Encode(dst []byte) []byte { - return append(dst, (*ErrorResponse)(src).marshalBinary('N')...) +func (src *NoticeResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'N') + dst = (*ErrorResponse)(src).appendFields(dst) + return finishMessage(dst, sp) } diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/notification_response.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/notification_response.go index 228e0dac..243b6bf7 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/notification_response.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/notification_response.go @@ -45,20 +45,14 @@ func (dst *NotificationResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *NotificationResponse) Encode(dst []byte) []byte { - dst = append(dst, 'A') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *NotificationResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'A') dst = pgio.AppendUint32(dst, src.PID) dst = append(dst, src.Channel...) dst = append(dst, 0) dst = append(dst, src.Payload...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/parameter_description.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/parameter_description.go index 374d38a3..1ef27b75 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/parameter_description.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/parameter_description.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -39,19 +41,18 @@ func (dst *ParameterDescription) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ParameterDescription) Encode(dst []byte) []byte { - dst = append(dst, 't') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *ParameterDescription) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 't') + if len(src.ParameterOIDs) > math.MaxUint16 { + return nil, errors.New("too many parameter oids") + } dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) for _, oid := range src.ParameterOIDs { dst = pgio.AppendUint32(dst, oid) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/parameter_status.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/parameter_status.go index a303e453..9ee0720b 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/parameter_status.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/parameter_status.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type ParameterStatus struct { @@ -37,19 +35,13 @@ func (dst *ParameterStatus) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ParameterStatus) Encode(dst []byte) []byte { - dst = append(dst, 'S') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *ParameterStatus) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'S') dst = append(dst, src.Name...) dst = append(dst, 0) dst = append(dst, src.Value...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/parse.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/parse.go index b53200dc..6ba3486c 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/parse.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/parse.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -52,24 +54,23 @@ func (dst *Parse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Parse) Encode(dst []byte) []byte { - dst = append(dst, 'P') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *Parse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'P') dst = append(dst, src.Name...) dst = append(dst, 0) dst = append(dst, src.Query...) dst = append(dst, 0) + if len(src.ParameterOIDs) > math.MaxUint16 { + return nil, errors.New("too many parameter oids") + } dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) for _, oid := range src.ParameterOIDs { dst = pgio.AppendUint32(dst, oid) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/parse_complete.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/parse_complete.go index 92c9498b..cff9e27d 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/parse_complete.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/parse_complete.go @@ -20,8 +20,8 @@ func (dst *ParseComplete) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ParseComplete) Encode(dst []byte) []byte { - return append(dst, '1', 0, 0, 0, 4) +func (src *ParseComplete) Encode(dst []byte) ([]byte, error) { + return append(dst, '1', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/password_message.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/password_message.go index 41f98692..d820d327 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/password_message.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/password_message.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type PasswordMessage struct { @@ -32,14 +30,11 @@ func (dst *PasswordMessage) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *PasswordMessage) Encode(dst []byte) []byte { - dst = append(dst, 'p') - dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1)) - +func (src *PasswordMessage) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') dst = append(dst, src.Password...) dst = append(dst, 0) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/pgproto3.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/pgproto3.go index ef5a5489..128f97f8 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/pgproto3.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/pgproto3.go @@ -4,8 +4,14 @@ import ( "encoding/hex" "errors" "fmt" + + "github.com/jackc/pgx/v5/internal/pgio" ) +// maxMessageBodyLen is the maximum length of a message body in bytes. See PG_LARGE_MESSAGE_LIMIT in the PostgreSQL +// source. It is defined as (MaxAllocSize - 1). MaxAllocSize is defined as 0x3fffffff. +const maxMessageBodyLen = (0x3fffffff - 1) + // Message is the interface implemented by an object that can decode and encode // a particular PostgreSQL message. type Message interface { @@ -14,7 +20,7 @@ type Message interface { Decode(data []byte) error // Encode appends itself to dst and returns the new buffer. - Encode(dst []byte) []byte + Encode(dst []byte) ([]byte, error) } // FrontendMessage is a message sent by the frontend (i.e. the client). @@ -70,6 +76,15 @@ func (e *writeError) Unwrap() error { return e.err } +type ExceededMaxBodyLenErr struct { + MaxExpectedBodyLen int + ActualBodyLen int +} + +func (e *ExceededMaxBodyLenErr) Error() string { + return fmt.Sprintf("invalid body length: expected at most %d, but got %d", e.MaxExpectedBodyLen, e.ActualBodyLen) +} + // getValueFromJSON gets the value from a protocol message representation in JSON. func getValueFromJSON(v map[string]string) ([]byte, error) { if v == nil { @@ -83,3 +98,23 @@ func getValueFromJSON(v map[string]string) ([]byte, error) { } return nil, errors.New("unknown protocol representation") } + +// beginMessage begins a new message of type t. It appends the message type and a placeholder for the message length to +// dst. It returns the new buffer and the position of the message length placeholder. +func beginMessage(dst []byte, t byte) ([]byte, int) { + dst = append(dst, t) + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + return dst, sp +} + +// finishMessage finishes a message that was started with beginMessage. It computes the message length and writes it to +// dst[sp]. If the message length is too large it returns an error. Otherwise it returns the final message buffer. +func finishMessage(dst []byte, sp int) ([]byte, error) { + messageBodyLen := len(dst[sp:]) + if messageBodyLen > maxMessageBodyLen { + return nil, errors.New("message body too large") + } + pgio.SetInt32(dst[sp:], int32(messageBodyLen)) + return dst, nil +} diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/portal_suspended.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/portal_suspended.go index 1a9e7bfb..9e2f8cbc 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/portal_suspended.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/portal_suspended.go @@ -20,8 +20,8 @@ func (dst *PortalSuspended) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *PortalSuspended) Encode(dst []byte) []byte { - return append(dst, 's', 0, 0, 0, 4) +func (src *PortalSuspended) Encode(dst []byte) ([]byte, error) { + return append(dst, 's', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/query.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/query.go index e963a0ec..aebdfde8 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/query.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/query.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type Query struct { @@ -28,14 +26,11 @@ func (dst *Query) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Query) Encode(dst []byte) []byte { - dst = append(dst, 'Q') - dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1)) - +func (src *Query) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'Q') dst = append(dst, src.String...) dst = append(dst, 0) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/ready_for_query.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/ready_for_query.go index 67a39be3..a56af9fb 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/ready_for_query.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/ready_for_query.go @@ -25,8 +25,8 @@ func (dst *ReadyForQuery) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ReadyForQuery) Encode(dst []byte) []byte { - return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus) +func (src *ReadyForQuery) Encode(dst []byte) ([]byte, error) { + return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/row_description.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/row_description.go index 6f6f0681..dc2a4ddf 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/row_description.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/row_description.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -99,11 +101,12 @@ func (dst *RowDescription) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *RowDescription) Encode(dst []byte) []byte { - dst = append(dst, 'T') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *RowDescription) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'T') + if len(src.Fields) > math.MaxUint16 { + return nil, errors.New("too many fields") + } dst = pgio.AppendUint16(dst, uint16(len(src.Fields))) for _, fd := range src.Fields { dst = append(dst, fd.Name...) @@ -117,9 +120,7 @@ func (src *RowDescription) Encode(dst []byte) []byte { dst = pgio.AppendInt16(dst, fd.Format) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/sasl_initial_response.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/sasl_initial_response.go index eeda4691..9eb1b6a4 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/sasl_initial_response.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/sasl_initial_response.go @@ -39,10 +39,8 @@ func (dst *SASLInitialResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *SASLInitialResponse) Encode(dst []byte) []byte { - dst = append(dst, 'p') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *SASLInitialResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') dst = append(dst, []byte(src.AuthMechanism)...) dst = append(dst, 0) @@ -50,9 +48,7 @@ func (src *SASLInitialResponse) Encode(dst []byte) []byte { dst = pgio.AppendInt32(dst, int32(len(src.Data))) dst = append(dst, src.Data...) - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/sasl_response.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/sasl_response.go index 54c3d96f..1b604c25 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/sasl_response.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/sasl_response.go @@ -3,8 +3,6 @@ package pgproto3 import ( "encoding/hex" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type SASLResponse struct { @@ -22,13 +20,10 @@ func (dst *SASLResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *SASLResponse) Encode(dst []byte) []byte { - dst = append(dst, 'p') - dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) - +func (src *SASLResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') dst = append(dst, src.Data...) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/ssl_request.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/ssl_request.go index 1b00c16b..b0fc2847 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/ssl_request.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/ssl_request.go @@ -31,10 +31,10 @@ func (dst *SSLRequest) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 4 byte message length. -func (src *SSLRequest) Encode(dst []byte) []byte { +func (src *SSLRequest) Encode(dst []byte) ([]byte, error) { dst = pgio.AppendInt32(dst, 8) dst = pgio.AppendInt32(dst, sslRequestNumber) - return dst + return dst, nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/startup_message.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/startup_message.go index 5c974f02..3af4587d 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/startup_message.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/startup_message.go @@ -38,14 +38,14 @@ func (dst *StartupMessage) Decode(src []byte) error { for { idx := bytes.IndexByte(src[rp:], 0) if idx < 0 { - return &invalidMessageFormatErr{messageType: "StartupMesage"} + return &invalidMessageFormatErr{messageType: "StartupMessage"} } key := string(src[rp : rp+idx]) rp += idx + 1 idx = bytes.IndexByte(src[rp:], 0) if idx < 0 { - return &invalidMessageFormatErr{messageType: "StartupMesage"} + return &invalidMessageFormatErr{messageType: "StartupMessage"} } value := string(src[rp : rp+idx]) rp += idx + 1 @@ -64,7 +64,7 @@ func (dst *StartupMessage) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *StartupMessage) Encode(dst []byte) []byte { +func (src *StartupMessage) Encode(dst []byte) ([]byte, error) { sp := len(dst) dst = pgio.AppendInt32(dst, -1) @@ -77,9 +77,7 @@ func (src *StartupMessage) Encode(dst []byte) []byte { } dst = append(dst, 0) - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/sync.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/sync.go index 5db8e07a..ea4fc959 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/sync.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/sync.go @@ -20,8 +20,8 @@ func (dst *Sync) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Sync) Encode(dst []byte) []byte { - return append(dst, 'S', 0, 0, 0, 4) +func (src *Sync) Encode(dst []byte) ([]byte, error) { + return append(dst, 'S', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/terminate.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/terminate.go index 135191ea..35a9dc83 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/terminate.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgproto3/terminate.go @@ -20,8 +20,8 @@ func (dst *Terminate) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Terminate) Encode(dst []byte) []byte { - return append(dst, 'X', 0, 0, 0, 4) +func (src *Terminate) Encode(dst []byte) ([]byte, error) { + return append(dst, 'X', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/array.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/array.go index 73761956..06b824ad 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/array.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/array.go @@ -110,7 +110,7 @@ func parseUntypedTextArray(src string) (*untypedTextArray, error) { r, _, err := buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %w", err) } var explicitDimensions []ArrayDimension @@ -122,7 +122,7 @@ func parseUntypedTextArray(src string) (*untypedTextArray, error) { for { r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %w", err) } if r == '=' { @@ -133,12 +133,12 @@ func parseUntypedTextArray(src string) (*untypedTextArray, error) { lower, err := arrayParseInteger(buf) if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %w", err) } r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %w", err) } if r != ':' { @@ -147,12 +147,12 @@ func parseUntypedTextArray(src string) (*untypedTextArray, error) { upper, err := arrayParseInteger(buf) if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %w", err) } r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %w", err) } if r != ']' { @@ -164,12 +164,12 @@ func parseUntypedTextArray(src string) (*untypedTextArray, error) { r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %w", err) } } if r != '{' { - return nil, fmt.Errorf("invalid array, expected '{': %v", err) + return nil, fmt.Errorf("invalid array, expected '{' got %v", r) } implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}} @@ -178,7 +178,7 @@ func parseUntypedTextArray(src string) (*untypedTextArray, error) { for { r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %w", err) } if r == '{' { @@ -195,7 +195,7 @@ func parseUntypedTextArray(src string) (*untypedTextArray, error) { for { r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %w", err) } switch r { @@ -214,7 +214,7 @@ func parseUntypedTextArray(src string) (*untypedTextArray, error) { buf.UnreadRune() value, quoted, err := arrayParseValue(buf) if err != nil { - return nil, fmt.Errorf("invalid array value: %v", err) + return nil, fmt.Errorf("invalid array value: %w", err) } if currentDim == counterDim { implicitDimensions[currentDim].Length++ diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/array_codec.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/array_codec.go index c1863b32..bf5f6989 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/array_codec.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/array_codec.go @@ -6,7 +6,6 @@ import ( "fmt" "reflect" - "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -230,7 +229,7 @@ func (c *ArrayCodec) PlanScan(m *Map, oid uint32, format int16, target any) Scan // target / arrayScanner might be a pointer to a nil. If it is create one so we can call ScanIndexType to plan the // scan of the elements. - if anynil.Is(target) { + if isNil, _ := isNilDriverValuer(target); isNil { arrayScanner = reflect.New(reflect.TypeOf(target).Elem()).Interface().(ArraySetter) } diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/bits.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/bits.go index 30558118..e7a1d016 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/bits.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/bits.go @@ -176,8 +176,10 @@ func (scanPlanBinaryBitsToBitsScanner) Scan(src []byte, dst any) error { bitLen := int32(binary.BigEndian.Uint32(src)) rp := 4 + buf := make([]byte, len(src[rp:])) + copy(buf, src[rp:]) - return scanner.ScanBits(Bits{Bytes: src[rp:], Len: bitLen, Valid: true}) + return scanner.ScanBits(Bits{Bytes: buf, Len: bitLen, Valid: true}) } type scanPlanTextAnyToBitsScanner struct{} diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go index 8bf367c1..b39d3fa1 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go @@ -231,7 +231,7 @@ func (w *uint64Wrapper) ScanNumeric(v Numeric) error { bi, err := v.toBigInt() if err != nil { - return fmt.Errorf("cannot scan into *uint64: %v", err) + return fmt.Errorf("cannot scan into *uint64: %w", err) } if !bi.IsUint64() { @@ -284,7 +284,7 @@ func (w *uintWrapper) ScanNumeric(v Numeric) error { bi, err := v.toBigInt() if err != nil { - return fmt.Errorf("cannot scan into *uint: %v", err) + return fmt.Errorf("cannot scan into *uint: %w", err) } if !bi.IsUint64() { diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/date.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/date.go index 009fc0db..784b16de 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/date.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/date.go @@ -282,17 +282,17 @@ func (scanPlanTextAnyToDateScanner) Scan(src []byte, dst any) error { if match != nil { year, err := strconv.ParseInt(match[1], 10, 32) if err != nil { - return fmt.Errorf("BUG: cannot parse date that regexp matched (year): %v", err) + return fmt.Errorf("BUG: cannot parse date that regexp matched (year): %w", err) } month, err := strconv.ParseInt(match[2], 10, 32) if err != nil { - return fmt.Errorf("BUG: cannot parse date that regexp matched (month): %v", err) + return fmt.Errorf("BUG: cannot parse date that regexp matched (month): %w", err) } day, err := strconv.ParseInt(match[3], 10, 32) if err != nil { - return fmt.Errorf("BUG: cannot parse date that regexp matched (month): %v", err) + return fmt.Errorf("BUG: cannot parse date that regexp matched (month): %w", err) } // BC matched diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/doc.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/doc.go index 6612c896..d56c1dc7 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/doc.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/doc.go @@ -67,7 +67,7 @@ See example_custom_type_test.go for an example of a custom type for the PostgreS Sometimes pgx supports a PostgreSQL type such as numeric but the Go type is in an external package that does not have pgx support such as github.com/shopspring/decimal. These types can be registered with pgtype with custom conversion -logic. See https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid for a example +logic. See https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid for example integrations. New PostgreSQL Type Support @@ -139,6 +139,16 @@ Compatibility with database/sql pgtype also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer interfaces. +Encoding Typed Nils + +pgtype encodes untyped and typed nils (e.g. nil and []byte(nil)) to the SQL NULL value without going through the Codec +system. This means that Codecs and other encoding logic do not have to handle nil or *T(nil). + +However, database/sql compatibility requires Value to be called on T(nil) when T implements driver.Valuer. Therefore, +driver.Valuer values are only considered NULL when *T(nil) where driver.Valuer is implemented on T not on *T. See +https://github.com/golang/go/issues/8415 and +https://github.com/golang/go/commit/0ce1d79a6a771f7449ec493b993ed2a720917870. + Child Records pgtype's support for arrays and composite records can be used to load records and their children in a single query. See @@ -149,7 +159,7 @@ Overview of Scanning Implementation The first step is to use the OID to lookup the correct Codec. If the OID is unavailable, Map will try to find the OID from previous calls of Map.RegisterDefaultPgType. The Map will call the Codec's PlanScan method to get a plan for scanning into the Go value. A Codec will support scanning into one or more Go types. Oftentime these Go types are -interfaces rather than explicit types. For example, PointCodec can use any Go type that implments the PointScanner and +interfaces rather than explicit types. For example, PointCodec can use any Go type that implements the PointScanner and PointValuer interfaces. If a Go value is not supported directly by a Codec then Map will try wrapping it with additional logic and try again. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/float4.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/float4.go index 2540f9e5..8646d9d2 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/float4.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/float4.go @@ -3,6 +3,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" "fmt" "math" "strconv" @@ -65,6 +66,29 @@ func (f Float4) Value() (driver.Value, error) { return float64(f.Float32), nil } +func (f Float4) MarshalJSON() ([]byte, error) { + if !f.Valid { + return []byte("null"), nil + } + return json.Marshal(f.Float32) +} + +func (f *Float4) UnmarshalJSON(b []byte) error { + var n *float32 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *f = Float4{} + } else { + *f = Float4{Float32: *n, Valid: true} + } + + return nil +} + type Float4Codec struct{} func (Float4Codec) FormatSupported(format int16) bool { @@ -273,12 +297,12 @@ func (c Float4Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr return nil, nil } - var n float64 + var n float32 err := codecScan(c, m, oid, format, src, &n) if err != nil { return nil, err } - return n, nil + return float64(n), nil } func (c Float4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/float8.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/float8.go index 65e5d8b3..9c923c9a 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/float8.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/float8.go @@ -74,6 +74,29 @@ func (f Float8) Value() (driver.Value, error) { return f.Float64, nil } +func (f Float8) MarshalJSON() ([]byte, error) { + if !f.Valid { + return []byte("null"), nil + } + return json.Marshal(f.Float64) +} + +func (f *Float8) UnmarshalJSON(b []byte) error { + var n *float64 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *f = Float8{} + } else { + *f = Float8{Float64: *n, Valid: true} + } + + return nil +} + type Float8Codec struct{} func (Float8Codec) FormatSupported(format int16) bool { @@ -109,13 +132,6 @@ func (Float8Codec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod return nil } -func (f *Float8) MarshalJSON() ([]byte, error) { - if !f.Valid { - return []byte("null"), nil - } - return json.Marshal(f.Float64) -} - type encodePlanFloat8CodecBinaryFloat64 struct{} func (encodePlanFloat8CodecBinaryFloat64) Encode(value any, buf []byte) (newBuf []byte, err error) { diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/inet.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/inet.go index a85646d7..6ca10ea0 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/inet.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/inet.go @@ -156,7 +156,7 @@ func (scanPlanBinaryInetToNetipPrefixScanner) Scan(src []byte, dst any) error { } if len(src) != 8 && len(src) != 20 { - return fmt.Errorf("Received an invalid size for a inet: %d", len(src)) + return fmt.Errorf("Received an invalid size for an inet: %d", len(src)) } // ignore family diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/interval.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/interval.go index a172ecdb..06703d4d 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/interval.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/interval.go @@ -132,22 +132,31 @@ func (encodePlanIntervalCodecText) Encode(value any, buf []byte) (newBuf []byte, if interval.Days != 0 { buf = append(buf, strconv.FormatInt(int64(interval.Days), 10)...) - buf = append(buf, " day "...) + buf = append(buf, " day"...) } - absMicroseconds := interval.Microseconds - if absMicroseconds < 0 { - absMicroseconds = -absMicroseconds - buf = append(buf, '-') - } + if interval.Microseconds != 0 { + buf = append(buf, " "...) + + absMicroseconds := interval.Microseconds + if absMicroseconds < 0 { + absMicroseconds = -absMicroseconds + buf = append(buf, '-') + } + + hours := absMicroseconds / microsecondsPerHour + minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute + seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond - hours := absMicroseconds / microsecondsPerHour - minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute - seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond - microseconds := absMicroseconds % microsecondsPerSecond + timeStr := fmt.Sprintf("%02d:%02d:%02d", hours, minutes, seconds) + buf = append(buf, timeStr...) + + microseconds := absMicroseconds % microsecondsPerSecond + if microseconds != 0 { + buf = append(buf, fmt.Sprintf(".%06d", microseconds)...) + } + } - timeStr := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, microseconds) - buf = append(buf, timeStr...) return buf, nil } @@ -179,7 +188,7 @@ func (scanPlanBinaryIntervalToIntervalScanner) Scan(src []byte, dst any) error { } if len(src) != 16 { - return fmt.Errorf("Received an invalid size for a interval: %d", len(src)) + return fmt.Errorf("Received an invalid size for an interval: %d", len(src)) } microseconds := int64(binary.BigEndian.Uint64(src)) @@ -242,21 +251,21 @@ func (scanPlanTextAnyToIntervalScanner) Scan(src []byte, dst any) error { return fmt.Errorf("bad interval minute format: %s", timeParts[1]) } - secondParts := strings.SplitN(timeParts[2], ".", 2) + sec, secFrac, secFracFound := strings.Cut(timeParts[2], ".") - seconds, err := strconv.ParseInt(secondParts[0], 10, 64) + seconds, err := strconv.ParseInt(sec, 10, 64) if err != nil { - return fmt.Errorf("bad interval second format: %s", secondParts[0]) + return fmt.Errorf("bad interval second format: %s", sec) } var uSeconds int64 - if len(secondParts) == 2 { - uSeconds, err = strconv.ParseInt(secondParts[1], 10, 64) + if secFracFound { + uSeconds, err = strconv.ParseInt(secFrac, 10, 64) if err != nil { - return fmt.Errorf("bad interval decimal format: %s", secondParts[1]) + return fmt.Errorf("bad interval decimal format: %s", secFrac) } - for i := 0; i < 6-len(secondParts[1]); i++ { + for i := 0; i < 6-len(secFrac); i++ { uSeconds *= 10 } } diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/json.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/json.go index d332dd0d..e71dcb9b 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/json.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/json.go @@ -8,35 +8,48 @@ import ( "reflect" ) -type JSONCodec struct{} +type JSONCodec struct { + Marshal func(v any) ([]byte, error) + Unmarshal func(data []byte, v any) error +} -func (JSONCodec) FormatSupported(format int16) bool { +func (*JSONCodec) FormatSupported(format int16) bool { return format == TextFormatCode || format == BinaryFormatCode } -func (JSONCodec) PreferredFormat() int16 { +func (*JSONCodec) PreferredFormat() int16 { return TextFormatCode } -func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { +func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch value.(type) { case string: return encodePlanJSONCodecEitherFormatString{} case []byte: return encodePlanJSONCodecEitherFormatByteSlice{} - // Must come before trying wrap encode plans because a pointer to a struct may be unwrapped to a struct that can be - // marshalled. - // - // https://github.com/jackc/pgx/issues/1681 - case json.Marshaler: - return encodePlanJSONCodecEitherFormatMarshal{} + // Handle json.RawMessage specifically because if it is run through json.Marshal it may be mutated. + // e.g. `{"foo": "bar"}` -> `{"foo":"bar"}`. + case json.RawMessage: + return encodePlanJSONCodecEitherFormatJSONRawMessage{} // Cannot rely on driver.Valuer being handled later because anything can be marshalled. // // https://github.com/jackc/pgx/issues/1430 + // + // Check for driver.Valuer must come before json.Marshaler so that it is guaranteed to beused + // when both are implemented https://github.com/jackc/pgx/issues/1805 case driver.Valuer: return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format} + + // Must come before trying wrap encode plans because a pointer to a struct may be unwrapped to a struct that can be + // marshalled. + // + // https://github.com/jackc/pgx/issues/1681 + case json.Marshaler: + return &encodePlanJSONCodecEitherFormatMarshal{ + marshal: c.Marshal, + } } // Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the @@ -53,7 +66,9 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod } } - return encodePlanJSONCodecEitherFormatMarshal{} + return &encodePlanJSONCodecEitherFormatMarshal{ + marshal: c.Marshal, + } } type encodePlanJSONCodecEitherFormatString struct{} @@ -76,10 +91,24 @@ func (encodePlanJSONCodecEitherFormatByteSlice) Encode(value any, buf []byte) (n return buf, nil } -type encodePlanJSONCodecEitherFormatMarshal struct{} +type encodePlanJSONCodecEitherFormatJSONRawMessage struct{} -func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) { - jsonBytes, err := json.Marshal(value) +func (encodePlanJSONCodecEitherFormatJSONRawMessage) Encode(value any, buf []byte) (newBuf []byte, err error) { + jsonBytes := value.(json.RawMessage) + if jsonBytes == nil { + return nil, nil + } + + buf = append(buf, jsonBytes...) + return buf, nil +} + +type encodePlanJSONCodecEitherFormatMarshal struct { + marshal func(v any) ([]byte, error) +} + +func (e *encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) { + jsonBytes, err := e.marshal(value) if err != nil { return nil, err } @@ -88,7 +117,7 @@ func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (new return buf, nil } -func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch target.(type) { case *string: return scanPlanAnyToString{} @@ -121,7 +150,9 @@ func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan return &scanPlanSQLScanner{formatCode: format} } - return scanPlanJSONToJSONUnmarshal{} + return &scanPlanJSONToJSONUnmarshal{ + unmarshal: c.Unmarshal, + } } type scanPlanAnyToString struct{} @@ -153,9 +184,11 @@ func (scanPlanJSONToBytesScanner) Scan(src []byte, dst any) error { return scanner.ScanBytes(src) } -type scanPlanJSONToJSONUnmarshal struct{} +type scanPlanJSONToJSONUnmarshal struct { + unmarshal func(data []byte, v any) error +} -func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { +func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { if src == nil { dstValue := reflect.ValueOf(dst) if dstValue.Kind() == reflect.Ptr { @@ -173,10 +206,10 @@ func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { elem := reflect.ValueOf(dst).Elem() elem.Set(reflect.Zero(elem.Type())) - return json.Unmarshal(src, dst) + return s.unmarshal(src, dst) } -func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c *JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -186,12 +219,12 @@ func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return dstBuf, nil } -func (c JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { +func (c *JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } var dst any - err := json.Unmarshal(src, &dst) + err := c.Unmarshal(src, &dst) return dst, err } diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/jsonb.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/jsonb.go index 25555e7f..4d4eb58e 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/jsonb.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/jsonb.go @@ -2,29 +2,31 @@ package pgtype import ( "database/sql/driver" - "encoding/json" "fmt" ) -type JSONBCodec struct{} +type JSONBCodec struct { + Marshal func(v any) ([]byte, error) + Unmarshal func(data []byte, v any) error +} -func (JSONBCodec) FormatSupported(format int16) bool { +func (*JSONBCodec) FormatSupported(format int16) bool { return format == TextFormatCode || format == BinaryFormatCode } -func (JSONBCodec) PreferredFormat() int16 { +func (*JSONBCodec) PreferredFormat() int16 { return TextFormatCode } -func (JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { +func (c *JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch format { case BinaryFormatCode: - plan := JSONCodec{}.PlanEncode(m, oid, TextFormatCode, value) + plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, TextFormatCode, value) if plan != nil { return &encodePlanJSONBCodecBinaryWrapper{textPlan: plan} } case TextFormatCode: - return JSONCodec{}.PlanEncode(m, oid, format, value) + return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, format, value) } return nil @@ -39,15 +41,15 @@ func (plan *encodePlanJSONBCodecBinaryWrapper) Encode(value any, buf []byte) (ne return plan.textPlan.Encode(value, buf) } -func (JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +func (c *JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: - plan := JSONCodec{}.PlanScan(m, oid, TextFormatCode, target) + plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, TextFormatCode, target) if plan != nil { return &scanPlanJSONBCodecBinaryUnwrapper{textPlan: plan} } case TextFormatCode: - return JSONCodec{}.PlanScan(m, oid, format, target) + return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, format, target) } return nil @@ -73,7 +75,7 @@ func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(src []byte, dst any) error { return plan.textPlan.Scan(src[1:], dst) } -func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c *JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -100,7 +102,7 @@ func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src } } -func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { +func (c *JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } @@ -122,6 +124,6 @@ func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (a } var dst any - err := json.Unmarshal(src, &dst) + err := c.Unmarshal(src, &dst) return dst, err } diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/ltree.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/ltree.go new file mode 100644 index 00000000..6af31779 --- /dev/null +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/ltree.go @@ -0,0 +1,122 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" +) + +type LtreeCodec struct{} + +func (l LtreeCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +// PreferredFormat returns the preferred format. +func (l LtreeCodec) PreferredFormat() int16 { + return TextFormatCode +} + +// PlanEncode returns an EncodePlan for encoding value into PostgreSQL format for oid and format. If no plan can be +// found then nil is returned. +func (l LtreeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case TextFormatCode: + return (TextCodec)(l).PlanEncode(m, oid, format, value) + case BinaryFormatCode: + switch value.(type) { + case string: + return encodeLtreeCodecBinaryString{} + case []byte: + return encodeLtreeCodecBinaryByteSlice{} + case TextValuer: + return encodeLtreeCodecBinaryTextValuer{} + } + } + + return nil +} + +type encodeLtreeCodecBinaryString struct{} + +func (encodeLtreeCodecBinaryString) Encode(value any, buf []byte) (newBuf []byte, err error) { + ltree := value.(string) + buf = append(buf, 1) + return append(buf, ltree...), nil +} + +type encodeLtreeCodecBinaryByteSlice struct{} + +func (encodeLtreeCodecBinaryByteSlice) Encode(value any, buf []byte) (newBuf []byte, err error) { + ltree := value.([]byte) + buf = append(buf, 1) + return append(buf, ltree...), nil +} + +type encodeLtreeCodecBinaryTextValuer struct{} + +func (encodeLtreeCodecBinaryTextValuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + t, err := value.(TextValuer).TextValue() + if err != nil { + return nil, err + } + if !t.Valid { + return nil, nil + } + + buf = append(buf, 1) + return append(buf, t.String...), nil +} + +// PlanScan returns a ScanPlan for scanning a PostgreSQL value into a destination with the same type as target. If +// no plan can be found then nil is returned. +func (l LtreeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case TextFormatCode: + return (TextCodec)(l).PlanScan(m, oid, format, target) + case BinaryFormatCode: + switch target.(type) { + case *string: + return scanPlanBinaryLtreeToString{} + case TextScanner: + return scanPlanBinaryLtreeToTextScanner{} + } + } + + return nil +} + +type scanPlanBinaryLtreeToString struct{} + +func (scanPlanBinaryLtreeToString) Scan(src []byte, target any) error { + version := src[0] + if version != 1 { + return fmt.Errorf("unsupported ltree version %d", version) + } + + p := (target).(*string) + *p = string(src[1:]) + + return nil +} + +type scanPlanBinaryLtreeToTextScanner struct{} + +func (scanPlanBinaryLtreeToTextScanner) Scan(src []byte, target any) error { + version := src[0] + if version != 1 { + return fmt.Errorf("unsupported ltree version %d", version) + } + + scanner := (target).(TextScanner) + return scanner.ScanText(Text{String: string(src[1:]), Valid: true}) +} + +// DecodeDatabaseSQLValue returns src decoded into a value compatible with the sql.Scanner interface. +func (l LtreeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return (TextCodec)(l).DecodeDatabaseSQLValue(m, oid, format, src) +} + +// DecodeValue returns src decoded into its default format. +func (l LtreeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + return (TextCodec)(l).DecodeValue(m, oid, format, src) +} diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/multirange.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/multirange.go index 34950b34..e5763788 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/multirange.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/multirange.go @@ -339,18 +339,18 @@ func parseUntypedTextMultirange(src []byte) ([]string, error) { r, _, err := buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, fmt.Errorf("invalid array: %w", err) } if r != '{' { - return nil, fmt.Errorf("invalid multirange, expected '{': %v", err) + return nil, fmt.Errorf("invalid multirange, expected '{' got %v", r) } parseValueLoop: for { r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid multirange: %v", err) + return nil, fmt.Errorf("invalid multirange: %w", err) } switch r { @@ -361,7 +361,7 @@ parseValueLoop: buf.UnreadRune() value, err := parseRange(buf) if err != nil { - return nil, fmt.Errorf("invalid multirange value: %v", err) + return nil, fmt.Errorf("invalid multirange value: %w", err) } elements = append(elements, value) } diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go index 0e58fd07..4dbec786 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go @@ -119,6 +119,26 @@ func (n Numeric) Int64Value() (Int8, error) { return Int8{Int64: bi.Int64(), Valid: true}, nil } +func (n *Numeric) ScanScientific(src string) error { + if !strings.ContainsAny("eE", src) { + return scanPlanTextAnyToNumericScanner{}.Scan([]byte(src), n) + } + + if bigF, ok := new(big.Float).SetString(string(src)); ok { + smallF, _ := bigF.Float64() + src = strconv.FormatFloat(smallF, 'f', -1, 64) + } + + num, exp, err := parseNumericString(src) + if err != nil { + return err + } + + *n = Numeric{Int: num, Exp: exp, Valid: true} + + return nil +} + func (n *Numeric) toBigInt() (*big.Int, error) { if n.Exp == 0 { return n.Int, nil diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go index 59d833a1..40829568 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go @@ -41,6 +41,7 @@ const ( CircleOID = 718 CircleArrayOID = 719 UnknownOID = 705 + Macaddr8OID = 774 MacaddrOID = 829 InetOID = 869 BoolArrayOID = 1000 @@ -81,6 +82,8 @@ const ( IntervalOID = 1186 IntervalArrayOID = 1187 NumericArrayOID = 1231 + TimetzOID = 1266 + TimetzArrayOID = 1270 BitOID = 1560 BitArrayOID = 1561 VarbitOID = 1562 @@ -559,7 +562,7 @@ func TryFindUnderlyingTypeScanPlan(dst any) (plan WrappedScanPlanNextSetter, nex } } - if nextDstType != nil && dstValue.Type() != nextDstType { + if nextDstType != nil && dstValue.Type() != nextDstType && dstValue.CanConvert(nextDstType) { return &underlyingTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true } @@ -1328,7 +1331,7 @@ func (plan *derefPointerEncodePlan) Encode(value any, buf []byte) (newBuf []byte } // TryWrapDerefPointerEncodePlan tries to dereference a pointer. e.g. If value was of type *string then a wrapper plan -// would be returned that derefences the value. +// would be returned that dereferences the value. func TryWrapDerefPointerEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { if _, ok := value.(driver.Valuer); ok { return nil, nil, false @@ -1358,6 +1361,8 @@ var kindToTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ reflect.Bool: reflect.TypeOf(false), } +var byteSliceType = reflect.TypeOf([]byte{}) + type underlyingTypeEncodePlan struct { nextValueType reflect.Type next EncodePlan @@ -1372,6 +1377,10 @@ func (plan *underlyingTypeEncodePlan) Encode(value any, buf []byte) (newBuf []by // TryWrapFindUnderlyingTypeEncodePlan tries to convert to a Go builtin type. e.g. If value was of type MyString and // MyString was defined as a string then a wrapper plan would be returned that converts MyString to string. func TryWrapFindUnderlyingTypeEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { + if value == nil { + return nil, nil, false + } + if _, ok := value.(driver.Valuer); ok { return nil, nil, false } @@ -1387,6 +1396,15 @@ func TryWrapFindUnderlyingTypeEncodePlan(value any) (plan WrappedEncodePlanNextS return &underlyingTypeEncodePlan{nextValueType: nextValueType}, refValue.Convert(nextValueType).Interface(), true } + // []byte is a special case. It is a slice but we treat it as a scalar type. In the case of a named type like + // json.RawMessage which is defined as []byte the underlying type should be considered as []byte. But any other slice + // does not have a special underlying type. + // + // https://github.com/jackc/pgx/issues/1763 + if refValue.Type() != byteSliceType && refValue.Type().AssignableTo(byteSliceType) { + return &underlyingTypeEncodePlan{nextValueType: byteSliceType}, refValue.Convert(byteSliceType).Interface(), true + } + return nil, nil, false } @@ -1894,8 +1912,17 @@ func newEncodeError(value any, m *Map, oid uint32, formatCode int16, err error) // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written. func (m *Map) Encode(oid uint32, formatCode int16, value any, buf []byte) (newBuf []byte, err error) { - if value == nil { - return nil, nil + if isNil, callNilDriverValuer := isNilDriverValuer(value); isNil { + if callNilDriverValuer { + newBuf, err = (&encodePlanDriverValuer{m: m, oid: oid, formatCode: formatCode}).Encode(value, buf) + if err != nil { + return nil, newEncodeError(value, m, oid, formatCode, err) + } + + return newBuf, nil + } else { + return nil, nil + } } plan := m.PlanEncode(oid, formatCode, value) @@ -1950,3 +1977,55 @@ func (w *sqlScannerWrapper) Scan(src any) error { return w.m.Scan(t.OID, TextFormatCode, bufSrc, w.v) } + +// canBeNil returns true if value can be nil. +func canBeNil(value any) bool { + refVal := reflect.ValueOf(value) + kind := refVal.Kind() + switch kind { + case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: + return true + default: + return false + } +} + +// valuerReflectType is a reflect.Type for driver.Valuer. It has confusing syntax because reflect.TypeOf returns nil +// when it's argument is a nil interface value. So we use a pointer to the interface and call Elem to get the actual +// type. Yuck. +// +// This can be simplified in Go 1.22 with reflect.TypeFor. +// +// var valuerReflectType = reflect.TypeFor[driver.Valuer]() +var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + +// isNilDriverValuer returns true if value is any type of nil unless it implements driver.Valuer. *T is not considered to implement +// driver.Valuer if it is only implemented by T. +func isNilDriverValuer(value any) (isNil bool, callNilDriverValuer bool) { + if value == nil { + return true, false + } + + refVal := reflect.ValueOf(value) + kind := refVal.Kind() + switch kind { + case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: + if !refVal.IsNil() { + return false, false + } + + if _, ok := value.(driver.Valuer); ok { + if kind == reflect.Ptr { + // The type assertion will succeed if driver.Valuer is implemented on T or *T. Check if it is implemented on *T + // by checking if it is not implemented on *T. + return true, !refVal.Type().Elem().Implements(valuerReflectType) + } else { + return true, true + } + } + + return true, false + default: + return false, false + } +} diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go index 58f4b92c..9525f37c 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/json" "net" "net/netip" "reflect" @@ -64,11 +65,12 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) defaultMap.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) defaultMap.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) - defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) - defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) + defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: &JSONCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}}) + defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: &JSONBCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}}) defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}}) defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) + defaultMap.RegisterType(&Type{Name: "macaddr8", OID: Macaddr8OID, Codec: MacaddrCodec{}}) defaultMap.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) defaultMap.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) @@ -80,8 +82,8 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) defaultMap.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) - defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) - defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) + defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: &TimestampCodec{}}) + defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: &TimestamptzCodec{}}) defaultMap.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) defaultMap.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) @@ -173,6 +175,7 @@ func initDefaultMap() { registerDefaultPgTypeVariants[time.Time](defaultMap, "timestamptz") registerDefaultPgTypeVariants[time.Duration](defaultMap, "interval") registerDefaultPgTypeVariants[string](defaultMap, "text") + registerDefaultPgTypeVariants[json.RawMessage](defaultMap, "json") registerDefaultPgTypeVariants[[]byte](defaultMap, "bytea") registerDefaultPgTypeVariants[net.IP](defaultMap, "inet") diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/point.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/point.go index b5a4320b..09b19bb5 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/point.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/point.go @@ -50,17 +50,17 @@ func parsePoint(src []byte) (*Point, error) { if src[0] == '"' && src[len(src)-1] == '"' { src = src[1 : len(src)-1] } - parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) - if len(parts) < 2 { + sx, sy, found := strings.Cut(string(src[1:len(src)-1]), ",") + if !found { return nil, fmt.Errorf("invalid format for point") } - x, err := strconv.ParseFloat(parts[0], 64) + x, err := strconv.ParseFloat(sx, 64) if err != nil { return nil, err } - y, err := strconv.ParseFloat(parts[1], 64) + y, err := strconv.ParseFloat(sy, 64) if err != nil { return nil, err } @@ -247,17 +247,17 @@ func (scanPlanTextAnyToPointScanner) Scan(src []byte, dst any) error { return fmt.Errorf("invalid length for point: %v", len(src)) } - parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) - if len(parts) < 2 { + sx, sy, found := strings.Cut(string(src[1:len(src)-1]), ",") + if !found { return fmt.Errorf("invalid format for point") } - x, err := strconv.ParseFloat(parts[0], 64) + x, err := strconv.ParseFloat(sx, 64) if err != nil { return err } - y, err := strconv.ParseFloat(parts[1], 64) + y, err := strconv.ParseFloat(sy, 64) if err != nil { return err } diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/range.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/range.go index 8f408f9f..16427ccc 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/range.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/range.go @@ -40,7 +40,7 @@ func parseUntypedTextRange(src string) (*untypedTextRange, error) { r, _, err := buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid lower bound: %v", err) + return nil, fmt.Errorf("invalid lower bound: %w", err) } switch r { case '(': @@ -53,7 +53,7 @@ func parseUntypedTextRange(src string) (*untypedTextRange, error) { r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid lower value: %v", err) + return nil, fmt.Errorf("invalid lower value: %w", err) } buf.UnreadRune() @@ -62,13 +62,13 @@ func parseUntypedTextRange(src string) (*untypedTextRange, error) { } else { utr.Lower, err = rangeParseValue(buf) if err != nil { - return nil, fmt.Errorf("invalid lower value: %v", err) + return nil, fmt.Errorf("invalid lower value: %w", err) } } r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("missing range separator: %v", err) + return nil, fmt.Errorf("missing range separator: %w", err) } if r != ',' { return nil, fmt.Errorf("missing range separator: %v", r) @@ -76,7 +76,7 @@ func parseUntypedTextRange(src string) (*untypedTextRange, error) { r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid upper value: %v", err) + return nil, fmt.Errorf("invalid upper value: %w", err) } if r == ')' || r == ']' { @@ -85,12 +85,12 @@ func parseUntypedTextRange(src string) (*untypedTextRange, error) { buf.UnreadRune() utr.Upper, err = rangeParseValue(buf) if err != nil { - return nil, fmt.Errorf("invalid upper value: %v", err) + return nil, fmt.Errorf("invalid upper value: %w", err) } r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("missing upper bound: %v", err) + return nil, fmt.Errorf("missing upper bound: %w", err) } switch r { case ')': diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/range_codec.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/range_codec.go index 8cfb3a63..684f1bf7 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/range_codec.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/range_codec.go @@ -120,7 +120,7 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byt buf, err = lowerPlan.Encode(lower, buf) if err != nil { - return nil, fmt.Errorf("failed to encode %v as element of range: %v", lower, err) + return nil, fmt.Errorf("failed to encode %v as element of range: %w", lower, err) } if buf == nil { return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded") @@ -144,7 +144,7 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byt buf, err = upperPlan.Encode(upper, buf) if err != nil { - return nil, fmt.Errorf("failed to encode %v as element of range: %v", upper, err) + return nil, fmt.Errorf("failed to encode %v as element of range: %w", upper, err) } if buf == nil { return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded") @@ -194,7 +194,7 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) buf, err = lowerPlan.Encode(lower, buf) if err != nil { - return nil, fmt.Errorf("failed to encode %v as element of range: %v", lower, err) + return nil, fmt.Errorf("failed to encode %v as element of range: %w", lower, err) } if buf == nil { return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded") @@ -215,7 +215,7 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) buf, err = upperPlan.Encode(upper, buf) if err != nil { - return nil, fmt.Errorf("failed to encode %v as element of range: %v", upper, err) + return nil, fmt.Errorf("failed to encode %v as element of range: %w", upper, err) } if buf == nil { return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded") @@ -282,7 +282,7 @@ func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) erro err = lowerPlan.Scan(ubr.Lower, lowerTarget) if err != nil { - return fmt.Errorf("cannot scan into %v from range element: %v", lowerTarget, err) + return fmt.Errorf("cannot scan into %v from range element: %w", lowerTarget, err) } } @@ -294,7 +294,7 @@ func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) erro err = upperPlan.Scan(ubr.Upper, upperTarget) if err != nil { - return fmt.Errorf("cannot scan into %v from range element: %v", upperTarget, err) + return fmt.Errorf("cannot scan into %v from range element: %w", upperTarget, err) } } @@ -332,7 +332,7 @@ func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error err = lowerPlan.Scan([]byte(utr.Lower), lowerTarget) if err != nil { - return fmt.Errorf("cannot scan into %v from range element: %v", lowerTarget, err) + return fmt.Errorf("cannot scan into %v from range element: %w", lowerTarget, err) } } @@ -344,7 +344,7 @@ func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error err = upperPlan.Scan([]byte(utr.Upper), upperTarget) if err != nil { - return fmt.Errorf("cannot scan into %v from range element: %v", upperTarget, err) + return fmt.Errorf("cannot scan into %v from range element: %w", upperTarget, err) } } diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/tid.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/tid.go index 5839e874..9bc2c2a1 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/tid.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/tid.go @@ -205,17 +205,17 @@ func (scanPlanTextAnyToTIDScanner) Scan(src []byte, dst any) error { return fmt.Errorf("invalid length for tid: %v", len(src)) } - parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) - if len(parts) < 2 { + block, offset, found := strings.Cut(string(src[1:len(src)-1]), ",") + if !found { return fmt.Errorf("invalid format for tid") } - blockNumber, err := strconv.ParseUint(parts[0], 10, 32) + blockNumber, err := strconv.ParseUint(block, 10, 32) if err != nil { return err } - offsetNumber, err := strconv.ParseUint(parts[1], 10, 16) + offsetNumber, err := strconv.ParseUint(offset, 10, 16) if err != nil { return err } diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/time.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/time.go index 2eb6ace2..61a3abdf 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/time.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/time.go @@ -45,7 +45,12 @@ func (t *Time) Scan(src any) error { switch src := src.(type) { case string: - return scanPlanTextAnyToTimeScanner{}.Scan([]byte(src), t) + err := scanPlanTextAnyToTimeScanner{}.Scan([]byte(src), t) + if err != nil { + t.Microseconds = 0 + t.Valid = false + } + return err } return fmt.Errorf("cannot scan %T", src) @@ -136,6 +141,8 @@ func (TimeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan switch target.(type) { case TimeScanner: return scanPlanBinaryTimeToTimeScanner{} + case TextScanner: + return scanPlanBinaryTimeToTextScanner{} } case TextFormatCode: switch target.(type) { @@ -165,6 +172,34 @@ func (scanPlanBinaryTimeToTimeScanner) Scan(src []byte, dst any) error { return scanner.ScanTime(Time{Microseconds: usec, Valid: true}) } +type scanPlanBinaryTimeToTextScanner struct{} + +func (scanPlanBinaryTimeToTextScanner) Scan(src []byte, dst any) error { + ts, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return ts.ScanText(Text{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for time: %v", len(src)) + } + + usec := int64(binary.BigEndian.Uint64(src)) + + tim := Time{Microseconds: usec, Valid: true} + + buf, err := TimeCodec{}.PlanEncode(nil, 0, TextFormatCode, tim).Encode(tim, nil) + if err != nil { + return err + } + + return ts.ScanText(Text{String: string(buf), Valid: true}) +} + type scanPlanTextAnyToTimeScanner struct{} func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst any) error { @@ -176,7 +211,7 @@ func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst any) error { s := string(src) - if len(s) < 8 { + if len(s) < 8 || s[2] != ':' || s[5] != ':' { return fmt.Errorf("cannot decode %v into Time", s) } @@ -199,6 +234,10 @@ func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst any) error { usec += seconds * microsecondsPerSecond if len(s) > 9 { + if s[8] != '.' || len(s) > 15 { + return fmt.Errorf("cannot decode %v into Time", s) + } + fraction := s[9:] n, err := strconv.ParseInt(fraction, 10, 64) if err != nil { diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go index 35d73956..677a2c6e 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go @@ -46,7 +46,7 @@ func (ts *Timestamp) Scan(src any) error { switch src := src.(type) { case string: - return scanPlanTextTimestampToTimestampScanner{}.Scan([]byte(src), ts) + return (&scanPlanTextTimestampToTimestampScanner{}).Scan([]byte(src), ts) case time.Time: *ts = Timestamp{Time: src, Valid: true} return nil @@ -116,17 +116,21 @@ func (ts *Timestamp) UnmarshalJSON(b []byte) error { return nil } -type TimestampCodec struct{} +type TimestampCodec struct { + // ScanLocation is the location that the time is assumed to be in for scanning. This is different from + // TimestamptzCodec.ScanLocation in that this setting does change the instant in time that the timestamp represents. + ScanLocation *time.Location +} -func (TimestampCodec) FormatSupported(format int16) bool { +func (*TimestampCodec) FormatSupported(format int16) bool { return format == TextFormatCode || format == BinaryFormatCode } -func (TimestampCodec) PreferredFormat() int16 { +func (*TimestampCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (TimestampCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { +func (*TimestampCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(TimestampValuer); !ok { return nil } @@ -220,27 +224,27 @@ func discardTimeZone(t time.Time) time.Time { return t } -func (TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +func (c *TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { case TimestampScanner: - return scanPlanBinaryTimestampToTimestampScanner{} + return &scanPlanBinaryTimestampToTimestampScanner{location: c.ScanLocation} } case TextFormatCode: switch target.(type) { case TimestampScanner: - return scanPlanTextTimestampToTimestampScanner{} + return &scanPlanTextTimestampToTimestampScanner{location: c.ScanLocation} } } return nil } -type scanPlanBinaryTimestampToTimestampScanner struct{} +type scanPlanBinaryTimestampToTimestampScanner struct{ location *time.Location } -func (scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error { +func (plan *scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error { scanner := (dst).(TimestampScanner) if src == nil { @@ -264,15 +268,18 @@ func (scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), ).UTC() + if plan.location != nil { + tim = time.Date(tim.Year(), tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), plan.location) + } ts = Timestamp{Time: tim, Valid: true} } return scanner.ScanTimestamp(ts) } -type scanPlanTextTimestampToTimestampScanner struct{} +type scanPlanTextTimestampToTimestampScanner struct{ location *time.Location } -func (scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error { +func (plan *scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error { scanner := (dst).(TimestampScanner) if src == nil { @@ -302,13 +309,17 @@ func (scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error { tim = time.Date(year, tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), tim.Location()) } + if plan.location != nil { + tim = time.Date(tim.Year(), tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), plan.location) + } + ts = Timestamp{Time: tim, Valid: true} } return scanner.ScanTimestamp(ts) } -func (c TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c *TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -326,7 +337,7 @@ func (c TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, return ts.Time, nil } -func (c TimestampCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { +func (c *TimestampCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/timestamptz.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/timestamptz.go index f568fe30..7efbcffd 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/timestamptz.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/timestamptz.go @@ -54,7 +54,7 @@ func (tstz *Timestamptz) Scan(src any) error { switch src := src.(type) { case string: - return scanPlanTextTimestamptzToTimestamptzScanner{}.Scan([]byte(src), tstz) + return (&scanPlanTextTimestamptzToTimestamptzScanner{}).Scan([]byte(src), tstz) case time.Time: *tstz = Timestamptz{Time: src, Valid: true} return nil @@ -124,17 +124,21 @@ func (tstz *Timestamptz) UnmarshalJSON(b []byte) error { return nil } -type TimestamptzCodec struct{} +type TimestamptzCodec struct { + // ScanLocation is the location to return scanned timestamptz values in. This does not change the instant in time that + // the timestamptz represents. + ScanLocation *time.Location +} -func (TimestamptzCodec) FormatSupported(format int16) bool { +func (*TimestamptzCodec) FormatSupported(format int16) bool { return format == TextFormatCode || format == BinaryFormatCode } -func (TimestamptzCodec) PreferredFormat() int16 { +func (*TimestamptzCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (TimestamptzCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { +func (*TimestamptzCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(TimestamptzValuer); !ok { return nil } @@ -220,27 +224,27 @@ func (encodePlanTimestamptzCodecText) Encode(value any, buf []byte) (newBuf []by return buf, nil } -func (TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +func (c *TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { case TimestamptzScanner: - return scanPlanBinaryTimestamptzToTimestamptzScanner{} + return &scanPlanBinaryTimestamptzToTimestamptzScanner{location: c.ScanLocation} } case TextFormatCode: switch target.(type) { case TimestamptzScanner: - return scanPlanTextTimestamptzToTimestamptzScanner{} + return &scanPlanTextTimestamptzToTimestamptzScanner{location: c.ScanLocation} } } return nil } -type scanPlanBinaryTimestamptzToTimestamptzScanner struct{} +type scanPlanBinaryTimestamptzToTimestamptzScanner struct{ location *time.Location } -func (scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error { +func (plan *scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error { scanner := (dst).(TimestamptzScanner) if src == nil { @@ -264,15 +268,18 @@ func (scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) e microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), ) + if plan.location != nil { + tim = tim.In(plan.location) + } tstz = Timestamptz{Time: tim, Valid: true} } return scanner.ScanTimestamptz(tstz) } -type scanPlanTextTimestamptzToTimestamptzScanner struct{} +type scanPlanTextTimestamptzToTimestamptzScanner struct{ location *time.Location } -func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error { +func (plan *scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error { scanner := (dst).(TimestamptzScanner) if src == nil { @@ -312,13 +319,17 @@ func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) err tim = time.Date(year, tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), tim.Location()) } + if plan.location != nil { + tim = tim.In(plan.location) + } + tstz = Timestamptz{Time: tim, Valid: true} } return scanner.ScanTimestamptz(tstz) } -func (c TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c *TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -336,7 +347,7 @@ func (c TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int1 return tstz.Time, nil } -func (c TimestamptzCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { +func (c *TimestamptzCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/uuid.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/uuid.go index b59d6e76..d57c0f2f 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/uuid.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgtype/uuid.go @@ -52,7 +52,19 @@ func parseUUID(src string) (dst [16]byte, err error) { // encodeUUID converts a uuid byte array to UUID standard string form. func encodeUUID(src [16]byte) string { - return fmt.Sprintf("%x-%x-%x-%x-%x", src[0:4], src[4:6], src[6:8], src[8:10], src[10:16]) + var buf [36]byte + + hex.Encode(buf[0:8], src[:4]) + buf[8] = '-' + hex.Encode(buf[9:13], src[4:6]) + buf[13] = '-' + hex.Encode(buf[14:18], src[6:8]) + buf[18] = '-' + hex.Encode(buf[19:23], src[8:10]) + buf[23] = '-' + hex.Encode(buf[24:], src[10:]) + + return string(buf[:]) } // Scan implements the database/sql Scanner interface. diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/batch_results.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/batch_results.go new file mode 100644 index 00000000..5d5c681d --- /dev/null +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/batch_results.go @@ -0,0 +1,52 @@ +package pgxpool + +import ( + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type errBatchResults struct { + err error +} + +func (br errBatchResults) Exec() (pgconn.CommandTag, error) { + return pgconn.CommandTag{}, br.err +} + +func (br errBatchResults) Query() (pgx.Rows, error) { + return errRows{err: br.err}, br.err +} + +func (br errBatchResults) QueryRow() pgx.Row { + return errRow{err: br.err} +} + +func (br errBatchResults) Close() error { + return br.err +} + +type poolBatchResults struct { + br pgx.BatchResults + c *Conn +} + +func (br *poolBatchResults) Exec() (pgconn.CommandTag, error) { + return br.br.Exec() +} + +func (br *poolBatchResults) Query() (pgx.Rows, error) { + return br.br.Query() +} + +func (br *poolBatchResults) QueryRow() pgx.Row { + return br.br.QueryRow() +} + +func (br *poolBatchResults) Close() error { + err := br.br.Close() + if br.c != nil { + br.c.Release() + br.c = nil + } + return err +} diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/conn.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/conn.go new file mode 100644 index 00000000..38c90f3d --- /dev/null +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/conn.go @@ -0,0 +1,134 @@ +package pgxpool + +import ( + "context" + "sync/atomic" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/puddle/v2" +) + +// Conn is an acquired *pgx.Conn from a Pool. +type Conn struct { + res *puddle.Resource[*connResource] + p *Pool +} + +// Release returns c to the pool it was acquired from. Once Release has been called, other methods must not be called. +// However, it is safe to call Release multiple times. Subsequent calls after the first will be ignored. +func (c *Conn) Release() { + if c.res == nil { + return + } + + conn := c.Conn() + res := c.res + c.res = nil + + if c.p.releaseTracer != nil { + c.p.releaseTracer.TraceRelease(c.p, TraceReleaseData{Conn: conn}) + } + + if conn.IsClosed() || conn.PgConn().IsBusy() || conn.PgConn().TxStatus() != 'I' { + res.Destroy() + // Signal to the health check to run since we just destroyed a connections + // and we might be below minConns now + c.p.triggerHealthCheck() + return + } + + // If the pool is consistently being used, we might never get to check the + // lifetime of a connection since we only check idle connections in checkConnsHealth + // so we also check the lifetime here and force a health check + if c.p.isExpired(res) { + atomic.AddInt64(&c.p.lifetimeDestroyCount, 1) + res.Destroy() + // Signal to the health check to run since we just destroyed a connections + // and we might be below minConns now + c.p.triggerHealthCheck() + return + } + + if c.p.afterRelease == nil { + res.Release() + return + } + + go func() { + if c.p.afterRelease(conn) { + res.Release() + } else { + res.Destroy() + // Signal to the health check to run since we just destroyed a connections + // and we might be below minConns now + c.p.triggerHealthCheck() + } + }() +} + +// Hijack assumes ownership of the connection from the pool. Caller is responsible for closing the connection. Hijack +// will panic if called on an already released or hijacked connection. +func (c *Conn) Hijack() *pgx.Conn { + if c.res == nil { + panic("cannot hijack already released or hijacked connection") + } + + conn := c.Conn() + res := c.res + c.res = nil + + res.Hijack() + + return conn +} + +func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { + return c.Conn().Exec(ctx, sql, arguments...) +} + +func (c *Conn) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { + return c.Conn().Query(ctx, sql, args...) +} + +func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { + return c.Conn().QueryRow(ctx, sql, args...) +} + +func (c *Conn) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { + return c.Conn().SendBatch(ctx, b) +} + +func (c *Conn) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc) +} + +// Begin starts a transaction block from the *Conn without explicitly setting a transaction mode (see BeginTx with TxOptions if transaction mode is required). +func (c *Conn) Begin(ctx context.Context) (pgx.Tx, error) { + return c.Conn().Begin(ctx) +} + +// BeginTx starts a transaction block from the *Conn with txOptions determining the transaction mode. +func (c *Conn) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) { + return c.Conn().BeginTx(ctx, txOptions) +} + +func (c *Conn) Ping(ctx context.Context) error { + return c.Conn().Ping(ctx) +} + +func (c *Conn) Conn() *pgx.Conn { + return c.connResource().conn +} + +func (c *Conn) connResource() *connResource { + return c.res.Value() +} + +func (c *Conn) getPoolRow(r pgx.Row) *poolRow { + return c.connResource().getPoolRow(c, r) +} + +func (c *Conn) getPoolRows(r pgx.Rows) *poolRows { + return c.connResource().getPoolRows(c, r) +} diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/doc.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/doc.go new file mode 100644 index 00000000..099443bc --- /dev/null +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/doc.go @@ -0,0 +1,27 @@ +// Package pgxpool is a concurrency-safe connection pool for pgx. +/* +pgxpool implements a nearly identical interface to pgx connections. + +Creating a Pool + +The primary way of creating a pool is with [pgxpool.New]: + + pool, err := pgxpool.New(context.Background(), os.Getenv("DATABASE_URL")) + +The database connection string can be in URL or keyword/value format. PostgreSQL settings, pgx settings, and pool settings can be +specified here. In addition, a config struct can be created by [ParseConfig]. + + config, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL")) + if err != nil { + // ... + } + config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { + // do something with every new connection + } + + pool, err := pgxpool.NewWithConfig(context.Background(), config) + +A pool returns without waiting for any connections to be established. Acquire a connection immediately after creating +the pool to check if a connection can successfully be established. +*/ +package pgxpool diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/pool.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/pool.go new file mode 100644 index 00000000..fdcba724 --- /dev/null +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/pool.go @@ -0,0 +1,717 @@ +package pgxpool + +import ( + "context" + "fmt" + "math/rand" + "runtime" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/puddle/v2" +) + +var defaultMaxConns = int32(4) +var defaultMinConns = int32(0) +var defaultMaxConnLifetime = time.Hour +var defaultMaxConnIdleTime = time.Minute * 30 +var defaultHealthCheckPeriod = time.Minute + +type connResource struct { + conn *pgx.Conn + conns []Conn + poolRows []poolRow + poolRowss []poolRows + maxAgeTime time.Time +} + +func (cr *connResource) getConn(p *Pool, res *puddle.Resource[*connResource]) *Conn { + if len(cr.conns) == 0 { + cr.conns = make([]Conn, 128) + } + + c := &cr.conns[len(cr.conns)-1] + cr.conns = cr.conns[0 : len(cr.conns)-1] + + c.res = res + c.p = p + + return c +} + +func (cr *connResource) getPoolRow(c *Conn, r pgx.Row) *poolRow { + if len(cr.poolRows) == 0 { + cr.poolRows = make([]poolRow, 128) + } + + pr := &cr.poolRows[len(cr.poolRows)-1] + cr.poolRows = cr.poolRows[0 : len(cr.poolRows)-1] + + pr.c = c + pr.r = r + + return pr +} + +func (cr *connResource) getPoolRows(c *Conn, r pgx.Rows) *poolRows { + if len(cr.poolRowss) == 0 { + cr.poolRowss = make([]poolRows, 128) + } + + pr := &cr.poolRowss[len(cr.poolRowss)-1] + cr.poolRowss = cr.poolRowss[0 : len(cr.poolRowss)-1] + + pr.c = c + pr.r = r + + return pr +} + +// Pool allows for connection reuse. +type Pool struct { + // 64 bit fields accessed with atomics must be at beginning of struct to guarantee alignment for certain 32-bit + // architectures. See BUGS section of https://pkg.go.dev/sync/atomic and https://github.com/jackc/pgx/issues/1288. + newConnsCount int64 + lifetimeDestroyCount int64 + idleDestroyCount int64 + + p *puddle.Pool[*connResource] + config *Config + beforeConnect func(context.Context, *pgx.ConnConfig) error + afterConnect func(context.Context, *pgx.Conn) error + beforeAcquire func(context.Context, *pgx.Conn) bool + afterRelease func(*pgx.Conn) bool + beforeClose func(*pgx.Conn) + minConns int32 + maxConns int32 + maxConnLifetime time.Duration + maxConnLifetimeJitter time.Duration + maxConnIdleTime time.Duration + healthCheckPeriod time.Duration + + healthCheckChan chan struct{} + + acquireTracer AcquireTracer + releaseTracer ReleaseTracer + + closeOnce sync.Once + closeChan chan struct{} +} + +// Config is the configuration struct for creating a pool. It must be created by [ParseConfig] and then it can be +// modified. +type Config struct { + ConnConfig *pgx.ConnConfig + + // BeforeConnect is called before a new connection is made. It is passed a copy of the underlying pgx.ConnConfig and + // will not impact any existing open connections. + BeforeConnect func(context.Context, *pgx.ConnConfig) error + + // AfterConnect is called after a connection is established, but before it is added to the pool. + AfterConnect func(context.Context, *pgx.Conn) error + + // BeforeAcquire is called before a connection is acquired from the pool. It must return true to allow the + // acquisition or false to indicate that the connection should be destroyed and a different connection should be + // acquired. + BeforeAcquire func(context.Context, *pgx.Conn) bool + + // AfterRelease is called after a connection is released, but before it is returned to the pool. It must return true to + // return the connection to the pool or false to destroy the connection. + AfterRelease func(*pgx.Conn) bool + + // BeforeClose is called right before a connection is closed and removed from the pool. + BeforeClose func(*pgx.Conn) + + // MaxConnLifetime is the duration since creation after which a connection will be automatically closed. + MaxConnLifetime time.Duration + + // MaxConnLifetimeJitter is the duration after MaxConnLifetime to randomly decide to close a connection. + // This helps prevent all connections from being closed at the exact same time, starving the pool. + MaxConnLifetimeJitter time.Duration + + // MaxConnIdleTime is the duration after which an idle connection will be automatically closed by the health check. + MaxConnIdleTime time.Duration + + // MaxConns is the maximum size of the pool. The default is the greater of 4 or runtime.NumCPU(). + MaxConns int32 + + // MinConns is the minimum size of the pool. After connection closes, the pool might dip below MinConns. A low + // number of MinConns might mean the pool is empty after MaxConnLifetime until the health check has a chance + // to create new connections. + MinConns int32 + + // HealthCheckPeriod is the duration between checks of the health of idle connections. + HealthCheckPeriod time.Duration + + createdByParseConfig bool // Used to enforce created by ParseConfig rule. +} + +// Copy returns a deep copy of the config that is safe to use and modify. +// The only exception is the tls.Config: +// according to the tls.Config docs it must not be modified after creation. +func (c *Config) Copy() *Config { + newConfig := new(Config) + *newConfig = *c + newConfig.ConnConfig = c.ConnConfig.Copy() + return newConfig +} + +// ConnString returns the connection string as parsed by pgxpool.ParseConfig into pgxpool.Config. +func (c *Config) ConnString() string { return c.ConnConfig.ConnString() } + +// New creates a new Pool. See [ParseConfig] for information on connString format. +func New(ctx context.Context, connString string) (*Pool, error) { + config, err := ParseConfig(connString) + if err != nil { + return nil, err + } + + return NewWithConfig(ctx, config) +} + +// NewWithConfig creates a new Pool. config must have been created by [ParseConfig]. +func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { + // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from + // zero values. + if !config.createdByParseConfig { + panic("config must be created by ParseConfig") + } + + p := &Pool{ + config: config, + beforeConnect: config.BeforeConnect, + afterConnect: config.AfterConnect, + beforeAcquire: config.BeforeAcquire, + afterRelease: config.AfterRelease, + beforeClose: config.BeforeClose, + minConns: config.MinConns, + maxConns: config.MaxConns, + maxConnLifetime: config.MaxConnLifetime, + maxConnLifetimeJitter: config.MaxConnLifetimeJitter, + maxConnIdleTime: config.MaxConnIdleTime, + healthCheckPeriod: config.HealthCheckPeriod, + healthCheckChan: make(chan struct{}, 1), + closeChan: make(chan struct{}), + } + + if t, ok := config.ConnConfig.Tracer.(AcquireTracer); ok { + p.acquireTracer = t + } + + if t, ok := config.ConnConfig.Tracer.(ReleaseTracer); ok { + p.releaseTracer = t + } + + var err error + p.p, err = puddle.NewPool( + &puddle.Config[*connResource]{ + Constructor: func(ctx context.Context) (*connResource, error) { + atomic.AddInt64(&p.newConnsCount, 1) + connConfig := p.config.ConnConfig.Copy() + + // Connection will continue in background even if Acquire is canceled. Ensure that a connect won't hang forever. + if connConfig.ConnectTimeout <= 0 { + connConfig.ConnectTimeout = 2 * time.Minute + } + + if p.beforeConnect != nil { + if err := p.beforeConnect(ctx, connConfig); err != nil { + return nil, err + } + } + + conn, err := pgx.ConnectConfig(ctx, connConfig) + if err != nil { + return nil, err + } + + if p.afterConnect != nil { + err = p.afterConnect(ctx, conn) + if err != nil { + conn.Close(ctx) + return nil, err + } + } + + jitterSecs := rand.Float64() * config.MaxConnLifetimeJitter.Seconds() + maxAgeTime := time.Now().Add(config.MaxConnLifetime).Add(time.Duration(jitterSecs) * time.Second) + + cr := &connResource{ + conn: conn, + conns: make([]Conn, 64), + poolRows: make([]poolRow, 64), + poolRowss: make([]poolRows, 64), + maxAgeTime: maxAgeTime, + } + + return cr, nil + }, + Destructor: func(value *connResource) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + conn := value.conn + if p.beforeClose != nil { + p.beforeClose(conn) + } + conn.Close(ctx) + select { + case <-conn.PgConn().CleanupDone(): + case <-ctx.Done(): + } + cancel() + }, + MaxSize: config.MaxConns, + }, + ) + if err != nil { + return nil, err + } + + go func() { + p.createIdleResources(ctx, int(p.minConns)) + p.backgroundHealthCheck() + }() + + return p, nil +} + +// ParseConfig builds a Config from connString. It parses connString with the same behavior as [pgx.ParseConfig] with the +// addition of the following variables: +// +// - pool_max_conns: integer greater than 0 +// - pool_min_conns: integer 0 or greater +// - pool_max_conn_lifetime: duration string +// - pool_max_conn_idle_time: duration string +// - pool_health_check_period: duration string +// - pool_max_conn_lifetime_jitter: duration string +// +// See Config for definitions of these arguments. +// +// # Example Keyword/Value +// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca pool_max_conns=10 +// +// # Example URL +// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca&pool_max_conns=10 +func ParseConfig(connString string) (*Config, error) { + connConfig, err := pgx.ParseConfig(connString) + if err != nil { + return nil, err + } + + config := &Config{ + ConnConfig: connConfig, + createdByParseConfig: true, + } + + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conns"]; ok { + delete(connConfig.Config.RuntimeParams, "pool_max_conns") + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return nil, fmt.Errorf("cannot parse pool_max_conns: %w", err) + } + if n < 1 { + return nil, fmt.Errorf("pool_max_conns too small: %d", n) + } + config.MaxConns = int32(n) + } else { + config.MaxConns = defaultMaxConns + if numCPU := int32(runtime.NumCPU()); numCPU > config.MaxConns { + config.MaxConns = numCPU + } + } + + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_min_conns"]; ok { + delete(connConfig.Config.RuntimeParams, "pool_min_conns") + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return nil, fmt.Errorf("cannot parse pool_min_conns: %w", err) + } + config.MinConns = int32(n) + } else { + config.MinConns = defaultMinConns + } + + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_lifetime"]; ok { + delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime") + d, err := time.ParseDuration(s) + if err != nil { + return nil, fmt.Errorf("invalid pool_max_conn_lifetime: %w", err) + } + config.MaxConnLifetime = d + } else { + config.MaxConnLifetime = defaultMaxConnLifetime + } + + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_idle_time"]; ok { + delete(connConfig.Config.RuntimeParams, "pool_max_conn_idle_time") + d, err := time.ParseDuration(s) + if err != nil { + return nil, fmt.Errorf("invalid pool_max_conn_idle_time: %w", err) + } + config.MaxConnIdleTime = d + } else { + config.MaxConnIdleTime = defaultMaxConnIdleTime + } + + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_health_check_period"]; ok { + delete(connConfig.Config.RuntimeParams, "pool_health_check_period") + d, err := time.ParseDuration(s) + if err != nil { + return nil, fmt.Errorf("invalid pool_health_check_period: %w", err) + } + config.HealthCheckPeriod = d + } else { + config.HealthCheckPeriod = defaultHealthCheckPeriod + } + + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_lifetime_jitter"]; ok { + delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime_jitter") + d, err := time.ParseDuration(s) + if err != nil { + return nil, fmt.Errorf("invalid pool_max_conn_lifetime_jitter: %w", err) + } + config.MaxConnLifetimeJitter = d + } + + return config, nil +} + +// Close closes all connections in the pool and rejects future Acquire calls. Blocks until all connections are returned +// to pool and closed. +func (p *Pool) Close() { + p.closeOnce.Do(func() { + close(p.closeChan) + p.p.Close() + }) +} + +func (p *Pool) isExpired(res *puddle.Resource[*connResource]) bool { + return time.Now().After(res.Value().maxAgeTime) +} + +func (p *Pool) triggerHealthCheck() { + go func() { + // Destroy is asynchronous so we give it time to actually remove itself from + // the pool otherwise we might try to check the pool size too soon + time.Sleep(500 * time.Millisecond) + select { + case p.healthCheckChan <- struct{}{}: + default: + } + }() +} + +func (p *Pool) backgroundHealthCheck() { + ticker := time.NewTicker(p.healthCheckPeriod) + defer ticker.Stop() + for { + select { + case <-p.closeChan: + return + case <-p.healthCheckChan: + p.checkHealth() + case <-ticker.C: + p.checkHealth() + } + } +} + +func (p *Pool) checkHealth() { + for { + // If checkMinConns failed we don't destroy any connections since we couldn't + // even get to minConns + if err := p.checkMinConns(); err != nil { + // Should we log this error somewhere? + break + } + if !p.checkConnsHealth() { + // Since we didn't destroy any connections we can stop looping + break + } + // Technically Destroy is asynchronous but 500ms should be enough for it to + // remove it from the underlying pool + select { + case <-p.closeChan: + return + case <-time.After(500 * time.Millisecond): + } + } +} + +// checkConnsHealth will check all idle connections, destroy a connection if +// it's idle or too old, and returns true if any were destroyed +func (p *Pool) checkConnsHealth() bool { + var destroyed bool + totalConns := p.Stat().TotalConns() + resources := p.p.AcquireAllIdle() + for _, res := range resources { + // We're okay going under minConns if the lifetime is up + if p.isExpired(res) && totalConns >= p.minConns { + atomic.AddInt64(&p.lifetimeDestroyCount, 1) + res.Destroy() + destroyed = true + // Since Destroy is async we manually decrement totalConns. + totalConns-- + } else if res.IdleDuration() > p.maxConnIdleTime && totalConns > p.minConns { + atomic.AddInt64(&p.idleDestroyCount, 1) + res.Destroy() + destroyed = true + // Since Destroy is async we manually decrement totalConns. + totalConns-- + } else { + res.ReleaseUnused() + } + } + return destroyed +} + +func (p *Pool) checkMinConns() error { + // TotalConns can include ones that are being destroyed but we should have + // sleep(500ms) around all of the destroys to help prevent that from throwing + // off this check + toCreate := p.minConns - p.Stat().TotalConns() + if toCreate > 0 { + return p.createIdleResources(context.Background(), int(toCreate)) + } + return nil +} + +func (p *Pool) createIdleResources(parentCtx context.Context, targetResources int) error { + ctx, cancel := context.WithCancel(parentCtx) + defer cancel() + + errs := make(chan error, targetResources) + + for i := 0; i < targetResources; i++ { + go func() { + err := p.p.CreateResource(ctx) + // Ignore ErrNotAvailable since it means that the pool has become full since we started creating resource. + if err == puddle.ErrNotAvailable { + err = nil + } + errs <- err + }() + } + + var firstError error + for i := 0; i < targetResources; i++ { + err := <-errs + if err != nil && firstError == nil { + cancel() + firstError = err + } + } + + return firstError +} + +// Acquire returns a connection (*Conn) from the Pool +func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) { + if p.acquireTracer != nil { + ctx = p.acquireTracer.TraceAcquireStart(ctx, p, TraceAcquireStartData{}) + defer func() { + var conn *pgx.Conn + if c != nil { + conn = c.Conn() + } + p.acquireTracer.TraceAcquireEnd(ctx, p, TraceAcquireEndData{Conn: conn, Err: err}) + }() + } + + for { + res, err := p.p.Acquire(ctx) + if err != nil { + return nil, err + } + + cr := res.Value() + + if res.IdleDuration() > time.Second { + err := cr.conn.Ping(ctx) + if err != nil { + res.Destroy() + continue + } + } + + if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) { + return cr.getConn(p, res), nil + } + + res.Destroy() + } +} + +// AcquireFunc acquires a *Conn and calls f with that *Conn. ctx will only affect the Acquire. It has no effect on the +// call of f. The return value is either an error acquiring the *Conn or the return value of f. The *Conn is +// automatically released after the call of f. +func (p *Pool) AcquireFunc(ctx context.Context, f func(*Conn) error) error { + conn, err := p.Acquire(ctx) + if err != nil { + return err + } + defer conn.Release() + + return f(conn) +} + +// AcquireAllIdle atomically acquires all currently idle connections. Its intended use is for health check and +// keep-alive functionality. It does not update pool statistics. +func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn { + resources := p.p.AcquireAllIdle() + conns := make([]*Conn, 0, len(resources)) + for _, res := range resources { + cr := res.Value() + if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) { + conns = append(conns, cr.getConn(p, res)) + } else { + res.Destroy() + } + } + + return conns +} + +// Reset closes all connections, but leaves the pool open. It is intended for use when an error is detected that would +// disrupt all connections (such as a network interruption or a server state change). +// +// It is safe to reset a pool while connections are checked out. Those connections will be closed when they are returned +// to the pool. +func (p *Pool) Reset() { + p.p.Reset() +} + +// Config returns a copy of config that was used to initialize this pool. +func (p *Pool) Config() *Config { return p.config.Copy() } + +// Stat returns a pgxpool.Stat struct with a snapshot of Pool statistics. +func (p *Pool) Stat() *Stat { + return &Stat{ + s: p.p.Stat(), + newConnsCount: atomic.LoadInt64(&p.newConnsCount), + lifetimeDestroyCount: atomic.LoadInt64(&p.lifetimeDestroyCount), + idleDestroyCount: atomic.LoadInt64(&p.idleDestroyCount), + } +} + +// Exec acquires a connection from the Pool and executes the given SQL. +// SQL can be either a prepared statement name or an SQL string. +// Arguments should be referenced positionally from the SQL string as $1, $2, etc. +// The acquired connection is returned to the pool when the Exec function returns. +func (p *Pool) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { + c, err := p.Acquire(ctx) + if err != nil { + return pgconn.CommandTag{}, err + } + defer c.Release() + + return c.Exec(ctx, sql, arguments...) +} + +// Query acquires a connection and executes a query that returns pgx.Rows. +// Arguments should be referenced positionally from the SQL string as $1, $2, etc. +// See pgx.Rows documentation to close the returned Rows and return the acquired connection to the Pool. +// +// If there is an error, the returned pgx.Rows will be returned in an error state. +// If preferred, ignore the error returned from Query and handle errors using the returned pgx.Rows. +// +// For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and +// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely +// needed. See the documentation for those types for details. +func (p *Pool) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { + c, err := p.Acquire(ctx) + if err != nil { + return errRows{err: err}, err + } + + rows, err := c.Query(ctx, sql, args...) + if err != nil { + c.Release() + return errRows{err: err}, err + } + + return c.getPoolRows(rows), nil +} + +// QueryRow acquires a connection and executes a query that is expected +// to return at most one row (pgx.Row). Errors are deferred until pgx.Row's +// Scan method is called. If the query selects no rows, pgx.Row's Scan will +// return ErrNoRows. Otherwise, pgx.Row's Scan scans the first selected row +// and discards the rest. The acquired connection is returned to the Pool when +// pgx.Row's Scan method is called. +// +// Arguments should be referenced positionally from the SQL string as $1, $2, etc. +// +// For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and +// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely +// needed. See the documentation for those types for details. +func (p *Pool) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { + c, err := p.Acquire(ctx) + if err != nil { + return errRow{err: err} + } + + row := c.QueryRow(ctx, sql, args...) + return c.getPoolRow(row) +} + +func (p *Pool) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { + c, err := p.Acquire(ctx) + if err != nil { + return errBatchResults{err: err} + } + + br := c.SendBatch(ctx, b) + return &poolBatchResults{br: br, c: c} +} + +// Begin acquires a connection from the Pool and starts a transaction. Unlike database/sql, the context only affects the begin command. i.e. there is no +// auto-rollback on context cancellation. Begin initiates a transaction block without explicitly setting a transaction mode for the block (see BeginTx with TxOptions if transaction mode is required). +// *pgxpool.Tx is returned, which implements the pgx.Tx interface. +// Commit or Rollback must be called on the returned transaction to finalize the transaction block. +func (p *Pool) Begin(ctx context.Context) (pgx.Tx, error) { + return p.BeginTx(ctx, pgx.TxOptions{}) +} + +// BeginTx acquires a connection from the Pool and starts a transaction with pgx.TxOptions determining the transaction mode. +// Unlike database/sql, the context only affects the begin command. i.e. there is no auto-rollback on context cancellation. +// *pgxpool.Tx is returned, which implements the pgx.Tx interface. +// Commit or Rollback must be called on the returned transaction to finalize the transaction block. +func (p *Pool) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) { + c, err := p.Acquire(ctx) + if err != nil { + return nil, err + } + + t, err := c.BeginTx(ctx, txOptions) + if err != nil { + c.Release() + return nil, err + } + + return &Tx{t: t, c: c}, nil +} + +func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + c, err := p.Acquire(ctx) + if err != nil { + return 0, err + } + defer c.Release() + + return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc) +} + +// Ping acquires a connection from the Pool and executes an empty sql statement against it. +// If the sql returns without error, the database Ping is considered successful, otherwise, the error is returned. +func (p *Pool) Ping(ctx context.Context) error { + c, err := p.Acquire(ctx) + if err != nil { + return err + } + defer c.Release() + return c.Ping(ctx) +} diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/rows.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/rows.go new file mode 100644 index 00000000..f834b7ec --- /dev/null +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/rows.go @@ -0,0 +1,116 @@ +package pgxpool + +import ( + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type errRows struct { + err error +} + +func (errRows) Close() {} +func (e errRows) Err() error { return e.err } +func (errRows) CommandTag() pgconn.CommandTag { return pgconn.CommandTag{} } +func (errRows) FieldDescriptions() []pgconn.FieldDescription { return nil } +func (errRows) Next() bool { return false } +func (e errRows) Scan(dest ...any) error { return e.err } +func (e errRows) Values() ([]any, error) { return nil, e.err } +func (e errRows) RawValues() [][]byte { return nil } +func (e errRows) Conn() *pgx.Conn { return nil } + +type errRow struct { + err error +} + +func (e errRow) Scan(dest ...any) error { return e.err } + +type poolRows struct { + r pgx.Rows + c *Conn + err error +} + +func (rows *poolRows) Close() { + rows.r.Close() + if rows.c != nil { + rows.c.Release() + rows.c = nil + } +} + +func (rows *poolRows) Err() error { + if rows.err != nil { + return rows.err + } + return rows.r.Err() +} + +func (rows *poolRows) CommandTag() pgconn.CommandTag { + return rows.r.CommandTag() +} + +func (rows *poolRows) FieldDescriptions() []pgconn.FieldDescription { + return rows.r.FieldDescriptions() +} + +func (rows *poolRows) Next() bool { + if rows.err != nil { + return false + } + + n := rows.r.Next() + if !n { + rows.Close() + } + return n +} + +func (rows *poolRows) Scan(dest ...any) error { + err := rows.r.Scan(dest...) + if err != nil { + rows.Close() + } + return err +} + +func (rows *poolRows) Values() ([]any, error) { + values, err := rows.r.Values() + if err != nil { + rows.Close() + } + return values, err +} + +func (rows *poolRows) RawValues() [][]byte { + return rows.r.RawValues() +} + +func (rows *poolRows) Conn() *pgx.Conn { + return rows.r.Conn() +} + +type poolRow struct { + r pgx.Row + c *Conn + err error +} + +func (row *poolRow) Scan(dest ...any) error { + if row.err != nil { + return row.err + } + + panicked := true + defer func() { + if panicked && row.c != nil { + row.c.Release() + } + }() + err := row.r.Scan(dest...) + panicked = false + if row.c != nil { + row.c.Release() + } + return err +} diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/stat.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/stat.go new file mode 100644 index 00000000..cfa0c4c5 --- /dev/null +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/stat.go @@ -0,0 +1,84 @@ +package pgxpool + +import ( + "time" + + "github.com/jackc/puddle/v2" +) + +// Stat is a snapshot of Pool statistics. +type Stat struct { + s *puddle.Stat + newConnsCount int64 + lifetimeDestroyCount int64 + idleDestroyCount int64 +} + +// AcquireCount returns the cumulative count of successful acquires from the pool. +func (s *Stat) AcquireCount() int64 { + return s.s.AcquireCount() +} + +// AcquireDuration returns the total duration of all successful acquires from +// the pool. +func (s *Stat) AcquireDuration() time.Duration { + return s.s.AcquireDuration() +} + +// AcquiredConns returns the number of currently acquired connections in the pool. +func (s *Stat) AcquiredConns() int32 { + return s.s.AcquiredResources() +} + +// CanceledAcquireCount returns the cumulative count of acquires from the pool +// that were canceled by a context. +func (s *Stat) CanceledAcquireCount() int64 { + return s.s.CanceledAcquireCount() +} + +// ConstructingConns returns the number of conns with construction in progress in +// the pool. +func (s *Stat) ConstructingConns() int32 { + return s.s.ConstructingResources() +} + +// EmptyAcquireCount returns the cumulative count of successful acquires from the pool +// that waited for a resource to be released or constructed because the pool was +// empty. +func (s *Stat) EmptyAcquireCount() int64 { + return s.s.EmptyAcquireCount() +} + +// IdleConns returns the number of currently idle conns in the pool. +func (s *Stat) IdleConns() int32 { + return s.s.IdleResources() +} + +// MaxConns returns the maximum size of the pool. +func (s *Stat) MaxConns() int32 { + return s.s.MaxResources() +} + +// TotalConns returns the total number of resources currently in the pool. +// The value is the sum of ConstructingConns, AcquiredConns, and +// IdleConns. +func (s *Stat) TotalConns() int32 { + return s.s.TotalResources() +} + +// NewConnsCount returns the cumulative count of new connections opened. +func (s *Stat) NewConnsCount() int64 { + return s.newConnsCount +} + +// MaxLifetimeDestroyCount returns the cumulative count of connections destroyed +// because they exceeded MaxConnLifetime. +func (s *Stat) MaxLifetimeDestroyCount() int64 { + return s.lifetimeDestroyCount +} + +// MaxIdleDestroyCount returns the cumulative count of connections destroyed because +// they exceeded MaxConnIdleTime. +func (s *Stat) MaxIdleDestroyCount() int64 { + return s.idleDestroyCount +} diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/tracer.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/tracer.go new file mode 100644 index 00000000..78b9d15a --- /dev/null +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/tracer.go @@ -0,0 +1,33 @@ +package pgxpool + +import ( + "context" + + "github.com/jackc/pgx/v5" +) + +// AcquireTracer traces Acquire. +type AcquireTracer interface { + // TraceAcquireStart is called at the beginning of Acquire. + // The returned context is used for the rest of the call and will be passed to the TraceAcquireEnd. + TraceAcquireStart(ctx context.Context, pool *Pool, data TraceAcquireStartData) context.Context + // TraceAcquireEnd is called when a connection has been acquired. + TraceAcquireEnd(ctx context.Context, pool *Pool, data TraceAcquireEndData) +} + +type TraceAcquireStartData struct{} + +type TraceAcquireEndData struct { + Conn *pgx.Conn + Err error +} + +// ReleaseTracer traces Release. +type ReleaseTracer interface { + // TraceRelease is called at the beginning of Release. + TraceRelease(pool *Pool, data TraceReleaseData) +} + +type TraceReleaseData struct { + Conn *pgx.Conn +} diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/tx.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/tx.go new file mode 100644 index 00000000..74df8593 --- /dev/null +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/pgxpool/tx.go @@ -0,0 +1,82 @@ +package pgxpool + +import ( + "context" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +// Tx represents a database transaction acquired from a Pool. +type Tx struct { + t pgx.Tx + c *Conn +} + +// Begin starts a pseudo nested transaction implemented with a savepoint. +func (tx *Tx) Begin(ctx context.Context) (pgx.Tx, error) { + return tx.t.Begin(ctx) +} + +// Commit commits the transaction and returns the associated connection back to the Pool. Commit will return ErrTxClosed +// if the Tx is already closed, but is otherwise safe to call multiple times. If the commit fails with a rollback status +// (e.g. the transaction was already in a broken state) then ErrTxCommitRollback will be returned. +func (tx *Tx) Commit(ctx context.Context) error { + err := tx.t.Commit(ctx) + if tx.c != nil { + tx.c.Release() + tx.c = nil + } + return err +} + +// Rollback rolls back the transaction and returns the associated connection back to the Pool. Rollback will return ErrTxClosed +// if the Tx is already closed, but is otherwise safe to call multiple times. Hence, defer tx.Rollback() is safe even if +// tx.Commit() will be called first in a non-error condition. +func (tx *Tx) Rollback(ctx context.Context) error { + err := tx.t.Rollback(ctx) + if tx.c != nil { + tx.c.Release() + tx.c = nil + } + return err +} + +func (tx *Tx) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + return tx.t.CopyFrom(ctx, tableName, columnNames, rowSrc) +} + +func (tx *Tx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { + return tx.t.SendBatch(ctx, b) +} + +func (tx *Tx) LargeObjects() pgx.LargeObjects { + return tx.t.LargeObjects() +} + +// Prepare creates a prepared statement with name and sql. If the name is empty, +// an anonymous prepared statement will be used. sql can contain placeholders +// for bound parameters. These placeholders are referenced positionally as $1, $2, etc. +// +// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same +// name and sql arguments. This allows a code path to Prepare and Query/Exec without +// needing to first check whether the statement has already been prepared. +func (tx *Tx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) { + return tx.t.Prepare(ctx, name, sql) +} + +func (tx *Tx) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { + return tx.t.Exec(ctx, sql, arguments...) +} + +func (tx *Tx) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { + return tx.t.Query(ctx, sql, args...) +} + +func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { + return tx.t.QueryRow(ctx, sql, args...) +} + +func (tx *Tx) Conn() *pgx.Conn { + return tx.t.Conn() +} diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/rows.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/rows.go index 1b1c8ac9..d4f7a901 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/rows.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/rows.go @@ -6,9 +6,9 @@ import ( "fmt" "reflect" "strings" + "sync" "time" - "github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" ) @@ -17,7 +17,8 @@ import ( // the *Conn can be used again. Rows are closed by explicitly calling Close(), // calling Next() until it returns false, or when a fatal error occurs. // -// Once a Rows is closed the only methods that may be called are Close(), Err(), and CommandTag(). +// Once a Rows is closed the only methods that may be called are Close(), Err(), +// and CommandTag(). // // Rows is an interface instead of a struct to allow tests to mock Query. However, // adding a method to an interface is technically a breaking change. Because of this @@ -41,8 +42,15 @@ type Rows interface { FieldDescriptions() []pgconn.FieldDescription // Next prepares the next row for reading. It returns true if there is another - // row and false if no more rows are available. It automatically closes rows - // when all rows are read. + // row and false if no more rows are available or a fatal error has occurred. + // It automatically closes rows when all rows are read. + // + // Callers should check rows.Err() after rows.Next() returns false to detect + // whether result-set reading ended prematurely due to an error. See + // Conn.Query for details. + // + // For simpler error handling, consider using the higher-level pgx v5 + // CollectRows() and ForEachRow() helpers instead. Next() bool // Scan reads the values from the current row into dest values positionally. @@ -166,14 +174,12 @@ func (rows *baseRows) Close() { } if rows.err != nil && rows.conn != nil && rows.sql != "" { - if stmtcache.IsStatementInvalid(rows.err) { - if sc := rows.conn.statementCache; sc != nil { - sc.Invalidate(rows.sql) - } + if sc := rows.conn.statementCache; sc != nil { + sc.Invalidate(rows.sql) + } - if sc := rows.conn.descriptionCache; sc != nil { - sc.Invalidate(rows.sql) - } + if sc := rows.conn.descriptionCache; sc != nil { + sc.Invalidate(rows.sql) } } @@ -412,12 +418,12 @@ type CollectableRow interface { // RowToFunc is a function that scans or otherwise converts row to a T. type RowToFunc[T any] func(row CollectableRow) (T, error) -// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T. -func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { +// AppendRows iterates through rows, calling fn for each row, and appending the results into a slice of T. +// +// This function closes the rows automatically on return. +func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) { defer rows.Close() - slice := []T{} - for rows.Next() { value, err := fn(rows) if err != nil { @@ -433,8 +439,17 @@ func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { return slice, nil } +// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T. +// +// This function closes the rows automatically on return. +func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { + return AppendRows([]T{}, rows, fn) +} + // CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true. // CollectOneRow is to CollectRows as QueryRow is to Query. +// +// This function closes the rows automatically on return. func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { defer rows.Close() @@ -457,6 +472,41 @@ func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { return value, rows.Err() } +// CollectExactlyOneRow calls fn for the first row in rows and returns the result. +// - If no rows are found returns an error where errors.Is(ErrNoRows) is true. +// - If more than 1 row is found returns an error where errors.Is(ErrTooManyRows) is true. +// +// This function closes the rows automatically on return. +func CollectExactlyOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { + defer rows.Close() + + var ( + err error + value T + ) + + if !rows.Next() { + if err = rows.Err(); err != nil { + return value, err + } + + return value, ErrNoRows + } + + value, err = fn(rows) + if err != nil { + return value, err + } + + if rows.Next() { + var zero T + + return zero, ErrTooManyRows + } + + return value, rows.Err() +} + // RowTo returns a T scanned from row. func RowTo[T any](row CollectableRow) (T, error) { var value T @@ -496,20 +546,20 @@ func (rs *mapRowScanner) ScanRow(rows Rows) error { } // RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row -// has fields. The row and T fields will by matched by position. If the "db" struct tag is "-" then the field will be +// has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then the field will be // ignored. func RowToStructByPos[T any](row CollectableRow) (T, error) { var value T - err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) + err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row) return value, err } // RowToAddrOfStructByPos returns the address of a T scanned from row. T must be a struct. T must have the same number a -// public fields as row has fields. The row and T fields will by matched by position. If the "db" struct tag is "-" then +// public fields as row has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then // the field will be ignored. func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) { var value T - err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) + err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row) return &value, err } @@ -517,83 +567,97 @@ type positionalStructRowScanner struct { ptrToStruct any } -func (rs *positionalStructRowScanner) ScanRow(rows Rows) error { - dst := rs.ptrToStruct - dstValue := reflect.ValueOf(dst) - if dstValue.Kind() != reflect.Ptr { - return fmt.Errorf("dst not a pointer") - } - - dstElemValue := dstValue.Elem() - scanTargets := rs.appendScanTargets(dstElemValue, nil) - - if len(rows.RawValues()) > len(scanTargets) { - return fmt.Errorf("got %d values, but dst struct has only %d fields", len(rows.RawValues()), len(scanTargets)) +func (rs *positionalStructRowScanner) ScanRow(rows CollectableRow) error { + typ := reflect.TypeOf(rs.ptrToStruct).Elem() + fields := lookupStructFields(typ) + if len(rows.RawValues()) > len(fields) { + return fmt.Errorf( + "got %d values, but dst struct has only %d fields", + len(rows.RawValues()), + len(fields), + ) } - + scanTargets := setupStructScanTargets(rs.ptrToStruct, fields) return rows.Scan(scanTargets...) } -func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any) []any { - dstElemType := dstElemValue.Type() +// Map from reflect.Type -> []structRowField +var positionalStructFieldMap sync.Map - if scanTargets == nil { - scanTargets = make([]any, 0, dstElemType.NumField()) +func lookupStructFields(t reflect.Type) []structRowField { + if cached, ok := positionalStructFieldMap.Load(t); ok { + return cached.([]structRowField) } - for i := 0; i < dstElemType.NumField(); i++ { - sf := dstElemType.Field(i) + fieldStack := make([]int, 0, 1) + fields := computeStructFields(t, make([]structRowField, 0, t.NumField()), &fieldStack) + fieldsIface, _ := positionalStructFieldMap.LoadOrStore(t, fields) + return fieldsIface.([]structRowField) +} + +func computeStructFields( + t reflect.Type, + fields []structRowField, + fieldStack *[]int, +) []structRowField { + tail := len(*fieldStack) + *fieldStack = append(*fieldStack, 0) + for i := 0; i < t.NumField(); i++ { + sf := t.Field(i) + (*fieldStack)[tail] = i // Handle anonymous struct embedding, but do not try to handle embedded pointers. if sf.Anonymous && sf.Type.Kind() == reflect.Struct { - scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets) + fields = computeStructFields(sf.Type, fields, fieldStack) } else if sf.PkgPath == "" { dbTag, _ := sf.Tag.Lookup(structTagKey) if dbTag == "-" { // Field is ignored, skip it. continue } - scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface()) + fields = append(fields, structRowField{ + path: append([]int(nil), *fieldStack...), + }) } } - - return scanTargets + *fieldStack = (*fieldStack)[:tail] + return fields } // RowToStructByName returns a T scanned from row. T must be a struct. T must have the same number of named public -// fields as row has fields. The row and T fields will by matched by name. The match is case-insensitive. The database +// fields as row has fields. The row and T fields will be matched by name. The match is case-insensitive. The database // column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored. func RowToStructByName[T any](row CollectableRow) (T, error) { var value T - err := row.Scan(&namedStructRowScanner{ptrToStruct: &value}) + err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row) return value, err } // RowToAddrOfStructByName returns the address of a T scanned from row. T must be a struct. T must have the same number -// of named public fields as row has fields. The row and T fields will by matched by name. The match is +// of named public fields as row has fields. The row and T fields will be matched by name. The match is // case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" // then the field will be ignored. func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) { var value T - err := row.Scan(&namedStructRowScanner{ptrToStruct: &value}) + err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row) return &value, err } // RowToStructByNameLax returns a T scanned from row. T must be a struct. T must have greater than or equal number of named public -// fields as row has fields. The row and T fields will by matched by name. The match is case-insensitive. The database +// fields as row has fields. The row and T fields will be matched by name. The match is case-insensitive. The database // column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored. func RowToStructByNameLax[T any](row CollectableRow) (T, error) { var value T - err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true}) + err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row) return value, err } // RowToAddrOfStructByNameLax returns the address of a T scanned from row. T must be a struct. T must have greater than or -// equal number of named public fields as row has fields. The row and T fields will by matched by name. The match is +// equal number of named public fields as row has fields. The row and T fields will be matched by name. The match is // case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" // then the field will be ignored. func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) { var value T - err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true}) + err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row) return &value, err } @@ -602,64 +666,128 @@ type namedStructRowScanner struct { lax bool } -func (rs *namedStructRowScanner) ScanRow(rows Rows) error { - dst := rs.ptrToStruct - dstValue := reflect.ValueOf(dst) - if dstValue.Kind() != reflect.Ptr { - return fmt.Errorf("dst not a pointer") - } - - dstElemValue := dstValue.Elem() - scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions()) +func (rs *namedStructRowScanner) ScanRow(rows CollectableRow) error { + typ := reflect.TypeOf(rs.ptrToStruct).Elem() + fldDescs := rows.FieldDescriptions() + namedStructFields, err := lookupNamedStructFields(typ, fldDescs) if err != nil { return err } - - for i, t := range scanTargets { - if t == nil { - return fmt.Errorf("struct doesn't have corresponding row field %s", rows.FieldDescriptions()[i].Name) - } + if !rs.lax && namedStructFields.missingField != "" { + return fmt.Errorf("cannot find field %s in returned row", namedStructFields.missingField) } - + fields := namedStructFields.fields + scanTargets := setupStructScanTargets(rs.ptrToStruct, fields) return rows.Scan(scanTargets...) } -const structTagKey = "db" - -func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) { - i = -1 - for i, desc := range fldDescs { - if strings.EqualFold(desc.Name, field) { - return i +// Map from namedStructFieldMap -> *namedStructFields +var namedStructFieldMap sync.Map + +type namedStructFieldsKey struct { + t reflect.Type + colNames string +} + +type namedStructFields struct { + fields []structRowField + // missingField is the first field from the struct without a corresponding row field. + // This is used to construct the correct error message for non-lax queries. + missingField string +} + +func lookupNamedStructFields( + t reflect.Type, + fldDescs []pgconn.FieldDescription, +) (*namedStructFields, error) { + key := namedStructFieldsKey{ + t: t, + colNames: joinFieldNames(fldDescs), + } + if cached, ok := namedStructFieldMap.Load(key); ok { + return cached.(*namedStructFields), nil + } + + // We could probably do two-levels of caching, where we compute the key -> fields mapping + // for a type only once, cache it by type, then use that to compute the column -> fields + // mapping for a given set of columns. + fieldStack := make([]int, 0, 1) + fields, missingField := computeNamedStructFields( + fldDescs, + t, + make([]structRowField, len(fldDescs)), + &fieldStack, + ) + for i, f := range fields { + if f.path == nil { + return nil, fmt.Errorf( + "struct doesn't have corresponding row field %s", + fldDescs[i].Name, + ) } } - return + + fieldsIface, _ := namedStructFieldMap.LoadOrStore( + key, + &namedStructFields{fields: fields, missingField: missingField}, + ) + return fieldsIface.(*namedStructFields), nil } -func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any, fldDescs []pgconn.FieldDescription) ([]any, error) { - var err error - dstElemType := dstElemValue.Type() +func joinFieldNames(fldDescs []pgconn.FieldDescription) string { + switch len(fldDescs) { + case 0: + return "" + case 1: + return fldDescs[0].Name + } - if scanTargets == nil { - scanTargets = make([]any, len(fldDescs)) + totalSize := len(fldDescs) - 1 // Space for separator bytes. + for _, d := range fldDescs { + totalSize += len(d.Name) + } + var b strings.Builder + b.Grow(totalSize) + b.WriteString(fldDescs[0].Name) + for _, d := range fldDescs[1:] { + b.WriteByte(0) // Join with NUL byte as it's (presumably) not a valid column character. + b.WriteString(d.Name) } + return b.String() +} - for i := 0; i < dstElemType.NumField(); i++ { - sf := dstElemType.Field(i) +func computeNamedStructFields( + fldDescs []pgconn.FieldDescription, + t reflect.Type, + fields []structRowField, + fieldStack *[]int, +) ([]structRowField, string) { + var missingField string + tail := len(*fieldStack) + *fieldStack = append(*fieldStack, 0) + for i := 0; i < t.NumField(); i++ { + sf := t.Field(i) + (*fieldStack)[tail] = i if sf.PkgPath != "" && !sf.Anonymous { // Field is unexported, skip it. continue } - // Handle anoymous struct embedding, but do not try to handle embedded pointers. + // Handle anonymous struct embedding, but do not try to handle embedded pointers. if sf.Anonymous && sf.Type.Kind() == reflect.Struct { - scanTargets, err = rs.appendScanTargets(dstElemValue.Field(i), scanTargets, fldDescs) - if err != nil { - return nil, err + var missingSubField string + fields, missingSubField = computeNamedStructFields( + fldDescs, + sf.Type, + fields, + fieldStack, + ) + if missingField == "" { + missingField = missingSubField } } else { dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey) if dbTagPresent { - dbTag = strings.Split(dbTag, ",")[0] + dbTag, _, _ = strings.Cut(dbTag, ",") } if dbTag == "-" { // Field is ignored, skip it. @@ -671,17 +799,53 @@ func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, s } fpos := fieldPosByName(fldDescs, colName) if fpos == -1 { - if rs.lax { - continue + if missingField == "" { + missingField = colName } - return nil, fmt.Errorf("cannot find field %s in returned row", colName) + continue } - if fpos >= len(scanTargets) && !rs.lax { - return nil, fmt.Errorf("cannot find field %s in returned row", colName) + fields[fpos] = structRowField{ + path: append([]int(nil), *fieldStack...), } - scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface() } } + *fieldStack = (*fieldStack)[:tail] + + return fields, missingField +} - return scanTargets, err +const structTagKey = "db" + +func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) { + i = -1 + for i, desc := range fldDescs { + + // Snake case support. + field = strings.ReplaceAll(field, "_", "") + descName := strings.ReplaceAll(desc.Name, "_", "") + + if strings.EqualFold(descName, field) { + return i + } + } + return +} + +// structRowField describes a field of a struct. +// +// TODO: It would be a bit more efficient to track the path using the pointer +// offset within the (outermost) struct and use unsafe.Pointer arithmetic to +// construct references when scanning rows. However, it's not clear it's worth +// using unsafe for this. +type structRowField struct { + path []int +} + +func setupStructScanTargets(receiver any, fields []structRowField) []any { + scanTargets := make([]any, len(fields)) + v := reflect.ValueOf(receiver).Elem() + for i, f := range fields { + scanTargets[i] = v.FieldByIndex(f.path).Addr().Interface() + } + return scanTargets } diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/stdlib/sql.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/stdlib/sql.go index 97ecc9b2..29cd3fbb 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/stdlib/sql.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/stdlib/sql.go @@ -7,19 +7,28 @@ // return err // } // -// Or from a DSN string. +// Or from a keyword/value string. // // db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable") // if err != nil { // return err // } // +// Or from a *pgxpool.Pool. +// +// pool, err := pgxpool.New(context.Background(), os.Getenv("DATABASE_URL")) +// if err != nil { +// return err +// } +// +// db := stdlib.OpenDBFromPool(pool) +// // Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the // pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used // with sql.Open. // // connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL")) -// connConfig.Logger = myLogger +// connConfig.Tracer = &tracelog.TraceLog{Logger: myLogger, LogLevel: tracelog.LogLevelInfo} // connStr := stdlib.RegisterConnConfig(connConfig) // db, _ := sql.Open("pgx", connStr) // @@ -74,6 +83,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" ) // Only intrinsic types should be binary format with database/sql. @@ -125,14 +135,14 @@ func contains(list []string, y string) bool { type OptionOpenDB func(*connector) // OptionBeforeConnect provides a callback for before connect. It is passed a shallow copy of the ConnConfig that will -// be used to connect, so only its immediate members should be modified. +// be used to connect, so only its immediate members should be modified. Used only if db is opened with *pgx.ConnConfig. func OptionBeforeConnect(bc func(context.Context, *pgx.ConnConfig) error) OptionOpenDB { return func(dc *connector) { dc.BeforeConnect = bc } } -// OptionAfterConnect provides a callback for after connect. +// OptionAfterConnect provides a callback for after connect. Used only if db is opened with *pgx.ConnConfig. func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB { return func(dc *connector) { dc.AfterConnect = ac @@ -191,13 +201,42 @@ func GetConnector(config pgx.ConnConfig, opts ...OptionOpenDB) driver.Connector return c } +// GetPoolConnector creates a new driver.Connector from the given *pgxpool.Pool. By using this be sure to set the +// maximum idle connections of the *sql.DB created with this connector to zero since they must be managed from the +// *pgxpool.Pool. This is required to avoid acquiring all the connections from the pgxpool and starving any direct +// users of the pgxpool. +func GetPoolConnector(pool *pgxpool.Pool, opts ...OptionOpenDB) driver.Connector { + c := connector{ + pool: pool, + ResetSession: func(context.Context, *pgx.Conn) error { return nil }, // noop reset session by default + driver: pgxDriver, + } + + for _, opt := range opts { + opt(&c) + } + + return c +} + func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB { c := GetConnector(config, opts...) return sql.OpenDB(c) } +// OpenDBFromPool creates a new *sql.DB from the given *pgxpool.Pool. Note that this method automatically sets the +// maximum number of idle connections in *sql.DB to zero, since they must be managed from the *pgxpool.Pool. This is +// required to avoid acquiring all the connections from the pgxpool and starving any direct users of the pgxpool. +func OpenDBFromPool(pool *pgxpool.Pool, opts ...OptionOpenDB) *sql.DB { + c := GetPoolConnector(pool, opts...) + db := sql.OpenDB(c) + db.SetMaxIdleConns(0) + return db +} + type connector struct { pgx.ConnConfig + pool *pgxpool.Pool BeforeConnect func(context.Context, *pgx.ConnConfig) error // function to call before creation of every new connection AfterConnect func(context.Context, *pgx.Conn) error // function to call after creation of every new connection ResetSession func(context.Context, *pgx.Conn) error // function is called before a connection is reused @@ -207,25 +246,53 @@ type connector struct { // Connect implement driver.Connector interface func (c connector) Connect(ctx context.Context) (driver.Conn, error) { var ( - err error - conn *pgx.Conn + connConfig pgx.ConnConfig + conn *pgx.Conn + close func(context.Context) error + err error ) - // Create a shallow copy of the config, so that BeforeConnect can safely modify it - connConfig := c.ConnConfig - if err = c.BeforeConnect(ctx, &connConfig); err != nil { - return nil, err - } + if c.pool == nil { + // Create a shallow copy of the config, so that BeforeConnect can safely modify it + connConfig = c.ConnConfig - if conn, err = pgx.ConnectConfig(ctx, &connConfig); err != nil { - return nil, err - } + if err = c.BeforeConnect(ctx, &connConfig); err != nil { + return nil, err + } - if err = c.AfterConnect(ctx, conn); err != nil { - return nil, err + if conn, err = pgx.ConnectConfig(ctx, &connConfig); err != nil { + return nil, err + } + + if err = c.AfterConnect(ctx, conn); err != nil { + return nil, err + } + + close = conn.Close + } else { + var pconn *pgxpool.Conn + + pconn, err = c.pool.Acquire(ctx) + if err != nil { + return nil, err + } + + conn = pconn.Conn() + + close = func(_ context.Context) error { + pconn.Release() + return nil + } } - return &Conn{conn: conn, driver: c.driver, connConfig: connConfig, resetSessionFunc: c.ResetSession}, nil + return &Conn{ + conn: conn, + close: close, + driver: c.driver, + connConfig: connConfig, + resetSessionFunc: c.ResetSession, + psRefCounts: make(map[*pgconn.StatementDescription]int), + }, nil } // Driver implement driver.Connector interface @@ -302,9 +369,11 @@ func (dc *driverConnector) Connect(ctx context.Context) (driver.Conn, error) { c := &Conn{ conn: conn, + close: conn.Close, driver: dc.driver, connConfig: *connConfig, resetSessionFunc: func(context.Context, *pgx.Conn) error { return nil }, + psRefCounts: make(map[*pgconn.StatementDescription]int), } return c, nil @@ -326,11 +395,19 @@ func UnregisterConnConfig(connStr string) { type Conn struct { conn *pgx.Conn - psCount int64 // Counter used for creating unique prepared statement names + close func(context.Context) error driver *Driver connConfig pgx.ConnConfig resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused lastResetSessionTime time.Time + + // psRefCounts contains reference counts for prepared statements. Prepare uses the underlying pgx logic to generate + // deterministic statement names from the statement text. If this query has already been prepared then the existing + // *pgconn.StatementDescription will be returned. However, this means that if Close is called on the returned Stmt + // then the underlying prepared statement will be closed even when the underlying prepared statement is still in use + // by another database/sql Stmt. To prevent this psRefCounts keeps track of how many database/sql statements are using + // the same underlying statement and only closes the underlying statement when the reference count reaches 0. + psRefCounts map[*pgconn.StatementDescription]int } // Conn returns the underlying *pgx.Conn @@ -347,13 +424,11 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e return nil, driver.ErrBadConn } - name := fmt.Sprintf("pgx_%d", c.psCount) - c.psCount++ - - sd, err := c.conn.Prepare(ctx, name, query) + sd, err := c.conn.Prepare(ctx, query, query) if err != nil { return nil, err } + c.psRefCounts[sd]++ return &Stmt{sd: sd, conn: c}, nil } @@ -361,7 +436,7 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e func (c *Conn) Close() error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - return c.conn.Close(ctx) + return c.close(ctx) } func (c *Conn) Begin() (driver.Tx, error) { @@ -470,7 +545,7 @@ func (c *Conn) ResetSession(ctx context.Context) error { now := time.Now() if now.Sub(c.lastResetSessionTime) > time.Second { - if err := c.conn.PgConn().CheckConn(); err != nil { + if err := c.conn.PgConn().Ping(ctx); err != nil { return driver.ErrBadConn } } @@ -487,7 +562,16 @@ type Stmt struct { func (s *Stmt) Close() error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - return s.conn.conn.Deallocate(ctx, s.sd.Name) + + refCount := s.conn.psRefCounts[s.sd] + if refCount == 1 { + delete(s.conn.psRefCounts, s.sd) + } else { + s.conn.psRefCounts[s.sd]-- + return nil + } + + return s.conn.conn.Deallocate(ctx, s.sd.SQL) } func (s *Stmt) NumInput() int { @@ -499,7 +583,7 @@ func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) { } func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) { - return s.conn.ExecContext(ctx, s.sd.Name, argsV) + return s.conn.ExecContext(ctx, s.sd.SQL, argsV) } func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { @@ -507,7 +591,7 @@ func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { } func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) { - return s.conn.QueryContext(ctx, s.sd.Name, argsV) + return s.conn.QueryContext(ctx, s.sd.SQL, argsV) } type rowValueFunc func(src []byte) (driver.Value, error) @@ -753,7 +837,7 @@ func (r *Rows) Next(dest []driver.Value) error { var err error dest[i], err = r.valueFuncs[i](rv) if err != nil { - return fmt.Errorf("convert field %d failed: %v", i, err) + return fmt.Errorf("convert field %d failed: %w", i, err) } } else { dest[i] = nil diff --git a/cluster-manager/vendor/github.com/jackc/pgx/v5/values.go b/cluster-manager/vendor/github.com/jackc/pgx/v5/values.go index 19c642fa..6e2ff300 100644 --- a/cluster-manager/vendor/github.com/jackc/pgx/v5/values.go +++ b/cluster-manager/vendor/github.com/jackc/pgx/v5/values.go @@ -3,7 +3,6 @@ package pgx import ( "errors" - "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgtype" ) @@ -15,10 +14,6 @@ const ( ) func convertSimpleArgument(m *pgtype.Map, arg any) (any, error) { - if anynil.Is(arg) { - return nil, nil - } - buf, err := m.Encode(0, TextFormatCode, arg, []byte{}) if err != nil { return nil, err @@ -30,10 +25,6 @@ func convertSimpleArgument(m *pgtype.Map, arg any) (any, error) { } func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) { - if anynil.Is(arg) { - return pgio.AppendInt32(buf, -1), nil - } - sp := len(buf) buf = pgio.AppendInt32(buf, -1) argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf) @@ -55,7 +46,11 @@ func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, er func tryScanStringCopyValueThenEncode(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) { s, ok := arg.(string) if !ok { - return nil, errors.New("not a string") + textBuf, err := m.Encode(oid, TextFormatCode, arg, nil) + if err != nil { + return nil, errors.New("not a string and cannot be encoded as text") + } + s = string(textBuf) } var v any diff --git a/cluster-manager/vendor/github.com/jackc/puddle/v2/CHANGELOG.md b/cluster-manager/vendor/github.com/jackc/puddle/v2/CHANGELOG.md new file mode 100644 index 00000000..d0d202c7 --- /dev/null +++ b/cluster-manager/vendor/github.com/jackc/puddle/v2/CHANGELOG.md @@ -0,0 +1,79 @@ +# 2.2.2 (September 10, 2024) + +* Add empty acquire time to stats (Maxim Ivanov) +* Stop importing nanotime from runtime via linkname (maypok86) + +# 2.2.1 (July 15, 2023) + +* Fix: CreateResource cannot overflow pool. This changes documented behavior of CreateResource. Previously, + CreateResource could create a resource even if the pool was full. This could cause the pool to overflow. While this + was documented, it was documenting incorrect behavior. CreateResource now returns an error if the pool is full. + +# 2.2.0 (February 11, 2023) + +* Use Go 1.19 atomics and drop go.uber.org/atomic dependency + +# 2.1.2 (November 12, 2022) + +* Restore support to Go 1.18 via go.uber.org/atomic + +# 2.1.1 (November 11, 2022) + +* Fix create resource concurrently with Stat call race + +# 2.1.0 (October 28, 2022) + +* Concurrency control is now implemented with a semaphore. This simplifies some internal logic, resolves a few error conditions (including a deadlock), and improves performance. (Jan Dubsky) +* Go 1.19 is now required for the improved atomic support. + +# 2.0.1 (October 28, 2022) + +* Fix race condition when Close is called concurrently with multiple constructors + +# 2.0.0 (September 17, 2022) + +* Use generics instead of interface{} (Столяров Владимир Алексеевич) +* Add Reset +* Do not cancel resource construction when Acquire is canceled +* NewPool takes Config + +# 1.3.0 (August 27, 2022) + +* Acquire creates resources in background to allow creation to continue after Acquire is canceled (James Hartig) + +# 1.2.1 (December 2, 2021) + +* TryAcquire now does not block when background constructing resource + +# 1.2.0 (November 20, 2021) + +* Add TryAcquire (A. Jensen) +* Fix: remove memory leak / unintentionally pinned memory when shrinking slices (Alexander Staubo) +* Fix: Do not leave pool locked after panic from nil context + +# 1.1.4 (September 11, 2021) + +* Fix: Deadlock in CreateResource if pool was closed during resource acquisition (Dmitriy Matrenichev) + +# 1.1.3 (December 3, 2020) + +* Fix: Failed resource creation could cause concurrent Acquire to hang. (Evgeny Vanslov) + +# 1.1.2 (September 26, 2020) + +* Fix: Resource.Destroy no longer removes itself from the pool before its destructor has completed. +* Fix: Prevent crash when pool is closed while resource is being created. + +# 1.1.1 (April 2, 2020) + +* Pool.Close can be safely called multiple times +* AcquireAllIDle immediately returns nil if pool is closed +* CreateResource checks if pool is closed before taking any action +* Fix potential race condition when CreateResource and Close are called concurrently. CreateResource now checks if pool is closed before adding newly created resource to pool. + +# 1.1.0 (February 5, 2020) + +* Use runtime.nanotime for faster tracking of acquire time and last usage time. +* Track resource idle time to enable client health check logic. (Patrick Ellul) +* Add CreateResource to construct a new resource without acquiring it. (Patrick Ellul) +* Fix deadlock race when acquire is cancelled. (Michael Tharp) diff --git a/cluster-manager/vendor/github.com/jackc/puddle/v2/LICENSE b/cluster-manager/vendor/github.com/jackc/puddle/v2/LICENSE new file mode 100644 index 00000000..bcc286c5 --- /dev/null +++ b/cluster-manager/vendor/github.com/jackc/puddle/v2/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2018 Jack Christensen + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/cluster-manager/vendor/github.com/jackc/puddle/v2/README.md b/cluster-manager/vendor/github.com/jackc/puddle/v2/README.md new file mode 100644 index 00000000..fa82a9d4 --- /dev/null +++ b/cluster-manager/vendor/github.com/jackc/puddle/v2/README.md @@ -0,0 +1,80 @@ +[![Go Reference](https://pkg.go.dev/badge/github.com/jackc/puddle/v2.svg)](https://pkg.go.dev/github.com/jackc/puddle/v2) +![Build Status](https://github.com/jackc/puddle/actions/workflows/ci.yml/badge.svg) + +# Puddle + +Puddle is a tiny generic resource pool library for Go that uses the standard +context library to signal cancellation of acquires. It is designed to contain +the minimum functionality required for a resource pool. It can be used directly +or it can be used as the base for a domain specific resource pool. For example, +a database connection pool may use puddle internally and implement health checks +and keep-alive behavior without needing to implement any concurrent code of its +own. + +## Features + +* Acquire cancellation via context standard library +* Statistics API for monitoring pool pressure +* No dependencies outside of standard library and golang.org/x/sync +* High performance +* 100% test coverage of reachable code + +## Example Usage + +```go +package main + +import ( + "context" + "log" + "net" + + "github.com/jackc/puddle/v2" +) + +func main() { + constructor := func(context.Context) (net.Conn, error) { + return net.Dial("tcp", "127.0.0.1:8080") + } + destructor := func(value net.Conn) { + value.Close() + } + maxPoolSize := int32(10) + + pool, err := puddle.NewPool(&puddle.Config[net.Conn]{Constructor: constructor, Destructor: destructor, MaxSize: maxPoolSize}) + if err != nil { + log.Fatal(err) + } + + // Acquire resource from the pool. + res, err := pool.Acquire(context.Background()) + if err != nil { + log.Fatal(err) + } + + // Use resource. + _, err = res.Value().Write([]byte{1}) + if err != nil { + log.Fatal(err) + } + + // Release when done. + res.Release() +} +``` + +## Status + +Puddle is stable and feature complete. + +* Bug reports and fixes are welcome. +* New features will usually not be accepted if they can be feasibly implemented in a wrapper. +* Performance optimizations will usually not be accepted unless the performance issue rises to the level of a bug. + +## Supported Go Versions + +puddle supports the same versions of Go that are supported by the Go project. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases. This means puddle supports Go 1.19 and higher. + +## License + +MIT diff --git a/cluster-manager/vendor/github.com/jackc/puddle/v2/context.go b/cluster-manager/vendor/github.com/jackc/puddle/v2/context.go new file mode 100644 index 00000000..e19d2a60 --- /dev/null +++ b/cluster-manager/vendor/github.com/jackc/puddle/v2/context.go @@ -0,0 +1,24 @@ +package puddle + +import ( + "context" + "time" +) + +// valueCancelCtx combines two contexts into one. One context is used for values and the other is used for cancellation. +type valueCancelCtx struct { + valueCtx context.Context + cancelCtx context.Context +} + +func (ctx *valueCancelCtx) Deadline() (time.Time, bool) { return ctx.cancelCtx.Deadline() } +func (ctx *valueCancelCtx) Done() <-chan struct{} { return ctx.cancelCtx.Done() } +func (ctx *valueCancelCtx) Err() error { return ctx.cancelCtx.Err() } +func (ctx *valueCancelCtx) Value(key any) any { return ctx.valueCtx.Value(key) } + +func newValueCancelCtx(valueCtx, cancelContext context.Context) context.Context { + return &valueCancelCtx{ + valueCtx: valueCtx, + cancelCtx: cancelContext, + } +} diff --git a/cluster-manager/vendor/github.com/jackc/puddle/v2/doc.go b/cluster-manager/vendor/github.com/jackc/puddle/v2/doc.go new file mode 100644 index 00000000..818e4a69 --- /dev/null +++ b/cluster-manager/vendor/github.com/jackc/puddle/v2/doc.go @@ -0,0 +1,11 @@ +// Package puddle is a generic resource pool with type-parametrized api. +/* + +Puddle is a tiny generic resource pool library for Go that uses the standard +context library to signal cancellation of acquires. It is designed to contain +the minimum functionality a resource pool needs that cannot be implemented +without concurrency concerns. For example, a database connection pool may use +puddle internally and implement health checks and keep-alive behavior without +needing to implement any concurrent code of its own. +*/ +package puddle diff --git a/cluster-manager/vendor/github.com/jackc/puddle/v2/internal/genstack/gen_stack.go b/cluster-manager/vendor/github.com/jackc/puddle/v2/internal/genstack/gen_stack.go new file mode 100644 index 00000000..7e4660c8 --- /dev/null +++ b/cluster-manager/vendor/github.com/jackc/puddle/v2/internal/genstack/gen_stack.go @@ -0,0 +1,85 @@ +package genstack + +// GenStack implements a generational stack. +// +// GenStack works as common stack except for the fact that all elements in the +// older generation are guaranteed to be popped before any element in the newer +// generation. New elements are always pushed to the current (newest) +// generation. +// +// We could also say that GenStack behaves as a stack in case of a single +// generation, but it behaves as a queue of individual generation stacks. +type GenStack[T any] struct { + // We can represent arbitrary number of generations using 2 stacks. The + // new stack stores all new pushes and the old stack serves all reads. + // Old stack can represent multiple generations. If old == new, then all + // elements pushed in previous (not current) generations have already + // been popped. + + old *stack[T] + new *stack[T] +} + +// NewGenStack creates a new empty GenStack. +func NewGenStack[T any]() *GenStack[T] { + s := &stack[T]{} + return &GenStack[T]{ + old: s, + new: s, + } +} + +func (s *GenStack[T]) Pop() (T, bool) { + // Pushes always append to the new stack, so if the old once becomes + // empty, it will remail empty forever. + if s.old.len() == 0 && s.old != s.new { + s.old = s.new + } + + if s.old.len() == 0 { + var zero T + return zero, false + } + + return s.old.pop(), true +} + +// Push pushes a new element at the top of the stack. +func (s *GenStack[T]) Push(v T) { s.new.push(v) } + +// NextGen starts a new stack generation. +func (s *GenStack[T]) NextGen() { + if s.old == s.new { + s.new = &stack[T]{} + return + } + + // We need to pop from the old stack to the top of the new stack. Let's + // have an example: + // + // Old: 4 3 2 1 + // New: 8 7 6 5 + // PopOrder: 1 2 3 4 5 6 7 8 + // + // + // To preserve pop order, we have to take all elements from the old + // stack and push them to the top of new stack: + // + // New: 8 7 6 5 4 3 2 1 + // + s.new.push(s.old.takeAll()...) + + // We have the old stack allocated and empty, so why not to reuse it as + // new new stack. + s.old, s.new = s.new, s.old +} + +// Len returns number of elements in the stack. +func (s *GenStack[T]) Len() int { + l := s.old.len() + if s.old != s.new { + l += s.new.len() + } + + return l +} diff --git a/cluster-manager/vendor/github.com/jackc/puddle/v2/internal/genstack/stack.go b/cluster-manager/vendor/github.com/jackc/puddle/v2/internal/genstack/stack.go new file mode 100644 index 00000000..dbced0c7 --- /dev/null +++ b/cluster-manager/vendor/github.com/jackc/puddle/v2/internal/genstack/stack.go @@ -0,0 +1,39 @@ +package genstack + +// stack is a wrapper around an array implementing a stack. +// +// We cannot use slice to represent the stack because append might change the +// pointer value of the slice. That would be an issue in GenStack +// implementation. +type stack[T any] struct { + arr []T +} + +// push pushes a new element at the top of a stack. +func (s *stack[T]) push(vs ...T) { s.arr = append(s.arr, vs...) } + +// pop pops the stack top-most element. +// +// If stack length is zero, this method panics. +func (s *stack[T]) pop() T { + idx := s.len() - 1 + val := s.arr[idx] + + // Avoid memory leak + var zero T + s.arr[idx] = zero + + s.arr = s.arr[:idx] + return val +} + +// takeAll returns all elements in the stack in order as they are stored - i.e. +// the top-most stack element is the last one. +func (s *stack[T]) takeAll() []T { + arr := s.arr + s.arr = nil + return arr +} + +// len returns number of elements in the stack. +func (s *stack[T]) len() int { return len(s.arr) } diff --git a/cluster-manager/vendor/github.com/jackc/puddle/v2/log.go b/cluster-manager/vendor/github.com/jackc/puddle/v2/log.go new file mode 100644 index 00000000..b21b9463 --- /dev/null +++ b/cluster-manager/vendor/github.com/jackc/puddle/v2/log.go @@ -0,0 +1,32 @@ +package puddle + +import "unsafe" + +type ints interface { + int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32 | uint64 +} + +// log2Int returns log2 of an integer. This function panics if val < 0. For val +// == 0, returns 0. +func log2Int[T ints](val T) uint8 { + if val <= 0 { + panic("log2 of non-positive number does not exist") + } + + return log2IntRange(val, 0, uint8(8*unsafe.Sizeof(val))) +} + +func log2IntRange[T ints](val T, begin, end uint8) uint8 { + length := end - begin + if length == 1 { + return begin + } + + delim := begin + length/2 + mask := T(1) << delim + if mask > val { + return log2IntRange(val, begin, delim) + } else { + return log2IntRange(val, delim, end) + } +} diff --git a/cluster-manager/vendor/github.com/jackc/puddle/v2/nanotime.go b/cluster-manager/vendor/github.com/jackc/puddle/v2/nanotime.go new file mode 100644 index 00000000..8a5351a0 --- /dev/null +++ b/cluster-manager/vendor/github.com/jackc/puddle/v2/nanotime.go @@ -0,0 +1,16 @@ +package puddle + +import "time" + +// nanotime returns the time in nanoseconds since process start. +// +// This approach, described at +// https://github.com/golang/go/issues/61765#issuecomment-1672090302, +// is fast, monotonic, and portable, and avoids the previous +// dependence on runtime.nanotime using the (unsafe) linkname hack. +// In particular, time.Since does less work than time.Now. +func nanotime() int64 { + return time.Since(globalStart).Nanoseconds() +} + +var globalStart = time.Now() diff --git a/cluster-manager/vendor/github.com/jackc/puddle/v2/pool.go b/cluster-manager/vendor/github.com/jackc/puddle/v2/pool.go new file mode 100644 index 00000000..c411d2f6 --- /dev/null +++ b/cluster-manager/vendor/github.com/jackc/puddle/v2/pool.go @@ -0,0 +1,710 @@ +package puddle + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "time" + + "github.com/jackc/puddle/v2/internal/genstack" + "golang.org/x/sync/semaphore" +) + +const ( + resourceStatusConstructing = 0 + resourceStatusIdle = iota + resourceStatusAcquired = iota + resourceStatusHijacked = iota +) + +// ErrClosedPool occurs on an attempt to acquire a connection from a closed pool +// or a pool that is closed while the acquire is waiting. +var ErrClosedPool = errors.New("closed pool") + +// ErrNotAvailable occurs on an attempt to acquire a resource from a pool +// that is at maximum capacity and has no available resources. +var ErrNotAvailable = errors.New("resource not available") + +// Constructor is a function called by the pool to construct a resource. +type Constructor[T any] func(ctx context.Context) (res T, err error) + +// Destructor is a function called by the pool to destroy a resource. +type Destructor[T any] func(res T) + +// Resource is the resource handle returned by acquiring from the pool. +type Resource[T any] struct { + value T + pool *Pool[T] + creationTime time.Time + lastUsedNano int64 + poolResetCount int + status byte +} + +// Value returns the resource value. +func (res *Resource[T]) Value() T { + if !(res.status == resourceStatusAcquired || res.status == resourceStatusHijacked) { + panic("tried to access resource that is not acquired or hijacked") + } + return res.value +} + +// Release returns the resource to the pool. res must not be subsequently used. +func (res *Resource[T]) Release() { + if res.status != resourceStatusAcquired { + panic("tried to release resource that is not acquired") + } + res.pool.releaseAcquiredResource(res, nanotime()) +} + +// ReleaseUnused returns the resource to the pool without updating when it was last used used. i.e. LastUsedNanotime +// will not change. res must not be subsequently used. +func (res *Resource[T]) ReleaseUnused() { + if res.status != resourceStatusAcquired { + panic("tried to release resource that is not acquired") + } + res.pool.releaseAcquiredResource(res, res.lastUsedNano) +} + +// Destroy returns the resource to the pool for destruction. res must not be +// subsequently used. +func (res *Resource[T]) Destroy() { + if res.status != resourceStatusAcquired { + panic("tried to destroy resource that is not acquired") + } + go res.pool.destroyAcquiredResource(res) +} + +// Hijack assumes ownership of the resource from the pool. Caller is responsible +// for cleanup of resource value. +func (res *Resource[T]) Hijack() { + if res.status != resourceStatusAcquired { + panic("tried to hijack resource that is not acquired") + } + res.pool.hijackAcquiredResource(res) +} + +// CreationTime returns when the resource was created by the pool. +func (res *Resource[T]) CreationTime() time.Time { + if !(res.status == resourceStatusAcquired || res.status == resourceStatusHijacked) { + panic("tried to access resource that is not acquired or hijacked") + } + return res.creationTime +} + +// LastUsedNanotime returns when Release was last called on the resource measured in nanoseconds from an arbitrary time +// (a monotonic time). Returns creation time if Release has never been called. This is only useful to compare with +// other calls to LastUsedNanotime. In almost all cases, IdleDuration should be used instead. +func (res *Resource[T]) LastUsedNanotime() int64 { + if !(res.status == resourceStatusAcquired || res.status == resourceStatusHijacked) { + panic("tried to access resource that is not acquired or hijacked") + } + + return res.lastUsedNano +} + +// IdleDuration returns the duration since Release was last called on the resource. This is equivalent to subtracting +// LastUsedNanotime to the current nanotime. +func (res *Resource[T]) IdleDuration() time.Duration { + if !(res.status == resourceStatusAcquired || res.status == resourceStatusHijacked) { + panic("tried to access resource that is not acquired or hijacked") + } + + return time.Duration(nanotime() - res.lastUsedNano) +} + +// Pool is a concurrency-safe resource pool. +type Pool[T any] struct { + // mux is the pool internal lock. Any modification of shared state of + // the pool (but Acquires of acquireSem) must be performed only by + // holder of the lock. Long running operations are not allowed when mux + // is held. + mux sync.Mutex + // acquireSem provides an allowance to acquire a resource. + // + // Releases are allowed only when caller holds mux. Acquires have to + // happen before mux is locked (doesn't apply to semaphore.TryAcquire in + // AcquireAllIdle). + acquireSem *semaphore.Weighted + destructWG sync.WaitGroup + + allResources resList[T] + idleResources *genstack.GenStack[*Resource[T]] + + constructor Constructor[T] + destructor Destructor[T] + maxSize int32 + + acquireCount int64 + acquireDuration time.Duration + emptyAcquireCount int64 + emptyAcquireWaitTime time.Duration + canceledAcquireCount atomic.Int64 + + resetCount int + + baseAcquireCtx context.Context + cancelBaseAcquireCtx context.CancelFunc + closed bool +} + +type Config[T any] struct { + Constructor Constructor[T] + Destructor Destructor[T] + MaxSize int32 +} + +// NewPool creates a new pool. Returns an error iff MaxSize is less than 1. +func NewPool[T any](config *Config[T]) (*Pool[T], error) { + if config.MaxSize < 1 { + return nil, errors.New("MaxSize must be >= 1") + } + + baseAcquireCtx, cancelBaseAcquireCtx := context.WithCancel(context.Background()) + + return &Pool[T]{ + acquireSem: semaphore.NewWeighted(int64(config.MaxSize)), + idleResources: genstack.NewGenStack[*Resource[T]](), + maxSize: config.MaxSize, + constructor: config.Constructor, + destructor: config.Destructor, + baseAcquireCtx: baseAcquireCtx, + cancelBaseAcquireCtx: cancelBaseAcquireCtx, + }, nil +} + +// Close destroys all resources in the pool and rejects future Acquire calls. +// Blocks until all resources are returned to pool and destroyed. +func (p *Pool[T]) Close() { + defer p.destructWG.Wait() + + p.mux.Lock() + defer p.mux.Unlock() + + if p.closed { + return + } + p.closed = true + p.cancelBaseAcquireCtx() + + for res, ok := p.idleResources.Pop(); ok; res, ok = p.idleResources.Pop() { + p.allResources.remove(res) + go p.destructResourceValue(res.value) + } +} + +// Stat is a snapshot of Pool statistics. +type Stat struct { + constructingResources int32 + acquiredResources int32 + idleResources int32 + maxResources int32 + acquireCount int64 + acquireDuration time.Duration + emptyAcquireCount int64 + emptyAcquireWaitTime time.Duration + canceledAcquireCount int64 +} + +// TotalResources returns the total number of resources currently in the pool. +// The value is the sum of ConstructingResources, AcquiredResources, and +// IdleResources. +func (s *Stat) TotalResources() int32 { + return s.constructingResources + s.acquiredResources + s.idleResources +} + +// ConstructingResources returns the number of resources with construction in progress in +// the pool. +func (s *Stat) ConstructingResources() int32 { + return s.constructingResources +} + +// AcquiredResources returns the number of currently acquired resources in the pool. +func (s *Stat) AcquiredResources() int32 { + return s.acquiredResources +} + +// IdleResources returns the number of currently idle resources in the pool. +func (s *Stat) IdleResources() int32 { + return s.idleResources +} + +// MaxResources returns the maximum size of the pool. +func (s *Stat) MaxResources() int32 { + return s.maxResources +} + +// AcquireCount returns the cumulative count of successful acquires from the pool. +func (s *Stat) AcquireCount() int64 { + return s.acquireCount +} + +// AcquireDuration returns the total duration of all successful acquires from +// the pool. +func (s *Stat) AcquireDuration() time.Duration { + return s.acquireDuration +} + +// EmptyAcquireCount returns the cumulative count of successful acquires from the pool +// that waited for a resource to be released or constructed because the pool was +// empty. +func (s *Stat) EmptyAcquireCount() int64 { + return s.emptyAcquireCount +} + +// EmptyAcquireWaitTime returns the cumulative time waited for successful acquires +// from the pool for a resource to be released or constructed because the pool was +// empty. +func (s *Stat) EmptyAcquireWaitTime() time.Duration { + return s.emptyAcquireWaitTime +} + +// CanceledAcquireCount returns the cumulative count of acquires from the pool +// that were canceled by a context. +func (s *Stat) CanceledAcquireCount() int64 { + return s.canceledAcquireCount +} + +// Stat returns the current pool statistics. +func (p *Pool[T]) Stat() *Stat { + p.mux.Lock() + defer p.mux.Unlock() + + s := &Stat{ + maxResources: p.maxSize, + acquireCount: p.acquireCount, + emptyAcquireCount: p.emptyAcquireCount, + emptyAcquireWaitTime: p.emptyAcquireWaitTime, + canceledAcquireCount: p.canceledAcquireCount.Load(), + acquireDuration: p.acquireDuration, + } + + for _, res := range p.allResources { + switch res.status { + case resourceStatusConstructing: + s.constructingResources += 1 + case resourceStatusIdle: + s.idleResources += 1 + case resourceStatusAcquired: + s.acquiredResources += 1 + } + } + + return s +} + +// tryAcquireIdleResource checks if there is any idle resource. If there is +// some, this method removes it from idle list and returns it. If the idle pool +// is empty, this method returns nil and doesn't modify the idleResources slice. +// +// WARNING: Caller of this method must hold the pool mutex! +func (p *Pool[T]) tryAcquireIdleResource() *Resource[T] { + res, ok := p.idleResources.Pop() + if !ok { + return nil + } + + res.status = resourceStatusAcquired + return res +} + +// createNewResource creates a new resource and inserts it into list of pool +// resources. +// +// WARNING: Caller of this method must hold the pool mutex! +func (p *Pool[T]) createNewResource() *Resource[T] { + res := &Resource[T]{ + pool: p, + creationTime: time.Now(), + lastUsedNano: nanotime(), + poolResetCount: p.resetCount, + status: resourceStatusConstructing, + } + + p.allResources.append(res) + p.destructWG.Add(1) + + return res +} + +// Acquire gets a resource from the pool. If no resources are available and the pool is not at maximum capacity it will +// create a new resource. If the pool is at maximum capacity it will block until a resource is available. ctx can be +// used to cancel the Acquire. +// +// If Acquire creates a new resource the resource constructor function will receive a context that delegates Value() to +// ctx. Canceling ctx will cause Acquire to return immediately but it will not cancel the resource creation. This avoids +// the problem of it being impossible to create resources when the time to create a resource is greater than any one +// caller of Acquire is willing to wait. +func (p *Pool[T]) Acquire(ctx context.Context) (_ *Resource[T], err error) { + select { + case <-ctx.Done(): + p.canceledAcquireCount.Add(1) + return nil, ctx.Err() + default: + } + + return p.acquire(ctx) +} + +// acquire is a continuation of Acquire function that doesn't check context +// validity. +// +// This function exists solely only for benchmarking purposes. +func (p *Pool[T]) acquire(ctx context.Context) (*Resource[T], error) { + startNano := nanotime() + + var waitedForLock bool + if !p.acquireSem.TryAcquire(1) { + waitedForLock = true + err := p.acquireSem.Acquire(ctx, 1) + if err != nil { + p.canceledAcquireCount.Add(1) + return nil, err + } + } + + p.mux.Lock() + if p.closed { + p.acquireSem.Release(1) + p.mux.Unlock() + return nil, ErrClosedPool + } + + // If a resource is available in the pool. + if res := p.tryAcquireIdleResource(); res != nil { + waitTime := time.Duration(nanotime() - startNano) + if waitedForLock { + p.emptyAcquireCount += 1 + p.emptyAcquireWaitTime += waitTime + } + p.acquireCount += 1 + p.acquireDuration += waitTime + p.mux.Unlock() + return res, nil + } + + if len(p.allResources) >= int(p.maxSize) { + // Unreachable code. + panic("bug: semaphore allowed more acquires than pool allows") + } + + // The resource is not idle, but there is enough space to create one. + res := p.createNewResource() + p.mux.Unlock() + + res, err := p.initResourceValue(ctx, res) + if err != nil { + return nil, err + } + + p.mux.Lock() + defer p.mux.Unlock() + + p.emptyAcquireCount += 1 + p.acquireCount += 1 + waitTime := time.Duration(nanotime() - startNano) + p.acquireDuration += waitTime + p.emptyAcquireWaitTime += waitTime + + return res, nil +} + +func (p *Pool[T]) initResourceValue(ctx context.Context, res *Resource[T]) (*Resource[T], error) { + // Create the resource in a goroutine to immediately return from Acquire + // if ctx is canceled without also canceling the constructor. + // + // See: + // - https://github.com/jackc/pgx/issues/1287 + // - https://github.com/jackc/pgx/issues/1259 + constructErrChan := make(chan error) + go func() { + constructorCtx := newValueCancelCtx(ctx, p.baseAcquireCtx) + value, err := p.constructor(constructorCtx) + if err != nil { + p.mux.Lock() + p.allResources.remove(res) + p.destructWG.Done() + + // The resource won't be acquired because its + // construction failed. We have to allow someone else to + // take that resouce. + p.acquireSem.Release(1) + p.mux.Unlock() + + select { + case constructErrChan <- err: + case <-ctx.Done(): + // The caller is cancelled, so no-one awaits the + // error. This branch avoid goroutine leak. + } + return + } + + // The resource is already in p.allResources where it might be read. So we need to acquire the lock to update its + // status. + p.mux.Lock() + res.value = value + res.status = resourceStatusAcquired + p.mux.Unlock() + + // This select works because the channel is unbuffered. + select { + case constructErrChan <- nil: + case <-ctx.Done(): + p.releaseAcquiredResource(res, res.lastUsedNano) + } + }() + + select { + case <-ctx.Done(): + p.canceledAcquireCount.Add(1) + return nil, ctx.Err() + case err := <-constructErrChan: + if err != nil { + return nil, err + } + return res, nil + } +} + +// TryAcquire gets a resource from the pool if one is immediately available. If not, it returns ErrNotAvailable. If no +// resources are available but the pool has room to grow, a resource will be created in the background. ctx is only +// used to cancel the background creation. +func (p *Pool[T]) TryAcquire(ctx context.Context) (*Resource[T], error) { + if !p.acquireSem.TryAcquire(1) { + return nil, ErrNotAvailable + } + + p.mux.Lock() + defer p.mux.Unlock() + + if p.closed { + p.acquireSem.Release(1) + return nil, ErrClosedPool + } + + // If a resource is available now + if res := p.tryAcquireIdleResource(); res != nil { + p.acquireCount += 1 + return res, nil + } + + if len(p.allResources) >= int(p.maxSize) { + // Unreachable code. + panic("bug: semaphore allowed more acquires than pool allows") + } + + res := p.createNewResource() + go func() { + value, err := p.constructor(ctx) + + p.mux.Lock() + defer p.mux.Unlock() + // We have to create the resource and only then release the + // semaphore - For the time being there is no resource that + // someone could acquire. + defer p.acquireSem.Release(1) + + if err != nil { + p.allResources.remove(res) + p.destructWG.Done() + return + } + + res.value = value + res.status = resourceStatusIdle + p.idleResources.Push(res) + }() + + return nil, ErrNotAvailable +} + +// acquireSemAll tries to acquire num free tokens from sem. This function is +// guaranteed to acquire at least the lowest number of tokens that has been +// available in the semaphore during runtime of this function. +// +// For the time being, semaphore doesn't allow to acquire all tokens atomically +// (see https://github.com/golang/sync/pull/19). We simulate this by trying all +// powers of 2 that are less or equal to num. +// +// For example, let's immagine we have 19 free tokens in the semaphore which in +// total has 24 tokens (i.e. the maxSize of the pool is 24 resources). Then if +// num is 24, the log2Uint(24) is 4 and we try to acquire 16, 8, 4, 2 and 1 +// tokens. Out of those, the acquire of 16, 2 and 1 tokens will succeed. +// +// Naturally, Acquires and Releases of the semaphore might take place +// concurrently. For this reason, it's not guaranteed that absolutely all free +// tokens in the semaphore will be acquired. But it's guaranteed that at least +// the minimal number of tokens that has been present over the whole process +// will be acquired. This is sufficient for the use-case we have in this +// package. +// +// TODO: Replace this with acquireSem.TryAcquireAll() if it gets to +// upstream. https://github.com/golang/sync/pull/19 +func acquireSemAll(sem *semaphore.Weighted, num int) int { + if sem.TryAcquire(int64(num)) { + return num + } + + var acquired int + for i := int(log2Int(num)); i >= 0; i-- { + val := 1 << i + if sem.TryAcquire(int64(val)) { + acquired += val + } + } + + return acquired +} + +// AcquireAllIdle acquires all currently idle resources. Its intended use is for +// health check and keep-alive functionality. It does not update pool +// statistics. +func (p *Pool[T]) AcquireAllIdle() []*Resource[T] { + p.mux.Lock() + defer p.mux.Unlock() + + if p.closed { + return nil + } + + numIdle := p.idleResources.Len() + if numIdle == 0 { + return nil + } + + // In acquireSemAll we use only TryAcquire and not Acquire. Because + // TryAcquire cannot block, the fact that we hold mutex locked and try + // to acquire semaphore cannot result in dead-lock. + // + // Because the mutex is locked, no parallel Release can run. This + // implies that the number of tokens can only decrease because some + // Acquire/TryAcquire call can consume the semaphore token. Consequently + // acquired is always less or equal to numIdle. Moreover if acquired < + // numIdle, then there are some parallel Acquire/TryAcquire calls that + // will take the remaining idle connections. + acquired := acquireSemAll(p.acquireSem, numIdle) + + idle := make([]*Resource[T], acquired) + for i := range idle { + res, _ := p.idleResources.Pop() + res.status = resourceStatusAcquired + idle[i] = res + } + + // We have to bump the generation to ensure that Acquire/TryAcquire + // calls running in parallel (those which caused acquired < numIdle) + // will consume old connections and not freshly released connections + // instead. + p.idleResources.NextGen() + + return idle +} + +// CreateResource constructs a new resource without acquiring it. It goes straight in the IdlePool. If the pool is full +// it returns an error. It can be useful to maintain warm resources under little load. +func (p *Pool[T]) CreateResource(ctx context.Context) error { + if !p.acquireSem.TryAcquire(1) { + return ErrNotAvailable + } + + p.mux.Lock() + if p.closed { + p.acquireSem.Release(1) + p.mux.Unlock() + return ErrClosedPool + } + + if len(p.allResources) >= int(p.maxSize) { + p.acquireSem.Release(1) + p.mux.Unlock() + return ErrNotAvailable + } + + res := p.createNewResource() + p.mux.Unlock() + + value, err := p.constructor(ctx) + p.mux.Lock() + defer p.mux.Unlock() + defer p.acquireSem.Release(1) + if err != nil { + p.allResources.remove(res) + p.destructWG.Done() + return err + } + + res.value = value + res.status = resourceStatusIdle + + // If closed while constructing resource then destroy it and return an error + if p.closed { + go p.destructResourceValue(res.value) + return ErrClosedPool + } + + p.idleResources.Push(res) + + return nil +} + +// Reset destroys all resources, but leaves the pool open. It is intended for use when an error is detected that would +// disrupt all resources (such as a network interruption or a server state change). +// +// It is safe to reset a pool while resources are checked out. Those resources will be destroyed when they are returned +// to the pool. +func (p *Pool[T]) Reset() { + p.mux.Lock() + defer p.mux.Unlock() + + p.resetCount++ + + for res, ok := p.idleResources.Pop(); ok; res, ok = p.idleResources.Pop() { + p.allResources.remove(res) + go p.destructResourceValue(res.value) + } +} + +// releaseAcquiredResource returns res to the the pool. +func (p *Pool[T]) releaseAcquiredResource(res *Resource[T], lastUsedNano int64) { + p.mux.Lock() + defer p.mux.Unlock() + defer p.acquireSem.Release(1) + + if p.closed || res.poolResetCount != p.resetCount { + p.allResources.remove(res) + go p.destructResourceValue(res.value) + } else { + res.lastUsedNano = lastUsedNano + res.status = resourceStatusIdle + p.idleResources.Push(res) + } +} + +// Remove removes res from the pool and closes it. If res is not part of the +// pool Remove will panic. +func (p *Pool[T]) destroyAcquiredResource(res *Resource[T]) { + p.destructResourceValue(res.value) + + p.mux.Lock() + defer p.mux.Unlock() + defer p.acquireSem.Release(1) + + p.allResources.remove(res) +} + +func (p *Pool[T]) hijackAcquiredResource(res *Resource[T]) { + p.mux.Lock() + defer p.mux.Unlock() + defer p.acquireSem.Release(1) + + p.allResources.remove(res) + res.status = resourceStatusHijacked + p.destructWG.Done() // not responsible for destructing hijacked resources +} + +func (p *Pool[T]) destructResourceValue(value T) { + p.destructor(value) + p.destructWG.Done() +} diff --git a/cluster-manager/vendor/github.com/jackc/puddle/v2/resource_list.go b/cluster-manager/vendor/github.com/jackc/puddle/v2/resource_list.go new file mode 100644 index 00000000..b2430959 --- /dev/null +++ b/cluster-manager/vendor/github.com/jackc/puddle/v2/resource_list.go @@ -0,0 +1,28 @@ +package puddle + +type resList[T any] []*Resource[T] + +func (l *resList[T]) append(val *Resource[T]) { *l = append(*l, val) } + +func (l *resList[T]) popBack() *Resource[T] { + idx := len(*l) - 1 + val := (*l)[idx] + (*l)[idx] = nil // Avoid memory leak + *l = (*l)[:idx] + + return val +} + +func (l *resList[T]) remove(val *Resource[T]) { + for i, elem := range *l { + if elem == val { + lastIdx := len(*l) - 1 + (*l)[i] = (*l)[lastIdx] + (*l)[lastIdx] = nil // Avoid memory leak + (*l) = (*l)[:lastIdx] + return + } + } + + panic("BUG: removeResource could not find res in slice") +} diff --git a/cluster-manager/vendor/gorm.io/driver/postgres/error_translator.go b/cluster-manager/vendor/gorm.io/driver/postgres/error_translator.go index 9c0ef253..5f813501 100644 --- a/cluster-manager/vendor/gorm.io/driver/postgres/error_translator.go +++ b/cluster-manager/vendor/gorm.io/driver/postgres/error_translator.go @@ -8,10 +8,12 @@ import ( "github.com/jackc/pgx/v5/pgconn" ) +// The error codes to map PostgreSQL errors to gorm errors, here is the PostgreSQL error codes reference https://www.postgresql.org/docs/current/errcodes-appendix.html. var errCodes = map[string]error{ "23505": gorm.ErrDuplicatedKey, "23503": gorm.ErrForeignKeyViolated, "42703": gorm.ErrInvalidField, + "23514": gorm.ErrCheckConstraintViolated, } type ErrMessage struct { diff --git a/cluster-manager/vendor/gorm.io/driver/postgres/migrator.go b/cluster-manager/vendor/gorm.io/driver/postgres/migrator.go index 6174e1c1..6b57ce69 100644 --- a/cluster-manager/vendor/gorm.io/driver/postgres/migrator.go +++ b/cluster-manager/vendor/gorm.io/driver/postgres/migrator.go @@ -3,10 +3,10 @@ package postgres import ( "database/sql" "fmt" + "github.com/jackc/pgx/v5" "regexp" "strings" - "github.com/jackc/pgx/v5" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/migrator" @@ -38,18 +38,34 @@ WHERE ` var typeAliasMap = map[string][]string{ - "int2": {"smallint"}, - "int4": {"integer"}, - "int8": {"bigint"}, - "smallint": {"int2"}, - "integer": {"int4"}, - "bigint": {"int8"}, - "decimal": {"numeric"}, - "numeric": {"decimal"}, - "timestamptz": {"timestamp with time zone"}, - "timestamp with time zone": {"timestamptz"}, - "bool": {"boolean"}, - "boolean": {"bool"}, + "int": {"integer"}, + "int2": {"smallint"}, + "int4": {"integer"}, + "int8": {"bigint"}, + "smallint": {"int2"}, + "integer": {"int4"}, + "bigint": {"int8"}, + "date": {"date"}, + "decimal": {"numeric"}, + "numeric": {"decimal"}, + "timestamp": {"timestamp"}, + "timestamptz": {"timestamp with time zone"}, + "timestamp without time zone": {"timestamp"}, + "timestamp with time zone": {"timestamptz"}, + "bool": {"boolean"}, + "boolean": {"bool"}, + "serial2": {"smallserial"}, + "serial4": {"serial"}, + "serial8": {"bigserial"}, + "varbit": {"bit varying"}, + "char": {"character"}, + "varchar": {"character varying"}, + "float4": {"real"}, + "float8": {"double precision"}, + "time": {"time"}, + "timetz": {"time with time zone"}, + "time without time zone": {"time"}, + "time with time zone": {"timetz"}, } type Migrator struct { @@ -120,7 +136,8 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { } createIndexSQL += "INDEX " - if strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" { + hasConcurrentOption := strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" + if hasConcurrentOption { createIndexSQL += "CONCURRENTLY " } @@ -132,6 +149,10 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { createIndexSQL += " ?" } + if idx.Option != "" && !hasConcurrentOption { + createIndexSQL += " " + idx.Option + } + if idx.Where != "" { createIndexSQL += " WHERE " + idx.Where } @@ -312,7 +333,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { fileType := clause.Expr{SQL: m.DataTypeOf(field)} // check for typeName and SQL name isSameType := true - if fieldColumnType.DatabaseTypeName() != fileType.SQL { + if !strings.EqualFold(fieldColumnType.DatabaseTypeName(), fileType.SQL) { isSameType = false // if different, also check for aliases aliases := m.GetTypeAliases(fieldColumnType.DatabaseTypeName()) @@ -375,10 +396,16 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { return err } } else { - if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil { + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { return err } } + } else if !field.HasDefaultValue { + // case - as-is column has default value and to-be column has no default value + // need to drop default + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { + return err + } } } return nil @@ -474,8 +501,8 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, column.LengthValue = typeLenValue } - if (strings.HasPrefix(column.DefaultValueValue.String, "nextval('") && - strings.HasSuffix(column.DefaultValueValue.String, "seq'::regclass)")) || (identityIncrement.Valid && identityIncrement.String != "") { + autoIncrementValuePattern := regexp.MustCompile(`^nextval\('"?[^']+seq"?'::regclass\)$`) + if autoIncrementValuePattern.MatchString(column.DefaultValueValue.String) || (identityIncrement.Valid && identityIncrement.String != "") { column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true} column.DefaultValueValue = sql.NullString{} } diff --git a/cluster-manager/vendor/gorm.io/driver/postgres/postgres.go b/cluster-manager/vendor/gorm.io/driver/postgres/postgres.go index e865b0f8..2d8fd997 100644 --- a/cluster-manager/vendor/gorm.io/driver/postgres/postgres.go +++ b/cluster-manager/vendor/gorm.io/driver/postgres/postgres.go @@ -1,13 +1,16 @@ package postgres import ( + "context" "database/sql" "fmt" "regexp" "strconv" "strings" + "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/stdlib" "gorm.io/gorm" "gorm.io/gorm/callbacks" @@ -31,7 +34,7 @@ type Config struct { } var ( - timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )") + timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone|timezone)=(.*?)($|&| )") defaultIdentifierLength = 63 //maximum identifier length for postgres ) @@ -99,10 +102,23 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol } result := timeZoneMatcher.FindStringSubmatch(dialector.Config.DSN) + var options []stdlib.OptionOpenDB if len(result) > 2 { config.RuntimeParams["timezone"] = result[2] + options = append(options, stdlib.OptionAfterConnect(func(ctx context.Context, conn *pgx.Conn) error { + loc, tzErr := time.LoadLocation(result[2]) + if tzErr != nil { + return tzErr + } + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "timestamp", + OID: pgtype.TimestampOID, + Codec: &pgtype.TimestampCodec{ScanLocation: loc}, + }) + return nil + })) } - db.ConnPool = stdlib.OpenDB(*config) + db.ConnPool = stdlib.OpenDB(*config, options...) } return } @@ -228,7 +244,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { } return "decimal" case schema.String: - if field.Size > 0 { + if field.Size > 0 && field.Size <= 10485760 { return fmt.Sprintf("varchar(%d)", field.Size) } return "text" diff --git a/cluster-manager/vendor/gorm.io/gorm/README.md b/cluster-manager/vendor/gorm.io/gorm/README.md index 745dad60..24eb84c9 100644 --- a/cluster-manager/vendor/gorm.io/gorm/README.md +++ b/cluster-manager/vendor/gorm.io/gorm/README.md @@ -3,7 +3,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) -[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) +[![test status](https://github.com/go-gorm/gorm/actions/workflows/tests.yml/badge.svg)](https://github.com/go-gorm/gorm/actions) [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) [![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) diff --git a/cluster-manager/vendor/gorm.io/gorm/association.go b/cluster-manager/vendor/gorm.io/gorm/association.go index e3f51d17..3a4e0e25 100644 --- a/cluster-manager/vendor/gorm.io/gorm/association.go +++ b/cluster-manager/vendor/gorm.io/gorm/association.go @@ -19,10 +19,10 @@ type Association struct { } func (db *DB) Association(column string) *Association { - association := &Association{DB: db} + association := &Association{DB: db, Unscope: db.Statement.Unscoped} table := db.Statement.Table - if err := db.Statement.Parse(db.Statement.Model); err == nil { + if association.Error = db.Statement.Parse(db.Statement.Model); association.Error == nil { db.Statement.Table = table association.Relationship = db.Statement.Schema.Relationships.Relations[column] @@ -34,8 +34,6 @@ func (db *DB) Association(column string) *Association { for db.Statement.ReflectValue.Kind() == reflect.Ptr { db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() } - } else { - association.Error = err } return association @@ -58,6 +56,8 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro } func (association *Association) Append(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { switch association.Relationship.Type { case schema.HasOne, schema.BelongsTo: @@ -73,6 +73,8 @@ func (association *Association) Append(values ...interface{}) error { } func (association *Association) Replace(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { reflectValue := association.DB.Statement.ReflectValue rel := association.Relationship @@ -97,7 +99,7 @@ func (association *Association) Replace(values ...interface{}) error { return association.Error } - // set old associations's foreign key to null + // set old association's foreign key to null switch rel.Type { case schema.BelongsTo: if len(values) == 0 { @@ -195,6 +197,8 @@ func (association *Association) Replace(values ...interface{}) error { } func (association *Association) Delete(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { var ( reflectValue = association.DB.Statement.ReflectValue @@ -300,7 +304,7 @@ func (association *Association) Delete(values ...interface{}) error { } if association.Error == nil { - // clean up deleted values's foreign key + // clean up deleted values' foreign key relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { @@ -431,10 +435,49 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } } + processMap := func(mapv reflect.Value) { + child := reflect.New(association.Relationship.FieldSchema.ModelType) + + switch association.Relationship.Type { + case schema.HasMany: + for _, ref := range association.Relationship.References { + key := reflect.ValueOf(ref.ForeignKey.DBName) + if ref.OwnPrimaryKey { + v := ref.PrimaryKey.ReflectValueOf(association.DB.Statement.Context, source) + mapv.SetMapIndex(key, v) + } else if ref.PrimaryValue != "" { + mapv.SetMapIndex(key, reflect.ValueOf(ref.PrimaryValue)) + } + } + association.Error = association.DB.Session(&Session{ + NewDB: true, + }).Model(child.Interface()).Create(mapv.Interface()).Error + case schema.Many2Many: + association.Error = association.DB.Session(&Session{ + NewDB: true, + }).Model(child.Interface()).Create(mapv.Interface()).Error + + for _, key := range mapv.MapKeys() { + k := strings.ToLower(key.String()) + if f, ok := association.Relationship.FieldSchema.FieldsByDBName[k]; ok { + _ = f.Set(association.DB.Statement.Context, child, mapv.MapIndex(key).Interface()) + } + } + appendToFieldValues(child) + } + } + switch rv.Kind() { + case reflect.Map: + processMap(rv) case reflect.Slice, reflect.Array: for i := 0; i < rv.Len(); i++ { - appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr()) + elem := reflect.Indirect(rv.Index(i)) + if elem.Kind() == reflect.Map { + processMap(elem) + continue + } + appendToFieldValues(elem.Addr()) } case reflect.Struct: if !rv.CanAddr() { @@ -591,3 +634,32 @@ func (association *Association) buildCondition() *DB { return tx } + +func expandValues(values ...any) (results []any) { + appendToResult := func(rv reflect.Value) { + // unwrap interface + if rv.IsValid() && rv.Kind() == reflect.Interface { + rv = rv.Elem() + } + if rv.IsValid() && rv.Kind() == reflect.Struct { + p := reflect.New(rv.Type()) + p.Elem().Set(rv) + results = append(results, p.Interface()) + } else if rv.IsValid() { + results = append(results, rv.Interface()) + } + } + + // Process each argument; if an argument is a slice/array, expand its elements + for _, value := range values { + rv := reflect.ValueOf(value) + if rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array { + for i := 0; i < rv.Len(); i++ { + appendToResult(rv.Index(i)) + } + } else { + appendToResult(rv) + } + } + return +} diff --git a/cluster-manager/vendor/gorm.io/gorm/callbacks.go b/cluster-manager/vendor/gorm.io/gorm/callbacks.go index 50b5b0e9..bd97f040 100644 --- a/cluster-manager/vendor/gorm.io/gorm/callbacks.go +++ b/cluster-manager/vendor/gorm.io/gorm/callbacks.go @@ -89,10 +89,16 @@ func (p *processor) Execute(db *DB) *DB { resetBuildClauses = true } - if optimizer, ok := db.Statement.Dest.(StatementModifier); ok { + if optimizer, ok := stmt.Dest.(StatementModifier); ok { optimizer.ModifyStatement(stmt) } + if db.DefaultContextTimeout > 0 { + if _, ok := stmt.Context.Deadline(); !ok { + stmt.Context, _ = context.WithTimeout(stmt.Context, db.DefaultContextTimeout) + } + } + // assign model values if stmt.Model == nil { stmt.Model = stmt.Dest diff --git a/cluster-manager/vendor/gorm.io/gorm/callbacks/create.go b/cluster-manager/vendor/gorm.io/gorm/callbacks/create.go index d8701f51..e5929adb 100644 --- a/cluster-manager/vendor/gorm.io/gorm/callbacks/create.go +++ b/cluster-manager/vendor/gorm.io/gorm/callbacks/create.go @@ -53,9 +53,13 @@ func Create(config *Config) func(db *gorm.DB) { if _, ok := db.Statement.Clauses["RETURNING"]; !ok { fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue)) for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { - fromColumns = append(fromColumns, clause.Column{Name: field.DBName}) + if field.Readable { + fromColumns = append(fromColumns, clause.Column{Name: field.DBName}) + } + } + if len(fromColumns) > 0 { + db.Statement.AddClause(clause.Returning{Columns: fromColumns}) } - db.Statement.AddClause(clause.Returning{Columns: fromColumns}) } } } @@ -76,8 +80,11 @@ func Create(config *Config) func(db *gorm.DB) { ok, mode := hasReturning(db, supportReturning) if ok { if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { - if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { + onConflict, _ := c.Expression.(clause.OnConflict) + if onConflict.DoNothing { mode |= gorm.ScanOnConflictDoNothing + } else if len(onConflict.DoUpdates) > 0 || onConflict.UpdateAll { + mode |= gorm.ScanUpdate } } @@ -122,6 +129,16 @@ func Create(config *Config) func(db *gorm.DB) { pkFieldName = "@id" ) + if db.Statement.Schema != nil { + if db.Statement.Schema.PrioritizedPrimaryField == nil || + !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue || + !db.Statement.Schema.PrioritizedPrimaryField.Readable { + return + } + pkField = db.Statement.Schema.PrioritizedPrimaryField + pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName + } + insertID, err := result.LastInsertId() insertOk := err == nil && insertID > 0 @@ -132,14 +149,6 @@ func Create(config *Config) func(db *gorm.DB) { return } - if db.Statement.Schema != nil { - if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - return - } - pkField = db.Statement.Schema.PrioritizedPrimaryField - pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName - } - // append @id column with value for auto-increment primary key // the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1 switch values := db.Statement.Dest.(type) { diff --git a/cluster-manager/vendor/gorm.io/gorm/chainable_api.go b/cluster-manager/vendor/gorm.io/gorm/chainable_api.go index 8a6aea34..8f6113cc 100644 --- a/cluster-manager/vendor/gorm.io/gorm/chainable_api.go +++ b/cluster-manager/vendor/gorm.io/gorm/chainable_api.go @@ -178,7 +178,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { - tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsInvalidDBNameChar) } else { tx.Statement.Omits = columns } @@ -283,7 +283,7 @@ func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() - fields := strings.FieldsFunc(name, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(name, utils.IsInvalidDBNameChar) tx.Statement.AddClause(clause.GroupBy{ Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, }) diff --git a/cluster-manager/vendor/gorm.io/gorm/clause/association.go b/cluster-manager/vendor/gorm.io/gorm/clause/association.go new file mode 100644 index 00000000..a9bf7eb0 --- /dev/null +++ b/cluster-manager/vendor/gorm.io/gorm/clause/association.go @@ -0,0 +1,35 @@ +package clause + +// AssociationOpType represents association operation types +type AssociationOpType int + +const ( + OpUnlink AssociationOpType = iota // Unlink association + OpDelete // Delete association records + OpUpdate // Update association records + OpCreate // Create association records with assignments +) + +// Association represents an association operation +type Association struct { + Association string // Association name + Type AssociationOpType // Operation type + Conditions []Expression // Filter conditions + Set []Assignment // Assignment operations (for Update and Create) + Values []interface{} // Values for Create operation +} + +// AssociationAssigner is an interface for association operation providers +type AssociationAssigner interface { + AssociationAssignments() []Association +} + +// Assignments implements the Assigner interface so that AssociationOperation can be used as a Set method parameter +func (ao Association) Assignments() []Assignment { + return []Assignment{} +} + +// AssociationAssignments implements the AssociationAssigner interface +func (ao Association) AssociationAssignments() []Association { + return []Association{ao} +} diff --git a/cluster-manager/vendor/gorm.io/gorm/clause/set.go b/cluster-manager/vendor/gorm.io/gorm/clause/set.go index 75eb6bdd..cb5f36a0 100644 --- a/cluster-manager/vendor/gorm.io/gorm/clause/set.go +++ b/cluster-manager/vendor/gorm.io/gorm/clause/set.go @@ -9,6 +9,11 @@ type Assignment struct { Value interface{} } +// Assigner assignments provider interface +type Assigner interface { + Assignments() []Assignment +} + func (set Set) Name() string { return "SET" } @@ -37,6 +42,9 @@ func (set Set) MergeClause(clause *Clause) { clause.Expression = Set(copiedAssignments) } +// Assignments implements Assigner for Set. +func (set Set) Assignments() []Assignment { return []Assignment(set) } + func Assignments(values map[string]interface{}) Set { keys := make([]string, 0, len(values)) for key := range values { @@ -58,3 +66,6 @@ func AssignmentColumns(values []string) Set { } return assignments } + +// Assignments implements Assigner for a single Assignment. +func (a Assignment) Assignments() []Assignment { return []Assignment{a} } diff --git a/cluster-manager/vendor/gorm.io/gorm/finisher_api.go b/cluster-manager/vendor/gorm.io/gorm/finisher_api.go index 57809d17..e9e35f1b 100644 --- a/cluster-manager/vendor/gorm.io/gorm/finisher_api.go +++ b/cluster-manager/vendor/gorm.io/gorm/finisher_api.go @@ -465,7 +465,7 @@ func (db *DB) Count(count *int64) (tx *DB) { if len(tx.Statement.Selects) == 1 { dbName := tx.Statement.Selects[0] - fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(dbName, utils.IsInvalidDBNameChar) if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) { if tx.Statement.Parse(tx.Statement.Model) == nil { if f := tx.Statement.Schema.LookUpField(dbName); f != nil { @@ -564,7 +564,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { } if len(tx.Statement.Selects) != 1 { - fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(column, utils.IsInvalidDBNameChar) tx.Statement.AddClauseIfNotExists(clause.Select{ Distinct: tx.Statement.Distinct, Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, @@ -675,9 +675,9 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { } ctx := tx.Statement.Context - if _, ok := ctx.Deadline(); !ok { - if db.Config.DefaultTransactionTimeout > 0 { - ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout) + if db.DefaultTransactionTimeout > 0 { + if _, ok := ctx.Deadline(); !ok { + ctx, _ = context.WithTimeout(ctx, db.DefaultTransactionTimeout) } } diff --git a/cluster-manager/vendor/gorm.io/gorm/generics.go b/cluster-manager/vendor/gorm.io/gorm/generics.go index ad2d063f..166d1520 100644 --- a/cluster-manager/vendor/gorm.io/gorm/generics.go +++ b/cluster-manager/vendor/gorm.io/gorm/generics.go @@ -3,12 +3,15 @@ package gorm import ( "context" "database/sql" + "errors" "fmt" + "reflect" "sort" "strings" "gorm.io/gorm/clause" "gorm.io/gorm/logger" + "gorm.io/gorm/schema" ) type result struct { @@ -35,10 +38,34 @@ type Interface[T any] interface { } type CreateInterface[T any] interface { - ChainInterface[T] + ExecInterface[T] + // chain methods available at start; Select/Omit keep CreateInterface to allow Create chaining + Scopes(scopes ...func(db *Statement)) ChainInterface[T] + Where(query interface{}, args ...interface{}) ChainInterface[T] + Not(query interface{}, args ...interface{}) ChainInterface[T] + Or(query interface{}, args ...interface{}) ChainInterface[T] + Limit(offset int) ChainInterface[T] + Offset(offset int) ChainInterface[T] + Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] + Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] + Select(query string, args ...interface{}) CreateInterface[T] + Omit(columns ...string) CreateInterface[T] + MapColumns(m map[string]string) ChainInterface[T] + Distinct(args ...interface{}) ChainInterface[T] + Group(name string) ChainInterface[T] + Having(query interface{}, args ...interface{}) ChainInterface[T] + Order(value interface{}) ChainInterface[T] + Build(builder clause.Builder) + + Delete(ctx context.Context) (rowsAffected int, err error) + Update(ctx context.Context, name string, value any) (rowsAffected int, err error) + Updates(ctx context.Context, t T) (rowsAffected int, err error) + Count(ctx context.Context, column string) (result int64, err error) + Table(name string, args ...interface{}) CreateInterface[T] Create(ctx context.Context, r *T) error CreateInBatches(ctx context.Context, r *[]T, batchSize int) error + Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T] } type ChainInterface[T any] interface { @@ -58,15 +85,28 @@ type ChainInterface[T any] interface { Group(name string) ChainInterface[T] Having(query interface{}, args ...interface{}) ChainInterface[T] Order(value interface{}) ChainInterface[T] + Set(assignments ...clause.Assigner) SetUpdateOnlyInterface[T] Build(builder clause.Builder) + Table(name string, args ...interface{}) ChainInterface[T] Delete(ctx context.Context) (rowsAffected int, err error) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) Updates(ctx context.Context, t T) (rowsAffected int, err error) Count(ctx context.Context, column string) (result int64, err error) } +// SetUpdateOnlyInterface is returned by Set after chaining; only Update is allowed +type SetUpdateOnlyInterface[T any] interface { + Update(ctx context.Context) (rowsAffected int, err error) +} + +// SetCreateOrUpdateInterface is returned by Set at start; Create or Update are allowed +type SetCreateOrUpdateInterface[T any] interface { + Create(ctx context.Context) error + Update(ctx context.Context) (rowsAffected int, err error) +} + type ExecInterface[T any] interface { Scan(ctx context.Context, r interface{}) error First(context.Context) (T, error) @@ -142,13 +182,15 @@ func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] { return execG[T]{g: &g[T]{ db: c.db, ops: append(c.ops, func(db *DB) *DB { - return db.Raw(sql, values...) + var r T + return db.Model(r).Raw(sql, values...) }), }} } func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error { - return c.apply(ctx).Exec(sql, values...).Error + var r T + return c.apply(ctx).Model(r).Exec(sql, values...).Error } type createG[T any] struct { @@ -161,6 +203,22 @@ func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { })} } +func (c createG[T]) Select(query string, args ...interface{}) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Select(query, args...) + })} +} + +func (c createG[T]) Omit(columns ...string) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Omit(columns...) + })} +} + +func (c createG[T]) Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T] { + return c.processSet(assignments...) +} + func (c createG[T]) Create(ctx context.Context, r *T) error { return c.g.apply(ctx).Create(r).Error } @@ -187,6 +245,12 @@ func (c chainG[T]) with(v op) chainG[T] { } } +func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Table(name, args...) + }) +} + func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { return c.with(func(db *DB) *DB { for _, fc := range scopes { @@ -196,12 +260,6 @@ func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { }) } -func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] { - return c.with(func(db *DB) *DB { - return db.Table(name, args...) - }) -} - func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Where(query, args...) @@ -388,6 +446,10 @@ func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] { }) } +func (c chainG[T]) Set(assignments ...clause.Assigner) SetUpdateOnlyInterface[T] { + return c.processSet(assignments...) +} + func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Distinct(args...) @@ -425,12 +487,12 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err relation, ok := db.Statement.Schema.Relationships.Relations[association] if !ok { if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 { - relationships := db.Statement.Schema.Relationships + relationships := &db.Statement.Schema.Relationships for _, field := range preloadFields { var ok bool relation, ok = relationships.Relations[field] if ok { - relationships = relation.FieldSchema.Relationships + relationships = &relation.FieldSchema.Relationships } else { db.AddError(fmt.Errorf("relation %s not found", association)) return nil @@ -567,7 +629,7 @@ func (g execG[T]) First(ctx context.Context) (T, error) { func (g execG[T]) Scan(ctx context.Context, result interface{}) error { var r T - err := g.g.apply(ctx).Model(r).Find(&result).Error + err := g.g.apply(ctx).Model(r).Find(result).Error return err } @@ -603,3 +665,242 @@ func (g execG[T]) Row(ctx context.Context) *sql.Row { func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) { return g.g.apply(ctx).Rows() } + +func (c chainG[T]) processSet(items ...clause.Assigner) setCreateOrUpdateG[T] { + var ( + assigns []clause.Assignment + assocOps []clause.Association + ) + + for _, item := range items { + // Check if it's an AssociationAssigner + if assocAssigner, ok := item.(clause.AssociationAssigner); ok { + assocOps = append(assocOps, assocAssigner.AssociationAssignments()...) + } else { + assigns = append(assigns, item.Assignments()...) + } + } + + return setCreateOrUpdateG[T]{ + c: c, + assigns: assigns, + assocOps: assocOps, + } +} + +// setCreateOrUpdateG[T] is a struct that holds operations to be executed in a batch. +// It supports regular assignments and association operations. +type setCreateOrUpdateG[T any] struct { + c chainG[T] + assigns []clause.Assignment + assocOps []clause.Association +} + +func (s setCreateOrUpdateG[T]) Update(ctx context.Context) (rowsAffected int, err error) { + // Execute association operations + for _, assocOp := range s.assocOps { + if err := s.executeAssociationOperation(ctx, assocOp); err != nil { + return 0, err + } + } + + // Execute assignment operations + if len(s.assigns) > 0 { + var r T + res := s.c.g.apply(ctx).Model(r).Clauses(clause.Set(s.assigns)).Updates(map[string]interface{}{}) + return int(res.RowsAffected), res.Error + } + + return 0, nil +} + +func (s setCreateOrUpdateG[T]) Create(ctx context.Context) error { + // Execute association operations + for _, assocOp := range s.assocOps { + if err := s.executeAssociationOperation(ctx, assocOp); err != nil { + return err + } + } + + // Execute assignment operations + if len(s.assigns) > 0 { + data := make(map[string]interface{}, len(s.assigns)) + for _, a := range s.assigns { + data[a.Column.Name] = a.Value + } + var r T + return s.c.g.apply(ctx).Model(r).Create(data).Error + } + + return nil +} + +// executeAssociationOperation executes an association operation +func (s setCreateOrUpdateG[T]) executeAssociationOperation(ctx context.Context, op clause.Association) error { + var r T + base := s.c.g.apply(ctx).Model(r) + + switch op.Type { + case clause.OpCreate: + return s.handleAssociationCreate(ctx, base, op) + case clause.OpUnlink, clause.OpDelete, clause.OpUpdate: + return s.handleAssociation(ctx, base, op) + default: + return fmt.Errorf("unknown association operation type: %v", op.Type) + } +} + +func (s setCreateOrUpdateG[T]) handleAssociationCreate(ctx context.Context, base *DB, op clause.Association) error { + if len(op.Set) > 0 { + return s.handleAssociationForOwners(base, ctx, func(owner T, assoc *Association) error { + data := make(map[string]interface{}, len(op.Set)) + for _, a := range op.Set { + data[a.Column.Name] = a.Value + } + return assoc.Append(data) + }, op.Association) + } + + return s.handleAssociationForOwners(base, ctx, func(owner T, assoc *Association) error { + return assoc.Append(op.Values...) + }, op.Association) +} + +// handleAssociationForOwners is a helper function that handles associations for all owners +func (s setCreateOrUpdateG[T]) handleAssociationForOwners(base *DB, ctx context.Context, handler func(owner T, association *Association) error, associationName string) error { + var owners []T + if err := base.Find(&owners).Error; err != nil { + return err + } + + for _, owner := range owners { + assoc := base.Session(&Session{NewDB: true, Context: ctx}).Model(&owner).Association(associationName) + if assoc.Error != nil { + return assoc.Error + } + + if err := handler(owner, assoc); err != nil { + return err + } + } + return nil +} + +func (s setCreateOrUpdateG[T]) handleAssociation(ctx context.Context, base *DB, op clause.Association) error { + assoc := base.Association(op.Association) + if assoc.Error != nil { + return assoc.Error + } + + var ( + rel = assoc.Relationship + assocModel = reflect.New(rel.FieldSchema.ModelType).Interface() + fkNil = map[string]any{} + setMap = make(map[string]any, len(op.Set)) + ownerPKNames []string + ownerFKNames []string + primaryColumns []any + foreignColumns []any + ) + + for _, a := range op.Set { + setMap[a.Column.Name] = a.Value + } + + for _, ref := range rel.References { + fkNil[ref.ForeignKey.DBName] = nil + + if ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownerPKNames = append(ownerPKNames, ref.PrimaryKey.DBName) + primaryColumns = append(primaryColumns, clause.Column{Name: ref.PrimaryKey.DBName}) + foreignColumns = append(foreignColumns, clause.Column{Name: ref.ForeignKey.DBName}) + } else if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownerFKNames = append(ownerFKNames, ref.ForeignKey.DBName) + primaryColumns = append(primaryColumns, clause.Column{Name: ref.PrimaryKey.DBName}) + } + } + + assocDB := s.c.g.db.Session(&Session{NewDB: true, Context: ctx}).Model(assocModel).Where(op.Conditions) + + switch rel.Type { + case schema.HasOne, schema.HasMany: + assocDB = assocDB.Where("? IN (?)", foreignColumns, base.Select(ownerPKNames)) + switch op.Type { + case clause.OpUnlink: + return assocDB.Updates(fkNil).Error + case clause.OpDelete: + return assocDB.Delete(assocModel).Error + case clause.OpUpdate: + return assocDB.Updates(setMap).Error + } + case schema.BelongsTo: + switch op.Type { + case clause.OpDelete: + return base.Transaction(func(tx *DB) error { + assocDB.Statement.ConnPool = tx.Statement.ConnPool + base.Statement.ConnPool = tx.Statement.ConnPool + + if err := assocDB.Where("? IN (?)", primaryColumns, base.Select(ownerFKNames)).Delete(assocModel).Error; err != nil { + return err + } + return base.Updates(fkNil).Error + }) + case clause.OpUnlink: + return base.Updates(fkNil).Error + case clause.OpUpdate: + return assocDB.Where("? IN (?)", primaryColumns, base.Select(ownerFKNames)).Updates(setMap).Error + } + case schema.Many2Many: + joinModel := reflect.New(rel.JoinTable.ModelType).Interface() + joinDB := base.Session(&Session{NewDB: true, Context: ctx}).Model(joinModel) + + // EXISTS owners: owners.pk = join.owner_fk for all owner refs + ownersExists := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.Schema.Table).Select("1") + for _, ref := range rel.References { + if ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownersExists = ownersExists.Where(clause.Eq{ + Column: clause.Column{Table: rel.Schema.Table, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + }) + } + } + + // EXISTS related: related.pk = join.rel_fk for all related refs, plus optional conditions + relatedExists := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.FieldSchema.Table).Select("1") + for _, ref := range rel.References { + if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + relatedExists = relatedExists.Where(clause.Eq{ + Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + }) + } + } + relatedExists = relatedExists.Where(op.Conditions) + + switch op.Type { + case clause.OpUnlink, clause.OpDelete: + joinDB = joinDB.Where("EXISTS (?)", ownersExists) + if len(op.Conditions) > 0 { + joinDB = joinDB.Where("EXISTS (?)", relatedExists) + } + return joinDB.Delete(nil).Error + case clause.OpUpdate: + // Update related table rows that have join rows matching owners + relatedDB := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.FieldSchema.Table).Where(op.Conditions) + + // correlated join subquery: join.rel_fk = related.pk AND EXISTS owners + joinSub := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.JoinTable.Table).Select("1") + for _, ref := range rel.References { + if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + joinSub = joinSub.Where(clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, + }) + } + } + joinSub = joinSub.Where("EXISTS (?)", ownersExists) + return relatedDB.Where("EXISTS (?)", joinSub).Updates(setMap).Error + } + } + return errors.New("unsupported relationship") +} diff --git a/cluster-manager/vendor/gorm.io/gorm/gorm.go b/cluster-manager/vendor/gorm.io/gorm/gorm.go index 67889262..a209bb09 100644 --- a/cluster-manager/vendor/gorm.io/gorm/gorm.go +++ b/cluster-manager/vendor/gorm.io/gorm/gorm.go @@ -23,6 +23,7 @@ type Config struct { // You can disable it by setting `SkipDefaultTransaction` to true SkipDefaultTransaction bool DefaultTransactionTimeout time.Duration + DefaultContextTimeout time.Duration // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer @@ -137,6 +138,14 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { return isConfig && !isConfig2 }) + if len(opts) > 0 { + if c, ok := opts[0].(*Config); ok { + config = c + } else { + opts = append([]Option{config}, opts...) + } + } + var skipAfterInitialize bool for _, opt := range opts { if opt != nil { diff --git a/cluster-manager/vendor/gorm.io/gorm/logger/slog.go b/cluster-manager/vendor/gorm.io/gorm/logger/slog.go new file mode 100644 index 00000000..27c74683 --- /dev/null +++ b/cluster-manager/vendor/gorm.io/gorm/logger/slog.go @@ -0,0 +1,116 @@ +//go:build go1.21 + +package logger + +import ( + "context" + "errors" + "fmt" + "log/slog" + "time" + + "gorm.io/gorm/utils" +) + +type slogLogger struct { + Logger *slog.Logger + LogLevel LogLevel + SlowThreshold time.Duration + Parameterized bool + Colorful bool // Ignored in slog + IgnoreRecordNotFoundError bool +} + +func NewSlogLogger(logger *slog.Logger, config Config) Interface { + return &slogLogger{ + Logger: logger, + LogLevel: config.LogLevel, + SlowThreshold: config.SlowThreshold, + Parameterized: config.ParameterizedQueries, + IgnoreRecordNotFoundError: config.IgnoreRecordNotFoundError, + } +} + +func (l *slogLogger) LogMode(level LogLevel) Interface { + newLogger := *l + newLogger.LogLevel = level + return &newLogger +} + +func (l *slogLogger) Info(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Info { + l.log(ctx, slog.LevelInfo, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Warn(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Warn { + l.log(ctx, slog.LevelWarn, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Error(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Error { + l.log(ctx, slog.LevelError, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + if l.LogLevel <= Silent { + return + } + + elapsed := time.Since(begin) + sql, rows := fc() + fields := []slog.Attr{ + slog.String("duration", fmt.Sprintf("%.3fms", float64(elapsed.Nanoseconds())/1e6)), + slog.String("sql", sql), + } + + if rows != -1 { + fields = append(fields, slog.Int64("rows", rows)) + } + + switch { + case err != nil && (!l.IgnoreRecordNotFoundError || !errors.Is(err, ErrRecordNotFound)): + fields = append(fields, slog.String("error", err.Error())) + l.log(ctx, slog.LevelError, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + + case l.SlowThreshold != 0 && elapsed > l.SlowThreshold: + l.log(ctx, slog.LevelWarn, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + + case l.LogLevel >= Info: + l.log(ctx, slog.LevelInfo, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + } +} + +func (l *slogLogger) log(ctx context.Context, level slog.Level, msg string, args ...any) { + if ctx == nil { + ctx = context.Background() + } + + if !l.Logger.Enabled(ctx, level) { + return + } + + r := slog.NewRecord(time.Now(), level, msg, utils.CallerFrame().PC) + r.Add(args...) + _ = l.Logger.Handler().Handle(ctx, r) +} + +// ParamsFilter filter params +func (l *slogLogger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { + if l.Parameterized { + return sql, nil + } + return sql, params +} diff --git a/cluster-manager/vendor/gorm.io/gorm/migrator/migrator.go b/cluster-manager/vendor/gorm.io/gorm/migrator/migrator.go index cec4e30f..35107d57 100644 --- a/cluster-manager/vendor/gorm.io/gorm/migrator/migrator.go +++ b/cluster-manager/vendor/gorm.io/gorm/migrator/migrator.go @@ -474,7 +474,6 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy // found, smart migrate fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) realDataType := strings.ToLower(columnType.DatabaseTypeName()) - var ( alterColumn bool isSameType = fullDataType == realDataType @@ -513,8 +512,19 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } } + } - // check precision + // check precision + if realDataType == "decimal" || realDataType == "numeric" && + regexp.MustCompile(realDataType+`\(.*\)`).FindString(fullDataType) != "" { // if realDataType has no precision,ignore + precision, scale, ok := columnType.DecimalSize() + if ok { + if !strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d,%d)", realDataType, precision, scale)) && + !strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d)", realDataType, precision)) { + alterColumn = true + } + } + } else { if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { alterColumn = true @@ -550,6 +560,10 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy v1, _ := strconv.ParseBool(dv) v2, _ := strconv.ParseBool(field.DefaultValue) alterColumn = v1 != v2 + case schema.String: + if dv != field.DefaultValue && dv != strings.Trim(field.DefaultValue, "'\"") { + alterColumn = true + } default: alterColumn = dv != field.DefaultValue } diff --git a/cluster-manager/vendor/gorm.io/gorm/schema/field.go b/cluster-manager/vendor/gorm.io/gorm/schema/field.go index a6ff1a72..de797402 100644 --- a/cluster-manager/vendor/gorm.io/gorm/schema/field.go +++ b/cluster-manager/vendor/gorm.io/gorm/schema/field.go @@ -448,16 +448,17 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } // create valuer, setter when parse struct -func (field *Field) setupValuerAndSetter() { +func (field *Field) setupValuerAndSetter(modelType reflect.Type) { // Setup NewValuePool field.setupNewValuePool() // ValueOf returns field's value and if it is zero fieldIndex := field.StructField.Index[0] switch { - case len(field.StructField.Index) == 1 && fieldIndex > 0: - field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) { - fieldValue := reflect.Indirect(value).Field(fieldIndex) + case len(field.StructField.Index) == 1 && fieldIndex >= 0: + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + v = reflect.Indirect(v) + fieldValue := v.Field(fieldIndex) return fieldValue.Interface(), fieldValue.IsZero() } default: @@ -504,9 +505,10 @@ func (field *Field) setupValuerAndSetter() { // ReflectValueOf returns field's reflect value switch { - case len(field.StructField.Index) == 1 && fieldIndex > 0: - field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value { - return reflect.Indirect(value).Field(fieldIndex) + case len(field.StructField.Index) == 1 && fieldIndex >= 0: + field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { + v = reflect.Indirect(v) + return v.Field(fieldIndex) } default: field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { diff --git a/cluster-manager/vendor/gorm.io/gorm/schema/relationship.go b/cluster-manager/vendor/gorm.io/gorm/schema/relationship.go index f1ace924..0535bba4 100644 --- a/cluster-manager/vendor/gorm.io/gorm/schema/relationship.go +++ b/cluster-manager/vendor/gorm.io/gorm/schema/relationship.go @@ -75,9 +75,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { } ) - cacheStore := schema.cacheStore - - if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { + if relation.FieldSchema, err = getOrParse(fieldValue, schema.cacheStore, schema.namer); err != nil { schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err) return nil } @@ -147,6 +145,9 @@ func hasPolymorphicRelation(tagSettings map[string]string) bool { } func (schema *Schema) setRelation(relation *Relationship) { + schema.Relationships.Mux.Lock() + defer schema.Relationships.Mux.Unlock() + // set non-embedded relation if rel := schema.Relationships.Relations[relation.Name]; rel != nil { if len(rel.Field.BindNames) > 1 { @@ -590,6 +591,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu // build references for idx, foreignField := range foreignFields { // use same data type for foreign keys + schema.Relationships.Mux.Lock() + if schema != foreignField.Schema { + foreignField.Schema.Relationships.Mux.Lock() + } if copyableDataType(primaryFields[idx].DataType) { foreignField.DataType = primaryFields[idx].DataType } @@ -597,6 +602,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu if foreignField.Size == 0 { foreignField.Size = primaryFields[idx].Size } + schema.Relationships.Mux.Unlock() + if schema != foreignField.Schema { + foreignField.Schema.Relationships.Mux.Unlock() + } relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], diff --git a/cluster-manager/vendor/gorm.io/gorm/schema/schema.go b/cluster-manager/vendor/gorm.io/gorm/schema/schema.go index db236797..09697a7a 100644 --- a/cluster-manager/vendor/gorm.io/gorm/schema/schema.go +++ b/cluster-manager/vendor/gorm.io/gorm/schema/schema.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "go/ast" + "path" "reflect" "strings" "sync" @@ -59,14 +60,14 @@ type Schema struct { cacheStore *sync.Map } -func (schema Schema) String() string { +func (schema *Schema) String() string { if schema.ModelType.Name() == "" { return fmt.Sprintf("%s(%s)", schema.Name, schema.Table) } return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name()) } -func (schema Schema) MakeSlice() reflect.Value { +func (schema *Schema) MakeSlice() reflect.Value { slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20) results := reflect.New(slice.Type()) results.Elem().Set(slice) @@ -74,13 +75,23 @@ func (schema Schema) MakeSlice() reflect.Value { return results } -func (schema Schema) LookUpField(name string) *Field { +func (schema *Schema) LookUpField(name string) *Field { if field, ok := schema.FieldsByDBName[name]; ok { return field } if field, ok := schema.FieldsByName[name]; ok { return field } + + // Lookup field using namer-driven ColumnName + if schema.namer == nil { + return nil + } + namerColumnName := schema.namer.ColumnName(schema.Table, name) + if field, ok := schema.FieldsByDBName[namerColumnName]; ok { + return field + } + return nil } @@ -92,10 +103,7 @@ func (schema Schema) LookUpField(name string) *Field { // } // ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID") // } -func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { - if len(bindNames) == 0 { - return nil - } +func (schema *Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { for i := len(bindNames) - 1; i >= 0; i-- { find := strings.Join(bindNames[:i], ".") + "." + name if field, ok := schema.FieldsByBindName[find]; ok { @@ -113,6 +121,14 @@ type TablerWithNamer interface { TableName(Namer) string } +var callbackTypes = []callbackType{ + callbackTypeBeforeCreate, callbackTypeAfterCreate, + callbackTypeBeforeUpdate, callbackTypeAfterUpdate, + callbackTypeBeforeSave, callbackTypeAfterSave, + callbackTypeBeforeDelete, callbackTypeAfterDelete, + callbackTypeAfterFind, +} + // Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { return ParseWithSpecialTableName(dest, cacheStore, namer, "") @@ -124,34 +140,33 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - value := reflect.ValueOf(dest) - if value.Kind() == reflect.Ptr && value.IsNil() { - value = reflect.New(value.Type().Elem()) - } - modelType := reflect.Indirect(value).Type() - - if modelType.Kind() == reflect.Interface { - modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() - } - - for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType := reflect.ValueOf(dest).Type() + if modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } if modelType.Kind() != reflect.Struct { - if modelType.PkgPath() == "" { - return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + if modelType.Kind() == reflect.Interface { + modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() + } + + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } - return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } // Cache the Schema for performance, // Use the modelType or modelType + schemaTable (if it present) as cache key. - var schemaCacheKey interface{} + var schemaCacheKey interface{} = modelType if specialTableName != "" { schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName) - } else { - schemaCacheKey = modelType } // Load exist schema cache, return if exists @@ -162,28 +177,29 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return s, s.err } + var tableName string modelValue := reflect.New(modelType) - tableName := namer.TableName(modelType.Name()) - if tabler, ok := modelValue.Interface().(Tabler); ok { + if specialTableName != "" { + tableName = specialTableName + } else if en, ok := namer.(embeddedNamer); ok { + tableName = en.Table + } else if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() - } - if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { + } else if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { tableName = tabler.TableName(namer) - } - if en, ok := namer.(embeddedNamer); ok { - tableName = en.Table - } - if specialTableName != "" && specialTableName != tableName { - tableName = specialTableName + } else { + tableName = namer.TableName(modelType.Name()) } schema := &Schema{ Name: modelType.Name(), ModelType: modelType, Table: tableName, - FieldsByName: map[string]*Field{}, - FieldsByBindName: map[string]*Field{}, - FieldsByDBName: map[string]*Field{}, + DBNames: make([]string, 0, 10), + Fields: make([]*Field, 0, 10), + FieldsByName: make(map[string]*Field, 10), + FieldsByBindName: make(map[string]*Field, 10), + FieldsByDBName: make(map[string]*Field, 10), Relationships: Relationships{Relations: map[string]*Relationship{}}, cacheStore: cacheStore, namer: namer, @@ -227,8 +243,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.FieldsByBindName[bindName] = field if v != nil && v.PrimaryKey { + // remove the existing primary key field for idx, f := range schema.PrimaryFields { - if f == v { + if f.DBName == v.DBName { schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) } } @@ -247,7 +264,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.FieldsByBindName[bindName] = field } - field.setupValuerAndSetter() + field.setupValuerAndSetter(modelType) } prioritizedPrimaryField := schema.LookUpField("id") @@ -283,10 +300,37 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) } + _, embedded := schema.cacheStore.Load(embeddedCacheKey) + relationshipFields := []*Field{} for _, field := range schema.Fields { if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil { schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } + + if !embedded { + if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) { + relationshipFields = append(relationshipFields, field) + schema.FieldsByName[field.Name] = field + schema.FieldsByBindName[field.BindName()] = field + } + + fieldValue := reflect.New(field.IndirectFieldType).Interface() + if fc, ok := fieldValue.(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } + + if fc, ok := fieldValue.(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } + + if fc, ok := fieldValue.(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } + + if fc, ok := fieldValue.(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + } + } } if field := schema.PrioritizedPrimaryField; field != nil { @@ -303,24 +347,6 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } } - callbackTypes := []callbackType{ - callbackTypeBeforeCreate, callbackTypeAfterCreate, - callbackTypeBeforeUpdate, callbackTypeAfterUpdate, - callbackTypeBeforeSave, callbackTypeAfterSave, - callbackTypeBeforeDelete, callbackTypeAfterDelete, - callbackTypeAfterFind, - } - for _, cbName := range callbackTypes { - if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() { - switch methodValue.Type().String() { - case "func(*gorm.DB) error": // TODO hack - reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) - default: - logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) - } - } - } - // Cache the schema if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded { s := v.(*Schema) @@ -336,84 +362,47 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } }() - if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { - for _, field := range schema.Fields { - if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) { - if schema.parseRelation(field); schema.err != nil { - return schema, schema.err + for _, cbName := range callbackTypes { + if methodValue := modelValue.MethodByName(string(cbName)); methodValue.IsValid() { + switch methodValue.Type().String() { + case "func(*gorm.DB) error": + expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath()) + if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath { + reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) } else { - schema.FieldsByName[field.Name] = field - schema.FieldsByBindName[field.BindName()] = field + logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg) + // PASS } + default: + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) } + } + } - fieldValue := reflect.New(field.IndirectFieldType) - fieldInterface := fieldValue.Interface() - if fc, ok := fieldInterface.(CreateClausesInterface); ok { - field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) - } - - if fc, ok := fieldInterface.(QueryClausesInterface); ok { - field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) - } - - if fc, ok := fieldInterface.(UpdateClausesInterface); ok { - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) - } - - if fc, ok := fieldInterface.(DeleteClausesInterface); ok { - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) - } + // parse relationships + for _, field := range relationshipFields { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err } } return schema, schema.err } -// This unrolling is needed to show to the compiler the exact set of methods -// that can be used on the modelType. -// Prior to go1.22 any use of MethodByName would cause the linker to -// abandon dead code elimination for the entire binary. -// As of go1.22 the compiler supports one special case of a string constant -// being passed to MethodByName. For enterprise customers or those building -// large binaries, this gives a significant reduction in binary size. -// https://github.com/golang/go/issues/62257 -func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value { - switch cbType { - case callbackTypeBeforeCreate: - return modelType.MethodByName(string(callbackTypeBeforeCreate)) - case callbackTypeAfterCreate: - return modelType.MethodByName(string(callbackTypeAfterCreate)) - case callbackTypeBeforeUpdate: - return modelType.MethodByName(string(callbackTypeBeforeUpdate)) - case callbackTypeAfterUpdate: - return modelType.MethodByName(string(callbackTypeAfterUpdate)) - case callbackTypeBeforeSave: - return modelType.MethodByName(string(callbackTypeBeforeSave)) - case callbackTypeAfterSave: - return modelType.MethodByName(string(callbackTypeAfterSave)) - case callbackTypeBeforeDelete: - return modelType.MethodByName(string(callbackTypeBeforeDelete)) - case callbackTypeAfterDelete: - return modelType.MethodByName(string(callbackTypeAfterDelete)) - case callbackTypeAfterFind: - return modelType.MethodByName(string(callbackTypeAfterFind)) - default: - return reflect.ValueOf(nil) - } -} - func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() - for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } if modelType.Kind() != reflect.Struct { - if modelType.PkgPath() == "" { - return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } - return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } if v, ok := cacheStore.Load(modelType); ok { diff --git a/cluster-manager/vendor/gorm.io/gorm/schema/serializer.go b/cluster-manager/vendor/gorm.io/gorm/schema/serializer.go index 0fafbcba..d774a8ed 100644 --- a/cluster-manager/vendor/gorm.io/gorm/schema/serializer.go +++ b/cluster-manager/vendor/gorm.io/gorm/schema/serializer.go @@ -8,6 +8,7 @@ import ( "encoding/gob" "encoding/json" "fmt" + "math" "reflect" "strings" "sync" @@ -127,16 +128,31 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect. // Value implements serializer interface func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { rv := reflect.ValueOf(fieldValue) - switch v := fieldValue.(type) { - case int64, int, uint, uint64, int32, uint32, int16, uint16: - result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() - case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: + switch fieldValue.(type) { + case int, int8, int16, int32, int64: + result = time.Unix(rv.Int(), 0).UTC() + case uint, uint8, uint16, uint32, uint64: + if uv := rv.Uint(); uv > math.MaxInt64 { + err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) + } else { + result = time.Unix(int64(uv), 0).UTC() //nolint:gosec + } + case *int, *int8, *int16, *int32, *int64: if rv.IsZero() { return nil, nil } - result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() + result = time.Unix(rv.Elem().Int(), 0).UTC() + case *uint, *uint8, *uint16, *uint32, *uint64: + if rv.IsZero() { + return nil, nil + } + if uv := rv.Elem().Uint(); uv > math.MaxInt64 { + err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) + } else { + result = time.Unix(int64(uv), 0).UTC() //nolint:gosec + } default: - err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) + err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", fieldValue) } return } diff --git a/cluster-manager/vendor/gorm.io/gorm/schema/utils.go b/cluster-manager/vendor/gorm.io/gorm/schema/utils.go index fa1c65d4..86305d7b 100644 --- a/cluster-manager/vendor/gorm.io/gorm/schema/utils.go +++ b/cluster-manager/vendor/gorm.io/gorm/schema/utils.go @@ -17,25 +17,23 @@ func ParseTagSetting(str string, sep string) map[string]string { settings := map[string]string{} names := strings.Split(str, sep) + var parsedNames []string for i := 0; i < len(names); i++ { - j := i - if len(names[j]) > 0 { - for { - if names[j][len(names[j])-1] == '\\' { - i++ - names[j] = names[j][0:len(names[j])-1] + sep + names[i] - names[i] = "" - } else { - break - } - } + s := names[i] + for strings.HasSuffix(s, "\\") && i+1 < len(names) { + i++ + s = s[:len(s)-1] + sep + names[i] } + parsedNames = append(parsedNames, s) + } - values := strings.Split(names[j], ":") + for _, tag := range parsedNames { + values := strings.Split(tag, ":") k := strings.TrimSpace(strings.ToUpper(values[0])) - if len(values) >= 2 { - settings[k] = strings.Join(values[1:], ":") + val := strings.Join(values[1:], ":") + val = strings.ReplaceAll(val, `\"`, `"`) + settings[k] = val } else if k != "" { settings[k] = k } @@ -121,6 +119,17 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, } switch reflectValue.Kind() { + case reflect.Map: + results = [][]interface{}{make([]interface{}, len(fields))} + for idx, field := range fields { + mapValue := reflectValue.MapIndex(reflect.ValueOf(field.DBName)) + if mapValue.IsZero() { + mapValue = reflectValue.MapIndex(reflect.ValueOf(field.Name)) + } + results[0][idx] = mapValue.Interface() + } + + dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue} case reflect.Struct: results = [][]interface{}{make([]interface{}, len(fields))} diff --git a/cluster-manager/vendor/gorm.io/gorm/statement.go b/cluster-manager/vendor/gorm.io/gorm/statement.go index c6183724..736087d7 100644 --- a/cluster-manager/vendor/gorm.io/gorm/statement.go +++ b/cluster-manager/vendor/gorm.io/gorm/statement.go @@ -96,7 +96,9 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { if v.Name == clause.CurrentTable { if stmt.TableExpr != nil { stmt.TableExpr.Build(stmt) - } else { + } else if stmt.Table != "" { + write(v.Raw, stmt.Table) + } else if stmt.AddError(stmt.Parse(stmt.Model)) == nil { write(v.Raw, stmt.Table) } } else { @@ -334,6 +336,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] switch v := arg.(type) { case clause.Expression: conds = append(conds, v) + case []clause.Expression: + conds = append(conds, v...) case *DB: v.executeScopes() @@ -341,7 +345,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if where, ok := cs.Expression.(clause.Where); ok { if len(where.Exprs) == 1 { if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { - where.Exprs[0] = clause.AndConditions(orConds) + if len(orConds.Exprs) == 1 { + where.Exprs[0] = clause.AndConditions(orConds) + } } } conds = append(conds, clause.And(where.Exprs...)) @@ -362,6 +368,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, key := range keys { column := clause.Column{Name: key, Table: curTable} + if strings.Contains(key, ".") { + column = clause.Column{Name: key} + } conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } case map[string]interface{}: @@ -374,6 +383,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, key := range keys { reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) column := clause.Column{Name: key, Table: curTable} + if strings.Contains(key, ".") { + column = clause.Column{Name: key} + } switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if _, ok := v[key].(driver.Valuer); ok { @@ -650,12 +662,15 @@ func (stmt *Statement) Changed(fields ...string) bool { for destValue.Kind() == reflect.Ptr { destValue = destValue.Elem() } - - changedValue, zero := field.ValueOf(stmt.Context, destValue) - if v { - return !utils.AssertEqual(changedValue, fieldValue) + if descSchema, err := schema.Parse(stmt.Dest, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + if destField := descSchema.LookUpField(field.DBName); destField != nil { + changedValue, zero := destField.ValueOf(stmt.Context, destValue) + if v { + return !utils.AssertEqual(changedValue, fieldValue) + } + return !zero && !utils.AssertEqual(changedValue, fieldValue) + } } - return !zero && !utils.AssertEqual(changedValue, fieldValue) } } return false diff --git a/cluster-manager/vendor/gorm.io/gorm/utils/utils.go b/cluster-manager/vendor/gorm.io/gorm/utils/utils.go index fc615d73..7e59264b 100644 --- a/cluster-manager/vendor/gorm.io/gorm/utils/utils.go +++ b/cluster-manager/vendor/gorm.io/gorm/utils/utils.go @@ -30,8 +30,12 @@ func sourceDir(file string) string { return filepath.ToSlash(s) + "/" } -// FileWithLineNum return the file name and line number of the current file -func FileWithLineNum() string { +// CallerFrame retrieves the first relevant stack frame outside of GORM's internal implementation files. +// It skips: +// - GORM's core source files (identified by gormSourceDir prefix) +// - Exclude test files (*_test.go) +// - go-gorm/gen's Generated files (*.gen.go) +func CallerFrame() runtime.Frame { pcs := [13]uintptr{} // the third caller usually from gorm internal len := runtime.Callers(3, pcs[:]) @@ -41,14 +45,24 @@ func FileWithLineNum() string { frame, _ := frames.Next() if (!strings.HasPrefix(frame.File, gormSourceDir) || strings.HasSuffix(frame.File, "_test.go")) && !strings.HasSuffix(frame.File, ".gen.go") { - return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) + return frame } } + return runtime.Frame{} +} + +// FileWithLineNum return the file name and line number of the current file +func FileWithLineNum() string { + frame := CallerFrame() + if frame.PC != 0 { + return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) + } + return "" } -func IsValidDBNameChar(c rune) bool { +func IsInvalidDBNameChar(c rune) bool { return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' } diff --git a/cluster-manager/vendor/modules.txt b/cluster-manager/vendor/modules.txt index d663f0bc..5a8739b3 100644 --- a/cluster-manager/vendor/modules.txt +++ b/cluster-manager/vendor/modules.txt @@ -484,23 +484,27 @@ github.com/inconshreveable/mousetrap # github.com/jackc/pgpassfile v1.0.0 ## explicit; go 1.12 github.com/jackc/pgpassfile -# github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a +# github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 ## explicit; go 1.14 github.com/jackc/pgservicefile -# github.com/jackc/pgx/v5 v5.4.3 -## explicit; go 1.19 +# github.com/jackc/pgx/v5 v5.6.0 +## explicit; go 1.20 github.com/jackc/pgx/v5 -github.com/jackc/pgx/v5/internal/anynil github.com/jackc/pgx/v5/internal/iobufpool github.com/jackc/pgx/v5/internal/pgio github.com/jackc/pgx/v5/internal/sanitize github.com/jackc/pgx/v5/internal/stmtcache github.com/jackc/pgx/v5/pgconn +github.com/jackc/pgx/v5/pgconn/ctxwatch github.com/jackc/pgx/v5/pgconn/internal/bgreader -github.com/jackc/pgx/v5/pgconn/internal/ctxwatch github.com/jackc/pgx/v5/pgproto3 github.com/jackc/pgx/v5/pgtype +github.com/jackc/pgx/v5/pgxpool github.com/jackc/pgx/v5/stdlib +# github.com/jackc/puddle/v2 v2.2.2 +## explicit; go 1.19 +github.com/jackc/puddle/v2 +github.com/jackc/puddle/v2/internal/genstack # github.com/jgautheron/goconst v1.7.1 ## explicit; go 1.13 github.com/jgautheron/goconst @@ -1244,10 +1248,10 @@ gopkg.in/yaml.v2 # gopkg.in/yaml.v3 v3.0.1 ## explicit gopkg.in/yaml.v3 -# gorm.io/driver/postgres v1.5.7 -## explicit; go 1.18 +# gorm.io/driver/postgres v1.6.0 +## explicit; go 1.20 gorm.io/driver/postgres -# gorm.io/gorm v1.30.0 +# gorm.io/gorm v1.31.1 ## explicit; go 1.18 gorm.io/gorm gorm.io/gorm/callbacks diff --git a/cs-manager/go.mod b/cs-manager/go.mod index 768c9874..f9adb1f9 100644 --- a/cs-manager/go.mod +++ b/cs-manager/go.mod @@ -15,8 +15,8 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect google.golang.org/grpc v1.73.0 google.golang.org/protobuf v1.36.6 - gorm.io/driver/postgres v1.5.7 - gorm.io/gorm v1.30.0 + gorm.io/driver/postgres v1.6.0 + gorm.io/gorm v1.31.1 ) require ( diff --git a/cs-manager/go.sum b/cs-manager/go.sum index 05ad1c2a..058ee5d9 100644 --- a/cs-manager/go.sum +++ b/cs-manager/go.sum @@ -661,10 +661,10 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= -gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= -gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs= -gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= +gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= +gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= honnef.co/go/tools v0.6.1 h1:R094WgE8K4JirYjBaOpz/AvTyUu/3wbmAoskKN/pxTI= honnef.co/go/tools v0.6.1/go.mod h1:3puzxxljPCe8RGJX7BIy1plGbxEOZni5mR2aXe3/uk4= mvdan.cc/gofumpt v0.7.0 h1:bg91ttqXmi9y2xawvkuMXyvAA/1ZGJqYAEGjXuP0JXU= diff --git a/cs-manager/vendor/gorm.io/driver/postgres/error_translator.go b/cs-manager/vendor/gorm.io/driver/postgres/error_translator.go index 9c0ef253..5f813501 100644 --- a/cs-manager/vendor/gorm.io/driver/postgres/error_translator.go +++ b/cs-manager/vendor/gorm.io/driver/postgres/error_translator.go @@ -8,10 +8,12 @@ import ( "github.com/jackc/pgx/v5/pgconn" ) +// The error codes to map PostgreSQL errors to gorm errors, here is the PostgreSQL error codes reference https://www.postgresql.org/docs/current/errcodes-appendix.html. var errCodes = map[string]error{ "23505": gorm.ErrDuplicatedKey, "23503": gorm.ErrForeignKeyViolated, "42703": gorm.ErrInvalidField, + "23514": gorm.ErrCheckConstraintViolated, } type ErrMessage struct { diff --git a/cs-manager/vendor/gorm.io/driver/postgres/migrator.go b/cs-manager/vendor/gorm.io/driver/postgres/migrator.go index 6174e1c1..6b57ce69 100644 --- a/cs-manager/vendor/gorm.io/driver/postgres/migrator.go +++ b/cs-manager/vendor/gorm.io/driver/postgres/migrator.go @@ -3,10 +3,10 @@ package postgres import ( "database/sql" "fmt" + "github.com/jackc/pgx/v5" "regexp" "strings" - "github.com/jackc/pgx/v5" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/migrator" @@ -38,18 +38,34 @@ WHERE ` var typeAliasMap = map[string][]string{ - "int2": {"smallint"}, - "int4": {"integer"}, - "int8": {"bigint"}, - "smallint": {"int2"}, - "integer": {"int4"}, - "bigint": {"int8"}, - "decimal": {"numeric"}, - "numeric": {"decimal"}, - "timestamptz": {"timestamp with time zone"}, - "timestamp with time zone": {"timestamptz"}, - "bool": {"boolean"}, - "boolean": {"bool"}, + "int": {"integer"}, + "int2": {"smallint"}, + "int4": {"integer"}, + "int8": {"bigint"}, + "smallint": {"int2"}, + "integer": {"int4"}, + "bigint": {"int8"}, + "date": {"date"}, + "decimal": {"numeric"}, + "numeric": {"decimal"}, + "timestamp": {"timestamp"}, + "timestamptz": {"timestamp with time zone"}, + "timestamp without time zone": {"timestamp"}, + "timestamp with time zone": {"timestamptz"}, + "bool": {"boolean"}, + "boolean": {"bool"}, + "serial2": {"smallserial"}, + "serial4": {"serial"}, + "serial8": {"bigserial"}, + "varbit": {"bit varying"}, + "char": {"character"}, + "varchar": {"character varying"}, + "float4": {"real"}, + "float8": {"double precision"}, + "time": {"time"}, + "timetz": {"time with time zone"}, + "time without time zone": {"time"}, + "time with time zone": {"timetz"}, } type Migrator struct { @@ -120,7 +136,8 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { } createIndexSQL += "INDEX " - if strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" { + hasConcurrentOption := strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" + if hasConcurrentOption { createIndexSQL += "CONCURRENTLY " } @@ -132,6 +149,10 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { createIndexSQL += " ?" } + if idx.Option != "" && !hasConcurrentOption { + createIndexSQL += " " + idx.Option + } + if idx.Where != "" { createIndexSQL += " WHERE " + idx.Where } @@ -312,7 +333,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { fileType := clause.Expr{SQL: m.DataTypeOf(field)} // check for typeName and SQL name isSameType := true - if fieldColumnType.DatabaseTypeName() != fileType.SQL { + if !strings.EqualFold(fieldColumnType.DatabaseTypeName(), fileType.SQL) { isSameType = false // if different, also check for aliases aliases := m.GetTypeAliases(fieldColumnType.DatabaseTypeName()) @@ -375,10 +396,16 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { return err } } else { - if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil { + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { return err } } + } else if !field.HasDefaultValue { + // case - as-is column has default value and to-be column has no default value + // need to drop default + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { + return err + } } } return nil @@ -474,8 +501,8 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, column.LengthValue = typeLenValue } - if (strings.HasPrefix(column.DefaultValueValue.String, "nextval('") && - strings.HasSuffix(column.DefaultValueValue.String, "seq'::regclass)")) || (identityIncrement.Valid && identityIncrement.String != "") { + autoIncrementValuePattern := regexp.MustCompile(`^nextval\('"?[^']+seq"?'::regclass\)$`) + if autoIncrementValuePattern.MatchString(column.DefaultValueValue.String) || (identityIncrement.Valid && identityIncrement.String != "") { column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true} column.DefaultValueValue = sql.NullString{} } diff --git a/cs-manager/vendor/gorm.io/driver/postgres/postgres.go b/cs-manager/vendor/gorm.io/driver/postgres/postgres.go index e865b0f8..2d8fd997 100644 --- a/cs-manager/vendor/gorm.io/driver/postgres/postgres.go +++ b/cs-manager/vendor/gorm.io/driver/postgres/postgres.go @@ -1,13 +1,16 @@ package postgres import ( + "context" "database/sql" "fmt" "regexp" "strconv" "strings" + "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/stdlib" "gorm.io/gorm" "gorm.io/gorm/callbacks" @@ -31,7 +34,7 @@ type Config struct { } var ( - timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )") + timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone|timezone)=(.*?)($|&| )") defaultIdentifierLength = 63 //maximum identifier length for postgres ) @@ -99,10 +102,23 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol } result := timeZoneMatcher.FindStringSubmatch(dialector.Config.DSN) + var options []stdlib.OptionOpenDB if len(result) > 2 { config.RuntimeParams["timezone"] = result[2] + options = append(options, stdlib.OptionAfterConnect(func(ctx context.Context, conn *pgx.Conn) error { + loc, tzErr := time.LoadLocation(result[2]) + if tzErr != nil { + return tzErr + } + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "timestamp", + OID: pgtype.TimestampOID, + Codec: &pgtype.TimestampCodec{ScanLocation: loc}, + }) + return nil + })) } - db.ConnPool = stdlib.OpenDB(*config) + db.ConnPool = stdlib.OpenDB(*config, options...) } return } @@ -228,7 +244,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { } return "decimal" case schema.String: - if field.Size > 0 { + if field.Size > 0 && field.Size <= 10485760 { return fmt.Sprintf("varchar(%d)", field.Size) } return "text" diff --git a/cs-manager/vendor/gorm.io/gorm/README.md b/cs-manager/vendor/gorm.io/gorm/README.md index 745dad60..24eb84c9 100644 --- a/cs-manager/vendor/gorm.io/gorm/README.md +++ b/cs-manager/vendor/gorm.io/gorm/README.md @@ -3,7 +3,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) -[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) +[![test status](https://github.com/go-gorm/gorm/actions/workflows/tests.yml/badge.svg)](https://github.com/go-gorm/gorm/actions) [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) [![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) diff --git a/cs-manager/vendor/gorm.io/gorm/association.go b/cs-manager/vendor/gorm.io/gorm/association.go index e3f51d17..3a4e0e25 100644 --- a/cs-manager/vendor/gorm.io/gorm/association.go +++ b/cs-manager/vendor/gorm.io/gorm/association.go @@ -19,10 +19,10 @@ type Association struct { } func (db *DB) Association(column string) *Association { - association := &Association{DB: db} + association := &Association{DB: db, Unscope: db.Statement.Unscoped} table := db.Statement.Table - if err := db.Statement.Parse(db.Statement.Model); err == nil { + if association.Error = db.Statement.Parse(db.Statement.Model); association.Error == nil { db.Statement.Table = table association.Relationship = db.Statement.Schema.Relationships.Relations[column] @@ -34,8 +34,6 @@ func (db *DB) Association(column string) *Association { for db.Statement.ReflectValue.Kind() == reflect.Ptr { db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() } - } else { - association.Error = err } return association @@ -58,6 +56,8 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro } func (association *Association) Append(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { switch association.Relationship.Type { case schema.HasOne, schema.BelongsTo: @@ -73,6 +73,8 @@ func (association *Association) Append(values ...interface{}) error { } func (association *Association) Replace(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { reflectValue := association.DB.Statement.ReflectValue rel := association.Relationship @@ -97,7 +99,7 @@ func (association *Association) Replace(values ...interface{}) error { return association.Error } - // set old associations's foreign key to null + // set old association's foreign key to null switch rel.Type { case schema.BelongsTo: if len(values) == 0 { @@ -195,6 +197,8 @@ func (association *Association) Replace(values ...interface{}) error { } func (association *Association) Delete(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { var ( reflectValue = association.DB.Statement.ReflectValue @@ -300,7 +304,7 @@ func (association *Association) Delete(values ...interface{}) error { } if association.Error == nil { - // clean up deleted values's foreign key + // clean up deleted values' foreign key relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { @@ -431,10 +435,49 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } } + processMap := func(mapv reflect.Value) { + child := reflect.New(association.Relationship.FieldSchema.ModelType) + + switch association.Relationship.Type { + case schema.HasMany: + for _, ref := range association.Relationship.References { + key := reflect.ValueOf(ref.ForeignKey.DBName) + if ref.OwnPrimaryKey { + v := ref.PrimaryKey.ReflectValueOf(association.DB.Statement.Context, source) + mapv.SetMapIndex(key, v) + } else if ref.PrimaryValue != "" { + mapv.SetMapIndex(key, reflect.ValueOf(ref.PrimaryValue)) + } + } + association.Error = association.DB.Session(&Session{ + NewDB: true, + }).Model(child.Interface()).Create(mapv.Interface()).Error + case schema.Many2Many: + association.Error = association.DB.Session(&Session{ + NewDB: true, + }).Model(child.Interface()).Create(mapv.Interface()).Error + + for _, key := range mapv.MapKeys() { + k := strings.ToLower(key.String()) + if f, ok := association.Relationship.FieldSchema.FieldsByDBName[k]; ok { + _ = f.Set(association.DB.Statement.Context, child, mapv.MapIndex(key).Interface()) + } + } + appendToFieldValues(child) + } + } + switch rv.Kind() { + case reflect.Map: + processMap(rv) case reflect.Slice, reflect.Array: for i := 0; i < rv.Len(); i++ { - appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr()) + elem := reflect.Indirect(rv.Index(i)) + if elem.Kind() == reflect.Map { + processMap(elem) + continue + } + appendToFieldValues(elem.Addr()) } case reflect.Struct: if !rv.CanAddr() { @@ -591,3 +634,32 @@ func (association *Association) buildCondition() *DB { return tx } + +func expandValues(values ...any) (results []any) { + appendToResult := func(rv reflect.Value) { + // unwrap interface + if rv.IsValid() && rv.Kind() == reflect.Interface { + rv = rv.Elem() + } + if rv.IsValid() && rv.Kind() == reflect.Struct { + p := reflect.New(rv.Type()) + p.Elem().Set(rv) + results = append(results, p.Interface()) + } else if rv.IsValid() { + results = append(results, rv.Interface()) + } + } + + // Process each argument; if an argument is a slice/array, expand its elements + for _, value := range values { + rv := reflect.ValueOf(value) + if rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array { + for i := 0; i < rv.Len(); i++ { + appendToResult(rv.Index(i)) + } + } else { + appendToResult(rv) + } + } + return +} diff --git a/cs-manager/vendor/gorm.io/gorm/callbacks.go b/cs-manager/vendor/gorm.io/gorm/callbacks.go index 50b5b0e9..bd97f040 100644 --- a/cs-manager/vendor/gorm.io/gorm/callbacks.go +++ b/cs-manager/vendor/gorm.io/gorm/callbacks.go @@ -89,10 +89,16 @@ func (p *processor) Execute(db *DB) *DB { resetBuildClauses = true } - if optimizer, ok := db.Statement.Dest.(StatementModifier); ok { + if optimizer, ok := stmt.Dest.(StatementModifier); ok { optimizer.ModifyStatement(stmt) } + if db.DefaultContextTimeout > 0 { + if _, ok := stmt.Context.Deadline(); !ok { + stmt.Context, _ = context.WithTimeout(stmt.Context, db.DefaultContextTimeout) + } + } + // assign model values if stmt.Model == nil { stmt.Model = stmt.Dest diff --git a/cs-manager/vendor/gorm.io/gorm/callbacks/create.go b/cs-manager/vendor/gorm.io/gorm/callbacks/create.go index d8701f51..e5929adb 100644 --- a/cs-manager/vendor/gorm.io/gorm/callbacks/create.go +++ b/cs-manager/vendor/gorm.io/gorm/callbacks/create.go @@ -53,9 +53,13 @@ func Create(config *Config) func(db *gorm.DB) { if _, ok := db.Statement.Clauses["RETURNING"]; !ok { fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue)) for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { - fromColumns = append(fromColumns, clause.Column{Name: field.DBName}) + if field.Readable { + fromColumns = append(fromColumns, clause.Column{Name: field.DBName}) + } + } + if len(fromColumns) > 0 { + db.Statement.AddClause(clause.Returning{Columns: fromColumns}) } - db.Statement.AddClause(clause.Returning{Columns: fromColumns}) } } } @@ -76,8 +80,11 @@ func Create(config *Config) func(db *gorm.DB) { ok, mode := hasReturning(db, supportReturning) if ok { if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { - if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { + onConflict, _ := c.Expression.(clause.OnConflict) + if onConflict.DoNothing { mode |= gorm.ScanOnConflictDoNothing + } else if len(onConflict.DoUpdates) > 0 || onConflict.UpdateAll { + mode |= gorm.ScanUpdate } } @@ -122,6 +129,16 @@ func Create(config *Config) func(db *gorm.DB) { pkFieldName = "@id" ) + if db.Statement.Schema != nil { + if db.Statement.Schema.PrioritizedPrimaryField == nil || + !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue || + !db.Statement.Schema.PrioritizedPrimaryField.Readable { + return + } + pkField = db.Statement.Schema.PrioritizedPrimaryField + pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName + } + insertID, err := result.LastInsertId() insertOk := err == nil && insertID > 0 @@ -132,14 +149,6 @@ func Create(config *Config) func(db *gorm.DB) { return } - if db.Statement.Schema != nil { - if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - return - } - pkField = db.Statement.Schema.PrioritizedPrimaryField - pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName - } - // append @id column with value for auto-increment primary key // the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1 switch values := db.Statement.Dest.(type) { diff --git a/cs-manager/vendor/gorm.io/gorm/chainable_api.go b/cs-manager/vendor/gorm.io/gorm/chainable_api.go index 8a6aea34..8f6113cc 100644 --- a/cs-manager/vendor/gorm.io/gorm/chainable_api.go +++ b/cs-manager/vendor/gorm.io/gorm/chainable_api.go @@ -178,7 +178,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { - tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsInvalidDBNameChar) } else { tx.Statement.Omits = columns } @@ -283,7 +283,7 @@ func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() - fields := strings.FieldsFunc(name, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(name, utils.IsInvalidDBNameChar) tx.Statement.AddClause(clause.GroupBy{ Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, }) diff --git a/cs-manager/vendor/gorm.io/gorm/clause/association.go b/cs-manager/vendor/gorm.io/gorm/clause/association.go new file mode 100644 index 00000000..a9bf7eb0 --- /dev/null +++ b/cs-manager/vendor/gorm.io/gorm/clause/association.go @@ -0,0 +1,35 @@ +package clause + +// AssociationOpType represents association operation types +type AssociationOpType int + +const ( + OpUnlink AssociationOpType = iota // Unlink association + OpDelete // Delete association records + OpUpdate // Update association records + OpCreate // Create association records with assignments +) + +// Association represents an association operation +type Association struct { + Association string // Association name + Type AssociationOpType // Operation type + Conditions []Expression // Filter conditions + Set []Assignment // Assignment operations (for Update and Create) + Values []interface{} // Values for Create operation +} + +// AssociationAssigner is an interface for association operation providers +type AssociationAssigner interface { + AssociationAssignments() []Association +} + +// Assignments implements the Assigner interface so that AssociationOperation can be used as a Set method parameter +func (ao Association) Assignments() []Assignment { + return []Assignment{} +} + +// AssociationAssignments implements the AssociationAssigner interface +func (ao Association) AssociationAssignments() []Association { + return []Association{ao} +} diff --git a/cs-manager/vendor/gorm.io/gorm/clause/set.go b/cs-manager/vendor/gorm.io/gorm/clause/set.go index 75eb6bdd..cb5f36a0 100644 --- a/cs-manager/vendor/gorm.io/gorm/clause/set.go +++ b/cs-manager/vendor/gorm.io/gorm/clause/set.go @@ -9,6 +9,11 @@ type Assignment struct { Value interface{} } +// Assigner assignments provider interface +type Assigner interface { + Assignments() []Assignment +} + func (set Set) Name() string { return "SET" } @@ -37,6 +42,9 @@ func (set Set) MergeClause(clause *Clause) { clause.Expression = Set(copiedAssignments) } +// Assignments implements Assigner for Set. +func (set Set) Assignments() []Assignment { return []Assignment(set) } + func Assignments(values map[string]interface{}) Set { keys := make([]string, 0, len(values)) for key := range values { @@ -58,3 +66,6 @@ func AssignmentColumns(values []string) Set { } return assignments } + +// Assignments implements Assigner for a single Assignment. +func (a Assignment) Assignments() []Assignment { return []Assignment{a} } diff --git a/cs-manager/vendor/gorm.io/gorm/finisher_api.go b/cs-manager/vendor/gorm.io/gorm/finisher_api.go index 57809d17..e9e35f1b 100644 --- a/cs-manager/vendor/gorm.io/gorm/finisher_api.go +++ b/cs-manager/vendor/gorm.io/gorm/finisher_api.go @@ -465,7 +465,7 @@ func (db *DB) Count(count *int64) (tx *DB) { if len(tx.Statement.Selects) == 1 { dbName := tx.Statement.Selects[0] - fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(dbName, utils.IsInvalidDBNameChar) if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) { if tx.Statement.Parse(tx.Statement.Model) == nil { if f := tx.Statement.Schema.LookUpField(dbName); f != nil { @@ -564,7 +564,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { } if len(tx.Statement.Selects) != 1 { - fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(column, utils.IsInvalidDBNameChar) tx.Statement.AddClauseIfNotExists(clause.Select{ Distinct: tx.Statement.Distinct, Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, @@ -675,9 +675,9 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { } ctx := tx.Statement.Context - if _, ok := ctx.Deadline(); !ok { - if db.Config.DefaultTransactionTimeout > 0 { - ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout) + if db.DefaultTransactionTimeout > 0 { + if _, ok := ctx.Deadline(); !ok { + ctx, _ = context.WithTimeout(ctx, db.DefaultTransactionTimeout) } } diff --git a/cs-manager/vendor/gorm.io/gorm/generics.go b/cs-manager/vendor/gorm.io/gorm/generics.go index ad2d063f..166d1520 100644 --- a/cs-manager/vendor/gorm.io/gorm/generics.go +++ b/cs-manager/vendor/gorm.io/gorm/generics.go @@ -3,12 +3,15 @@ package gorm import ( "context" "database/sql" + "errors" "fmt" + "reflect" "sort" "strings" "gorm.io/gorm/clause" "gorm.io/gorm/logger" + "gorm.io/gorm/schema" ) type result struct { @@ -35,10 +38,34 @@ type Interface[T any] interface { } type CreateInterface[T any] interface { - ChainInterface[T] + ExecInterface[T] + // chain methods available at start; Select/Omit keep CreateInterface to allow Create chaining + Scopes(scopes ...func(db *Statement)) ChainInterface[T] + Where(query interface{}, args ...interface{}) ChainInterface[T] + Not(query interface{}, args ...interface{}) ChainInterface[T] + Or(query interface{}, args ...interface{}) ChainInterface[T] + Limit(offset int) ChainInterface[T] + Offset(offset int) ChainInterface[T] + Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] + Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] + Select(query string, args ...interface{}) CreateInterface[T] + Omit(columns ...string) CreateInterface[T] + MapColumns(m map[string]string) ChainInterface[T] + Distinct(args ...interface{}) ChainInterface[T] + Group(name string) ChainInterface[T] + Having(query interface{}, args ...interface{}) ChainInterface[T] + Order(value interface{}) ChainInterface[T] + Build(builder clause.Builder) + + Delete(ctx context.Context) (rowsAffected int, err error) + Update(ctx context.Context, name string, value any) (rowsAffected int, err error) + Updates(ctx context.Context, t T) (rowsAffected int, err error) + Count(ctx context.Context, column string) (result int64, err error) + Table(name string, args ...interface{}) CreateInterface[T] Create(ctx context.Context, r *T) error CreateInBatches(ctx context.Context, r *[]T, batchSize int) error + Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T] } type ChainInterface[T any] interface { @@ -58,15 +85,28 @@ type ChainInterface[T any] interface { Group(name string) ChainInterface[T] Having(query interface{}, args ...interface{}) ChainInterface[T] Order(value interface{}) ChainInterface[T] + Set(assignments ...clause.Assigner) SetUpdateOnlyInterface[T] Build(builder clause.Builder) + Table(name string, args ...interface{}) ChainInterface[T] Delete(ctx context.Context) (rowsAffected int, err error) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) Updates(ctx context.Context, t T) (rowsAffected int, err error) Count(ctx context.Context, column string) (result int64, err error) } +// SetUpdateOnlyInterface is returned by Set after chaining; only Update is allowed +type SetUpdateOnlyInterface[T any] interface { + Update(ctx context.Context) (rowsAffected int, err error) +} + +// SetCreateOrUpdateInterface is returned by Set at start; Create or Update are allowed +type SetCreateOrUpdateInterface[T any] interface { + Create(ctx context.Context) error + Update(ctx context.Context) (rowsAffected int, err error) +} + type ExecInterface[T any] interface { Scan(ctx context.Context, r interface{}) error First(context.Context) (T, error) @@ -142,13 +182,15 @@ func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] { return execG[T]{g: &g[T]{ db: c.db, ops: append(c.ops, func(db *DB) *DB { - return db.Raw(sql, values...) + var r T + return db.Model(r).Raw(sql, values...) }), }} } func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error { - return c.apply(ctx).Exec(sql, values...).Error + var r T + return c.apply(ctx).Model(r).Exec(sql, values...).Error } type createG[T any] struct { @@ -161,6 +203,22 @@ func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { })} } +func (c createG[T]) Select(query string, args ...interface{}) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Select(query, args...) + })} +} + +func (c createG[T]) Omit(columns ...string) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Omit(columns...) + })} +} + +func (c createG[T]) Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T] { + return c.processSet(assignments...) +} + func (c createG[T]) Create(ctx context.Context, r *T) error { return c.g.apply(ctx).Create(r).Error } @@ -187,6 +245,12 @@ func (c chainG[T]) with(v op) chainG[T] { } } +func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Table(name, args...) + }) +} + func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { return c.with(func(db *DB) *DB { for _, fc := range scopes { @@ -196,12 +260,6 @@ func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { }) } -func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] { - return c.with(func(db *DB) *DB { - return db.Table(name, args...) - }) -} - func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Where(query, args...) @@ -388,6 +446,10 @@ func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] { }) } +func (c chainG[T]) Set(assignments ...clause.Assigner) SetUpdateOnlyInterface[T] { + return c.processSet(assignments...) +} + func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Distinct(args...) @@ -425,12 +487,12 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err relation, ok := db.Statement.Schema.Relationships.Relations[association] if !ok { if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 { - relationships := db.Statement.Schema.Relationships + relationships := &db.Statement.Schema.Relationships for _, field := range preloadFields { var ok bool relation, ok = relationships.Relations[field] if ok { - relationships = relation.FieldSchema.Relationships + relationships = &relation.FieldSchema.Relationships } else { db.AddError(fmt.Errorf("relation %s not found", association)) return nil @@ -567,7 +629,7 @@ func (g execG[T]) First(ctx context.Context) (T, error) { func (g execG[T]) Scan(ctx context.Context, result interface{}) error { var r T - err := g.g.apply(ctx).Model(r).Find(&result).Error + err := g.g.apply(ctx).Model(r).Find(result).Error return err } @@ -603,3 +665,242 @@ func (g execG[T]) Row(ctx context.Context) *sql.Row { func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) { return g.g.apply(ctx).Rows() } + +func (c chainG[T]) processSet(items ...clause.Assigner) setCreateOrUpdateG[T] { + var ( + assigns []clause.Assignment + assocOps []clause.Association + ) + + for _, item := range items { + // Check if it's an AssociationAssigner + if assocAssigner, ok := item.(clause.AssociationAssigner); ok { + assocOps = append(assocOps, assocAssigner.AssociationAssignments()...) + } else { + assigns = append(assigns, item.Assignments()...) + } + } + + return setCreateOrUpdateG[T]{ + c: c, + assigns: assigns, + assocOps: assocOps, + } +} + +// setCreateOrUpdateG[T] is a struct that holds operations to be executed in a batch. +// It supports regular assignments and association operations. +type setCreateOrUpdateG[T any] struct { + c chainG[T] + assigns []clause.Assignment + assocOps []clause.Association +} + +func (s setCreateOrUpdateG[T]) Update(ctx context.Context) (rowsAffected int, err error) { + // Execute association operations + for _, assocOp := range s.assocOps { + if err := s.executeAssociationOperation(ctx, assocOp); err != nil { + return 0, err + } + } + + // Execute assignment operations + if len(s.assigns) > 0 { + var r T + res := s.c.g.apply(ctx).Model(r).Clauses(clause.Set(s.assigns)).Updates(map[string]interface{}{}) + return int(res.RowsAffected), res.Error + } + + return 0, nil +} + +func (s setCreateOrUpdateG[T]) Create(ctx context.Context) error { + // Execute association operations + for _, assocOp := range s.assocOps { + if err := s.executeAssociationOperation(ctx, assocOp); err != nil { + return err + } + } + + // Execute assignment operations + if len(s.assigns) > 0 { + data := make(map[string]interface{}, len(s.assigns)) + for _, a := range s.assigns { + data[a.Column.Name] = a.Value + } + var r T + return s.c.g.apply(ctx).Model(r).Create(data).Error + } + + return nil +} + +// executeAssociationOperation executes an association operation +func (s setCreateOrUpdateG[T]) executeAssociationOperation(ctx context.Context, op clause.Association) error { + var r T + base := s.c.g.apply(ctx).Model(r) + + switch op.Type { + case clause.OpCreate: + return s.handleAssociationCreate(ctx, base, op) + case clause.OpUnlink, clause.OpDelete, clause.OpUpdate: + return s.handleAssociation(ctx, base, op) + default: + return fmt.Errorf("unknown association operation type: %v", op.Type) + } +} + +func (s setCreateOrUpdateG[T]) handleAssociationCreate(ctx context.Context, base *DB, op clause.Association) error { + if len(op.Set) > 0 { + return s.handleAssociationForOwners(base, ctx, func(owner T, assoc *Association) error { + data := make(map[string]interface{}, len(op.Set)) + for _, a := range op.Set { + data[a.Column.Name] = a.Value + } + return assoc.Append(data) + }, op.Association) + } + + return s.handleAssociationForOwners(base, ctx, func(owner T, assoc *Association) error { + return assoc.Append(op.Values...) + }, op.Association) +} + +// handleAssociationForOwners is a helper function that handles associations for all owners +func (s setCreateOrUpdateG[T]) handleAssociationForOwners(base *DB, ctx context.Context, handler func(owner T, association *Association) error, associationName string) error { + var owners []T + if err := base.Find(&owners).Error; err != nil { + return err + } + + for _, owner := range owners { + assoc := base.Session(&Session{NewDB: true, Context: ctx}).Model(&owner).Association(associationName) + if assoc.Error != nil { + return assoc.Error + } + + if err := handler(owner, assoc); err != nil { + return err + } + } + return nil +} + +func (s setCreateOrUpdateG[T]) handleAssociation(ctx context.Context, base *DB, op clause.Association) error { + assoc := base.Association(op.Association) + if assoc.Error != nil { + return assoc.Error + } + + var ( + rel = assoc.Relationship + assocModel = reflect.New(rel.FieldSchema.ModelType).Interface() + fkNil = map[string]any{} + setMap = make(map[string]any, len(op.Set)) + ownerPKNames []string + ownerFKNames []string + primaryColumns []any + foreignColumns []any + ) + + for _, a := range op.Set { + setMap[a.Column.Name] = a.Value + } + + for _, ref := range rel.References { + fkNil[ref.ForeignKey.DBName] = nil + + if ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownerPKNames = append(ownerPKNames, ref.PrimaryKey.DBName) + primaryColumns = append(primaryColumns, clause.Column{Name: ref.PrimaryKey.DBName}) + foreignColumns = append(foreignColumns, clause.Column{Name: ref.ForeignKey.DBName}) + } else if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownerFKNames = append(ownerFKNames, ref.ForeignKey.DBName) + primaryColumns = append(primaryColumns, clause.Column{Name: ref.PrimaryKey.DBName}) + } + } + + assocDB := s.c.g.db.Session(&Session{NewDB: true, Context: ctx}).Model(assocModel).Where(op.Conditions) + + switch rel.Type { + case schema.HasOne, schema.HasMany: + assocDB = assocDB.Where("? IN (?)", foreignColumns, base.Select(ownerPKNames)) + switch op.Type { + case clause.OpUnlink: + return assocDB.Updates(fkNil).Error + case clause.OpDelete: + return assocDB.Delete(assocModel).Error + case clause.OpUpdate: + return assocDB.Updates(setMap).Error + } + case schema.BelongsTo: + switch op.Type { + case clause.OpDelete: + return base.Transaction(func(tx *DB) error { + assocDB.Statement.ConnPool = tx.Statement.ConnPool + base.Statement.ConnPool = tx.Statement.ConnPool + + if err := assocDB.Where("? IN (?)", primaryColumns, base.Select(ownerFKNames)).Delete(assocModel).Error; err != nil { + return err + } + return base.Updates(fkNil).Error + }) + case clause.OpUnlink: + return base.Updates(fkNil).Error + case clause.OpUpdate: + return assocDB.Where("? IN (?)", primaryColumns, base.Select(ownerFKNames)).Updates(setMap).Error + } + case schema.Many2Many: + joinModel := reflect.New(rel.JoinTable.ModelType).Interface() + joinDB := base.Session(&Session{NewDB: true, Context: ctx}).Model(joinModel) + + // EXISTS owners: owners.pk = join.owner_fk for all owner refs + ownersExists := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.Schema.Table).Select("1") + for _, ref := range rel.References { + if ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownersExists = ownersExists.Where(clause.Eq{ + Column: clause.Column{Table: rel.Schema.Table, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + }) + } + } + + // EXISTS related: related.pk = join.rel_fk for all related refs, plus optional conditions + relatedExists := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.FieldSchema.Table).Select("1") + for _, ref := range rel.References { + if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + relatedExists = relatedExists.Where(clause.Eq{ + Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + }) + } + } + relatedExists = relatedExists.Where(op.Conditions) + + switch op.Type { + case clause.OpUnlink, clause.OpDelete: + joinDB = joinDB.Where("EXISTS (?)", ownersExists) + if len(op.Conditions) > 0 { + joinDB = joinDB.Where("EXISTS (?)", relatedExists) + } + return joinDB.Delete(nil).Error + case clause.OpUpdate: + // Update related table rows that have join rows matching owners + relatedDB := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.FieldSchema.Table).Where(op.Conditions) + + // correlated join subquery: join.rel_fk = related.pk AND EXISTS owners + joinSub := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.JoinTable.Table).Select("1") + for _, ref := range rel.References { + if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + joinSub = joinSub.Where(clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, + }) + } + } + joinSub = joinSub.Where("EXISTS (?)", ownersExists) + return relatedDB.Where("EXISTS (?)", joinSub).Updates(setMap).Error + } + } + return errors.New("unsupported relationship") +} diff --git a/cs-manager/vendor/gorm.io/gorm/gorm.go b/cs-manager/vendor/gorm.io/gorm/gorm.go index 67889262..a209bb09 100644 --- a/cs-manager/vendor/gorm.io/gorm/gorm.go +++ b/cs-manager/vendor/gorm.io/gorm/gorm.go @@ -23,6 +23,7 @@ type Config struct { // You can disable it by setting `SkipDefaultTransaction` to true SkipDefaultTransaction bool DefaultTransactionTimeout time.Duration + DefaultContextTimeout time.Duration // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer @@ -137,6 +138,14 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { return isConfig && !isConfig2 }) + if len(opts) > 0 { + if c, ok := opts[0].(*Config); ok { + config = c + } else { + opts = append([]Option{config}, opts...) + } + } + var skipAfterInitialize bool for _, opt := range opts { if opt != nil { diff --git a/cs-manager/vendor/gorm.io/gorm/logger/slog.go b/cs-manager/vendor/gorm.io/gorm/logger/slog.go new file mode 100644 index 00000000..27c74683 --- /dev/null +++ b/cs-manager/vendor/gorm.io/gorm/logger/slog.go @@ -0,0 +1,116 @@ +//go:build go1.21 + +package logger + +import ( + "context" + "errors" + "fmt" + "log/slog" + "time" + + "gorm.io/gorm/utils" +) + +type slogLogger struct { + Logger *slog.Logger + LogLevel LogLevel + SlowThreshold time.Duration + Parameterized bool + Colorful bool // Ignored in slog + IgnoreRecordNotFoundError bool +} + +func NewSlogLogger(logger *slog.Logger, config Config) Interface { + return &slogLogger{ + Logger: logger, + LogLevel: config.LogLevel, + SlowThreshold: config.SlowThreshold, + Parameterized: config.ParameterizedQueries, + IgnoreRecordNotFoundError: config.IgnoreRecordNotFoundError, + } +} + +func (l *slogLogger) LogMode(level LogLevel) Interface { + newLogger := *l + newLogger.LogLevel = level + return &newLogger +} + +func (l *slogLogger) Info(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Info { + l.log(ctx, slog.LevelInfo, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Warn(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Warn { + l.log(ctx, slog.LevelWarn, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Error(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Error { + l.log(ctx, slog.LevelError, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + if l.LogLevel <= Silent { + return + } + + elapsed := time.Since(begin) + sql, rows := fc() + fields := []slog.Attr{ + slog.String("duration", fmt.Sprintf("%.3fms", float64(elapsed.Nanoseconds())/1e6)), + slog.String("sql", sql), + } + + if rows != -1 { + fields = append(fields, slog.Int64("rows", rows)) + } + + switch { + case err != nil && (!l.IgnoreRecordNotFoundError || !errors.Is(err, ErrRecordNotFound)): + fields = append(fields, slog.String("error", err.Error())) + l.log(ctx, slog.LevelError, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + + case l.SlowThreshold != 0 && elapsed > l.SlowThreshold: + l.log(ctx, slog.LevelWarn, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + + case l.LogLevel >= Info: + l.log(ctx, slog.LevelInfo, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + } +} + +func (l *slogLogger) log(ctx context.Context, level slog.Level, msg string, args ...any) { + if ctx == nil { + ctx = context.Background() + } + + if !l.Logger.Enabled(ctx, level) { + return + } + + r := slog.NewRecord(time.Now(), level, msg, utils.CallerFrame().PC) + r.Add(args...) + _ = l.Logger.Handler().Handle(ctx, r) +} + +// ParamsFilter filter params +func (l *slogLogger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { + if l.Parameterized { + return sql, nil + } + return sql, params +} diff --git a/cs-manager/vendor/gorm.io/gorm/migrator/migrator.go b/cs-manager/vendor/gorm.io/gorm/migrator/migrator.go index cec4e30f..35107d57 100644 --- a/cs-manager/vendor/gorm.io/gorm/migrator/migrator.go +++ b/cs-manager/vendor/gorm.io/gorm/migrator/migrator.go @@ -474,7 +474,6 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy // found, smart migrate fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) realDataType := strings.ToLower(columnType.DatabaseTypeName()) - var ( alterColumn bool isSameType = fullDataType == realDataType @@ -513,8 +512,19 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } } + } - // check precision + // check precision + if realDataType == "decimal" || realDataType == "numeric" && + regexp.MustCompile(realDataType+`\(.*\)`).FindString(fullDataType) != "" { // if realDataType has no precision,ignore + precision, scale, ok := columnType.DecimalSize() + if ok { + if !strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d,%d)", realDataType, precision, scale)) && + !strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d)", realDataType, precision)) { + alterColumn = true + } + } + } else { if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { alterColumn = true @@ -550,6 +560,10 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy v1, _ := strconv.ParseBool(dv) v2, _ := strconv.ParseBool(field.DefaultValue) alterColumn = v1 != v2 + case schema.String: + if dv != field.DefaultValue && dv != strings.Trim(field.DefaultValue, "'\"") { + alterColumn = true + } default: alterColumn = dv != field.DefaultValue } diff --git a/cs-manager/vendor/gorm.io/gorm/schema/field.go b/cs-manager/vendor/gorm.io/gorm/schema/field.go index a6ff1a72..de797402 100644 --- a/cs-manager/vendor/gorm.io/gorm/schema/field.go +++ b/cs-manager/vendor/gorm.io/gorm/schema/field.go @@ -448,16 +448,17 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } // create valuer, setter when parse struct -func (field *Field) setupValuerAndSetter() { +func (field *Field) setupValuerAndSetter(modelType reflect.Type) { // Setup NewValuePool field.setupNewValuePool() // ValueOf returns field's value and if it is zero fieldIndex := field.StructField.Index[0] switch { - case len(field.StructField.Index) == 1 && fieldIndex > 0: - field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) { - fieldValue := reflect.Indirect(value).Field(fieldIndex) + case len(field.StructField.Index) == 1 && fieldIndex >= 0: + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + v = reflect.Indirect(v) + fieldValue := v.Field(fieldIndex) return fieldValue.Interface(), fieldValue.IsZero() } default: @@ -504,9 +505,10 @@ func (field *Field) setupValuerAndSetter() { // ReflectValueOf returns field's reflect value switch { - case len(field.StructField.Index) == 1 && fieldIndex > 0: - field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value { - return reflect.Indirect(value).Field(fieldIndex) + case len(field.StructField.Index) == 1 && fieldIndex >= 0: + field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { + v = reflect.Indirect(v) + return v.Field(fieldIndex) } default: field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { diff --git a/cs-manager/vendor/gorm.io/gorm/schema/relationship.go b/cs-manager/vendor/gorm.io/gorm/schema/relationship.go index f1ace924..0535bba4 100644 --- a/cs-manager/vendor/gorm.io/gorm/schema/relationship.go +++ b/cs-manager/vendor/gorm.io/gorm/schema/relationship.go @@ -75,9 +75,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { } ) - cacheStore := schema.cacheStore - - if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { + if relation.FieldSchema, err = getOrParse(fieldValue, schema.cacheStore, schema.namer); err != nil { schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err) return nil } @@ -147,6 +145,9 @@ func hasPolymorphicRelation(tagSettings map[string]string) bool { } func (schema *Schema) setRelation(relation *Relationship) { + schema.Relationships.Mux.Lock() + defer schema.Relationships.Mux.Unlock() + // set non-embedded relation if rel := schema.Relationships.Relations[relation.Name]; rel != nil { if len(rel.Field.BindNames) > 1 { @@ -590,6 +591,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu // build references for idx, foreignField := range foreignFields { // use same data type for foreign keys + schema.Relationships.Mux.Lock() + if schema != foreignField.Schema { + foreignField.Schema.Relationships.Mux.Lock() + } if copyableDataType(primaryFields[idx].DataType) { foreignField.DataType = primaryFields[idx].DataType } @@ -597,6 +602,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu if foreignField.Size == 0 { foreignField.Size = primaryFields[idx].Size } + schema.Relationships.Mux.Unlock() + if schema != foreignField.Schema { + foreignField.Schema.Relationships.Mux.Unlock() + } relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], diff --git a/cs-manager/vendor/gorm.io/gorm/schema/schema.go b/cs-manager/vendor/gorm.io/gorm/schema/schema.go index db236797..09697a7a 100644 --- a/cs-manager/vendor/gorm.io/gorm/schema/schema.go +++ b/cs-manager/vendor/gorm.io/gorm/schema/schema.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "go/ast" + "path" "reflect" "strings" "sync" @@ -59,14 +60,14 @@ type Schema struct { cacheStore *sync.Map } -func (schema Schema) String() string { +func (schema *Schema) String() string { if schema.ModelType.Name() == "" { return fmt.Sprintf("%s(%s)", schema.Name, schema.Table) } return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name()) } -func (schema Schema) MakeSlice() reflect.Value { +func (schema *Schema) MakeSlice() reflect.Value { slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20) results := reflect.New(slice.Type()) results.Elem().Set(slice) @@ -74,13 +75,23 @@ func (schema Schema) MakeSlice() reflect.Value { return results } -func (schema Schema) LookUpField(name string) *Field { +func (schema *Schema) LookUpField(name string) *Field { if field, ok := schema.FieldsByDBName[name]; ok { return field } if field, ok := schema.FieldsByName[name]; ok { return field } + + // Lookup field using namer-driven ColumnName + if schema.namer == nil { + return nil + } + namerColumnName := schema.namer.ColumnName(schema.Table, name) + if field, ok := schema.FieldsByDBName[namerColumnName]; ok { + return field + } + return nil } @@ -92,10 +103,7 @@ func (schema Schema) LookUpField(name string) *Field { // } // ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID") // } -func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { - if len(bindNames) == 0 { - return nil - } +func (schema *Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { for i := len(bindNames) - 1; i >= 0; i-- { find := strings.Join(bindNames[:i], ".") + "." + name if field, ok := schema.FieldsByBindName[find]; ok { @@ -113,6 +121,14 @@ type TablerWithNamer interface { TableName(Namer) string } +var callbackTypes = []callbackType{ + callbackTypeBeforeCreate, callbackTypeAfterCreate, + callbackTypeBeforeUpdate, callbackTypeAfterUpdate, + callbackTypeBeforeSave, callbackTypeAfterSave, + callbackTypeBeforeDelete, callbackTypeAfterDelete, + callbackTypeAfterFind, +} + // Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { return ParseWithSpecialTableName(dest, cacheStore, namer, "") @@ -124,34 +140,33 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - value := reflect.ValueOf(dest) - if value.Kind() == reflect.Ptr && value.IsNil() { - value = reflect.New(value.Type().Elem()) - } - modelType := reflect.Indirect(value).Type() - - if modelType.Kind() == reflect.Interface { - modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() - } - - for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType := reflect.ValueOf(dest).Type() + if modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } if modelType.Kind() != reflect.Struct { - if modelType.PkgPath() == "" { - return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + if modelType.Kind() == reflect.Interface { + modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() + } + + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } - return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } // Cache the Schema for performance, // Use the modelType or modelType + schemaTable (if it present) as cache key. - var schemaCacheKey interface{} + var schemaCacheKey interface{} = modelType if specialTableName != "" { schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName) - } else { - schemaCacheKey = modelType } // Load exist schema cache, return if exists @@ -162,28 +177,29 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return s, s.err } + var tableName string modelValue := reflect.New(modelType) - tableName := namer.TableName(modelType.Name()) - if tabler, ok := modelValue.Interface().(Tabler); ok { + if specialTableName != "" { + tableName = specialTableName + } else if en, ok := namer.(embeddedNamer); ok { + tableName = en.Table + } else if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() - } - if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { + } else if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { tableName = tabler.TableName(namer) - } - if en, ok := namer.(embeddedNamer); ok { - tableName = en.Table - } - if specialTableName != "" && specialTableName != tableName { - tableName = specialTableName + } else { + tableName = namer.TableName(modelType.Name()) } schema := &Schema{ Name: modelType.Name(), ModelType: modelType, Table: tableName, - FieldsByName: map[string]*Field{}, - FieldsByBindName: map[string]*Field{}, - FieldsByDBName: map[string]*Field{}, + DBNames: make([]string, 0, 10), + Fields: make([]*Field, 0, 10), + FieldsByName: make(map[string]*Field, 10), + FieldsByBindName: make(map[string]*Field, 10), + FieldsByDBName: make(map[string]*Field, 10), Relationships: Relationships{Relations: map[string]*Relationship{}}, cacheStore: cacheStore, namer: namer, @@ -227,8 +243,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.FieldsByBindName[bindName] = field if v != nil && v.PrimaryKey { + // remove the existing primary key field for idx, f := range schema.PrimaryFields { - if f == v { + if f.DBName == v.DBName { schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) } } @@ -247,7 +264,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.FieldsByBindName[bindName] = field } - field.setupValuerAndSetter() + field.setupValuerAndSetter(modelType) } prioritizedPrimaryField := schema.LookUpField("id") @@ -283,10 +300,37 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) } + _, embedded := schema.cacheStore.Load(embeddedCacheKey) + relationshipFields := []*Field{} for _, field := range schema.Fields { if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil { schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } + + if !embedded { + if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) { + relationshipFields = append(relationshipFields, field) + schema.FieldsByName[field.Name] = field + schema.FieldsByBindName[field.BindName()] = field + } + + fieldValue := reflect.New(field.IndirectFieldType).Interface() + if fc, ok := fieldValue.(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } + + if fc, ok := fieldValue.(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } + + if fc, ok := fieldValue.(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } + + if fc, ok := fieldValue.(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + } + } } if field := schema.PrioritizedPrimaryField; field != nil { @@ -303,24 +347,6 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } } - callbackTypes := []callbackType{ - callbackTypeBeforeCreate, callbackTypeAfterCreate, - callbackTypeBeforeUpdate, callbackTypeAfterUpdate, - callbackTypeBeforeSave, callbackTypeAfterSave, - callbackTypeBeforeDelete, callbackTypeAfterDelete, - callbackTypeAfterFind, - } - for _, cbName := range callbackTypes { - if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() { - switch methodValue.Type().String() { - case "func(*gorm.DB) error": // TODO hack - reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) - default: - logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) - } - } - } - // Cache the schema if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded { s := v.(*Schema) @@ -336,84 +362,47 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } }() - if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { - for _, field := range schema.Fields { - if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) { - if schema.parseRelation(field); schema.err != nil { - return schema, schema.err + for _, cbName := range callbackTypes { + if methodValue := modelValue.MethodByName(string(cbName)); methodValue.IsValid() { + switch methodValue.Type().String() { + case "func(*gorm.DB) error": + expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath()) + if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath { + reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) } else { - schema.FieldsByName[field.Name] = field - schema.FieldsByBindName[field.BindName()] = field + logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg) + // PASS } + default: + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) } + } + } - fieldValue := reflect.New(field.IndirectFieldType) - fieldInterface := fieldValue.Interface() - if fc, ok := fieldInterface.(CreateClausesInterface); ok { - field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) - } - - if fc, ok := fieldInterface.(QueryClausesInterface); ok { - field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) - } - - if fc, ok := fieldInterface.(UpdateClausesInterface); ok { - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) - } - - if fc, ok := fieldInterface.(DeleteClausesInterface); ok { - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) - } + // parse relationships + for _, field := range relationshipFields { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err } } return schema, schema.err } -// This unrolling is needed to show to the compiler the exact set of methods -// that can be used on the modelType. -// Prior to go1.22 any use of MethodByName would cause the linker to -// abandon dead code elimination for the entire binary. -// As of go1.22 the compiler supports one special case of a string constant -// being passed to MethodByName. For enterprise customers or those building -// large binaries, this gives a significant reduction in binary size. -// https://github.com/golang/go/issues/62257 -func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value { - switch cbType { - case callbackTypeBeforeCreate: - return modelType.MethodByName(string(callbackTypeBeforeCreate)) - case callbackTypeAfterCreate: - return modelType.MethodByName(string(callbackTypeAfterCreate)) - case callbackTypeBeforeUpdate: - return modelType.MethodByName(string(callbackTypeBeforeUpdate)) - case callbackTypeAfterUpdate: - return modelType.MethodByName(string(callbackTypeAfterUpdate)) - case callbackTypeBeforeSave: - return modelType.MethodByName(string(callbackTypeBeforeSave)) - case callbackTypeAfterSave: - return modelType.MethodByName(string(callbackTypeAfterSave)) - case callbackTypeBeforeDelete: - return modelType.MethodByName(string(callbackTypeBeforeDelete)) - case callbackTypeAfterDelete: - return modelType.MethodByName(string(callbackTypeAfterDelete)) - case callbackTypeAfterFind: - return modelType.MethodByName(string(callbackTypeAfterFind)) - default: - return reflect.ValueOf(nil) - } -} - func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() - for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } if modelType.Kind() != reflect.Struct { - if modelType.PkgPath() == "" { - return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } - return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } if v, ok := cacheStore.Load(modelType); ok { diff --git a/cs-manager/vendor/gorm.io/gorm/schema/serializer.go b/cs-manager/vendor/gorm.io/gorm/schema/serializer.go index 0fafbcba..d774a8ed 100644 --- a/cs-manager/vendor/gorm.io/gorm/schema/serializer.go +++ b/cs-manager/vendor/gorm.io/gorm/schema/serializer.go @@ -8,6 +8,7 @@ import ( "encoding/gob" "encoding/json" "fmt" + "math" "reflect" "strings" "sync" @@ -127,16 +128,31 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect. // Value implements serializer interface func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { rv := reflect.ValueOf(fieldValue) - switch v := fieldValue.(type) { - case int64, int, uint, uint64, int32, uint32, int16, uint16: - result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() - case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: + switch fieldValue.(type) { + case int, int8, int16, int32, int64: + result = time.Unix(rv.Int(), 0).UTC() + case uint, uint8, uint16, uint32, uint64: + if uv := rv.Uint(); uv > math.MaxInt64 { + err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) + } else { + result = time.Unix(int64(uv), 0).UTC() //nolint:gosec + } + case *int, *int8, *int16, *int32, *int64: if rv.IsZero() { return nil, nil } - result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() + result = time.Unix(rv.Elem().Int(), 0).UTC() + case *uint, *uint8, *uint16, *uint32, *uint64: + if rv.IsZero() { + return nil, nil + } + if uv := rv.Elem().Uint(); uv > math.MaxInt64 { + err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) + } else { + result = time.Unix(int64(uv), 0).UTC() //nolint:gosec + } default: - err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) + err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", fieldValue) } return } diff --git a/cs-manager/vendor/gorm.io/gorm/schema/utils.go b/cs-manager/vendor/gorm.io/gorm/schema/utils.go index fa1c65d4..86305d7b 100644 --- a/cs-manager/vendor/gorm.io/gorm/schema/utils.go +++ b/cs-manager/vendor/gorm.io/gorm/schema/utils.go @@ -17,25 +17,23 @@ func ParseTagSetting(str string, sep string) map[string]string { settings := map[string]string{} names := strings.Split(str, sep) + var parsedNames []string for i := 0; i < len(names); i++ { - j := i - if len(names[j]) > 0 { - for { - if names[j][len(names[j])-1] == '\\' { - i++ - names[j] = names[j][0:len(names[j])-1] + sep + names[i] - names[i] = "" - } else { - break - } - } + s := names[i] + for strings.HasSuffix(s, "\\") && i+1 < len(names) { + i++ + s = s[:len(s)-1] + sep + names[i] } + parsedNames = append(parsedNames, s) + } - values := strings.Split(names[j], ":") + for _, tag := range parsedNames { + values := strings.Split(tag, ":") k := strings.TrimSpace(strings.ToUpper(values[0])) - if len(values) >= 2 { - settings[k] = strings.Join(values[1:], ":") + val := strings.Join(values[1:], ":") + val = strings.ReplaceAll(val, `\"`, `"`) + settings[k] = val } else if k != "" { settings[k] = k } @@ -121,6 +119,17 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, } switch reflectValue.Kind() { + case reflect.Map: + results = [][]interface{}{make([]interface{}, len(fields))} + for idx, field := range fields { + mapValue := reflectValue.MapIndex(reflect.ValueOf(field.DBName)) + if mapValue.IsZero() { + mapValue = reflectValue.MapIndex(reflect.ValueOf(field.Name)) + } + results[0][idx] = mapValue.Interface() + } + + dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue} case reflect.Struct: results = [][]interface{}{make([]interface{}, len(fields))} diff --git a/cs-manager/vendor/gorm.io/gorm/statement.go b/cs-manager/vendor/gorm.io/gorm/statement.go index c6183724..736087d7 100644 --- a/cs-manager/vendor/gorm.io/gorm/statement.go +++ b/cs-manager/vendor/gorm.io/gorm/statement.go @@ -96,7 +96,9 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { if v.Name == clause.CurrentTable { if stmt.TableExpr != nil { stmt.TableExpr.Build(stmt) - } else { + } else if stmt.Table != "" { + write(v.Raw, stmt.Table) + } else if stmt.AddError(stmt.Parse(stmt.Model)) == nil { write(v.Raw, stmt.Table) } } else { @@ -334,6 +336,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] switch v := arg.(type) { case clause.Expression: conds = append(conds, v) + case []clause.Expression: + conds = append(conds, v...) case *DB: v.executeScopes() @@ -341,7 +345,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if where, ok := cs.Expression.(clause.Where); ok { if len(where.Exprs) == 1 { if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { - where.Exprs[0] = clause.AndConditions(orConds) + if len(orConds.Exprs) == 1 { + where.Exprs[0] = clause.AndConditions(orConds) + } } } conds = append(conds, clause.And(where.Exprs...)) @@ -362,6 +368,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, key := range keys { column := clause.Column{Name: key, Table: curTable} + if strings.Contains(key, ".") { + column = clause.Column{Name: key} + } conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } case map[string]interface{}: @@ -374,6 +383,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, key := range keys { reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) column := clause.Column{Name: key, Table: curTable} + if strings.Contains(key, ".") { + column = clause.Column{Name: key} + } switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if _, ok := v[key].(driver.Valuer); ok { @@ -650,12 +662,15 @@ func (stmt *Statement) Changed(fields ...string) bool { for destValue.Kind() == reflect.Ptr { destValue = destValue.Elem() } - - changedValue, zero := field.ValueOf(stmt.Context, destValue) - if v { - return !utils.AssertEqual(changedValue, fieldValue) + if descSchema, err := schema.Parse(stmt.Dest, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + if destField := descSchema.LookUpField(field.DBName); destField != nil { + changedValue, zero := destField.ValueOf(stmt.Context, destValue) + if v { + return !utils.AssertEqual(changedValue, fieldValue) + } + return !zero && !utils.AssertEqual(changedValue, fieldValue) + } } - return !zero && !utils.AssertEqual(changedValue, fieldValue) } } return false diff --git a/cs-manager/vendor/gorm.io/gorm/utils/utils.go b/cs-manager/vendor/gorm.io/gorm/utils/utils.go index fc615d73..7e59264b 100644 --- a/cs-manager/vendor/gorm.io/gorm/utils/utils.go +++ b/cs-manager/vendor/gorm.io/gorm/utils/utils.go @@ -30,8 +30,12 @@ func sourceDir(file string) string { return filepath.ToSlash(s) + "/" } -// FileWithLineNum return the file name and line number of the current file -func FileWithLineNum() string { +// CallerFrame retrieves the first relevant stack frame outside of GORM's internal implementation files. +// It skips: +// - GORM's core source files (identified by gormSourceDir prefix) +// - Exclude test files (*_test.go) +// - go-gorm/gen's Generated files (*.gen.go) +func CallerFrame() runtime.Frame { pcs := [13]uintptr{} // the third caller usually from gorm internal len := runtime.Callers(3, pcs[:]) @@ -41,14 +45,24 @@ func FileWithLineNum() string { frame, _ := frames.Next() if (!strings.HasPrefix(frame.File, gormSourceDir) || strings.HasSuffix(frame.File, "_test.go")) && !strings.HasSuffix(frame.File, ".gen.go") { - return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) + return frame } } + return runtime.Frame{} +} + +// FileWithLineNum return the file name and line number of the current file +func FileWithLineNum() string { + frame := CallerFrame() + if frame.PC != 0 { + return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) + } + return "" } -func IsValidDBNameChar(c rune) bool { +func IsInvalidDBNameChar(c rune) bool { return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' } diff --git a/cs-manager/vendor/modules.txt b/cs-manager/vendor/modules.txt index cd53e9b4..c60a9d35 100644 --- a/cs-manager/vendor/modules.txt +++ b/cs-manager/vendor/modules.txt @@ -1249,10 +1249,10 @@ gopkg.in/yaml.v2 # gopkg.in/yaml.v3 v3.0.1 ## explicit gopkg.in/yaml.v3 -# gorm.io/driver/postgres v1.5.7 -## explicit; go 1.18 +# gorm.io/driver/postgres v1.6.0 +## explicit; go 1.20 gorm.io/driver/postgres -# gorm.io/gorm v1.30.0 +# gorm.io/gorm v1.31.1 ## explicit; go 1.18 gorm.io/gorm gorm.io/gorm/callbacks diff --git a/event-processor/go.mod b/event-processor/go.mod index 4e879bb4..248a3788 100644 --- a/event-processor/go.mod +++ b/event-processor/go.mod @@ -27,7 +27,7 @@ require ( google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1 google.golang.org/protobuf v1.36.10 gorm.io/driver/postgres v1.6.0 - gorm.io/gorm v1.31.0 + gorm.io/gorm v1.31.1 ) require ( diff --git a/event-processor/go.sum b/event-processor/go.sum index 9c69ba23..67a0600a 100644 --- a/event-processor/go.sum +++ b/event-processor/go.sum @@ -678,8 +678,8 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= -gorm.io/gorm v1.31.0 h1:0VlycGreVhK7RF/Bwt51Fk8v0xLiiiFdbGDPIZQ7mJY= -gorm.io/gorm v1.31.0/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= honnef.co/go/tools v0.6.1 h1:R094WgE8K4JirYjBaOpz/AvTyUu/3wbmAoskKN/pxTI= honnef.co/go/tools v0.6.1/go.mod h1:3puzxxljPCe8RGJX7BIy1plGbxEOZni5mR2aXe3/uk4= mvdan.cc/gofumpt v0.7.0 h1:bg91ttqXmi9y2xawvkuMXyvAA/1ZGJqYAEGjXuP0JXU= diff --git a/event-processor/vendor/gorm.io/gorm/README.md b/event-processor/vendor/gorm.io/gorm/README.md index 745dad60..24eb84c9 100644 --- a/event-processor/vendor/gorm.io/gorm/README.md +++ b/event-processor/vendor/gorm.io/gorm/README.md @@ -3,7 +3,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) -[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) +[![test status](https://github.com/go-gorm/gorm/actions/workflows/tests.yml/badge.svg)](https://github.com/go-gorm/gorm/actions) [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) [![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) diff --git a/event-processor/vendor/gorm.io/gorm/association.go b/event-processor/vendor/gorm.io/gorm/association.go index f210ca0a..3a4e0e25 100644 --- a/event-processor/vendor/gorm.io/gorm/association.go +++ b/event-processor/vendor/gorm.io/gorm/association.go @@ -99,7 +99,7 @@ func (association *Association) Replace(values ...interface{}) error { return association.Error } - // set old associations's foreign key to null + // set old association's foreign key to null switch rel.Type { case schema.BelongsTo: if len(values) == 0 { @@ -304,7 +304,7 @@ func (association *Association) Delete(values ...interface{}) error { } if association.Error == nil { - // clean up deleted values's foreign key + // clean up deleted values' foreign key relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { diff --git a/event-processor/vendor/gorm.io/gorm/chainable_api.go b/event-processor/vendor/gorm.io/gorm/chainable_api.go index 8a6aea34..8f6113cc 100644 --- a/event-processor/vendor/gorm.io/gorm/chainable_api.go +++ b/event-processor/vendor/gorm.io/gorm/chainable_api.go @@ -178,7 +178,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { - tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsInvalidDBNameChar) } else { tx.Statement.Omits = columns } @@ -283,7 +283,7 @@ func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() - fields := strings.FieldsFunc(name, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(name, utils.IsInvalidDBNameChar) tx.Statement.AddClause(clause.GroupBy{ Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, }) diff --git a/event-processor/vendor/gorm.io/gorm/finisher_api.go b/event-processor/vendor/gorm.io/gorm/finisher_api.go index e601fe66..e9e35f1b 100644 --- a/event-processor/vendor/gorm.io/gorm/finisher_api.go +++ b/event-processor/vendor/gorm.io/gorm/finisher_api.go @@ -465,7 +465,7 @@ func (db *DB) Count(count *int64) (tx *DB) { if len(tx.Statement.Selects) == 1 { dbName := tx.Statement.Selects[0] - fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(dbName, utils.IsInvalidDBNameChar) if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) { if tx.Statement.Parse(tx.Statement.Model) == nil { if f := tx.Statement.Schema.LookUpField(dbName); f != nil { @@ -564,7 +564,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { } if len(tx.Statement.Selects) != 1 { - fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(column, utils.IsInvalidDBNameChar) tx.Statement.AddClauseIfNotExists(clause.Select{ Distinct: tx.Statement.Distinct, Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, diff --git a/event-processor/vendor/gorm.io/gorm/generics.go b/event-processor/vendor/gorm.io/gorm/generics.go index 79238d5f..166d1520 100644 --- a/event-processor/vendor/gorm.io/gorm/generics.go +++ b/event-processor/vendor/gorm.io/gorm/generics.go @@ -39,7 +39,7 @@ type Interface[T any] interface { type CreateInterface[T any] interface { ExecInterface[T] - // chain methods available at start; return ChainInterface + // chain methods available at start; Select/Omit keep CreateInterface to allow Create chaining Scopes(scopes ...func(db *Statement)) ChainInterface[T] Where(query interface{}, args ...interface{}) ChainInterface[T] Not(query interface{}, args ...interface{}) ChainInterface[T] @@ -48,8 +48,8 @@ type CreateInterface[T any] interface { Offset(offset int) ChainInterface[T] Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] - Select(query string, args ...interface{}) ChainInterface[T] - Omit(columns ...string) ChainInterface[T] + Select(query string, args ...interface{}) CreateInterface[T] + Omit(columns ...string) CreateInterface[T] MapColumns(m map[string]string) ChainInterface[T] Distinct(args ...interface{}) ChainInterface[T] Group(name string) ChainInterface[T] @@ -203,6 +203,18 @@ func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { })} } +func (c createG[T]) Select(query string, args ...interface{}) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Select(query, args...) + })} +} + +func (c createG[T]) Omit(columns ...string) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Omit(columns...) + })} +} + func (c createG[T]) Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T] { return c.processSet(assignments...) } diff --git a/event-processor/vendor/gorm.io/gorm/logger/slog.go b/event-processor/vendor/gorm.io/gorm/logger/slog.go index 613234ca..27c74683 100644 --- a/event-processor/vendor/gorm.io/gorm/logger/slog.go +++ b/event-processor/vendor/gorm.io/gorm/logger/slog.go @@ -8,6 +8,8 @@ import ( "fmt" "log/slog" "time" + + "gorm.io/gorm/utils" ) type slogLogger struct { @@ -37,19 +39,19 @@ func (l *slogLogger) LogMode(level LogLevel) Interface { func (l *slogLogger) Info(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Info { - l.Logger.InfoContext(ctx, msg, slog.Any("data", data)) + l.log(ctx, slog.LevelInfo, msg, slog.Any("data", data)) } } func (l *slogLogger) Warn(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Warn { - l.Logger.WarnContext(ctx, msg, slog.Any("data", data)) + l.log(ctx, slog.LevelWarn, msg, slog.Any("data", data)) } } func (l *slogLogger) Error(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Error { - l.Logger.ErrorContext(ctx, msg, slog.Any("data", data)) + l.log(ctx, slog.LevelError, msg, slog.Any("data", data)) } } @@ -72,25 +74,39 @@ func (l *slogLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql switch { case err != nil && (!l.IgnoreRecordNotFoundError || !errors.Is(err, ErrRecordNotFound)): fields = append(fields, slog.String("error", err.Error())) - l.Logger.ErrorContext(ctx, "SQL executed", slog.Attr{ + l.log(ctx, slog.LevelError, "SQL executed", slog.Attr{ Key: "trace", Value: slog.GroupValue(fields...), }) case l.SlowThreshold != 0 && elapsed > l.SlowThreshold: - l.Logger.WarnContext(ctx, "SQL executed", slog.Attr{ + l.log(ctx, slog.LevelWarn, "SQL executed", slog.Attr{ Key: "trace", Value: slog.GroupValue(fields...), }) case l.LogLevel >= Info: - l.Logger.InfoContext(ctx, "SQL executed", slog.Attr{ + l.log(ctx, slog.LevelInfo, "SQL executed", slog.Attr{ Key: "trace", Value: slog.GroupValue(fields...), }) } } +func (l *slogLogger) log(ctx context.Context, level slog.Level, msg string, args ...any) { + if ctx == nil { + ctx = context.Background() + } + + if !l.Logger.Enabled(ctx, level) { + return + } + + r := slog.NewRecord(time.Now(), level, msg, utils.CallerFrame().PC) + r.Add(args...) + _ = l.Logger.Handler().Handle(ctx, r) +} + // ParamsFilter filter params func (l *slogLogger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { if l.Parameterized { diff --git a/event-processor/vendor/gorm.io/gorm/migrator/migrator.go b/event-processor/vendor/gorm.io/gorm/migrator/migrator.go index 50a36d10..35107d57 100644 --- a/event-processor/vendor/gorm.io/gorm/migrator/migrator.go +++ b/event-processor/vendor/gorm.io/gorm/migrator/migrator.go @@ -560,6 +560,10 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy v1, _ := strconv.ParseBool(dv) v2, _ := strconv.ParseBool(field.DefaultValue) alterColumn = v1 != v2 + case schema.String: + if dv != field.DefaultValue && dv != strings.Trim(field.DefaultValue, "'\"") { + alterColumn = true + } default: alterColumn = dv != field.DefaultValue } diff --git a/event-processor/vendor/gorm.io/gorm/schema/schema.go b/event-processor/vendor/gorm.io/gorm/schema/schema.go index 9419846b..09697a7a 100644 --- a/event-processor/vendor/gorm.io/gorm/schema/schema.go +++ b/event-processor/vendor/gorm.io/gorm/schema/schema.go @@ -82,6 +82,16 @@ func (schema *Schema) LookUpField(name string) *Field { if field, ok := schema.FieldsByName[name]; ok { return field } + + // Lookup field using namer-driven ColumnName + if schema.namer == nil { + return nil + } + namerColumnName := schema.namer.ColumnName(schema.Table, name) + if field, ok := schema.FieldsByDBName[namerColumnName]; ok { + return field + } + return nil } diff --git a/event-processor/vendor/gorm.io/gorm/schema/serializer.go b/event-processor/vendor/gorm.io/gorm/schema/serializer.go index 0fafbcba..d774a8ed 100644 --- a/event-processor/vendor/gorm.io/gorm/schema/serializer.go +++ b/event-processor/vendor/gorm.io/gorm/schema/serializer.go @@ -8,6 +8,7 @@ import ( "encoding/gob" "encoding/json" "fmt" + "math" "reflect" "strings" "sync" @@ -127,16 +128,31 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect. // Value implements serializer interface func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { rv := reflect.ValueOf(fieldValue) - switch v := fieldValue.(type) { - case int64, int, uint, uint64, int32, uint32, int16, uint16: - result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() - case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: + switch fieldValue.(type) { + case int, int8, int16, int32, int64: + result = time.Unix(rv.Int(), 0).UTC() + case uint, uint8, uint16, uint32, uint64: + if uv := rv.Uint(); uv > math.MaxInt64 { + err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) + } else { + result = time.Unix(int64(uv), 0).UTC() //nolint:gosec + } + case *int, *int8, *int16, *int32, *int64: if rv.IsZero() { return nil, nil } - result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() + result = time.Unix(rv.Elem().Int(), 0).UTC() + case *uint, *uint8, *uint16, *uint32, *uint64: + if rv.IsZero() { + return nil, nil + } + if uv := rv.Elem().Uint(); uv > math.MaxInt64 { + err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) + } else { + result = time.Unix(int64(uv), 0).UTC() //nolint:gosec + } default: - err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) + err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", fieldValue) } return } diff --git a/event-processor/vendor/gorm.io/gorm/schema/utils.go b/event-processor/vendor/gorm.io/gorm/schema/utils.go index d4fe252e..86305d7b 100644 --- a/event-processor/vendor/gorm.io/gorm/schema/utils.go +++ b/event-processor/vendor/gorm.io/gorm/schema/utils.go @@ -17,25 +17,23 @@ func ParseTagSetting(str string, sep string) map[string]string { settings := map[string]string{} names := strings.Split(str, sep) + var parsedNames []string for i := 0; i < len(names); i++ { - j := i - if len(names[j]) > 0 { - for { - if names[j][len(names[j])-1] == '\\' { - i++ - names[j] = names[j][0:len(names[j])-1] + sep + names[i] - names[i] = "" - } else { - break - } - } + s := names[i] + for strings.HasSuffix(s, "\\") && i+1 < len(names) { + i++ + s = s[:len(s)-1] + sep + names[i] } + parsedNames = append(parsedNames, s) + } - values := strings.Split(names[j], ":") + for _, tag := range parsedNames { + values := strings.Split(tag, ":") k := strings.TrimSpace(strings.ToUpper(values[0])) - if len(values) >= 2 { - settings[k] = strings.Join(values[1:], ":") + val := strings.Join(values[1:], ":") + val = strings.ReplaceAll(val, `\"`, `"`) + settings[k] = val } else if k != "" { settings[k] = k } diff --git a/event-processor/vendor/gorm.io/gorm/utils/utils.go b/event-processor/vendor/gorm.io/gorm/utils/utils.go index fc615d73..7e59264b 100644 --- a/event-processor/vendor/gorm.io/gorm/utils/utils.go +++ b/event-processor/vendor/gorm.io/gorm/utils/utils.go @@ -30,8 +30,12 @@ func sourceDir(file string) string { return filepath.ToSlash(s) + "/" } -// FileWithLineNum return the file name and line number of the current file -func FileWithLineNum() string { +// CallerFrame retrieves the first relevant stack frame outside of GORM's internal implementation files. +// It skips: +// - GORM's core source files (identified by gormSourceDir prefix) +// - Exclude test files (*_test.go) +// - go-gorm/gen's Generated files (*.gen.go) +func CallerFrame() runtime.Frame { pcs := [13]uintptr{} // the third caller usually from gorm internal len := runtime.Callers(3, pcs[:]) @@ -41,14 +45,24 @@ func FileWithLineNum() string { frame, _ := frames.Next() if (!strings.HasPrefix(frame.File, gormSourceDir) || strings.HasSuffix(frame.File, "_test.go")) && !strings.HasSuffix(frame.File, ".gen.go") { - return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) + return frame } } + return runtime.Frame{} +} + +// FileWithLineNum return the file name and line number of the current file +func FileWithLineNum() string { + frame := CallerFrame() + if frame.PC != 0 { + return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) + } + return "" } -func IsValidDBNameChar(c rune) bool { +func IsInvalidDBNameChar(c rune) bool { return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' } diff --git a/event-processor/vendor/modules.txt b/event-processor/vendor/modules.txt index b34ed1e8..06e7a45b 100644 --- a/event-processor/vendor/modules.txt +++ b/event-processor/vendor/modules.txt @@ -1337,7 +1337,7 @@ gopkg.in/yaml.v3 # gorm.io/driver/postgres v1.6.0 ## explicit; go 1.20 gorm.io/driver/postgres -# gorm.io/gorm v1.31.0 +# gorm.io/gorm v1.31.1 ## explicit; go 1.18 gorm.io/gorm gorm.io/gorm/callbacks diff --git a/history-api/go.mod b/history-api/go.mod index fb790952..b627e1f8 100644 --- a/history-api/go.mod +++ b/history-api/go.mod @@ -21,7 +21,7 @@ require ( google.golang.org/protobuf v1.36.10 gorm.io/driver/clickhouse v0.6.1 gorm.io/driver/postgres v1.6.0 - gorm.io/gorm v1.31.0 + gorm.io/gorm v1.31.1 ) require ( diff --git a/history-api/go.sum b/history-api/go.sum index b47fad38..7b7e0df0 100644 --- a/history-api/go.sum +++ b/history-api/go.sum @@ -714,8 +714,8 @@ gorm.io/driver/clickhouse v0.6.1 h1:t7JMB6sLBXxN8hEO6RdzCbJCwq/jAEVZdwXlmQs1Sd4= gorm.io/driver/clickhouse v0.6.1/go.mod h1:riMYpJcGZ3sJ/OAZZ1rEP1j/Y0H6cByOAnwz7fo2AyM= gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= -gorm.io/gorm v1.31.0 h1:0VlycGreVhK7RF/Bwt51Fk8v0xLiiiFdbGDPIZQ7mJY= -gorm.io/gorm v1.31.0/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= honnef.co/go/tools v0.6.1 h1:R094WgE8K4JirYjBaOpz/AvTyUu/3wbmAoskKN/pxTI= honnef.co/go/tools v0.6.1/go.mod h1:3puzxxljPCe8RGJX7BIy1plGbxEOZni5mR2aXe3/uk4= mvdan.cc/gofumpt v0.7.0 h1:bg91ttqXmi9y2xawvkuMXyvAA/1ZGJqYAEGjXuP0JXU= diff --git a/history-api/vendor/gorm.io/gorm/README.md b/history-api/vendor/gorm.io/gorm/README.md index 745dad60..24eb84c9 100644 --- a/history-api/vendor/gorm.io/gorm/README.md +++ b/history-api/vendor/gorm.io/gorm/README.md @@ -3,7 +3,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) -[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) +[![test status](https://github.com/go-gorm/gorm/actions/workflows/tests.yml/badge.svg)](https://github.com/go-gorm/gorm/actions) [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) [![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) diff --git a/history-api/vendor/gorm.io/gorm/association.go b/history-api/vendor/gorm.io/gorm/association.go index f210ca0a..3a4e0e25 100644 --- a/history-api/vendor/gorm.io/gorm/association.go +++ b/history-api/vendor/gorm.io/gorm/association.go @@ -99,7 +99,7 @@ func (association *Association) Replace(values ...interface{}) error { return association.Error } - // set old associations's foreign key to null + // set old association's foreign key to null switch rel.Type { case schema.BelongsTo: if len(values) == 0 { @@ -304,7 +304,7 @@ func (association *Association) Delete(values ...interface{}) error { } if association.Error == nil { - // clean up deleted values's foreign key + // clean up deleted values' foreign key relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { diff --git a/history-api/vendor/gorm.io/gorm/chainable_api.go b/history-api/vendor/gorm.io/gorm/chainable_api.go index 8a6aea34..8f6113cc 100644 --- a/history-api/vendor/gorm.io/gorm/chainable_api.go +++ b/history-api/vendor/gorm.io/gorm/chainable_api.go @@ -178,7 +178,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { - tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsInvalidDBNameChar) } else { tx.Statement.Omits = columns } @@ -283,7 +283,7 @@ func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() - fields := strings.FieldsFunc(name, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(name, utils.IsInvalidDBNameChar) tx.Statement.AddClause(clause.GroupBy{ Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, }) diff --git a/history-api/vendor/gorm.io/gorm/finisher_api.go b/history-api/vendor/gorm.io/gorm/finisher_api.go index e601fe66..e9e35f1b 100644 --- a/history-api/vendor/gorm.io/gorm/finisher_api.go +++ b/history-api/vendor/gorm.io/gorm/finisher_api.go @@ -465,7 +465,7 @@ func (db *DB) Count(count *int64) (tx *DB) { if len(tx.Statement.Selects) == 1 { dbName := tx.Statement.Selects[0] - fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(dbName, utils.IsInvalidDBNameChar) if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) { if tx.Statement.Parse(tx.Statement.Model) == nil { if f := tx.Statement.Schema.LookUpField(dbName); f != nil { @@ -564,7 +564,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { } if len(tx.Statement.Selects) != 1 { - fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(column, utils.IsInvalidDBNameChar) tx.Statement.AddClauseIfNotExists(clause.Select{ Distinct: tx.Statement.Distinct, Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, diff --git a/history-api/vendor/gorm.io/gorm/generics.go b/history-api/vendor/gorm.io/gorm/generics.go index 79238d5f..166d1520 100644 --- a/history-api/vendor/gorm.io/gorm/generics.go +++ b/history-api/vendor/gorm.io/gorm/generics.go @@ -39,7 +39,7 @@ type Interface[T any] interface { type CreateInterface[T any] interface { ExecInterface[T] - // chain methods available at start; return ChainInterface + // chain methods available at start; Select/Omit keep CreateInterface to allow Create chaining Scopes(scopes ...func(db *Statement)) ChainInterface[T] Where(query interface{}, args ...interface{}) ChainInterface[T] Not(query interface{}, args ...interface{}) ChainInterface[T] @@ -48,8 +48,8 @@ type CreateInterface[T any] interface { Offset(offset int) ChainInterface[T] Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] - Select(query string, args ...interface{}) ChainInterface[T] - Omit(columns ...string) ChainInterface[T] + Select(query string, args ...interface{}) CreateInterface[T] + Omit(columns ...string) CreateInterface[T] MapColumns(m map[string]string) ChainInterface[T] Distinct(args ...interface{}) ChainInterface[T] Group(name string) ChainInterface[T] @@ -203,6 +203,18 @@ func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { })} } +func (c createG[T]) Select(query string, args ...interface{}) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Select(query, args...) + })} +} + +func (c createG[T]) Omit(columns ...string) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Omit(columns...) + })} +} + func (c createG[T]) Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T] { return c.processSet(assignments...) } diff --git a/history-api/vendor/gorm.io/gorm/logger/slog.go b/history-api/vendor/gorm.io/gorm/logger/slog.go index 613234ca..27c74683 100644 --- a/history-api/vendor/gorm.io/gorm/logger/slog.go +++ b/history-api/vendor/gorm.io/gorm/logger/slog.go @@ -8,6 +8,8 @@ import ( "fmt" "log/slog" "time" + + "gorm.io/gorm/utils" ) type slogLogger struct { @@ -37,19 +39,19 @@ func (l *slogLogger) LogMode(level LogLevel) Interface { func (l *slogLogger) Info(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Info { - l.Logger.InfoContext(ctx, msg, slog.Any("data", data)) + l.log(ctx, slog.LevelInfo, msg, slog.Any("data", data)) } } func (l *slogLogger) Warn(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Warn { - l.Logger.WarnContext(ctx, msg, slog.Any("data", data)) + l.log(ctx, slog.LevelWarn, msg, slog.Any("data", data)) } } func (l *slogLogger) Error(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Error { - l.Logger.ErrorContext(ctx, msg, slog.Any("data", data)) + l.log(ctx, slog.LevelError, msg, slog.Any("data", data)) } } @@ -72,25 +74,39 @@ func (l *slogLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql switch { case err != nil && (!l.IgnoreRecordNotFoundError || !errors.Is(err, ErrRecordNotFound)): fields = append(fields, slog.String("error", err.Error())) - l.Logger.ErrorContext(ctx, "SQL executed", slog.Attr{ + l.log(ctx, slog.LevelError, "SQL executed", slog.Attr{ Key: "trace", Value: slog.GroupValue(fields...), }) case l.SlowThreshold != 0 && elapsed > l.SlowThreshold: - l.Logger.WarnContext(ctx, "SQL executed", slog.Attr{ + l.log(ctx, slog.LevelWarn, "SQL executed", slog.Attr{ Key: "trace", Value: slog.GroupValue(fields...), }) case l.LogLevel >= Info: - l.Logger.InfoContext(ctx, "SQL executed", slog.Attr{ + l.log(ctx, slog.LevelInfo, "SQL executed", slog.Attr{ Key: "trace", Value: slog.GroupValue(fields...), }) } } +func (l *slogLogger) log(ctx context.Context, level slog.Level, msg string, args ...any) { + if ctx == nil { + ctx = context.Background() + } + + if !l.Logger.Enabled(ctx, level) { + return + } + + r := slog.NewRecord(time.Now(), level, msg, utils.CallerFrame().PC) + r.Add(args...) + _ = l.Logger.Handler().Handle(ctx, r) +} + // ParamsFilter filter params func (l *slogLogger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { if l.Parameterized { diff --git a/history-api/vendor/gorm.io/gorm/migrator/migrator.go b/history-api/vendor/gorm.io/gorm/migrator/migrator.go index 50a36d10..35107d57 100644 --- a/history-api/vendor/gorm.io/gorm/migrator/migrator.go +++ b/history-api/vendor/gorm.io/gorm/migrator/migrator.go @@ -560,6 +560,10 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy v1, _ := strconv.ParseBool(dv) v2, _ := strconv.ParseBool(field.DefaultValue) alterColumn = v1 != v2 + case schema.String: + if dv != field.DefaultValue && dv != strings.Trim(field.DefaultValue, "'\"") { + alterColumn = true + } default: alterColumn = dv != field.DefaultValue } diff --git a/history-api/vendor/gorm.io/gorm/schema/schema.go b/history-api/vendor/gorm.io/gorm/schema/schema.go index 9419846b..09697a7a 100644 --- a/history-api/vendor/gorm.io/gorm/schema/schema.go +++ b/history-api/vendor/gorm.io/gorm/schema/schema.go @@ -82,6 +82,16 @@ func (schema *Schema) LookUpField(name string) *Field { if field, ok := schema.FieldsByName[name]; ok { return field } + + // Lookup field using namer-driven ColumnName + if schema.namer == nil { + return nil + } + namerColumnName := schema.namer.ColumnName(schema.Table, name) + if field, ok := schema.FieldsByDBName[namerColumnName]; ok { + return field + } + return nil } diff --git a/history-api/vendor/gorm.io/gorm/schema/serializer.go b/history-api/vendor/gorm.io/gorm/schema/serializer.go index 0fafbcba..d774a8ed 100644 --- a/history-api/vendor/gorm.io/gorm/schema/serializer.go +++ b/history-api/vendor/gorm.io/gorm/schema/serializer.go @@ -8,6 +8,7 @@ import ( "encoding/gob" "encoding/json" "fmt" + "math" "reflect" "strings" "sync" @@ -127,16 +128,31 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect. // Value implements serializer interface func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { rv := reflect.ValueOf(fieldValue) - switch v := fieldValue.(type) { - case int64, int, uint, uint64, int32, uint32, int16, uint16: - result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() - case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: + switch fieldValue.(type) { + case int, int8, int16, int32, int64: + result = time.Unix(rv.Int(), 0).UTC() + case uint, uint8, uint16, uint32, uint64: + if uv := rv.Uint(); uv > math.MaxInt64 { + err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) + } else { + result = time.Unix(int64(uv), 0).UTC() //nolint:gosec + } + case *int, *int8, *int16, *int32, *int64: if rv.IsZero() { return nil, nil } - result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() + result = time.Unix(rv.Elem().Int(), 0).UTC() + case *uint, *uint8, *uint16, *uint32, *uint64: + if rv.IsZero() { + return nil, nil + } + if uv := rv.Elem().Uint(); uv > math.MaxInt64 { + err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) + } else { + result = time.Unix(int64(uv), 0).UTC() //nolint:gosec + } default: - err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) + err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", fieldValue) } return } diff --git a/history-api/vendor/gorm.io/gorm/schema/utils.go b/history-api/vendor/gorm.io/gorm/schema/utils.go index d4fe252e..86305d7b 100644 --- a/history-api/vendor/gorm.io/gorm/schema/utils.go +++ b/history-api/vendor/gorm.io/gorm/schema/utils.go @@ -17,25 +17,23 @@ func ParseTagSetting(str string, sep string) map[string]string { settings := map[string]string{} names := strings.Split(str, sep) + var parsedNames []string for i := 0; i < len(names); i++ { - j := i - if len(names[j]) > 0 { - for { - if names[j][len(names[j])-1] == '\\' { - i++ - names[j] = names[j][0:len(names[j])-1] + sep + names[i] - names[i] = "" - } else { - break - } - } + s := names[i] + for strings.HasSuffix(s, "\\") && i+1 < len(names) { + i++ + s = s[:len(s)-1] + sep + names[i] } + parsedNames = append(parsedNames, s) + } - values := strings.Split(names[j], ":") + for _, tag := range parsedNames { + values := strings.Split(tag, ":") k := strings.TrimSpace(strings.ToUpper(values[0])) - if len(values) >= 2 { - settings[k] = strings.Join(values[1:], ":") + val := strings.Join(values[1:], ":") + val = strings.ReplaceAll(val, `\"`, `"`) + settings[k] = val } else if k != "" { settings[k] = k } diff --git a/history-api/vendor/gorm.io/gorm/utils/utils.go b/history-api/vendor/gorm.io/gorm/utils/utils.go index fc615d73..7e59264b 100644 --- a/history-api/vendor/gorm.io/gorm/utils/utils.go +++ b/history-api/vendor/gorm.io/gorm/utils/utils.go @@ -30,8 +30,12 @@ func sourceDir(file string) string { return filepath.ToSlash(s) + "/" } -// FileWithLineNum return the file name and line number of the current file -func FileWithLineNum() string { +// CallerFrame retrieves the first relevant stack frame outside of GORM's internal implementation files. +// It skips: +// - GORM's core source files (identified by gormSourceDir prefix) +// - Exclude test files (*_test.go) +// - go-gorm/gen's Generated files (*.gen.go) +func CallerFrame() runtime.Frame { pcs := [13]uintptr{} // the third caller usually from gorm internal len := runtime.Callers(3, pcs[:]) @@ -41,14 +45,24 @@ func FileWithLineNum() string { frame, _ := frames.Next() if (!strings.HasPrefix(frame.File, gormSourceDir) || strings.HasSuffix(frame.File, "_test.go")) && !strings.HasSuffix(frame.File, ".gen.go") { - return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) + return frame } } + return runtime.Frame{} +} + +// FileWithLineNum return the file name and line number of the current file +func FileWithLineNum() string { + frame := CallerFrame() + if frame.PC != 0 { + return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) + } + return "" } -func IsValidDBNameChar(c rune) bool { +func IsInvalidDBNameChar(c rune) bool { return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' } diff --git a/history-api/vendor/modules.txt b/history-api/vendor/modules.txt index 367043d6..36958200 100644 --- a/history-api/vendor/modules.txt +++ b/history-api/vendor/modules.txt @@ -1336,7 +1336,7 @@ gorm.io/driver/clickhouse # gorm.io/driver/postgres v1.6.0 ## explicit; go 1.20 gorm.io/driver/postgres -# gorm.io/gorm v1.31.0 +# gorm.io/gorm v1.31.1 ## explicit; go 1.18 gorm.io/gorm gorm.io/gorm/callbacks diff --git a/lib/go.mod b/lib/go.mod index 79282ab7..b9ef5778 100644 --- a/lib/go.mod +++ b/lib/go.mod @@ -14,7 +14,7 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d google.golang.org/grpc v1.59.0 google.golang.org/protobuf v1.36.6 - gorm.io/gorm v1.30.0 + gorm.io/gorm v1.31.1 gorm.io/plugin/prometheus v0.1.0 ) diff --git a/lib/go.sum b/lib/go.sum index feb0695d..1c2bcac4 100644 --- a/lib/go.sum +++ b/lib/go.sum @@ -592,8 +592,8 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= -gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs= -gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= gorm.io/plugin/prometheus v0.1.0 h1:kDQwAfCUsT9D6jDUpIp7pnc7bCJu/6voM8I/BmFjxUQ= gorm.io/plugin/prometheus v0.1.0/go.mod h1:5nrc/JrWCUNoDXCY4eOae/FK/J5WjQ0axXuFusCzdTc= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/lib/vendor/gorm.io/gorm/README.md b/lib/vendor/gorm.io/gorm/README.md index 745dad60..24eb84c9 100644 --- a/lib/vendor/gorm.io/gorm/README.md +++ b/lib/vendor/gorm.io/gorm/README.md @@ -3,7 +3,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) -[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) +[![test status](https://github.com/go-gorm/gorm/actions/workflows/tests.yml/badge.svg)](https://github.com/go-gorm/gorm/actions) [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) [![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) diff --git a/lib/vendor/gorm.io/gorm/association.go b/lib/vendor/gorm.io/gorm/association.go index e3f51d17..3a4e0e25 100644 --- a/lib/vendor/gorm.io/gorm/association.go +++ b/lib/vendor/gorm.io/gorm/association.go @@ -19,10 +19,10 @@ type Association struct { } func (db *DB) Association(column string) *Association { - association := &Association{DB: db} + association := &Association{DB: db, Unscope: db.Statement.Unscoped} table := db.Statement.Table - if err := db.Statement.Parse(db.Statement.Model); err == nil { + if association.Error = db.Statement.Parse(db.Statement.Model); association.Error == nil { db.Statement.Table = table association.Relationship = db.Statement.Schema.Relationships.Relations[column] @@ -34,8 +34,6 @@ func (db *DB) Association(column string) *Association { for db.Statement.ReflectValue.Kind() == reflect.Ptr { db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() } - } else { - association.Error = err } return association @@ -58,6 +56,8 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro } func (association *Association) Append(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { switch association.Relationship.Type { case schema.HasOne, schema.BelongsTo: @@ -73,6 +73,8 @@ func (association *Association) Append(values ...interface{}) error { } func (association *Association) Replace(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { reflectValue := association.DB.Statement.ReflectValue rel := association.Relationship @@ -97,7 +99,7 @@ func (association *Association) Replace(values ...interface{}) error { return association.Error } - // set old associations's foreign key to null + // set old association's foreign key to null switch rel.Type { case schema.BelongsTo: if len(values) == 0 { @@ -195,6 +197,8 @@ func (association *Association) Replace(values ...interface{}) error { } func (association *Association) Delete(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { var ( reflectValue = association.DB.Statement.ReflectValue @@ -300,7 +304,7 @@ func (association *Association) Delete(values ...interface{}) error { } if association.Error == nil { - // clean up deleted values's foreign key + // clean up deleted values' foreign key relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { @@ -431,10 +435,49 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } } + processMap := func(mapv reflect.Value) { + child := reflect.New(association.Relationship.FieldSchema.ModelType) + + switch association.Relationship.Type { + case schema.HasMany: + for _, ref := range association.Relationship.References { + key := reflect.ValueOf(ref.ForeignKey.DBName) + if ref.OwnPrimaryKey { + v := ref.PrimaryKey.ReflectValueOf(association.DB.Statement.Context, source) + mapv.SetMapIndex(key, v) + } else if ref.PrimaryValue != "" { + mapv.SetMapIndex(key, reflect.ValueOf(ref.PrimaryValue)) + } + } + association.Error = association.DB.Session(&Session{ + NewDB: true, + }).Model(child.Interface()).Create(mapv.Interface()).Error + case schema.Many2Many: + association.Error = association.DB.Session(&Session{ + NewDB: true, + }).Model(child.Interface()).Create(mapv.Interface()).Error + + for _, key := range mapv.MapKeys() { + k := strings.ToLower(key.String()) + if f, ok := association.Relationship.FieldSchema.FieldsByDBName[k]; ok { + _ = f.Set(association.DB.Statement.Context, child, mapv.MapIndex(key).Interface()) + } + } + appendToFieldValues(child) + } + } + switch rv.Kind() { + case reflect.Map: + processMap(rv) case reflect.Slice, reflect.Array: for i := 0; i < rv.Len(); i++ { - appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr()) + elem := reflect.Indirect(rv.Index(i)) + if elem.Kind() == reflect.Map { + processMap(elem) + continue + } + appendToFieldValues(elem.Addr()) } case reflect.Struct: if !rv.CanAddr() { @@ -591,3 +634,32 @@ func (association *Association) buildCondition() *DB { return tx } + +func expandValues(values ...any) (results []any) { + appendToResult := func(rv reflect.Value) { + // unwrap interface + if rv.IsValid() && rv.Kind() == reflect.Interface { + rv = rv.Elem() + } + if rv.IsValid() && rv.Kind() == reflect.Struct { + p := reflect.New(rv.Type()) + p.Elem().Set(rv) + results = append(results, p.Interface()) + } else if rv.IsValid() { + results = append(results, rv.Interface()) + } + } + + // Process each argument; if an argument is a slice/array, expand its elements + for _, value := range values { + rv := reflect.ValueOf(value) + if rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array { + for i := 0; i < rv.Len(); i++ { + appendToResult(rv.Index(i)) + } + } else { + appendToResult(rv) + } + } + return +} diff --git a/lib/vendor/gorm.io/gorm/callbacks.go b/lib/vendor/gorm.io/gorm/callbacks.go index 50b5b0e9..bd97f040 100644 --- a/lib/vendor/gorm.io/gorm/callbacks.go +++ b/lib/vendor/gorm.io/gorm/callbacks.go @@ -89,10 +89,16 @@ func (p *processor) Execute(db *DB) *DB { resetBuildClauses = true } - if optimizer, ok := db.Statement.Dest.(StatementModifier); ok { + if optimizer, ok := stmt.Dest.(StatementModifier); ok { optimizer.ModifyStatement(stmt) } + if db.DefaultContextTimeout > 0 { + if _, ok := stmt.Context.Deadline(); !ok { + stmt.Context, _ = context.WithTimeout(stmt.Context, db.DefaultContextTimeout) + } + } + // assign model values if stmt.Model == nil { stmt.Model = stmt.Dest diff --git a/lib/vendor/gorm.io/gorm/chainable_api.go b/lib/vendor/gorm.io/gorm/chainable_api.go index 8a6aea34..8f6113cc 100644 --- a/lib/vendor/gorm.io/gorm/chainable_api.go +++ b/lib/vendor/gorm.io/gorm/chainable_api.go @@ -178,7 +178,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { - tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsInvalidDBNameChar) } else { tx.Statement.Omits = columns } @@ -283,7 +283,7 @@ func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() - fields := strings.FieldsFunc(name, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(name, utils.IsInvalidDBNameChar) tx.Statement.AddClause(clause.GroupBy{ Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, }) diff --git a/lib/vendor/gorm.io/gorm/clause/association.go b/lib/vendor/gorm.io/gorm/clause/association.go new file mode 100644 index 00000000..a9bf7eb0 --- /dev/null +++ b/lib/vendor/gorm.io/gorm/clause/association.go @@ -0,0 +1,35 @@ +package clause + +// AssociationOpType represents association operation types +type AssociationOpType int + +const ( + OpUnlink AssociationOpType = iota // Unlink association + OpDelete // Delete association records + OpUpdate // Update association records + OpCreate // Create association records with assignments +) + +// Association represents an association operation +type Association struct { + Association string // Association name + Type AssociationOpType // Operation type + Conditions []Expression // Filter conditions + Set []Assignment // Assignment operations (for Update and Create) + Values []interface{} // Values for Create operation +} + +// AssociationAssigner is an interface for association operation providers +type AssociationAssigner interface { + AssociationAssignments() []Association +} + +// Assignments implements the Assigner interface so that AssociationOperation can be used as a Set method parameter +func (ao Association) Assignments() []Assignment { + return []Assignment{} +} + +// AssociationAssignments implements the AssociationAssigner interface +func (ao Association) AssociationAssignments() []Association { + return []Association{ao} +} diff --git a/lib/vendor/gorm.io/gorm/clause/set.go b/lib/vendor/gorm.io/gorm/clause/set.go index 75eb6bdd..cb5f36a0 100644 --- a/lib/vendor/gorm.io/gorm/clause/set.go +++ b/lib/vendor/gorm.io/gorm/clause/set.go @@ -9,6 +9,11 @@ type Assignment struct { Value interface{} } +// Assigner assignments provider interface +type Assigner interface { + Assignments() []Assignment +} + func (set Set) Name() string { return "SET" } @@ -37,6 +42,9 @@ func (set Set) MergeClause(clause *Clause) { clause.Expression = Set(copiedAssignments) } +// Assignments implements Assigner for Set. +func (set Set) Assignments() []Assignment { return []Assignment(set) } + func Assignments(values map[string]interface{}) Set { keys := make([]string, 0, len(values)) for key := range values { @@ -58,3 +66,6 @@ func AssignmentColumns(values []string) Set { } return assignments } + +// Assignments implements Assigner for a single Assignment. +func (a Assignment) Assignments() []Assignment { return []Assignment{a} } diff --git a/lib/vendor/gorm.io/gorm/finisher_api.go b/lib/vendor/gorm.io/gorm/finisher_api.go index 57809d17..e9e35f1b 100644 --- a/lib/vendor/gorm.io/gorm/finisher_api.go +++ b/lib/vendor/gorm.io/gorm/finisher_api.go @@ -465,7 +465,7 @@ func (db *DB) Count(count *int64) (tx *DB) { if len(tx.Statement.Selects) == 1 { dbName := tx.Statement.Selects[0] - fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(dbName, utils.IsInvalidDBNameChar) if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) { if tx.Statement.Parse(tx.Statement.Model) == nil { if f := tx.Statement.Schema.LookUpField(dbName); f != nil { @@ -564,7 +564,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { } if len(tx.Statement.Selects) != 1 { - fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(column, utils.IsInvalidDBNameChar) tx.Statement.AddClauseIfNotExists(clause.Select{ Distinct: tx.Statement.Distinct, Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, @@ -675,9 +675,9 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { } ctx := tx.Statement.Context - if _, ok := ctx.Deadline(); !ok { - if db.Config.DefaultTransactionTimeout > 0 { - ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout) + if db.DefaultTransactionTimeout > 0 { + if _, ok := ctx.Deadline(); !ok { + ctx, _ = context.WithTimeout(ctx, db.DefaultTransactionTimeout) } } diff --git a/lib/vendor/gorm.io/gorm/generics.go b/lib/vendor/gorm.io/gorm/generics.go index ad2d063f..166d1520 100644 --- a/lib/vendor/gorm.io/gorm/generics.go +++ b/lib/vendor/gorm.io/gorm/generics.go @@ -3,12 +3,15 @@ package gorm import ( "context" "database/sql" + "errors" "fmt" + "reflect" "sort" "strings" "gorm.io/gorm/clause" "gorm.io/gorm/logger" + "gorm.io/gorm/schema" ) type result struct { @@ -35,10 +38,34 @@ type Interface[T any] interface { } type CreateInterface[T any] interface { - ChainInterface[T] + ExecInterface[T] + // chain methods available at start; Select/Omit keep CreateInterface to allow Create chaining + Scopes(scopes ...func(db *Statement)) ChainInterface[T] + Where(query interface{}, args ...interface{}) ChainInterface[T] + Not(query interface{}, args ...interface{}) ChainInterface[T] + Or(query interface{}, args ...interface{}) ChainInterface[T] + Limit(offset int) ChainInterface[T] + Offset(offset int) ChainInterface[T] + Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] + Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] + Select(query string, args ...interface{}) CreateInterface[T] + Omit(columns ...string) CreateInterface[T] + MapColumns(m map[string]string) ChainInterface[T] + Distinct(args ...interface{}) ChainInterface[T] + Group(name string) ChainInterface[T] + Having(query interface{}, args ...interface{}) ChainInterface[T] + Order(value interface{}) ChainInterface[T] + Build(builder clause.Builder) + + Delete(ctx context.Context) (rowsAffected int, err error) + Update(ctx context.Context, name string, value any) (rowsAffected int, err error) + Updates(ctx context.Context, t T) (rowsAffected int, err error) + Count(ctx context.Context, column string) (result int64, err error) + Table(name string, args ...interface{}) CreateInterface[T] Create(ctx context.Context, r *T) error CreateInBatches(ctx context.Context, r *[]T, batchSize int) error + Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T] } type ChainInterface[T any] interface { @@ -58,15 +85,28 @@ type ChainInterface[T any] interface { Group(name string) ChainInterface[T] Having(query interface{}, args ...interface{}) ChainInterface[T] Order(value interface{}) ChainInterface[T] + Set(assignments ...clause.Assigner) SetUpdateOnlyInterface[T] Build(builder clause.Builder) + Table(name string, args ...interface{}) ChainInterface[T] Delete(ctx context.Context) (rowsAffected int, err error) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) Updates(ctx context.Context, t T) (rowsAffected int, err error) Count(ctx context.Context, column string) (result int64, err error) } +// SetUpdateOnlyInterface is returned by Set after chaining; only Update is allowed +type SetUpdateOnlyInterface[T any] interface { + Update(ctx context.Context) (rowsAffected int, err error) +} + +// SetCreateOrUpdateInterface is returned by Set at start; Create or Update are allowed +type SetCreateOrUpdateInterface[T any] interface { + Create(ctx context.Context) error + Update(ctx context.Context) (rowsAffected int, err error) +} + type ExecInterface[T any] interface { Scan(ctx context.Context, r interface{}) error First(context.Context) (T, error) @@ -142,13 +182,15 @@ func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] { return execG[T]{g: &g[T]{ db: c.db, ops: append(c.ops, func(db *DB) *DB { - return db.Raw(sql, values...) + var r T + return db.Model(r).Raw(sql, values...) }), }} } func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error { - return c.apply(ctx).Exec(sql, values...).Error + var r T + return c.apply(ctx).Model(r).Exec(sql, values...).Error } type createG[T any] struct { @@ -161,6 +203,22 @@ func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { })} } +func (c createG[T]) Select(query string, args ...interface{}) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Select(query, args...) + })} +} + +func (c createG[T]) Omit(columns ...string) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Omit(columns...) + })} +} + +func (c createG[T]) Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T] { + return c.processSet(assignments...) +} + func (c createG[T]) Create(ctx context.Context, r *T) error { return c.g.apply(ctx).Create(r).Error } @@ -187,6 +245,12 @@ func (c chainG[T]) with(v op) chainG[T] { } } +func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Table(name, args...) + }) +} + func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { return c.with(func(db *DB) *DB { for _, fc := range scopes { @@ -196,12 +260,6 @@ func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { }) } -func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] { - return c.with(func(db *DB) *DB { - return db.Table(name, args...) - }) -} - func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Where(query, args...) @@ -388,6 +446,10 @@ func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] { }) } +func (c chainG[T]) Set(assignments ...clause.Assigner) SetUpdateOnlyInterface[T] { + return c.processSet(assignments...) +} + func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Distinct(args...) @@ -425,12 +487,12 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err relation, ok := db.Statement.Schema.Relationships.Relations[association] if !ok { if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 { - relationships := db.Statement.Schema.Relationships + relationships := &db.Statement.Schema.Relationships for _, field := range preloadFields { var ok bool relation, ok = relationships.Relations[field] if ok { - relationships = relation.FieldSchema.Relationships + relationships = &relation.FieldSchema.Relationships } else { db.AddError(fmt.Errorf("relation %s not found", association)) return nil @@ -567,7 +629,7 @@ func (g execG[T]) First(ctx context.Context) (T, error) { func (g execG[T]) Scan(ctx context.Context, result interface{}) error { var r T - err := g.g.apply(ctx).Model(r).Find(&result).Error + err := g.g.apply(ctx).Model(r).Find(result).Error return err } @@ -603,3 +665,242 @@ func (g execG[T]) Row(ctx context.Context) *sql.Row { func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) { return g.g.apply(ctx).Rows() } + +func (c chainG[T]) processSet(items ...clause.Assigner) setCreateOrUpdateG[T] { + var ( + assigns []clause.Assignment + assocOps []clause.Association + ) + + for _, item := range items { + // Check if it's an AssociationAssigner + if assocAssigner, ok := item.(clause.AssociationAssigner); ok { + assocOps = append(assocOps, assocAssigner.AssociationAssignments()...) + } else { + assigns = append(assigns, item.Assignments()...) + } + } + + return setCreateOrUpdateG[T]{ + c: c, + assigns: assigns, + assocOps: assocOps, + } +} + +// setCreateOrUpdateG[T] is a struct that holds operations to be executed in a batch. +// It supports regular assignments and association operations. +type setCreateOrUpdateG[T any] struct { + c chainG[T] + assigns []clause.Assignment + assocOps []clause.Association +} + +func (s setCreateOrUpdateG[T]) Update(ctx context.Context) (rowsAffected int, err error) { + // Execute association operations + for _, assocOp := range s.assocOps { + if err := s.executeAssociationOperation(ctx, assocOp); err != nil { + return 0, err + } + } + + // Execute assignment operations + if len(s.assigns) > 0 { + var r T + res := s.c.g.apply(ctx).Model(r).Clauses(clause.Set(s.assigns)).Updates(map[string]interface{}{}) + return int(res.RowsAffected), res.Error + } + + return 0, nil +} + +func (s setCreateOrUpdateG[T]) Create(ctx context.Context) error { + // Execute association operations + for _, assocOp := range s.assocOps { + if err := s.executeAssociationOperation(ctx, assocOp); err != nil { + return err + } + } + + // Execute assignment operations + if len(s.assigns) > 0 { + data := make(map[string]interface{}, len(s.assigns)) + for _, a := range s.assigns { + data[a.Column.Name] = a.Value + } + var r T + return s.c.g.apply(ctx).Model(r).Create(data).Error + } + + return nil +} + +// executeAssociationOperation executes an association operation +func (s setCreateOrUpdateG[T]) executeAssociationOperation(ctx context.Context, op clause.Association) error { + var r T + base := s.c.g.apply(ctx).Model(r) + + switch op.Type { + case clause.OpCreate: + return s.handleAssociationCreate(ctx, base, op) + case clause.OpUnlink, clause.OpDelete, clause.OpUpdate: + return s.handleAssociation(ctx, base, op) + default: + return fmt.Errorf("unknown association operation type: %v", op.Type) + } +} + +func (s setCreateOrUpdateG[T]) handleAssociationCreate(ctx context.Context, base *DB, op clause.Association) error { + if len(op.Set) > 0 { + return s.handleAssociationForOwners(base, ctx, func(owner T, assoc *Association) error { + data := make(map[string]interface{}, len(op.Set)) + for _, a := range op.Set { + data[a.Column.Name] = a.Value + } + return assoc.Append(data) + }, op.Association) + } + + return s.handleAssociationForOwners(base, ctx, func(owner T, assoc *Association) error { + return assoc.Append(op.Values...) + }, op.Association) +} + +// handleAssociationForOwners is a helper function that handles associations for all owners +func (s setCreateOrUpdateG[T]) handleAssociationForOwners(base *DB, ctx context.Context, handler func(owner T, association *Association) error, associationName string) error { + var owners []T + if err := base.Find(&owners).Error; err != nil { + return err + } + + for _, owner := range owners { + assoc := base.Session(&Session{NewDB: true, Context: ctx}).Model(&owner).Association(associationName) + if assoc.Error != nil { + return assoc.Error + } + + if err := handler(owner, assoc); err != nil { + return err + } + } + return nil +} + +func (s setCreateOrUpdateG[T]) handleAssociation(ctx context.Context, base *DB, op clause.Association) error { + assoc := base.Association(op.Association) + if assoc.Error != nil { + return assoc.Error + } + + var ( + rel = assoc.Relationship + assocModel = reflect.New(rel.FieldSchema.ModelType).Interface() + fkNil = map[string]any{} + setMap = make(map[string]any, len(op.Set)) + ownerPKNames []string + ownerFKNames []string + primaryColumns []any + foreignColumns []any + ) + + for _, a := range op.Set { + setMap[a.Column.Name] = a.Value + } + + for _, ref := range rel.References { + fkNil[ref.ForeignKey.DBName] = nil + + if ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownerPKNames = append(ownerPKNames, ref.PrimaryKey.DBName) + primaryColumns = append(primaryColumns, clause.Column{Name: ref.PrimaryKey.DBName}) + foreignColumns = append(foreignColumns, clause.Column{Name: ref.ForeignKey.DBName}) + } else if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownerFKNames = append(ownerFKNames, ref.ForeignKey.DBName) + primaryColumns = append(primaryColumns, clause.Column{Name: ref.PrimaryKey.DBName}) + } + } + + assocDB := s.c.g.db.Session(&Session{NewDB: true, Context: ctx}).Model(assocModel).Where(op.Conditions) + + switch rel.Type { + case schema.HasOne, schema.HasMany: + assocDB = assocDB.Where("? IN (?)", foreignColumns, base.Select(ownerPKNames)) + switch op.Type { + case clause.OpUnlink: + return assocDB.Updates(fkNil).Error + case clause.OpDelete: + return assocDB.Delete(assocModel).Error + case clause.OpUpdate: + return assocDB.Updates(setMap).Error + } + case schema.BelongsTo: + switch op.Type { + case clause.OpDelete: + return base.Transaction(func(tx *DB) error { + assocDB.Statement.ConnPool = tx.Statement.ConnPool + base.Statement.ConnPool = tx.Statement.ConnPool + + if err := assocDB.Where("? IN (?)", primaryColumns, base.Select(ownerFKNames)).Delete(assocModel).Error; err != nil { + return err + } + return base.Updates(fkNil).Error + }) + case clause.OpUnlink: + return base.Updates(fkNil).Error + case clause.OpUpdate: + return assocDB.Where("? IN (?)", primaryColumns, base.Select(ownerFKNames)).Updates(setMap).Error + } + case schema.Many2Many: + joinModel := reflect.New(rel.JoinTable.ModelType).Interface() + joinDB := base.Session(&Session{NewDB: true, Context: ctx}).Model(joinModel) + + // EXISTS owners: owners.pk = join.owner_fk for all owner refs + ownersExists := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.Schema.Table).Select("1") + for _, ref := range rel.References { + if ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownersExists = ownersExists.Where(clause.Eq{ + Column: clause.Column{Table: rel.Schema.Table, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + }) + } + } + + // EXISTS related: related.pk = join.rel_fk for all related refs, plus optional conditions + relatedExists := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.FieldSchema.Table).Select("1") + for _, ref := range rel.References { + if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + relatedExists = relatedExists.Where(clause.Eq{ + Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + }) + } + } + relatedExists = relatedExists.Where(op.Conditions) + + switch op.Type { + case clause.OpUnlink, clause.OpDelete: + joinDB = joinDB.Where("EXISTS (?)", ownersExists) + if len(op.Conditions) > 0 { + joinDB = joinDB.Where("EXISTS (?)", relatedExists) + } + return joinDB.Delete(nil).Error + case clause.OpUpdate: + // Update related table rows that have join rows matching owners + relatedDB := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.FieldSchema.Table).Where(op.Conditions) + + // correlated join subquery: join.rel_fk = related.pk AND EXISTS owners + joinSub := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.JoinTable.Table).Select("1") + for _, ref := range rel.References { + if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + joinSub = joinSub.Where(clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, + }) + } + } + joinSub = joinSub.Where("EXISTS (?)", ownersExists) + return relatedDB.Where("EXISTS (?)", joinSub).Updates(setMap).Error + } + } + return errors.New("unsupported relationship") +} diff --git a/lib/vendor/gorm.io/gorm/gorm.go b/lib/vendor/gorm.io/gorm/gorm.go index 67889262..a209bb09 100644 --- a/lib/vendor/gorm.io/gorm/gorm.go +++ b/lib/vendor/gorm.io/gorm/gorm.go @@ -23,6 +23,7 @@ type Config struct { // You can disable it by setting `SkipDefaultTransaction` to true SkipDefaultTransaction bool DefaultTransactionTimeout time.Duration + DefaultContextTimeout time.Duration // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer @@ -137,6 +138,14 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { return isConfig && !isConfig2 }) + if len(opts) > 0 { + if c, ok := opts[0].(*Config); ok { + config = c + } else { + opts = append([]Option{config}, opts...) + } + } + var skipAfterInitialize bool for _, opt := range opts { if opt != nil { diff --git a/lib/vendor/gorm.io/gorm/logger/slog.go b/lib/vendor/gorm.io/gorm/logger/slog.go new file mode 100644 index 00000000..27c74683 --- /dev/null +++ b/lib/vendor/gorm.io/gorm/logger/slog.go @@ -0,0 +1,116 @@ +//go:build go1.21 + +package logger + +import ( + "context" + "errors" + "fmt" + "log/slog" + "time" + + "gorm.io/gorm/utils" +) + +type slogLogger struct { + Logger *slog.Logger + LogLevel LogLevel + SlowThreshold time.Duration + Parameterized bool + Colorful bool // Ignored in slog + IgnoreRecordNotFoundError bool +} + +func NewSlogLogger(logger *slog.Logger, config Config) Interface { + return &slogLogger{ + Logger: logger, + LogLevel: config.LogLevel, + SlowThreshold: config.SlowThreshold, + Parameterized: config.ParameterizedQueries, + IgnoreRecordNotFoundError: config.IgnoreRecordNotFoundError, + } +} + +func (l *slogLogger) LogMode(level LogLevel) Interface { + newLogger := *l + newLogger.LogLevel = level + return &newLogger +} + +func (l *slogLogger) Info(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Info { + l.log(ctx, slog.LevelInfo, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Warn(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Warn { + l.log(ctx, slog.LevelWarn, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Error(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Error { + l.log(ctx, slog.LevelError, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + if l.LogLevel <= Silent { + return + } + + elapsed := time.Since(begin) + sql, rows := fc() + fields := []slog.Attr{ + slog.String("duration", fmt.Sprintf("%.3fms", float64(elapsed.Nanoseconds())/1e6)), + slog.String("sql", sql), + } + + if rows != -1 { + fields = append(fields, slog.Int64("rows", rows)) + } + + switch { + case err != nil && (!l.IgnoreRecordNotFoundError || !errors.Is(err, ErrRecordNotFound)): + fields = append(fields, slog.String("error", err.Error())) + l.log(ctx, slog.LevelError, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + + case l.SlowThreshold != 0 && elapsed > l.SlowThreshold: + l.log(ctx, slog.LevelWarn, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + + case l.LogLevel >= Info: + l.log(ctx, slog.LevelInfo, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + } +} + +func (l *slogLogger) log(ctx context.Context, level slog.Level, msg string, args ...any) { + if ctx == nil { + ctx = context.Background() + } + + if !l.Logger.Enabled(ctx, level) { + return + } + + r := slog.NewRecord(time.Now(), level, msg, utils.CallerFrame().PC) + r.Add(args...) + _ = l.Logger.Handler().Handle(ctx, r) +} + +// ParamsFilter filter params +func (l *slogLogger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { + if l.Parameterized { + return sql, nil + } + return sql, params +} diff --git a/lib/vendor/gorm.io/gorm/schema/field.go b/lib/vendor/gorm.io/gorm/schema/field.go index a6ff1a72..de797402 100644 --- a/lib/vendor/gorm.io/gorm/schema/field.go +++ b/lib/vendor/gorm.io/gorm/schema/field.go @@ -448,16 +448,17 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } // create valuer, setter when parse struct -func (field *Field) setupValuerAndSetter() { +func (field *Field) setupValuerAndSetter(modelType reflect.Type) { // Setup NewValuePool field.setupNewValuePool() // ValueOf returns field's value and if it is zero fieldIndex := field.StructField.Index[0] switch { - case len(field.StructField.Index) == 1 && fieldIndex > 0: - field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) { - fieldValue := reflect.Indirect(value).Field(fieldIndex) + case len(field.StructField.Index) == 1 && fieldIndex >= 0: + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + v = reflect.Indirect(v) + fieldValue := v.Field(fieldIndex) return fieldValue.Interface(), fieldValue.IsZero() } default: @@ -504,9 +505,10 @@ func (field *Field) setupValuerAndSetter() { // ReflectValueOf returns field's reflect value switch { - case len(field.StructField.Index) == 1 && fieldIndex > 0: - field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value { - return reflect.Indirect(value).Field(fieldIndex) + case len(field.StructField.Index) == 1 && fieldIndex >= 0: + field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { + v = reflect.Indirect(v) + return v.Field(fieldIndex) } default: field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { diff --git a/lib/vendor/gorm.io/gorm/schema/relationship.go b/lib/vendor/gorm.io/gorm/schema/relationship.go index f1ace924..0535bba4 100644 --- a/lib/vendor/gorm.io/gorm/schema/relationship.go +++ b/lib/vendor/gorm.io/gorm/schema/relationship.go @@ -75,9 +75,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { } ) - cacheStore := schema.cacheStore - - if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { + if relation.FieldSchema, err = getOrParse(fieldValue, schema.cacheStore, schema.namer); err != nil { schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err) return nil } @@ -147,6 +145,9 @@ func hasPolymorphicRelation(tagSettings map[string]string) bool { } func (schema *Schema) setRelation(relation *Relationship) { + schema.Relationships.Mux.Lock() + defer schema.Relationships.Mux.Unlock() + // set non-embedded relation if rel := schema.Relationships.Relations[relation.Name]; rel != nil { if len(rel.Field.BindNames) > 1 { @@ -590,6 +591,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu // build references for idx, foreignField := range foreignFields { // use same data type for foreign keys + schema.Relationships.Mux.Lock() + if schema != foreignField.Schema { + foreignField.Schema.Relationships.Mux.Lock() + } if copyableDataType(primaryFields[idx].DataType) { foreignField.DataType = primaryFields[idx].DataType } @@ -597,6 +602,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu if foreignField.Size == 0 { foreignField.Size = primaryFields[idx].Size } + schema.Relationships.Mux.Unlock() + if schema != foreignField.Schema { + foreignField.Schema.Relationships.Mux.Unlock() + } relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], diff --git a/lib/vendor/gorm.io/gorm/schema/schema.go b/lib/vendor/gorm.io/gorm/schema/schema.go index db236797..09697a7a 100644 --- a/lib/vendor/gorm.io/gorm/schema/schema.go +++ b/lib/vendor/gorm.io/gorm/schema/schema.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "go/ast" + "path" "reflect" "strings" "sync" @@ -59,14 +60,14 @@ type Schema struct { cacheStore *sync.Map } -func (schema Schema) String() string { +func (schema *Schema) String() string { if schema.ModelType.Name() == "" { return fmt.Sprintf("%s(%s)", schema.Name, schema.Table) } return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name()) } -func (schema Schema) MakeSlice() reflect.Value { +func (schema *Schema) MakeSlice() reflect.Value { slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20) results := reflect.New(slice.Type()) results.Elem().Set(slice) @@ -74,13 +75,23 @@ func (schema Schema) MakeSlice() reflect.Value { return results } -func (schema Schema) LookUpField(name string) *Field { +func (schema *Schema) LookUpField(name string) *Field { if field, ok := schema.FieldsByDBName[name]; ok { return field } if field, ok := schema.FieldsByName[name]; ok { return field } + + // Lookup field using namer-driven ColumnName + if schema.namer == nil { + return nil + } + namerColumnName := schema.namer.ColumnName(schema.Table, name) + if field, ok := schema.FieldsByDBName[namerColumnName]; ok { + return field + } + return nil } @@ -92,10 +103,7 @@ func (schema Schema) LookUpField(name string) *Field { // } // ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID") // } -func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { - if len(bindNames) == 0 { - return nil - } +func (schema *Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { for i := len(bindNames) - 1; i >= 0; i-- { find := strings.Join(bindNames[:i], ".") + "." + name if field, ok := schema.FieldsByBindName[find]; ok { @@ -113,6 +121,14 @@ type TablerWithNamer interface { TableName(Namer) string } +var callbackTypes = []callbackType{ + callbackTypeBeforeCreate, callbackTypeAfterCreate, + callbackTypeBeforeUpdate, callbackTypeAfterUpdate, + callbackTypeBeforeSave, callbackTypeAfterSave, + callbackTypeBeforeDelete, callbackTypeAfterDelete, + callbackTypeAfterFind, +} + // Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { return ParseWithSpecialTableName(dest, cacheStore, namer, "") @@ -124,34 +140,33 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - value := reflect.ValueOf(dest) - if value.Kind() == reflect.Ptr && value.IsNil() { - value = reflect.New(value.Type().Elem()) - } - modelType := reflect.Indirect(value).Type() - - if modelType.Kind() == reflect.Interface { - modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() - } - - for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType := reflect.ValueOf(dest).Type() + if modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } if modelType.Kind() != reflect.Struct { - if modelType.PkgPath() == "" { - return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + if modelType.Kind() == reflect.Interface { + modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() + } + + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } - return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } // Cache the Schema for performance, // Use the modelType or modelType + schemaTable (if it present) as cache key. - var schemaCacheKey interface{} + var schemaCacheKey interface{} = modelType if specialTableName != "" { schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName) - } else { - schemaCacheKey = modelType } // Load exist schema cache, return if exists @@ -162,28 +177,29 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return s, s.err } + var tableName string modelValue := reflect.New(modelType) - tableName := namer.TableName(modelType.Name()) - if tabler, ok := modelValue.Interface().(Tabler); ok { + if specialTableName != "" { + tableName = specialTableName + } else if en, ok := namer.(embeddedNamer); ok { + tableName = en.Table + } else if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() - } - if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { + } else if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { tableName = tabler.TableName(namer) - } - if en, ok := namer.(embeddedNamer); ok { - tableName = en.Table - } - if specialTableName != "" && specialTableName != tableName { - tableName = specialTableName + } else { + tableName = namer.TableName(modelType.Name()) } schema := &Schema{ Name: modelType.Name(), ModelType: modelType, Table: tableName, - FieldsByName: map[string]*Field{}, - FieldsByBindName: map[string]*Field{}, - FieldsByDBName: map[string]*Field{}, + DBNames: make([]string, 0, 10), + Fields: make([]*Field, 0, 10), + FieldsByName: make(map[string]*Field, 10), + FieldsByBindName: make(map[string]*Field, 10), + FieldsByDBName: make(map[string]*Field, 10), Relationships: Relationships{Relations: map[string]*Relationship{}}, cacheStore: cacheStore, namer: namer, @@ -227,8 +243,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.FieldsByBindName[bindName] = field if v != nil && v.PrimaryKey { + // remove the existing primary key field for idx, f := range schema.PrimaryFields { - if f == v { + if f.DBName == v.DBName { schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) } } @@ -247,7 +264,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.FieldsByBindName[bindName] = field } - field.setupValuerAndSetter() + field.setupValuerAndSetter(modelType) } prioritizedPrimaryField := schema.LookUpField("id") @@ -283,10 +300,37 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) } + _, embedded := schema.cacheStore.Load(embeddedCacheKey) + relationshipFields := []*Field{} for _, field := range schema.Fields { if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil { schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } + + if !embedded { + if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) { + relationshipFields = append(relationshipFields, field) + schema.FieldsByName[field.Name] = field + schema.FieldsByBindName[field.BindName()] = field + } + + fieldValue := reflect.New(field.IndirectFieldType).Interface() + if fc, ok := fieldValue.(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } + + if fc, ok := fieldValue.(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } + + if fc, ok := fieldValue.(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } + + if fc, ok := fieldValue.(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + } + } } if field := schema.PrioritizedPrimaryField; field != nil { @@ -303,24 +347,6 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } } - callbackTypes := []callbackType{ - callbackTypeBeforeCreate, callbackTypeAfterCreate, - callbackTypeBeforeUpdate, callbackTypeAfterUpdate, - callbackTypeBeforeSave, callbackTypeAfterSave, - callbackTypeBeforeDelete, callbackTypeAfterDelete, - callbackTypeAfterFind, - } - for _, cbName := range callbackTypes { - if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() { - switch methodValue.Type().String() { - case "func(*gorm.DB) error": // TODO hack - reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) - default: - logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) - } - } - } - // Cache the schema if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded { s := v.(*Schema) @@ -336,84 +362,47 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } }() - if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { - for _, field := range schema.Fields { - if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) { - if schema.parseRelation(field); schema.err != nil { - return schema, schema.err + for _, cbName := range callbackTypes { + if methodValue := modelValue.MethodByName(string(cbName)); methodValue.IsValid() { + switch methodValue.Type().String() { + case "func(*gorm.DB) error": + expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath()) + if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath { + reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) } else { - schema.FieldsByName[field.Name] = field - schema.FieldsByBindName[field.BindName()] = field + logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg) + // PASS } + default: + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) } + } + } - fieldValue := reflect.New(field.IndirectFieldType) - fieldInterface := fieldValue.Interface() - if fc, ok := fieldInterface.(CreateClausesInterface); ok { - field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) - } - - if fc, ok := fieldInterface.(QueryClausesInterface); ok { - field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) - } - - if fc, ok := fieldInterface.(UpdateClausesInterface); ok { - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) - } - - if fc, ok := fieldInterface.(DeleteClausesInterface); ok { - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) - } + // parse relationships + for _, field := range relationshipFields { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err } } return schema, schema.err } -// This unrolling is needed to show to the compiler the exact set of methods -// that can be used on the modelType. -// Prior to go1.22 any use of MethodByName would cause the linker to -// abandon dead code elimination for the entire binary. -// As of go1.22 the compiler supports one special case of a string constant -// being passed to MethodByName. For enterprise customers or those building -// large binaries, this gives a significant reduction in binary size. -// https://github.com/golang/go/issues/62257 -func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value { - switch cbType { - case callbackTypeBeforeCreate: - return modelType.MethodByName(string(callbackTypeBeforeCreate)) - case callbackTypeAfterCreate: - return modelType.MethodByName(string(callbackTypeAfterCreate)) - case callbackTypeBeforeUpdate: - return modelType.MethodByName(string(callbackTypeBeforeUpdate)) - case callbackTypeAfterUpdate: - return modelType.MethodByName(string(callbackTypeAfterUpdate)) - case callbackTypeBeforeSave: - return modelType.MethodByName(string(callbackTypeBeforeSave)) - case callbackTypeAfterSave: - return modelType.MethodByName(string(callbackTypeAfterSave)) - case callbackTypeBeforeDelete: - return modelType.MethodByName(string(callbackTypeBeforeDelete)) - case callbackTypeAfterDelete: - return modelType.MethodByName(string(callbackTypeAfterDelete)) - case callbackTypeAfterFind: - return modelType.MethodByName(string(callbackTypeAfterFind)) - default: - return reflect.ValueOf(nil) - } -} - func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() - for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } if modelType.Kind() != reflect.Struct { - if modelType.PkgPath() == "" { - return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } - return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } if v, ok := cacheStore.Load(modelType); ok { diff --git a/lib/vendor/gorm.io/gorm/schema/serializer.go b/lib/vendor/gorm.io/gorm/schema/serializer.go index 0fafbcba..d774a8ed 100644 --- a/lib/vendor/gorm.io/gorm/schema/serializer.go +++ b/lib/vendor/gorm.io/gorm/schema/serializer.go @@ -8,6 +8,7 @@ import ( "encoding/gob" "encoding/json" "fmt" + "math" "reflect" "strings" "sync" @@ -127,16 +128,31 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect. // Value implements serializer interface func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { rv := reflect.ValueOf(fieldValue) - switch v := fieldValue.(type) { - case int64, int, uint, uint64, int32, uint32, int16, uint16: - result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() - case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: + switch fieldValue.(type) { + case int, int8, int16, int32, int64: + result = time.Unix(rv.Int(), 0).UTC() + case uint, uint8, uint16, uint32, uint64: + if uv := rv.Uint(); uv > math.MaxInt64 { + err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) + } else { + result = time.Unix(int64(uv), 0).UTC() //nolint:gosec + } + case *int, *int8, *int16, *int32, *int64: if rv.IsZero() { return nil, nil } - result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() + result = time.Unix(rv.Elem().Int(), 0).UTC() + case *uint, *uint8, *uint16, *uint32, *uint64: + if rv.IsZero() { + return nil, nil + } + if uv := rv.Elem().Uint(); uv > math.MaxInt64 { + err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) + } else { + result = time.Unix(int64(uv), 0).UTC() //nolint:gosec + } default: - err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) + err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", fieldValue) } return } diff --git a/lib/vendor/gorm.io/gorm/schema/utils.go b/lib/vendor/gorm.io/gorm/schema/utils.go index fa1c65d4..86305d7b 100644 --- a/lib/vendor/gorm.io/gorm/schema/utils.go +++ b/lib/vendor/gorm.io/gorm/schema/utils.go @@ -17,25 +17,23 @@ func ParseTagSetting(str string, sep string) map[string]string { settings := map[string]string{} names := strings.Split(str, sep) + var parsedNames []string for i := 0; i < len(names); i++ { - j := i - if len(names[j]) > 0 { - for { - if names[j][len(names[j])-1] == '\\' { - i++ - names[j] = names[j][0:len(names[j])-1] + sep + names[i] - names[i] = "" - } else { - break - } - } + s := names[i] + for strings.HasSuffix(s, "\\") && i+1 < len(names) { + i++ + s = s[:len(s)-1] + sep + names[i] } + parsedNames = append(parsedNames, s) + } - values := strings.Split(names[j], ":") + for _, tag := range parsedNames { + values := strings.Split(tag, ":") k := strings.TrimSpace(strings.ToUpper(values[0])) - if len(values) >= 2 { - settings[k] = strings.Join(values[1:], ":") + val := strings.Join(values[1:], ":") + val = strings.ReplaceAll(val, `\"`, `"`) + settings[k] = val } else if k != "" { settings[k] = k } @@ -121,6 +119,17 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, } switch reflectValue.Kind() { + case reflect.Map: + results = [][]interface{}{make([]interface{}, len(fields))} + for idx, field := range fields { + mapValue := reflectValue.MapIndex(reflect.ValueOf(field.DBName)) + if mapValue.IsZero() { + mapValue = reflectValue.MapIndex(reflect.ValueOf(field.Name)) + } + results[0][idx] = mapValue.Interface() + } + + dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue} case reflect.Struct: results = [][]interface{}{make([]interface{}, len(fields))} diff --git a/lib/vendor/gorm.io/gorm/statement.go b/lib/vendor/gorm.io/gorm/statement.go index c6183724..736087d7 100644 --- a/lib/vendor/gorm.io/gorm/statement.go +++ b/lib/vendor/gorm.io/gorm/statement.go @@ -96,7 +96,9 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { if v.Name == clause.CurrentTable { if stmt.TableExpr != nil { stmt.TableExpr.Build(stmt) - } else { + } else if stmt.Table != "" { + write(v.Raw, stmt.Table) + } else if stmt.AddError(stmt.Parse(stmt.Model)) == nil { write(v.Raw, stmt.Table) } } else { @@ -334,6 +336,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] switch v := arg.(type) { case clause.Expression: conds = append(conds, v) + case []clause.Expression: + conds = append(conds, v...) case *DB: v.executeScopes() @@ -341,7 +345,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if where, ok := cs.Expression.(clause.Where); ok { if len(where.Exprs) == 1 { if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { - where.Exprs[0] = clause.AndConditions(orConds) + if len(orConds.Exprs) == 1 { + where.Exprs[0] = clause.AndConditions(orConds) + } } } conds = append(conds, clause.And(where.Exprs...)) @@ -362,6 +368,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, key := range keys { column := clause.Column{Name: key, Table: curTable} + if strings.Contains(key, ".") { + column = clause.Column{Name: key} + } conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } case map[string]interface{}: @@ -374,6 +383,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, key := range keys { reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) column := clause.Column{Name: key, Table: curTable} + if strings.Contains(key, ".") { + column = clause.Column{Name: key} + } switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if _, ok := v[key].(driver.Valuer); ok { @@ -650,12 +662,15 @@ func (stmt *Statement) Changed(fields ...string) bool { for destValue.Kind() == reflect.Ptr { destValue = destValue.Elem() } - - changedValue, zero := field.ValueOf(stmt.Context, destValue) - if v { - return !utils.AssertEqual(changedValue, fieldValue) + if descSchema, err := schema.Parse(stmt.Dest, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + if destField := descSchema.LookUpField(field.DBName); destField != nil { + changedValue, zero := destField.ValueOf(stmt.Context, destValue) + if v { + return !utils.AssertEqual(changedValue, fieldValue) + } + return !zero && !utils.AssertEqual(changedValue, fieldValue) + } } - return !zero && !utils.AssertEqual(changedValue, fieldValue) } } return false diff --git a/lib/vendor/gorm.io/gorm/utils/utils.go b/lib/vendor/gorm.io/gorm/utils/utils.go index fc615d73..7e59264b 100644 --- a/lib/vendor/gorm.io/gorm/utils/utils.go +++ b/lib/vendor/gorm.io/gorm/utils/utils.go @@ -30,8 +30,12 @@ func sourceDir(file string) string { return filepath.ToSlash(s) + "/" } -// FileWithLineNum return the file name and line number of the current file -func FileWithLineNum() string { +// CallerFrame retrieves the first relevant stack frame outside of GORM's internal implementation files. +// It skips: +// - GORM's core source files (identified by gormSourceDir prefix) +// - Exclude test files (*_test.go) +// - go-gorm/gen's Generated files (*.gen.go) +func CallerFrame() runtime.Frame { pcs := [13]uintptr{} // the third caller usually from gorm internal len := runtime.Callers(3, pcs[:]) @@ -41,14 +45,24 @@ func FileWithLineNum() string { frame, _ := frames.Next() if (!strings.HasPrefix(frame.File, gormSourceDir) || strings.HasSuffix(frame.File, "_test.go")) && !strings.HasSuffix(frame.File, ".gen.go") { - return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) + return frame } } + return runtime.Frame{} +} + +// FileWithLineNum return the file name and line number of the current file +func FileWithLineNum() string { + frame := CallerFrame() + if frame.PC != 0 { + return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) + } + return "" } -func IsValidDBNameChar(c rune) bool { +func IsInvalidDBNameChar(c rune) bool { return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' } diff --git a/lib/vendor/modules.txt b/lib/vendor/modules.txt index 23023b19..faed04f7 100644 --- a/lib/vendor/modules.txt +++ b/lib/vendor/modules.txt @@ -191,7 +191,7 @@ google.golang.org/protobuf/types/gofeaturespb google.golang.org/protobuf/types/known/anypb google.golang.org/protobuf/types/known/durationpb google.golang.org/protobuf/types/known/timestamppb -# gorm.io/gorm v1.30.0 +# gorm.io/gorm v1.31.1 ## explicit; go 1.18 gorm.io/gorm gorm.io/gorm/clause diff --git a/notifier/go.mod b/notifier/go.mod index 86e7e062..6dc42822 100644 --- a/notifier/go.mod +++ b/notifier/go.mod @@ -26,7 +26,7 @@ require ( google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1 google.golang.org/protobuf v1.36.10 gorm.io/driver/postgres v1.6.0 - gorm.io/gorm v1.31.0 + gorm.io/gorm v1.31.1 ) require ( diff --git a/notifier/go.sum b/notifier/go.sum index ecf43287..27df52ba 100644 --- a/notifier/go.sum +++ b/notifier/go.sum @@ -673,8 +673,8 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= -gorm.io/gorm v1.31.0 h1:0VlycGreVhK7RF/Bwt51Fk8v0xLiiiFdbGDPIZQ7mJY= -gorm.io/gorm v1.31.0/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= honnef.co/go/tools v0.6.1 h1:R094WgE8K4JirYjBaOpz/AvTyUu/3wbmAoskKN/pxTI= honnef.co/go/tools v0.6.1/go.mod h1:3puzxxljPCe8RGJX7BIy1plGbxEOZni5mR2aXe3/uk4= mvdan.cc/gofumpt v0.7.0 h1:bg91ttqXmi9y2xawvkuMXyvAA/1ZGJqYAEGjXuP0JXU= diff --git a/notifier/vendor/gorm.io/gorm/README.md b/notifier/vendor/gorm.io/gorm/README.md index 745dad60..24eb84c9 100644 --- a/notifier/vendor/gorm.io/gorm/README.md +++ b/notifier/vendor/gorm.io/gorm/README.md @@ -3,7 +3,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) -[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) +[![test status](https://github.com/go-gorm/gorm/actions/workflows/tests.yml/badge.svg)](https://github.com/go-gorm/gorm/actions) [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) [![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) diff --git a/notifier/vendor/gorm.io/gorm/association.go b/notifier/vendor/gorm.io/gorm/association.go index f210ca0a..3a4e0e25 100644 --- a/notifier/vendor/gorm.io/gorm/association.go +++ b/notifier/vendor/gorm.io/gorm/association.go @@ -99,7 +99,7 @@ func (association *Association) Replace(values ...interface{}) error { return association.Error } - // set old associations's foreign key to null + // set old association's foreign key to null switch rel.Type { case schema.BelongsTo: if len(values) == 0 { @@ -304,7 +304,7 @@ func (association *Association) Delete(values ...interface{}) error { } if association.Error == nil { - // clean up deleted values's foreign key + // clean up deleted values' foreign key relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { diff --git a/notifier/vendor/gorm.io/gorm/chainable_api.go b/notifier/vendor/gorm.io/gorm/chainable_api.go index 8a6aea34..8f6113cc 100644 --- a/notifier/vendor/gorm.io/gorm/chainable_api.go +++ b/notifier/vendor/gorm.io/gorm/chainable_api.go @@ -178,7 +178,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { - tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsInvalidDBNameChar) } else { tx.Statement.Omits = columns } @@ -283,7 +283,7 @@ func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() - fields := strings.FieldsFunc(name, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(name, utils.IsInvalidDBNameChar) tx.Statement.AddClause(clause.GroupBy{ Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, }) diff --git a/notifier/vendor/gorm.io/gorm/finisher_api.go b/notifier/vendor/gorm.io/gorm/finisher_api.go index e601fe66..e9e35f1b 100644 --- a/notifier/vendor/gorm.io/gorm/finisher_api.go +++ b/notifier/vendor/gorm.io/gorm/finisher_api.go @@ -465,7 +465,7 @@ func (db *DB) Count(count *int64) (tx *DB) { if len(tx.Statement.Selects) == 1 { dbName := tx.Statement.Selects[0] - fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(dbName, utils.IsInvalidDBNameChar) if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) { if tx.Statement.Parse(tx.Statement.Model) == nil { if f := tx.Statement.Schema.LookUpField(dbName); f != nil { @@ -564,7 +564,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { } if len(tx.Statement.Selects) != 1 { - fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(column, utils.IsInvalidDBNameChar) tx.Statement.AddClauseIfNotExists(clause.Select{ Distinct: tx.Statement.Distinct, Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, diff --git a/notifier/vendor/gorm.io/gorm/generics.go b/notifier/vendor/gorm.io/gorm/generics.go index 79238d5f..166d1520 100644 --- a/notifier/vendor/gorm.io/gorm/generics.go +++ b/notifier/vendor/gorm.io/gorm/generics.go @@ -39,7 +39,7 @@ type Interface[T any] interface { type CreateInterface[T any] interface { ExecInterface[T] - // chain methods available at start; return ChainInterface + // chain methods available at start; Select/Omit keep CreateInterface to allow Create chaining Scopes(scopes ...func(db *Statement)) ChainInterface[T] Where(query interface{}, args ...interface{}) ChainInterface[T] Not(query interface{}, args ...interface{}) ChainInterface[T] @@ -48,8 +48,8 @@ type CreateInterface[T any] interface { Offset(offset int) ChainInterface[T] Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] - Select(query string, args ...interface{}) ChainInterface[T] - Omit(columns ...string) ChainInterface[T] + Select(query string, args ...interface{}) CreateInterface[T] + Omit(columns ...string) CreateInterface[T] MapColumns(m map[string]string) ChainInterface[T] Distinct(args ...interface{}) ChainInterface[T] Group(name string) ChainInterface[T] @@ -203,6 +203,18 @@ func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { })} } +func (c createG[T]) Select(query string, args ...interface{}) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Select(query, args...) + })} +} + +func (c createG[T]) Omit(columns ...string) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Omit(columns...) + })} +} + func (c createG[T]) Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T] { return c.processSet(assignments...) } diff --git a/notifier/vendor/gorm.io/gorm/logger/slog.go b/notifier/vendor/gorm.io/gorm/logger/slog.go index 613234ca..27c74683 100644 --- a/notifier/vendor/gorm.io/gorm/logger/slog.go +++ b/notifier/vendor/gorm.io/gorm/logger/slog.go @@ -8,6 +8,8 @@ import ( "fmt" "log/slog" "time" + + "gorm.io/gorm/utils" ) type slogLogger struct { @@ -37,19 +39,19 @@ func (l *slogLogger) LogMode(level LogLevel) Interface { func (l *slogLogger) Info(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Info { - l.Logger.InfoContext(ctx, msg, slog.Any("data", data)) + l.log(ctx, slog.LevelInfo, msg, slog.Any("data", data)) } } func (l *slogLogger) Warn(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Warn { - l.Logger.WarnContext(ctx, msg, slog.Any("data", data)) + l.log(ctx, slog.LevelWarn, msg, slog.Any("data", data)) } } func (l *slogLogger) Error(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Error { - l.Logger.ErrorContext(ctx, msg, slog.Any("data", data)) + l.log(ctx, slog.LevelError, msg, slog.Any("data", data)) } } @@ -72,25 +74,39 @@ func (l *slogLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql switch { case err != nil && (!l.IgnoreRecordNotFoundError || !errors.Is(err, ErrRecordNotFound)): fields = append(fields, slog.String("error", err.Error())) - l.Logger.ErrorContext(ctx, "SQL executed", slog.Attr{ + l.log(ctx, slog.LevelError, "SQL executed", slog.Attr{ Key: "trace", Value: slog.GroupValue(fields...), }) case l.SlowThreshold != 0 && elapsed > l.SlowThreshold: - l.Logger.WarnContext(ctx, "SQL executed", slog.Attr{ + l.log(ctx, slog.LevelWarn, "SQL executed", slog.Attr{ Key: "trace", Value: slog.GroupValue(fields...), }) case l.LogLevel >= Info: - l.Logger.InfoContext(ctx, "SQL executed", slog.Attr{ + l.log(ctx, slog.LevelInfo, "SQL executed", slog.Attr{ Key: "trace", Value: slog.GroupValue(fields...), }) } } +func (l *slogLogger) log(ctx context.Context, level slog.Level, msg string, args ...any) { + if ctx == nil { + ctx = context.Background() + } + + if !l.Logger.Enabled(ctx, level) { + return + } + + r := slog.NewRecord(time.Now(), level, msg, utils.CallerFrame().PC) + r.Add(args...) + _ = l.Logger.Handler().Handle(ctx, r) +} + // ParamsFilter filter params func (l *slogLogger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { if l.Parameterized { diff --git a/notifier/vendor/gorm.io/gorm/migrator/migrator.go b/notifier/vendor/gorm.io/gorm/migrator/migrator.go index 50a36d10..35107d57 100644 --- a/notifier/vendor/gorm.io/gorm/migrator/migrator.go +++ b/notifier/vendor/gorm.io/gorm/migrator/migrator.go @@ -560,6 +560,10 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy v1, _ := strconv.ParseBool(dv) v2, _ := strconv.ParseBool(field.DefaultValue) alterColumn = v1 != v2 + case schema.String: + if dv != field.DefaultValue && dv != strings.Trim(field.DefaultValue, "'\"") { + alterColumn = true + } default: alterColumn = dv != field.DefaultValue } diff --git a/notifier/vendor/gorm.io/gorm/schema/schema.go b/notifier/vendor/gorm.io/gorm/schema/schema.go index 9419846b..09697a7a 100644 --- a/notifier/vendor/gorm.io/gorm/schema/schema.go +++ b/notifier/vendor/gorm.io/gorm/schema/schema.go @@ -82,6 +82,16 @@ func (schema *Schema) LookUpField(name string) *Field { if field, ok := schema.FieldsByName[name]; ok { return field } + + // Lookup field using namer-driven ColumnName + if schema.namer == nil { + return nil + } + namerColumnName := schema.namer.ColumnName(schema.Table, name) + if field, ok := schema.FieldsByDBName[namerColumnName]; ok { + return field + } + return nil } diff --git a/notifier/vendor/gorm.io/gorm/schema/serializer.go b/notifier/vendor/gorm.io/gorm/schema/serializer.go index 0fafbcba..d774a8ed 100644 --- a/notifier/vendor/gorm.io/gorm/schema/serializer.go +++ b/notifier/vendor/gorm.io/gorm/schema/serializer.go @@ -8,6 +8,7 @@ import ( "encoding/gob" "encoding/json" "fmt" + "math" "reflect" "strings" "sync" @@ -127,16 +128,31 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect. // Value implements serializer interface func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { rv := reflect.ValueOf(fieldValue) - switch v := fieldValue.(type) { - case int64, int, uint, uint64, int32, uint32, int16, uint16: - result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() - case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: + switch fieldValue.(type) { + case int, int8, int16, int32, int64: + result = time.Unix(rv.Int(), 0).UTC() + case uint, uint8, uint16, uint32, uint64: + if uv := rv.Uint(); uv > math.MaxInt64 { + err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) + } else { + result = time.Unix(int64(uv), 0).UTC() //nolint:gosec + } + case *int, *int8, *int16, *int32, *int64: if rv.IsZero() { return nil, nil } - result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() + result = time.Unix(rv.Elem().Int(), 0).UTC() + case *uint, *uint8, *uint16, *uint32, *uint64: + if rv.IsZero() { + return nil, nil + } + if uv := rv.Elem().Uint(); uv > math.MaxInt64 { + err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) + } else { + result = time.Unix(int64(uv), 0).UTC() //nolint:gosec + } default: - err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) + err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", fieldValue) } return } diff --git a/notifier/vendor/gorm.io/gorm/schema/utils.go b/notifier/vendor/gorm.io/gorm/schema/utils.go index d4fe252e..86305d7b 100644 --- a/notifier/vendor/gorm.io/gorm/schema/utils.go +++ b/notifier/vendor/gorm.io/gorm/schema/utils.go @@ -17,25 +17,23 @@ func ParseTagSetting(str string, sep string) map[string]string { settings := map[string]string{} names := strings.Split(str, sep) + var parsedNames []string for i := 0; i < len(names); i++ { - j := i - if len(names[j]) > 0 { - for { - if names[j][len(names[j])-1] == '\\' { - i++ - names[j] = names[j][0:len(names[j])-1] + sep + names[i] - names[i] = "" - } else { - break - } - } + s := names[i] + for strings.HasSuffix(s, "\\") && i+1 < len(names) { + i++ + s = s[:len(s)-1] + sep + names[i] } + parsedNames = append(parsedNames, s) + } - values := strings.Split(names[j], ":") + for _, tag := range parsedNames { + values := strings.Split(tag, ":") k := strings.TrimSpace(strings.ToUpper(values[0])) - if len(values) >= 2 { - settings[k] = strings.Join(values[1:], ":") + val := strings.Join(values[1:], ":") + val = strings.ReplaceAll(val, `\"`, `"`) + settings[k] = val } else if k != "" { settings[k] = k } diff --git a/notifier/vendor/gorm.io/gorm/utils/utils.go b/notifier/vendor/gorm.io/gorm/utils/utils.go index fc615d73..7e59264b 100644 --- a/notifier/vendor/gorm.io/gorm/utils/utils.go +++ b/notifier/vendor/gorm.io/gorm/utils/utils.go @@ -30,8 +30,12 @@ func sourceDir(file string) string { return filepath.ToSlash(s) + "/" } -// FileWithLineNum return the file name and line number of the current file -func FileWithLineNum() string { +// CallerFrame retrieves the first relevant stack frame outside of GORM's internal implementation files. +// It skips: +// - GORM's core source files (identified by gormSourceDir prefix) +// - Exclude test files (*_test.go) +// - go-gorm/gen's Generated files (*.gen.go) +func CallerFrame() runtime.Frame { pcs := [13]uintptr{} // the third caller usually from gorm internal len := runtime.Callers(3, pcs[:]) @@ -41,14 +45,24 @@ func FileWithLineNum() string { frame, _ := frames.Next() if (!strings.HasPrefix(frame.File, gormSourceDir) || strings.HasSuffix(frame.File, "_test.go")) && !strings.HasSuffix(frame.File, ".gen.go") { - return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) + return frame } } + return runtime.Frame{} +} + +// FileWithLineNum return the file name and line number of the current file +func FileWithLineNum() string { + frame := CallerFrame() + if frame.PC != 0 { + return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) + } + return "" } -func IsValidDBNameChar(c rune) bool { +func IsInvalidDBNameChar(c rune) bool { return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' } diff --git a/notifier/vendor/modules.txt b/notifier/vendor/modules.txt index c8f68649..04061cec 100644 --- a/notifier/vendor/modules.txt +++ b/notifier/vendor/modules.txt @@ -1284,7 +1284,7 @@ gopkg.in/yaml.v3 # gorm.io/driver/postgres v1.6.0 ## explicit; go 1.20 gorm.io/driver/postgres -# gorm.io/gorm v1.31.0 +# gorm.io/gorm v1.31.1 ## explicit; go 1.18 gorm.io/gorm gorm.io/gorm/callbacks diff --git a/policy-enforcer/go.mod b/policy-enforcer/go.mod index 4710810e..81151740 100644 --- a/policy-enforcer/go.mod +++ b/policy-enforcer/go.mod @@ -23,7 +23,7 @@ require ( google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1 google.golang.org/protobuf v1.36.6 gorm.io/driver/postgres v1.6.0 - gorm.io/gorm v1.30.0 + gorm.io/gorm v1.31.1 ) require ( diff --git a/policy-enforcer/go.sum b/policy-enforcer/go.sum index d056ecfd..3c147a7f 100644 --- a/policy-enforcer/go.sum +++ b/policy-enforcer/go.sum @@ -673,8 +673,8 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= -gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs= -gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= honnef.co/go/tools v0.6.1 h1:R094WgE8K4JirYjBaOpz/AvTyUu/3wbmAoskKN/pxTI= honnef.co/go/tools v0.6.1/go.mod h1:3puzxxljPCe8RGJX7BIy1plGbxEOZni5mR2aXe3/uk4= mvdan.cc/gofumpt v0.7.0 h1:bg91ttqXmi9y2xawvkuMXyvAA/1ZGJqYAEGjXuP0JXU= diff --git a/policy-enforcer/vendor/gorm.io/gorm/README.md b/policy-enforcer/vendor/gorm.io/gorm/README.md index 745dad60..24eb84c9 100644 --- a/policy-enforcer/vendor/gorm.io/gorm/README.md +++ b/policy-enforcer/vendor/gorm.io/gorm/README.md @@ -3,7 +3,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) -[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) +[![test status](https://github.com/go-gorm/gorm/actions/workflows/tests.yml/badge.svg)](https://github.com/go-gorm/gorm/actions) [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) [![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) diff --git a/policy-enforcer/vendor/gorm.io/gorm/association.go b/policy-enforcer/vendor/gorm.io/gorm/association.go index e3f51d17..3a4e0e25 100644 --- a/policy-enforcer/vendor/gorm.io/gorm/association.go +++ b/policy-enforcer/vendor/gorm.io/gorm/association.go @@ -19,10 +19,10 @@ type Association struct { } func (db *DB) Association(column string) *Association { - association := &Association{DB: db} + association := &Association{DB: db, Unscope: db.Statement.Unscoped} table := db.Statement.Table - if err := db.Statement.Parse(db.Statement.Model); err == nil { + if association.Error = db.Statement.Parse(db.Statement.Model); association.Error == nil { db.Statement.Table = table association.Relationship = db.Statement.Schema.Relationships.Relations[column] @@ -34,8 +34,6 @@ func (db *DB) Association(column string) *Association { for db.Statement.ReflectValue.Kind() == reflect.Ptr { db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() } - } else { - association.Error = err } return association @@ -58,6 +56,8 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro } func (association *Association) Append(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { switch association.Relationship.Type { case schema.HasOne, schema.BelongsTo: @@ -73,6 +73,8 @@ func (association *Association) Append(values ...interface{}) error { } func (association *Association) Replace(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { reflectValue := association.DB.Statement.ReflectValue rel := association.Relationship @@ -97,7 +99,7 @@ func (association *Association) Replace(values ...interface{}) error { return association.Error } - // set old associations's foreign key to null + // set old association's foreign key to null switch rel.Type { case schema.BelongsTo: if len(values) == 0 { @@ -195,6 +197,8 @@ func (association *Association) Replace(values ...interface{}) error { } func (association *Association) Delete(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { var ( reflectValue = association.DB.Statement.ReflectValue @@ -300,7 +304,7 @@ func (association *Association) Delete(values ...interface{}) error { } if association.Error == nil { - // clean up deleted values's foreign key + // clean up deleted values' foreign key relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { @@ -431,10 +435,49 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } } + processMap := func(mapv reflect.Value) { + child := reflect.New(association.Relationship.FieldSchema.ModelType) + + switch association.Relationship.Type { + case schema.HasMany: + for _, ref := range association.Relationship.References { + key := reflect.ValueOf(ref.ForeignKey.DBName) + if ref.OwnPrimaryKey { + v := ref.PrimaryKey.ReflectValueOf(association.DB.Statement.Context, source) + mapv.SetMapIndex(key, v) + } else if ref.PrimaryValue != "" { + mapv.SetMapIndex(key, reflect.ValueOf(ref.PrimaryValue)) + } + } + association.Error = association.DB.Session(&Session{ + NewDB: true, + }).Model(child.Interface()).Create(mapv.Interface()).Error + case schema.Many2Many: + association.Error = association.DB.Session(&Session{ + NewDB: true, + }).Model(child.Interface()).Create(mapv.Interface()).Error + + for _, key := range mapv.MapKeys() { + k := strings.ToLower(key.String()) + if f, ok := association.Relationship.FieldSchema.FieldsByDBName[k]; ok { + _ = f.Set(association.DB.Statement.Context, child, mapv.MapIndex(key).Interface()) + } + } + appendToFieldValues(child) + } + } + switch rv.Kind() { + case reflect.Map: + processMap(rv) case reflect.Slice, reflect.Array: for i := 0; i < rv.Len(); i++ { - appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr()) + elem := reflect.Indirect(rv.Index(i)) + if elem.Kind() == reflect.Map { + processMap(elem) + continue + } + appendToFieldValues(elem.Addr()) } case reflect.Struct: if !rv.CanAddr() { @@ -591,3 +634,32 @@ func (association *Association) buildCondition() *DB { return tx } + +func expandValues(values ...any) (results []any) { + appendToResult := func(rv reflect.Value) { + // unwrap interface + if rv.IsValid() && rv.Kind() == reflect.Interface { + rv = rv.Elem() + } + if rv.IsValid() && rv.Kind() == reflect.Struct { + p := reflect.New(rv.Type()) + p.Elem().Set(rv) + results = append(results, p.Interface()) + } else if rv.IsValid() { + results = append(results, rv.Interface()) + } + } + + // Process each argument; if an argument is a slice/array, expand its elements + for _, value := range values { + rv := reflect.ValueOf(value) + if rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array { + for i := 0; i < rv.Len(); i++ { + appendToResult(rv.Index(i)) + } + } else { + appendToResult(rv) + } + } + return +} diff --git a/policy-enforcer/vendor/gorm.io/gorm/callbacks.go b/policy-enforcer/vendor/gorm.io/gorm/callbacks.go index 50b5b0e9..bd97f040 100644 --- a/policy-enforcer/vendor/gorm.io/gorm/callbacks.go +++ b/policy-enforcer/vendor/gorm.io/gorm/callbacks.go @@ -89,10 +89,16 @@ func (p *processor) Execute(db *DB) *DB { resetBuildClauses = true } - if optimizer, ok := db.Statement.Dest.(StatementModifier); ok { + if optimizer, ok := stmt.Dest.(StatementModifier); ok { optimizer.ModifyStatement(stmt) } + if db.DefaultContextTimeout > 0 { + if _, ok := stmt.Context.Deadline(); !ok { + stmt.Context, _ = context.WithTimeout(stmt.Context, db.DefaultContextTimeout) + } + } + // assign model values if stmt.Model == nil { stmt.Model = stmt.Dest diff --git a/policy-enforcer/vendor/gorm.io/gorm/callbacks/create.go b/policy-enforcer/vendor/gorm.io/gorm/callbacks/create.go index d8701f51..e5929adb 100644 --- a/policy-enforcer/vendor/gorm.io/gorm/callbacks/create.go +++ b/policy-enforcer/vendor/gorm.io/gorm/callbacks/create.go @@ -53,9 +53,13 @@ func Create(config *Config) func(db *gorm.DB) { if _, ok := db.Statement.Clauses["RETURNING"]; !ok { fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue)) for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { - fromColumns = append(fromColumns, clause.Column{Name: field.DBName}) + if field.Readable { + fromColumns = append(fromColumns, clause.Column{Name: field.DBName}) + } + } + if len(fromColumns) > 0 { + db.Statement.AddClause(clause.Returning{Columns: fromColumns}) } - db.Statement.AddClause(clause.Returning{Columns: fromColumns}) } } } @@ -76,8 +80,11 @@ func Create(config *Config) func(db *gorm.DB) { ok, mode := hasReturning(db, supportReturning) if ok { if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { - if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { + onConflict, _ := c.Expression.(clause.OnConflict) + if onConflict.DoNothing { mode |= gorm.ScanOnConflictDoNothing + } else if len(onConflict.DoUpdates) > 0 || onConflict.UpdateAll { + mode |= gorm.ScanUpdate } } @@ -122,6 +129,16 @@ func Create(config *Config) func(db *gorm.DB) { pkFieldName = "@id" ) + if db.Statement.Schema != nil { + if db.Statement.Schema.PrioritizedPrimaryField == nil || + !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue || + !db.Statement.Schema.PrioritizedPrimaryField.Readable { + return + } + pkField = db.Statement.Schema.PrioritizedPrimaryField + pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName + } + insertID, err := result.LastInsertId() insertOk := err == nil && insertID > 0 @@ -132,14 +149,6 @@ func Create(config *Config) func(db *gorm.DB) { return } - if db.Statement.Schema != nil { - if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - return - } - pkField = db.Statement.Schema.PrioritizedPrimaryField - pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName - } - // append @id column with value for auto-increment primary key // the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1 switch values := db.Statement.Dest.(type) { diff --git a/policy-enforcer/vendor/gorm.io/gorm/chainable_api.go b/policy-enforcer/vendor/gorm.io/gorm/chainable_api.go index 8a6aea34..8f6113cc 100644 --- a/policy-enforcer/vendor/gorm.io/gorm/chainable_api.go +++ b/policy-enforcer/vendor/gorm.io/gorm/chainable_api.go @@ -178,7 +178,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { - tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsInvalidDBNameChar) } else { tx.Statement.Omits = columns } @@ -283,7 +283,7 @@ func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() - fields := strings.FieldsFunc(name, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(name, utils.IsInvalidDBNameChar) tx.Statement.AddClause(clause.GroupBy{ Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, }) diff --git a/policy-enforcer/vendor/gorm.io/gorm/clause/association.go b/policy-enforcer/vendor/gorm.io/gorm/clause/association.go new file mode 100644 index 00000000..a9bf7eb0 --- /dev/null +++ b/policy-enforcer/vendor/gorm.io/gorm/clause/association.go @@ -0,0 +1,35 @@ +package clause + +// AssociationOpType represents association operation types +type AssociationOpType int + +const ( + OpUnlink AssociationOpType = iota // Unlink association + OpDelete // Delete association records + OpUpdate // Update association records + OpCreate // Create association records with assignments +) + +// Association represents an association operation +type Association struct { + Association string // Association name + Type AssociationOpType // Operation type + Conditions []Expression // Filter conditions + Set []Assignment // Assignment operations (for Update and Create) + Values []interface{} // Values for Create operation +} + +// AssociationAssigner is an interface for association operation providers +type AssociationAssigner interface { + AssociationAssignments() []Association +} + +// Assignments implements the Assigner interface so that AssociationOperation can be used as a Set method parameter +func (ao Association) Assignments() []Assignment { + return []Assignment{} +} + +// AssociationAssignments implements the AssociationAssigner interface +func (ao Association) AssociationAssignments() []Association { + return []Association{ao} +} diff --git a/policy-enforcer/vendor/gorm.io/gorm/clause/set.go b/policy-enforcer/vendor/gorm.io/gorm/clause/set.go index 75eb6bdd..cb5f36a0 100644 --- a/policy-enforcer/vendor/gorm.io/gorm/clause/set.go +++ b/policy-enforcer/vendor/gorm.io/gorm/clause/set.go @@ -9,6 +9,11 @@ type Assignment struct { Value interface{} } +// Assigner assignments provider interface +type Assigner interface { + Assignments() []Assignment +} + func (set Set) Name() string { return "SET" } @@ -37,6 +42,9 @@ func (set Set) MergeClause(clause *Clause) { clause.Expression = Set(copiedAssignments) } +// Assignments implements Assigner for Set. +func (set Set) Assignments() []Assignment { return []Assignment(set) } + func Assignments(values map[string]interface{}) Set { keys := make([]string, 0, len(values)) for key := range values { @@ -58,3 +66,6 @@ func AssignmentColumns(values []string) Set { } return assignments } + +// Assignments implements Assigner for a single Assignment. +func (a Assignment) Assignments() []Assignment { return []Assignment{a} } diff --git a/policy-enforcer/vendor/gorm.io/gorm/finisher_api.go b/policy-enforcer/vendor/gorm.io/gorm/finisher_api.go index 57809d17..e9e35f1b 100644 --- a/policy-enforcer/vendor/gorm.io/gorm/finisher_api.go +++ b/policy-enforcer/vendor/gorm.io/gorm/finisher_api.go @@ -465,7 +465,7 @@ func (db *DB) Count(count *int64) (tx *DB) { if len(tx.Statement.Selects) == 1 { dbName := tx.Statement.Selects[0] - fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(dbName, utils.IsInvalidDBNameChar) if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) { if tx.Statement.Parse(tx.Statement.Model) == nil { if f := tx.Statement.Schema.LookUpField(dbName); f != nil { @@ -564,7 +564,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { } if len(tx.Statement.Selects) != 1 { - fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(column, utils.IsInvalidDBNameChar) tx.Statement.AddClauseIfNotExists(clause.Select{ Distinct: tx.Statement.Distinct, Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, @@ -675,9 +675,9 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { } ctx := tx.Statement.Context - if _, ok := ctx.Deadline(); !ok { - if db.Config.DefaultTransactionTimeout > 0 { - ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout) + if db.DefaultTransactionTimeout > 0 { + if _, ok := ctx.Deadline(); !ok { + ctx, _ = context.WithTimeout(ctx, db.DefaultTransactionTimeout) } } diff --git a/policy-enforcer/vendor/gorm.io/gorm/generics.go b/policy-enforcer/vendor/gorm.io/gorm/generics.go index ad2d063f..166d1520 100644 --- a/policy-enforcer/vendor/gorm.io/gorm/generics.go +++ b/policy-enforcer/vendor/gorm.io/gorm/generics.go @@ -3,12 +3,15 @@ package gorm import ( "context" "database/sql" + "errors" "fmt" + "reflect" "sort" "strings" "gorm.io/gorm/clause" "gorm.io/gorm/logger" + "gorm.io/gorm/schema" ) type result struct { @@ -35,10 +38,34 @@ type Interface[T any] interface { } type CreateInterface[T any] interface { - ChainInterface[T] + ExecInterface[T] + // chain methods available at start; Select/Omit keep CreateInterface to allow Create chaining + Scopes(scopes ...func(db *Statement)) ChainInterface[T] + Where(query interface{}, args ...interface{}) ChainInterface[T] + Not(query interface{}, args ...interface{}) ChainInterface[T] + Or(query interface{}, args ...interface{}) ChainInterface[T] + Limit(offset int) ChainInterface[T] + Offset(offset int) ChainInterface[T] + Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] + Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] + Select(query string, args ...interface{}) CreateInterface[T] + Omit(columns ...string) CreateInterface[T] + MapColumns(m map[string]string) ChainInterface[T] + Distinct(args ...interface{}) ChainInterface[T] + Group(name string) ChainInterface[T] + Having(query interface{}, args ...interface{}) ChainInterface[T] + Order(value interface{}) ChainInterface[T] + Build(builder clause.Builder) + + Delete(ctx context.Context) (rowsAffected int, err error) + Update(ctx context.Context, name string, value any) (rowsAffected int, err error) + Updates(ctx context.Context, t T) (rowsAffected int, err error) + Count(ctx context.Context, column string) (result int64, err error) + Table(name string, args ...interface{}) CreateInterface[T] Create(ctx context.Context, r *T) error CreateInBatches(ctx context.Context, r *[]T, batchSize int) error + Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T] } type ChainInterface[T any] interface { @@ -58,15 +85,28 @@ type ChainInterface[T any] interface { Group(name string) ChainInterface[T] Having(query interface{}, args ...interface{}) ChainInterface[T] Order(value interface{}) ChainInterface[T] + Set(assignments ...clause.Assigner) SetUpdateOnlyInterface[T] Build(builder clause.Builder) + Table(name string, args ...interface{}) ChainInterface[T] Delete(ctx context.Context) (rowsAffected int, err error) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) Updates(ctx context.Context, t T) (rowsAffected int, err error) Count(ctx context.Context, column string) (result int64, err error) } +// SetUpdateOnlyInterface is returned by Set after chaining; only Update is allowed +type SetUpdateOnlyInterface[T any] interface { + Update(ctx context.Context) (rowsAffected int, err error) +} + +// SetCreateOrUpdateInterface is returned by Set at start; Create or Update are allowed +type SetCreateOrUpdateInterface[T any] interface { + Create(ctx context.Context) error + Update(ctx context.Context) (rowsAffected int, err error) +} + type ExecInterface[T any] interface { Scan(ctx context.Context, r interface{}) error First(context.Context) (T, error) @@ -142,13 +182,15 @@ func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] { return execG[T]{g: &g[T]{ db: c.db, ops: append(c.ops, func(db *DB) *DB { - return db.Raw(sql, values...) + var r T + return db.Model(r).Raw(sql, values...) }), }} } func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error { - return c.apply(ctx).Exec(sql, values...).Error + var r T + return c.apply(ctx).Model(r).Exec(sql, values...).Error } type createG[T any] struct { @@ -161,6 +203,22 @@ func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { })} } +func (c createG[T]) Select(query string, args ...interface{}) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Select(query, args...) + })} +} + +func (c createG[T]) Omit(columns ...string) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Omit(columns...) + })} +} + +func (c createG[T]) Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T] { + return c.processSet(assignments...) +} + func (c createG[T]) Create(ctx context.Context, r *T) error { return c.g.apply(ctx).Create(r).Error } @@ -187,6 +245,12 @@ func (c chainG[T]) with(v op) chainG[T] { } } +func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Table(name, args...) + }) +} + func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { return c.with(func(db *DB) *DB { for _, fc := range scopes { @@ -196,12 +260,6 @@ func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { }) } -func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] { - return c.with(func(db *DB) *DB { - return db.Table(name, args...) - }) -} - func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Where(query, args...) @@ -388,6 +446,10 @@ func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] { }) } +func (c chainG[T]) Set(assignments ...clause.Assigner) SetUpdateOnlyInterface[T] { + return c.processSet(assignments...) +} + func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Distinct(args...) @@ -425,12 +487,12 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err relation, ok := db.Statement.Schema.Relationships.Relations[association] if !ok { if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 { - relationships := db.Statement.Schema.Relationships + relationships := &db.Statement.Schema.Relationships for _, field := range preloadFields { var ok bool relation, ok = relationships.Relations[field] if ok { - relationships = relation.FieldSchema.Relationships + relationships = &relation.FieldSchema.Relationships } else { db.AddError(fmt.Errorf("relation %s not found", association)) return nil @@ -567,7 +629,7 @@ func (g execG[T]) First(ctx context.Context) (T, error) { func (g execG[T]) Scan(ctx context.Context, result interface{}) error { var r T - err := g.g.apply(ctx).Model(r).Find(&result).Error + err := g.g.apply(ctx).Model(r).Find(result).Error return err } @@ -603,3 +665,242 @@ func (g execG[T]) Row(ctx context.Context) *sql.Row { func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) { return g.g.apply(ctx).Rows() } + +func (c chainG[T]) processSet(items ...clause.Assigner) setCreateOrUpdateG[T] { + var ( + assigns []clause.Assignment + assocOps []clause.Association + ) + + for _, item := range items { + // Check if it's an AssociationAssigner + if assocAssigner, ok := item.(clause.AssociationAssigner); ok { + assocOps = append(assocOps, assocAssigner.AssociationAssignments()...) + } else { + assigns = append(assigns, item.Assignments()...) + } + } + + return setCreateOrUpdateG[T]{ + c: c, + assigns: assigns, + assocOps: assocOps, + } +} + +// setCreateOrUpdateG[T] is a struct that holds operations to be executed in a batch. +// It supports regular assignments and association operations. +type setCreateOrUpdateG[T any] struct { + c chainG[T] + assigns []clause.Assignment + assocOps []clause.Association +} + +func (s setCreateOrUpdateG[T]) Update(ctx context.Context) (rowsAffected int, err error) { + // Execute association operations + for _, assocOp := range s.assocOps { + if err := s.executeAssociationOperation(ctx, assocOp); err != nil { + return 0, err + } + } + + // Execute assignment operations + if len(s.assigns) > 0 { + var r T + res := s.c.g.apply(ctx).Model(r).Clauses(clause.Set(s.assigns)).Updates(map[string]interface{}{}) + return int(res.RowsAffected), res.Error + } + + return 0, nil +} + +func (s setCreateOrUpdateG[T]) Create(ctx context.Context) error { + // Execute association operations + for _, assocOp := range s.assocOps { + if err := s.executeAssociationOperation(ctx, assocOp); err != nil { + return err + } + } + + // Execute assignment operations + if len(s.assigns) > 0 { + data := make(map[string]interface{}, len(s.assigns)) + for _, a := range s.assigns { + data[a.Column.Name] = a.Value + } + var r T + return s.c.g.apply(ctx).Model(r).Create(data).Error + } + + return nil +} + +// executeAssociationOperation executes an association operation +func (s setCreateOrUpdateG[T]) executeAssociationOperation(ctx context.Context, op clause.Association) error { + var r T + base := s.c.g.apply(ctx).Model(r) + + switch op.Type { + case clause.OpCreate: + return s.handleAssociationCreate(ctx, base, op) + case clause.OpUnlink, clause.OpDelete, clause.OpUpdate: + return s.handleAssociation(ctx, base, op) + default: + return fmt.Errorf("unknown association operation type: %v", op.Type) + } +} + +func (s setCreateOrUpdateG[T]) handleAssociationCreate(ctx context.Context, base *DB, op clause.Association) error { + if len(op.Set) > 0 { + return s.handleAssociationForOwners(base, ctx, func(owner T, assoc *Association) error { + data := make(map[string]interface{}, len(op.Set)) + for _, a := range op.Set { + data[a.Column.Name] = a.Value + } + return assoc.Append(data) + }, op.Association) + } + + return s.handleAssociationForOwners(base, ctx, func(owner T, assoc *Association) error { + return assoc.Append(op.Values...) + }, op.Association) +} + +// handleAssociationForOwners is a helper function that handles associations for all owners +func (s setCreateOrUpdateG[T]) handleAssociationForOwners(base *DB, ctx context.Context, handler func(owner T, association *Association) error, associationName string) error { + var owners []T + if err := base.Find(&owners).Error; err != nil { + return err + } + + for _, owner := range owners { + assoc := base.Session(&Session{NewDB: true, Context: ctx}).Model(&owner).Association(associationName) + if assoc.Error != nil { + return assoc.Error + } + + if err := handler(owner, assoc); err != nil { + return err + } + } + return nil +} + +func (s setCreateOrUpdateG[T]) handleAssociation(ctx context.Context, base *DB, op clause.Association) error { + assoc := base.Association(op.Association) + if assoc.Error != nil { + return assoc.Error + } + + var ( + rel = assoc.Relationship + assocModel = reflect.New(rel.FieldSchema.ModelType).Interface() + fkNil = map[string]any{} + setMap = make(map[string]any, len(op.Set)) + ownerPKNames []string + ownerFKNames []string + primaryColumns []any + foreignColumns []any + ) + + for _, a := range op.Set { + setMap[a.Column.Name] = a.Value + } + + for _, ref := range rel.References { + fkNil[ref.ForeignKey.DBName] = nil + + if ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownerPKNames = append(ownerPKNames, ref.PrimaryKey.DBName) + primaryColumns = append(primaryColumns, clause.Column{Name: ref.PrimaryKey.DBName}) + foreignColumns = append(foreignColumns, clause.Column{Name: ref.ForeignKey.DBName}) + } else if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownerFKNames = append(ownerFKNames, ref.ForeignKey.DBName) + primaryColumns = append(primaryColumns, clause.Column{Name: ref.PrimaryKey.DBName}) + } + } + + assocDB := s.c.g.db.Session(&Session{NewDB: true, Context: ctx}).Model(assocModel).Where(op.Conditions) + + switch rel.Type { + case schema.HasOne, schema.HasMany: + assocDB = assocDB.Where("? IN (?)", foreignColumns, base.Select(ownerPKNames)) + switch op.Type { + case clause.OpUnlink: + return assocDB.Updates(fkNil).Error + case clause.OpDelete: + return assocDB.Delete(assocModel).Error + case clause.OpUpdate: + return assocDB.Updates(setMap).Error + } + case schema.BelongsTo: + switch op.Type { + case clause.OpDelete: + return base.Transaction(func(tx *DB) error { + assocDB.Statement.ConnPool = tx.Statement.ConnPool + base.Statement.ConnPool = tx.Statement.ConnPool + + if err := assocDB.Where("? IN (?)", primaryColumns, base.Select(ownerFKNames)).Delete(assocModel).Error; err != nil { + return err + } + return base.Updates(fkNil).Error + }) + case clause.OpUnlink: + return base.Updates(fkNil).Error + case clause.OpUpdate: + return assocDB.Where("? IN (?)", primaryColumns, base.Select(ownerFKNames)).Updates(setMap).Error + } + case schema.Many2Many: + joinModel := reflect.New(rel.JoinTable.ModelType).Interface() + joinDB := base.Session(&Session{NewDB: true, Context: ctx}).Model(joinModel) + + // EXISTS owners: owners.pk = join.owner_fk for all owner refs + ownersExists := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.Schema.Table).Select("1") + for _, ref := range rel.References { + if ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownersExists = ownersExists.Where(clause.Eq{ + Column: clause.Column{Table: rel.Schema.Table, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + }) + } + } + + // EXISTS related: related.pk = join.rel_fk for all related refs, plus optional conditions + relatedExists := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.FieldSchema.Table).Select("1") + for _, ref := range rel.References { + if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + relatedExists = relatedExists.Where(clause.Eq{ + Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + }) + } + } + relatedExists = relatedExists.Where(op.Conditions) + + switch op.Type { + case clause.OpUnlink, clause.OpDelete: + joinDB = joinDB.Where("EXISTS (?)", ownersExists) + if len(op.Conditions) > 0 { + joinDB = joinDB.Where("EXISTS (?)", relatedExists) + } + return joinDB.Delete(nil).Error + case clause.OpUpdate: + // Update related table rows that have join rows matching owners + relatedDB := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.FieldSchema.Table).Where(op.Conditions) + + // correlated join subquery: join.rel_fk = related.pk AND EXISTS owners + joinSub := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.JoinTable.Table).Select("1") + for _, ref := range rel.References { + if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + joinSub = joinSub.Where(clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, + }) + } + } + joinSub = joinSub.Where("EXISTS (?)", ownersExists) + return relatedDB.Where("EXISTS (?)", joinSub).Updates(setMap).Error + } + } + return errors.New("unsupported relationship") +} diff --git a/policy-enforcer/vendor/gorm.io/gorm/gorm.go b/policy-enforcer/vendor/gorm.io/gorm/gorm.go index 67889262..a209bb09 100644 --- a/policy-enforcer/vendor/gorm.io/gorm/gorm.go +++ b/policy-enforcer/vendor/gorm.io/gorm/gorm.go @@ -23,6 +23,7 @@ type Config struct { // You can disable it by setting `SkipDefaultTransaction` to true SkipDefaultTransaction bool DefaultTransactionTimeout time.Duration + DefaultContextTimeout time.Duration // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer @@ -137,6 +138,14 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { return isConfig && !isConfig2 }) + if len(opts) > 0 { + if c, ok := opts[0].(*Config); ok { + config = c + } else { + opts = append([]Option{config}, opts...) + } + } + var skipAfterInitialize bool for _, opt := range opts { if opt != nil { diff --git a/policy-enforcer/vendor/gorm.io/gorm/logger/slog.go b/policy-enforcer/vendor/gorm.io/gorm/logger/slog.go new file mode 100644 index 00000000..27c74683 --- /dev/null +++ b/policy-enforcer/vendor/gorm.io/gorm/logger/slog.go @@ -0,0 +1,116 @@ +//go:build go1.21 + +package logger + +import ( + "context" + "errors" + "fmt" + "log/slog" + "time" + + "gorm.io/gorm/utils" +) + +type slogLogger struct { + Logger *slog.Logger + LogLevel LogLevel + SlowThreshold time.Duration + Parameterized bool + Colorful bool // Ignored in slog + IgnoreRecordNotFoundError bool +} + +func NewSlogLogger(logger *slog.Logger, config Config) Interface { + return &slogLogger{ + Logger: logger, + LogLevel: config.LogLevel, + SlowThreshold: config.SlowThreshold, + Parameterized: config.ParameterizedQueries, + IgnoreRecordNotFoundError: config.IgnoreRecordNotFoundError, + } +} + +func (l *slogLogger) LogMode(level LogLevel) Interface { + newLogger := *l + newLogger.LogLevel = level + return &newLogger +} + +func (l *slogLogger) Info(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Info { + l.log(ctx, slog.LevelInfo, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Warn(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Warn { + l.log(ctx, slog.LevelWarn, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Error(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Error { + l.log(ctx, slog.LevelError, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + if l.LogLevel <= Silent { + return + } + + elapsed := time.Since(begin) + sql, rows := fc() + fields := []slog.Attr{ + slog.String("duration", fmt.Sprintf("%.3fms", float64(elapsed.Nanoseconds())/1e6)), + slog.String("sql", sql), + } + + if rows != -1 { + fields = append(fields, slog.Int64("rows", rows)) + } + + switch { + case err != nil && (!l.IgnoreRecordNotFoundError || !errors.Is(err, ErrRecordNotFound)): + fields = append(fields, slog.String("error", err.Error())) + l.log(ctx, slog.LevelError, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + + case l.SlowThreshold != 0 && elapsed > l.SlowThreshold: + l.log(ctx, slog.LevelWarn, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + + case l.LogLevel >= Info: + l.log(ctx, slog.LevelInfo, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + } +} + +func (l *slogLogger) log(ctx context.Context, level slog.Level, msg string, args ...any) { + if ctx == nil { + ctx = context.Background() + } + + if !l.Logger.Enabled(ctx, level) { + return + } + + r := slog.NewRecord(time.Now(), level, msg, utils.CallerFrame().PC) + r.Add(args...) + _ = l.Logger.Handler().Handle(ctx, r) +} + +// ParamsFilter filter params +func (l *slogLogger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { + if l.Parameterized { + return sql, nil + } + return sql, params +} diff --git a/policy-enforcer/vendor/gorm.io/gorm/migrator/migrator.go b/policy-enforcer/vendor/gorm.io/gorm/migrator/migrator.go index cec4e30f..35107d57 100644 --- a/policy-enforcer/vendor/gorm.io/gorm/migrator/migrator.go +++ b/policy-enforcer/vendor/gorm.io/gorm/migrator/migrator.go @@ -474,7 +474,6 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy // found, smart migrate fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) realDataType := strings.ToLower(columnType.DatabaseTypeName()) - var ( alterColumn bool isSameType = fullDataType == realDataType @@ -513,8 +512,19 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } } + } - // check precision + // check precision + if realDataType == "decimal" || realDataType == "numeric" && + regexp.MustCompile(realDataType+`\(.*\)`).FindString(fullDataType) != "" { // if realDataType has no precision,ignore + precision, scale, ok := columnType.DecimalSize() + if ok { + if !strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d,%d)", realDataType, precision, scale)) && + !strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d)", realDataType, precision)) { + alterColumn = true + } + } + } else { if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { alterColumn = true @@ -550,6 +560,10 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy v1, _ := strconv.ParseBool(dv) v2, _ := strconv.ParseBool(field.DefaultValue) alterColumn = v1 != v2 + case schema.String: + if dv != field.DefaultValue && dv != strings.Trim(field.DefaultValue, "'\"") { + alterColumn = true + } default: alterColumn = dv != field.DefaultValue } diff --git a/policy-enforcer/vendor/gorm.io/gorm/schema/field.go b/policy-enforcer/vendor/gorm.io/gorm/schema/field.go index a6ff1a72..de797402 100644 --- a/policy-enforcer/vendor/gorm.io/gorm/schema/field.go +++ b/policy-enforcer/vendor/gorm.io/gorm/schema/field.go @@ -448,16 +448,17 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } // create valuer, setter when parse struct -func (field *Field) setupValuerAndSetter() { +func (field *Field) setupValuerAndSetter(modelType reflect.Type) { // Setup NewValuePool field.setupNewValuePool() // ValueOf returns field's value and if it is zero fieldIndex := field.StructField.Index[0] switch { - case len(field.StructField.Index) == 1 && fieldIndex > 0: - field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) { - fieldValue := reflect.Indirect(value).Field(fieldIndex) + case len(field.StructField.Index) == 1 && fieldIndex >= 0: + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + v = reflect.Indirect(v) + fieldValue := v.Field(fieldIndex) return fieldValue.Interface(), fieldValue.IsZero() } default: @@ -504,9 +505,10 @@ func (field *Field) setupValuerAndSetter() { // ReflectValueOf returns field's reflect value switch { - case len(field.StructField.Index) == 1 && fieldIndex > 0: - field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value { - return reflect.Indirect(value).Field(fieldIndex) + case len(field.StructField.Index) == 1 && fieldIndex >= 0: + field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { + v = reflect.Indirect(v) + return v.Field(fieldIndex) } default: field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { diff --git a/policy-enforcer/vendor/gorm.io/gorm/schema/relationship.go b/policy-enforcer/vendor/gorm.io/gorm/schema/relationship.go index f1ace924..0535bba4 100644 --- a/policy-enforcer/vendor/gorm.io/gorm/schema/relationship.go +++ b/policy-enforcer/vendor/gorm.io/gorm/schema/relationship.go @@ -75,9 +75,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { } ) - cacheStore := schema.cacheStore - - if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { + if relation.FieldSchema, err = getOrParse(fieldValue, schema.cacheStore, schema.namer); err != nil { schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err) return nil } @@ -147,6 +145,9 @@ func hasPolymorphicRelation(tagSettings map[string]string) bool { } func (schema *Schema) setRelation(relation *Relationship) { + schema.Relationships.Mux.Lock() + defer schema.Relationships.Mux.Unlock() + // set non-embedded relation if rel := schema.Relationships.Relations[relation.Name]; rel != nil { if len(rel.Field.BindNames) > 1 { @@ -590,6 +591,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu // build references for idx, foreignField := range foreignFields { // use same data type for foreign keys + schema.Relationships.Mux.Lock() + if schema != foreignField.Schema { + foreignField.Schema.Relationships.Mux.Lock() + } if copyableDataType(primaryFields[idx].DataType) { foreignField.DataType = primaryFields[idx].DataType } @@ -597,6 +602,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu if foreignField.Size == 0 { foreignField.Size = primaryFields[idx].Size } + schema.Relationships.Mux.Unlock() + if schema != foreignField.Schema { + foreignField.Schema.Relationships.Mux.Unlock() + } relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], diff --git a/policy-enforcer/vendor/gorm.io/gorm/schema/schema.go b/policy-enforcer/vendor/gorm.io/gorm/schema/schema.go index db236797..09697a7a 100644 --- a/policy-enforcer/vendor/gorm.io/gorm/schema/schema.go +++ b/policy-enforcer/vendor/gorm.io/gorm/schema/schema.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "go/ast" + "path" "reflect" "strings" "sync" @@ -59,14 +60,14 @@ type Schema struct { cacheStore *sync.Map } -func (schema Schema) String() string { +func (schema *Schema) String() string { if schema.ModelType.Name() == "" { return fmt.Sprintf("%s(%s)", schema.Name, schema.Table) } return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name()) } -func (schema Schema) MakeSlice() reflect.Value { +func (schema *Schema) MakeSlice() reflect.Value { slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20) results := reflect.New(slice.Type()) results.Elem().Set(slice) @@ -74,13 +75,23 @@ func (schema Schema) MakeSlice() reflect.Value { return results } -func (schema Schema) LookUpField(name string) *Field { +func (schema *Schema) LookUpField(name string) *Field { if field, ok := schema.FieldsByDBName[name]; ok { return field } if field, ok := schema.FieldsByName[name]; ok { return field } + + // Lookup field using namer-driven ColumnName + if schema.namer == nil { + return nil + } + namerColumnName := schema.namer.ColumnName(schema.Table, name) + if field, ok := schema.FieldsByDBName[namerColumnName]; ok { + return field + } + return nil } @@ -92,10 +103,7 @@ func (schema Schema) LookUpField(name string) *Field { // } // ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID") // } -func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { - if len(bindNames) == 0 { - return nil - } +func (schema *Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { for i := len(bindNames) - 1; i >= 0; i-- { find := strings.Join(bindNames[:i], ".") + "." + name if field, ok := schema.FieldsByBindName[find]; ok { @@ -113,6 +121,14 @@ type TablerWithNamer interface { TableName(Namer) string } +var callbackTypes = []callbackType{ + callbackTypeBeforeCreate, callbackTypeAfterCreate, + callbackTypeBeforeUpdate, callbackTypeAfterUpdate, + callbackTypeBeforeSave, callbackTypeAfterSave, + callbackTypeBeforeDelete, callbackTypeAfterDelete, + callbackTypeAfterFind, +} + // Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { return ParseWithSpecialTableName(dest, cacheStore, namer, "") @@ -124,34 +140,33 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - value := reflect.ValueOf(dest) - if value.Kind() == reflect.Ptr && value.IsNil() { - value = reflect.New(value.Type().Elem()) - } - modelType := reflect.Indirect(value).Type() - - if modelType.Kind() == reflect.Interface { - modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() - } - - for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType := reflect.ValueOf(dest).Type() + if modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } if modelType.Kind() != reflect.Struct { - if modelType.PkgPath() == "" { - return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + if modelType.Kind() == reflect.Interface { + modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() + } + + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } - return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } // Cache the Schema for performance, // Use the modelType or modelType + schemaTable (if it present) as cache key. - var schemaCacheKey interface{} + var schemaCacheKey interface{} = modelType if specialTableName != "" { schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName) - } else { - schemaCacheKey = modelType } // Load exist schema cache, return if exists @@ -162,28 +177,29 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return s, s.err } + var tableName string modelValue := reflect.New(modelType) - tableName := namer.TableName(modelType.Name()) - if tabler, ok := modelValue.Interface().(Tabler); ok { + if specialTableName != "" { + tableName = specialTableName + } else if en, ok := namer.(embeddedNamer); ok { + tableName = en.Table + } else if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() - } - if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { + } else if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { tableName = tabler.TableName(namer) - } - if en, ok := namer.(embeddedNamer); ok { - tableName = en.Table - } - if specialTableName != "" && specialTableName != tableName { - tableName = specialTableName + } else { + tableName = namer.TableName(modelType.Name()) } schema := &Schema{ Name: modelType.Name(), ModelType: modelType, Table: tableName, - FieldsByName: map[string]*Field{}, - FieldsByBindName: map[string]*Field{}, - FieldsByDBName: map[string]*Field{}, + DBNames: make([]string, 0, 10), + Fields: make([]*Field, 0, 10), + FieldsByName: make(map[string]*Field, 10), + FieldsByBindName: make(map[string]*Field, 10), + FieldsByDBName: make(map[string]*Field, 10), Relationships: Relationships{Relations: map[string]*Relationship{}}, cacheStore: cacheStore, namer: namer, @@ -227,8 +243,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.FieldsByBindName[bindName] = field if v != nil && v.PrimaryKey { + // remove the existing primary key field for idx, f := range schema.PrimaryFields { - if f == v { + if f.DBName == v.DBName { schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) } } @@ -247,7 +264,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.FieldsByBindName[bindName] = field } - field.setupValuerAndSetter() + field.setupValuerAndSetter(modelType) } prioritizedPrimaryField := schema.LookUpField("id") @@ -283,10 +300,37 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) } + _, embedded := schema.cacheStore.Load(embeddedCacheKey) + relationshipFields := []*Field{} for _, field := range schema.Fields { if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil { schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } + + if !embedded { + if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) { + relationshipFields = append(relationshipFields, field) + schema.FieldsByName[field.Name] = field + schema.FieldsByBindName[field.BindName()] = field + } + + fieldValue := reflect.New(field.IndirectFieldType).Interface() + if fc, ok := fieldValue.(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } + + if fc, ok := fieldValue.(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } + + if fc, ok := fieldValue.(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } + + if fc, ok := fieldValue.(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + } + } } if field := schema.PrioritizedPrimaryField; field != nil { @@ -303,24 +347,6 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } } - callbackTypes := []callbackType{ - callbackTypeBeforeCreate, callbackTypeAfterCreate, - callbackTypeBeforeUpdate, callbackTypeAfterUpdate, - callbackTypeBeforeSave, callbackTypeAfterSave, - callbackTypeBeforeDelete, callbackTypeAfterDelete, - callbackTypeAfterFind, - } - for _, cbName := range callbackTypes { - if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() { - switch methodValue.Type().String() { - case "func(*gorm.DB) error": // TODO hack - reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) - default: - logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) - } - } - } - // Cache the schema if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded { s := v.(*Schema) @@ -336,84 +362,47 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } }() - if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { - for _, field := range schema.Fields { - if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) { - if schema.parseRelation(field); schema.err != nil { - return schema, schema.err + for _, cbName := range callbackTypes { + if methodValue := modelValue.MethodByName(string(cbName)); methodValue.IsValid() { + switch methodValue.Type().String() { + case "func(*gorm.DB) error": + expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath()) + if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath { + reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) } else { - schema.FieldsByName[field.Name] = field - schema.FieldsByBindName[field.BindName()] = field + logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg) + // PASS } + default: + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) } + } + } - fieldValue := reflect.New(field.IndirectFieldType) - fieldInterface := fieldValue.Interface() - if fc, ok := fieldInterface.(CreateClausesInterface); ok { - field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) - } - - if fc, ok := fieldInterface.(QueryClausesInterface); ok { - field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) - } - - if fc, ok := fieldInterface.(UpdateClausesInterface); ok { - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) - } - - if fc, ok := fieldInterface.(DeleteClausesInterface); ok { - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) - } + // parse relationships + for _, field := range relationshipFields { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err } } return schema, schema.err } -// This unrolling is needed to show to the compiler the exact set of methods -// that can be used on the modelType. -// Prior to go1.22 any use of MethodByName would cause the linker to -// abandon dead code elimination for the entire binary. -// As of go1.22 the compiler supports one special case of a string constant -// being passed to MethodByName. For enterprise customers or those building -// large binaries, this gives a significant reduction in binary size. -// https://github.com/golang/go/issues/62257 -func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value { - switch cbType { - case callbackTypeBeforeCreate: - return modelType.MethodByName(string(callbackTypeBeforeCreate)) - case callbackTypeAfterCreate: - return modelType.MethodByName(string(callbackTypeAfterCreate)) - case callbackTypeBeforeUpdate: - return modelType.MethodByName(string(callbackTypeBeforeUpdate)) - case callbackTypeAfterUpdate: - return modelType.MethodByName(string(callbackTypeAfterUpdate)) - case callbackTypeBeforeSave: - return modelType.MethodByName(string(callbackTypeBeforeSave)) - case callbackTypeAfterSave: - return modelType.MethodByName(string(callbackTypeAfterSave)) - case callbackTypeBeforeDelete: - return modelType.MethodByName(string(callbackTypeBeforeDelete)) - case callbackTypeAfterDelete: - return modelType.MethodByName(string(callbackTypeAfterDelete)) - case callbackTypeAfterFind: - return modelType.MethodByName(string(callbackTypeAfterFind)) - default: - return reflect.ValueOf(nil) - } -} - func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() - for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } if modelType.Kind() != reflect.Struct { - if modelType.PkgPath() == "" { - return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } - return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } if v, ok := cacheStore.Load(modelType); ok { diff --git a/policy-enforcer/vendor/gorm.io/gorm/schema/serializer.go b/policy-enforcer/vendor/gorm.io/gorm/schema/serializer.go index 0fafbcba..d774a8ed 100644 --- a/policy-enforcer/vendor/gorm.io/gorm/schema/serializer.go +++ b/policy-enforcer/vendor/gorm.io/gorm/schema/serializer.go @@ -8,6 +8,7 @@ import ( "encoding/gob" "encoding/json" "fmt" + "math" "reflect" "strings" "sync" @@ -127,16 +128,31 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect. // Value implements serializer interface func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { rv := reflect.ValueOf(fieldValue) - switch v := fieldValue.(type) { - case int64, int, uint, uint64, int32, uint32, int16, uint16: - result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() - case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: + switch fieldValue.(type) { + case int, int8, int16, int32, int64: + result = time.Unix(rv.Int(), 0).UTC() + case uint, uint8, uint16, uint32, uint64: + if uv := rv.Uint(); uv > math.MaxInt64 { + err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) + } else { + result = time.Unix(int64(uv), 0).UTC() //nolint:gosec + } + case *int, *int8, *int16, *int32, *int64: if rv.IsZero() { return nil, nil } - result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() + result = time.Unix(rv.Elem().Int(), 0).UTC() + case *uint, *uint8, *uint16, *uint32, *uint64: + if rv.IsZero() { + return nil, nil + } + if uv := rv.Elem().Uint(); uv > math.MaxInt64 { + err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) + } else { + result = time.Unix(int64(uv), 0).UTC() //nolint:gosec + } default: - err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) + err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", fieldValue) } return } diff --git a/policy-enforcer/vendor/gorm.io/gorm/schema/utils.go b/policy-enforcer/vendor/gorm.io/gorm/schema/utils.go index fa1c65d4..86305d7b 100644 --- a/policy-enforcer/vendor/gorm.io/gorm/schema/utils.go +++ b/policy-enforcer/vendor/gorm.io/gorm/schema/utils.go @@ -17,25 +17,23 @@ func ParseTagSetting(str string, sep string) map[string]string { settings := map[string]string{} names := strings.Split(str, sep) + var parsedNames []string for i := 0; i < len(names); i++ { - j := i - if len(names[j]) > 0 { - for { - if names[j][len(names[j])-1] == '\\' { - i++ - names[j] = names[j][0:len(names[j])-1] + sep + names[i] - names[i] = "" - } else { - break - } - } + s := names[i] + for strings.HasSuffix(s, "\\") && i+1 < len(names) { + i++ + s = s[:len(s)-1] + sep + names[i] } + parsedNames = append(parsedNames, s) + } - values := strings.Split(names[j], ":") + for _, tag := range parsedNames { + values := strings.Split(tag, ":") k := strings.TrimSpace(strings.ToUpper(values[0])) - if len(values) >= 2 { - settings[k] = strings.Join(values[1:], ":") + val := strings.Join(values[1:], ":") + val = strings.ReplaceAll(val, `\"`, `"`) + settings[k] = val } else if k != "" { settings[k] = k } @@ -121,6 +119,17 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, } switch reflectValue.Kind() { + case reflect.Map: + results = [][]interface{}{make([]interface{}, len(fields))} + for idx, field := range fields { + mapValue := reflectValue.MapIndex(reflect.ValueOf(field.DBName)) + if mapValue.IsZero() { + mapValue = reflectValue.MapIndex(reflect.ValueOf(field.Name)) + } + results[0][idx] = mapValue.Interface() + } + + dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue} case reflect.Struct: results = [][]interface{}{make([]interface{}, len(fields))} diff --git a/policy-enforcer/vendor/gorm.io/gorm/statement.go b/policy-enforcer/vendor/gorm.io/gorm/statement.go index c6183724..736087d7 100644 --- a/policy-enforcer/vendor/gorm.io/gorm/statement.go +++ b/policy-enforcer/vendor/gorm.io/gorm/statement.go @@ -96,7 +96,9 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { if v.Name == clause.CurrentTable { if stmt.TableExpr != nil { stmt.TableExpr.Build(stmt) - } else { + } else if stmt.Table != "" { + write(v.Raw, stmt.Table) + } else if stmt.AddError(stmt.Parse(stmt.Model)) == nil { write(v.Raw, stmt.Table) } } else { @@ -334,6 +336,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] switch v := arg.(type) { case clause.Expression: conds = append(conds, v) + case []clause.Expression: + conds = append(conds, v...) case *DB: v.executeScopes() @@ -341,7 +345,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if where, ok := cs.Expression.(clause.Where); ok { if len(where.Exprs) == 1 { if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { - where.Exprs[0] = clause.AndConditions(orConds) + if len(orConds.Exprs) == 1 { + where.Exprs[0] = clause.AndConditions(orConds) + } } } conds = append(conds, clause.And(where.Exprs...)) @@ -362,6 +368,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, key := range keys { column := clause.Column{Name: key, Table: curTable} + if strings.Contains(key, ".") { + column = clause.Column{Name: key} + } conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } case map[string]interface{}: @@ -374,6 +383,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, key := range keys { reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) column := clause.Column{Name: key, Table: curTable} + if strings.Contains(key, ".") { + column = clause.Column{Name: key} + } switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if _, ok := v[key].(driver.Valuer); ok { @@ -650,12 +662,15 @@ func (stmt *Statement) Changed(fields ...string) bool { for destValue.Kind() == reflect.Ptr { destValue = destValue.Elem() } - - changedValue, zero := field.ValueOf(stmt.Context, destValue) - if v { - return !utils.AssertEqual(changedValue, fieldValue) + if descSchema, err := schema.Parse(stmt.Dest, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + if destField := descSchema.LookUpField(field.DBName); destField != nil { + changedValue, zero := destField.ValueOf(stmt.Context, destValue) + if v { + return !utils.AssertEqual(changedValue, fieldValue) + } + return !zero && !utils.AssertEqual(changedValue, fieldValue) + } } - return !zero && !utils.AssertEqual(changedValue, fieldValue) } } return false diff --git a/policy-enforcer/vendor/gorm.io/gorm/utils/utils.go b/policy-enforcer/vendor/gorm.io/gorm/utils/utils.go index fc615d73..7e59264b 100644 --- a/policy-enforcer/vendor/gorm.io/gorm/utils/utils.go +++ b/policy-enforcer/vendor/gorm.io/gorm/utils/utils.go @@ -30,8 +30,12 @@ func sourceDir(file string) string { return filepath.ToSlash(s) + "/" } -// FileWithLineNum return the file name and line number of the current file -func FileWithLineNum() string { +// CallerFrame retrieves the first relevant stack frame outside of GORM's internal implementation files. +// It skips: +// - GORM's core source files (identified by gormSourceDir prefix) +// - Exclude test files (*_test.go) +// - go-gorm/gen's Generated files (*.gen.go) +func CallerFrame() runtime.Frame { pcs := [13]uintptr{} // the third caller usually from gorm internal len := runtime.Callers(3, pcs[:]) @@ -41,14 +45,24 @@ func FileWithLineNum() string { frame, _ := frames.Next() if (!strings.HasPrefix(frame.File, gormSourceDir) || strings.HasSuffix(frame.File, "_test.go")) && !strings.HasSuffix(frame.File, ".gen.go") { - return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) + return frame } } + return runtime.Frame{} +} + +// FileWithLineNum return the file name and line number of the current file +func FileWithLineNum() string { + frame := CallerFrame() + if frame.PC != 0 { + return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) + } + return "" } -func IsValidDBNameChar(c rune) bool { +func IsInvalidDBNameChar(c rune) bool { return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' } diff --git a/policy-enforcer/vendor/modules.txt b/policy-enforcer/vendor/modules.txt index f8a6f534..65cd2a47 100644 --- a/policy-enforcer/vendor/modules.txt +++ b/policy-enforcer/vendor/modules.txt @@ -1269,7 +1269,7 @@ gopkg.in/yaml.v3 # gorm.io/driver/postgres v1.6.0 ## explicit; go 1.20 gorm.io/driver/postgres -# gorm.io/gorm v1.30.0 +# gorm.io/gorm v1.31.1 ## explicit; go 1.18 gorm.io/gorm gorm.io/gorm/callbacks diff --git a/public-api/go.mod b/public-api/go.mod index 0500d5b5..45e35706 100644 --- a/public-api/go.mod +++ b/public-api/go.mod @@ -18,7 +18,7 @@ require ( google.golang.org/grpc v1.75.1 google.golang.org/protobuf v1.36.10 gorm.io/driver/postgres v1.6.0 - gorm.io/gorm v1.31.0 + gorm.io/gorm v1.31.1 ) require ( diff --git a/public-api/go.sum b/public-api/go.sum index d1e0336f..1c1dc6c4 100644 --- a/public-api/go.sum +++ b/public-api/go.sum @@ -658,8 +658,8 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= -gorm.io/gorm v1.31.0 h1:0VlycGreVhK7RF/Bwt51Fk8v0xLiiiFdbGDPIZQ7mJY= -gorm.io/gorm v1.31.0/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= honnef.co/go/tools v0.6.1 h1:R094WgE8K4JirYjBaOpz/AvTyUu/3wbmAoskKN/pxTI= honnef.co/go/tools v0.6.1/go.mod h1:3puzxxljPCe8RGJX7BIy1plGbxEOZni5mR2aXe3/uk4= mvdan.cc/gofumpt v0.7.0 h1:bg91ttqXmi9y2xawvkuMXyvAA/1ZGJqYAEGjXuP0JXU= diff --git a/public-api/vendor/gorm.io/gorm/README.md b/public-api/vendor/gorm.io/gorm/README.md index 745dad60..24eb84c9 100644 --- a/public-api/vendor/gorm.io/gorm/README.md +++ b/public-api/vendor/gorm.io/gorm/README.md @@ -3,7 +3,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) -[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) +[![test status](https://github.com/go-gorm/gorm/actions/workflows/tests.yml/badge.svg)](https://github.com/go-gorm/gorm/actions) [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) [![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) diff --git a/public-api/vendor/gorm.io/gorm/association.go b/public-api/vendor/gorm.io/gorm/association.go index f210ca0a..3a4e0e25 100644 --- a/public-api/vendor/gorm.io/gorm/association.go +++ b/public-api/vendor/gorm.io/gorm/association.go @@ -99,7 +99,7 @@ func (association *Association) Replace(values ...interface{}) error { return association.Error } - // set old associations's foreign key to null + // set old association's foreign key to null switch rel.Type { case schema.BelongsTo: if len(values) == 0 { @@ -304,7 +304,7 @@ func (association *Association) Delete(values ...interface{}) error { } if association.Error == nil { - // clean up deleted values's foreign key + // clean up deleted values' foreign key relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { diff --git a/public-api/vendor/gorm.io/gorm/chainable_api.go b/public-api/vendor/gorm.io/gorm/chainable_api.go index 8a6aea34..8f6113cc 100644 --- a/public-api/vendor/gorm.io/gorm/chainable_api.go +++ b/public-api/vendor/gorm.io/gorm/chainable_api.go @@ -178,7 +178,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { - tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsInvalidDBNameChar) } else { tx.Statement.Omits = columns } @@ -283,7 +283,7 @@ func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() - fields := strings.FieldsFunc(name, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(name, utils.IsInvalidDBNameChar) tx.Statement.AddClause(clause.GroupBy{ Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, }) diff --git a/public-api/vendor/gorm.io/gorm/finisher_api.go b/public-api/vendor/gorm.io/gorm/finisher_api.go index e601fe66..e9e35f1b 100644 --- a/public-api/vendor/gorm.io/gorm/finisher_api.go +++ b/public-api/vendor/gorm.io/gorm/finisher_api.go @@ -465,7 +465,7 @@ func (db *DB) Count(count *int64) (tx *DB) { if len(tx.Statement.Selects) == 1 { dbName := tx.Statement.Selects[0] - fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(dbName, utils.IsInvalidDBNameChar) if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) { if tx.Statement.Parse(tx.Statement.Model) == nil { if f := tx.Statement.Schema.LookUpField(dbName); f != nil { @@ -564,7 +564,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { } if len(tx.Statement.Selects) != 1 { - fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(column, utils.IsInvalidDBNameChar) tx.Statement.AddClauseIfNotExists(clause.Select{ Distinct: tx.Statement.Distinct, Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, diff --git a/public-api/vendor/gorm.io/gorm/generics.go b/public-api/vendor/gorm.io/gorm/generics.go index 79238d5f..166d1520 100644 --- a/public-api/vendor/gorm.io/gorm/generics.go +++ b/public-api/vendor/gorm.io/gorm/generics.go @@ -39,7 +39,7 @@ type Interface[T any] interface { type CreateInterface[T any] interface { ExecInterface[T] - // chain methods available at start; return ChainInterface + // chain methods available at start; Select/Omit keep CreateInterface to allow Create chaining Scopes(scopes ...func(db *Statement)) ChainInterface[T] Where(query interface{}, args ...interface{}) ChainInterface[T] Not(query interface{}, args ...interface{}) ChainInterface[T] @@ -48,8 +48,8 @@ type CreateInterface[T any] interface { Offset(offset int) ChainInterface[T] Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] - Select(query string, args ...interface{}) ChainInterface[T] - Omit(columns ...string) ChainInterface[T] + Select(query string, args ...interface{}) CreateInterface[T] + Omit(columns ...string) CreateInterface[T] MapColumns(m map[string]string) ChainInterface[T] Distinct(args ...interface{}) ChainInterface[T] Group(name string) ChainInterface[T] @@ -203,6 +203,18 @@ func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { })} } +func (c createG[T]) Select(query string, args ...interface{}) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Select(query, args...) + })} +} + +func (c createG[T]) Omit(columns ...string) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Omit(columns...) + })} +} + func (c createG[T]) Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T] { return c.processSet(assignments...) } diff --git a/public-api/vendor/gorm.io/gorm/logger/slog.go b/public-api/vendor/gorm.io/gorm/logger/slog.go index 613234ca..27c74683 100644 --- a/public-api/vendor/gorm.io/gorm/logger/slog.go +++ b/public-api/vendor/gorm.io/gorm/logger/slog.go @@ -8,6 +8,8 @@ import ( "fmt" "log/slog" "time" + + "gorm.io/gorm/utils" ) type slogLogger struct { @@ -37,19 +39,19 @@ func (l *slogLogger) LogMode(level LogLevel) Interface { func (l *slogLogger) Info(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Info { - l.Logger.InfoContext(ctx, msg, slog.Any("data", data)) + l.log(ctx, slog.LevelInfo, msg, slog.Any("data", data)) } } func (l *slogLogger) Warn(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Warn { - l.Logger.WarnContext(ctx, msg, slog.Any("data", data)) + l.log(ctx, slog.LevelWarn, msg, slog.Any("data", data)) } } func (l *slogLogger) Error(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Error { - l.Logger.ErrorContext(ctx, msg, slog.Any("data", data)) + l.log(ctx, slog.LevelError, msg, slog.Any("data", data)) } } @@ -72,25 +74,39 @@ func (l *slogLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql switch { case err != nil && (!l.IgnoreRecordNotFoundError || !errors.Is(err, ErrRecordNotFound)): fields = append(fields, slog.String("error", err.Error())) - l.Logger.ErrorContext(ctx, "SQL executed", slog.Attr{ + l.log(ctx, slog.LevelError, "SQL executed", slog.Attr{ Key: "trace", Value: slog.GroupValue(fields...), }) case l.SlowThreshold != 0 && elapsed > l.SlowThreshold: - l.Logger.WarnContext(ctx, "SQL executed", slog.Attr{ + l.log(ctx, slog.LevelWarn, "SQL executed", slog.Attr{ Key: "trace", Value: slog.GroupValue(fields...), }) case l.LogLevel >= Info: - l.Logger.InfoContext(ctx, "SQL executed", slog.Attr{ + l.log(ctx, slog.LevelInfo, "SQL executed", slog.Attr{ Key: "trace", Value: slog.GroupValue(fields...), }) } } +func (l *slogLogger) log(ctx context.Context, level slog.Level, msg string, args ...any) { + if ctx == nil { + ctx = context.Background() + } + + if !l.Logger.Enabled(ctx, level) { + return + } + + r := slog.NewRecord(time.Now(), level, msg, utils.CallerFrame().PC) + r.Add(args...) + _ = l.Logger.Handler().Handle(ctx, r) +} + // ParamsFilter filter params func (l *slogLogger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { if l.Parameterized { diff --git a/public-api/vendor/gorm.io/gorm/migrator/migrator.go b/public-api/vendor/gorm.io/gorm/migrator/migrator.go index 50a36d10..35107d57 100644 --- a/public-api/vendor/gorm.io/gorm/migrator/migrator.go +++ b/public-api/vendor/gorm.io/gorm/migrator/migrator.go @@ -560,6 +560,10 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy v1, _ := strconv.ParseBool(dv) v2, _ := strconv.ParseBool(field.DefaultValue) alterColumn = v1 != v2 + case schema.String: + if dv != field.DefaultValue && dv != strings.Trim(field.DefaultValue, "'\"") { + alterColumn = true + } default: alterColumn = dv != field.DefaultValue } diff --git a/public-api/vendor/gorm.io/gorm/schema/schema.go b/public-api/vendor/gorm.io/gorm/schema/schema.go index 9419846b..09697a7a 100644 --- a/public-api/vendor/gorm.io/gorm/schema/schema.go +++ b/public-api/vendor/gorm.io/gorm/schema/schema.go @@ -82,6 +82,16 @@ func (schema *Schema) LookUpField(name string) *Field { if field, ok := schema.FieldsByName[name]; ok { return field } + + // Lookup field using namer-driven ColumnName + if schema.namer == nil { + return nil + } + namerColumnName := schema.namer.ColumnName(schema.Table, name) + if field, ok := schema.FieldsByDBName[namerColumnName]; ok { + return field + } + return nil } diff --git a/public-api/vendor/gorm.io/gorm/schema/serializer.go b/public-api/vendor/gorm.io/gorm/schema/serializer.go index 0fafbcba..d774a8ed 100644 --- a/public-api/vendor/gorm.io/gorm/schema/serializer.go +++ b/public-api/vendor/gorm.io/gorm/schema/serializer.go @@ -8,6 +8,7 @@ import ( "encoding/gob" "encoding/json" "fmt" + "math" "reflect" "strings" "sync" @@ -127,16 +128,31 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect. // Value implements serializer interface func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { rv := reflect.ValueOf(fieldValue) - switch v := fieldValue.(type) { - case int64, int, uint, uint64, int32, uint32, int16, uint16: - result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() - case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: + switch fieldValue.(type) { + case int, int8, int16, int32, int64: + result = time.Unix(rv.Int(), 0).UTC() + case uint, uint8, uint16, uint32, uint64: + if uv := rv.Uint(); uv > math.MaxInt64 { + err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) + } else { + result = time.Unix(int64(uv), 0).UTC() //nolint:gosec + } + case *int, *int8, *int16, *int32, *int64: if rv.IsZero() { return nil, nil } - result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() + result = time.Unix(rv.Elem().Int(), 0).UTC() + case *uint, *uint8, *uint16, *uint32, *uint64: + if rv.IsZero() { + return nil, nil + } + if uv := rv.Elem().Uint(); uv > math.MaxInt64 { + err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) + } else { + result = time.Unix(int64(uv), 0).UTC() //nolint:gosec + } default: - err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) + err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", fieldValue) } return } diff --git a/public-api/vendor/gorm.io/gorm/schema/utils.go b/public-api/vendor/gorm.io/gorm/schema/utils.go index d4fe252e..86305d7b 100644 --- a/public-api/vendor/gorm.io/gorm/schema/utils.go +++ b/public-api/vendor/gorm.io/gorm/schema/utils.go @@ -17,25 +17,23 @@ func ParseTagSetting(str string, sep string) map[string]string { settings := map[string]string{} names := strings.Split(str, sep) + var parsedNames []string for i := 0; i < len(names); i++ { - j := i - if len(names[j]) > 0 { - for { - if names[j][len(names[j])-1] == '\\' { - i++ - names[j] = names[j][0:len(names[j])-1] + sep + names[i] - names[i] = "" - } else { - break - } - } + s := names[i] + for strings.HasSuffix(s, "\\") && i+1 < len(names) { + i++ + s = s[:len(s)-1] + sep + names[i] } + parsedNames = append(parsedNames, s) + } - values := strings.Split(names[j], ":") + for _, tag := range parsedNames { + values := strings.Split(tag, ":") k := strings.TrimSpace(strings.ToUpper(values[0])) - if len(values) >= 2 { - settings[k] = strings.Join(values[1:], ":") + val := strings.Join(values[1:], ":") + val = strings.ReplaceAll(val, `\"`, `"`) + settings[k] = val } else if k != "" { settings[k] = k } diff --git a/public-api/vendor/gorm.io/gorm/utils/utils.go b/public-api/vendor/gorm.io/gorm/utils/utils.go index fc615d73..7e59264b 100644 --- a/public-api/vendor/gorm.io/gorm/utils/utils.go +++ b/public-api/vendor/gorm.io/gorm/utils/utils.go @@ -30,8 +30,12 @@ func sourceDir(file string) string { return filepath.ToSlash(s) + "/" } -// FileWithLineNum return the file name and line number of the current file -func FileWithLineNum() string { +// CallerFrame retrieves the first relevant stack frame outside of GORM's internal implementation files. +// It skips: +// - GORM's core source files (identified by gormSourceDir prefix) +// - Exclude test files (*_test.go) +// - go-gorm/gen's Generated files (*.gen.go) +func CallerFrame() runtime.Frame { pcs := [13]uintptr{} // the third caller usually from gorm internal len := runtime.Callers(3, pcs[:]) @@ -41,14 +45,24 @@ func FileWithLineNum() string { frame, _ := frames.Next() if (!strings.HasPrefix(frame.File, gormSourceDir) || strings.HasSuffix(frame.File, "_test.go")) && !strings.HasSuffix(frame.File, ".gen.go") { - return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) + return frame } } + return runtime.Frame{} +} + +// FileWithLineNum return the file name and line number of the current file +func FileWithLineNum() string { + frame := CallerFrame() + if frame.PC != 0 { + return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) + } + return "" } -func IsValidDBNameChar(c rune) bool { +func IsInvalidDBNameChar(c rune) bool { return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' } diff --git a/public-api/vendor/modules.txt b/public-api/vendor/modules.txt index 884eda3a..74c157d9 100644 --- a/public-api/vendor/modules.txt +++ b/public-api/vendor/modules.txt @@ -1231,7 +1231,7 @@ gopkg.in/yaml.v3 # gorm.io/driver/postgres v1.6.0 ## explicit; go 1.20 gorm.io/driver/postgres -# gorm.io/gorm v1.31.0 +# gorm.io/gorm v1.31.1 ## explicit; go 1.18 gorm.io/gorm gorm.io/gorm/callbacks diff --git a/runtime-monitor/go.mod b/runtime-monitor/go.mod index 45c7b5cd..204f83ca 100644 --- a/runtime-monitor/go.mod +++ b/runtime-monitor/go.mod @@ -20,8 +20,8 @@ require ( google.golang.org/grpc v1.70.0 google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.3.0 google.golang.org/protobuf v1.36.6 - gorm.io/driver/postgres v1.5.9 - gorm.io/gorm v1.30.0 + gorm.io/driver/postgres v1.6.0 + gorm.io/gorm v1.31.1 ) require ( @@ -127,7 +127,7 @@ require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgx/v5 v5.6.0 // indirect - github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jgautheron/goconst v1.7.1 // indirect github.com/jingyugao/rowserrcheck v1.1.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect diff --git a/runtime-monitor/go.sum b/runtime-monitor/go.sum index 39f32f70..608e6871 100644 --- a/runtime-monitor/go.sum +++ b/runtime-monitor/go.sum @@ -286,8 +286,8 @@ github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7Ulw github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= -github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= -github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jgautheron/goconst v1.7.1 h1:VpdAG7Ca7yvvJk5n8dMwQhfEZJh95kl/Hl9S1OI5Jkk= github.com/jgautheron/goconst v1.7.1/go.mod h1:aAosetZ5zaeC/2EfMeRswtxUFBpe2Hr7HzkgX4fanO4= github.com/jingyugao/rowserrcheck v1.1.1 h1:zibz55j/MJtLsjP1OF4bSdgXxwL1b+Vn7Tjzq7gFzUs= @@ -787,10 +787,10 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/postgres v1.5.9 h1:DkegyItji119OlcaLjqN11kHoUgZ/j13E0jkJZgD6A8= -gorm.io/driver/postgres v1.5.9/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSkweRGI= -gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs= -gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= +gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= +gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= honnef.co/go/tools v0.6.1 h1:R094WgE8K4JirYjBaOpz/AvTyUu/3wbmAoskKN/pxTI= honnef.co/go/tools v0.6.1/go.mod h1:3puzxxljPCe8RGJX7BIy1plGbxEOZni5mR2aXe3/uk4= k8s.io/api v0.32.0-rc.0 h1:/JeK0EoDPuDmV4YhcojORdB38o3tfSJlEXx6zBIFVBE= diff --git a/runtime-monitor/vendor/github.com/jackc/puddle/v2/CHANGELOG.md b/runtime-monitor/vendor/github.com/jackc/puddle/v2/CHANGELOG.md index a15991c5..d0d202c7 100644 --- a/runtime-monitor/vendor/github.com/jackc/puddle/v2/CHANGELOG.md +++ b/runtime-monitor/vendor/github.com/jackc/puddle/v2/CHANGELOG.md @@ -1,3 +1,8 @@ +# 2.2.2 (September 10, 2024) + +* Add empty acquire time to stats (Maxim Ivanov) +* Stop importing nanotime from runtime via linkname (maypok86) + # 2.2.1 (July 15, 2023) * Fix: CreateResource cannot overflow pool. This changes documented behavior of CreateResource. Previously, diff --git a/runtime-monitor/vendor/github.com/jackc/puddle/v2/README.md b/runtime-monitor/vendor/github.com/jackc/puddle/v2/README.md index 0ad07ec4..fa82a9d4 100644 --- a/runtime-monitor/vendor/github.com/jackc/puddle/v2/README.md +++ b/runtime-monitor/vendor/github.com/jackc/puddle/v2/README.md @@ -1,4 +1,4 @@ -[![](https://godoc.org/github.com/jackc/puddle?status.svg)](https://godoc.org/github.com/jackc/puddle) +[![Go Reference](https://pkg.go.dev/badge/github.com/jackc/puddle/v2.svg)](https://pkg.go.dev/github.com/jackc/puddle/v2) ![Build Status](https://github.com/jackc/puddle/actions/workflows/ci.yml/badge.svg) # Puddle diff --git a/runtime-monitor/vendor/github.com/jackc/puddle/v2/nanotime.go b/runtime-monitor/vendor/github.com/jackc/puddle/v2/nanotime.go new file mode 100644 index 00000000..8a5351a0 --- /dev/null +++ b/runtime-monitor/vendor/github.com/jackc/puddle/v2/nanotime.go @@ -0,0 +1,16 @@ +package puddle + +import "time" + +// nanotime returns the time in nanoseconds since process start. +// +// This approach, described at +// https://github.com/golang/go/issues/61765#issuecomment-1672090302, +// is fast, monotonic, and portable, and avoids the previous +// dependence on runtime.nanotime using the (unsafe) linkname hack. +// In particular, time.Since does less work than time.Now. +func nanotime() int64 { + return time.Since(globalStart).Nanoseconds() +} + +var globalStart = time.Now() diff --git a/runtime-monitor/vendor/github.com/jackc/puddle/v2/nanotime_time.go b/runtime-monitor/vendor/github.com/jackc/puddle/v2/nanotime_time.go deleted file mode 100644 index f8e75938..00000000 --- a/runtime-monitor/vendor/github.com/jackc/puddle/v2/nanotime_time.go +++ /dev/null @@ -1,13 +0,0 @@ -//go:build purego || appengine || js - -// This file contains the safe implementation of nanotime using time.Now(). - -package puddle - -import ( - "time" -) - -func nanotime() int64 { - return time.Now().UnixNano() -} diff --git a/runtime-monitor/vendor/github.com/jackc/puddle/v2/nanotime_unsafe.go b/runtime-monitor/vendor/github.com/jackc/puddle/v2/nanotime_unsafe.go deleted file mode 100644 index fc3b8a25..00000000 --- a/runtime-monitor/vendor/github.com/jackc/puddle/v2/nanotime_unsafe.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build !purego && !appengine && !js - -// This file contains the implementation of nanotime using runtime.nanotime. - -package puddle - -import "unsafe" - -var _ = unsafe.Sizeof(0) - -//go:linkname nanotime runtime.nanotime -func nanotime() int64 diff --git a/runtime-monitor/vendor/github.com/jackc/puddle/v2/pool.go b/runtime-monitor/vendor/github.com/jackc/puddle/v2/pool.go index c8edc0fb..c411d2f6 100644 --- a/runtime-monitor/vendor/github.com/jackc/puddle/v2/pool.go +++ b/runtime-monitor/vendor/github.com/jackc/puddle/v2/pool.go @@ -139,6 +139,7 @@ type Pool[T any] struct { acquireCount int64 acquireDuration time.Duration emptyAcquireCount int64 + emptyAcquireWaitTime time.Duration canceledAcquireCount atomic.Int64 resetCount int @@ -154,7 +155,7 @@ type Config[T any] struct { MaxSize int32 } -// NewPool creates a new pool. Panics if maxSize is less than 1. +// NewPool creates a new pool. Returns an error iff MaxSize is less than 1. func NewPool[T any](config *Config[T]) (*Pool[T], error) { if config.MaxSize < 1 { return nil, errors.New("MaxSize must be >= 1") @@ -202,6 +203,7 @@ type Stat struct { acquireCount int64 acquireDuration time.Duration emptyAcquireCount int64 + emptyAcquireWaitTime time.Duration canceledAcquireCount int64 } @@ -251,6 +253,13 @@ func (s *Stat) EmptyAcquireCount() int64 { return s.emptyAcquireCount } +// EmptyAcquireWaitTime returns the cumulative time waited for successful acquires +// from the pool for a resource to be released or constructed because the pool was +// empty. +func (s *Stat) EmptyAcquireWaitTime() time.Duration { + return s.emptyAcquireWaitTime +} + // CanceledAcquireCount returns the cumulative count of acquires from the pool // that were canceled by a context. func (s *Stat) CanceledAcquireCount() int64 { @@ -266,6 +275,7 @@ func (p *Pool[T]) Stat() *Stat { maxResources: p.maxSize, acquireCount: p.acquireCount, emptyAcquireCount: p.emptyAcquireCount, + emptyAcquireWaitTime: p.emptyAcquireWaitTime, canceledAcquireCount: p.canceledAcquireCount.Load(), acquireDuration: p.acquireDuration, } @@ -363,11 +373,13 @@ func (p *Pool[T]) acquire(ctx context.Context) (*Resource[T], error) { // If a resource is available in the pool. if res := p.tryAcquireIdleResource(); res != nil { + waitTime := time.Duration(nanotime() - startNano) if waitedForLock { p.emptyAcquireCount += 1 + p.emptyAcquireWaitTime += waitTime } p.acquireCount += 1 - p.acquireDuration += time.Duration(nanotime() - startNano) + p.acquireDuration += waitTime p.mux.Unlock() return res, nil } @@ -391,7 +403,9 @@ func (p *Pool[T]) acquire(ctx context.Context) (*Resource[T], error) { p.emptyAcquireCount += 1 p.acquireCount += 1 - p.acquireDuration += time.Duration(nanotime() - startNano) + waitTime := time.Duration(nanotime() - startNano) + p.acquireDuration += waitTime + p.emptyAcquireWaitTime += waitTime return res, nil } diff --git a/runtime-monitor/vendor/gorm.io/driver/postgres/migrator.go b/runtime-monitor/vendor/gorm.io/driver/postgres/migrator.go index df18db1b..6b57ce69 100644 --- a/runtime-monitor/vendor/gorm.io/driver/postgres/migrator.go +++ b/runtime-monitor/vendor/gorm.io/driver/postgres/migrator.go @@ -3,10 +3,10 @@ package postgres import ( "database/sql" "fmt" + "github.com/jackc/pgx/v5" "regexp" "strings" - "github.com/jackc/pgx/v5" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/migrator" @@ -38,28 +38,34 @@ WHERE ` var typeAliasMap = map[string][]string{ - "int": {"integer"}, - "int2": {"smallint"}, - "int4": {"integer"}, - "int8": {"bigint"}, - "smallint": {"int2"}, - "integer": {"int4"}, - "bigint": {"int8"}, - "decimal": {"numeric"}, - "numeric": {"decimal"}, - "timestamptz": {"timestamp with time zone"}, - "timestamp with time zone": {"timestamptz"}, - "bool": {"boolean"}, - "boolean": {"bool"}, - "serial2": {"smallserial"}, - "serial4": {"serial"}, - "serial8": {"bigserial"}, - "varbit": {"bit varying"}, - "char": {"character"}, - "varchar": {"character varying"}, - "float4": {"real"}, - "float8": {"double precision"}, - "timetz": {"time with time zone"}, + "int": {"integer"}, + "int2": {"smallint"}, + "int4": {"integer"}, + "int8": {"bigint"}, + "smallint": {"int2"}, + "integer": {"int4"}, + "bigint": {"int8"}, + "date": {"date"}, + "decimal": {"numeric"}, + "numeric": {"decimal"}, + "timestamp": {"timestamp"}, + "timestamptz": {"timestamp with time zone"}, + "timestamp without time zone": {"timestamp"}, + "timestamp with time zone": {"timestamptz"}, + "bool": {"boolean"}, + "boolean": {"bool"}, + "serial2": {"smallserial"}, + "serial4": {"serial"}, + "serial8": {"bigserial"}, + "varbit": {"bit varying"}, + "char": {"character"}, + "varchar": {"character varying"}, + "float4": {"real"}, + "float8": {"double precision"}, + "time": {"time"}, + "timetz": {"time with time zone"}, + "time without time zone": {"time"}, + "time with time zone": {"timetz"}, } type Migrator struct { @@ -130,7 +136,8 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { } createIndexSQL += "INDEX " - if strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" { + hasConcurrentOption := strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" + if hasConcurrentOption { createIndexSQL += "CONCURRENTLY " } @@ -142,6 +149,10 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { createIndexSQL += " ?" } + if idx.Option != "" && !hasConcurrentOption { + createIndexSQL += " " + idx.Option + } + if idx.Where != "" { createIndexSQL += " WHERE " + idx.Where } @@ -385,10 +396,16 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { return err } } else { - if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil { + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { return err } } + } else if !field.HasDefaultValue { + // case - as-is column has default value and to-be column has no default value + // need to drop default + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { + return err + } } } return nil @@ -484,8 +501,8 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, column.LengthValue = typeLenValue } - if (strings.HasPrefix(column.DefaultValueValue.String, "nextval('") && - strings.HasSuffix(column.DefaultValueValue.String, "seq'::regclass)")) || (identityIncrement.Valid && identityIncrement.String != "") { + autoIncrementValuePattern := regexp.MustCompile(`^nextval\('"?[^']+seq"?'::regclass\)$`) + if autoIncrementValuePattern.MatchString(column.DefaultValueValue.String) || (identityIncrement.Valid && identityIncrement.String != "") { column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true} column.DefaultValueValue = sql.NullString{} } diff --git a/runtime-monitor/vendor/gorm.io/driver/postgres/postgres.go b/runtime-monitor/vendor/gorm.io/driver/postgres/postgres.go index e865b0f8..2d8fd997 100644 --- a/runtime-monitor/vendor/gorm.io/driver/postgres/postgres.go +++ b/runtime-monitor/vendor/gorm.io/driver/postgres/postgres.go @@ -1,13 +1,16 @@ package postgres import ( + "context" "database/sql" "fmt" "regexp" "strconv" "strings" + "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/stdlib" "gorm.io/gorm" "gorm.io/gorm/callbacks" @@ -31,7 +34,7 @@ type Config struct { } var ( - timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )") + timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone|timezone)=(.*?)($|&| )") defaultIdentifierLength = 63 //maximum identifier length for postgres ) @@ -99,10 +102,23 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol } result := timeZoneMatcher.FindStringSubmatch(dialector.Config.DSN) + var options []stdlib.OptionOpenDB if len(result) > 2 { config.RuntimeParams["timezone"] = result[2] + options = append(options, stdlib.OptionAfterConnect(func(ctx context.Context, conn *pgx.Conn) error { + loc, tzErr := time.LoadLocation(result[2]) + if tzErr != nil { + return tzErr + } + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "timestamp", + OID: pgtype.TimestampOID, + Codec: &pgtype.TimestampCodec{ScanLocation: loc}, + }) + return nil + })) } - db.ConnPool = stdlib.OpenDB(*config) + db.ConnPool = stdlib.OpenDB(*config, options...) } return } @@ -228,7 +244,7 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { } return "decimal" case schema.String: - if field.Size > 0 { + if field.Size > 0 && field.Size <= 10485760 { return fmt.Sprintf("varchar(%d)", field.Size) } return "text" diff --git a/runtime-monitor/vendor/gorm.io/gorm/README.md b/runtime-monitor/vendor/gorm.io/gorm/README.md index 745dad60..24eb84c9 100644 --- a/runtime-monitor/vendor/gorm.io/gorm/README.md +++ b/runtime-monitor/vendor/gorm.io/gorm/README.md @@ -3,7 +3,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) -[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) +[![test status](https://github.com/go-gorm/gorm/actions/workflows/tests.yml/badge.svg)](https://github.com/go-gorm/gorm/actions) [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) [![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) diff --git a/runtime-monitor/vendor/gorm.io/gorm/association.go b/runtime-monitor/vendor/gorm.io/gorm/association.go index e3f51d17..3a4e0e25 100644 --- a/runtime-monitor/vendor/gorm.io/gorm/association.go +++ b/runtime-monitor/vendor/gorm.io/gorm/association.go @@ -19,10 +19,10 @@ type Association struct { } func (db *DB) Association(column string) *Association { - association := &Association{DB: db} + association := &Association{DB: db, Unscope: db.Statement.Unscoped} table := db.Statement.Table - if err := db.Statement.Parse(db.Statement.Model); err == nil { + if association.Error = db.Statement.Parse(db.Statement.Model); association.Error == nil { db.Statement.Table = table association.Relationship = db.Statement.Schema.Relationships.Relations[column] @@ -34,8 +34,6 @@ func (db *DB) Association(column string) *Association { for db.Statement.ReflectValue.Kind() == reflect.Ptr { db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() } - } else { - association.Error = err } return association @@ -58,6 +56,8 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro } func (association *Association) Append(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { switch association.Relationship.Type { case schema.HasOne, schema.BelongsTo: @@ -73,6 +73,8 @@ func (association *Association) Append(values ...interface{}) error { } func (association *Association) Replace(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { reflectValue := association.DB.Statement.ReflectValue rel := association.Relationship @@ -97,7 +99,7 @@ func (association *Association) Replace(values ...interface{}) error { return association.Error } - // set old associations's foreign key to null + // set old association's foreign key to null switch rel.Type { case schema.BelongsTo: if len(values) == 0 { @@ -195,6 +197,8 @@ func (association *Association) Replace(values ...interface{}) error { } func (association *Association) Delete(values ...interface{}) error { + values = expandValues(values) + if association.Error == nil { var ( reflectValue = association.DB.Statement.ReflectValue @@ -300,7 +304,7 @@ func (association *Association) Delete(values ...interface{}) error { } if association.Error == nil { - // clean up deleted values's foreign key + // clean up deleted values' foreign key relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { @@ -431,10 +435,49 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } } + processMap := func(mapv reflect.Value) { + child := reflect.New(association.Relationship.FieldSchema.ModelType) + + switch association.Relationship.Type { + case schema.HasMany: + for _, ref := range association.Relationship.References { + key := reflect.ValueOf(ref.ForeignKey.DBName) + if ref.OwnPrimaryKey { + v := ref.PrimaryKey.ReflectValueOf(association.DB.Statement.Context, source) + mapv.SetMapIndex(key, v) + } else if ref.PrimaryValue != "" { + mapv.SetMapIndex(key, reflect.ValueOf(ref.PrimaryValue)) + } + } + association.Error = association.DB.Session(&Session{ + NewDB: true, + }).Model(child.Interface()).Create(mapv.Interface()).Error + case schema.Many2Many: + association.Error = association.DB.Session(&Session{ + NewDB: true, + }).Model(child.Interface()).Create(mapv.Interface()).Error + + for _, key := range mapv.MapKeys() { + k := strings.ToLower(key.String()) + if f, ok := association.Relationship.FieldSchema.FieldsByDBName[k]; ok { + _ = f.Set(association.DB.Statement.Context, child, mapv.MapIndex(key).Interface()) + } + } + appendToFieldValues(child) + } + } + switch rv.Kind() { + case reflect.Map: + processMap(rv) case reflect.Slice, reflect.Array: for i := 0; i < rv.Len(); i++ { - appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr()) + elem := reflect.Indirect(rv.Index(i)) + if elem.Kind() == reflect.Map { + processMap(elem) + continue + } + appendToFieldValues(elem.Addr()) } case reflect.Struct: if !rv.CanAddr() { @@ -591,3 +634,32 @@ func (association *Association) buildCondition() *DB { return tx } + +func expandValues(values ...any) (results []any) { + appendToResult := func(rv reflect.Value) { + // unwrap interface + if rv.IsValid() && rv.Kind() == reflect.Interface { + rv = rv.Elem() + } + if rv.IsValid() && rv.Kind() == reflect.Struct { + p := reflect.New(rv.Type()) + p.Elem().Set(rv) + results = append(results, p.Interface()) + } else if rv.IsValid() { + results = append(results, rv.Interface()) + } + } + + // Process each argument; if an argument is a slice/array, expand its elements + for _, value := range values { + rv := reflect.ValueOf(value) + if rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array { + for i := 0; i < rv.Len(); i++ { + appendToResult(rv.Index(i)) + } + } else { + appendToResult(rv) + } + } + return +} diff --git a/runtime-monitor/vendor/gorm.io/gorm/callbacks.go b/runtime-monitor/vendor/gorm.io/gorm/callbacks.go index 50b5b0e9..bd97f040 100644 --- a/runtime-monitor/vendor/gorm.io/gorm/callbacks.go +++ b/runtime-monitor/vendor/gorm.io/gorm/callbacks.go @@ -89,10 +89,16 @@ func (p *processor) Execute(db *DB) *DB { resetBuildClauses = true } - if optimizer, ok := db.Statement.Dest.(StatementModifier); ok { + if optimizer, ok := stmt.Dest.(StatementModifier); ok { optimizer.ModifyStatement(stmt) } + if db.DefaultContextTimeout > 0 { + if _, ok := stmt.Context.Deadline(); !ok { + stmt.Context, _ = context.WithTimeout(stmt.Context, db.DefaultContextTimeout) + } + } + // assign model values if stmt.Model == nil { stmt.Model = stmt.Dest diff --git a/runtime-monitor/vendor/gorm.io/gorm/callbacks/create.go b/runtime-monitor/vendor/gorm.io/gorm/callbacks/create.go index d8701f51..e5929adb 100644 --- a/runtime-monitor/vendor/gorm.io/gorm/callbacks/create.go +++ b/runtime-monitor/vendor/gorm.io/gorm/callbacks/create.go @@ -53,9 +53,13 @@ func Create(config *Config) func(db *gorm.DB) { if _, ok := db.Statement.Clauses["RETURNING"]; !ok { fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue)) for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { - fromColumns = append(fromColumns, clause.Column{Name: field.DBName}) + if field.Readable { + fromColumns = append(fromColumns, clause.Column{Name: field.DBName}) + } + } + if len(fromColumns) > 0 { + db.Statement.AddClause(clause.Returning{Columns: fromColumns}) } - db.Statement.AddClause(clause.Returning{Columns: fromColumns}) } } } @@ -76,8 +80,11 @@ func Create(config *Config) func(db *gorm.DB) { ok, mode := hasReturning(db, supportReturning) if ok { if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { - if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { + onConflict, _ := c.Expression.(clause.OnConflict) + if onConflict.DoNothing { mode |= gorm.ScanOnConflictDoNothing + } else if len(onConflict.DoUpdates) > 0 || onConflict.UpdateAll { + mode |= gorm.ScanUpdate } } @@ -122,6 +129,16 @@ func Create(config *Config) func(db *gorm.DB) { pkFieldName = "@id" ) + if db.Statement.Schema != nil { + if db.Statement.Schema.PrioritizedPrimaryField == nil || + !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue || + !db.Statement.Schema.PrioritizedPrimaryField.Readable { + return + } + pkField = db.Statement.Schema.PrioritizedPrimaryField + pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName + } + insertID, err := result.LastInsertId() insertOk := err == nil && insertID > 0 @@ -132,14 +149,6 @@ func Create(config *Config) func(db *gorm.DB) { return } - if db.Statement.Schema != nil { - if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - return - } - pkField = db.Statement.Schema.PrioritizedPrimaryField - pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName - } - // append @id column with value for auto-increment primary key // the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1 switch values := db.Statement.Dest.(type) { diff --git a/runtime-monitor/vendor/gorm.io/gorm/chainable_api.go b/runtime-monitor/vendor/gorm.io/gorm/chainable_api.go index 8a6aea34..8f6113cc 100644 --- a/runtime-monitor/vendor/gorm.io/gorm/chainable_api.go +++ b/runtime-monitor/vendor/gorm.io/gorm/chainable_api.go @@ -178,7 +178,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { - tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsInvalidDBNameChar) } else { tx.Statement.Omits = columns } @@ -283,7 +283,7 @@ func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() - fields := strings.FieldsFunc(name, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(name, utils.IsInvalidDBNameChar) tx.Statement.AddClause(clause.GroupBy{ Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, }) diff --git a/runtime-monitor/vendor/gorm.io/gorm/clause/association.go b/runtime-monitor/vendor/gorm.io/gorm/clause/association.go new file mode 100644 index 00000000..a9bf7eb0 --- /dev/null +++ b/runtime-monitor/vendor/gorm.io/gorm/clause/association.go @@ -0,0 +1,35 @@ +package clause + +// AssociationOpType represents association operation types +type AssociationOpType int + +const ( + OpUnlink AssociationOpType = iota // Unlink association + OpDelete // Delete association records + OpUpdate // Update association records + OpCreate // Create association records with assignments +) + +// Association represents an association operation +type Association struct { + Association string // Association name + Type AssociationOpType // Operation type + Conditions []Expression // Filter conditions + Set []Assignment // Assignment operations (for Update and Create) + Values []interface{} // Values for Create operation +} + +// AssociationAssigner is an interface for association operation providers +type AssociationAssigner interface { + AssociationAssignments() []Association +} + +// Assignments implements the Assigner interface so that AssociationOperation can be used as a Set method parameter +func (ao Association) Assignments() []Assignment { + return []Assignment{} +} + +// AssociationAssignments implements the AssociationAssigner interface +func (ao Association) AssociationAssignments() []Association { + return []Association{ao} +} diff --git a/runtime-monitor/vendor/gorm.io/gorm/clause/set.go b/runtime-monitor/vendor/gorm.io/gorm/clause/set.go index 75eb6bdd..cb5f36a0 100644 --- a/runtime-monitor/vendor/gorm.io/gorm/clause/set.go +++ b/runtime-monitor/vendor/gorm.io/gorm/clause/set.go @@ -9,6 +9,11 @@ type Assignment struct { Value interface{} } +// Assigner assignments provider interface +type Assigner interface { + Assignments() []Assignment +} + func (set Set) Name() string { return "SET" } @@ -37,6 +42,9 @@ func (set Set) MergeClause(clause *Clause) { clause.Expression = Set(copiedAssignments) } +// Assignments implements Assigner for Set. +func (set Set) Assignments() []Assignment { return []Assignment(set) } + func Assignments(values map[string]interface{}) Set { keys := make([]string, 0, len(values)) for key := range values { @@ -58,3 +66,6 @@ func AssignmentColumns(values []string) Set { } return assignments } + +// Assignments implements Assigner for a single Assignment. +func (a Assignment) Assignments() []Assignment { return []Assignment{a} } diff --git a/runtime-monitor/vendor/gorm.io/gorm/finisher_api.go b/runtime-monitor/vendor/gorm.io/gorm/finisher_api.go index 57809d17..e9e35f1b 100644 --- a/runtime-monitor/vendor/gorm.io/gorm/finisher_api.go +++ b/runtime-monitor/vendor/gorm.io/gorm/finisher_api.go @@ -465,7 +465,7 @@ func (db *DB) Count(count *int64) (tx *DB) { if len(tx.Statement.Selects) == 1 { dbName := tx.Statement.Selects[0] - fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(dbName, utils.IsInvalidDBNameChar) if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) { if tx.Statement.Parse(tx.Statement.Model) == nil { if f := tx.Statement.Schema.LookUpField(dbName); f != nil { @@ -564,7 +564,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { } if len(tx.Statement.Selects) != 1 { - fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) + fields := strings.FieldsFunc(column, utils.IsInvalidDBNameChar) tx.Statement.AddClauseIfNotExists(clause.Select{ Distinct: tx.Statement.Distinct, Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, @@ -675,9 +675,9 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { } ctx := tx.Statement.Context - if _, ok := ctx.Deadline(); !ok { - if db.Config.DefaultTransactionTimeout > 0 { - ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout) + if db.DefaultTransactionTimeout > 0 { + if _, ok := ctx.Deadline(); !ok { + ctx, _ = context.WithTimeout(ctx, db.DefaultTransactionTimeout) } } diff --git a/runtime-monitor/vendor/gorm.io/gorm/generics.go b/runtime-monitor/vendor/gorm.io/gorm/generics.go index ad2d063f..166d1520 100644 --- a/runtime-monitor/vendor/gorm.io/gorm/generics.go +++ b/runtime-monitor/vendor/gorm.io/gorm/generics.go @@ -3,12 +3,15 @@ package gorm import ( "context" "database/sql" + "errors" "fmt" + "reflect" "sort" "strings" "gorm.io/gorm/clause" "gorm.io/gorm/logger" + "gorm.io/gorm/schema" ) type result struct { @@ -35,10 +38,34 @@ type Interface[T any] interface { } type CreateInterface[T any] interface { - ChainInterface[T] + ExecInterface[T] + // chain methods available at start; Select/Omit keep CreateInterface to allow Create chaining + Scopes(scopes ...func(db *Statement)) ChainInterface[T] + Where(query interface{}, args ...interface{}) ChainInterface[T] + Not(query interface{}, args ...interface{}) ChainInterface[T] + Or(query interface{}, args ...interface{}) ChainInterface[T] + Limit(offset int) ChainInterface[T] + Offset(offset int) ChainInterface[T] + Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] + Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] + Select(query string, args ...interface{}) CreateInterface[T] + Omit(columns ...string) CreateInterface[T] + MapColumns(m map[string]string) ChainInterface[T] + Distinct(args ...interface{}) ChainInterface[T] + Group(name string) ChainInterface[T] + Having(query interface{}, args ...interface{}) ChainInterface[T] + Order(value interface{}) ChainInterface[T] + Build(builder clause.Builder) + + Delete(ctx context.Context) (rowsAffected int, err error) + Update(ctx context.Context, name string, value any) (rowsAffected int, err error) + Updates(ctx context.Context, t T) (rowsAffected int, err error) + Count(ctx context.Context, column string) (result int64, err error) + Table(name string, args ...interface{}) CreateInterface[T] Create(ctx context.Context, r *T) error CreateInBatches(ctx context.Context, r *[]T, batchSize int) error + Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T] } type ChainInterface[T any] interface { @@ -58,15 +85,28 @@ type ChainInterface[T any] interface { Group(name string) ChainInterface[T] Having(query interface{}, args ...interface{}) ChainInterface[T] Order(value interface{}) ChainInterface[T] + Set(assignments ...clause.Assigner) SetUpdateOnlyInterface[T] Build(builder clause.Builder) + Table(name string, args ...interface{}) ChainInterface[T] Delete(ctx context.Context) (rowsAffected int, err error) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) Updates(ctx context.Context, t T) (rowsAffected int, err error) Count(ctx context.Context, column string) (result int64, err error) } +// SetUpdateOnlyInterface is returned by Set after chaining; only Update is allowed +type SetUpdateOnlyInterface[T any] interface { + Update(ctx context.Context) (rowsAffected int, err error) +} + +// SetCreateOrUpdateInterface is returned by Set at start; Create or Update are allowed +type SetCreateOrUpdateInterface[T any] interface { + Create(ctx context.Context) error + Update(ctx context.Context) (rowsAffected int, err error) +} + type ExecInterface[T any] interface { Scan(ctx context.Context, r interface{}) error First(context.Context) (T, error) @@ -142,13 +182,15 @@ func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] { return execG[T]{g: &g[T]{ db: c.db, ops: append(c.ops, func(db *DB) *DB { - return db.Raw(sql, values...) + var r T + return db.Model(r).Raw(sql, values...) }), }} } func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error { - return c.apply(ctx).Exec(sql, values...).Error + var r T + return c.apply(ctx).Model(r).Exec(sql, values...).Error } type createG[T any] struct { @@ -161,6 +203,22 @@ func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { })} } +func (c createG[T]) Select(query string, args ...interface{}) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Select(query, args...) + })} +} + +func (c createG[T]) Omit(columns ...string) CreateInterface[T] { + return createG[T]{c.with(func(db *DB) *DB { + return db.Omit(columns...) + })} +} + +func (c createG[T]) Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T] { + return c.processSet(assignments...) +} + func (c createG[T]) Create(ctx context.Context, r *T) error { return c.g.apply(ctx).Create(r).Error } @@ -187,6 +245,12 @@ func (c chainG[T]) with(v op) chainG[T] { } } +func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] { + return c.with(func(db *DB) *DB { + return db.Table(name, args...) + }) +} + func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { return c.with(func(db *DB) *DB { for _, fc := range scopes { @@ -196,12 +260,6 @@ func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { }) } -func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] { - return c.with(func(db *DB) *DB { - return db.Table(name, args...) - }) -} - func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Where(query, args...) @@ -388,6 +446,10 @@ func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] { }) } +func (c chainG[T]) Set(assignments ...clause.Assigner) SetUpdateOnlyInterface[T] { + return c.processSet(assignments...) +} + func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Distinct(args...) @@ -425,12 +487,12 @@ func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) err relation, ok := db.Statement.Schema.Relationships.Relations[association] if !ok { if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 { - relationships := db.Statement.Schema.Relationships + relationships := &db.Statement.Schema.Relationships for _, field := range preloadFields { var ok bool relation, ok = relationships.Relations[field] if ok { - relationships = relation.FieldSchema.Relationships + relationships = &relation.FieldSchema.Relationships } else { db.AddError(fmt.Errorf("relation %s not found", association)) return nil @@ -567,7 +629,7 @@ func (g execG[T]) First(ctx context.Context) (T, error) { func (g execG[T]) Scan(ctx context.Context, result interface{}) error { var r T - err := g.g.apply(ctx).Model(r).Find(&result).Error + err := g.g.apply(ctx).Model(r).Find(result).Error return err } @@ -603,3 +665,242 @@ func (g execG[T]) Row(ctx context.Context) *sql.Row { func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) { return g.g.apply(ctx).Rows() } + +func (c chainG[T]) processSet(items ...clause.Assigner) setCreateOrUpdateG[T] { + var ( + assigns []clause.Assignment + assocOps []clause.Association + ) + + for _, item := range items { + // Check if it's an AssociationAssigner + if assocAssigner, ok := item.(clause.AssociationAssigner); ok { + assocOps = append(assocOps, assocAssigner.AssociationAssignments()...) + } else { + assigns = append(assigns, item.Assignments()...) + } + } + + return setCreateOrUpdateG[T]{ + c: c, + assigns: assigns, + assocOps: assocOps, + } +} + +// setCreateOrUpdateG[T] is a struct that holds operations to be executed in a batch. +// It supports regular assignments and association operations. +type setCreateOrUpdateG[T any] struct { + c chainG[T] + assigns []clause.Assignment + assocOps []clause.Association +} + +func (s setCreateOrUpdateG[T]) Update(ctx context.Context) (rowsAffected int, err error) { + // Execute association operations + for _, assocOp := range s.assocOps { + if err := s.executeAssociationOperation(ctx, assocOp); err != nil { + return 0, err + } + } + + // Execute assignment operations + if len(s.assigns) > 0 { + var r T + res := s.c.g.apply(ctx).Model(r).Clauses(clause.Set(s.assigns)).Updates(map[string]interface{}{}) + return int(res.RowsAffected), res.Error + } + + return 0, nil +} + +func (s setCreateOrUpdateG[T]) Create(ctx context.Context) error { + // Execute association operations + for _, assocOp := range s.assocOps { + if err := s.executeAssociationOperation(ctx, assocOp); err != nil { + return err + } + } + + // Execute assignment operations + if len(s.assigns) > 0 { + data := make(map[string]interface{}, len(s.assigns)) + for _, a := range s.assigns { + data[a.Column.Name] = a.Value + } + var r T + return s.c.g.apply(ctx).Model(r).Create(data).Error + } + + return nil +} + +// executeAssociationOperation executes an association operation +func (s setCreateOrUpdateG[T]) executeAssociationOperation(ctx context.Context, op clause.Association) error { + var r T + base := s.c.g.apply(ctx).Model(r) + + switch op.Type { + case clause.OpCreate: + return s.handleAssociationCreate(ctx, base, op) + case clause.OpUnlink, clause.OpDelete, clause.OpUpdate: + return s.handleAssociation(ctx, base, op) + default: + return fmt.Errorf("unknown association operation type: %v", op.Type) + } +} + +func (s setCreateOrUpdateG[T]) handleAssociationCreate(ctx context.Context, base *DB, op clause.Association) error { + if len(op.Set) > 0 { + return s.handleAssociationForOwners(base, ctx, func(owner T, assoc *Association) error { + data := make(map[string]interface{}, len(op.Set)) + for _, a := range op.Set { + data[a.Column.Name] = a.Value + } + return assoc.Append(data) + }, op.Association) + } + + return s.handleAssociationForOwners(base, ctx, func(owner T, assoc *Association) error { + return assoc.Append(op.Values...) + }, op.Association) +} + +// handleAssociationForOwners is a helper function that handles associations for all owners +func (s setCreateOrUpdateG[T]) handleAssociationForOwners(base *DB, ctx context.Context, handler func(owner T, association *Association) error, associationName string) error { + var owners []T + if err := base.Find(&owners).Error; err != nil { + return err + } + + for _, owner := range owners { + assoc := base.Session(&Session{NewDB: true, Context: ctx}).Model(&owner).Association(associationName) + if assoc.Error != nil { + return assoc.Error + } + + if err := handler(owner, assoc); err != nil { + return err + } + } + return nil +} + +func (s setCreateOrUpdateG[T]) handleAssociation(ctx context.Context, base *DB, op clause.Association) error { + assoc := base.Association(op.Association) + if assoc.Error != nil { + return assoc.Error + } + + var ( + rel = assoc.Relationship + assocModel = reflect.New(rel.FieldSchema.ModelType).Interface() + fkNil = map[string]any{} + setMap = make(map[string]any, len(op.Set)) + ownerPKNames []string + ownerFKNames []string + primaryColumns []any + foreignColumns []any + ) + + for _, a := range op.Set { + setMap[a.Column.Name] = a.Value + } + + for _, ref := range rel.References { + fkNil[ref.ForeignKey.DBName] = nil + + if ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownerPKNames = append(ownerPKNames, ref.PrimaryKey.DBName) + primaryColumns = append(primaryColumns, clause.Column{Name: ref.PrimaryKey.DBName}) + foreignColumns = append(foreignColumns, clause.Column{Name: ref.ForeignKey.DBName}) + } else if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownerFKNames = append(ownerFKNames, ref.ForeignKey.DBName) + primaryColumns = append(primaryColumns, clause.Column{Name: ref.PrimaryKey.DBName}) + } + } + + assocDB := s.c.g.db.Session(&Session{NewDB: true, Context: ctx}).Model(assocModel).Where(op.Conditions) + + switch rel.Type { + case schema.HasOne, schema.HasMany: + assocDB = assocDB.Where("? IN (?)", foreignColumns, base.Select(ownerPKNames)) + switch op.Type { + case clause.OpUnlink: + return assocDB.Updates(fkNil).Error + case clause.OpDelete: + return assocDB.Delete(assocModel).Error + case clause.OpUpdate: + return assocDB.Updates(setMap).Error + } + case schema.BelongsTo: + switch op.Type { + case clause.OpDelete: + return base.Transaction(func(tx *DB) error { + assocDB.Statement.ConnPool = tx.Statement.ConnPool + base.Statement.ConnPool = tx.Statement.ConnPool + + if err := assocDB.Where("? IN (?)", primaryColumns, base.Select(ownerFKNames)).Delete(assocModel).Error; err != nil { + return err + } + return base.Updates(fkNil).Error + }) + case clause.OpUnlink: + return base.Updates(fkNil).Error + case clause.OpUpdate: + return assocDB.Where("? IN (?)", primaryColumns, base.Select(ownerFKNames)).Updates(setMap).Error + } + case schema.Many2Many: + joinModel := reflect.New(rel.JoinTable.ModelType).Interface() + joinDB := base.Session(&Session{NewDB: true, Context: ctx}).Model(joinModel) + + // EXISTS owners: owners.pk = join.owner_fk for all owner refs + ownersExists := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.Schema.Table).Select("1") + for _, ref := range rel.References { + if ref.OwnPrimaryKey && ref.PrimaryKey != nil { + ownersExists = ownersExists.Where(clause.Eq{ + Column: clause.Column{Table: rel.Schema.Table, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + }) + } + } + + // EXISTS related: related.pk = join.rel_fk for all related refs, plus optional conditions + relatedExists := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.FieldSchema.Table).Select("1") + for _, ref := range rel.References { + if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + relatedExists = relatedExists.Where(clause.Eq{ + Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + }) + } + } + relatedExists = relatedExists.Where(op.Conditions) + + switch op.Type { + case clause.OpUnlink, clause.OpDelete: + joinDB = joinDB.Where("EXISTS (?)", ownersExists) + if len(op.Conditions) > 0 { + joinDB = joinDB.Where("EXISTS (?)", relatedExists) + } + return joinDB.Delete(nil).Error + case clause.OpUpdate: + // Update related table rows that have join rows matching owners + relatedDB := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.FieldSchema.Table).Where(op.Conditions) + + // correlated join subquery: join.rel_fk = related.pk AND EXISTS owners + joinSub := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.JoinTable.Table).Select("1") + for _, ref := range rel.References { + if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { + joinSub = joinSub.Where(clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, + }) + } + } + joinSub = joinSub.Where("EXISTS (?)", ownersExists) + return relatedDB.Where("EXISTS (?)", joinSub).Updates(setMap).Error + } + } + return errors.New("unsupported relationship") +} diff --git a/runtime-monitor/vendor/gorm.io/gorm/gorm.go b/runtime-monitor/vendor/gorm.io/gorm/gorm.go index 67889262..a209bb09 100644 --- a/runtime-monitor/vendor/gorm.io/gorm/gorm.go +++ b/runtime-monitor/vendor/gorm.io/gorm/gorm.go @@ -23,6 +23,7 @@ type Config struct { // You can disable it by setting `SkipDefaultTransaction` to true SkipDefaultTransaction bool DefaultTransactionTimeout time.Duration + DefaultContextTimeout time.Duration // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer @@ -137,6 +138,14 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { return isConfig && !isConfig2 }) + if len(opts) > 0 { + if c, ok := opts[0].(*Config); ok { + config = c + } else { + opts = append([]Option{config}, opts...) + } + } + var skipAfterInitialize bool for _, opt := range opts { if opt != nil { diff --git a/runtime-monitor/vendor/gorm.io/gorm/logger/slog.go b/runtime-monitor/vendor/gorm.io/gorm/logger/slog.go new file mode 100644 index 00000000..27c74683 --- /dev/null +++ b/runtime-monitor/vendor/gorm.io/gorm/logger/slog.go @@ -0,0 +1,116 @@ +//go:build go1.21 + +package logger + +import ( + "context" + "errors" + "fmt" + "log/slog" + "time" + + "gorm.io/gorm/utils" +) + +type slogLogger struct { + Logger *slog.Logger + LogLevel LogLevel + SlowThreshold time.Duration + Parameterized bool + Colorful bool // Ignored in slog + IgnoreRecordNotFoundError bool +} + +func NewSlogLogger(logger *slog.Logger, config Config) Interface { + return &slogLogger{ + Logger: logger, + LogLevel: config.LogLevel, + SlowThreshold: config.SlowThreshold, + Parameterized: config.ParameterizedQueries, + IgnoreRecordNotFoundError: config.IgnoreRecordNotFoundError, + } +} + +func (l *slogLogger) LogMode(level LogLevel) Interface { + newLogger := *l + newLogger.LogLevel = level + return &newLogger +} + +func (l *slogLogger) Info(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Info { + l.log(ctx, slog.LevelInfo, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Warn(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Warn { + l.log(ctx, slog.LevelWarn, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Error(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Error { + l.log(ctx, slog.LevelError, msg, slog.Any("data", data)) + } +} + +func (l *slogLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + if l.LogLevel <= Silent { + return + } + + elapsed := time.Since(begin) + sql, rows := fc() + fields := []slog.Attr{ + slog.String("duration", fmt.Sprintf("%.3fms", float64(elapsed.Nanoseconds())/1e6)), + slog.String("sql", sql), + } + + if rows != -1 { + fields = append(fields, slog.Int64("rows", rows)) + } + + switch { + case err != nil && (!l.IgnoreRecordNotFoundError || !errors.Is(err, ErrRecordNotFound)): + fields = append(fields, slog.String("error", err.Error())) + l.log(ctx, slog.LevelError, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + + case l.SlowThreshold != 0 && elapsed > l.SlowThreshold: + l.log(ctx, slog.LevelWarn, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + + case l.LogLevel >= Info: + l.log(ctx, slog.LevelInfo, "SQL executed", slog.Attr{ + Key: "trace", + Value: slog.GroupValue(fields...), + }) + } +} + +func (l *slogLogger) log(ctx context.Context, level slog.Level, msg string, args ...any) { + if ctx == nil { + ctx = context.Background() + } + + if !l.Logger.Enabled(ctx, level) { + return + } + + r := slog.NewRecord(time.Now(), level, msg, utils.CallerFrame().PC) + r.Add(args...) + _ = l.Logger.Handler().Handle(ctx, r) +} + +// ParamsFilter filter params +func (l *slogLogger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { + if l.Parameterized { + return sql, nil + } + return sql, params +} diff --git a/runtime-monitor/vendor/gorm.io/gorm/migrator/migrator.go b/runtime-monitor/vendor/gorm.io/gorm/migrator/migrator.go index cec4e30f..35107d57 100644 --- a/runtime-monitor/vendor/gorm.io/gorm/migrator/migrator.go +++ b/runtime-monitor/vendor/gorm.io/gorm/migrator/migrator.go @@ -474,7 +474,6 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy // found, smart migrate fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) realDataType := strings.ToLower(columnType.DatabaseTypeName()) - var ( alterColumn bool isSameType = fullDataType == realDataType @@ -513,8 +512,19 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } } + } - // check precision + // check precision + if realDataType == "decimal" || realDataType == "numeric" && + regexp.MustCompile(realDataType+`\(.*\)`).FindString(fullDataType) != "" { // if realDataType has no precision,ignore + precision, scale, ok := columnType.DecimalSize() + if ok { + if !strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d,%d)", realDataType, precision, scale)) && + !strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d)", realDataType, precision)) { + alterColumn = true + } + } + } else { if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { alterColumn = true @@ -550,6 +560,10 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy v1, _ := strconv.ParseBool(dv) v2, _ := strconv.ParseBool(field.DefaultValue) alterColumn = v1 != v2 + case schema.String: + if dv != field.DefaultValue && dv != strings.Trim(field.DefaultValue, "'\"") { + alterColumn = true + } default: alterColumn = dv != field.DefaultValue } diff --git a/runtime-monitor/vendor/gorm.io/gorm/schema/field.go b/runtime-monitor/vendor/gorm.io/gorm/schema/field.go index a6ff1a72..de797402 100644 --- a/runtime-monitor/vendor/gorm.io/gorm/schema/field.go +++ b/runtime-monitor/vendor/gorm.io/gorm/schema/field.go @@ -448,16 +448,17 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } // create valuer, setter when parse struct -func (field *Field) setupValuerAndSetter() { +func (field *Field) setupValuerAndSetter(modelType reflect.Type) { // Setup NewValuePool field.setupNewValuePool() // ValueOf returns field's value and if it is zero fieldIndex := field.StructField.Index[0] switch { - case len(field.StructField.Index) == 1 && fieldIndex > 0: - field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) { - fieldValue := reflect.Indirect(value).Field(fieldIndex) + case len(field.StructField.Index) == 1 && fieldIndex >= 0: + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + v = reflect.Indirect(v) + fieldValue := v.Field(fieldIndex) return fieldValue.Interface(), fieldValue.IsZero() } default: @@ -504,9 +505,10 @@ func (field *Field) setupValuerAndSetter() { // ReflectValueOf returns field's reflect value switch { - case len(field.StructField.Index) == 1 && fieldIndex > 0: - field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value { - return reflect.Indirect(value).Field(fieldIndex) + case len(field.StructField.Index) == 1 && fieldIndex >= 0: + field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { + v = reflect.Indirect(v) + return v.Field(fieldIndex) } default: field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { diff --git a/runtime-monitor/vendor/gorm.io/gorm/schema/relationship.go b/runtime-monitor/vendor/gorm.io/gorm/schema/relationship.go index f1ace924..0535bba4 100644 --- a/runtime-monitor/vendor/gorm.io/gorm/schema/relationship.go +++ b/runtime-monitor/vendor/gorm.io/gorm/schema/relationship.go @@ -75,9 +75,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { } ) - cacheStore := schema.cacheStore - - if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { + if relation.FieldSchema, err = getOrParse(fieldValue, schema.cacheStore, schema.namer); err != nil { schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err) return nil } @@ -147,6 +145,9 @@ func hasPolymorphicRelation(tagSettings map[string]string) bool { } func (schema *Schema) setRelation(relation *Relationship) { + schema.Relationships.Mux.Lock() + defer schema.Relationships.Mux.Unlock() + // set non-embedded relation if rel := schema.Relationships.Relations[relation.Name]; rel != nil { if len(rel.Field.BindNames) > 1 { @@ -590,6 +591,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu // build references for idx, foreignField := range foreignFields { // use same data type for foreign keys + schema.Relationships.Mux.Lock() + if schema != foreignField.Schema { + foreignField.Schema.Relationships.Mux.Lock() + } if copyableDataType(primaryFields[idx].DataType) { foreignField.DataType = primaryFields[idx].DataType } @@ -597,6 +602,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu if foreignField.Size == 0 { foreignField.Size = primaryFields[idx].Size } + schema.Relationships.Mux.Unlock() + if schema != foreignField.Schema { + foreignField.Schema.Relationships.Mux.Unlock() + } relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], diff --git a/runtime-monitor/vendor/gorm.io/gorm/schema/schema.go b/runtime-monitor/vendor/gorm.io/gorm/schema/schema.go index db236797..09697a7a 100644 --- a/runtime-monitor/vendor/gorm.io/gorm/schema/schema.go +++ b/runtime-monitor/vendor/gorm.io/gorm/schema/schema.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "go/ast" + "path" "reflect" "strings" "sync" @@ -59,14 +60,14 @@ type Schema struct { cacheStore *sync.Map } -func (schema Schema) String() string { +func (schema *Schema) String() string { if schema.ModelType.Name() == "" { return fmt.Sprintf("%s(%s)", schema.Name, schema.Table) } return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name()) } -func (schema Schema) MakeSlice() reflect.Value { +func (schema *Schema) MakeSlice() reflect.Value { slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20) results := reflect.New(slice.Type()) results.Elem().Set(slice) @@ -74,13 +75,23 @@ func (schema Schema) MakeSlice() reflect.Value { return results } -func (schema Schema) LookUpField(name string) *Field { +func (schema *Schema) LookUpField(name string) *Field { if field, ok := schema.FieldsByDBName[name]; ok { return field } if field, ok := schema.FieldsByName[name]; ok { return field } + + // Lookup field using namer-driven ColumnName + if schema.namer == nil { + return nil + } + namerColumnName := schema.namer.ColumnName(schema.Table, name) + if field, ok := schema.FieldsByDBName[namerColumnName]; ok { + return field + } + return nil } @@ -92,10 +103,7 @@ func (schema Schema) LookUpField(name string) *Field { // } // ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID") // } -func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { - if len(bindNames) == 0 { - return nil - } +func (schema *Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { for i := len(bindNames) - 1; i >= 0; i-- { find := strings.Join(bindNames[:i], ".") + "." + name if field, ok := schema.FieldsByBindName[find]; ok { @@ -113,6 +121,14 @@ type TablerWithNamer interface { TableName(Namer) string } +var callbackTypes = []callbackType{ + callbackTypeBeforeCreate, callbackTypeAfterCreate, + callbackTypeBeforeUpdate, callbackTypeAfterUpdate, + callbackTypeBeforeSave, callbackTypeAfterSave, + callbackTypeBeforeDelete, callbackTypeAfterDelete, + callbackTypeAfterFind, +} + // Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { return ParseWithSpecialTableName(dest, cacheStore, namer, "") @@ -124,34 +140,33 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - value := reflect.ValueOf(dest) - if value.Kind() == reflect.Ptr && value.IsNil() { - value = reflect.New(value.Type().Elem()) - } - modelType := reflect.Indirect(value).Type() - - if modelType.Kind() == reflect.Interface { - modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() - } - - for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType := reflect.ValueOf(dest).Type() + if modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } if modelType.Kind() != reflect.Struct { - if modelType.PkgPath() == "" { - return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + if modelType.Kind() == reflect.Interface { + modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() + } + + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } - return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } // Cache the Schema for performance, // Use the modelType or modelType + schemaTable (if it present) as cache key. - var schemaCacheKey interface{} + var schemaCacheKey interface{} = modelType if specialTableName != "" { schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName) - } else { - schemaCacheKey = modelType } // Load exist schema cache, return if exists @@ -162,28 +177,29 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return s, s.err } + var tableName string modelValue := reflect.New(modelType) - tableName := namer.TableName(modelType.Name()) - if tabler, ok := modelValue.Interface().(Tabler); ok { + if specialTableName != "" { + tableName = specialTableName + } else if en, ok := namer.(embeddedNamer); ok { + tableName = en.Table + } else if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() - } - if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { + } else if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { tableName = tabler.TableName(namer) - } - if en, ok := namer.(embeddedNamer); ok { - tableName = en.Table - } - if specialTableName != "" && specialTableName != tableName { - tableName = specialTableName + } else { + tableName = namer.TableName(modelType.Name()) } schema := &Schema{ Name: modelType.Name(), ModelType: modelType, Table: tableName, - FieldsByName: map[string]*Field{}, - FieldsByBindName: map[string]*Field{}, - FieldsByDBName: map[string]*Field{}, + DBNames: make([]string, 0, 10), + Fields: make([]*Field, 0, 10), + FieldsByName: make(map[string]*Field, 10), + FieldsByBindName: make(map[string]*Field, 10), + FieldsByDBName: make(map[string]*Field, 10), Relationships: Relationships{Relations: map[string]*Relationship{}}, cacheStore: cacheStore, namer: namer, @@ -227,8 +243,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.FieldsByBindName[bindName] = field if v != nil && v.PrimaryKey { + // remove the existing primary key field for idx, f := range schema.PrimaryFields { - if f == v { + if f.DBName == v.DBName { schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) } } @@ -247,7 +264,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.FieldsByBindName[bindName] = field } - field.setupValuerAndSetter() + field.setupValuerAndSetter(modelType) } prioritizedPrimaryField := schema.LookUpField("id") @@ -283,10 +300,37 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) } + _, embedded := schema.cacheStore.Load(embeddedCacheKey) + relationshipFields := []*Field{} for _, field := range schema.Fields { if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil { schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } + + if !embedded { + if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) { + relationshipFields = append(relationshipFields, field) + schema.FieldsByName[field.Name] = field + schema.FieldsByBindName[field.BindName()] = field + } + + fieldValue := reflect.New(field.IndirectFieldType).Interface() + if fc, ok := fieldValue.(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } + + if fc, ok := fieldValue.(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } + + if fc, ok := fieldValue.(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } + + if fc, ok := fieldValue.(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + } + } } if field := schema.PrioritizedPrimaryField; field != nil { @@ -303,24 +347,6 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } } - callbackTypes := []callbackType{ - callbackTypeBeforeCreate, callbackTypeAfterCreate, - callbackTypeBeforeUpdate, callbackTypeAfterUpdate, - callbackTypeBeforeSave, callbackTypeAfterSave, - callbackTypeBeforeDelete, callbackTypeAfterDelete, - callbackTypeAfterFind, - } - for _, cbName := range callbackTypes { - if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() { - switch methodValue.Type().String() { - case "func(*gorm.DB) error": // TODO hack - reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) - default: - logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) - } - } - } - // Cache the schema if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded { s := v.(*Schema) @@ -336,84 +362,47 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } }() - if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { - for _, field := range schema.Fields { - if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) { - if schema.parseRelation(field); schema.err != nil { - return schema, schema.err + for _, cbName := range callbackTypes { + if methodValue := modelValue.MethodByName(string(cbName)); methodValue.IsValid() { + switch methodValue.Type().String() { + case "func(*gorm.DB) error": + expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath()) + if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath { + reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) } else { - schema.FieldsByName[field.Name] = field - schema.FieldsByBindName[field.BindName()] = field + logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg) + // PASS } + default: + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) } + } + } - fieldValue := reflect.New(field.IndirectFieldType) - fieldInterface := fieldValue.Interface() - if fc, ok := fieldInterface.(CreateClausesInterface); ok { - field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) - } - - if fc, ok := fieldInterface.(QueryClausesInterface); ok { - field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) - } - - if fc, ok := fieldInterface.(UpdateClausesInterface); ok { - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) - } - - if fc, ok := fieldInterface.(DeleteClausesInterface); ok { - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) - } + // parse relationships + for _, field := range relationshipFields { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err } } return schema, schema.err } -// This unrolling is needed to show to the compiler the exact set of methods -// that can be used on the modelType. -// Prior to go1.22 any use of MethodByName would cause the linker to -// abandon dead code elimination for the entire binary. -// As of go1.22 the compiler supports one special case of a string constant -// being passed to MethodByName. For enterprise customers or those building -// large binaries, this gives a significant reduction in binary size. -// https://github.com/golang/go/issues/62257 -func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value { - switch cbType { - case callbackTypeBeforeCreate: - return modelType.MethodByName(string(callbackTypeBeforeCreate)) - case callbackTypeAfterCreate: - return modelType.MethodByName(string(callbackTypeAfterCreate)) - case callbackTypeBeforeUpdate: - return modelType.MethodByName(string(callbackTypeBeforeUpdate)) - case callbackTypeAfterUpdate: - return modelType.MethodByName(string(callbackTypeAfterUpdate)) - case callbackTypeBeforeSave: - return modelType.MethodByName(string(callbackTypeBeforeSave)) - case callbackTypeAfterSave: - return modelType.MethodByName(string(callbackTypeAfterSave)) - case callbackTypeBeforeDelete: - return modelType.MethodByName(string(callbackTypeBeforeDelete)) - case callbackTypeAfterDelete: - return modelType.MethodByName(string(callbackTypeAfterDelete)) - case callbackTypeAfterFind: - return modelType.MethodByName(string(callbackTypeAfterFind)) - default: - return reflect.ValueOf(nil) - } -} - func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() - for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } if modelType.Kind() != reflect.Struct { - if modelType.PkgPath() == "" { - return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } - return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } if v, ok := cacheStore.Load(modelType); ok { diff --git a/runtime-monitor/vendor/gorm.io/gorm/schema/serializer.go b/runtime-monitor/vendor/gorm.io/gorm/schema/serializer.go index 0fafbcba..d774a8ed 100644 --- a/runtime-monitor/vendor/gorm.io/gorm/schema/serializer.go +++ b/runtime-monitor/vendor/gorm.io/gorm/schema/serializer.go @@ -8,6 +8,7 @@ import ( "encoding/gob" "encoding/json" "fmt" + "math" "reflect" "strings" "sync" @@ -127,16 +128,31 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect. // Value implements serializer interface func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { rv := reflect.ValueOf(fieldValue) - switch v := fieldValue.(type) { - case int64, int, uint, uint64, int32, uint32, int16, uint16: - result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() - case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: + switch fieldValue.(type) { + case int, int8, int16, int32, int64: + result = time.Unix(rv.Int(), 0).UTC() + case uint, uint8, uint16, uint32, uint64: + if uv := rv.Uint(); uv > math.MaxInt64 { + err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) + } else { + result = time.Unix(int64(uv), 0).UTC() //nolint:gosec + } + case *int, *int8, *int16, *int32, *int64: if rv.IsZero() { return nil, nil } - result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() + result = time.Unix(rv.Elem().Int(), 0).UTC() + case *uint, *uint8, *uint16, *uint32, *uint64: + if rv.IsZero() { + return nil, nil + } + if uv := rv.Elem().Uint(); uv > math.MaxInt64 { + err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) + } else { + result = time.Unix(int64(uv), 0).UTC() //nolint:gosec + } default: - err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) + err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", fieldValue) } return } diff --git a/runtime-monitor/vendor/gorm.io/gorm/schema/utils.go b/runtime-monitor/vendor/gorm.io/gorm/schema/utils.go index fa1c65d4..86305d7b 100644 --- a/runtime-monitor/vendor/gorm.io/gorm/schema/utils.go +++ b/runtime-monitor/vendor/gorm.io/gorm/schema/utils.go @@ -17,25 +17,23 @@ func ParseTagSetting(str string, sep string) map[string]string { settings := map[string]string{} names := strings.Split(str, sep) + var parsedNames []string for i := 0; i < len(names); i++ { - j := i - if len(names[j]) > 0 { - for { - if names[j][len(names[j])-1] == '\\' { - i++ - names[j] = names[j][0:len(names[j])-1] + sep + names[i] - names[i] = "" - } else { - break - } - } + s := names[i] + for strings.HasSuffix(s, "\\") && i+1 < len(names) { + i++ + s = s[:len(s)-1] + sep + names[i] } + parsedNames = append(parsedNames, s) + } - values := strings.Split(names[j], ":") + for _, tag := range parsedNames { + values := strings.Split(tag, ":") k := strings.TrimSpace(strings.ToUpper(values[0])) - if len(values) >= 2 { - settings[k] = strings.Join(values[1:], ":") + val := strings.Join(values[1:], ":") + val = strings.ReplaceAll(val, `\"`, `"`) + settings[k] = val } else if k != "" { settings[k] = k } @@ -121,6 +119,17 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, } switch reflectValue.Kind() { + case reflect.Map: + results = [][]interface{}{make([]interface{}, len(fields))} + for idx, field := range fields { + mapValue := reflectValue.MapIndex(reflect.ValueOf(field.DBName)) + if mapValue.IsZero() { + mapValue = reflectValue.MapIndex(reflect.ValueOf(field.Name)) + } + results[0][idx] = mapValue.Interface() + } + + dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue} case reflect.Struct: results = [][]interface{}{make([]interface{}, len(fields))} diff --git a/runtime-monitor/vendor/gorm.io/gorm/statement.go b/runtime-monitor/vendor/gorm.io/gorm/statement.go index c6183724..736087d7 100644 --- a/runtime-monitor/vendor/gorm.io/gorm/statement.go +++ b/runtime-monitor/vendor/gorm.io/gorm/statement.go @@ -96,7 +96,9 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { if v.Name == clause.CurrentTable { if stmt.TableExpr != nil { stmt.TableExpr.Build(stmt) - } else { + } else if stmt.Table != "" { + write(v.Raw, stmt.Table) + } else if stmt.AddError(stmt.Parse(stmt.Model)) == nil { write(v.Raw, stmt.Table) } } else { @@ -334,6 +336,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] switch v := arg.(type) { case clause.Expression: conds = append(conds, v) + case []clause.Expression: + conds = append(conds, v...) case *DB: v.executeScopes() @@ -341,7 +345,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if where, ok := cs.Expression.(clause.Where); ok { if len(where.Exprs) == 1 { if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { - where.Exprs[0] = clause.AndConditions(orConds) + if len(orConds.Exprs) == 1 { + where.Exprs[0] = clause.AndConditions(orConds) + } } } conds = append(conds, clause.And(where.Exprs...)) @@ -362,6 +368,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, key := range keys { column := clause.Column{Name: key, Table: curTable} + if strings.Contains(key, ".") { + column = clause.Column{Name: key} + } conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } case map[string]interface{}: @@ -374,6 +383,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, key := range keys { reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) column := clause.Column{Name: key, Table: curTable} + if strings.Contains(key, ".") { + column = clause.Column{Name: key} + } switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if _, ok := v[key].(driver.Valuer); ok { @@ -650,12 +662,15 @@ func (stmt *Statement) Changed(fields ...string) bool { for destValue.Kind() == reflect.Ptr { destValue = destValue.Elem() } - - changedValue, zero := field.ValueOf(stmt.Context, destValue) - if v { - return !utils.AssertEqual(changedValue, fieldValue) + if descSchema, err := schema.Parse(stmt.Dest, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + if destField := descSchema.LookUpField(field.DBName); destField != nil { + changedValue, zero := destField.ValueOf(stmt.Context, destValue) + if v { + return !utils.AssertEqual(changedValue, fieldValue) + } + return !zero && !utils.AssertEqual(changedValue, fieldValue) + } } - return !zero && !utils.AssertEqual(changedValue, fieldValue) } } return false diff --git a/runtime-monitor/vendor/gorm.io/gorm/utils/utils.go b/runtime-monitor/vendor/gorm.io/gorm/utils/utils.go index fc615d73..7e59264b 100644 --- a/runtime-monitor/vendor/gorm.io/gorm/utils/utils.go +++ b/runtime-monitor/vendor/gorm.io/gorm/utils/utils.go @@ -30,8 +30,12 @@ func sourceDir(file string) string { return filepath.ToSlash(s) + "/" } -// FileWithLineNum return the file name and line number of the current file -func FileWithLineNum() string { +// CallerFrame retrieves the first relevant stack frame outside of GORM's internal implementation files. +// It skips: +// - GORM's core source files (identified by gormSourceDir prefix) +// - Exclude test files (*_test.go) +// - go-gorm/gen's Generated files (*.gen.go) +func CallerFrame() runtime.Frame { pcs := [13]uintptr{} // the third caller usually from gorm internal len := runtime.Callers(3, pcs[:]) @@ -41,14 +45,24 @@ func FileWithLineNum() string { frame, _ := frames.Next() if (!strings.HasPrefix(frame.File, gormSourceDir) || strings.HasSuffix(frame.File, "_test.go")) && !strings.HasSuffix(frame.File, ".gen.go") { - return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) + return frame } } + return runtime.Frame{} +} + +// FileWithLineNum return the file name and line number of the current file +func FileWithLineNum() string { + frame := CallerFrame() + if frame.PC != 0 { + return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) + } + return "" } -func IsValidDBNameChar(c rune) bool { +func IsInvalidDBNameChar(c rune) bool { return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' } diff --git a/runtime-monitor/vendor/modules.txt b/runtime-monitor/vendor/modules.txt index 0efaadd1..8c1fb7b9 100644 --- a/runtime-monitor/vendor/modules.txt +++ b/runtime-monitor/vendor/modules.txt @@ -633,7 +633,7 @@ github.com/jackc/pgx/v5/pgproto3 github.com/jackc/pgx/v5/pgtype github.com/jackc/pgx/v5/pgxpool github.com/jackc/pgx/v5/stdlib -# github.com/jackc/puddle/v2 v2.2.1 +# github.com/jackc/puddle/v2 v2.2.2 ## explicit; go 1.19 github.com/jackc/puddle/v2 github.com/jackc/puddle/v2/internal/genstack @@ -1483,10 +1483,10 @@ gopkg.in/yaml.v2 # gopkg.in/yaml.v3 v3.0.1 ## explicit gopkg.in/yaml.v3 -# gorm.io/driver/postgres v1.5.9 -## explicit; go 1.19 +# gorm.io/driver/postgres v1.6.0 +## explicit; go 1.20 gorm.io/driver/postgres -# gorm.io/gorm v1.30.0 +# gorm.io/gorm v1.31.1 ## explicit; go 1.18 gorm.io/gorm gorm.io/gorm/callbacks