Skip to content
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/apache/cassandra-gocql-driver/v2 v2.0.0
github.com/datastax/go-cassandra-native-protocol v0.0.0-20260130100129-9d5b43677a33
github.com/google/uuid v1.1.1
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/jpillora/backoff v1.0.0
github.com/kelseyhightower/envconfig v1.4.0
github.com/mcuadros/go-defaults v1.2.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA=
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
Expand Down
1 change: 1 addition & 0 deletions integration-tests/setup/testcluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ func NewTestConfig(originHost string, targetHost string) *config.Config {

conf.ProxyMaxClientConnections = 1000
conf.ProxyMaxStreamIds = 2048
conf.ProxyMaxPreparedStatementCacheSize = 10000

conf.RequestResponseMaxWorkers = -1
conf.WriteMaxWorkers = -1
Expand Down
11 changes: 6 additions & 5 deletions proxy/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,12 @@ type Config struct {

// Proxy bucket

ProxyListenAddress string `default:"localhost" split_words:"true" yaml:"proxy_listen_address"`
ProxyListenPort int `default:"14002" split_words:"true" yaml:"proxy_listen_port"`
ProxyRequestTimeoutMs int `default:"10000" split_words:"true" yaml:"proxy_request_timeout_ms"`
ProxyMaxClientConnections int `default:"1000" split_words:"true" yaml:"proxy_max_client_connections"`
ProxyMaxStreamIds int `default:"2048" split_words:"true" yaml:"proxy_max_stream_ids"`
ProxyListenAddress string `default:"localhost" split_words:"true" yaml:"proxy_listen_address"`
ProxyListenPort int `default:"14002" split_words:"true" yaml:"proxy_listen_port"`
ProxyRequestTimeoutMs int `default:"10000" split_words:"true" yaml:"proxy_request_timeout_ms"`
ProxyMaxClientConnections int `default:"1000" split_words:"true" yaml:"proxy_max_client_connections"`
ProxyMaxStreamIds int `default:"2048" split_words:"true" yaml:"proxy_max_stream_ids"`
ProxyMaxPreparedStatementCacheSize int `default:"10000" split_words:"true" yaml:"proxy_max_prepared_statement_cache_size"`

ProxyTlsCaPath string `split_words:"true" yaml:"proxy_tls_ca_path"`
ProxyTlsCertPath string `split_words:"true" yaml:"proxy_tls_cert_path"`
Expand Down
16 changes: 8 additions & 8 deletions proxy/pkg/zdmproxy/cqlparser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ func TestInspectFrame(t *testing.T) {
targetPreparedId: []byte("LOCAL"),
prepareRequestInfo: NewPrepareRequestInfo(NewInterceptedRequestInfo(local, newStarSelectClause()), nil, false, "SELECT * FROM system.local", ""),
}
psCache := NewPreparedStatementCache()
psCache.cache["BOTH"] = bothCacheEntry
psCache.cache["ORIGIN"] = originCacheEntry
psCache.cache["TARGET"] = targetCacheEntry
psCache.interceptedCache["PEERS"] = peersCacheEntry
psCache.interceptedCache["PEERS_KS"] = peersKsCacheEntry
psCache.interceptedCache["LOCAL"] = localCacheEntry
psCache.interceptedCache["LOCAL_KS"] = localKsCacheEntry
psCache := createPSCacheForTests(t)
psCache.cache.Add("BOTH", bothCacheEntry)
psCache.cache.Add("ORIGIN", originCacheEntry)
psCache.cache.Add("TARGET", targetCacheEntry)
psCache.interceptedCache.Add("PEERS", peersCacheEntry)
psCache.interceptedCache.Add("PEERS_KS", peersKsCacheEntry)
psCache.interceptedCache.Add("LOCAL", localCacheEntry)
psCache.interceptedCache.Add("LOCAL_KS", localKsCacheEntry)
mh := newFakeMetricHandler()
km := ""
primaryClusterTarget := common.ClusterTypeTarget
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func getGeneralParamsForTests(t *testing.T) params {
require.Nil(t, err)

return params{
psCache: NewPreparedStatementCache(),
psCache: createPSCacheForTests(t),
mh: newFakeMetricHandler(),
kn: "",
primaryCluster: common.ClusterTypeOrigin,
Expand All @@ -38,6 +38,12 @@ func getGeneralParamsForTests(t *testing.T) params {
}
}

func createPSCacheForTests(t *testing.T) *PreparedStatementCache {
psCache, err := NewPreparedStatementCache(1000)
require.Nil(t, err)
return psCache
}

func buildQueryMessageForTests(queryString string) *message.Query {
var defaultTimestamp int64 = 1647023221311969
var serialConsistency = primitive.ConsistencyLevelLocalSerial
Expand Down
5 changes: 4 additions & 1 deletion proxy/pkg/zdmproxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,10 @@ func (p *ZdmProxy) initializeGlobalStructures() error {
p.globalClientHandlersWg = &sync.WaitGroup{}
p.clientHandlersShutdownRequestCtx, p.clientHandlersShutdownRequestCancelFn = context.WithCancel(context.Background())

p.PreparedStatementCache = NewPreparedStatementCache()
p.PreparedStatementCache, err = NewPreparedStatementCache(p.Conf.ProxyMaxPreparedStatementCacheSize)
if err != nil {
return err
}

p.controlConnShutdownCtx, p.controlConnCancelFn = context.WithCancel(context.Background())
p.controlConnShutdownWg = &sync.WaitGroup{}
Expand Down
71 changes: 50 additions & 21 deletions proxy/pkg/zdmproxy/pscache.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,57 @@ package zdmproxy
import (
"encoding/hex"
"fmt"
"sync"

"github.com/datastax/go-cassandra-native-protocol/message"
"github.com/hashicorp/golang-lru/v2/simplelru"
log "github.com/sirupsen/logrus"
"sync"
)

type PreparedStatementCache struct {
cache map[string]PreparedData // Map containing the prepared queries (raw bytes) keyed on prepareId
index map[string]string // Map that can be used as an index to look up origin prepareIds by target prepareId
cache *simplelru.LRU[string, PreparedData] // Map containing the prepared queries (raw bytes) keyed on prepareId
index map[string]string // Map that can be used as an index to look up origin prepareIds by target prepareId

interceptedCache map[string]PreparedData // Map containing the prepared queries for intercepted requests
interceptedCache *simplelru.LRU[string, PreparedData] // Map containing the prepared queries for intercepted requests

lock *sync.RWMutex
}

func NewPreparedStatementCache() *PreparedStatementCache {
func NewPreparedStatementCache(maxSize int) (*PreparedStatementCache, error) {
indexMap := make(map[string]string)

cache, err := simplelru.NewLRU[string, PreparedData](maxSize, func(key string, value PreparedData) {
// this is called by LRU.Add() so we already have a lock here
delete(indexMap, string(value.GetTargetPreparedId()))
})
if err != nil {
return nil, fmt.Errorf("error initializing the PreparedStatementCache cache map: %v", err)
}

interceptedCache, err := simplelru.NewLRU[string, PreparedData](maxSize, nil)
if err != nil {
return nil, fmt.Errorf("error initializing the PreparedStatementCache interceptedCache map: %v", err)
}

return &PreparedStatementCache{
cache: make(map[string]PreparedData),
index: make(map[string]string),
interceptedCache: make(map[string]PreparedData),
cache: cache,
index: indexMap,
interceptedCache: interceptedCache,
lock: &sync.RWMutex{},
}
}, nil
}

func (psc PreparedStatementCache) GetPreparedStatementCacheSize() float64 {
psc.lock.RLock()
defer psc.lock.RUnlock()

return float64(len(psc.cache) + len(psc.interceptedCache))
cacheLen := psc.cache.Len()
interceptedCacheLen := psc.interceptedCache.Len()

log.Debugf("PS Cache Size: %v, PS Intercepted Size: %v, PS Index Size: %v.",
cacheLen, interceptedCacheLen, len(psc.index))

return float64(cacheLen + interceptedCacheLen)
}

func (psc *PreparedStatementCache) Store(
Expand All @@ -42,7 +65,7 @@ func (psc *PreparedStatementCache) Store(
psc.lock.Lock()
defer psc.lock.Unlock()

psc.cache[originPrepareIdStr] = NewPreparedData(originPreparedResult, targetPreparedResult, prepareRequestInfo)
psc.cache.Add(originPrepareIdStr, NewPreparedData(originPreparedResult, targetPreparedResult, prepareRequestInfo))
psc.index[targetPrepareIdStr] = originPrepareIdStr

log.Debugf("Storing PS cache entry: {OriginPreparedId=%v, TargetPreparedId: %v, RequestInfo: %v}",
Expand All @@ -55,33 +78,39 @@ func (psc *PreparedStatementCache) StoreIntercepted(preparedResult *message.Prep
defer psc.lock.Unlock()

preparedData := NewPreparedData(preparedResult, preparedResult, prepareRequestInfo)
psc.interceptedCache[prepareIdStr] = preparedData
psc.interceptedCache.Add(prepareIdStr, preparedData)

log.Debugf("Storing intercepted PS cache entry: {PreparedId=%v, RequestInfo: %v}",
hex.EncodeToString(preparedResult.PreparedQueryId), prepareRequestInfo)
}

func (psc *PreparedStatementCache) Get(originPreparedId []byte) (PreparedData, bool) {
psc.lock.RLock()
defer psc.lock.RUnlock()
data, ok := psc.cache[string(originPreparedId)]
if !ok {
data, ok = psc.interceptedCache[string(originPreparedId)]
psc.lock.Lock()
defer psc.lock.Unlock()
data, ok := psc.cache.Get(string(originPreparedId))
if ok {
return data, true
}

data, ok = psc.interceptedCache.Get(string(originPreparedId))
if ok {
return data, true
}
return data, ok

return nil, false
}

func (psc *PreparedStatementCache) GetByTargetPreparedId(targetPreparedId []byte) (PreparedData, bool) {
psc.lock.RLock()
defer psc.lock.RUnlock()
psc.lock.Lock()
defer psc.lock.Unlock()

originPreparedId, ok := psc.index[string(targetPreparedId)]
if !ok {
// Don't bother attempting a lookup on the intercepted cache because this method should only be used to handle UNPREPARED responses
return nil, false
}

data, ok := psc.cache[originPreparedId]
data, ok := psc.cache.Get(originPreparedId)
if !ok {
log.Errorf("Could not get prepared data by target id even though there is an entry on the index map. "+
"This is most likely a bug. OriginPreparedId = %v, TargetPreparedId = %v", originPreparedId, targetPreparedId)
Expand Down
Loading
Loading