Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions common/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,18 @@ type (
SerialConsistency string `yaml:"serialConsistency"`
}

// PasswordCommandConfig configures an external command to fetch the datastore password.
// The command's stdout is used as the password.
PasswordCommandConfig struct {
// Command is the path to the executable to run.
Command string `yaml:"command"`
// Args is the list of arguments to pass to the command.
Args []string `yaml:"args"`
// Timeout is the maximum duration to wait for the command to complete.
// Defaults to 30 seconds if unset.
Timeout time.Duration `yaml:"timeout"`
}

// SQL is the configuration for connecting to a SQL backed datastore
SQL struct {
// Connect is a function that returns a sql db connection. String based configuration is ignored if this is provided.
Expand All @@ -408,6 +420,11 @@ type (
User string `yaml:"user"`
// Password is the password corresponding to the user name
Password string `yaml:"password"`
// PasswordCommand executes an external command and uses its stdout as the password.
// Mutually exclusive with Password.
// If the command returns an expiring token (e.g. cloud IAM), set MaxConnLifetime
// to ensure connections are recycled before the token expires.
PasswordCommand *PasswordCommandConfig `yaml:"passwordCommand"`
// PluginName is the name of SQL plugin
PluginName string `yaml:"pluginName" validate:"nonzero"`
// DatabaseName is the name of SQL database to connect to
Expand Down
54 changes: 52 additions & 2 deletions common/config/persistence.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package config

import (
"bytes"
"context"
"errors"
"fmt"
"os/exec"
"reflect"
"strings"
"time"

"github.com/gocql/gocql"
"go.temporal.io/server/common/persistence/visibility/store/elasticsearch/client"
Expand Down Expand Up @@ -170,8 +174,13 @@ func (ds *DataStore) Validate() error {
)
}

if ds.SQL != nil && ds.SQL.TaskScanPartitions == 0 {
ds.SQL.TaskScanPartitions = 1
if ds.SQL != nil {
if ds.SQL.TaskScanPartitions == 0 {
ds.SQL.TaskScanPartitions = 1
}
if err := ds.SQL.validate(); err != nil {
return err
}
}
if ds.Cassandra != nil {
if err := ds.Cassandra.validate(); err != nil {
Expand Down Expand Up @@ -271,3 +280,44 @@ func parseSerialConsistency(serialConsistency string) (gocql.SerialConsistency,
err := s.UnmarshalText([]byte(strings.ToUpper(serialConsistency)))
return s, err
}

func (c *SQL) validate() error {
Comment thread
simvlad marked this conversation as resolved.
if c.PasswordCommand != nil && c.Password != "" {
return errors.New("passwordCommand and password are mutually exclusive")
}
if c.PasswordCommand != nil && c.PasswordCommand.Command == "" {
return errors.New("passwordCommand.command must not be empty")
}
return nil
}

const (
defaultPasswordCommandTimeout = 30 * time.Second
passwordCommandWaitDelay = 5 * time.Second
)

// ResolvePassword returns the database password, either from the static Password
// field or by executing PasswordCommand. If neither is set, it returns an empty string.
func (c *SQL) ResolvePassword() (string, error) {
if c.PasswordCommand == nil {
return c.Password, nil
}
timeout := c.PasswordCommand.Timeout
if timeout == 0 {
timeout = defaultPasswordCommandTimeout
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
cmd := exec.CommandContext(ctx, c.PasswordCommand.Command, c.PasswordCommand.Args...) //nolint:gosec
// WaitDelay caps how long we block on the stdout pipe after the process is killed.
// Without it, a subprocess that inherits the pipe could keep it open indefinitely.
cmd.WaitDelay = passwordCommandWaitDelay
var stderr bytes.Buffer
cmd.Stderr = &stderr
out, err := cmd.Output()
if err != nil {
return "", fmt.Errorf("passwordCommand %q %v failed: %w (stderr: %s)",
c.PasswordCommand.Command, c.PasswordCommand.Args, err, stderr.String())
}
return strings.TrimRight(string(out), "\n\r"), nil
}
102 changes: 102 additions & 0 deletions common/config/persistence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,112 @@ package config
import (
"reflect"
"testing"
"time"

"github.com/gocql/gocql"
"github.com/stretchr/testify/require"
)

func TestSQLValidate_MutualExclusivity(t *testing.T) {
cfg := &SQL{
Password: "static",
PasswordCommand: &PasswordCommandConfig{
Command: "echo",
Args: []string{"dynamic"},
},
}
err := cfg.validate()
require.ErrorContains(t, err, "mutually exclusive")
}

func TestSQLValidate_PasswordOnly(t *testing.T) {
cfg := &SQL{Password: "static"}
err := cfg.validate()
require.NoError(t, err)
}

func TestSQLValidate_PasswordCommandOnly(t *testing.T) {
cfg := &SQL{
PasswordCommand: &PasswordCommandConfig{
Command: "echo",
Args: []string{"dynamic"},
},
}
err := cfg.validate()
require.NoError(t, err)
}

func TestSQLValidate_PasswordCommandEmptyCommand(t *testing.T) {
cfg := &SQL{
PasswordCommand: &PasswordCommandConfig{},
}
err := cfg.validate()
require.ErrorContains(t, err, "passwordCommand.command must not be empty")
}

func TestSQLResolvePassword_Static(t *testing.T) {
cfg := &SQL{Password: "static-pass"}
pw, err := cfg.ResolvePassword()
require.NoError(t, err)
require.Equal(t, "static-pass", pw)
}

func TestSQLResolvePassword_EmptyWhenNothingSet(t *testing.T) {
cfg := &SQL{}
pw, err := cfg.ResolvePassword()
require.NoError(t, err)
require.Empty(t, pw)
}

func TestSQLResolvePassword_Command(t *testing.T) {
cfg := &SQL{
PasswordCommand: &PasswordCommandConfig{
Command: "echo",
Args: []string{"hello"},
},
}
pw, err := cfg.ResolvePassword()
require.NoError(t, err)
require.Equal(t, "hello", pw)
}

func TestSQLResolvePassword_CommandTrimsTrailingNewline(t *testing.T) {
cfg := &SQL{
PasswordCommand: &PasswordCommandConfig{
Command: "printf",
Args: []string{"hello\n\n"},
},
}
pw, err := cfg.ResolvePassword()
require.NoError(t, err)
require.Equal(t, "hello", pw)
}

func TestSQLResolvePassword_CommandFailure(t *testing.T) {
cfg := &SQL{
PasswordCommand: &PasswordCommandConfig{
Command: "false",
},
}
_, err := cfg.ResolvePassword()
require.ErrorContains(t, err, "passwordCommand")
}

func TestSQLResolvePassword_CommandTimeout(t *testing.T) {
cfg := &SQL{
PasswordCommand: &PasswordCommandConfig{
Command: "sleep",
Args: []string{"10"},
Timeout: 10 * time.Millisecond,
},
}
start := time.Now()
_, err := cfg.ResolvePassword()
elapsed := time.Since(start)
require.ErrorContains(t, err, "passwordCommand")
require.Less(t, elapsed, 5*time.Second, "command should have been killed by timeout")
}

func TestCassandraStoreConsistency_GetConsistency(t *testing.T) {
t.Parallel()

Expand Down
43 changes: 43 additions & 0 deletions common/persistence/sql/sqlplugin/connector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package sqlplugin

import (
"context"
"database/sql/driver"
)

// RefreshingConnector is a driver.Connector that calls buildDSN on every
// Connect, so each new physical connection gets a fresh credential. This is
// used when passwordCommand is configured to fetch short-lived tokens.
type RefreshingConnector struct {
buildDSN func() (string, error)
newConnector func(dsn string) (driver.Connector, error)
driver driver.Driver
}

func NewRefreshingConnector(
buildDSN func() (string, error),
newConnector func(dsn string) (driver.Connector, error),
d driver.Driver,
) *RefreshingConnector {
return &RefreshingConnector{
buildDSN: buildDSN,
newConnector: newConnector,
driver: d,
}
}

func (c *RefreshingConnector) Connect(ctx context.Context) (driver.Conn, error) {
dsn, err := c.buildDSN()
if err != nil {
return nil, err
}
connector, err := c.newConnector(dsn)
if err != nil {
return nil, err
}
return connector.Connect(ctx)
}

func (c *RefreshingConnector) Driver() driver.Driver {
return c.driver
}
74 changes: 74 additions & 0 deletions common/persistence/sql/sqlplugin/connector_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package sqlplugin

import (
"context"
"database/sql/driver"
"errors"
"testing"

"github.com/stretchr/testify/require"
)

type fakeConn struct{}

func (fakeConn) Prepare(string) (driver.Stmt, error) { return nil, nil }
func (fakeConn) Close() error { return nil }
func (fakeConn) Begin() (driver.Tx, error) { return nil, nil }

type fakeConnector struct{ dsn string }

func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) { return fakeConn{}, nil }
func (c *fakeConnector) Driver() driver.Driver { return nil }

func TestRefreshingConnector_CallsBuildDSNOnEachConnect(t *testing.T) {
callCount := 0
c := NewRefreshingConnector(
func() (string, error) {
callCount++
return "dsn-" + string(rune('0'+callCount)), nil
},
func(dsn string) (driver.Connector, error) {
return &fakeConnector{dsn: dsn}, nil
},
nil,
)

for range 3 {
_, err := c.Connect(context.Background())
require.NoError(t, err)
}
require.Equal(t, 3, callCount)
}

func TestRefreshingConnector_PropagatesBuildDSNError(t *testing.T) {
expectedErr := errors.New("token expired")
c := NewRefreshingConnector(
func() (string, error) {
return "", expectedErr
},
func(dsn string) (driver.Connector, error) {
t.Fatal("NewConnector should not be called when BuildDSN fails")
return nil, nil
},
nil,
)

_, err := c.Connect(context.Background())
require.ErrorIs(t, err, expectedErr)
}

func TestRefreshingConnector_PropagatesNewConnectorError(t *testing.T) {
expectedErr := errors.New("bad dsn")
c := NewRefreshingConnector(
func() (string, error) {
return "some-dsn", nil
},
func(dsn string) (driver.Connector, error) {
return nil, expectedErr
},
nil,
)

_, err := c.Connect(context.Background())
require.ErrorIs(t, err, expectedErr)
}
Loading
Loading