From 23dc945ae0ccd0b64d35466fba0e920f1d2993f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20de=20Le=C3=B3n?= Date: Tue, 28 Apr 2026 17:13:16 -0300 Subject: [PATCH 1/6] chore: update fucntions to return errors --- configstore/configstore.go | 40 ++++++++----- configstore/configstore_test.go | 88 +++++++++++++++------------- pluginmanager/pluginmanager.go | 80 +++++++++++++++---------- pluginmanager/pluginmanager_test.go | 9 ++- wacecore.go | 52 ++++++++++------- wacecore_test.go | 91 +++++++++++++++++++++-------- 6 files changed, 229 insertions(+), 131 deletions(-) diff --git a/configstore/configstore.go b/configstore/configstore.go index 98f52c8..c0dfc21 100644 --- a/configstore/configstore.go +++ b/configstore/configstore.go @@ -76,8 +76,8 @@ type modelPluginConfig struct { Threshold float64 Params map[string]string PluginType ModelPluginType - Mode string - Remote bool + Mode string + Remote bool } // DecisionPluginConfig stores the configuration of a decision plugin @@ -95,18 +95,32 @@ type ConfigStore struct { DecisionPlugins map[string]decisionPluginConfig LogPath string LogLevel lg.LogLevel - NatsURL string - ApplicationId string + NatsURL string + ApplicationId string } var config *ConfigStore -// Get returns or creates the unique instance of configstore -func Get() *ConfigStore { +// Create and returns the unique instance of configstore if it does not exist previously, in other case returns error +func New() (*ConfigStore, error) { + if config != nil { + return nil, fmt.Errorf("ConfigStore already exists") + } + config = new(ConfigStore) + return config, nil +} + +// Get returns the unique instance of configstore +func Get() (*ConfigStore, error) { if config == nil { - config = new(ConfigStore) + return nil, fmt.Errorf("Configuration was not loaded") } - return config + return config, nil +} + +// Clean remove the references to the stored instance of configstore +func Clean() { + config = nil } type configFileModelPlugin struct { @@ -116,8 +130,8 @@ type configFileModelPlugin struct { Threshold float64 Params map[string]string PluginType string `yaml:"plugintype"` - Mode string - Remote bool + Mode string + Remote bool } type configFileDecisionPlugin struct { @@ -133,7 +147,7 @@ type ConfigFileData struct { Loglevel string Modelplugins []configFileModelPlugin Decisionplugins []configFileDecisionPlugin - NatsURL string + NatsURL string } // IsAsync returns true if the model plugin is async @@ -243,6 +257,6 @@ func (cs *ConfigStore) SetConfig(inConf ConfigFileData) error { } else { cs.NatsURL = "localhost:4222" } - + return nil -} \ No newline at end of file +} diff --git a/configstore/configstore_test.go b/configstore/configstore_test.go index c5ed707..67abf0f 100644 --- a/configstore/configstore_test.go +++ b/configstore/configstore_test.go @@ -43,9 +43,12 @@ decisionplugins: `) func initialize(configuration []byte) error { - cs := Get() + cs, err := Get() + if err != nil { + return err + } var aux ConfigFileData - err := yaml.Unmarshal(configuration, &aux) + err = yaml.Unmarshal(configuration, &aux) if err != nil { return err } @@ -57,31 +60,55 @@ func initialize(configuration []byte) error { } func TestLoadConfigYamlEmpty(t *testing.T) { + _, err := New() + if err != nil { + t.Error(err) + } + + defer Clean() - err := initialize([]byte(`---`)) + err = initialize([]byte(`---`)) if err == nil { - t.Errorf("empty config does not return error") + t.Error("empty config does not return error") } } func TestLoadConfigYamlValid(t *testing.T) { + _, err := New() + if err != nil { + t.Error(err) + } + + defer Clean() - err := initialize(validConfig) + err = initialize(validConfig) if err != nil { t.Errorf("valid config returned error: %v", err) } } func TestLoadConfigYamlInvalid(t *testing.T) { + _, err := New() + if err != nil { + t.Error(err) + } + + defer Clean() - err := initialize([]byte(`()=)(/&/()~@#~½¬{[{½¬½---sfdjlskjfs#@~sjdfa`)) + err = initialize([]byte(`()=)(/&/()~@#~½¬{[{½¬½---sfdjlskjfs#@~sjdfa`)) if err == nil { - t.Errorf("invalid config does not return error") + t.Error("invalid config does not return error") } } func TestLoadConfigYamlLogLevel(t *testing.T) { + _, err := New() + if err != nil { + t.Error(err) + } + + defer Clean() values := []string{ "a", @@ -93,7 +120,7 @@ func TestLoadConfigYamlLogLevel(t *testing.T) { config := `--- logpath: "/dev/null" loglevel: ` + v - err := initialize([]byte(config)) + err = initialize([]byte(config)) if err == nil { t.Errorf("invalid log level %v does not return error", v) } @@ -101,9 +128,14 @@ loglevel: ` + v } func TestLoadConfigYamlPluginType(t *testing.T) { - cs := Get() + cs, err := New() + if err != nil { + t.Error(err) + } + + defer Clean() - err := initialize([]byte(`--- + err = initialize([]byte(`--- loglevel: ERROR logpath: /dev/null modelplugins: @@ -203,37 +235,15 @@ modelplugins: } } -// func TestLoadConfig(t *testing.T) { -// cs := Get() - -// err := cs.LoadConfig("") -// if err == nil { -// t.Errorf("empty config file path does not return error") -// } - -// err = cs.LoadConfig("/dev/null") -// if err == nil { -// t.Errorf("empty config file contents does not return error") -// } - -// tmpFile, err := ioutil.TempFile(os.TempDir(), "configstore_test-") -// if err != nil { -// t.Errorf("cannot create temporary file: %v", err) -// } -// defer os.Remove(tmpFile.Name()) - -// if _, err = tmpFile.Write(validConfig); err != nil { -// t.Errorf("failed to write to temporary file: %v", err) -// } -// err = cs.LoadConfig(tmpFile.Name()) -// if err != nil { -// t.Errorf("valid config file returned error: %v", err) -// } -// } - func TestInvalidLogging(t *testing.T) { + _, err := New() + if err != nil { + t.Error(err) + } + + defer Clean() - err := initialize([]byte(`--- + err = initialize([]byte(`--- loglevel: INVALIDLOGLEVEL logpath: /dev/null `)) diff --git a/pluginmanager/pluginmanager.go b/pluginmanager/pluginmanager.go index 21b4534..2b3e46c 100644 --- a/pluginmanager/pluginmanager.go +++ b/pluginmanager/pluginmanager.go @@ -78,9 +78,12 @@ type PluginManager struct { } // New creates a new PluginManager instance. -func New(meter metric.Meter) *PluginManager { +func New(meter metric.Meter) (*PluginManager, error) { pm := new(PluginManager) - conf := cf.Get() + conf, err := cf.Get() + if err != nil { + return nil, err + } logger := lg.Get() logger.Printf(lg.DEBUG, "Connecting to NATS server at %s", conf.NatsURL) @@ -187,7 +190,7 @@ func New(meter metric.Meter) *PluginManager { decisionPluginLoaded := decisionPlugin{tp} pm.decisionPlugins[data.ID] = decisionPluginLoaded } - return pm + return pm, nil } // InitTransaction initializes the transaction with the given ID @@ -205,8 +208,9 @@ func (p *PluginManager) CloseTransaction(transactionId string) { } else { transactionMap.(*sync.Map).Range(func(key, value interface{}) bool { ch := value.(chan ModelStatus) - close(ch) - for range ch {} + close(ch) + for range ch { + } transactionMap.(*sync.Map).Delete(key) return true }) @@ -241,13 +245,14 @@ func (p *PluginManager) RemoveAsyncModelChannel(transactionId string, t cf.Model typeModel, ok := p.asyncModelsChannels.Load(transactionId) if ok { channelMap := typeModel.(*sync.Map) - ch, channelOk := channelMap.Load(t.String()) + ch, channelOk := channelMap.Load(t.String()) - if channelOk { + if channelOk { close(ch.(chan ModelStatus)) - for range ch.(chan ModelStatus) {} - channelMap.Delete(t.String()) - } + for range ch.(chan ModelStatus) { + } + channelMap.Delete(t.String()) + } remainChannels := 0 typeModel.(*sync.Map).Range(func(key, value interface{}) bool { @@ -280,44 +285,48 @@ func (p *PluginManager) AddToQueue(modelId, transactionId, payload string) error } // Process is in charge of calling the model plugin with id modelID -func (p *PluginManager) Process(modelID, transactionId, payload string, t cf.ModelPluginType, modelPlugStatus chan ModelStatus) { - conf := cf.Get() +func (p *PluginManager) Process(modelID, transactionId, payload string, t cf.ModelPluginType, modelPlugStatus chan ModelStatus) error { + conf, err := cf.Get() + if err != nil { + return err + } mp, exists := p.modelPlugins[modelID] if !exists { modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("model plugin not found")} - return + return nil } // check if the plugin is capable of analyzing the indicated part of the transaction if mp.pluginType != t { modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("plugin type %v cannot process a request with incompatible type %v", mp.pluginType, t)} - return + return nil } process := p.modelProcessFunc[modelID] if conf.ModelPlugins[modelID].Mode == "async" { modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("model plugin is async")} - return + return nil } else { res, err := process(ModelInput{TransactionId: transactionId, Payload: payload}) // res, err := process(transactionId, payload) if err != nil { modelPlugStatus <- ModelStatus{ModelID: modelID, Err: err} - return + return nil } // store the results resultSyncMap, ok := p.results.Load(transactionId) if !ok { modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("transaction results not found")} - return + return nil } resultSyncMap.(*sync.Map).Store(modelID, res) modelPlugStatus <- ModelStatus{ModelID: modelID, ProbAttack: res.ProbAttack, Err: nil} } + return nil } // CheckResult is in charge of calling the decision plugin with id decisionID over the @@ -335,13 +344,16 @@ func (p *PluginManager) CheckResult(transactionId, decisionId string, wafParams return false, fmt.Errorf("transaction results not found") } - configStore := cf.Get() + cs, err := cf.Get() + if err != nil { + return false, nil + } modelResultMap := make(map[string]ModelResults) modelWeightMap := make(map[string]float64) transactionResults.(*sync.Map).Range(func(key, value interface{}) bool { modelResultMap[key.(string)] = value.(ModelResults) - modelWeightMap[key.(string)] = configStore.ModelPlugins[key.(string)].Weight + modelWeightMap[key.(string)] = cs.ModelPlugins[key.(string)].Weight return true }) @@ -352,9 +364,12 @@ func (p *PluginManager) CheckResult(transactionId, decisionId string, wafParams } // ModelResultsHandler listens for messages on the model results queue -func (p *PluginManager) ModelResultsHandler(modelId string) { +func (p *PluginManager) ModelResultsHandler(modelId string) error { logger := lg.Get() - conf := cf.Get() + cs, err := cf.Get() + if err != nil { + return err + } sub, err := p.natConn.Subscribe(modelId+"/results", func(msg *nats.Msg) { go func(msg nats.Msg) { @@ -365,7 +380,7 @@ func (p *PluginManager) ModelResultsHandler(modelId string) { } else { var channel interface{} var ok bool - if conf.ModelPlugins[modelId].Mode == "async" { + if cs.ModelPlugins[modelId].Mode == "async" { channel, ok = p.asyncModelsChannels.Load(data.TransactionId) } else { channel, ok = p.syncModelsChannels.Load(data.TransactionId) @@ -373,14 +388,14 @@ func (p *PluginManager) ModelResultsHandler(modelId string) { if !ok { logger.TPrintf(lg.ERROR, data.TransactionId, " Model %s | Transaction not found", modelId) } else { - modelChannel, ok := channel.(*sync.Map).Load(conf.ModelPlugins[modelId].PluginType.String()) + modelChannel, ok := channel.(*sync.Map).Load(cs.ModelPlugins[modelId].PluginType.String()) if !ok { logger.Printf(lg.ERROR, "Model %s not found", modelId) } else { if data.Error != nil { modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, Err: data.Error} } else { - if conf.ModelPlugins[modelId].Mode != "async" { + if cs.ModelPlugins[modelId].Mode != "async" { // store the results resultSyncMap, ok := p.results.Load(data.TransactionId) if !ok { @@ -400,7 +415,7 @@ func (p *PluginManager) ModelResultsHandler(modelId string) { if err != nil { logger.Printf(lg.ERROR, "Model: %s | Failed to subscribe to model queue | %s", modelId, err.Error()) - return + return err } logger.Printf(lg.INFO, "Model: %s | Listening for messages on model results queue", modelId) @@ -409,19 +424,23 @@ func (p *PluginManager) ModelResultsHandler(modelId string) { defer p.natConn.Drain() select {} + } // ModelProcessHandler listens for messages on the model queue -func ModelProcessHandler(modelId string, modelProcess func(ModelInput) (ModelResults, error)) { +func ModelProcessHandler(modelId string, modelProcess func(ModelInput) (ModelResults, error)) error { logger := lg.Get() logger.Printf(lg.INFO, "Model: %s | Starting model process handler", modelId) - conf := cf.Get() + cs, err := cf.Get() + if err != nil { + return err + } - nc, err := nats.Connect(conf.NatsURL) + nc, err := nats.Connect(cs.NatsURL) if err != nil { logger.Printf(lg.ERROR, "Model: %s | Failed to connect to NATS server", modelId) - return + return err } _, err = nc.Subscribe(modelId, func(msg *nats.Msg) { @@ -452,8 +471,9 @@ func ModelProcessHandler(modelId string, modelProcess func(ModelInput) (ModelRes if err != nil { logger.Printf(lg.ERROR, "Model: %s | Failed to subscribe to model queue | %s", modelId, err.Error()) - return + return err } logger.Printf(lg.INFO, "Model: %s | Listening for messages on model queue", modelId) + return nil } diff --git a/pluginmanager/pluginmanager_test.go b/pluginmanager/pluginmanager_test.go index a00efa0..c6e6f48 100644 --- a/pluginmanager/pluginmanager_test.go +++ b/pluginmanager/pluginmanager_test.go @@ -56,14 +56,17 @@ func initilize(configuration []byte) error { if err != nil { return err } - err = cf.Get().SetConfig(aux) + cs, err := cf.Get() + if err != nil { + return err + } + err = cs.SetConfig(aux) if err != nil { return err } logger := lg.Get() - conf := cf.Get() - err = logger.LoadLogger(conf.LogPath, conf.LogLevel) + err = logger.LoadLogger(cs.LogPath, cs.LogLevel) if err != nil { return err diff --git a/wacecore.go b/wacecore.go index 3d5388b..c41a0f8 100644 --- a/wacecore.go +++ b/wacecore.go @@ -5,7 +5,6 @@ package wace import ( "fmt" - "os" "strings" "sync" "sync/atomic" @@ -62,7 +61,7 @@ func addTransactionAnalysis(transactionID string) { // callPlugins calls the model plugins in the given list, with the given input. // It waits for all the synchronous model plugins to finish, and sends the // result to the client. The asynchronous model plugins are executed in parallel -func callPlugins(input string, models []string, t cf.ModelPluginType, transactionId string) { +func callPlugins(input string, models []string, t cf.ModelPluginType, transactionId string) error { logger := lg.Get() // channel to receive the status of the execution of the analysis @@ -73,7 +72,10 @@ func callPlugins(input string, models []string, t cf.ModelPluginType, transactio plugins.AddModelChannel(transactionId, t, asyncModelPlugStatus, "async") plugins.AddModelChannel(transactionId, t, modelPlugStatus, "sync") - conf := cf.Get() + conf, err := cf.Get() + if err != nil { + return err + } syncCounter := 0 asyncCounter := 0 @@ -154,10 +156,11 @@ func callPlugins(input string, models []string, t cf.ModelPluginType, transactio value, ok := analysisMap.Load(transactionId) if !ok { logger.TPrintf(lg.ERROR, transactionId, "core | could not find transaction %s in analysis map", transactionId) - return + return fmt.Errorf("core | could not find transaction %s in analysis map", transactionId) } analysisChan := value.(*transactionSync).Channel analysisChan <- "done" + return nil } // InitTransaction initializes a transaction with the given id @@ -191,40 +194,40 @@ func Analyze(modelsTypeAsString, transactionId, payload string, models []string) // CheckTransaction checks the result of the analysis of the transaction // with the given id and decision plugin -func CheckTransaction(transactionID, decisionPlugin string, wafParams map[string]string) (bool, error) { +func CheckTransaction(transactionId, decisionPlugin string, wafParams map[string]string) (bool, error) { logger := lg.Get() - logger.TPrintf(lg.DEBUG, transactionID, "core | checking transaction") + logger.TPrintf(lg.DEBUG, transactionId, "core | checking transaction") - value, exists := analysisMap.Load(transactionID) + value, exists := analysisMap.Load(transactionId) if !exists { - return false, fmt.Errorf("transaction with id %s does not exist", transactionID) + return false, fmt.Errorf("transaction with id %s does not exist", transactionId) } sync := value.(*transactionSync) - logger.TPrintln(lg.DEBUG, transactionID, "core | waiting for all models to finish...") + logger.TPrintln(lg.DEBUG, transactionId, "core | waiting for all models to finish...") for i := 0; i < int(sync.Counter); i++ { <-sync.Channel } sync.Counter = 0 - logger.TPrintln(lg.DEBUG, transactionID, "core | done, checking data...") - res, err := plugins.CheckResult(transactionID, decisionPlugin, wafParams) + logger.TPrintln(lg.DEBUG, transactionId, "core | done, checking data...") + res, err := plugins.CheckResult(transactionId, decisionPlugin, wafParams) if err == nil { - logger.TPrintf(lg.DEBUG, transactionID, "core | transaction checked successfully. Blocking transaction: %t", res) + logger.TPrintf(lg.DEBUG, transactionId, "core | transaction checked successfully. Blocking transaction: %t", res) if res { metric, err := meter.Int64Counter("wace.client.request.blocked.total", metric.WithDescription(decisionPlugin)) if err != nil { - logger.TPrintf(lg.WARN, transactionID, "core | failed to record blocked request metric: %v", err.Error()) + logger.TPrintf(lg.WARN, transactionId, "core | failed to record blocked request metric: %v", err.Error()) } metric.Add(ctx, 1) } } else { - logger.TPrintf(lg.ERROR, transactionID, "core | could not check transaction: %v", err) + logger.TPrintf(lg.ERROR, transactionId, "core | could not check transaction: %v", err) } return res, err } @@ -235,31 +238,36 @@ func CloseTransaction(transactionID string) { plugins.CloseTransaction(transactionID) value, ok := analysisMap.Load(transactionID) logger := lg.Get() - + if !ok { logger.TPrintf(lg.ERROR, transactionID, "Analysis for transaction %s not found", transactionID) } else { close(value.(*transactionSync).Channel) - for range value.(*transactionSync).Channel {} + for range value.(*transactionSync).Channel { + } analysisMap.Delete(transactionID) } } // Init initializes the WACE core with the given metric meter -func Init(met metric.Meter) { +func Init(met metric.Meter) error { logger := lg.Get() - conf := cf.Get() + conf, err := cf.Get() meter = met - err := logger.LoadLogger(conf.LogPath, conf.LogLevel) + err = logger.LoadLogger(conf.LogPath, conf.LogLevel) if err != nil { logger.Printf(lg.ERROR, "ERROR: could not open wace log file: %v", err) - os.Exit(1) - + return err } logger.Printf(lg.DEBUG, "Writing logs to %s from now", conf.LogPath) logger.Println(lg.DEBUG, "Loading plugin manager...") - plugins = pm.New(met) + plugins, err = pm.New(met) + if err != nil { + return err + } logger.Println(lg.DEBUG, "Plugin manager loaded") + + return nil } diff --git a/wacecore_test.go b/wacecore_test.go index 1c2a34a..0309a03 100644 --- a/wacecore_test.go +++ b/wacecore_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - cf "github.com/tilsor/ModSecIntl_wace_lib/configstore" + "github.com/tilsor/ModSecIntl_wace_lib/configstore" "go.opentelemetry.io/otel/sdk/metric" "gopkg.in/yaml.v3" @@ -237,16 +237,23 @@ var provider = metric.NewMeterProvider() var testMeter = provider.Meter("example-meter") func initilize(configuration []byte) error { - var aux cf.ConfigFileData + var aux configstore.ConfigFileData err := yaml.Unmarshal(configuration, &aux) if err != nil { return err } - err = cf.Get().SetConfig(aux) + cs, err := configstore.Get() + if err != nil { + return err + } + err = cs.SetConfig(aux) + if err != nil { + return err + } + err = Init(testMeter) if err != nil { return err } - Init(testMeter) return nil } @@ -261,7 +268,13 @@ func generateRandomID() string { } func TestAnalyzeRequestInParts(t *testing.T) { - err := initilize(configAllModels) + _, err := configstore.New() + if err != nil { + t.Error(err) + } + + defer configstore.Clean() + err = initilize(configAllModels) if err != nil { t.Errorf("Error initing test: %v", err) } @@ -288,7 +301,13 @@ func TestAnalyzeRequestInParts(t *testing.T) { } func TestAnalyzeWholeRequest(t *testing.T) { - err := initilize(configAllModels) + _, err := configstore.New() + if err != nil { + t.Error(err) + } + + defer configstore.Clean() + err = initilize(configAllModels) if err != nil { t.Errorf("Error initing test: %v", err) } @@ -311,7 +330,13 @@ func TestAnalyzeWholeRequest(t *testing.T) { } func TestAnalyzeResponseInParts(t *testing.T) { - err := initilize(configAllModels) + _, err := configstore.New() + if err != nil { + t.Error(err) + } + + defer configstore.Clean() + err = initilize(configAllModels) if err != nil { t.Errorf("Error initing test: %v", err) } @@ -338,7 +363,13 @@ func TestAnalyzeResponseInParts(t *testing.T) { } func TestAnalyzeWholeResponse(t *testing.T) { - err := initilize(configAllModels) + _, err := configstore.New() + if err != nil { + t.Error(err) + } + + defer configstore.Clean() + err = initilize(configAllModels) if err != nil { t.Errorf("Error initing test: %v", err) } @@ -361,13 +392,16 @@ func TestAnalyzeWholeResponse(t *testing.T) { } func TestAnalyzeRequestInPartsAsync(t *testing.T) { - var aux cf.ConfigFileData - err := yaml.Unmarshal(configAsync, &aux) + _, err := configstore.New() + if err != nil { + t.Error(err) + } + + defer configstore.Clean() + err = initilize(configAsync) if err != nil { t.Errorf("Error initing test: %v", err) } - err = cf.Get().SetConfig(aux) - Init(testMeter) transactionID := generateRandomID() InitTransaction(transactionID) @@ -395,13 +429,16 @@ func TestCheckInvalidTransaction(t *testing.T) { } func TestCheckAttackTransaction(t *testing.T) { - var aux cf.ConfigFileData - err := yaml.Unmarshal(configSyncNoRemote, &aux) + _, err := configstore.New() + if err != nil { + t.Error(err) + } + + defer configstore.Clean() + err = initilize(configSyncNoRemote) if err != nil { t.Errorf("Error initing test: %v", err) } - err = cf.Get().SetConfig(aux) - Init(testMeter) transactionID := generateRandomID() InitTransaction(transactionID) @@ -476,13 +513,16 @@ func TestCheckAttackTransaction(t *testing.T) { func BenchmarkTrivial(b *testing.B) { - var aux cf.ConfigFileData - err := yaml.Unmarshal(configSyncNoRemote, &aux) + _, err := configstore.New() + if err != nil { + b.Error(err) + } + + defer configstore.Clean() + err = initilize(configSyncNoRemote) if err != nil { b.Errorf("Error initing test: %v", err) } - err = cf.Get().SetConfig(aux) - Init(testMeter) wafParams := make(map[string]string) auxString := "COMBINED_SCORE=0,HTTP=0,LFI=0,PHPI=0,RCE=0,RFI=0,SESS=0,SQLI=0,XSS=0,inbound_blocking=0,inbound_detection=0,inbound_per_pl=0-0-0-0,inbound_threshold=5,outbound_blocking=0,outbound_detection=0,outbound_per_pl=0-0-0-0,outbound_threshold=4,phase=2" for _, score := range strings.Split(auxString, ",") { @@ -504,13 +544,16 @@ func BenchmarkTrivial(b *testing.B) { } func BenchmarkTrivialFullNATS(b *testing.B) { - var aux cf.ConfigFileData - err := yaml.Unmarshal(configSyncRemote, &aux) + _, err := configstore.New() + if err != nil { + b.Error(err) + } + + defer configstore.Clean() + err = initilize(configSyncRemote) if err != nil { b.Errorf("Error initing test: %v", err) } - err = cf.Get().SetConfig(aux) - Init(testMeter) time.Sleep(2 * time.Millisecond) wafParams := make(map[string]string) auxString := "COMBINED_SCORE=0,HTTP=0,LFI=0,PHPI=0,RCE=0,RFI=0,SESS=0,SQLI=0,XSS=0,inbound_blocking=0,inbound_detection=0,inbound_per_pl=0-0-0-0,inbound_threshold=5,outbound_blocking=0,outbound_detection=0,outbound_per_pl=0-0-0-0,outbound_threshold=4,phase=2" From 12b00d36d2cefe19c7424665fc2174164b6e7d70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20de=20Le=C3=B3n?= Date: Tue, 28 Apr 2026 17:19:12 -0300 Subject: [PATCH 2/6] chore: update pkg names --- configstore/configstore.go | 9 ++- pluginmanager/pluginmanager.go | 98 ++++++++++++++--------------- pluginmanager/pluginmanager_test.go | 14 ++--- wacecore.go | 90 +++++++++++++------------- 4 files changed, 105 insertions(+), 106 deletions(-) diff --git a/configstore/configstore.go b/configstore/configstore.go index c0dfc21..26001f6 100644 --- a/configstore/configstore.go +++ b/configstore/configstore.go @@ -7,10 +7,9 @@ package configstore import ( "fmt" - "io/ioutil" "os" - lg "github.com/tilsor/ModSecIntl_logging/logging" + "github.com/tilsor/ModSecIntl_logging/logging" ) // ModelPluginType is an enum listing the parts of a request or @@ -94,7 +93,7 @@ type ConfigStore struct { ModelPlugins map[string]modelPluginConfig DecisionPlugins map[string]decisionPluginConfig LogPath string - LogLevel lg.LogLevel + LogLevel logging.LogLevel NatsURL string ApplicationId string } @@ -165,7 +164,7 @@ func checkLogging(inConf ConfigFileData) error { if err != nil { // check if log file does not exists already // Attempt to create dummy file var d []byte - err = ioutil.WriteFile(inConf.Logpath, d, 0644) + err = os.WriteFile(inConf.Logpath, d, 0644) if err == nil { err = os.Remove(inConf.Logpath) // delete it } @@ -219,7 +218,7 @@ func (cs *ConfigStore) SetConfig(inConf ConfigFileData) error { } cs.LogPath = inConf.Logpath - cs.LogLevel, err = lg.StringToLogLevel(inConf.Loglevel) + cs.LogLevel, err = logging.StringToLogLevel(inConf.Loglevel) if err != nil { return err } diff --git a/pluginmanager/pluginmanager.go b/pluginmanager/pluginmanager.go index 2b3e46c..a6a319b 100644 --- a/pluginmanager/pluginmanager.go +++ b/pluginmanager/pluginmanager.go @@ -10,11 +10,11 @@ import ( "plugin" "sync" - cf "github.com/tilsor/ModSecIntl_wace_lib/configstore" + "github.com/tilsor/ModSecIntl_wace_lib/configstore" "go.opentelemetry.io/otel/metric" "github.com/nats-io/nats.go" - lg "github.com/tilsor/ModSecIntl_logging/logging" + "github.com/tilsor/ModSecIntl_logging/logging" ) // ResultData maps the model plugin ID with the corresponding analysis result. @@ -47,7 +47,7 @@ type ModelTransmitionResults struct { // modelPlugin is the struct that stores the model plugin and its type type modelPlugin struct { p *plugin.Plugin - pluginType cf.ModelPluginType + pluginType configstore.ModelPluginType } // decisionPlugin is the struct that stores the decision plugin @@ -80,17 +80,17 @@ type PluginManager struct { // New creates a new PluginManager instance. func New(meter metric.Meter) (*PluginManager, error) { pm := new(PluginManager) - conf, err := cf.Get() + conf, err := configstore.Get() if err != nil { return nil, err } - logger := lg.Get() - logger.Printf(lg.DEBUG, "Connecting to NATS server at %s", conf.NatsURL) + logger := logging.Get() + logger.Printf(logging.DEBUG, "Connecting to NATS server at %s", conf.NatsURL) nc, err := nats.Connect(conf.NatsURL) if err != nil { - logger.Printf(lg.ERROR, "Failed to connect to NATS server") + logger.Printf(logging.ERROR, "Failed to connect to NATS server") } pm.natConn = nc @@ -101,55 +101,55 @@ func New(meter metric.Meter) (*PluginManager, error) { for _, data := range conf.ModelPlugins { tp, err := plugin.Open(data.Path) if err != nil { - logger.Printf(lg.WARN, "| %s | cannot load plugin: %v", data.ID, err) + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) continue } if data.Mode == "async" || conf.ModelPlugins[data.ID].Remote { f, err := tp.Lookup("InitPluginAsync") if err != nil { - logger.Printf(lg.WARN, "| %s | cannot load plugin: %v", data.ID, err) + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) continue } initPlugin, ok := f.(func(map[string]string, metric.Meter, func(func(ModelInput) (ModelResults, error))) error) if !ok { - logger.Printf(lg.WARN, "| %s | cannot load plugin: invalid InitPluginAsync function type", data.ID) + logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid InitPluginAsync function type", data.ID) continue } err = initPlugin(data.Params, meter, func(modelProcess func(ModelInput) (ModelResults, error)) { ModelProcessHandler(data.ID, modelProcess) }) if err != nil { - logger.Printf(lg.WARN, "| %s | cannot load plugin: %v", data.ID, err) + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) continue } go pm.ModelResultsHandler(data.ID) } else { f, err := tp.Lookup("InitPlugin") if err != nil { - logger.Printf(lg.WARN, "| %s | cannot load plugin: %v", data.ID, err) + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) continue } initPlugin, ok := f.(func(map[string]string, metric.Meter) error) if !ok { - logger.Printf(lg.WARN, "| %s | cannot load plugin: invalid InitPlugin function type", data.ID) + logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid InitPlugin function type", data.ID) continue } err = initPlugin(data.Params, meter) procFunc, err := tp.Lookup("Process") if err != nil { - logger.Printf(lg.WARN, "| %s | cannot load plugin: cannot load Process function", data.ID) + logger.Printf(logging.WARN, "| %s | cannot load plugin: cannot load Process function", data.ID) continue } process, ok := procFunc.(func(ModelInput) (ModelResults, error)) if !ok { - logger.Printf(lg.WARN, "| %s | cannot load plugin: invalid Process function type", data.ID) + logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid Process function type", data.ID) continue } pm.modelProcessFunc[data.ID] = process } modelPluginLoaded := modelPlugin{tp, data.PluginType} pm.modelPlugins[data.ID] = modelPluginLoaded - logger.Printf(lg.INFO, "| %s | plugin loaded", data.ID) + logger.Printf(logging.INFO, "| %s | plugin loaded", data.ID) } pm.decisionPlugins = make(map[string]decisionPlugin) @@ -158,32 +158,32 @@ func New(meter metric.Meter) (*PluginManager, error) { for _, data := range conf.DecisionPlugins { tp, err := plugin.Open(data.Path) if err != nil { - logger.Printf(lg.WARN, "| %s | cannot load plugin: %v", data.ID, err) + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) continue } f, err := tp.Lookup("InitPlugin") if err != nil { - logger.Printf(lg.WARN, "| %s | cannot load plugin: %v", data.ID, err) + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) continue } initPlugin, ok := f.(func(map[string]string, metric.Meter) error) if !ok { - logger.Printf(lg.WARN, "| %s | cannot load plugin: invalid InitPlugin function type", data.ID) + logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid InitPlugin function type", data.ID) continue } err = initPlugin(data.Params, meter) if err != nil { - logger.Printf(lg.WARN, "| %s | cannot load plugin: %v", data.ID, err) + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) continue } cR, err := tp.Lookup("CheckResults") if err != nil { - logger.Printf(lg.ERROR, "| %s | cannot load plugin check results function: %v", data.ID, err) + logger.Printf(logging.ERROR, "| %s | cannot load plugin check results function: %v", data.ID, err) continue } checkResults, ok := cR.(func(DecisionInput) (bool, error)) if !ok { - logger.Printf(lg.ERROR, "| %s | CheckResults lookup failed for plugin: invalid function type", data.ID) + logger.Printf(logging.ERROR, "| %s | CheckResults lookup failed for plugin: invalid function type", data.ID) continue } pm.decisionCheckFunc[data.ID] = checkResults @@ -201,10 +201,10 @@ func (p *PluginManager) InitTransaction(transactionId string) { // CloseTransaction closes the transaction with the given ID // removing all sync model data func (p *PluginManager) CloseTransaction(transactionId string) { - logger := lg.Get() + logger := logging.Get() transactionMap, ok := p.syncModelsChannels.Load(transactionId) if !ok { - logger.TPrintf(lg.ERROR, transactionId, "Transaction %s not found", transactionId) + logger.TPrintf(logging.ERROR, transactionId, "Transaction %s not found", transactionId) } else { transactionMap.(*sync.Map).Range(func(key, value interface{}) bool { ch := value.(chan ModelStatus) @@ -217,7 +217,7 @@ func (p *PluginManager) CloseTransaction(transactionId string) { p.syncModelsChannels.Delete(transactionId) resultsMap, ok := p.results.Load(transactionId) if !ok { - logger.TPrintf(lg.ERROR, transactionId, "Results for transaction %s not found", transactionId) + logger.TPrintf(logging.ERROR, transactionId, "Results for transaction %s not found", transactionId) } else { resultsMap.(*sync.Map).Range(func(key, value interface{}) bool { resultsMap.(*sync.Map).Delete(key) @@ -229,7 +229,7 @@ func (p *PluginManager) CloseTransaction(transactionId string) { } // AddModelChannel adds a channel to result channel map -func (p *PluginManager) AddModelChannel(transactionId string, t cf.ModelPluginType, modelPlugStatus chan ModelStatus, modelType string) { +func (p *PluginManager) AddModelChannel(transactionId string, t configstore.ModelPluginType, modelPlugStatus chan ModelStatus, modelType string) { typeModel := new(sync.Map) var value interface{} if modelType == "sync" { @@ -241,7 +241,7 @@ func (p *PluginManager) AddModelChannel(transactionId string, t cf.ModelPluginTy } // RemoveModelChannel removes a channel from the result channel map -func (p *PluginManager) RemoveAsyncModelChannel(transactionId string, t cf.ModelPluginType) { +func (p *PluginManager) RemoveAsyncModelChannel(transactionId string, t configstore.ModelPluginType) { typeModel, ok := p.asyncModelsChannels.Load(transactionId) if ok { channelMap := typeModel.(*sync.Map) @@ -263,8 +263,8 @@ func (p *PluginManager) RemoveAsyncModelChannel(transactionId string, t cf.Model p.asyncModelsChannels.Delete(transactionId) } } else { - logger := lg.Get() - logger.TPrintf(lg.ERROR, transactionId, "Transaction %s not found when trying to remove async model channel", transactionId) + logger := logging.Get() + logger.TPrintf(logging.ERROR, transactionId, "Transaction %s not found when trying to remove async model channel", transactionId) } } @@ -285,8 +285,8 @@ func (p *PluginManager) AddToQueue(modelId, transactionId, payload string) error } // Process is in charge of calling the model plugin with id modelID -func (p *PluginManager) Process(modelID, transactionId, payload string, t cf.ModelPluginType, modelPlugStatus chan ModelStatus) error { - conf, err := cf.Get() +func (p *PluginManager) Process(modelID, transactionId, payload string, t configstore.ModelPluginType, modelPlugStatus chan ModelStatus) error { + conf, err := configstore.Get() if err != nil { return err } @@ -332,7 +332,7 @@ func (p *PluginManager) Process(modelID, transactionId, payload string, t cf.Mod // CheckResult is in charge of calling the decision plugin with id decisionID over the // transaction with id transactID func (p *PluginManager) CheckResult(transactionId, decisionId string, wafParams map[string]string) (bool, error) { - logger := lg.Get() + logger := logging.Get() checkResults, ok := p.decisionCheckFunc[decisionId] if !ok { @@ -344,7 +344,7 @@ func (p *PluginManager) CheckResult(transactionId, decisionId string, wafParams return false, fmt.Errorf("transaction results not found") } - cs, err := cf.Get() + cs, err := configstore.Get() if err != nil { return false, nil } @@ -358,15 +358,15 @@ func (p *PluginManager) CheckResult(transactionId, decisionId string, wafParams }) res, err := checkResults(DecisionInput{TransactionId: transactionId, Results: modelResultMap, ModelWeight: modelWeightMap, WAFdata: wafParams}) - logger.TPrintf(lg.INFO, transactionId, "%s | transaction checked. Block: %t ", decisionId, res) + logger.TPrintf(logging.INFO, transactionId, "%s | transaction checked. Block: %t ", decisionId, res) return res, err } // ModelResultsHandler listens for messages on the model results queue func (p *PluginManager) ModelResultsHandler(modelId string) error { - logger := lg.Get() - cs, err := cf.Get() + logger := logging.Get() + cs, err := configstore.Get() if err != nil { return err } @@ -376,7 +376,7 @@ func (p *PluginManager) ModelResultsHandler(modelId string) error { data := &ModelTransmitionResults{} err := json.Unmarshal(msg.Data, data) if err != nil { - logger.Printf(lg.ERROR, "Model: %s | Failed to parse JSON payload", modelId) + logger.Printf(logging.ERROR, "Model: %s | Failed to parse JSON payload", modelId) } else { var channel interface{} var ok bool @@ -386,11 +386,11 @@ func (p *PluginManager) ModelResultsHandler(modelId string) error { channel, ok = p.syncModelsChannels.Load(data.TransactionId) } if !ok { - logger.TPrintf(lg.ERROR, data.TransactionId, " Model %s | Transaction not found", modelId) + logger.TPrintf(logging.ERROR, data.TransactionId, " Model %s | Transaction not found", modelId) } else { modelChannel, ok := channel.(*sync.Map).Load(cs.ModelPlugins[modelId].PluginType.String()) if !ok { - logger.Printf(lg.ERROR, "Model %s not found", modelId) + logger.Printf(logging.ERROR, "Model %s not found", modelId) } else { if data.Error != nil { modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, Err: data.Error} @@ -414,11 +414,11 @@ func (p *PluginManager) ModelResultsHandler(modelId string) error { }) if err != nil { - logger.Printf(lg.ERROR, "Model: %s | Failed to subscribe to model queue | %s", modelId, err.Error()) + logger.Printf(logging.ERROR, "Model: %s | Failed to subscribe to model queue | %s", modelId, err.Error()) return err } - logger.Printf(lg.INFO, "Model: %s | Listening for messages on model results queue", modelId) + logger.Printf(logging.INFO, "Model: %s | Listening for messages on model results queue", modelId) defer sub.Unsubscribe() defer p.natConn.Drain() @@ -429,9 +429,9 @@ func (p *PluginManager) ModelResultsHandler(modelId string) error { // ModelProcessHandler listens for messages on the model queue func ModelProcessHandler(modelId string, modelProcess func(ModelInput) (ModelResults, error)) error { - logger := lg.Get() - logger.Printf(lg.INFO, "Model: %s | Starting model process handler", modelId) - cs, err := cf.Get() + logger := logging.Get() + logger.Printf(logging.INFO, "Model: %s | Starting model process handler", modelId) + cs, err := configstore.Get() if err != nil { return err } @@ -439,7 +439,7 @@ func ModelProcessHandler(modelId string, modelProcess func(ModelInput) (ModelRes nc, err := nats.Connect(cs.NatsURL) if err != nil { - logger.Printf(lg.ERROR, "Model: %s | Failed to connect to NATS server", modelId) + logger.Printf(logging.ERROR, "Model: %s | Failed to connect to NATS server", modelId) return err } @@ -448,7 +448,7 @@ func ModelProcessHandler(modelId string, modelProcess func(ModelInput) (ModelRes data := &ModelInput{} err := json.Unmarshal(msg.Data, data) if err != nil { - logger.Printf(lg.ERROR, "Model: %s | Failed to parse JSON payload", modelId) + logger.Printf(logging.ERROR, "Model: %s | Failed to parse JSON payload", modelId) } else { res, err := modelProcess(*data) modelResult := ModelResults{ProbAttack: res.ProbAttack, Data: res.Data} @@ -461,7 +461,7 @@ func ModelProcessHandler(modelId string, modelProcess func(ModelInput) (ModelRes jsonPayload, err := json.Marshal(payloadToSend) if err != nil { - logger.Printf(lg.ERROR, "Model: %s | Failed to parse JSON payload", modelId) + logger.Printf(logging.ERROR, "Model: %s | Failed to parse JSON payload", modelId) } nc.Publish(modelId+"/results", jsonPayload) @@ -470,10 +470,10 @@ func ModelProcessHandler(modelId string, modelProcess func(ModelInput) (ModelRes }) if err != nil { - logger.Printf(lg.ERROR, "Model: %s | Failed to subscribe to model queue | %s", modelId, err.Error()) + logger.Printf(logging.ERROR, "Model: %s | Failed to subscribe to model queue | %s", modelId, err.Error()) return err } - logger.Printf(lg.INFO, "Model: %s | Listening for messages on model queue", modelId) + logger.Printf(logging.INFO, "Model: %s | Listening for messages on model queue", modelId) return nil } diff --git a/pluginmanager/pluginmanager_test.go b/pluginmanager/pluginmanager_test.go index c6e6f48..1d52d43 100644 --- a/pluginmanager/pluginmanager_test.go +++ b/pluginmanager/pluginmanager_test.go @@ -4,11 +4,11 @@ import ( "math/rand" "time" - cf "github.com/tilsor/ModSecIntl_wace_lib/configstore" + "github.com/tilsor/ModSecIntl_wace_lib/configstore" "go.opentelemetry.io/otel/sdk/metric" "gopkg.in/yaml.v3" - lg "github.com/tilsor/ModSecIntl_logging/logging" + "github.com/tilsor/ModSecIntl_logging/logging" ) var baseConfig = `--- @@ -51,12 +51,12 @@ var provider = metric.NewMeterProvider() var testMeter = provider.Meter("example-meter") func initilize(configuration []byte) error { - var aux cf.ConfigFileData + var aux configstore.ConfigFileData err := yaml.Unmarshal(configuration, &aux) if err != nil { return err } - cs, err := cf.Get() + cs, err := configstore.Get() if err != nil { return err } @@ -64,7 +64,7 @@ func initilize(configuration []byte) error { if err != nil { return err } - logger := lg.Get() + logger := logging.Get() err = logger.LoadLogger(cs.LogPath, cs.LogLevel) if err != nil { @@ -77,8 +77,8 @@ func initilize(configuration []byte) error { func init() { rand.Seed(time.Now().UnixNano()) - logger := lg.Get() - err := logger.LoadLogger("/dev/null", lg.ERROR) + logger := logging.Get() + err := logger.LoadLogger("/dev/null", logging.ERROR) if err != nil { panic("Error loading logger") } diff --git a/wacecore.go b/wacecore.go index c41a0f8..995dd99 100644 --- a/wacecore.go +++ b/wacecore.go @@ -10,11 +10,11 @@ import ( "sync/atomic" "time" - cf "github.com/tilsor/ModSecIntl_wace_lib/configstore" + "github.com/tilsor/ModSecIntl_wace_lib/configstore" - pm "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" + "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" - lg "github.com/tilsor/ModSecIntl_logging/logging" + "github.com/tilsor/ModSecIntl_logging/logging" "context" @@ -22,7 +22,7 @@ import ( "go.opentelemetry.io/otel/metric" ) -var plugins *pm.PluginManager +var plugins *pluginmanager.PluginManager var ctx = context.Background() var meter metric.Meter @@ -61,18 +61,18 @@ func addTransactionAnalysis(transactionID string) { // callPlugins calls the model plugins in the given list, with the given input. // It waits for all the synchronous model plugins to finish, and sends the // result to the client. The asynchronous model plugins are executed in parallel -func callPlugins(input string, models []string, t cf.ModelPluginType, transactionId string) error { - logger := lg.Get() +func callPlugins(input string, models []string, t configstore.ModelPluginType, transactionId string) error { + logger := logging.Get() // channel to receive the status of the execution of the analysis // of all the model plugins executed - modelPlugStatus := make(chan pm.ModelStatus) - asyncModelPlugStatus := make(chan pm.ModelStatus) + modelPlugStatus := make(chan pluginmanager.ModelStatus) + asyncModelPlugStatus := make(chan pluginmanager.ModelStatus) plugins.AddModelChannel(transactionId, t, asyncModelPlugStatus, "async") plugins.AddModelChannel(transactionId, t, modelPlugStatus, "sync") - conf, err := cf.Get() + conf, err := configstore.Get() if err != nil { return err } @@ -83,12 +83,12 @@ func callPlugins(input string, models []string, t cf.ModelPluginType, transactio startTime := time.Now() for _, id := range models { - logger.TPrintf(lg.DEBUG, transactionId, "%s | calling from core", id) + logger.TPrintf(logging.DEBUG, transactionId, "%s | calling from core", id) if _, ok := conf.ModelPlugins[id]; !ok { - logger.TPrintf(lg.ERROR, transactionId, "core | model plugin %s not found", id) + logger.TPrintf(logging.ERROR, transactionId, "core | model plugin %s not found", id) } else { if conf.ModelPlugins[id].PluginType != t { - logger.TPrintf(lg.ERROR, transactionId, "core | model plugin %s is not of type %s", id, t) + logger.TPrintf(logging.ERROR, transactionId, "core | model plugin %s is not of type %s", id, t) } else { if conf.IsAsync(id) { asyncCounter++ @@ -106,25 +106,25 @@ func callPlugins(input string, models []string, t cf.ModelPluginType, transactio } go func() { - logger.TPrintf(lg.DEBUG, transactionId, "core | waiting for %d async model plugins to finish", asyncCounter) + logger.TPrintf(logging.DEBUG, transactionId, "core | waiting for %d async model plugins to finish", asyncCounter) wg := sync.WaitGroup{} wg.Add(asyncCounter) for i := 0; i < asyncCounter; i++ { // Await for the execution of the async model plugins - logger.TPrintf(lg.DEBUG, transactionId, "core | Waiting for async model plugin %d...", i+1) + logger.TPrintf(logging.DEBUG, transactionId, "core | Waiting for async model plugin %d...", i+1) status := <-asyncModelPlugStatus if status.Err == nil { - logger.TPrintf(lg.DEBUG, transactionId, "%s async | success. Result: %.5f", status.ModelID, status.ProbAttack) + logger.TPrintf(logging.DEBUG, transactionId, "%s async | success. Result: %.5f", status.ModelID, status.ProbAttack) histogramMeter, err := meter.Int64Histogram("wace.model.duration.nanoseconds") if err != nil { - logger.TPrintf(lg.WARN, transactionId, "core | failed to record duration metric: %v", err.Error()) + logger.TPrintf(logging.WARN, transactionId, "core | failed to record duration metric: %v", err.Error()) } histogramMeter.Record(ctx, time.Since(startTime).Nanoseconds(), metric.WithAttributes( attribute.String("model_id", status.ModelID), attribute.String("model_mode", "async"), attribute.Float64("attack_probability", status.ProbAttack))) } else { - logger.TPrintf(lg.WARN, transactionId, "%s | %v", status.ModelID, status.Err) + logger.TPrintf(logging.WARN, transactionId, "%s | %v", status.ModelID, status.Err) } wg.Done() } @@ -132,30 +132,30 @@ func callPlugins(input string, models []string, t cf.ModelPluginType, transactio plugins.RemoveAsyncModelChannel(transactionId, t) }() - logger.TPrintf(lg.DEBUG, transactionId, "core | waiting for %d sync model plugins to finish", syncCounter) + logger.TPrintf(logging.DEBUG, transactionId, "core | waiting for %d sync model plugins to finish", syncCounter) for i := 0; i < syncCounter; i++ { // Await for the execution of the model plugins - logger.TPrintf(lg.DEBUG, transactionId, "core | Waiting for sync model plugin %d...", i+1) + logger.TPrintf(logging.DEBUG, transactionId, "core | Waiting for sync model plugin %d...", i+1) status := <-modelPlugStatus if status.Err == nil { - logger.TPrintf(lg.DEBUG, transactionId, "%s sync | success. Result: %.5f", status.ModelID, status.ProbAttack) + logger.TPrintf(logging.DEBUG, transactionId, "%s sync | success. Result: %.5f", status.ModelID, status.ProbAttack) histogramMeter, err := meter.Int64Histogram("wace.model.duration.nanoseconds") if err != nil { - logger.TPrintf(lg.WARN, transactionId, "core | failed to record duration metric: %v", err.Error()) + logger.TPrintf(logging.WARN, transactionId, "core | failed to record duration metric: %v", err.Error()) } histogramMeter.Record(ctx, time.Since(startTime).Nanoseconds(), metric.WithAttributes( attribute.String("model_id", status.ModelID), attribute.String("model_mode", "sync"), attribute.Float64("attack_probability", status.ProbAttack))) } else { - logger.TPrintf(lg.WARN, transactionId, "%s | %v", status.ModelID, status.Err) + logger.TPrintf(logging.WARN, transactionId, "%s | %v", status.ModelID, status.Err) } } value, ok := analysisMap.Load(transactionId) if !ok { - logger.TPrintf(lg.ERROR, transactionId, "core | could not find transaction %s in analysis map", transactionId) + logger.TPrintf(logging.ERROR, transactionId, "core | could not find transaction %s in analysis map", transactionId) return fmt.Errorf("core | could not find transaction %s in analysis map", transactionId) } analysisChan := value.(*transactionSync).Channel @@ -165,9 +165,9 @@ func callPlugins(input string, models []string, t cf.ModelPluginType, transactio // InitTransaction initializes a transaction with the given id func InitTransaction(transactionId string) { - logger := lg.Get() + logger := logging.Get() logger.StartTransaction(transactionId) - logger.TPrintf(lg.DEBUG, transactionId, "core | initializing transaction") + logger.TPrintf(logging.DEBUG, transactionId, "core | initializing transaction") tSync := transactionSync{ Channel: make(chan string), Counter: 0, @@ -179,13 +179,13 @@ func InitTransaction(transactionId string) { // Analyze calls the model plugins with the given payload and models func Analyze(modelsTypeAsString, transactionId, payload string, models []string) error { if len(models) > 0 { - logger := lg.Get() - modelsType, err := cf.StringToPluginType(modelsTypeAsString) + logger := logging.Get() + modelsType, err := configstore.StringToPluginType(modelsTypeAsString) if err != nil { - logger.TPrintf(lg.ERROR, transactionId, "core | %s is not a valid type", modelsTypeAsString) + logger.TPrintf(logging.ERROR, transactionId, "core | %s is not a valid type", modelsTypeAsString) return err } - logger.TPrintf(lg.DEBUG, transactionId, "core | analyzing %s: [%s...]", modelsTypeAsString, strings.Split(payload, "\n")[0]) + logger.TPrintf(logging.DEBUG, transactionId, "core | analyzing %s: [%s...]", modelsTypeAsString, strings.Split(payload, "\n")[0]) addTransactionAnalysis(transactionId) go callPlugins(payload, models, modelsType, transactionId) } @@ -195,8 +195,8 @@ func Analyze(modelsTypeAsString, transactionId, payload string, models []string) // CheckTransaction checks the result of the analysis of the transaction // with the given id and decision plugin func CheckTransaction(transactionId, decisionPlugin string, wafParams map[string]string) (bool, error) { - logger := lg.Get() - logger.TPrintf(lg.DEBUG, transactionId, "core | checking transaction") + logger := logging.Get() + logger.TPrintf(logging.DEBUG, transactionId, "core | checking transaction") value, exists := analysisMap.Load(transactionId) @@ -206,28 +206,28 @@ func CheckTransaction(transactionId, decisionPlugin string, wafParams map[string sync := value.(*transactionSync) - logger.TPrintln(lg.DEBUG, transactionId, "core | waiting for all models to finish...") + logger.TPrintln(logging.DEBUG, transactionId, "core | waiting for all models to finish...") for i := 0; i < int(sync.Counter); i++ { <-sync.Channel } sync.Counter = 0 - logger.TPrintln(lg.DEBUG, transactionId, "core | done, checking data...") + logger.TPrintln(logging.DEBUG, transactionId, "core | done, checking data...") res, err := plugins.CheckResult(transactionId, decisionPlugin, wafParams) if err == nil { - logger.TPrintf(lg.DEBUG, transactionId, "core | transaction checked successfully. Blocking transaction: %t", res) + logger.TPrintf(logging.DEBUG, transactionId, "core | transaction checked successfully. Blocking transaction: %t", res) if res { metric, err := meter.Int64Counter("wace.client.request.blocked.total", metric.WithDescription(decisionPlugin)) if err != nil { - logger.TPrintf(lg.WARN, transactionId, "core | failed to record blocked request metric: %v", err.Error()) + logger.TPrintf(logging.WARN, transactionId, "core | failed to record blocked request metric: %v", err.Error()) } metric.Add(ctx, 1) } } else { - logger.TPrintf(lg.ERROR, transactionId, "core | could not check transaction: %v", err) + logger.TPrintf(logging.ERROR, transactionId, "core | could not check transaction: %v", err) } return res, err } @@ -237,10 +237,10 @@ func CheckTransaction(transactionId, decisionPlugin string, wafParams map[string func CloseTransaction(transactionID string) { plugins.CloseTransaction(transactionID) value, ok := analysisMap.Load(transactionID) - logger := lg.Get() + logger := logging.Get() if !ok { - logger.TPrintf(lg.ERROR, transactionID, "Analysis for transaction %s not found", transactionID) + logger.TPrintf(logging.ERROR, transactionID, "Analysis for transaction %s not found", transactionID) } else { close(value.(*transactionSync).Channel) for range value.(*transactionSync).Channel { @@ -251,23 +251,23 @@ func CloseTransaction(transactionID string) { // Init initializes the WACE core with the given metric meter func Init(met metric.Meter) error { - logger := lg.Get() - conf, err := cf.Get() + logger := logging.Get() + conf, err := configstore.Get() meter = met err = logger.LoadLogger(conf.LogPath, conf.LogLevel) if err != nil { - logger.Printf(lg.ERROR, "ERROR: could not open wace log file: %v", err) + logger.Printf(logging.ERROR, "ERROR: could not open wace log file: %v", err) return err } - logger.Printf(lg.DEBUG, "Writing logs to %s from now", conf.LogPath) + logger.Printf(logging.DEBUG, "Writing logs to %s from now", conf.LogPath) - logger.Println(lg.DEBUG, "Loading plugin manager...") - plugins, err = pm.New(met) + logger.Println(logging.DEBUG, "Loading plugin manager...") + plugins, err = pluginmanager.New(met) if err != nil { return err } - logger.Println(lg.DEBUG, "Plugin manager loaded") + logger.Println(logging.DEBUG, "Plugin manager loaded") return nil } From ff70d97611505ce3a5b5c504ab101cd85e0c791c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20de=20Le=C3=B3n?= Date: Tue, 28 Apr 2026 17:37:04 -0300 Subject: [PATCH 3/6] chore: update dependencies --- go.mod | 27 +++++++++++++------------ go.sum | 62 ++++++++++++++++++++++++++++++---------------------------- 2 files changed, 46 insertions(+), 43 deletions(-) diff --git a/go.mod b/go.mod index fa6770e..aabc361 100644 --- a/go.mod +++ b/go.mod @@ -1,26 +1,27 @@ module github.com/tilsor/ModSecIntl_wace_lib -go 1.22.9 +go 1.26.2 require ( - github.com/nats-io/nats.go v1.38.0 + github.com/nats-io/nats.go v1.51.0 github.com/tilsor/ModSecIntl_logging v1.0.1 - go.opentelemetry.io/otel v1.34.0 - go.opentelemetry.io/otel/metric v1.34.0 - go.opentelemetry.io/otel/sdk/metric v1.34.0 + go.opentelemetry.io/otel v1.43.0 + go.opentelemetry.io/otel/metric v1.43.0 + go.opentelemetry.io/otel/sdk/metric v1.43.0 gopkg.in/yaml.v3 v3.0.1 ) require ( - github.com/go-logr/logr v1.4.2 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/google/uuid v1.6.0 // indirect - github.com/klauspost/compress v1.17.9 // indirect - github.com/nats-io/nkeys v0.4.9 // indirect + github.com/klauspost/compress v1.18.5 // indirect + github.com/nats-io/nkeys v0.4.15 // indirect github.com/nats-io/nuid v1.0.1 // indirect - go.opentelemetry.io/auto/sdk v1.1.0 // indirect - go.opentelemetry.io/otel/sdk v1.34.0 // indirect - go.opentelemetry.io/otel/trace v1.34.0 // indirect - golang.org/x/crypto v0.31.0 // indirect - golang.org/x/sys v0.29.0 // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect + go.opentelemetry.io/otel/sdk v1.43.0 // indirect + go.opentelemetry.io/otel/trace v1.43.0 // indirect + golang.org/x/crypto v0.50.0 // indirect + golang.org/x/sys v0.43.0 // indirect ) diff --git a/go.sum b/go.sum index ecc9057..7b259f6 100644 --- a/go.sum +++ b/go.sum @@ -1,50 +1,52 @@ +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= -github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= -github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= +github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/nats-io/nats.go v1.38.0 h1:A7P+g7Wjp4/NWqDOOP/K6hfhr54DvdDQUznt5JFg9XA= -github.com/nats-io/nats.go v1.38.0/go.mod h1:IGUM++TwokGnXPs82/wCuiHS02/aKrdYUQkU8If6yjw= -github.com/nats-io/nkeys v0.4.9 h1:qe9Faq2Gxwi6RZnZMXfmGMZkg3afLLOtrU+gDZJ35b0= -github.com/nats-io/nkeys v0.4.9/go.mod h1:jcMqs+FLG+W5YO36OX6wFIFcmpdAns+w1Wm6D3I/evE= +github.com/nats-io/nats.go v1.51.0 h1:ByW84XTz6W03GSSsygsZcA+xgKK8vPGaa/FCAAEHnAI= +github.com/nats-io/nats.go v1.51.0/go.mod h1:26HypzazeOkyO3/mqd1zZd53STJN0EjCYF9Uy2ZOBno= +github.com/nats-io/nkeys v0.4.15 h1:JACV5jRVO9V856KOapQ7x+EY8Jo3qw1vJt/9Jpwzkk4= +github.com/nats-io/nkeys v0.4.15/go.mod h1:CpMchTXC9fxA5zrMo4KpySxNjiDVvr8ANOSZdiNfUrs= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= -github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tilsor/ModSecIntl_logging v1.0.1 h1:wFd3SxJPUU5JxX2UlrsH0Ef/m9a18d/VRo4xVCtCxVM= github.com/tilsor/ModSecIntl_logging v1.0.1/go.mod h1:9RrpYmS4v/wYIiiYXzDW6Lqr8Xb8wq3ejpHi8jmQsyo= -go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= -go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= -go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY= -go.opentelemetry.io/otel v1.34.0/go.mod h1:OWFPOQ+h4G8xpyjgqo4SxJYdDQ/qmRH+wivy7zzx9oI= -go.opentelemetry.io/otel/metric v1.34.0 h1:+eTR3U0MyfWjRDhmFMxe2SsW64QrZ84AOhvqS7Y+PoQ= -go.opentelemetry.io/otel/metric v1.34.0/go.mod h1:CEDrp0fy2D0MvkXE+dPV7cMi8tWZwX3dmaIhwPOaqHE= -go.opentelemetry.io/otel/sdk v1.34.0 h1:95zS4k/2GOy069d321O8jWgYsW3MzVV+KuSPKp7Wr1A= -go.opentelemetry.io/otel/sdk v1.34.0/go.mod h1:0e/pNiaMAqaykJGKbi+tSjWfNNHMTxoC9qANsCzbyxU= -go.opentelemetry.io/otel/sdk/metric v1.34.0 h1:5CeK9ujjbFVL5c1PhLuStg1wxA7vQv7ce1EK0Gyvahk= -go.opentelemetry.io/otel/sdk/metric v1.34.0/go.mod h1:jQ/r8Ze28zRKoNRdkjCZxfs6YvBTG1+YIqyFVFYec5w= -go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k= -go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE= -golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= -golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= -golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= +go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= +go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= +go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= +go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= +go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= +go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= +go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= +go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= +go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= From 44795c6e7419109adc31033f8ffeda781b810198 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20de=20Le=C3=B3n?= Date: Wed, 29 Apr 2026 12:14:34 -0300 Subject: [PATCH 4/6] feat: add HTTPPayload type to function signatures --- pluginmanager/pluginmanager.go | 29 ++++++-- testdata/plugins/model/trivial.go | 2 +- testdata/plugins/model/trivial2.go | 2 +- testdata/plugins/model/trivial_async.go | 2 +- testdata/plugins/model/trivial_async2.go | 2 +- wacecore.go | 59 ++++++++-------- wacecore_test.go | 90 ++++++++++++++++-------- 7 files changed, 118 insertions(+), 68 deletions(-) diff --git a/pluginmanager/pluginmanager.go b/pluginmanager/pluginmanager.go index a6a319b..641e4f9 100644 --- a/pluginmanager/pluginmanager.go +++ b/pluginmanager/pluginmanager.go @@ -23,10 +23,27 @@ type ModelResults struct { Data map[string]interface{} `json:"data"` } +type HTTPHeader struct { + Key string + Value string +} + +type HTTPPayload struct { + URI string + Method string + HTTPVersion string + RequestHeaders []HTTPHeader + RequestBody string + ResponseProtocol string + ResponseCode int + ResponseHeaders []HTTPHeader + ResponseBody string +} + // ModelInput is the struct that contains the input data for the model plugin type ModelInput struct { - TransactionId string `json:"transactionId"` - Payload string `json:"payload"` + TransactionId string `json:"transactionId"` + Payload HTTPPayload `json:"payload"` } // DecisionInput is the struct that contains the input data for the decision plugin @@ -269,9 +286,9 @@ func (p *PluginManager) RemoveAsyncModelChannel(transactionId string, t configst } // AddToQueue adds a payload to the model queue -func (p *PluginManager) AddToQueue(modelId, transactionId, payload string) error { +func (p *PluginManager) AddToQueue(modelID, transactionID string, payload HTTPPayload) error { payloadToSend := &ModelInput{ - TransactionId: transactionId, + TransactionId: transactionID, Payload: payload, } @@ -281,11 +298,11 @@ func (p *PluginManager) AddToQueue(modelId, transactionId, payload string) error return err } - return p.natConn.Publish(modelId, jsonPayload) + return p.natConn.Publish(modelID, jsonPayload) } // Process is in charge of calling the model plugin with id modelID -func (p *PluginManager) Process(modelID, transactionId, payload string, t configstore.ModelPluginType, modelPlugStatus chan ModelStatus) error { +func (p *PluginManager) Process(modelID, transactionId string, payload HTTPPayload, t configstore.ModelPluginType, modelPlugStatus chan ModelStatus) error { conf, err := configstore.Get() if err != nil { return err diff --git a/testdata/plugins/model/trivial.go b/testdata/plugins/model/trivial.go index 228b75c..046ff50 100644 --- a/testdata/plugins/model/trivial.go +++ b/testdata/plugins/model/trivial.go @@ -34,7 +34,7 @@ func InitPluginAsync(params map[string]string, meter metric.Meter, natsManager f func Process(input pm.ModelInput) (pm.ModelResults, error) { logger := lg.Get() - logger.TPrintf(lg.WARN, input.TransactionId, "[trivial:Process] \"%s\"\n", input.Payload) + logger.TPrintf(lg.WARN, input.TransactionId, "[trivial:Process] \"%v\"\n", input.Payload) result := pm.ModelResults{ ProbAttack: 0.0, Data: make(map[string]interface{}), diff --git a/testdata/plugins/model/trivial2.go b/testdata/plugins/model/trivial2.go index f04218c..7b9ed66 100644 --- a/testdata/plugins/model/trivial2.go +++ b/testdata/plugins/model/trivial2.go @@ -34,7 +34,7 @@ func InitPluginAsync(params map[string]string, meter metric.Meter, natsManager f func Process(input pm.ModelInput) (pm.ModelResults, error) { logger := lg.Get() - logger.TPrintf(lg.WARN, input.TransactionId, "[trivial2:Proccess] \"%s\"\n", input.Payload) + logger.TPrintf(lg.WARN, input.TransactionId, "[trivial2:Proccess] \"%v\"\n", input.Payload) result := pm.ModelResults{ ProbAttack: 1.0, Data: make(map[string]interface{}), diff --git a/testdata/plugins/model/trivial_async.go b/testdata/plugins/model/trivial_async.go index 3896fbc..9e38210 100644 --- a/testdata/plugins/model/trivial_async.go +++ b/testdata/plugins/model/trivial_async.go @@ -49,7 +49,7 @@ func InitPluginAsync(params map[string]string, meter metric.Meter, natsManager f func Process(input pm.ModelInput) (pm.ModelResults, error) { time.Sleep(time.Duration(sleepTime) * time.Second) logger := lg.Get() - logger.TPrintf(lg.WARN, input.TransactionId, "[trivial_async:Process] \"%s\"\n", input.Payload) + logger.TPrintf(lg.WARN, input.TransactionId, "[trivial_async:Process] \"%v\"\n", input.Payload) result := pm.ModelResults{ ProbAttack: 0.0, Data: make(map[string]interface{}), diff --git a/testdata/plugins/model/trivial_async2.go b/testdata/plugins/model/trivial_async2.go index 86c0243..79a0810 100644 --- a/testdata/plugins/model/trivial_async2.go +++ b/testdata/plugins/model/trivial_async2.go @@ -49,7 +49,7 @@ func InitPluginAsync(params map[string]string, meter metric.Meter, natsManager f func Process(input pm.ModelInput) (pm.ModelResults, error) { time.Sleep(time.Duration(sleepTime) * time.Second) logger := lg.Get() - logger.TPrintf(lg.WARN, input.TransactionId, "[trivial_async2:Process] \"%s\"\n", input.Payload) + logger.TPrintf(lg.WARN, input.TransactionId, "[trivial_async2:Process] \"%v\"\n", input.Payload) result := pm.ModelResults{ ProbAttack: 1.0, Data: make(map[string]interface{}), diff --git a/wacecore.go b/wacecore.go index 995dd99..d22c3bd 100644 --- a/wacecore.go +++ b/wacecore.go @@ -5,7 +5,6 @@ package wace import ( "fmt" - "strings" "sync" "sync/atomic" "time" @@ -61,16 +60,16 @@ func addTransactionAnalysis(transactionID string) { // callPlugins calls the model plugins in the given list, with the given input. // It waits for all the synchronous model plugins to finish, and sends the // result to the client. The asynchronous model plugins are executed in parallel -func callPlugins(input string, models []string, t configstore.ModelPluginType, transactionId string) error { +func callPlugins(input pluginmanager.HTTPPayload, models []string, t configstore.ModelPluginType, transactionID string) error { logger := logging.Get() // channel to receive the status of the execution of the analysis // of all the model plugins executed - modelPlugStatus := make(chan pluginmanager.ModelStatus) - asyncModelPlugStatus := make(chan pluginmanager.ModelStatus) + modelPluginStatus := make(chan pluginmanager.ModelStatus) + asyncModelPluginStatus := make(chan pluginmanager.ModelStatus) - plugins.AddModelChannel(transactionId, t, asyncModelPlugStatus, "async") - plugins.AddModelChannel(transactionId, t, modelPlugStatus, "sync") + plugins.AddModelChannel(transactionID, t, asyncModelPluginStatus, "async") + plugins.AddModelChannel(transactionID, t, modelPluginStatus, "sync") conf, err := configstore.Get() if err != nil { @@ -83,21 +82,21 @@ func callPlugins(input string, models []string, t configstore.ModelPluginType, t startTime := time.Now() for _, id := range models { - logger.TPrintf(logging.DEBUG, transactionId, "%s | calling from core", id) + logger.TPrintf(logging.DEBUG, transactionID, "%s | calling from core", id) if _, ok := conf.ModelPlugins[id]; !ok { - logger.TPrintf(logging.ERROR, transactionId, "core | model plugin %s not found", id) + logger.TPrintf(logging.ERROR, transactionID, "core | model plugin %s not found", id) } else { if conf.ModelPlugins[id].PluginType != t { - logger.TPrintf(logging.ERROR, transactionId, "core | model plugin %s is not of type %s", id, t) + logger.TPrintf(logging.ERROR, transactionID, "core | model plugin %s is not of type %s", id, t) } else { if conf.IsAsync(id) { asyncCounter++ - go plugins.AddToQueue(id, transactionId, input) + go plugins.AddToQueue(id, transactionID, input) } else { if conf.ModelPlugins[id].Remote { - go plugins.AddToQueue(id, transactionId, input) + go plugins.AddToQueue(id, transactionID, input) } else { - go plugins.Process(id, transactionId, input, t, modelPlugStatus) + go plugins.Process(id, transactionID, input, t, modelPluginStatus) } syncCounter++ } @@ -106,57 +105,57 @@ func callPlugins(input string, models []string, t configstore.ModelPluginType, t } go func() { - logger.TPrintf(logging.DEBUG, transactionId, "core | waiting for %d async model plugins to finish", asyncCounter) + logger.TPrintf(logging.DEBUG, transactionID, "core | waiting for %d async model plugins to finish", asyncCounter) wg := sync.WaitGroup{} wg.Add(asyncCounter) for i := 0; i < asyncCounter; i++ { // Await for the execution of the async model plugins - logger.TPrintf(logging.DEBUG, transactionId, "core | Waiting for async model plugin %d...", i+1) - status := <-asyncModelPlugStatus + logger.TPrintf(logging.DEBUG, transactionID, "core | Waiting for async model plugin %d...", i+1) + status := <-asyncModelPluginStatus if status.Err == nil { - logger.TPrintf(logging.DEBUG, transactionId, "%s async | success. Result: %.5f", status.ModelID, status.ProbAttack) + logger.TPrintf(logging.DEBUG, transactionID, "%s async | success. Result: %.5f", status.ModelID, status.ProbAttack) histogramMeter, err := meter.Int64Histogram("wace.model.duration.nanoseconds") if err != nil { - logger.TPrintf(logging.WARN, transactionId, "core | failed to record duration metric: %v", err.Error()) + logger.TPrintf(logging.WARN, transactionID, "core | failed to record duration metric: %v", err.Error()) } histogramMeter.Record(ctx, time.Since(startTime).Nanoseconds(), metric.WithAttributes( attribute.String("model_id", status.ModelID), attribute.String("model_mode", "async"), attribute.Float64("attack_probability", status.ProbAttack))) } else { - logger.TPrintf(logging.WARN, transactionId, "%s | %v", status.ModelID, status.Err) + logger.TPrintf(logging.WARN, transactionID, "%s | %v", status.ModelID, status.Err) } wg.Done() } wg.Wait() - plugins.RemoveAsyncModelChannel(transactionId, t) + plugins.RemoveAsyncModelChannel(transactionID, t) }() - logger.TPrintf(logging.DEBUG, transactionId, "core | waiting for %d sync model plugins to finish", syncCounter) + logger.TPrintf(logging.DEBUG, transactionID, "core | waiting for %d sync model plugins to finish", syncCounter) for i := 0; i < syncCounter; i++ { // Await for the execution of the model plugins - logger.TPrintf(logging.DEBUG, transactionId, "core | Waiting for sync model plugin %d...", i+1) - status := <-modelPlugStatus + logger.TPrintf(logging.DEBUG, transactionID, "core | Waiting for sync model plugin %d...", i+1) + status := <-modelPluginStatus if status.Err == nil { - logger.TPrintf(logging.DEBUG, transactionId, "%s sync | success. Result: %.5f", status.ModelID, status.ProbAttack) + logger.TPrintf(logging.DEBUG, transactionID, "%s sync | success. Result: %.5f", status.ModelID, status.ProbAttack) histogramMeter, err := meter.Int64Histogram("wace.model.duration.nanoseconds") if err != nil { - logger.TPrintf(logging.WARN, transactionId, "core | failed to record duration metric: %v", err.Error()) + logger.TPrintf(logging.WARN, transactionID, "core | failed to record duration metric: %v", err.Error()) } histogramMeter.Record(ctx, time.Since(startTime).Nanoseconds(), metric.WithAttributes( attribute.String("model_id", status.ModelID), attribute.String("model_mode", "sync"), attribute.Float64("attack_probability", status.ProbAttack))) } else { - logger.TPrintf(logging.WARN, transactionId, "%s | %v", status.ModelID, status.Err) + logger.TPrintf(logging.WARN, transactionID, "%s | %v", status.ModelID, status.Err) } } - value, ok := analysisMap.Load(transactionId) + value, ok := analysisMap.Load(transactionID) if !ok { - logger.TPrintf(logging.ERROR, transactionId, "core | could not find transaction %s in analysis map", transactionId) - return fmt.Errorf("core | could not find transaction %s in analysis map", transactionId) + logger.TPrintf(logging.ERROR, transactionID, "core | could not find transaction %s in analysis map", transactionID) + return fmt.Errorf("core | could not find transaction %s in analysis map", transactionID) } analysisChan := value.(*transactionSync).Channel analysisChan <- "done" @@ -177,7 +176,7 @@ func InitTransaction(transactionId string) { } // Analyze calls the model plugins with the given payload and models -func Analyze(modelsTypeAsString, transactionId, payload string, models []string) error { +func Analyze(modelsTypeAsString, transactionId string, payload pluginmanager.HTTPPayload, models []string) error { if len(models) > 0 { logger := logging.Get() modelsType, err := configstore.StringToPluginType(modelsTypeAsString) @@ -185,7 +184,7 @@ func Analyze(modelsTypeAsString, transactionId, payload string, models []string) logger.TPrintf(logging.ERROR, transactionId, "core | %s is not a valid type", modelsTypeAsString) return err } - logger.TPrintf(logging.DEBUG, transactionId, "core | analyzing %s: [%s...]", modelsTypeAsString, strings.Split(payload, "\n")[0]) + logger.TPrintf(logging.DEBUG, transactionId, "core | analyzing %s: [%v...]", modelsTypeAsString, payload) addTransactionAnalysis(transactionId) go callPlugins(payload, models, modelsType, transactionId) } diff --git a/wacecore_test.go b/wacecore_test.go index 0309a03..367e1d6 100644 --- a/wacecore_test.go +++ b/wacecore_test.go @@ -8,39 +8,73 @@ import ( "time" "github.com/tilsor/ModSecIntl_wace_lib/configstore" + "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" "go.opentelemetry.io/otel/sdk/metric" "gopkg.in/yaml.v3" ) -var requestLine = "POST /cgi-bin/process.cgi HTTP/1.1\n" -var requestHeaders = `User-Agent: Mozilla/4.0 (compatible; MSIE5.01; Windows NT) -Host: www.tutorialspoint.com -Content-Type: application/x-www-form-urlencoded -Content-Length: length -Accept-Language: en-us -Accept-Encoding: gzip, deflate -Connection: Keep-Alive -` +var requestURI = "/cgi-bin/process.cgi" +var requestMethod = "POST" +var requestVersion = "HTTP/1.1" + +// var requestLine = "POST /cgi-bin/process.cgi HTTP/1.1\n" + +var requestHeaders = []pluginmanager.HTTPHeader{ + {Key: "User-Agent", Value: "Mozilla/4.0 (compatible; MSIE5.01; Windows NT)"}, + {Key: "Host", Value: "www.tutorialspoint.com"}, + {Key: "Content-Type", Value: "application/x-www-form-urlencoded"}, + {Key: "Content-Length", Value: "length"}, + {Key: "Accept-Language", Value: "en-us"}, + {Key: "Accept-Encoding", Value: "gzip, deflate"}, + {Key: "Connection", Value: "Keep-Alive"}, +} + +var requestHeadersPayload = pluginmanager.HTTPPayload{ + URI: requestURI, + Method: requestMethod, + HTTPVersion: requestVersion, +} var requestBody = "licenseID=string&content=string&/paramsXML=string\n" -var wholeRequest = requestLine + requestHeaders + "\n" + requestBody - -var responseLine = "HTTP/1.1 200 OK\n" -var responseHeaders = `Date: Mon, 27 Jul 2009 12:28:53 GMT -Server: Apache/2.2.14 (Win32) -Last-Modified: Wed, 22 Jul 2009 19:15:56 GMT -Content-Length: 88 -Content-Type: text/html -Connection: Closed -` +var wholeRequest = pluginmanager.HTTPPayload{ + URI: requestURI, + Method: requestMethod, + HTTPVersion: requestVersion, + RequestBody: requestBody, +} + +// var wholeRequest = requestLine + requestHeaders + "\n" + requestBody +var responseCode = 200 +var responseProto = "HTTP/1.1" +var responseHeaders = []pluginmanager.HTTPHeader{ + {Key: "Date", Value: "Mon, 27 Jul 2009 12:28:53 GMT"}, + {Key: "Server", Value: "Apache/2.2.14 (Win32)"}, + {Key: "Last-Modified", Value: "Wed, 22 Jul 2009 19:15:56 GMT"}, + {Key: "Content-Length", Value: "88"}, + {Key: "Content-Type", Value: "text/html"}, + {Key: "Connection", Value: "Closed"}, +} + +var responseHeadersPayload = pluginmanager.HTTPPayload{ + ResponseProtocol: responseProto, + ResponseCode: responseCode, + ResponseHeaders: responseHeaders, +} + var responseBody = `

Hello, World!

` -var wholeResponse = responseLine + responseHeaders + "\n" + responseBody + +var wholeResponse = pluginmanager.HTTPPayload{ + ResponseProtocol: responseProto, + ResponseCode: responseCode, + ResponseHeaders: responseHeaders, + ResponseBody: responseBody, +} var config = []byte(`--- logpath: "/dev/null" @@ -283,11 +317,11 @@ func TestAnalyzeRequestInParts(t *testing.T) { InitTransaction(transactionID) - res := Analyze("RequestHeaders", transactionID, requestLine+"\n"+requestHeaders, []string{"trivialRequestHeaders"}) + res := Analyze("RequestHeaders", transactionID, requestHeadersPayload, []string{"trivialRequestHeaders"}) if res != nil { t.Errorf("Error: Analyze RequestHeaders: %s", res.Error()) } - res = Analyze("RequestBody", transactionID, requestBody, []string{"trivialRequestBody"}) + res = Analyze("RequestBody", transactionID, pluginmanager.HTTPPayload{ResponseBody: requestBody}, []string{"trivialRequestBody"}) if res != nil { t.Errorf("Error: Analyze RequestBody: %s", res.Error()) } @@ -345,11 +379,11 @@ func TestAnalyzeResponseInParts(t *testing.T) { InitTransaction(transactionID) - res := Analyze("ResponseHeaders", transactionID, responseLine+"\n"+responseHeaders, []string{"trivialResponseHeaders"}) + res := Analyze("ResponseHeaders", transactionID, responseHeadersPayload, []string{"trivialResponseHeaders"}) if res != nil { t.Errorf("Error: Analyze ResponseHeaders: %s", res.Error()) } - res = Analyze("ResponseBody", transactionID, responseBody, []string{"trivialResponseBody"}) + res = Analyze("ResponseBody", transactionID, pluginmanager.HTTPPayload{ResponseBody: responseBody}, []string{"trivialResponseBody"}) if res != nil { t.Errorf("Error: Analyze ResponseBody: %s", res.Error()) } @@ -406,7 +440,7 @@ func TestAnalyzeRequestInPartsAsync(t *testing.T) { InitTransaction(transactionID) - res := Analyze("RequestHeaders", transactionID, requestLine+"\n"+requestHeaders, []string{"trivial", "trivial2"}) + res := Analyze("RequestHeaders", transactionID, requestHeadersPayload, []string{"trivial", "trivial2"}) if res != nil { t.Errorf("Error: Analyze RequestHeaders: %s", res.Error()) } @@ -450,7 +484,7 @@ func TestCheckAttackTransaction(t *testing.T) { wafParams[scoreParts[0]] = scoreParts[1] } - err = Analyze("RequestHeaders", transactionID, requestLine+"\n"+requestHeaders, []string{"trivial", "trivial2", "trivial3"}) + err = Analyze("RequestHeaders", transactionID, requestHeadersPayload, []string{"trivial", "trivial2", "trivial3"}) if err != nil { t.Errorf("Error: Analyze RequestHeaders: %s", err.Error()) } @@ -533,7 +567,7 @@ func BenchmarkTrivial(b *testing.B) { transactionId := strconv.Itoa(i) InitTransaction(transactionId) - Analyze("RequestHeaders", transactionId, "Request line and headers\n", []string{"trivial", "trivial2"}) + Analyze("RequestHeaders", transactionId, pluginmanager.HTTPPayload{URI: "Request line and headers\n"}, []string{"trivial", "trivial2"}) _, err := CheckTransaction(transactionId, "simple", wafParams) if err != nil { @@ -565,7 +599,7 @@ func BenchmarkTrivialFullNATS(b *testing.B) { transactionId := generateRandomID() InitTransaction(transactionId) - Analyze("RequestHeaders", transactionId, "Request line and headers\n", []string{"trivial", "trivial2"}) + Analyze("RequestHeaders", transactionId, pluginmanager.HTTPPayload{URI: "Request line and headers\n"}, []string{"trivial", "trivial2"}) _, err := CheckTransaction(transactionId, "simple", wafParams) if err != nil { From 7a4bed1ef9ee49de9df02c52279ec6da1480c776 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20de=20Le=C3=B3n?= Date: Wed, 29 Apr 2026 12:16:57 -0300 Subject: [PATCH 5/6] chore: gofmt excecution --- configstore/configstore.go | 522 ++++++------ configstore/configstore_test.go | 554 ++++++------ pluginmanager/pluginmanager.go | 992 +++++++++++----------- pluginmanager/pluginmanager_test.go | 742 ++++++++-------- wacecore.go | 544 ++++++------ wacecore_test.go | 1220 +++++++++++++-------------- 6 files changed, 2287 insertions(+), 2287 deletions(-) diff --git a/configstore/configstore.go b/configstore/configstore.go index 26001f6..f2f4df6 100644 --- a/configstore/configstore.go +++ b/configstore/configstore.go @@ -1,261 +1,261 @@ -/* -Package configstore handles the configuration of WACE. The -configuration file is parsed, checked for errors and loaded into -memory -*/ -package configstore - -import ( - "fmt" - "os" - - "github.com/tilsor/ModSecIntl_logging/logging" -) - -// ModelPluginType is an enum listing the parts of a request or -// response that a model plugin can handle. -type ModelPluginType int - -const ( - RequestHeaders ModelPluginType = iota - RequestBody - AllRequest - ResponseHeaders - ResponseBody - AllResponse - Everything -) - -// String returns the string representation of a model plugin type -func (t ModelPluginType) String() string { - switch t { - case RequestHeaders: - return "RequestHeaders" - case RequestBody: - return "RequestBody" - case AllRequest: - return "AllRequest" - case ResponseHeaders: - return "ResponseHeaders" - case ResponseBody: - return "ResponseBody" - case AllResponse: - return "AllResponse" - default: - return "Everything" - } -} - -// StringToPluginType converts a string to the corresponding model plugin type -func StringToPluginType(textType string) (ModelPluginType, error) { - switch textType { - case "RequestHeaders": - return RequestHeaders, nil - case "RequestBody": - return RequestBody, nil - case "AllRequest": - return AllRequest, nil - case "ResponseHeaders": - return ResponseHeaders, nil - case "ResponseBody": - return ResponseBody, nil - case "AllResponse": - return AllResponse, nil - case "Everything": - return Everything, nil - } - return -1, fmt.Errorf("invalid plugin type %s", textType) -} - -// ModelPluginConfig stores the configuration of a model plugin -type modelPluginConfig struct { - ID string - Path string - Weight float64 - Threshold float64 - Params map[string]string - PluginType ModelPluginType - Mode string - Remote bool -} - -// DecisionPluginConfig stores the configuration of a decision plugin -type decisionPluginConfig struct { - ID string - Path string - WAFweight float64 - DecisionBalance float64 - Params map[string]string -} - -// ConfigStore stores all wacecore configuration from the config file. -type ConfigStore struct { - ModelPlugins map[string]modelPluginConfig - DecisionPlugins map[string]decisionPluginConfig - LogPath string - LogLevel logging.LogLevel - NatsURL string - ApplicationId string -} - -var config *ConfigStore - -// Create and returns the unique instance of configstore if it does not exist previously, in other case returns error -func New() (*ConfigStore, error) { - if config != nil { - return nil, fmt.Errorf("ConfigStore already exists") - } - config = new(ConfigStore) - return config, nil -} - -// Get returns the unique instance of configstore -func Get() (*ConfigStore, error) { - if config == nil { - return nil, fmt.Errorf("Configuration was not loaded") - } - return config, nil -} - -// Clean remove the references to the stored instance of configstore -func Clean() { - config = nil -} - -type configFileModelPlugin struct { - ID string - Path string - Weight float64 - Threshold float64 - Params map[string]string - PluginType string `yaml:"plugintype"` - Mode string - Remote bool -} - -type configFileDecisionPlugin struct { - ID string - Path string - wafweight float64 - decisionbalance float64 - Params map[string]string -} - -type ConfigFileData struct { - Logpath string - Loglevel string - Modelplugins []configFileModelPlugin - Decisionplugins []configFileDecisionPlugin - NatsURL string -} - -// IsAsync returns true if the model plugin is async -func (c *ConfigStore) IsAsync(modelID string) bool { - return c.ModelPlugins[modelID].Mode == "async" -} - -// CheckLogging verifies if the log path is valid -func checkLogging(inConf ConfigFileData) error { - // check logpath - if inConf.Logpath == "" { - return fmt.Errorf("log path empty") - } - _, err := os.Stat(inConf.Logpath) - if err != nil { // check if log file does not exists already - // Attempt to create dummy file - var d []byte - err = os.WriteFile(inConf.Logpath, d, 0644) - if err == nil { - err = os.Remove(inConf.Logpath) // delete it - } - } - return err -} - -// CheckConfig verifies if the configuration read from the config file -// is correct. -func checkConfig(inConf ConfigFileData) error { - err := checkLogging(inConf) - if err != nil { - return fmt.Errorf("invalid log path %s: %v", inConf.Logpath, err) - } - - // check modelplugins - for _, modelP := range inConf.Modelplugins { - - if modelP.Path != "" { - if _, err := os.Stat(modelP.Path); err != nil { - return fmt.Errorf("%s plugin path %s: %v", modelP.ID, modelP.Path, err) - } - } else { - return fmt.Errorf("%s plugin path is empty, please provide a valid path", modelP.ID) - } - if modelP.PluginType == "" { - return fmt.Errorf("%s plugin type cannot be empty, please provide a valid type", modelP.ID) - } - // fmt.Printf("modelP.Type: %s\n", modelP.Type) - } - // check decisionplugins - for _, decisionP := range inConf.Decisionplugins { - - if decisionP.Path != "" { - if _, err := os.Stat(decisionP.Path); err != nil { - return fmt.Errorf("%s plugin path %s cannot be opened: %v", decisionP.ID, decisionP.Path, err) - } - } else { - return fmt.Errorf("%s plugin path is empty, please provide a valid path", decisionP.ID) - } - } - - return nil -} - -// SetConfig sets the configuration of WACE from the configuration file -func (cs *ConfigStore) SetConfig(inConf ConfigFileData) error { - err := checkConfig(inConf) - if err != nil { - return err - } - - cs.LogPath = inConf.Logpath - cs.LogLevel, err = logging.StringToLogLevel(inConf.Loglevel) - if err != nil { - return err - } - - cs.ModelPlugins = make(map[string]modelPluginConfig) - for _, modelP := range inConf.Modelplugins { - var modelConfig modelPluginConfig - modelConfig.ID = modelP.ID - modelConfig.Path = modelP.Path - modelConfig.Weight = modelP.Weight - modelConfig.Threshold = modelP.Threshold - modelConfig.Params = modelP.Params - modelConfig.PluginType, err = StringToPluginType(modelP.PluginType) - modelConfig.Mode = modelP.Mode - modelConfig.Remote = modelP.Remote - if err != nil { - return err - } - cs.ModelPlugins[modelConfig.ID] = modelConfig - } - - cs.DecisionPlugins = make(map[string]decisionPluginConfig) - for _, decisionP := range inConf.Decisionplugins { - var decisionConfig decisionPluginConfig - decisionConfig.ID = decisionP.ID - decisionConfig.Path = decisionP.Path - decisionConfig.WAFweight = decisionP.wafweight - decisionConfig.DecisionBalance = decisionP.decisionbalance - decisionConfig.Params = decisionP.Params - cs.DecisionPlugins[decisionConfig.ID] = decisionConfig - } - - if inConf.NatsURL != "" { - cs.NatsURL = inConf.NatsURL - } else { - cs.NatsURL = "localhost:4222" - } - - return nil -} +/* +Package configstore handles the configuration of WACE. The +configuration file is parsed, checked for errors and loaded into +memory +*/ +package configstore + +import ( + "fmt" + "os" + + "github.com/tilsor/ModSecIntl_logging/logging" +) + +// ModelPluginType is an enum listing the parts of a request or +// response that a model plugin can handle. +type ModelPluginType int + +const ( + RequestHeaders ModelPluginType = iota + RequestBody + AllRequest + ResponseHeaders + ResponseBody + AllResponse + Everything +) + +// String returns the string representation of a model plugin type +func (t ModelPluginType) String() string { + switch t { + case RequestHeaders: + return "RequestHeaders" + case RequestBody: + return "RequestBody" + case AllRequest: + return "AllRequest" + case ResponseHeaders: + return "ResponseHeaders" + case ResponseBody: + return "ResponseBody" + case AllResponse: + return "AllResponse" + default: + return "Everything" + } +} + +// StringToPluginType converts a string to the corresponding model plugin type +func StringToPluginType(textType string) (ModelPluginType, error) { + switch textType { + case "RequestHeaders": + return RequestHeaders, nil + case "RequestBody": + return RequestBody, nil + case "AllRequest": + return AllRequest, nil + case "ResponseHeaders": + return ResponseHeaders, nil + case "ResponseBody": + return ResponseBody, nil + case "AllResponse": + return AllResponse, nil + case "Everything": + return Everything, nil + } + return -1, fmt.Errorf("invalid plugin type %s", textType) +} + +// ModelPluginConfig stores the configuration of a model plugin +type modelPluginConfig struct { + ID string + Path string + Weight float64 + Threshold float64 + Params map[string]string + PluginType ModelPluginType + Mode string + Remote bool +} + +// DecisionPluginConfig stores the configuration of a decision plugin +type decisionPluginConfig struct { + ID string + Path string + WAFweight float64 + DecisionBalance float64 + Params map[string]string +} + +// ConfigStore stores all wacecore configuration from the config file. +type ConfigStore struct { + ModelPlugins map[string]modelPluginConfig + DecisionPlugins map[string]decisionPluginConfig + LogPath string + LogLevel logging.LogLevel + NatsURL string + ApplicationId string +} + +var config *ConfigStore + +// Create and returns the unique instance of configstore if it does not exist previously, in other case returns error +func New() (*ConfigStore, error) { + if config != nil { + return nil, fmt.Errorf("ConfigStore already exists") + } + config = new(ConfigStore) + return config, nil +} + +// Get returns the unique instance of configstore +func Get() (*ConfigStore, error) { + if config == nil { + return nil, fmt.Errorf("Configuration was not loaded") + } + return config, nil +} + +// Clean remove the references to the stored instance of configstore +func Clean() { + config = nil +} + +type configFileModelPlugin struct { + ID string + Path string + Weight float64 + Threshold float64 + Params map[string]string + PluginType string `yaml:"plugintype"` + Mode string + Remote bool +} + +type configFileDecisionPlugin struct { + ID string + Path string + wafweight float64 + decisionbalance float64 + Params map[string]string +} + +type ConfigFileData struct { + Logpath string + Loglevel string + Modelplugins []configFileModelPlugin + Decisionplugins []configFileDecisionPlugin + NatsURL string +} + +// IsAsync returns true if the model plugin is async +func (c *ConfigStore) IsAsync(modelID string) bool { + return c.ModelPlugins[modelID].Mode == "async" +} + +// CheckLogging verifies if the log path is valid +func checkLogging(inConf ConfigFileData) error { + // check logpath + if inConf.Logpath == "" { + return fmt.Errorf("log path empty") + } + _, err := os.Stat(inConf.Logpath) + if err != nil { // check if log file does not exists already + // Attempt to create dummy file + var d []byte + err = os.WriteFile(inConf.Logpath, d, 0644) + if err == nil { + err = os.Remove(inConf.Logpath) // delete it + } + } + return err +} + +// CheckConfig verifies if the configuration read from the config file +// is correct. +func checkConfig(inConf ConfigFileData) error { + err := checkLogging(inConf) + if err != nil { + return fmt.Errorf("invalid log path %s: %v", inConf.Logpath, err) + } + + // check modelplugins + for _, modelP := range inConf.Modelplugins { + + if modelP.Path != "" { + if _, err := os.Stat(modelP.Path); err != nil { + return fmt.Errorf("%s plugin path %s: %v", modelP.ID, modelP.Path, err) + } + } else { + return fmt.Errorf("%s plugin path is empty, please provide a valid path", modelP.ID) + } + if modelP.PluginType == "" { + return fmt.Errorf("%s plugin type cannot be empty, please provide a valid type", modelP.ID) + } + // fmt.Printf("modelP.Type: %s\n", modelP.Type) + } + // check decisionplugins + for _, decisionP := range inConf.Decisionplugins { + + if decisionP.Path != "" { + if _, err := os.Stat(decisionP.Path); err != nil { + return fmt.Errorf("%s plugin path %s cannot be opened: %v", decisionP.ID, decisionP.Path, err) + } + } else { + return fmt.Errorf("%s plugin path is empty, please provide a valid path", decisionP.ID) + } + } + + return nil +} + +// SetConfig sets the configuration of WACE from the configuration file +func (cs *ConfigStore) SetConfig(inConf ConfigFileData) error { + err := checkConfig(inConf) + if err != nil { + return err + } + + cs.LogPath = inConf.Logpath + cs.LogLevel, err = logging.StringToLogLevel(inConf.Loglevel) + if err != nil { + return err + } + + cs.ModelPlugins = make(map[string]modelPluginConfig) + for _, modelP := range inConf.Modelplugins { + var modelConfig modelPluginConfig + modelConfig.ID = modelP.ID + modelConfig.Path = modelP.Path + modelConfig.Weight = modelP.Weight + modelConfig.Threshold = modelP.Threshold + modelConfig.Params = modelP.Params + modelConfig.PluginType, err = StringToPluginType(modelP.PluginType) + modelConfig.Mode = modelP.Mode + modelConfig.Remote = modelP.Remote + if err != nil { + return err + } + cs.ModelPlugins[modelConfig.ID] = modelConfig + } + + cs.DecisionPlugins = make(map[string]decisionPluginConfig) + for _, decisionP := range inConf.Decisionplugins { + var decisionConfig decisionPluginConfig + decisionConfig.ID = decisionP.ID + decisionConfig.Path = decisionP.Path + decisionConfig.WAFweight = decisionP.wafweight + decisionConfig.DecisionBalance = decisionP.decisionbalance + decisionConfig.Params = decisionP.Params + cs.DecisionPlugins[decisionConfig.ID] = decisionConfig + } + + if inConf.NatsURL != "" { + cs.NatsURL = inConf.NatsURL + } else { + cs.NatsURL = "localhost:4222" + } + + return nil +} diff --git a/configstore/configstore_test.go b/configstore/configstore_test.go index 67abf0f..93b3d97 100644 --- a/configstore/configstore_test.go +++ b/configstore/configstore_test.go @@ -1,277 +1,277 @@ -package configstore - -import ( - "fmt" - "os" - "testing" - - "gopkg.in/yaml.v3" -) - -var validConfig = []byte(`--- -logpath: "/dev/stderr" -loglevel: "DEBUG" -modelplugins: - - id: "trivial" - path: "../testdata/plugins/model/trivial.so" - weight: 1 - threshold: 0.5 - params: - d: "sds" - b: "dnid" - e: "dofnno" - plugintype: "RequestHeaders" - mode: "sync" - - id: "trivial2" - path: "../testdata/plugins/model/trivial2.so" - weight: 2 - threshold: 0.1 - params: - a: "sdsds" - b: "sdfjdnid" - c: "kfoskdofnno" - plugintype: "RequestHeaders" -decisionplugins: - - id: "test" - path: "../testdata/plugins/decision/test.so" - wafweight: 0.5 - decisionbalance: 0.5 - params: - ssdaf: "sdsds" - dsfb: "sdfjdnid" - csfd: "kfoskdofnno" -`) - -func initialize(configuration []byte) error { - cs, err := Get() - if err != nil { - return err - } - var aux ConfigFileData - err = yaml.Unmarshal(configuration, &aux) - if err != nil { - return err - } - err = cs.SetConfig(aux) - if err != nil { - return err - } - return nil -} - -func TestLoadConfigYamlEmpty(t *testing.T) { - _, err := New() - if err != nil { - t.Error(err) - } - - defer Clean() - - err = initialize([]byte(`---`)) - if err == nil { - t.Error("empty config does not return error") - } -} - -func TestLoadConfigYamlValid(t *testing.T) { - _, err := New() - if err != nil { - t.Error(err) - } - - defer Clean() - - err = initialize(validConfig) - if err != nil { - t.Errorf("valid config returned error: %v", err) - } -} - -func TestLoadConfigYamlInvalid(t *testing.T) { - _, err := New() - if err != nil { - t.Error(err) - } - - defer Clean() - - err = initialize([]byte(`()=)(/&/()~@#~½¬{[{½¬½---sfdjlskjfs#@~sjdfa`)) - - if err == nil { - t.Error("invalid config does not return error") - } -} - -func TestLoadConfigYamlLogLevel(t *testing.T) { - _, err := New() - if err != nil { - t.Error(err) - } - - defer Clean() - - values := []string{ - "a", - "4", - "0", - } - - for _, v := range values { - config := `--- -logpath: "/dev/null" -loglevel: ` + v - err = initialize([]byte(config)) - if err == nil { - t.Errorf("invalid log level %v does not return error", v) - } - } -} - -func TestLoadConfigYamlPluginType(t *testing.T) { - cs, err := New() - if err != nil { - t.Error(err) - } - - defer Clean() - - err = initialize([]byte(`--- -loglevel: ERROR -logpath: /dev/null -modelplugins: - - id: "testplugin" - path: "../testdata/plugins/model/trivial.so" - plugintype: InvalidPluginType -`)) - if err == nil { - t.Errorf("invalid plugin type does not return error") - } - - err = initialize([]byte(`--- -loglevel: ERROR -logpath: /dev/null -modelplugins: - - id: "testplugin" - path: "../testdata/plugins/model/trivial.so" - plugintype: "" -`)) - if err == nil { - t.Errorf("empty plugin type does not return error") - } - - err = initialize([]byte(`--- -loglevel: ERROR -logpath: /dev/null -modelplugins: - - id: "testplugin" - path: "../testdata/plugins/model/nonexistent.so" - plugintype: "RequestHeaders" -`)) - if err == nil { - t.Errorf("nonexistent model plugin path does not return error") - } - - err = initialize([]byte(`--- -loglevel: ERROR -logpath: /dev/null -modelplugins: - - id: "testplugin" - path: "" - plugintype: "RequestHeaders" -`)) - if err == nil { - t.Errorf("empty plugin path does not return error") - } - - err = initialize([]byte(`--- -loglevel: ERROR -logpath: /dev/null -decisionplugins: - - id: "test" - path: "" -`)) - if err == nil { - t.Errorf("empty decision plugin path does not return error") - } - - err = initialize([]byte(`--- -loglevel: ERROR -logpath: /dev/null -decisionplugins: - - id: "testplugin" - path: "../testdata/plugins/decision/nonexistent.so" -`)) - if err == nil { - t.Errorf("nonexistent decision plugin path does not return error") - } - - values := []string{ - "RequestHeaders", - "RequestBody", - "AllRequest", - "ResponseHeaders", - "ResponseBody", - "AllResponse", - "Everything", - } - - for _, v := range values { - config := `--- -loglevel: ERROR -logpath: /dev/null -modelplugins: - - id: "testplugin" - path: "../testdata/plugins/model/trivial.so" - plugintype: "` + v + `" -` - err = initialize([]byte(config)) - if err != nil { - t.Errorf("Plugin type %s returns error: %v", v, err) - } - - if fmt.Sprint(cs.ModelPlugins["testplugin"].PluginType) != v { - t.Errorf("Stored plugin type is %v, expected %v", cs.ModelPlugins["testplugin"].PluginType, v) - } - } -} - -func TestInvalidLogging(t *testing.T) { - _, err := New() - if err != nil { - t.Error(err) - } - - defer Clean() - - err = initialize([]byte(`--- -loglevel: INVALIDLOGLEVEL -logpath: /dev/null -`)) - if err == nil { - t.Errorf("invalid log level does not return error") - } - - if _, err = os.Stat("./configstore_test.log"); err == nil { - err = os.Remove("./configstore_test.log") - if err != nil { - t.Errorf("could not remove ./configstore_test.log") - } - } - - err = initialize([]byte(`--- -loglevel: ERROR -logpath: ./configstore_test.log`)) - - if err != nil { - t.Errorf("Error loading config with nonexistent file: %v", err) - } - - err = initialize([]byte(`--- -loglevel: ERROR -logpath: /usr/configstore_test.log`)) - - if err == nil { - t.Errorf("non existent log file in directory without permissions does not rise error") - } - -} +package configstore + +import ( + "fmt" + "os" + "testing" + + "gopkg.in/yaml.v3" +) + +var validConfig = []byte(`--- +logpath: "/dev/stderr" +loglevel: "DEBUG" +modelplugins: + - id: "trivial" + path: "../testdata/plugins/model/trivial.so" + weight: 1 + threshold: 0.5 + params: + d: "sds" + b: "dnid" + e: "dofnno" + plugintype: "RequestHeaders" + mode: "sync" + - id: "trivial2" + path: "../testdata/plugins/model/trivial2.so" + weight: 2 + threshold: 0.1 + params: + a: "sdsds" + b: "sdfjdnid" + c: "kfoskdofnno" + plugintype: "RequestHeaders" +decisionplugins: + - id: "test" + path: "../testdata/plugins/decision/test.so" + wafweight: 0.5 + decisionbalance: 0.5 + params: + ssdaf: "sdsds" + dsfb: "sdfjdnid" + csfd: "kfoskdofnno" +`) + +func initialize(configuration []byte) error { + cs, err := Get() + if err != nil { + return err + } + var aux ConfigFileData + err = yaml.Unmarshal(configuration, &aux) + if err != nil { + return err + } + err = cs.SetConfig(aux) + if err != nil { + return err + } + return nil +} + +func TestLoadConfigYamlEmpty(t *testing.T) { + _, err := New() + if err != nil { + t.Error(err) + } + + defer Clean() + + err = initialize([]byte(`---`)) + if err == nil { + t.Error("empty config does not return error") + } +} + +func TestLoadConfigYamlValid(t *testing.T) { + _, err := New() + if err != nil { + t.Error(err) + } + + defer Clean() + + err = initialize(validConfig) + if err != nil { + t.Errorf("valid config returned error: %v", err) + } +} + +func TestLoadConfigYamlInvalid(t *testing.T) { + _, err := New() + if err != nil { + t.Error(err) + } + + defer Clean() + + err = initialize([]byte(`()=)(/&/()~@#~½¬{[{½¬½---sfdjlskjfs#@~sjdfa`)) + + if err == nil { + t.Error("invalid config does not return error") + } +} + +func TestLoadConfigYamlLogLevel(t *testing.T) { + _, err := New() + if err != nil { + t.Error(err) + } + + defer Clean() + + values := []string{ + "a", + "4", + "0", + } + + for _, v := range values { + config := `--- +logpath: "/dev/null" +loglevel: ` + v + err = initialize([]byte(config)) + if err == nil { + t.Errorf("invalid log level %v does not return error", v) + } + } +} + +func TestLoadConfigYamlPluginType(t *testing.T) { + cs, err := New() + if err != nil { + t.Error(err) + } + + defer Clean() + + err = initialize([]byte(`--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "../testdata/plugins/model/trivial.so" + plugintype: InvalidPluginType +`)) + if err == nil { + t.Errorf("invalid plugin type does not return error") + } + + err = initialize([]byte(`--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "../testdata/plugins/model/trivial.so" + plugintype: "" +`)) + if err == nil { + t.Errorf("empty plugin type does not return error") + } + + err = initialize([]byte(`--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "../testdata/plugins/model/nonexistent.so" + plugintype: "RequestHeaders" +`)) + if err == nil { + t.Errorf("nonexistent model plugin path does not return error") + } + + err = initialize([]byte(`--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "" + plugintype: "RequestHeaders" +`)) + if err == nil { + t.Errorf("empty plugin path does not return error") + } + + err = initialize([]byte(`--- +loglevel: ERROR +logpath: /dev/null +decisionplugins: + - id: "test" + path: "" +`)) + if err == nil { + t.Errorf("empty decision plugin path does not return error") + } + + err = initialize([]byte(`--- +loglevel: ERROR +logpath: /dev/null +decisionplugins: + - id: "testplugin" + path: "../testdata/plugins/decision/nonexistent.so" +`)) + if err == nil { + t.Errorf("nonexistent decision plugin path does not return error") + } + + values := []string{ + "RequestHeaders", + "RequestBody", + "AllRequest", + "ResponseHeaders", + "ResponseBody", + "AllResponse", + "Everything", + } + + for _, v := range values { + config := `--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "../testdata/plugins/model/trivial.so" + plugintype: "` + v + `" +` + err = initialize([]byte(config)) + if err != nil { + t.Errorf("Plugin type %s returns error: %v", v, err) + } + + if fmt.Sprint(cs.ModelPlugins["testplugin"].PluginType) != v { + t.Errorf("Stored plugin type is %v, expected %v", cs.ModelPlugins["testplugin"].PluginType, v) + } + } +} + +func TestInvalidLogging(t *testing.T) { + _, err := New() + if err != nil { + t.Error(err) + } + + defer Clean() + + err = initialize([]byte(`--- +loglevel: INVALIDLOGLEVEL +logpath: /dev/null +`)) + if err == nil { + t.Errorf("invalid log level does not return error") + } + + if _, err = os.Stat("./configstore_test.log"); err == nil { + err = os.Remove("./configstore_test.log") + if err != nil { + t.Errorf("could not remove ./configstore_test.log") + } + } + + err = initialize([]byte(`--- +loglevel: ERROR +logpath: ./configstore_test.log`)) + + if err != nil { + t.Errorf("Error loading config with nonexistent file: %v", err) + } + + err = initialize([]byte(`--- +loglevel: ERROR +logpath: /usr/configstore_test.log`)) + + if err == nil { + t.Errorf("non existent log file in directory without permissions does not rise error") + } + +} diff --git a/pluginmanager/pluginmanager.go b/pluginmanager/pluginmanager.go index 641e4f9..e7204ce 100644 --- a/pluginmanager/pluginmanager.go +++ b/pluginmanager/pluginmanager.go @@ -1,496 +1,496 @@ -/* -Package pluginmanager handles the communication with the model and -decision plugins -*/ -package pluginmanager - -import ( - "encoding/json" - "fmt" - "plugin" - "sync" - - "github.com/tilsor/ModSecIntl_wace_lib/configstore" - "go.opentelemetry.io/otel/metric" - - "github.com/nats-io/nats.go" - "github.com/tilsor/ModSecIntl_logging/logging" -) - -// ResultData maps the model plugin ID with the corresponding analysis result. -type ModelResults struct { - ProbAttack float64 `json:"probattack"` - Data map[string]interface{} `json:"data"` -} - -type HTTPHeader struct { - Key string - Value string -} - -type HTTPPayload struct { - URI string - Method string - HTTPVersion string - RequestHeaders []HTTPHeader - RequestBody string - ResponseProtocol string - ResponseCode int - ResponseHeaders []HTTPHeader - ResponseBody string -} - -// ModelInput is the struct that contains the input data for the model plugin -type ModelInput struct { - TransactionId string `json:"transactionId"` - Payload HTTPPayload `json:"payload"` -} - -// DecisionInput is the struct that contains the input data for the decision plugin -type DecisionInput struct { - TransactionId string - Results map[string]ModelResults - ModelWeight map[string]float64 - WAFdata map[string]string -} - -// ModelTransmitionResults is the struct that contains the results of the model plugin -type ModelTransmitionResults struct { - TransactionId string `json:"transactionId"` - ModelResults `json:",inline"` - Error error `json:"error"` -} - -// modelPlugin is the struct that stores the model plugin and its type -type modelPlugin struct { - p *plugin.Plugin - pluginType configstore.ModelPluginType -} - -// decisionPlugin is the struct that stores the decision plugin -type decisionPlugin struct { - p *plugin.Plugin -} - -// ModelStatus stores whether there was an error while processing a -// request (response) by the modelID model plugin -type ModelStatus struct { - ModelID string - ProbAttack float64 - Err error -} - -// PluginManager is the main plugin struct storing information of -// every plugin execution. -type PluginManager struct { - modelPlugins map[string]modelPlugin - modelProcessFunc map[string]func(ModelInput) (ModelResults, error) - decisionCheckFunc map[string]func(DecisionInput) (bool, error) - decisionPlugins map[string]decisionPlugin - results sync.Map - channelsMutex sync.Mutex - syncModelsChannels sync.Map - asyncModelsChannels sync.Map - natConn *nats.Conn -} - -// New creates a new PluginManager instance. -func New(meter metric.Meter) (*PluginManager, error) { - pm := new(PluginManager) - conf, err := configstore.Get() - if err != nil { - return nil, err - } - logger := logging.Get() - logger.Printf(logging.DEBUG, "Connecting to NATS server at %s", conf.NatsURL) - - nc, err := nats.Connect(conf.NatsURL) - - if err != nil { - logger.Printf(logging.ERROR, "Failed to connect to NATS server") - } - - pm.natConn = nc - - // Loading of model plugins - pm.modelPlugins = make(map[string]modelPlugin) - pm.modelProcessFunc = make(map[string]func(ModelInput) (ModelResults, error)) - for _, data := range conf.ModelPlugins { - tp, err := plugin.Open(data.Path) - if err != nil { - logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) - continue - } - if data.Mode == "async" || conf.ModelPlugins[data.ID].Remote { - f, err := tp.Lookup("InitPluginAsync") - if err != nil { - logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) - continue - } - initPlugin, ok := f.(func(map[string]string, metric.Meter, func(func(ModelInput) (ModelResults, error))) error) - if !ok { - logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid InitPluginAsync function type", data.ID) - continue - } - err = initPlugin(data.Params, meter, func(modelProcess func(ModelInput) (ModelResults, error)) { - ModelProcessHandler(data.ID, modelProcess) - }) - if err != nil { - logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) - continue - } - go pm.ModelResultsHandler(data.ID) - } else { - f, err := tp.Lookup("InitPlugin") - if err != nil { - logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) - continue - } - initPlugin, ok := f.(func(map[string]string, metric.Meter) error) - if !ok { - logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid InitPlugin function type", data.ID) - continue - } - err = initPlugin(data.Params, meter) - procFunc, err := tp.Lookup("Process") - if err != nil { - logger.Printf(logging.WARN, "| %s | cannot load plugin: cannot load Process function", data.ID) - continue - } - process, ok := procFunc.(func(ModelInput) (ModelResults, error)) - if !ok { - logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid Process function type", data.ID) - continue - } - pm.modelProcessFunc[data.ID] = process - } - modelPluginLoaded := modelPlugin{tp, data.PluginType} - pm.modelPlugins[data.ID] = modelPluginLoaded - logger.Printf(logging.INFO, "| %s | plugin loaded", data.ID) - } - - pm.decisionPlugins = make(map[string]decisionPlugin) - pm.decisionCheckFunc = make(map[string]func(DecisionInput) (bool, error)) - // Loading of decision plugins - for _, data := range conf.DecisionPlugins { - tp, err := plugin.Open(data.Path) - if err != nil { - logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) - continue - } - f, err := tp.Lookup("InitPlugin") - if err != nil { - logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) - continue - } - initPlugin, ok := f.(func(map[string]string, metric.Meter) error) - if !ok { - logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid InitPlugin function type", data.ID) - continue - } - err = initPlugin(data.Params, meter) - if err != nil { - logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) - continue - } - cR, err := tp.Lookup("CheckResults") - if err != nil { - logger.Printf(logging.ERROR, "| %s | cannot load plugin check results function: %v", data.ID, err) - continue - } - checkResults, ok := cR.(func(DecisionInput) (bool, error)) - if !ok { - logger.Printf(logging.ERROR, "| %s | CheckResults lookup failed for plugin: invalid function type", data.ID) - continue - } - pm.decisionCheckFunc[data.ID] = checkResults - decisionPluginLoaded := decisionPlugin{tp} - pm.decisionPlugins[data.ID] = decisionPluginLoaded - } - return pm, nil -} - -// InitTransaction initializes the transaction with the given ID -func (p *PluginManager) InitTransaction(transactionId string) { - p.results.Store(transactionId, new(sync.Map)) -} - -// CloseTransaction closes the transaction with the given ID -// removing all sync model data -func (p *PluginManager) CloseTransaction(transactionId string) { - logger := logging.Get() - transactionMap, ok := p.syncModelsChannels.Load(transactionId) - if !ok { - logger.TPrintf(logging.ERROR, transactionId, "Transaction %s not found", transactionId) - } else { - transactionMap.(*sync.Map).Range(func(key, value interface{}) bool { - ch := value.(chan ModelStatus) - close(ch) - for range ch { - } - transactionMap.(*sync.Map).Delete(key) - return true - }) - p.syncModelsChannels.Delete(transactionId) - resultsMap, ok := p.results.Load(transactionId) - if !ok { - logger.TPrintf(logging.ERROR, transactionId, "Results for transaction %s not found", transactionId) - } else { - resultsMap.(*sync.Map).Range(func(key, value interface{}) bool { - resultsMap.(*sync.Map).Delete(key) - return true - }) - } - p.results.Delete(transactionId) - } -} - -// AddModelChannel adds a channel to result channel map -func (p *PluginManager) AddModelChannel(transactionId string, t configstore.ModelPluginType, modelPlugStatus chan ModelStatus, modelType string) { - typeModel := new(sync.Map) - var value interface{} - if modelType == "sync" { - value, _ = p.syncModelsChannels.LoadOrStore(transactionId, typeModel) - } else { - value, _ = p.asyncModelsChannels.LoadOrStore(transactionId, typeModel) - } - value.(*sync.Map).Store(t.String(), modelPlugStatus) -} - -// RemoveModelChannel removes a channel from the result channel map -func (p *PluginManager) RemoveAsyncModelChannel(transactionId string, t configstore.ModelPluginType) { - typeModel, ok := p.asyncModelsChannels.Load(transactionId) - if ok { - channelMap := typeModel.(*sync.Map) - ch, channelOk := channelMap.Load(t.String()) - - if channelOk { - close(ch.(chan ModelStatus)) - for range ch.(chan ModelStatus) { - } - channelMap.Delete(t.String()) - } - - remainChannels := 0 - typeModel.(*sync.Map).Range(func(key, value interface{}) bool { - remainChannels++ - return true - }) - if remainChannels == 0 { - p.asyncModelsChannels.Delete(transactionId) - } - } else { - logger := logging.Get() - logger.TPrintf(logging.ERROR, transactionId, "Transaction %s not found when trying to remove async model channel", transactionId) - } -} - -// AddToQueue adds a payload to the model queue -func (p *PluginManager) AddToQueue(modelID, transactionID string, payload HTTPPayload) error { - payloadToSend := &ModelInput{ - TransactionId: transactionID, - Payload: payload, - } - - jsonPayload, err := json.Marshal(payloadToSend) - - if err != nil { - return err - } - - return p.natConn.Publish(modelID, jsonPayload) -} - -// Process is in charge of calling the model plugin with id modelID -func (p *PluginManager) Process(modelID, transactionId string, payload HTTPPayload, t configstore.ModelPluginType, modelPlugStatus chan ModelStatus) error { - conf, err := configstore.Get() - if err != nil { - return err - } - - mp, exists := p.modelPlugins[modelID] - if !exists { - modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("model plugin not found")} - return nil - } - - // check if the plugin is capable of analyzing the indicated part of the transaction - if mp.pluginType != t { - modelPlugStatus <- ModelStatus{ModelID: modelID, - Err: fmt.Errorf("plugin type %v cannot process a request with incompatible type %v", mp.pluginType, t)} - return nil - } - - process := p.modelProcessFunc[modelID] - - if conf.ModelPlugins[modelID].Mode == "async" { - modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("model plugin is async")} - return nil - } else { - res, err := process(ModelInput{TransactionId: transactionId, Payload: payload}) - // res, err := process(transactionId, payload) - - if err != nil { - modelPlugStatus <- ModelStatus{ModelID: modelID, Err: err} - return nil - } - // store the results - resultSyncMap, ok := p.results.Load(transactionId) - if !ok { - modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("transaction results not found")} - return nil - } - resultSyncMap.(*sync.Map).Store(modelID, res) - modelPlugStatus <- ModelStatus{ModelID: modelID, ProbAttack: res.ProbAttack, Err: nil} - } - return nil -} - -// CheckResult is in charge of calling the decision plugin with id decisionID over the -// transaction with id transactID -func (p *PluginManager) CheckResult(transactionId, decisionId string, wafParams map[string]string) (bool, error) { - logger := logging.Get() - - checkResults, ok := p.decisionCheckFunc[decisionId] - if !ok { - return false, fmt.Errorf("decision plugin not found") - } - - transactionResults, ok := p.results.Load(transactionId) - if !ok { - return false, fmt.Errorf("transaction results not found") - } - - cs, err := configstore.Get() - if err != nil { - return false, nil - } - - modelResultMap := make(map[string]ModelResults) - modelWeightMap := make(map[string]float64) - transactionResults.(*sync.Map).Range(func(key, value interface{}) bool { - modelResultMap[key.(string)] = value.(ModelResults) - modelWeightMap[key.(string)] = cs.ModelPlugins[key.(string)].Weight - return true - }) - - res, err := checkResults(DecisionInput{TransactionId: transactionId, Results: modelResultMap, ModelWeight: modelWeightMap, WAFdata: wafParams}) - logger.TPrintf(logging.INFO, transactionId, "%s | transaction checked. Block: %t ", decisionId, res) - - return res, err -} - -// ModelResultsHandler listens for messages on the model results queue -func (p *PluginManager) ModelResultsHandler(modelId string) error { - logger := logging.Get() - cs, err := configstore.Get() - if err != nil { - return err - } - - sub, err := p.natConn.Subscribe(modelId+"/results", func(msg *nats.Msg) { - go func(msg nats.Msg) { - data := &ModelTransmitionResults{} - err := json.Unmarshal(msg.Data, data) - if err != nil { - logger.Printf(logging.ERROR, "Model: %s | Failed to parse JSON payload", modelId) - } else { - var channel interface{} - var ok bool - if cs.ModelPlugins[modelId].Mode == "async" { - channel, ok = p.asyncModelsChannels.Load(data.TransactionId) - } else { - channel, ok = p.syncModelsChannels.Load(data.TransactionId) - } - if !ok { - logger.TPrintf(logging.ERROR, data.TransactionId, " Model %s | Transaction not found", modelId) - } else { - modelChannel, ok := channel.(*sync.Map).Load(cs.ModelPlugins[modelId].PluginType.String()) - if !ok { - logger.Printf(logging.ERROR, "Model %s not found", modelId) - } else { - if data.Error != nil { - modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, Err: data.Error} - } else { - if cs.ModelPlugins[modelId].Mode != "async" { - // store the results - resultSyncMap, ok := p.results.Load(data.TransactionId) - if !ok { - modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, Err: fmt.Errorf("transaction results not found")} - return - } - modelResult := ModelResults{ProbAttack: data.ProbAttack, Data: data.Data} - resultSyncMap.(*sync.Map).Store(modelId, modelResult) - } - modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, ProbAttack: data.ProbAttack, Err: nil} - } - } - } - } - }(*msg) - }) - - if err != nil { - logger.Printf(logging.ERROR, "Model: %s | Failed to subscribe to model queue | %s", modelId, err.Error()) - return err - } - - logger.Printf(logging.INFO, "Model: %s | Listening for messages on model results queue", modelId) - - defer sub.Unsubscribe() - defer p.natConn.Drain() - - select {} - -} - -// ModelProcessHandler listens for messages on the model queue -func ModelProcessHandler(modelId string, modelProcess func(ModelInput) (ModelResults, error)) error { - logger := logging.Get() - logger.Printf(logging.INFO, "Model: %s | Starting model process handler", modelId) - cs, err := configstore.Get() - if err != nil { - return err - } - - nc, err := nats.Connect(cs.NatsURL) - - if err != nil { - logger.Printf(logging.ERROR, "Model: %s | Failed to connect to NATS server", modelId) - return err - } - - _, err = nc.Subscribe(modelId, func(msg *nats.Msg) { - go func(msg nats.Msg) { - data := &ModelInput{} - err := json.Unmarshal(msg.Data, data) - if err != nil { - logger.Printf(logging.ERROR, "Model: %s | Failed to parse JSON payload", modelId) - } else { - res, err := modelProcess(*data) - modelResult := ModelResults{ProbAttack: res.ProbAttack, Data: res.Data} - payloadToSend := &ModelTransmitionResults{ - TransactionId: data.TransactionId, - ModelResults: modelResult, - Error: err, - } - - jsonPayload, err := json.Marshal(payloadToSend) - - if err != nil { - logger.Printf(logging.ERROR, "Model: %s | Failed to parse JSON payload", modelId) - } - - nc.Publish(modelId+"/results", jsonPayload) - } - }(*msg) - }) - - if err != nil { - logger.Printf(logging.ERROR, "Model: %s | Failed to subscribe to model queue | %s", modelId, err.Error()) - return err - } - - logger.Printf(logging.INFO, "Model: %s | Listening for messages on model queue", modelId) - return nil -} +/* +Package pluginmanager handles the communication with the model and +decision plugins +*/ +package pluginmanager + +import ( + "encoding/json" + "fmt" + "plugin" + "sync" + + "github.com/tilsor/ModSecIntl_wace_lib/configstore" + "go.opentelemetry.io/otel/metric" + + "github.com/nats-io/nats.go" + "github.com/tilsor/ModSecIntl_logging/logging" +) + +// ResultData maps the model plugin ID with the corresponding analysis result. +type ModelResults struct { + ProbAttack float64 `json:"probattack"` + Data map[string]interface{} `json:"data"` +} + +type HTTPHeader struct { + Key string + Value string +} + +type HTTPPayload struct { + URI string + Method string + HTTPVersion string + RequestHeaders []HTTPHeader + RequestBody string + ResponseProtocol string + ResponseCode int + ResponseHeaders []HTTPHeader + ResponseBody string +} + +// ModelInput is the struct that contains the input data for the model plugin +type ModelInput struct { + TransactionId string `json:"transactionId"` + Payload HTTPPayload `json:"payload"` +} + +// DecisionInput is the struct that contains the input data for the decision plugin +type DecisionInput struct { + TransactionId string + Results map[string]ModelResults + ModelWeight map[string]float64 + WAFdata map[string]string +} + +// ModelTransmitionResults is the struct that contains the results of the model plugin +type ModelTransmitionResults struct { + TransactionId string `json:"transactionId"` + ModelResults `json:",inline"` + Error error `json:"error"` +} + +// modelPlugin is the struct that stores the model plugin and its type +type modelPlugin struct { + p *plugin.Plugin + pluginType configstore.ModelPluginType +} + +// decisionPlugin is the struct that stores the decision plugin +type decisionPlugin struct { + p *plugin.Plugin +} + +// ModelStatus stores whether there was an error while processing a +// request (response) by the modelID model plugin +type ModelStatus struct { + ModelID string + ProbAttack float64 + Err error +} + +// PluginManager is the main plugin struct storing information of +// every plugin execution. +type PluginManager struct { + modelPlugins map[string]modelPlugin + modelProcessFunc map[string]func(ModelInput) (ModelResults, error) + decisionCheckFunc map[string]func(DecisionInput) (bool, error) + decisionPlugins map[string]decisionPlugin + results sync.Map + channelsMutex sync.Mutex + syncModelsChannels sync.Map + asyncModelsChannels sync.Map + natConn *nats.Conn +} + +// New creates a new PluginManager instance. +func New(meter metric.Meter) (*PluginManager, error) { + pm := new(PluginManager) + conf, err := configstore.Get() + if err != nil { + return nil, err + } + logger := logging.Get() + logger.Printf(logging.DEBUG, "Connecting to NATS server at %s", conf.NatsURL) + + nc, err := nats.Connect(conf.NatsURL) + + if err != nil { + logger.Printf(logging.ERROR, "Failed to connect to NATS server") + } + + pm.natConn = nc + + // Loading of model plugins + pm.modelPlugins = make(map[string]modelPlugin) + pm.modelProcessFunc = make(map[string]func(ModelInput) (ModelResults, error)) + for _, data := range conf.ModelPlugins { + tp, err := plugin.Open(data.Path) + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) + continue + } + if data.Mode == "async" || conf.ModelPlugins[data.ID].Remote { + f, err := tp.Lookup("InitPluginAsync") + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) + continue + } + initPlugin, ok := f.(func(map[string]string, metric.Meter, func(func(ModelInput) (ModelResults, error))) error) + if !ok { + logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid InitPluginAsync function type", data.ID) + continue + } + err = initPlugin(data.Params, meter, func(modelProcess func(ModelInput) (ModelResults, error)) { + ModelProcessHandler(data.ID, modelProcess) + }) + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) + continue + } + go pm.ModelResultsHandler(data.ID) + } else { + f, err := tp.Lookup("InitPlugin") + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) + continue + } + initPlugin, ok := f.(func(map[string]string, metric.Meter) error) + if !ok { + logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid InitPlugin function type", data.ID) + continue + } + err = initPlugin(data.Params, meter) + procFunc, err := tp.Lookup("Process") + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: cannot load Process function", data.ID) + continue + } + process, ok := procFunc.(func(ModelInput) (ModelResults, error)) + if !ok { + logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid Process function type", data.ID) + continue + } + pm.modelProcessFunc[data.ID] = process + } + modelPluginLoaded := modelPlugin{tp, data.PluginType} + pm.modelPlugins[data.ID] = modelPluginLoaded + logger.Printf(logging.INFO, "| %s | plugin loaded", data.ID) + } + + pm.decisionPlugins = make(map[string]decisionPlugin) + pm.decisionCheckFunc = make(map[string]func(DecisionInput) (bool, error)) + // Loading of decision plugins + for _, data := range conf.DecisionPlugins { + tp, err := plugin.Open(data.Path) + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) + continue + } + f, err := tp.Lookup("InitPlugin") + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) + continue + } + initPlugin, ok := f.(func(map[string]string, metric.Meter) error) + if !ok { + logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid InitPlugin function type", data.ID) + continue + } + err = initPlugin(data.Params, meter) + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) + continue + } + cR, err := tp.Lookup("CheckResults") + if err != nil { + logger.Printf(logging.ERROR, "| %s | cannot load plugin check results function: %v", data.ID, err) + continue + } + checkResults, ok := cR.(func(DecisionInput) (bool, error)) + if !ok { + logger.Printf(logging.ERROR, "| %s | CheckResults lookup failed for plugin: invalid function type", data.ID) + continue + } + pm.decisionCheckFunc[data.ID] = checkResults + decisionPluginLoaded := decisionPlugin{tp} + pm.decisionPlugins[data.ID] = decisionPluginLoaded + } + return pm, nil +} + +// InitTransaction initializes the transaction with the given ID +func (p *PluginManager) InitTransaction(transactionId string) { + p.results.Store(transactionId, new(sync.Map)) +} + +// CloseTransaction closes the transaction with the given ID +// removing all sync model data +func (p *PluginManager) CloseTransaction(transactionId string) { + logger := logging.Get() + transactionMap, ok := p.syncModelsChannels.Load(transactionId) + if !ok { + logger.TPrintf(logging.ERROR, transactionId, "Transaction %s not found", transactionId) + } else { + transactionMap.(*sync.Map).Range(func(key, value interface{}) bool { + ch := value.(chan ModelStatus) + close(ch) + for range ch { + } + transactionMap.(*sync.Map).Delete(key) + return true + }) + p.syncModelsChannels.Delete(transactionId) + resultsMap, ok := p.results.Load(transactionId) + if !ok { + logger.TPrintf(logging.ERROR, transactionId, "Results for transaction %s not found", transactionId) + } else { + resultsMap.(*sync.Map).Range(func(key, value interface{}) bool { + resultsMap.(*sync.Map).Delete(key) + return true + }) + } + p.results.Delete(transactionId) + } +} + +// AddModelChannel adds a channel to result channel map +func (p *PluginManager) AddModelChannel(transactionId string, t configstore.ModelPluginType, modelPlugStatus chan ModelStatus, modelType string) { + typeModel := new(sync.Map) + var value interface{} + if modelType == "sync" { + value, _ = p.syncModelsChannels.LoadOrStore(transactionId, typeModel) + } else { + value, _ = p.asyncModelsChannels.LoadOrStore(transactionId, typeModel) + } + value.(*sync.Map).Store(t.String(), modelPlugStatus) +} + +// RemoveModelChannel removes a channel from the result channel map +func (p *PluginManager) RemoveAsyncModelChannel(transactionId string, t configstore.ModelPluginType) { + typeModel, ok := p.asyncModelsChannels.Load(transactionId) + if ok { + channelMap := typeModel.(*sync.Map) + ch, channelOk := channelMap.Load(t.String()) + + if channelOk { + close(ch.(chan ModelStatus)) + for range ch.(chan ModelStatus) { + } + channelMap.Delete(t.String()) + } + + remainChannels := 0 + typeModel.(*sync.Map).Range(func(key, value interface{}) bool { + remainChannels++ + return true + }) + if remainChannels == 0 { + p.asyncModelsChannels.Delete(transactionId) + } + } else { + logger := logging.Get() + logger.TPrintf(logging.ERROR, transactionId, "Transaction %s not found when trying to remove async model channel", transactionId) + } +} + +// AddToQueue adds a payload to the model queue +func (p *PluginManager) AddToQueue(modelID, transactionID string, payload HTTPPayload) error { + payloadToSend := &ModelInput{ + TransactionId: transactionID, + Payload: payload, + } + + jsonPayload, err := json.Marshal(payloadToSend) + + if err != nil { + return err + } + + return p.natConn.Publish(modelID, jsonPayload) +} + +// Process is in charge of calling the model plugin with id modelID +func (p *PluginManager) Process(modelID, transactionId string, payload HTTPPayload, t configstore.ModelPluginType, modelPlugStatus chan ModelStatus) error { + conf, err := configstore.Get() + if err != nil { + return err + } + + mp, exists := p.modelPlugins[modelID] + if !exists { + modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("model plugin not found")} + return nil + } + + // check if the plugin is capable of analyzing the indicated part of the transaction + if mp.pluginType != t { + modelPlugStatus <- ModelStatus{ModelID: modelID, + Err: fmt.Errorf("plugin type %v cannot process a request with incompatible type %v", mp.pluginType, t)} + return nil + } + + process := p.modelProcessFunc[modelID] + + if conf.ModelPlugins[modelID].Mode == "async" { + modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("model plugin is async")} + return nil + } else { + res, err := process(ModelInput{TransactionId: transactionId, Payload: payload}) + // res, err := process(transactionId, payload) + + if err != nil { + modelPlugStatus <- ModelStatus{ModelID: modelID, Err: err} + return nil + } + // store the results + resultSyncMap, ok := p.results.Load(transactionId) + if !ok { + modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("transaction results not found")} + return nil + } + resultSyncMap.(*sync.Map).Store(modelID, res) + modelPlugStatus <- ModelStatus{ModelID: modelID, ProbAttack: res.ProbAttack, Err: nil} + } + return nil +} + +// CheckResult is in charge of calling the decision plugin with id decisionID over the +// transaction with id transactID +func (p *PluginManager) CheckResult(transactionId, decisionId string, wafParams map[string]string) (bool, error) { + logger := logging.Get() + + checkResults, ok := p.decisionCheckFunc[decisionId] + if !ok { + return false, fmt.Errorf("decision plugin not found") + } + + transactionResults, ok := p.results.Load(transactionId) + if !ok { + return false, fmt.Errorf("transaction results not found") + } + + cs, err := configstore.Get() + if err != nil { + return false, nil + } + + modelResultMap := make(map[string]ModelResults) + modelWeightMap := make(map[string]float64) + transactionResults.(*sync.Map).Range(func(key, value interface{}) bool { + modelResultMap[key.(string)] = value.(ModelResults) + modelWeightMap[key.(string)] = cs.ModelPlugins[key.(string)].Weight + return true + }) + + res, err := checkResults(DecisionInput{TransactionId: transactionId, Results: modelResultMap, ModelWeight: modelWeightMap, WAFdata: wafParams}) + logger.TPrintf(logging.INFO, transactionId, "%s | transaction checked. Block: %t ", decisionId, res) + + return res, err +} + +// ModelResultsHandler listens for messages on the model results queue +func (p *PluginManager) ModelResultsHandler(modelId string) error { + logger := logging.Get() + cs, err := configstore.Get() + if err != nil { + return err + } + + sub, err := p.natConn.Subscribe(modelId+"/results", func(msg *nats.Msg) { + go func(msg nats.Msg) { + data := &ModelTransmitionResults{} + err := json.Unmarshal(msg.Data, data) + if err != nil { + logger.Printf(logging.ERROR, "Model: %s | Failed to parse JSON payload", modelId) + } else { + var channel interface{} + var ok bool + if cs.ModelPlugins[modelId].Mode == "async" { + channel, ok = p.asyncModelsChannels.Load(data.TransactionId) + } else { + channel, ok = p.syncModelsChannels.Load(data.TransactionId) + } + if !ok { + logger.TPrintf(logging.ERROR, data.TransactionId, " Model %s | Transaction not found", modelId) + } else { + modelChannel, ok := channel.(*sync.Map).Load(cs.ModelPlugins[modelId].PluginType.String()) + if !ok { + logger.Printf(logging.ERROR, "Model %s not found", modelId) + } else { + if data.Error != nil { + modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, Err: data.Error} + } else { + if cs.ModelPlugins[modelId].Mode != "async" { + // store the results + resultSyncMap, ok := p.results.Load(data.TransactionId) + if !ok { + modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, Err: fmt.Errorf("transaction results not found")} + return + } + modelResult := ModelResults{ProbAttack: data.ProbAttack, Data: data.Data} + resultSyncMap.(*sync.Map).Store(modelId, modelResult) + } + modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, ProbAttack: data.ProbAttack, Err: nil} + } + } + } + } + }(*msg) + }) + + if err != nil { + logger.Printf(logging.ERROR, "Model: %s | Failed to subscribe to model queue | %s", modelId, err.Error()) + return err + } + + logger.Printf(logging.INFO, "Model: %s | Listening for messages on model results queue", modelId) + + defer sub.Unsubscribe() + defer p.natConn.Drain() + + select {} + +} + +// ModelProcessHandler listens for messages on the model queue +func ModelProcessHandler(modelId string, modelProcess func(ModelInput) (ModelResults, error)) error { + logger := logging.Get() + logger.Printf(logging.INFO, "Model: %s | Starting model process handler", modelId) + cs, err := configstore.Get() + if err != nil { + return err + } + + nc, err := nats.Connect(cs.NatsURL) + + if err != nil { + logger.Printf(logging.ERROR, "Model: %s | Failed to connect to NATS server", modelId) + return err + } + + _, err = nc.Subscribe(modelId, func(msg *nats.Msg) { + go func(msg nats.Msg) { + data := &ModelInput{} + err := json.Unmarshal(msg.Data, data) + if err != nil { + logger.Printf(logging.ERROR, "Model: %s | Failed to parse JSON payload", modelId) + } else { + res, err := modelProcess(*data) + modelResult := ModelResults{ProbAttack: res.ProbAttack, Data: res.Data} + payloadToSend := &ModelTransmitionResults{ + TransactionId: data.TransactionId, + ModelResults: modelResult, + Error: err, + } + + jsonPayload, err := json.Marshal(payloadToSend) + + if err != nil { + logger.Printf(logging.ERROR, "Model: %s | Failed to parse JSON payload", modelId) + } + + nc.Publish(modelId+"/results", jsonPayload) + } + }(*msg) + }) + + if err != nil { + logger.Printf(logging.ERROR, "Model: %s | Failed to subscribe to model queue | %s", modelId, err.Error()) + return err + } + + logger.Printf(logging.INFO, "Model: %s | Listening for messages on model queue", modelId) + return nil +} diff --git a/pluginmanager/pluginmanager_test.go b/pluginmanager/pluginmanager_test.go index 1d52d43..2834794 100644 --- a/pluginmanager/pluginmanager_test.go +++ b/pluginmanager/pluginmanager_test.go @@ -1,371 +1,371 @@ -package pluginmanager - -import ( - "math/rand" - "time" - - "github.com/tilsor/ModSecIntl_wace_lib/configstore" - "go.opentelemetry.io/otel/sdk/metric" - "gopkg.in/yaml.v3" - - "github.com/tilsor/ModSecIntl_logging/logging" -) - -var baseConfig = `--- -logpath: "/tmp/wacetmp.log" -loglevel: "WARN" -` - -var trivialPlugin = ` - id: "trivial" - path: "../testdata/plugins/model/trivial.so" - weight: 1 - params: - param1: "first value" - param2: "second value" - param3: "third value" - plugintype: "Everything" - mode: sync -` - -var testPlugin = ` - id: "test" - path: "../testdata/plugins/decision/test.so" - wafweight: 0.5 - decisionbalance: 0.5 - params: - test1: "test" - test2: "testtest" - test3: "testtesttest" -` - -func generateRandomID() string { - letters := "1234567890ABCDEF" - id := "" - for i := 0; i < 16; i++ { - id += string(letters[rand.Intn(len(letters))]) - } - - return id -} - -var provider = metric.NewMeterProvider() -var testMeter = provider.Meter("example-meter") - -func initilize(configuration []byte) error { - var aux configstore.ConfigFileData - err := yaml.Unmarshal(configuration, &aux) - if err != nil { - return err - } - cs, err := configstore.Get() - if err != nil { - return err - } - err = cs.SetConfig(aux) - if err != nil { - return err - } - logger := logging.Get() - - err = logger.LoadLogger(cs.LogPath, cs.LogLevel) - if err != nil { - return err - - } - return nil -} - -func init() { - rand.Seed(time.Now().UnixNano()) - - logger := logging.Get() - err := logger.LoadLogger("/dev/null", logging.ERROR) - if err != nil { - panic("Error loading logger") - } -} - -// func TestPluginInit(t *testing.T) { -// cases := []struct{ id, conf string }{ -// // {"invalid_path", ` - id: "invalid_path" -// // path: "../testdata/plugins/model/nonexistent.so" -// // plugintype: "AllRequest" -// // `}, -// {"no_init", ` - id: "no_init" -// path: "../testdata/plugins/model/no_init.so" -// plugintype: "AllRequest" -// `}, -// {"wrong_init", ` - id: "wrong_init" -// path: "../testdata/plugins/model/wrong_init.so" -// plugintype: "AllRequest" -// `}, -// {"error_init", ` - id: "error_init" -// path: "../testdata/plugins/model/error_init.so" -// plugintype: "AllRequest" -// `}, -// } - -// // Test model plugin initialization -// for _, c := range cases { -// config := baseConfig + "modelplugins:\n" + trivialPlugin + c.conf - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } -// plugins := New(testMeter) -// if _, exists := plugins.modelPlugins["trivial"]; !exists { -// t.Errorf("trivial plugin not loaded") -// } -// if _, exists := plugins.modelPlugins[c.id]; exists { -// t.Errorf(c.id + " should not load") -// } -// } - -// // Test decision plugin initialization -// for _, c := range cases { -// config := baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + testPlugin + c.conf - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } -// plugins := New(testMeter) -// if _, exists := plugins.decisionPlugins["test"]; !exists { -// t.Errorf("test plugin not loaded") -// } -// if _, exists := plugins.decisionPlugins[c.id]; exists { -// t.Errorf(c.id + " should not load") -// } -// } - -// } - -// func TestPluginParams(t *testing.T) { -// config := baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + testPlugin - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } - -// var buf bytes.Buffer -// logger := lg.Get() -// err = logger.LoadLoggerWriter(&buf, lg.INFO) -// if err != nil { -// t.Errorf("Error loading logger: %v", err) -// } - -// plugins := New(testMeter) - -// if !strings.Contains(buf.String(), "[trivial:InitPlugin] map[param1:first value param2:second value param3:third value]") { -// t.Errorf("trivial plugin did not initialize correctly, got: %v, expected: %v", buf.String(), "[trivial:InitPlugin] map[param1:first value param2:second value param3:third value]") -// } -// if !strings.Contains(buf.String(), "[test:InitPlugin] map[test1:test test2:testtest test3:testtesttest]") { -// t.Errorf("test plugin did not initialize correctly") -// } - -// transactionID := generateRandomID() -// modelPlugStatus := make(chan ModelStatus) -// go plugins.Process("trivial", transactionID, "test request1", cf.AllRequest, modelPlugStatus) -// <-modelPlugStatus -// if !strings.Contains(buf.String(), "[trivial:ProcessRequest] \"test request1\"") { -// t.Errorf("trivial plugin did not analyze request") -// } - -// go plugins.Process("trivial", transactionID, "test response1", cf.AllResponse, modelPlugStatus) -// <-modelPlugStatus -// if !strings.Contains(buf.String(), "[trivial:ProcessResponse] \"test response1\"") { -// t.Errorf("trivial plugin did not analyze response") -// } - -// _, err = plugins.CheckResult(transactionID, "test", map[string]string{"anomalyscore": "100", "inboundthreshold": "10"}) -// if err != nil { -// t.Errorf("Error checking result: %v", err) -// } -// if !strings.Contains(buf.String(), "[test:CheckResults]") { -// t.Errorf("test plugin did not execute correctly") -// } -// if !strings.Contains(buf.String(), "modelRes: map[trivial:") { -// t.Errorf("trivial result is not stored in modelRes") -// } -// if !strings.Contains(buf.String(), "modelWeight: map[trivial:1]") { -// t.Errorf("trivial weight is not stored in modelWeight") -// } -// if !strings.Contains(buf.String(), "modelThres: map[trivial:0.5]") { -// t.Errorf("trivial threshold is not stored in modelWeight") -// } -// if !strings.Contains(buf.String(), "wafData: map[anomalyscore:100 inboundthreshold:10]") { -// t.Errorf("waf params are not stored in wafData") -// } -// } - -// func TestPluginType(t *testing.T) { -// cases := []struct { -// id string -// pluginType, requestType cf.ModelPluginType -// executes bool -// }{ -// {"req_headers-req_headers", cf.RequestHeaders, cf.RequestHeaders, true}, -// {"req_headers-resp_headers", cf.RequestHeaders, cf.ResponseHeaders, false}, -// {"req_headers-all_req", cf.RequestHeaders, cf.AllRequest, false}, -// {"all_req-req_headers", cf.AllRequest, cf.RequestHeaders, false}, -// {"all_req-all_resp", cf.AllRequest, cf.AllResponse, false}, - -// {"resp_headers-resp_headers", cf.ResponseHeaders, cf.ResponseHeaders, true}, -// {"resp_headers-req_headers", cf.ResponseHeaders, cf.RequestHeaders, false}, -// {"resp_headers-all_resp", cf.ResponseHeaders, cf.AllResponse, false}, -// {"all_resp-resp_headers", cf.AllResponse, cf.ResponseHeaders, false}, -// {"all_resp-all_req", cf.AllResponse, cf.AllRequest, false}, - -// {"everything-req_headers", cf.Everything, cf.RequestHeaders, true}, -// {"everything-all_req", cf.Everything, cf.AllRequest, true}, -// {"everything-resp_body", cf.Everything, cf.ResponseBody, true}, -// {"everything-all_resp", cf.Everything, cf.AllResponse, true}, -// } - -// for _, c := range cases { -// config := baseConfig + "modelplugins:\n" + -// " - id: \"" + c.id + "\"\n" + -// " path: \"../testdata/plugins/model/trivial.so\"\n" + -// " plugintype: \"" + c.pluginType.String() + "\"\n" - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } - -// old := log.Writer() -// var buf bytes.Buffer -// log.SetOutput(&buf) -// defer log.SetOutput(old) - -// plugins := New(testMeter) - -// transactionID := generateRandomID() -// modelPlugStatus := make(chan ModelStatus) -// switch c.requestType { -// case cf.RequestHeaders, cf.RequestBody, cf.AllRequest: -// go plugins.Process(c.id, transactionID, "test request", c.requestType, modelPlugStatus) -// <-modelPlugStatus -// if strings.Contains(buf.String(), "[trivial:ProcessRequest] \"test request\"") != c.executes { -// t.Errorf("case %s: expected to run trivial plugin: %v", c.id, c.executes) -// } -// if _, exists := plugins.results.Load(transactionID); exists != c.executes { -// t.Errorf("case %s: expected to store results: %v", c.id, c.executes) -// } -// case cf.ResponseHeaders, cf.ResponseBody, cf.AllResponse: -// go plugins.Process(c.id, transactionID, "test response", c.requestType, modelPlugStatus) -// <-modelPlugStatus -// if strings.Contains(buf.String(), "[trivial:ProcessResponse] \"test response\"") != c.executes { -// t.Errorf("case %s: expected to run trivial plugin: %v", c.id, c.executes) -// } -// if _, exists := plugins.results.Load(transactionID); exists != c.executes { -// t.Errorf("case %s: expected to store results: %v", c.id, c.executes) -// } -// } - -// } -// } - -// func TestProcessRequestInvalid(t *testing.T) { -// cases := []struct{ id, conf string }{ -// {"no_req", ` - id: "no_req" -// path: "../testdata/plugins/model/no_req.so" -// plugintype: "Everything" -// `}, -// {"wrong_req", ` - id: "wrong_req" -// path: "../testdata/plugins/model/wrong_req.so" -// plugintype: "Everything" -// `}, -// {"error_req", ` - id: "error_req" -// path: "../testdata/plugins/model/error_req.so" -// plugintype: "Everything" -// `}, -// } - -// // Test model plugin initialization -// for _, c := range cases { -// config := baseConfig + "modelplugins:\n" + trivialPlugin + c.conf - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } -// plugins := New(testMeter) - -// transactionID := generateRandomID() -// modelPlugStatus := make(chan ModelStatus) -// go plugins.Process(c.id, transactionID, "test request", cf.AllRequest, modelPlugStatus) -// <-modelPlugStatus -// go plugins.Process(c.id, transactionID, "test response", cf.AllResponse, modelPlugStatus) -// <-modelPlugStatus - -// if _, exists := plugins.results.Load(transactionID); exists { -// t.Errorf("invalid test %s stored a result", c.id) -// } -// } - -// config := baseConfig + "modelplugins:\n" + trivialPlugin - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } -// plugins := New(testMeter) - -// transactionID := generateRandomID() -// modelPlugStatus := make(chan ModelStatus) -// go plugins.Process("nonexistent", transactionID, "test request", cf.AllRequest, modelPlugStatus) -// <-modelPlugStatus -// go plugins.Process("nonexistent", transactionID, "test response", cf.AllResponse, modelPlugStatus) -// <-modelPlugStatus - -// if _, exists := plugins.results.Load(transactionID); exists { -// t.Errorf("nonexistent test stored a result") -// } - -// } - -// func TestCheckResultInvalid(t *testing.T) { -// cases := []struct{ id, conf string }{ -// {"no_check", ` - id: "no_check" -// path: "../testdata/plugins/decision/no_check.so" -// `}, -// {"wrong_check", ` - id: "wrong_check" -// path: "../testdata/plugins/decision/wrong_check.so" -// `}, -// {"error_check", ` - id: "error_check" -// path: "../testdata/plugins/decision/error_check.so" -// `}, -// } - -// // Test model plugin initialization -// for _, c := range cases { -// config := baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + c.conf - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } -// plugins := New(testMeter) - -// _, err = plugins.CheckResult(generateRandomID(), c.id, make(map[string]string)) -// if err == nil { -// t.Errorf("invalid CheckResult %s did not rise an error", c.id) -// } -// } - -// config := baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + testPlugin - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } -// plugins := New(testMeter) - -// _, err = plugins.CheckResult(generateRandomID(), "nonexitent", make(map[string]string)) -// if err == nil { -// t.Errorf("nonexistent plugin did not rise an error") -// } - -// } +package pluginmanager + +import ( + "math/rand" + "time" + + "github.com/tilsor/ModSecIntl_wace_lib/configstore" + "go.opentelemetry.io/otel/sdk/metric" + "gopkg.in/yaml.v3" + + "github.com/tilsor/ModSecIntl_logging/logging" +) + +var baseConfig = `--- +logpath: "/tmp/wacetmp.log" +loglevel: "WARN" +` + +var trivialPlugin = ` - id: "trivial" + path: "../testdata/plugins/model/trivial.so" + weight: 1 + params: + param1: "first value" + param2: "second value" + param3: "third value" + plugintype: "Everything" + mode: sync +` + +var testPlugin = ` - id: "test" + path: "../testdata/plugins/decision/test.so" + wafweight: 0.5 + decisionbalance: 0.5 + params: + test1: "test" + test2: "testtest" + test3: "testtesttest" +` + +func generateRandomID() string { + letters := "1234567890ABCDEF" + id := "" + for i := 0; i < 16; i++ { + id += string(letters[rand.Intn(len(letters))]) + } + + return id +} + +var provider = metric.NewMeterProvider() +var testMeter = provider.Meter("example-meter") + +func initilize(configuration []byte) error { + var aux configstore.ConfigFileData + err := yaml.Unmarshal(configuration, &aux) + if err != nil { + return err + } + cs, err := configstore.Get() + if err != nil { + return err + } + err = cs.SetConfig(aux) + if err != nil { + return err + } + logger := logging.Get() + + err = logger.LoadLogger(cs.LogPath, cs.LogLevel) + if err != nil { + return err + + } + return nil +} + +func init() { + rand.Seed(time.Now().UnixNano()) + + logger := logging.Get() + err := logger.LoadLogger("/dev/null", logging.ERROR) + if err != nil { + panic("Error loading logger") + } +} + +// func TestPluginInit(t *testing.T) { +// cases := []struct{ id, conf string }{ +// // {"invalid_path", ` - id: "invalid_path" +// // path: "../testdata/plugins/model/nonexistent.so" +// // plugintype: "AllRequest" +// // `}, +// {"no_init", ` - id: "no_init" +// path: "../testdata/plugins/model/no_init.so" +// plugintype: "AllRequest" +// `}, +// {"wrong_init", ` - id: "wrong_init" +// path: "../testdata/plugins/model/wrong_init.so" +// plugintype: "AllRequest" +// `}, +// {"error_init", ` - id: "error_init" +// path: "../testdata/plugins/model/error_init.so" +// plugintype: "AllRequest" +// `}, +// } + +// // Test model plugin initialization +// for _, c := range cases { +// config := baseConfig + "modelplugins:\n" + trivialPlugin + c.conf + +// err := initilize([]byte(config)) +// if err != nil { +// t.Errorf("Error loading config: %v", err) +// } +// plugins := New(testMeter) +// if _, exists := plugins.modelPlugins["trivial"]; !exists { +// t.Errorf("trivial plugin not loaded") +// } +// if _, exists := plugins.modelPlugins[c.id]; exists { +// t.Errorf(c.id + " should not load") +// } +// } + +// // Test decision plugin initialization +// for _, c := range cases { +// config := baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + testPlugin + c.conf + +// err := initilize([]byte(config)) +// if err != nil { +// t.Errorf("Error loading config: %v", err) +// } +// plugins := New(testMeter) +// if _, exists := plugins.decisionPlugins["test"]; !exists { +// t.Errorf("test plugin not loaded") +// } +// if _, exists := plugins.decisionPlugins[c.id]; exists { +// t.Errorf(c.id + " should not load") +// } +// } + +// } + +// func TestPluginParams(t *testing.T) { +// config := baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + testPlugin + +// err := initilize([]byte(config)) +// if err != nil { +// t.Errorf("Error loading config: %v", err) +// } + +// var buf bytes.Buffer +// logger := lg.Get() +// err = logger.LoadLoggerWriter(&buf, lg.INFO) +// if err != nil { +// t.Errorf("Error loading logger: %v", err) +// } + +// plugins := New(testMeter) + +// if !strings.Contains(buf.String(), "[trivial:InitPlugin] map[param1:first value param2:second value param3:third value]") { +// t.Errorf("trivial plugin did not initialize correctly, got: %v, expected: %v", buf.String(), "[trivial:InitPlugin] map[param1:first value param2:second value param3:third value]") +// } +// if !strings.Contains(buf.String(), "[test:InitPlugin] map[test1:test test2:testtest test3:testtesttest]") { +// t.Errorf("test plugin did not initialize correctly") +// } + +// transactionID := generateRandomID() +// modelPlugStatus := make(chan ModelStatus) +// go plugins.Process("trivial", transactionID, "test request1", cf.AllRequest, modelPlugStatus) +// <-modelPlugStatus +// if !strings.Contains(buf.String(), "[trivial:ProcessRequest] \"test request1\"") { +// t.Errorf("trivial plugin did not analyze request") +// } + +// go plugins.Process("trivial", transactionID, "test response1", cf.AllResponse, modelPlugStatus) +// <-modelPlugStatus +// if !strings.Contains(buf.String(), "[trivial:ProcessResponse] \"test response1\"") { +// t.Errorf("trivial plugin did not analyze response") +// } + +// _, err = plugins.CheckResult(transactionID, "test", map[string]string{"anomalyscore": "100", "inboundthreshold": "10"}) +// if err != nil { +// t.Errorf("Error checking result: %v", err) +// } +// if !strings.Contains(buf.String(), "[test:CheckResults]") { +// t.Errorf("test plugin did not execute correctly") +// } +// if !strings.Contains(buf.String(), "modelRes: map[trivial:") { +// t.Errorf("trivial result is not stored in modelRes") +// } +// if !strings.Contains(buf.String(), "modelWeight: map[trivial:1]") { +// t.Errorf("trivial weight is not stored in modelWeight") +// } +// if !strings.Contains(buf.String(), "modelThres: map[trivial:0.5]") { +// t.Errorf("trivial threshold is not stored in modelWeight") +// } +// if !strings.Contains(buf.String(), "wafData: map[anomalyscore:100 inboundthreshold:10]") { +// t.Errorf("waf params are not stored in wafData") +// } +// } + +// func TestPluginType(t *testing.T) { +// cases := []struct { +// id string +// pluginType, requestType cf.ModelPluginType +// executes bool +// }{ +// {"req_headers-req_headers", cf.RequestHeaders, cf.RequestHeaders, true}, +// {"req_headers-resp_headers", cf.RequestHeaders, cf.ResponseHeaders, false}, +// {"req_headers-all_req", cf.RequestHeaders, cf.AllRequest, false}, +// {"all_req-req_headers", cf.AllRequest, cf.RequestHeaders, false}, +// {"all_req-all_resp", cf.AllRequest, cf.AllResponse, false}, + +// {"resp_headers-resp_headers", cf.ResponseHeaders, cf.ResponseHeaders, true}, +// {"resp_headers-req_headers", cf.ResponseHeaders, cf.RequestHeaders, false}, +// {"resp_headers-all_resp", cf.ResponseHeaders, cf.AllResponse, false}, +// {"all_resp-resp_headers", cf.AllResponse, cf.ResponseHeaders, false}, +// {"all_resp-all_req", cf.AllResponse, cf.AllRequest, false}, + +// {"everything-req_headers", cf.Everything, cf.RequestHeaders, true}, +// {"everything-all_req", cf.Everything, cf.AllRequest, true}, +// {"everything-resp_body", cf.Everything, cf.ResponseBody, true}, +// {"everything-all_resp", cf.Everything, cf.AllResponse, true}, +// } + +// for _, c := range cases { +// config := baseConfig + "modelplugins:\n" + +// " - id: \"" + c.id + "\"\n" + +// " path: \"../testdata/plugins/model/trivial.so\"\n" + +// " plugintype: \"" + c.pluginType.String() + "\"\n" + +// err := initilize([]byte(config)) +// if err != nil { +// t.Errorf("Error loading config: %v", err) +// } + +// old := log.Writer() +// var buf bytes.Buffer +// log.SetOutput(&buf) +// defer log.SetOutput(old) + +// plugins := New(testMeter) + +// transactionID := generateRandomID() +// modelPlugStatus := make(chan ModelStatus) +// switch c.requestType { +// case cf.RequestHeaders, cf.RequestBody, cf.AllRequest: +// go plugins.Process(c.id, transactionID, "test request", c.requestType, modelPlugStatus) +// <-modelPlugStatus +// if strings.Contains(buf.String(), "[trivial:ProcessRequest] \"test request\"") != c.executes { +// t.Errorf("case %s: expected to run trivial plugin: %v", c.id, c.executes) +// } +// if _, exists := plugins.results.Load(transactionID); exists != c.executes { +// t.Errorf("case %s: expected to store results: %v", c.id, c.executes) +// } +// case cf.ResponseHeaders, cf.ResponseBody, cf.AllResponse: +// go plugins.Process(c.id, transactionID, "test response", c.requestType, modelPlugStatus) +// <-modelPlugStatus +// if strings.Contains(buf.String(), "[trivial:ProcessResponse] \"test response\"") != c.executes { +// t.Errorf("case %s: expected to run trivial plugin: %v", c.id, c.executes) +// } +// if _, exists := plugins.results.Load(transactionID); exists != c.executes { +// t.Errorf("case %s: expected to store results: %v", c.id, c.executes) +// } +// } + +// } +// } + +// func TestProcessRequestInvalid(t *testing.T) { +// cases := []struct{ id, conf string }{ +// {"no_req", ` - id: "no_req" +// path: "../testdata/plugins/model/no_req.so" +// plugintype: "Everything" +// `}, +// {"wrong_req", ` - id: "wrong_req" +// path: "../testdata/plugins/model/wrong_req.so" +// plugintype: "Everything" +// `}, +// {"error_req", ` - id: "error_req" +// path: "../testdata/plugins/model/error_req.so" +// plugintype: "Everything" +// `}, +// } + +// // Test model plugin initialization +// for _, c := range cases { +// config := baseConfig + "modelplugins:\n" + trivialPlugin + c.conf + +// err := initilize([]byte(config)) +// if err != nil { +// t.Errorf("Error loading config: %v", err) +// } +// plugins := New(testMeter) + +// transactionID := generateRandomID() +// modelPlugStatus := make(chan ModelStatus) +// go plugins.Process(c.id, transactionID, "test request", cf.AllRequest, modelPlugStatus) +// <-modelPlugStatus +// go plugins.Process(c.id, transactionID, "test response", cf.AllResponse, modelPlugStatus) +// <-modelPlugStatus + +// if _, exists := plugins.results.Load(transactionID); exists { +// t.Errorf("invalid test %s stored a result", c.id) +// } +// } + +// config := baseConfig + "modelplugins:\n" + trivialPlugin + +// err := initilize([]byte(config)) +// if err != nil { +// t.Errorf("Error loading config: %v", err) +// } +// plugins := New(testMeter) + +// transactionID := generateRandomID() +// modelPlugStatus := make(chan ModelStatus) +// go plugins.Process("nonexistent", transactionID, "test request", cf.AllRequest, modelPlugStatus) +// <-modelPlugStatus +// go plugins.Process("nonexistent", transactionID, "test response", cf.AllResponse, modelPlugStatus) +// <-modelPlugStatus + +// if _, exists := plugins.results.Load(transactionID); exists { +// t.Errorf("nonexistent test stored a result") +// } + +// } + +// func TestCheckResultInvalid(t *testing.T) { +// cases := []struct{ id, conf string }{ +// {"no_check", ` - id: "no_check" +// path: "../testdata/plugins/decision/no_check.so" +// `}, +// {"wrong_check", ` - id: "wrong_check" +// path: "../testdata/plugins/decision/wrong_check.so" +// `}, +// {"error_check", ` - id: "error_check" +// path: "../testdata/plugins/decision/error_check.so" +// `}, +// } + +// // Test model plugin initialization +// for _, c := range cases { +// config := baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + c.conf + +// err := initilize([]byte(config)) +// if err != nil { +// t.Errorf("Error loading config: %v", err) +// } +// plugins := New(testMeter) + +// _, err = plugins.CheckResult(generateRandomID(), c.id, make(map[string]string)) +// if err == nil { +// t.Errorf("invalid CheckResult %s did not rise an error", c.id) +// } +// } + +// config := baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + testPlugin + +// err := initilize([]byte(config)) +// if err != nil { +// t.Errorf("Error loading config: %v", err) +// } +// plugins := New(testMeter) + +// _, err = plugins.CheckResult(generateRandomID(), "nonexitent", make(map[string]string)) +// if err == nil { +// t.Errorf("nonexistent plugin did not rise an error") +// } + +// } diff --git a/wacecore.go b/wacecore.go index d22c3bd..f3111a2 100644 --- a/wacecore.go +++ b/wacecore.go @@ -1,272 +1,272 @@ -/* -The main package of WACE. -*/ -package wace - -import ( - "fmt" - "sync" - "sync/atomic" - "time" - - "github.com/tilsor/ModSecIntl_wace_lib/configstore" - - "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" - - "github.com/tilsor/ModSecIntl_logging/logging" - - "context" - - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/metric" -) - -var plugins *pluginmanager.PluginManager -var ctx = context.Background() -var meter metric.Meter - -// transactionSync is a struct to syncronize the analysis of a given -// transaction. Each time callPlugins is executed, the counter is -// incremented. At the end of each callPlugins execution, a message is -// sent through the channel, to signal checkTransaction that it has -// finished analyzing the request. checkTransaction waits for Counter -// number of messages in the channel, before calling the decision -// plugin and sending the result to the client. -type transactionSync struct { - Channel chan string - Counter int64 -} - -var ( - // Sync map witg channels to receive a notification when all plugins finish - // processing a transaction - analysisMap sync.Map -) - -// addTransactionAnalysis adds a transaction to the analysis map. If the -// transaction already exists, it increments the counter of the transaction -// by one. -func addTransactionAnalysis(transactionID string) { - tSync := transactionSync{ - Channel: make(chan string), - Counter: 1, - } - value, loaded := analysisMap.LoadOrStore(transactionID, &tSync) - if loaded { - atomic.AddInt64(&value.(*transactionSync).Counter, 1) - } -} - -// callPlugins calls the model plugins in the given list, with the given input. -// It waits for all the synchronous model plugins to finish, and sends the -// result to the client. The asynchronous model plugins are executed in parallel -func callPlugins(input pluginmanager.HTTPPayload, models []string, t configstore.ModelPluginType, transactionID string) error { - logger := logging.Get() - - // channel to receive the status of the execution of the analysis - // of all the model plugins executed - modelPluginStatus := make(chan pluginmanager.ModelStatus) - asyncModelPluginStatus := make(chan pluginmanager.ModelStatus) - - plugins.AddModelChannel(transactionID, t, asyncModelPluginStatus, "async") - plugins.AddModelChannel(transactionID, t, modelPluginStatus, "sync") - - conf, err := configstore.Get() - if err != nil { - return err - } - - syncCounter := 0 - asyncCounter := 0 - - startTime := time.Now() - - for _, id := range models { - logger.TPrintf(logging.DEBUG, transactionID, "%s | calling from core", id) - if _, ok := conf.ModelPlugins[id]; !ok { - logger.TPrintf(logging.ERROR, transactionID, "core | model plugin %s not found", id) - } else { - if conf.ModelPlugins[id].PluginType != t { - logger.TPrintf(logging.ERROR, transactionID, "core | model plugin %s is not of type %s", id, t) - } else { - if conf.IsAsync(id) { - asyncCounter++ - go plugins.AddToQueue(id, transactionID, input) - } else { - if conf.ModelPlugins[id].Remote { - go plugins.AddToQueue(id, transactionID, input) - } else { - go plugins.Process(id, transactionID, input, t, modelPluginStatus) - } - syncCounter++ - } - } - } - } - - go func() { - logger.TPrintf(logging.DEBUG, transactionID, "core | waiting for %d async model plugins to finish", asyncCounter) - wg := sync.WaitGroup{} - wg.Add(asyncCounter) - for i := 0; i < asyncCounter; i++ { - // Await for the execution of the async model plugins - logger.TPrintf(logging.DEBUG, transactionID, "core | Waiting for async model plugin %d...", i+1) - status := <-asyncModelPluginStatus - if status.Err == nil { - logger.TPrintf(logging.DEBUG, transactionID, "%s async | success. Result: %.5f", status.ModelID, status.ProbAttack) - histogramMeter, err := meter.Int64Histogram("wace.model.duration.nanoseconds") - if err != nil { - logger.TPrintf(logging.WARN, transactionID, "core | failed to record duration metric: %v", err.Error()) - } - histogramMeter.Record(ctx, time.Since(startTime).Nanoseconds(), metric.WithAttributes( - attribute.String("model_id", status.ModelID), - attribute.String("model_mode", "async"), - attribute.Float64("attack_probability", status.ProbAttack))) - } else { - logger.TPrintf(logging.WARN, transactionID, "%s | %v", status.ModelID, status.Err) - } - wg.Done() - } - wg.Wait() - plugins.RemoveAsyncModelChannel(transactionID, t) - }() - - logger.TPrintf(logging.DEBUG, transactionID, "core | waiting for %d sync model plugins to finish", syncCounter) - for i := 0; i < syncCounter; i++ { - // Await for the execution of the model plugins - logger.TPrintf(logging.DEBUG, transactionID, "core | Waiting for sync model plugin %d...", i+1) - status := <-modelPluginStatus - if status.Err == nil { - logger.TPrintf(logging.DEBUG, transactionID, "%s sync | success. Result: %.5f", status.ModelID, status.ProbAttack) - - histogramMeter, err := meter.Int64Histogram("wace.model.duration.nanoseconds") - if err != nil { - logger.TPrintf(logging.WARN, transactionID, "core | failed to record duration metric: %v", err.Error()) - } - histogramMeter.Record(ctx, time.Since(startTime).Nanoseconds(), metric.WithAttributes( - attribute.String("model_id", status.ModelID), - attribute.String("model_mode", "sync"), - attribute.Float64("attack_probability", status.ProbAttack))) - } else { - logger.TPrintf(logging.WARN, transactionID, "%s | %v", status.ModelID, status.Err) - } - } - - value, ok := analysisMap.Load(transactionID) - if !ok { - logger.TPrintf(logging.ERROR, transactionID, "core | could not find transaction %s in analysis map", transactionID) - return fmt.Errorf("core | could not find transaction %s in analysis map", transactionID) - } - analysisChan := value.(*transactionSync).Channel - analysisChan <- "done" - return nil -} - -// InitTransaction initializes a transaction with the given id -func InitTransaction(transactionId string) { - logger := logging.Get() - logger.StartTransaction(transactionId) - logger.TPrintf(logging.DEBUG, transactionId, "core | initializing transaction") - tSync := transactionSync{ - Channel: make(chan string), - Counter: 0, - } - analysisMap.Store(transactionId, &tSync) - plugins.InitTransaction(transactionId) -} - -// Analyze calls the model plugins with the given payload and models -func Analyze(modelsTypeAsString, transactionId string, payload pluginmanager.HTTPPayload, models []string) error { - if len(models) > 0 { - logger := logging.Get() - modelsType, err := configstore.StringToPluginType(modelsTypeAsString) - if err != nil { - logger.TPrintf(logging.ERROR, transactionId, "core | %s is not a valid type", modelsTypeAsString) - return err - } - logger.TPrintf(logging.DEBUG, transactionId, "core | analyzing %s: [%v...]", modelsTypeAsString, payload) - addTransactionAnalysis(transactionId) - go callPlugins(payload, models, modelsType, transactionId) - } - return nil -} - -// CheckTransaction checks the result of the analysis of the transaction -// with the given id and decision plugin -func CheckTransaction(transactionId, decisionPlugin string, wafParams map[string]string) (bool, error) { - logger := logging.Get() - logger.TPrintf(logging.DEBUG, transactionId, "core | checking transaction") - - value, exists := analysisMap.Load(transactionId) - - if !exists { - return false, fmt.Errorf("transaction with id %s does not exist", transactionId) - } - - sync := value.(*transactionSync) - - logger.TPrintln(logging.DEBUG, transactionId, "core | waiting for all models to finish...") - - for i := 0; i < int(sync.Counter); i++ { - <-sync.Channel - } - sync.Counter = 0 - - logger.TPrintln(logging.DEBUG, transactionId, "core | done, checking data...") - res, err := plugins.CheckResult(transactionId, decisionPlugin, wafParams) - - if err == nil { - logger.TPrintf(logging.DEBUG, transactionId, "core | transaction checked successfully. Blocking transaction: %t", res) - - if res { - metric, err := meter.Int64Counter("wace.client.request.blocked.total", metric.WithDescription(decisionPlugin)) - if err != nil { - logger.TPrintf(logging.WARN, transactionId, "core | failed to record blocked request metric: %v", err.Error()) - } - metric.Add(ctx, 1) - } - } else { - logger.TPrintf(logging.ERROR, transactionId, "core | could not check transaction: %v", err) - } - return res, err -} - -// CloseTransaction closes the transaction with the given id -// removing the transaction sync model results -func CloseTransaction(transactionID string) { - plugins.CloseTransaction(transactionID) - value, ok := analysisMap.Load(transactionID) - logger := logging.Get() - - if !ok { - logger.TPrintf(logging.ERROR, transactionID, "Analysis for transaction %s not found", transactionID) - } else { - close(value.(*transactionSync).Channel) - for range value.(*transactionSync).Channel { - } - analysisMap.Delete(transactionID) - } -} - -// Init initializes the WACE core with the given metric meter -func Init(met metric.Meter) error { - logger := logging.Get() - conf, err := configstore.Get() - meter = met - - err = logger.LoadLogger(conf.LogPath, conf.LogLevel) - if err != nil { - logger.Printf(logging.ERROR, "ERROR: could not open wace log file: %v", err) - return err - } - logger.Printf(logging.DEBUG, "Writing logs to %s from now", conf.LogPath) - - logger.Println(logging.DEBUG, "Loading plugin manager...") - plugins, err = pluginmanager.New(met) - if err != nil { - return err - } - logger.Println(logging.DEBUG, "Plugin manager loaded") - - return nil -} +/* +The main package of WACE. +*/ +package wace + +import ( + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/tilsor/ModSecIntl_wace_lib/configstore" + + "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" + + "github.com/tilsor/ModSecIntl_logging/logging" + + "context" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +var plugins *pluginmanager.PluginManager +var ctx = context.Background() +var meter metric.Meter + +// transactionSync is a struct to syncronize the analysis of a given +// transaction. Each time callPlugins is executed, the counter is +// incremented. At the end of each callPlugins execution, a message is +// sent through the channel, to signal checkTransaction that it has +// finished analyzing the request. checkTransaction waits for Counter +// number of messages in the channel, before calling the decision +// plugin and sending the result to the client. +type transactionSync struct { + Channel chan string + Counter int64 +} + +var ( + // Sync map witg channels to receive a notification when all plugins finish + // processing a transaction + analysisMap sync.Map +) + +// addTransactionAnalysis adds a transaction to the analysis map. If the +// transaction already exists, it increments the counter of the transaction +// by one. +func addTransactionAnalysis(transactionID string) { + tSync := transactionSync{ + Channel: make(chan string), + Counter: 1, + } + value, loaded := analysisMap.LoadOrStore(transactionID, &tSync) + if loaded { + atomic.AddInt64(&value.(*transactionSync).Counter, 1) + } +} + +// callPlugins calls the model plugins in the given list, with the given input. +// It waits for all the synchronous model plugins to finish, and sends the +// result to the client. The asynchronous model plugins are executed in parallel +func callPlugins(input pluginmanager.HTTPPayload, models []string, t configstore.ModelPluginType, transactionID string) error { + logger := logging.Get() + + // channel to receive the status of the execution of the analysis + // of all the model plugins executed + modelPluginStatus := make(chan pluginmanager.ModelStatus) + asyncModelPluginStatus := make(chan pluginmanager.ModelStatus) + + plugins.AddModelChannel(transactionID, t, asyncModelPluginStatus, "async") + plugins.AddModelChannel(transactionID, t, modelPluginStatus, "sync") + + conf, err := configstore.Get() + if err != nil { + return err + } + + syncCounter := 0 + asyncCounter := 0 + + startTime := time.Now() + + for _, id := range models { + logger.TPrintf(logging.DEBUG, transactionID, "%s | calling from core", id) + if _, ok := conf.ModelPlugins[id]; !ok { + logger.TPrintf(logging.ERROR, transactionID, "core | model plugin %s not found", id) + } else { + if conf.ModelPlugins[id].PluginType != t { + logger.TPrintf(logging.ERROR, transactionID, "core | model plugin %s is not of type %s", id, t) + } else { + if conf.IsAsync(id) { + asyncCounter++ + go plugins.AddToQueue(id, transactionID, input) + } else { + if conf.ModelPlugins[id].Remote { + go plugins.AddToQueue(id, transactionID, input) + } else { + go plugins.Process(id, transactionID, input, t, modelPluginStatus) + } + syncCounter++ + } + } + } + } + + go func() { + logger.TPrintf(logging.DEBUG, transactionID, "core | waiting for %d async model plugins to finish", asyncCounter) + wg := sync.WaitGroup{} + wg.Add(asyncCounter) + for i := 0; i < asyncCounter; i++ { + // Await for the execution of the async model plugins + logger.TPrintf(logging.DEBUG, transactionID, "core | Waiting for async model plugin %d...", i+1) + status := <-asyncModelPluginStatus + if status.Err == nil { + logger.TPrintf(logging.DEBUG, transactionID, "%s async | success. Result: %.5f", status.ModelID, status.ProbAttack) + histogramMeter, err := meter.Int64Histogram("wace.model.duration.nanoseconds") + if err != nil { + logger.TPrintf(logging.WARN, transactionID, "core | failed to record duration metric: %v", err.Error()) + } + histogramMeter.Record(ctx, time.Since(startTime).Nanoseconds(), metric.WithAttributes( + attribute.String("model_id", status.ModelID), + attribute.String("model_mode", "async"), + attribute.Float64("attack_probability", status.ProbAttack))) + } else { + logger.TPrintf(logging.WARN, transactionID, "%s | %v", status.ModelID, status.Err) + } + wg.Done() + } + wg.Wait() + plugins.RemoveAsyncModelChannel(transactionID, t) + }() + + logger.TPrintf(logging.DEBUG, transactionID, "core | waiting for %d sync model plugins to finish", syncCounter) + for i := 0; i < syncCounter; i++ { + // Await for the execution of the model plugins + logger.TPrintf(logging.DEBUG, transactionID, "core | Waiting for sync model plugin %d...", i+1) + status := <-modelPluginStatus + if status.Err == nil { + logger.TPrintf(logging.DEBUG, transactionID, "%s sync | success. Result: %.5f", status.ModelID, status.ProbAttack) + + histogramMeter, err := meter.Int64Histogram("wace.model.duration.nanoseconds") + if err != nil { + logger.TPrintf(logging.WARN, transactionID, "core | failed to record duration metric: %v", err.Error()) + } + histogramMeter.Record(ctx, time.Since(startTime).Nanoseconds(), metric.WithAttributes( + attribute.String("model_id", status.ModelID), + attribute.String("model_mode", "sync"), + attribute.Float64("attack_probability", status.ProbAttack))) + } else { + logger.TPrintf(logging.WARN, transactionID, "%s | %v", status.ModelID, status.Err) + } + } + + value, ok := analysisMap.Load(transactionID) + if !ok { + logger.TPrintf(logging.ERROR, transactionID, "core | could not find transaction %s in analysis map", transactionID) + return fmt.Errorf("core | could not find transaction %s in analysis map", transactionID) + } + analysisChan := value.(*transactionSync).Channel + analysisChan <- "done" + return nil +} + +// InitTransaction initializes a transaction with the given id +func InitTransaction(transactionId string) { + logger := logging.Get() + logger.StartTransaction(transactionId) + logger.TPrintf(logging.DEBUG, transactionId, "core | initializing transaction") + tSync := transactionSync{ + Channel: make(chan string), + Counter: 0, + } + analysisMap.Store(transactionId, &tSync) + plugins.InitTransaction(transactionId) +} + +// Analyze calls the model plugins with the given payload and models +func Analyze(modelsTypeAsString, transactionId string, payload pluginmanager.HTTPPayload, models []string) error { + if len(models) > 0 { + logger := logging.Get() + modelsType, err := configstore.StringToPluginType(modelsTypeAsString) + if err != nil { + logger.TPrintf(logging.ERROR, transactionId, "core | %s is not a valid type", modelsTypeAsString) + return err + } + logger.TPrintf(logging.DEBUG, transactionId, "core | analyzing %s: [%v...]", modelsTypeAsString, payload) + addTransactionAnalysis(transactionId) + go callPlugins(payload, models, modelsType, transactionId) + } + return nil +} + +// CheckTransaction checks the result of the analysis of the transaction +// with the given id and decision plugin +func CheckTransaction(transactionID, decisionPlugin string, wafParams map[string]string) (bool, error) { + logger := logging.Get() + logger.TPrintf(logging.DEBUG, transactionID, "core | checking transaction") + + value, exists := analysisMap.Load(transactionID) + + if !exists { + return false, fmt.Errorf("transaction with id %s does not exist", transactionID) + } + + sync := value.(*transactionSync) + + logger.TPrintln(logging.DEBUG, transactionID, "core | waiting for all models to finish...") + + for i := 0; i < int(sync.Counter); i++ { + <-sync.Channel + } + sync.Counter = 0 + + logger.TPrintln(logging.DEBUG, transactionID, "core | done, checking data...") + res, err := plugins.CheckResult(transactionID, decisionPlugin, wafParams) + + if err == nil { + logger.TPrintf(logging.DEBUG, transactionID, "core | transaction checked successfully. Blocking transaction: %t", res) + + if res { + metric, err := meter.Int64Counter("wace.client.request.blocked.total", metric.WithDescription(decisionPlugin)) + if err != nil { + logger.TPrintf(logging.WARN, transactionID, "core | failed to record blocked request metric: %v", err.Error()) + } + metric.Add(ctx, 1) + } + } else { + logger.TPrintf(logging.ERROR, transactionID, "core | could not check transaction: %v", err) + } + return res, err +} + +// CloseTransaction closes the transaction with the given id +// removing the transaction sync model results +func CloseTransaction(transactionID string) { + plugins.CloseTransaction(transactionID) + value, ok := analysisMap.Load(transactionID) + logger := logging.Get() + + if !ok { + logger.TPrintf(logging.ERROR, transactionID, "Analysis for transaction %s not found", transactionID) + } else { + close(value.(*transactionSync).Channel) + for range value.(*transactionSync).Channel { + } + analysisMap.Delete(transactionID) + } +} + +// Init initializes the WACE core with the given metric meter +func Init(met metric.Meter) error { + logger := logging.Get() + conf, err := configstore.Get() + meter = met + + err = logger.LoadLogger(conf.LogPath, conf.LogLevel) + if err != nil { + logger.Printf(logging.ERROR, "ERROR: could not open wace log file: %v", err) + return err + } + logger.Printf(logging.DEBUG, "Writing logs to %s from now", conf.LogPath) + + logger.Println(logging.DEBUG, "Loading plugin manager...") + plugins, err = pluginmanager.New(met) + if err != nil { + return err + } + logger.Println(logging.DEBUG, "Plugin manager loaded") + + return nil +} diff --git a/wacecore_test.go b/wacecore_test.go index 367e1d6..2b1d5fa 100644 --- a/wacecore_test.go +++ b/wacecore_test.go @@ -1,610 +1,610 @@ -package wace - -import ( - "math/rand" - "strconv" - "strings" - "testing" - "time" - - "github.com/tilsor/ModSecIntl_wace_lib/configstore" - "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" - "go.opentelemetry.io/otel/sdk/metric" - - "gopkg.in/yaml.v3" -) - -var requestURI = "/cgi-bin/process.cgi" -var requestMethod = "POST" -var requestVersion = "HTTP/1.1" - -// var requestLine = "POST /cgi-bin/process.cgi HTTP/1.1\n" - -var requestHeaders = []pluginmanager.HTTPHeader{ - {Key: "User-Agent", Value: "Mozilla/4.0 (compatible; MSIE5.01; Windows NT)"}, - {Key: "Host", Value: "www.tutorialspoint.com"}, - {Key: "Content-Type", Value: "application/x-www-form-urlencoded"}, - {Key: "Content-Length", Value: "length"}, - {Key: "Accept-Language", Value: "en-us"}, - {Key: "Accept-Encoding", Value: "gzip, deflate"}, - {Key: "Connection", Value: "Keep-Alive"}, -} - -var requestHeadersPayload = pluginmanager.HTTPPayload{ - URI: requestURI, - Method: requestMethod, - HTTPVersion: requestVersion, -} - -var requestBody = "licenseID=string&content=string&/paramsXML=string\n" -var wholeRequest = pluginmanager.HTTPPayload{ - URI: requestURI, - Method: requestMethod, - HTTPVersion: requestVersion, - RequestBody: requestBody, -} - -// var wholeRequest = requestLine + requestHeaders + "\n" + requestBody -var responseCode = 200 -var responseProto = "HTTP/1.1" -var responseHeaders = []pluginmanager.HTTPHeader{ - {Key: "Date", Value: "Mon, 27 Jul 2009 12:28:53 GMT"}, - {Key: "Server", Value: "Apache/2.2.14 (Win32)"}, - {Key: "Last-Modified", Value: "Wed, 22 Jul 2009 19:15:56 GMT"}, - {Key: "Content-Length", Value: "88"}, - {Key: "Content-Type", Value: "text/html"}, - {Key: "Connection", Value: "Closed"}, -} - -var responseHeadersPayload = pluginmanager.HTTPPayload{ - ResponseProtocol: responseProto, - ResponseCode: responseCode, - ResponseHeaders: responseHeaders, -} - -var responseBody = ` - -

Hello, World!

- - -` - -var wholeResponse = pluginmanager.HTTPPayload{ - ResponseProtocol: responseProto, - ResponseCode: responseCode, - ResponseHeaders: responseHeaders, - ResponseBody: responseBody, -} - -var config = []byte(`--- -logpath: "/dev/null" -loglevel: DEBUG -modelplugins: - - id: "trivial" - path: "testdata/plugins/model/trivial.so" - weight: 1 - params: - d: "sds" - b: "dnid" - e: "dofnno" - # plugintype: "RequestHeaders" - plugintype: "Everything" - - id: "trivial2" - path: "testdata/plugins/model/trivial2.so" - weight: 2 - params: - a: "sdsds" - b: "sdfjdnid" - c: "kfoskdofnno" - plugintype: "Everything" -decisionplugins: - - id: "simple" - path: "testdata/plugins/decision/simple.so" - wafweight: 0.5 - decisionbalance: 0.5 -`) - -var configAllModels = []byte(`--- -logpath: "/dev/null" -#The level of debug, the valid options are - ERRO, WARN, INFO, DEBUG -loglevel: "WARN" - -#The model plugins configuration -modelplugins: - - id: "trivialRequestHeaders" - plugintype: RequestHeaders - path: "testdata/plugins/model/trivial.so" - weight: 0.1 - mode: sync - - id: "trivialRequestBody" - plugintype: RequestBody - path: "testdata/plugins/model/trivial.so" - weight: 0.1 - mode: sync - - id: "trivialAllRequest" - plugintype: AllRequest - path: "testdata/plugins/model/trivial.so" - weight: 0.1 - mode: sync - - id: "trivialResponseHeaders" - plugintype: ResponseHeaders - path: "testdata/plugins/model/trivial.so" - weight: 0.1 - mode: sync - - id: "trivialResponseBody" - plugintype: ResponseBody - path: "testdata/plugins/model/trivial.so" - weight: 0.1 - mode: sync - - id: "trivialAllResponse" - plugintype: AllResponse - path: "testdata/plugins/model/trivial.so" - weight: 0.1 - mode: sync - -#The decision plugin configuration -decisionplugins: - - id: "simple" - path: "testdata/plugins/decision/simple.so" -# wafweight: 0.5 - decisionbalance: 0.1 -`) - -var configSyncNoRemote = []byte(`--- -logpath: "/dev/null" -#The level of debug, the valid options are - ERRO, WARN, INFO, DEBUG -loglevel: "WARN" - -#The model plugins configuration -modelplugins: - - id: "trivial" - plugintype: RequestHeaders - path: "testdata/plugins/model/trivial.so" - weight: 1 - mode: sync - - id: "trivial2" - plugintype: RequestHeaders - path: "testdata/plugins/model/trivial2.so" - weight: 2 - mode: sync - -#The decision plugin configuration -decisionplugins: - - id: "simple" - path: "testdata/plugins/decision/simple.so" -# wafweight: 0.5 - decisionbalance: 0.1 -`) - -var configSyncRemote = []byte(`--- -logpath: "/dev/null" -#The level of debug, the valid options are - ERRO, WARN, INFO, DEBUG -loglevel: "WARN" - -#The model plugins configuration -modelplugins: - - id: "trivial" - plugintype: RequestHeaders - path: "testdata/plugins/model/trivial.so" - weight: 1 - mode: sync - remote: true - - id: "trivial2" - plugintype: RequestHeaders - path: "testdata/plugins/model/trivial2.so" - weight: 2 - mode: sync - remote: true -#The decision plugin configuration -decisionplugins: - - id: "simple" - path: "testdata/plugins/decision/simple.so" -# wafweight: 0.5 - decisionbalance: 0.1 -`) - -var configAsync = []byte(`--- -logpath: "/dev/null" -#The level of debug, the valid options are - ERRO, WARN, INFO, DEBUG -loglevel: "WARN" - -#The model plugins configuration -modelplugins: - - id: "trivial" - plugintype: RequestHeaders - path: "testdata/plugins/model/trivial.so" - weight: 1 - mode: async - - id: "trivial2" - plugintype: RequestHeaders - path: "testdata/plugins/model/trivial2.so" - weight: 2 - mode: async -#The decision plugin configuration -decisionplugins: - - id: "simple" - path: "testdata/plugins/decision/simple.so" -# wafweight: 0.5 - decisionbalance: 0.1 -`) - -// var configRoberta = []byte(`--- -// logpath: "/dev/null" -// loglevel: DEBUG -// listenport: "50051" -// modelplugins: -// - id: "trivial" -// path: "testdata/plugins/model/trivial.so" -// weight: 1 -// threshold: 0.5 -// params: -// d: "sds" -// b: "dnid" -// e: "dofnno" -// # plugintype: "RequestHeaders" -// plugintype: "Everything" -// - id: "trivial2" -// path: "testdata/plugins/model/trivial2.so" -// weight: 2 -// threshold: 0.1 -// params: -// a: "sdsds" -// b: "sdfjdnid" -// c: "kfoskdofnno" -// plugintype: "Everything" -// - id: "roberta" -// path: "testdata/plugins/model/roberta.so" -// weight: 1 -// threshold: 0.5 -// params: -// url: "localhost:9999" -// distance_threshold: -0.02 -// plugintype: "AllRequest" -// decisionplugins: -// - id: "simple" -// path: "testdata/plugins/decision/simple.so" -// wafweight: 0.5 -// decisionbalance: 0.5 -// `) - -var provider = metric.NewMeterProvider() -var testMeter = provider.Meter("example-meter") - -func initilize(configuration []byte) error { - var aux configstore.ConfigFileData - err := yaml.Unmarshal(configuration, &aux) - if err != nil { - return err - } - cs, err := configstore.Get() - if err != nil { - return err - } - err = cs.SetConfig(aux) - if err != nil { - return err - } - err = Init(testMeter) - if err != nil { - return err - } - return nil -} - -func generateRandomID() string { - letters := "1234567890ABCDEF" - id := "" - for i := 0; i < 16; i++ { - id += string(letters[rand.Intn(len(letters))]) - } - - return id -} - -func TestAnalyzeRequestInParts(t *testing.T) { - _, err := configstore.New() - if err != nil { - t.Error(err) - } - - defer configstore.Clean() - err = initilize(configAllModels) - if err != nil { - t.Errorf("Error initing test: %v", err) - } - - transactionID := generateRandomID() - - InitTransaction(transactionID) - - res := Analyze("RequestHeaders", transactionID, requestHeadersPayload, []string{"trivialRequestHeaders"}) - if res != nil { - t.Errorf("Error: Analyze RequestHeaders: %s", res.Error()) - } - res = Analyze("RequestBody", transactionID, pluginmanager.HTTPPayload{ResponseBody: requestBody}, []string{"trivialRequestBody"}) - if res != nil { - t.Errorf("Error: Analyze RequestBody: %s", res.Error()) - } - - _, err = CheckTransaction(transactionID, "simple", make(map[string]string)) - if err != nil { - t.Errorf("Error: CheckTransaction: %s", err.Error()) - } - - CloseTransaction(transactionID) -} - -func TestAnalyzeWholeRequest(t *testing.T) { - _, err := configstore.New() - if err != nil { - t.Error(err) - } - - defer configstore.Clean() - err = initilize(configAllModels) - if err != nil { - t.Errorf("Error initing test: %v", err) - } - - transactionID := generateRandomID() - - InitTransaction(transactionID) - - res := Analyze("AllRequest", transactionID, wholeRequest, []string{"trivialAllRequest"}) - if res != nil { - t.Errorf("Error: Analyze AllRequest: %s", res.Error()) - } - - _, err = CheckTransaction(transactionID, "simple", make(map[string]string)) - if err != nil { - t.Errorf("Error: CheckTransaction: %s", err.Error()) - } - - CloseTransaction(transactionID) -} - -func TestAnalyzeResponseInParts(t *testing.T) { - _, err := configstore.New() - if err != nil { - t.Error(err) - } - - defer configstore.Clean() - err = initilize(configAllModels) - if err != nil { - t.Errorf("Error initing test: %v", err) - } - - transactionID := generateRandomID() - - InitTransaction(transactionID) - - res := Analyze("ResponseHeaders", transactionID, responseHeadersPayload, []string{"trivialResponseHeaders"}) - if res != nil { - t.Errorf("Error: Analyze ResponseHeaders: %s", res.Error()) - } - res = Analyze("ResponseBody", transactionID, pluginmanager.HTTPPayload{ResponseBody: responseBody}, []string{"trivialResponseBody"}) - if res != nil { - t.Errorf("Error: Analyze ResponseBody: %s", res.Error()) - } - - _, err = CheckTransaction(transactionID, "simple", make(map[string]string)) - if err != nil { - t.Errorf("Error: CheckTransaction: %s", err.Error()) - } - - CloseTransaction(transactionID) -} - -func TestAnalyzeWholeResponse(t *testing.T) { - _, err := configstore.New() - if err != nil { - t.Error(err) - } - - defer configstore.Clean() - err = initilize(configAllModels) - if err != nil { - t.Errorf("Error initing test: %v", err) - } - - transactionID := generateRandomID() - - InitTransaction(transactionID) - - res := Analyze("AllResponse", transactionID, wholeResponse, []string{"trivialAllResponse"}) - if res != nil { - t.Errorf("Error: Analyze AllResponse: %s", res.Error()) - } - - _, err = CheckTransaction(transactionID, "simple", make(map[string]string)) - if err != nil { - t.Errorf("Error: CheckTransaction: %s", err.Error()) - } - - CloseTransaction(transactionID) -} - -func TestAnalyzeRequestInPartsAsync(t *testing.T) { - _, err := configstore.New() - if err != nil { - t.Error(err) - } - - defer configstore.Clean() - err = initilize(configAsync) - if err != nil { - t.Errorf("Error initing test: %v", err) - } - transactionID := generateRandomID() - - InitTransaction(transactionID) - - res := Analyze("RequestHeaders", transactionID, requestHeadersPayload, []string{"trivial", "trivial2"}) - if res != nil { - t.Errorf("Error: Analyze RequestHeaders: %s", res.Error()) - } - - _, err = CheckTransaction(transactionID, "simple", make(map[string]string)) - if err != nil { - t.Errorf("Error: CheckTransaction: %s", err.Error()) - } - - CloseTransaction(transactionID) - - time.Sleep(10 * time.Millisecond) -} - -func TestCheckInvalidTransaction(t *testing.T) { - _, err := CheckTransaction("INEXISTENT", "simple", make(map[string]string)) - if err == nil { - t.Errorf("Error: CheckTransaction with inexistent transaction does not rise an error") - } -} - -func TestCheckAttackTransaction(t *testing.T) { - _, err := configstore.New() - if err != nil { - t.Error(err) - } - - defer configstore.Clean() - err = initilize(configSyncNoRemote) - if err != nil { - t.Errorf("Error initing test: %v", err) - } - transactionID := generateRandomID() - - InitTransaction(transactionID) - - wafParams := make(map[string]string) - auxString := "COMBINED_SCORE=0,HTTP=0,LFI=0,PHPI=0,RCE=0,RFI=0,SESS=0,SQLI=0,XSS=0,inbound_blocking=20,inbound_detection=0,inbound_per_pl=0-0-0-0,inbound_threshold=5,outbound_blocking=0,outbound_detection=0,outbound_per_pl=0-0-0-0,outbound_threshold=4,phase=2" - for _, score := range strings.Split(auxString, ",") { - scoreParts := strings.Split(score, "=") - wafParams[scoreParts[0]] = scoreParts[1] - } - - err = Analyze("RequestHeaders", transactionID, requestHeadersPayload, []string{"trivial", "trivial2", "trivial3"}) - if err != nil { - t.Errorf("Error: Analyze RequestHeaders: %s", err.Error()) - } - - res, err := CheckTransaction(transactionID, "simple", wafParams) - if err != nil { - t.Errorf("Error: CheckTransaction: %s", err.Error()) - } - if !res { - t.Errorf("Error: CheckTransaction: transaction should be blocked") - } - - CloseTransaction(transactionID) -} - -// func TestAnalyzeStress(t *testing.T) { -// for i := 0; i < 1000; i++ { -// transactionID := generateRandomID() -// AnalyzeRequest(transactionID, wholeRequest, []string{"trivial", "trivial2"}) -// _, err := CheckTransaction(transactionID, "simple", make(map[string]string)) -// if err != nil { -// t.Errorf("checkTransaction error: %v", err) -// } -// } - -// } - -// func processRequest(models []string) error { -// transactionID := generateRandomID() - -// res := AnalyzeRequest(transactionID, wholeRequest, models) -// if res != 0 { -// return errors.New("analyzeRequest returned non-zero") -// } - -// _, err := CheckTransaction(transactionID, "simple", -// map[string]string{"anomalyscore": "1", -// "inboundthreshold": "100"}) -// return err -// } - -// func TestRoberta(t *testing.T) { -// conf := cf.Get() -// err := conf.LoadConfigYaml(configRoberta) -// if err != nil { -// panic("Error loading config: " + err.Error()) -// } - -// err = processRequest([]string{"roberta"}) -// if err != nil { -// t.Errorf("callRoberta error: %v", err) -// } -// } - -// func BenchmarkRoberta(b *testing.B) { -// for i := 0; i < b.N; i++ { -// processRequest([]string{"roberta"}) -// } -// } - -func BenchmarkTrivial(b *testing.B) { - - _, err := configstore.New() - if err != nil { - b.Error(err) - } - - defer configstore.Clean() - err = initilize(configSyncNoRemote) - if err != nil { - b.Errorf("Error initing test: %v", err) - } - wafParams := make(map[string]string) - auxString := "COMBINED_SCORE=0,HTTP=0,LFI=0,PHPI=0,RCE=0,RFI=0,SESS=0,SQLI=0,XSS=0,inbound_blocking=0,inbound_detection=0,inbound_per_pl=0-0-0-0,inbound_threshold=5,outbound_blocking=0,outbound_detection=0,outbound_per_pl=0-0-0-0,outbound_threshold=4,phase=2" - for _, score := range strings.Split(auxString, ",") { - scoreParts := strings.Split(score, "=") - wafParams[scoreParts[0]] = scoreParts[1] - } - for i := 0; i < b.N; i++ { - transactionId := strconv.Itoa(i) - InitTransaction(transactionId) - - Analyze("RequestHeaders", transactionId, pluginmanager.HTTPPayload{URI: "Request line and headers\n"}, []string{"trivial", "trivial2"}) - - _, err := CheckTransaction(transactionId, "simple", wafParams) - if err != nil { - b.Errorf("Error checking transaction: %v", err) - } - CloseTransaction(transactionId) - } -} - -func BenchmarkTrivialFullNATS(b *testing.B) { - _, err := configstore.New() - if err != nil { - b.Error(err) - } - - defer configstore.Clean() - err = initilize(configSyncRemote) - if err != nil { - b.Errorf("Error initing test: %v", err) - } - time.Sleep(2 * time.Millisecond) - wafParams := make(map[string]string) - auxString := "COMBINED_SCORE=0,HTTP=0,LFI=0,PHPI=0,RCE=0,RFI=0,SESS=0,SQLI=0,XSS=0,inbound_blocking=0,inbound_detection=0,inbound_per_pl=0-0-0-0,inbound_threshold=5,outbound_blocking=0,outbound_detection=0,outbound_per_pl=0-0-0-0,outbound_threshold=4,phase=2" - for _, score := range strings.Split(auxString, ",") { - scoreParts := strings.Split(score, "=") - wafParams[scoreParts[0]] = scoreParts[1] - } - for i := 0; i < b.N; i++ { - transactionId := generateRandomID() - InitTransaction(transactionId) - - Analyze("RequestHeaders", transactionId, pluginmanager.HTTPPayload{URI: "Request line and headers\n"}, []string{"trivial", "trivial2"}) - - _, err := CheckTransaction(transactionId, "simple", wafParams) - if err != nil { - b.Errorf("Error checking transaction: %v", err) - } - CloseTransaction(transactionId) - } -} +package wace + +import ( + "math/rand" + "strconv" + "strings" + "testing" + "time" + + "github.com/tilsor/ModSecIntl_wace_lib/configstore" + "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" + "go.opentelemetry.io/otel/sdk/metric" + + "gopkg.in/yaml.v3" +) + +var requestURI = "/cgi-bin/process.cgi" +var requestMethod = "POST" +var requestVersion = "HTTP/1.1" + +// var requestLine = "POST /cgi-bin/process.cgi HTTP/1.1\n" + +var requestHeaders = []pluginmanager.HTTPHeader{ + {Key: "User-Agent", Value: "Mozilla/4.0 (compatible; MSIE5.01; Windows NT)"}, + {Key: "Host", Value: "www.tutorialspoint.com"}, + {Key: "Content-Type", Value: "application/x-www-form-urlencoded"}, + {Key: "Content-Length", Value: "length"}, + {Key: "Accept-Language", Value: "en-us"}, + {Key: "Accept-Encoding", Value: "gzip, deflate"}, + {Key: "Connection", Value: "Keep-Alive"}, +} + +var requestHeadersPayload = pluginmanager.HTTPPayload{ + URI: requestURI, + Method: requestMethod, + HTTPVersion: requestVersion, +} + +var requestBody = "licenseID=string&content=string&/paramsXML=string\n" +var wholeRequest = pluginmanager.HTTPPayload{ + URI: requestURI, + Method: requestMethod, + HTTPVersion: requestVersion, + RequestBody: requestBody, +} + +// var wholeRequest = requestLine + requestHeaders + "\n" + requestBody +var responseCode = 200 +var responseProto = "HTTP/1.1" +var responseHeaders = []pluginmanager.HTTPHeader{ + {Key: "Date", Value: "Mon, 27 Jul 2009 12:28:53 GMT"}, + {Key: "Server", Value: "Apache/2.2.14 (Win32)"}, + {Key: "Last-Modified", Value: "Wed, 22 Jul 2009 19:15:56 GMT"}, + {Key: "Content-Length", Value: "88"}, + {Key: "Content-Type", Value: "text/html"}, + {Key: "Connection", Value: "Closed"}, +} + +var responseHeadersPayload = pluginmanager.HTTPPayload{ + ResponseProtocol: responseProto, + ResponseCode: responseCode, + ResponseHeaders: responseHeaders, +} + +var responseBody = ` + +

Hello, World!

+ + +` + +var wholeResponse = pluginmanager.HTTPPayload{ + ResponseProtocol: responseProto, + ResponseCode: responseCode, + ResponseHeaders: responseHeaders, + ResponseBody: responseBody, +} + +var config = []byte(`--- +logpath: "/dev/null" +loglevel: DEBUG +modelplugins: + - id: "trivial" + path: "testdata/plugins/model/trivial.so" + weight: 1 + params: + d: "sds" + b: "dnid" + e: "dofnno" + # plugintype: "RequestHeaders" + plugintype: "Everything" + - id: "trivial2" + path: "testdata/plugins/model/trivial2.so" + weight: 2 + params: + a: "sdsds" + b: "sdfjdnid" + c: "kfoskdofnno" + plugintype: "Everything" +decisionplugins: + - id: "simple" + path: "testdata/plugins/decision/simple.so" + wafweight: 0.5 + decisionbalance: 0.5 +`) + +var configAllModels = []byte(`--- +logpath: "/dev/null" +#The level of debug, the valid options are - ERRO, WARN, INFO, DEBUG +loglevel: "WARN" + +#The model plugins configuration +modelplugins: + - id: "trivialRequestHeaders" + plugintype: RequestHeaders + path: "testdata/plugins/model/trivial.so" + weight: 0.1 + mode: sync + - id: "trivialRequestBody" + plugintype: RequestBody + path: "testdata/plugins/model/trivial.so" + weight: 0.1 + mode: sync + - id: "trivialAllRequest" + plugintype: AllRequest + path: "testdata/plugins/model/trivial.so" + weight: 0.1 + mode: sync + - id: "trivialResponseHeaders" + plugintype: ResponseHeaders + path: "testdata/plugins/model/trivial.so" + weight: 0.1 + mode: sync + - id: "trivialResponseBody" + plugintype: ResponseBody + path: "testdata/plugins/model/trivial.so" + weight: 0.1 + mode: sync + - id: "trivialAllResponse" + plugintype: AllResponse + path: "testdata/plugins/model/trivial.so" + weight: 0.1 + mode: sync + +#The decision plugin configuration +decisionplugins: + - id: "simple" + path: "testdata/plugins/decision/simple.so" +# wafweight: 0.5 + decisionbalance: 0.1 +`) + +var configSyncNoRemote = []byte(`--- +logpath: "/dev/null" +#The level of debug, the valid options are - ERRO, WARN, INFO, DEBUG +loglevel: "WARN" + +#The model plugins configuration +modelplugins: + - id: "trivial" + plugintype: RequestHeaders + path: "testdata/plugins/model/trivial.so" + weight: 1 + mode: sync + - id: "trivial2" + plugintype: RequestHeaders + path: "testdata/plugins/model/trivial2.so" + weight: 2 + mode: sync + +#The decision plugin configuration +decisionplugins: + - id: "simple" + path: "testdata/plugins/decision/simple.so" +# wafweight: 0.5 + decisionbalance: 0.1 +`) + +var configSyncRemote = []byte(`--- +logpath: "/dev/null" +#The level of debug, the valid options are - ERRO, WARN, INFO, DEBUG +loglevel: "WARN" + +#The model plugins configuration +modelplugins: + - id: "trivial" + plugintype: RequestHeaders + path: "testdata/plugins/model/trivial.so" + weight: 1 + mode: sync + remote: true + - id: "trivial2" + plugintype: RequestHeaders + path: "testdata/plugins/model/trivial2.so" + weight: 2 + mode: sync + remote: true +#The decision plugin configuration +decisionplugins: + - id: "simple" + path: "testdata/plugins/decision/simple.so" +# wafweight: 0.5 + decisionbalance: 0.1 +`) + +var configAsync = []byte(`--- +logpath: "/dev/null" +#The level of debug, the valid options are - ERRO, WARN, INFO, DEBUG +loglevel: "WARN" + +#The model plugins configuration +modelplugins: + - id: "trivial" + plugintype: RequestHeaders + path: "testdata/plugins/model/trivial.so" + weight: 1 + mode: async + - id: "trivial2" + plugintype: RequestHeaders + path: "testdata/plugins/model/trivial2.so" + weight: 2 + mode: async +#The decision plugin configuration +decisionplugins: + - id: "simple" + path: "testdata/plugins/decision/simple.so" +# wafweight: 0.5 + decisionbalance: 0.1 +`) + +// var configRoberta = []byte(`--- +// logpath: "/dev/null" +// loglevel: DEBUG +// listenport: "50051" +// modelplugins: +// - id: "trivial" +// path: "testdata/plugins/model/trivial.so" +// weight: 1 +// threshold: 0.5 +// params: +// d: "sds" +// b: "dnid" +// e: "dofnno" +// # plugintype: "RequestHeaders" +// plugintype: "Everything" +// - id: "trivial2" +// path: "testdata/plugins/model/trivial2.so" +// weight: 2 +// threshold: 0.1 +// params: +// a: "sdsds" +// b: "sdfjdnid" +// c: "kfoskdofnno" +// plugintype: "Everything" +// - id: "roberta" +// path: "testdata/plugins/model/roberta.so" +// weight: 1 +// threshold: 0.5 +// params: +// url: "localhost:9999" +// distance_threshold: -0.02 +// plugintype: "AllRequest" +// decisionplugins: +// - id: "simple" +// path: "testdata/plugins/decision/simple.so" +// wafweight: 0.5 +// decisionbalance: 0.5 +// `) + +var provider = metric.NewMeterProvider() +var testMeter = provider.Meter("example-meter") + +func initilize(configuration []byte) error { + var aux configstore.ConfigFileData + err := yaml.Unmarshal(configuration, &aux) + if err != nil { + return err + } + cs, err := configstore.Get() + if err != nil { + return err + } + err = cs.SetConfig(aux) + if err != nil { + return err + } + err = Init(testMeter) + if err != nil { + return err + } + return nil +} + +func generateRandomID() string { + letters := "1234567890ABCDEF" + id := "" + for i := 0; i < 16; i++ { + id += string(letters[rand.Intn(len(letters))]) + } + + return id +} + +func TestAnalyzeRequestInParts(t *testing.T) { + _, err := configstore.New() + if err != nil { + t.Error(err) + } + + defer configstore.Clean() + err = initilize(configAllModels) + if err != nil { + t.Errorf("Error initing test: %v", err) + } + + transactionID := generateRandomID() + + InitTransaction(transactionID) + + res := Analyze("RequestHeaders", transactionID, requestHeadersPayload, []string{"trivialRequestHeaders"}) + if res != nil { + t.Errorf("Error: Analyze RequestHeaders: %s", res.Error()) + } + res = Analyze("RequestBody", transactionID, pluginmanager.HTTPPayload{ResponseBody: requestBody}, []string{"trivialRequestBody"}) + if res != nil { + t.Errorf("Error: Analyze RequestBody: %s", res.Error()) + } + + _, err = CheckTransaction(transactionID, "simple", make(map[string]string)) + if err != nil { + t.Errorf("Error: CheckTransaction: %s", err.Error()) + } + + CloseTransaction(transactionID) +} + +func TestAnalyzeWholeRequest(t *testing.T) { + _, err := configstore.New() + if err != nil { + t.Error(err) + } + + defer configstore.Clean() + err = initilize(configAllModels) + if err != nil { + t.Errorf("Error initing test: %v", err) + } + + transactionID := generateRandomID() + + InitTransaction(transactionID) + + res := Analyze("AllRequest", transactionID, wholeRequest, []string{"trivialAllRequest"}) + if res != nil { + t.Errorf("Error: Analyze AllRequest: %s", res.Error()) + } + + _, err = CheckTransaction(transactionID, "simple", make(map[string]string)) + if err != nil { + t.Errorf("Error: CheckTransaction: %s", err.Error()) + } + + CloseTransaction(transactionID) +} + +func TestAnalyzeResponseInParts(t *testing.T) { + _, err := configstore.New() + if err != nil { + t.Error(err) + } + + defer configstore.Clean() + err = initilize(configAllModels) + if err != nil { + t.Errorf("Error initing test: %v", err) + } + + transactionID := generateRandomID() + + InitTransaction(transactionID) + + res := Analyze("ResponseHeaders", transactionID, responseHeadersPayload, []string{"trivialResponseHeaders"}) + if res != nil { + t.Errorf("Error: Analyze ResponseHeaders: %s", res.Error()) + } + res = Analyze("ResponseBody", transactionID, pluginmanager.HTTPPayload{ResponseBody: responseBody}, []string{"trivialResponseBody"}) + if res != nil { + t.Errorf("Error: Analyze ResponseBody: %s", res.Error()) + } + + _, err = CheckTransaction(transactionID, "simple", make(map[string]string)) + if err != nil { + t.Errorf("Error: CheckTransaction: %s", err.Error()) + } + + CloseTransaction(transactionID) +} + +func TestAnalyzeWholeResponse(t *testing.T) { + _, err := configstore.New() + if err != nil { + t.Error(err) + } + + defer configstore.Clean() + err = initilize(configAllModels) + if err != nil { + t.Errorf("Error initing test: %v", err) + } + + transactionID := generateRandomID() + + InitTransaction(transactionID) + + res := Analyze("AllResponse", transactionID, wholeResponse, []string{"trivialAllResponse"}) + if res != nil { + t.Errorf("Error: Analyze AllResponse: %s", res.Error()) + } + + _, err = CheckTransaction(transactionID, "simple", make(map[string]string)) + if err != nil { + t.Errorf("Error: CheckTransaction: %s", err.Error()) + } + + CloseTransaction(transactionID) +} + +func TestAnalyzeRequestInPartsAsync(t *testing.T) { + _, err := configstore.New() + if err != nil { + t.Error(err) + } + + defer configstore.Clean() + err = initilize(configAsync) + if err != nil { + t.Errorf("Error initing test: %v", err) + } + transactionID := generateRandomID() + + InitTransaction(transactionID) + + res := Analyze("RequestHeaders", transactionID, requestHeadersPayload, []string{"trivial", "trivial2"}) + if res != nil { + t.Errorf("Error: Analyze RequestHeaders: %s", res.Error()) + } + + _, err = CheckTransaction(transactionID, "simple", make(map[string]string)) + if err != nil { + t.Errorf("Error: CheckTransaction: %s", err.Error()) + } + + CloseTransaction(transactionID) + + time.Sleep(10 * time.Millisecond) +} + +func TestCheckInvalidTransaction(t *testing.T) { + _, err := CheckTransaction("INEXISTENT", "simple", make(map[string]string)) + if err == nil { + t.Errorf("Error: CheckTransaction with inexistent transaction does not rise an error") + } +} + +func TestCheckAttackTransaction(t *testing.T) { + _, err := configstore.New() + if err != nil { + t.Error(err) + } + + defer configstore.Clean() + err = initilize(configSyncNoRemote) + if err != nil { + t.Errorf("Error initing test: %v", err) + } + transactionID := generateRandomID() + + InitTransaction(transactionID) + + wafParams := make(map[string]string) + auxString := "COMBINED_SCORE=0,HTTP=0,LFI=0,PHPI=0,RCE=0,RFI=0,SESS=0,SQLI=0,XSS=0,inbound_blocking=20,inbound_detection=0,inbound_per_pl=0-0-0-0,inbound_threshold=5,outbound_blocking=0,outbound_detection=0,outbound_per_pl=0-0-0-0,outbound_threshold=4,phase=2" + for _, score := range strings.Split(auxString, ",") { + scoreParts := strings.Split(score, "=") + wafParams[scoreParts[0]] = scoreParts[1] + } + + err = Analyze("RequestHeaders", transactionID, requestHeadersPayload, []string{"trivial", "trivial2", "trivial3"}) + if err != nil { + t.Errorf("Error: Analyze RequestHeaders: %s", err.Error()) + } + + res, err := CheckTransaction(transactionID, "simple", wafParams) + if err != nil { + t.Errorf("Error: CheckTransaction: %s", err.Error()) + } + if !res { + t.Errorf("Error: CheckTransaction: transaction should be blocked") + } + + CloseTransaction(transactionID) +} + +// func TestAnalyzeStress(t *testing.T) { +// for i := 0; i < 1000; i++ { +// transactionID := generateRandomID() +// AnalyzeRequest(transactionID, wholeRequest, []string{"trivial", "trivial2"}) +// _, err := CheckTransaction(transactionID, "simple", make(map[string]string)) +// if err != nil { +// t.Errorf("checkTransaction error: %v", err) +// } +// } + +// } + +// func processRequest(models []string) error { +// transactionID := generateRandomID() + +// res := AnalyzeRequest(transactionID, wholeRequest, models) +// if res != 0 { +// return errors.New("analyzeRequest returned non-zero") +// } + +// _, err := CheckTransaction(transactionID, "simple", +// map[string]string{"anomalyscore": "1", +// "inboundthreshold": "100"}) +// return err +// } + +// func TestRoberta(t *testing.T) { +// conf := cf.Get() +// err := conf.LoadConfigYaml(configRoberta) +// if err != nil { +// panic("Error loading config: " + err.Error()) +// } + +// err = processRequest([]string{"roberta"}) +// if err != nil { +// t.Errorf("callRoberta error: %v", err) +// } +// } + +// func BenchmarkRoberta(b *testing.B) { +// for i := 0; i < b.N; i++ { +// processRequest([]string{"roberta"}) +// } +// } + +func BenchmarkTrivial(b *testing.B) { + + _, err := configstore.New() + if err != nil { + b.Error(err) + } + + defer configstore.Clean() + err = initilize(configSyncNoRemote) + if err != nil { + b.Errorf("Error initing test: %v", err) + } + wafParams := make(map[string]string) + auxString := "COMBINED_SCORE=0,HTTP=0,LFI=0,PHPI=0,RCE=0,RFI=0,SESS=0,SQLI=0,XSS=0,inbound_blocking=0,inbound_detection=0,inbound_per_pl=0-0-0-0,inbound_threshold=5,outbound_blocking=0,outbound_detection=0,outbound_per_pl=0-0-0-0,outbound_threshold=4,phase=2" + for _, score := range strings.Split(auxString, ",") { + scoreParts := strings.Split(score, "=") + wafParams[scoreParts[0]] = scoreParts[1] + } + for i := 0; i < b.N; i++ { + transactionId := strconv.Itoa(i) + InitTransaction(transactionId) + + Analyze("RequestHeaders", transactionId, pluginmanager.HTTPPayload{URI: "Request line and headers\n"}, []string{"trivial", "trivial2"}) + + _, err := CheckTransaction(transactionId, "simple", wafParams) + if err != nil { + b.Errorf("Error checking transaction: %v", err) + } + CloseTransaction(transactionId) + } +} + +func BenchmarkTrivialFullNATS(b *testing.B) { + _, err := configstore.New() + if err != nil { + b.Error(err) + } + + defer configstore.Clean() + err = initilize(configSyncRemote) + if err != nil { + b.Errorf("Error initing test: %v", err) + } + time.Sleep(2 * time.Millisecond) + wafParams := make(map[string]string) + auxString := "COMBINED_SCORE=0,HTTP=0,LFI=0,PHPI=0,RCE=0,RFI=0,SESS=0,SQLI=0,XSS=0,inbound_blocking=0,inbound_detection=0,inbound_per_pl=0-0-0-0,inbound_threshold=5,outbound_blocking=0,outbound_detection=0,outbound_per_pl=0-0-0-0,outbound_threshold=4,phase=2" + for _, score := range strings.Split(auxString, ",") { + scoreParts := strings.Split(score, "=") + wafParams[scoreParts[0]] = scoreParts[1] + } + for i := 0; i < b.N; i++ { + transactionId := generateRandomID() + InitTransaction(transactionId) + + Analyze("RequestHeaders", transactionId, pluginmanager.HTTPPayload{URI: "Request line and headers\n"}, []string{"trivial", "trivial2"}) + + _, err := CheckTransaction(transactionId, "simple", wafParams) + if err != nil { + b.Errorf("Error checking transaction: %v", err) + } + CloseTransaction(transactionId) + } +} From c92b3faa51f86acdf089b2a803b61d387c962a37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20de=20Le=C3=B3n?= Date: Thu, 30 Apr 2026 13:45:23 -0300 Subject: [PATCH 6/6] test: update test cases --- configstore/configstore.go | 4 +-- wacecore.go | 18 +++++++--- wacecore_test.go | 71 +++++++------------------------------- 3 files changed, 29 insertions(+), 64 deletions(-) diff --git a/configstore/configstore.go b/configstore/configstore.go index f2f4df6..58d9352 100644 --- a/configstore/configstore.go +++ b/configstore/configstore.go @@ -103,7 +103,7 @@ var config *ConfigStore // Create and returns the unique instance of configstore if it does not exist previously, in other case returns error func New() (*ConfigStore, error) { if config != nil { - return nil, fmt.Errorf("ConfigStore already exists") + return nil, fmt.Errorf("ConfigStore: an instance already exists") } config = new(ConfigStore) return config, nil @@ -112,7 +112,7 @@ func New() (*ConfigStore, error) { // Get returns the unique instance of configstore func Get() (*ConfigStore, error) { if config == nil { - return nil, fmt.Errorf("Configuration was not loaded") + return nil, fmt.Errorf("ConfigStore: Configuration was not loaded") } return config, nil } diff --git a/wacecore.go b/wacecore.go index f3111a2..82f315b 100644 --- a/wacecore.go +++ b/wacecore.go @@ -249,17 +249,27 @@ func CloseTransaction(transactionID string) { } // Init initializes the WACE core with the given metric meter -func Init(met metric.Meter) error { +func Init(met metric.Meter, conf configstore.ConfigFileData) error { logger := logging.Get() - conf, err := configstore.Get() + + cs, err := configstore.New() + if err != nil { + return err + } + + err = cs.SetConfig(conf) + if err != nil { + return err + } + meter = met - err = logger.LoadLogger(conf.LogPath, conf.LogLevel) + err = logger.LoadLogger(cs.LogPath, cs.LogLevel) if err != nil { logger.Printf(logging.ERROR, "ERROR: could not open wace log file: %v", err) return err } - logger.Printf(logging.DEBUG, "Writing logs to %s from now", conf.LogPath) + logger.Printf(logging.DEBUG, "Writing logs to %s from now", cs.LogPath) logger.Println(logging.DEBUG, "Loading plugin manager...") plugins, err = pluginmanager.New(met) diff --git a/wacecore_test.go b/wacecore_test.go index 2b1d5fa..5115dae 100644 --- a/wacecore_test.go +++ b/wacecore_test.go @@ -276,15 +276,7 @@ func initilize(configuration []byte) error { if err != nil { return err } - cs, err := configstore.Get() - if err != nil { - return err - } - err = cs.SetConfig(aux) - if err != nil { - return err - } - err = Init(testMeter) + err = Init(testMeter, aux) if err != nil { return err } @@ -302,13 +294,8 @@ func generateRandomID() string { } func TestAnalyzeRequestInParts(t *testing.T) { - _, err := configstore.New() - if err != nil { - t.Error(err) - } - + err := initilize(configAllModels) defer configstore.Clean() - err = initilize(configAllModels) if err != nil { t.Errorf("Error initing test: %v", err) } @@ -335,13 +322,8 @@ func TestAnalyzeRequestInParts(t *testing.T) { } func TestAnalyzeWholeRequest(t *testing.T) { - _, err := configstore.New() - if err != nil { - t.Error(err) - } - + err := initilize(configAllModels) defer configstore.Clean() - err = initilize(configAllModels) if err != nil { t.Errorf("Error initing test: %v", err) } @@ -364,13 +346,8 @@ func TestAnalyzeWholeRequest(t *testing.T) { } func TestAnalyzeResponseInParts(t *testing.T) { - _, err := configstore.New() - if err != nil { - t.Error(err) - } - + err := initilize(configAllModels) defer configstore.Clean() - err = initilize(configAllModels) if err != nil { t.Errorf("Error initing test: %v", err) } @@ -397,13 +374,8 @@ func TestAnalyzeResponseInParts(t *testing.T) { } func TestAnalyzeWholeResponse(t *testing.T) { - _, err := configstore.New() - if err != nil { - t.Error(err) - } - + err := initilize(configAllModels) defer configstore.Clean() - err = initilize(configAllModels) if err != nil { t.Errorf("Error initing test: %v", err) } @@ -426,16 +398,12 @@ func TestAnalyzeWholeResponse(t *testing.T) { } func TestAnalyzeRequestInPartsAsync(t *testing.T) { - _, err := configstore.New() - if err != nil { - t.Error(err) - } - + err := initilize(configAsync) defer configstore.Clean() - err = initilize(configAsync) if err != nil { t.Errorf("Error initing test: %v", err) } + transactionID := generateRandomID() InitTransaction(transactionID) @@ -463,16 +431,12 @@ func TestCheckInvalidTransaction(t *testing.T) { } func TestCheckAttackTransaction(t *testing.T) { - _, err := configstore.New() - if err != nil { - t.Error(err) - } - + err := initilize(configSyncNoRemote) defer configstore.Clean() - err = initilize(configSyncNoRemote) if err != nil { t.Errorf("Error initing test: %v", err) } + transactionID := generateRandomID() InitTransaction(transactionID) @@ -546,17 +510,12 @@ func TestCheckAttackTransaction(t *testing.T) { // } func BenchmarkTrivial(b *testing.B) { - - _, err := configstore.New() - if err != nil { - b.Error(err) - } - + err := initilize(configSyncNoRemote) defer configstore.Clean() - err = initilize(configSyncNoRemote) if err != nil { b.Errorf("Error initing test: %v", err) } + wafParams := make(map[string]string) auxString := "COMBINED_SCORE=0,HTTP=0,LFI=0,PHPI=0,RCE=0,RFI=0,SESS=0,SQLI=0,XSS=0,inbound_blocking=0,inbound_detection=0,inbound_per_pl=0-0-0-0,inbound_threshold=5,outbound_blocking=0,outbound_detection=0,outbound_per_pl=0-0-0-0,outbound_threshold=4,phase=2" for _, score := range strings.Split(auxString, ",") { @@ -578,16 +537,12 @@ func BenchmarkTrivial(b *testing.B) { } func BenchmarkTrivialFullNATS(b *testing.B) { - _, err := configstore.New() - if err != nil { - b.Error(err) - } - + err := initilize(configSyncRemote) defer configstore.Clean() - err = initilize(configSyncRemote) if err != nil { b.Errorf("Error initing test: %v", err) } + time.Sleep(2 * time.Millisecond) wafParams := make(map[string]string) auxString := "COMBINED_SCORE=0,HTTP=0,LFI=0,PHPI=0,RCE=0,RFI=0,SESS=0,SQLI=0,XSS=0,inbound_blocking=0,inbound_detection=0,inbound_per_pl=0-0-0-0,inbound_threshold=5,outbound_blocking=0,outbound_detection=0,outbound_per_pl=0-0-0-0,outbound_threshold=4,phase=2"