diff --git a/configstore/configstore.go b/configstore/configstore.go index 58d9352..5e88005 100644 --- a/configstore/configstore.go +++ b/configstore/configstore.go @@ -81,11 +81,9 @@ type modelPluginConfig struct { // DecisionPluginConfig stores the configuration of a decision plugin type decisionPluginConfig struct { - ID string - Path string - WAFweight float64 - DecisionBalance float64 - Params map[string]string + ID string + Path string + Params map[string]string } // ConfigStore stores all wacecore configuration from the config file. @@ -134,11 +132,9 @@ type configFileModelPlugin struct { } type configFileDecisionPlugin struct { - ID string - Path string - wafweight float64 - decisionbalance float64 - Params map[string]string + ID string + Path string + Params map[string]string } type ConfigFileData struct { @@ -245,8 +241,6 @@ func (cs *ConfigStore) SetConfig(inConf ConfigFileData) error { var decisionConfig decisionPluginConfig decisionConfig.ID = decisionP.ID decisionConfig.Path = decisionP.Path - decisionConfig.WAFweight = decisionP.wafweight - decisionConfig.DecisionBalance = decisionP.decisionbalance decisionConfig.Params = decisionP.Params cs.DecisionPlugins[decisionConfig.ID] = decisionConfig } diff --git a/configstore/configstore_test.go b/configstore/configstore_test.go index 93b3d97..d27eced 100644 --- a/configstore/configstore_test.go +++ b/configstore/configstore_test.go @@ -103,175 +103,475 @@ func TestLoadConfigYamlInvalid(t *testing.T) { } func TestLoadConfigYamlLogLevel(t *testing.T) { - _, err := New() - if err != nil { - t.Error(err) - } - - defer Clean() - - values := []string{ - "a", - "4", - "0", + tests := []struct { + level string + wantErr bool + }{ + {"a", true}, + {"4", true}, + {"0", true}, + {"DEBUG", false}, + {"INFO", false}, + {"WARN", false}, + {"ERROR", false}, } - for _, v := range values { - config := `--- -logpath: "/dev/null" -loglevel: ` + v - err = initialize([]byte(config)) - if err == nil { - t.Errorf("invalid log level %v does not return error", v) - } + for _, tt := range tests { + t.Run(tt.level, func(t *testing.T) { + _, err := New() + if err != nil { + t.Fatal(err) + } + defer Clean() + + config := "---\nlogpath: \"/dev/null\"\nloglevel: " + tt.level + err = initialize([]byte(config)) + if (err != nil) != tt.wantErr { + if tt.wantErr { + t.Errorf("log level %q should return error but did not", tt.level) + } else { + t.Errorf("log level %q returned unexpected error: %v", tt.level, err) + } + } + }) } } func TestLoadConfigYamlPluginType(t *testing.T) { - cs, err := New() - if err != nil { - t.Error(err) - } - - defer Clean() - - err = initialize([]byte(`--- + tests := []struct { + name string + config string + wantErr bool + wantType string + }{ + { + name: "invalid plugin type", + config: `--- loglevel: ERROR logpath: /dev/null modelplugins: - id: "testplugin" path: "../testdata/plugins/model/trivial.so" plugintype: InvalidPluginType -`)) - if err == nil { - t.Errorf("invalid plugin type does not return error") - } - - err = initialize([]byte(`--- +`, + wantErr: true, + }, + { + name: "empty plugin type", + config: `--- loglevel: ERROR logpath: /dev/null modelplugins: - id: "testplugin" path: "../testdata/plugins/model/trivial.so" plugintype: "" -`)) - if err == nil { - t.Errorf("empty plugin type does not return error") - } - - err = initialize([]byte(`--- +`, + wantErr: true, + }, + { + name: "nonexistent model plugin path", + config: `--- loglevel: ERROR logpath: /dev/null modelplugins: - id: "testplugin" path: "../testdata/plugins/model/nonexistent.so" plugintype: "RequestHeaders" -`)) - if err == nil { - t.Errorf("nonexistent model plugin path does not return error") - } - - err = initialize([]byte(`--- +`, + wantErr: true, + }, + { + name: "empty model plugin path", + config: `--- loglevel: ERROR logpath: /dev/null modelplugins: - id: "testplugin" path: "" plugintype: "RequestHeaders" -`)) - if err == nil { - t.Errorf("empty plugin path does not return error") - } - - err = initialize([]byte(`--- +`, + wantErr: true, + }, + { + name: "empty decision plugin path", + config: `--- loglevel: ERROR logpath: /dev/null decisionplugins: - id: "test" path: "" -`)) - if err == nil { - t.Errorf("empty decision plugin path does not return error") - } - - err = initialize([]byte(`--- +`, + wantErr: true, + }, + { + name: "nonexistent decision plugin path", + config: `--- loglevel: ERROR logpath: /dev/null decisionplugins: - id: "testplugin" path: "../testdata/plugins/decision/nonexistent.so" -`)) - if err == nil { - t.Errorf("nonexistent decision plugin path does not return error") - } - - values := []string{ - "RequestHeaders", - "RequestBody", - "AllRequest", - "ResponseHeaders", - "ResponseBody", - "AllResponse", - "Everything", - } - - for _, v := range values { - config := `--- +`, + wantErr: true, + }, + { + name: "valid RequestHeaders", + config: `--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "../testdata/plugins/model/trivial.so" + plugintype: "RequestHeaders" +`, + wantType: "RequestHeaders", + }, + { + name: "valid RequestBody", + config: `--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "../testdata/plugins/model/trivial.so" + plugintype: "RequestBody" +`, + wantType: "RequestBody", + }, + { + name: "valid AllRequest", + config: `--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "../testdata/plugins/model/trivial.so" + plugintype: "AllRequest" +`, + wantType: "AllRequest", + }, + { + name: "valid ResponseHeaders", + config: `--- loglevel: ERROR logpath: /dev/null modelplugins: - id: "testplugin" path: "../testdata/plugins/model/trivial.so" - plugintype: "` + v + `" -` - err = initialize([]byte(config)) - if err != nil { - t.Errorf("Plugin type %s returns error: %v", v, err) - } - - if fmt.Sprint(cs.ModelPlugins["testplugin"].PluginType) != v { - t.Errorf("Stored plugin type is %v, expected %v", cs.ModelPlugins["testplugin"].PluginType, v) - } + plugintype: "ResponseHeaders" +`, + wantType: "ResponseHeaders", + }, + { + name: "valid ResponseBody", + config: `--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "../testdata/plugins/model/trivial.so" + plugintype: "ResponseBody" +`, + wantType: "ResponseBody", + }, + { + name: "valid AllResponse", + config: `--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "../testdata/plugins/model/trivial.so" + plugintype: "AllResponse" +`, + wantType: "AllResponse", + }, + { + name: "valid Everything", + config: `--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "../testdata/plugins/model/trivial.so" + plugintype: "Everything" +`, + wantType: "Everything", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cs, err := New() + if err != nil { + t.Fatal(err) + } + defer Clean() + + err = initialize([]byte(tt.config)) + if (err != nil) != tt.wantErr { + if tt.wantErr { + t.Errorf("expected error but got none") + } else { + t.Errorf("unexpected error: %v", err) + } + return + } + if tt.wantType != "" { + if got := fmt.Sprint(cs.ModelPlugins["testplugin"].PluginType); got != tt.wantType { + t.Errorf("plugin type = %q, want %q", got, tt.wantType) + } + } + }) } } func TestInvalidLogging(t *testing.T) { - _, err := New() + tests := []struct { + name string + config string + cleanup func(t *testing.T) + wantErr bool + }{ + { + name: "invalid log level", + config: `--- +loglevel: INVALIDLOGLEVEL +logpath: /dev/null +`, + wantErr: true, + }, + { + name: "writable log path", + config: `--- +loglevel: ERROR +logpath: ./configstore_test.log`, + cleanup: func(t *testing.T) { + if _, err := os.Stat("./configstore_test.log"); err == nil { + if err := os.Remove("./configstore_test.log"); err != nil { + t.Errorf("could not remove ./configstore_test.log") + } + } + }, + wantErr: false, + }, + { + name: "inaccessible log path", + config: `--- +loglevel: ERROR +logpath: /usr/configstore_test.log`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := New() + if err != nil { + t.Fatal(err) + } + defer Clean() + if tt.cleanup != nil { + defer tt.cleanup(t) + } + + err = initialize([]byte(tt.config)) + if (err != nil) != tt.wantErr { + if tt.wantErr { + t.Errorf("expected error but got none") + } else { + t.Errorf("unexpected error: %v", err) + } + } + }) + } +} + +func TestNewGetCleanLifecycle(t *testing.T) { + cs1, err := New() if err != nil { - t.Error(err) + t.Fatalf("New() failed: %v", err) + } + + cs2, err := Get() + if err != nil { + t.Fatalf("Get() after New() failed: %v", err) + } + if cs1 != cs2 { + t.Errorf("Get() returned a different instance than New()") } + Clean() + + _, err = Get() + if err == nil { + t.Errorf("Get() after Clean() should return error") + } +} + +func TestNewDuplicate(t *testing.T) { + _, err := New() + if err != nil { + t.Fatalf("first New() failed: %v", err) + } defer Clean() - err = initialize([]byte(`--- -loglevel: INVALIDLOGLEVEL -logpath: /dev/null -`)) + _, err = New() if err == nil { - t.Errorf("invalid log level does not return error") + t.Errorf("second New() should return error when instance already exists") } +} - if _, err = os.Stat("./configstore_test.log"); err == nil { - err = os.Remove("./configstore_test.log") - if err != nil { - t.Errorf("could not remove ./configstore_test.log") - } +func TestStringToPluginType(t *testing.T) { + tests := []struct { + name string + input string + want ModelPluginType + wantErr bool + }{ + {"RequestHeaders", "RequestHeaders", RequestHeaders, false}, + {"RequestBody", "RequestBody", RequestBody, false}, + {"AllRequest", "AllRequest", AllRequest, false}, + {"ResponseHeaders", "ResponseHeaders", ResponseHeaders, false}, + {"ResponseBody", "ResponseBody", ResponseBody, false}, + {"AllResponse", "AllResponse", AllResponse, false}, + {"Everything", "Everything", Everything, false}, + {"invalid value", "invalid", 0, true}, + {"empty string", "", 0, true}, + {"wrong case", "requestheaders", 0, true}, } - err = initialize([]byte(`--- -loglevel: ERROR -logpath: ./configstore_test.log`)) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := StringToPluginType(tt.input) + if (err != nil) != tt.wantErr { + if tt.wantErr { + t.Errorf("StringToPluginType(%q) should return error but did not", tt.input) + } else { + t.Errorf("StringToPluginType(%q) returned unexpected error: %v", tt.input, err) + } + return + } + if !tt.wantErr && got != tt.want { + t.Errorf("StringToPluginType(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} - if err != nil { - t.Errorf("Error loading config with nonexistent file: %v", err) +func TestModelPluginTypeString(t *testing.T) { + tests := []struct { + pluginType ModelPluginType + want string + }{ + {RequestHeaders, "RequestHeaders"}, + {RequestBody, "RequestBody"}, + {AllRequest, "AllRequest"}, + {ResponseHeaders, "ResponseHeaders"}, + {ResponseBody, "ResponseBody"}, + {AllResponse, "AllResponse"}, + {Everything, "Everything"}, } - err = initialize([]byte(`--- + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + if got := tt.pluginType.String(); got != tt.want { + t.Errorf("ModelPluginType(%d).String() = %q, want %q", int(tt.pluginType), got, tt.want) + } + }) + } +} + +func TestIsAsync(t *testing.T) { + tests := []struct { + name string + mode string + wantAsync bool + }{ + {"sync mode", "sync", false}, + {"async mode", "async", true}, + {"empty mode defaults to sync", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cs, err := New() + if err != nil { + t.Fatal(err) + } + defer Clean() + + config := fmt.Sprintf(`--- loglevel: ERROR -logpath: /usr/configstore_test.log`)) +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "../testdata/plugins/model/trivial.so" + plugintype: "RequestHeaders" + mode: "%s" +`, tt.mode) + if err := initialize([]byte(config)); err != nil { + t.Fatalf("initialize failed: %v", err) + } + + if got := cs.IsAsync("testplugin"); got != tt.wantAsync { + t.Errorf("IsAsync with mode %q = %v, want %v", tt.mode, got, tt.wantAsync) + } + }) + } +} +func TestGetBeforeNew(t *testing.T) { + // ensure clean state + Clean() + + _, err := Get() if err == nil { - t.Errorf("non existent log file in directory without permissions does not rise error") + t.Error("Get() before New() should return error") + } +} + + +func TestNatsURL(t *testing.T) { + tests := []struct { + name string + config string + wantURL string + }{ + { + name: "defaults to localhost:4222", + config: `--- +loglevel: ERROR +logpath: /dev/null +`, + wantURL: "localhost:4222", + }, + { + name: "stores custom URL", + config: `--- +loglevel: ERROR +logpath: /dev/null +natsurl: "nats.example.com:4222" +`, + wantURL: "nats.example.com:4222", + }, } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cs, err := New() + if err != nil { + t.Fatal(err) + } + defer Clean() + + if err := initialize([]byte(tt.config)); err != nil { + t.Fatalf("initialize failed: %v", err) + } + + if cs.NatsURL != tt.wantURL { + t.Errorf("NatsURL = %q, want %q", cs.NatsURL, tt.wantURL) + } + }) + } } diff --git a/pluginmanager/pluginmanager.go b/pluginmanager/pluginmanager.go index e7204ce..3c63503 100644 --- a/pluginmanager/pluginmanager.go +++ b/pluginmanager/pluginmanager.go @@ -65,11 +65,15 @@ type ModelTransmitionResults struct { type modelPlugin struct { p *plugin.Plugin pluginType configstore.ModelPluginType + process func(ModelInput) (ModelResults, error) + reload func(map[string]string, metric.Meter) error } // decisionPlugin is the struct that stores the decision plugin type decisionPlugin struct { - p *plugin.Plugin + p *plugin.Plugin + checkResults func(DecisionInput) (bool, error) + reload func(map[string]string, metric.Meter) error } // ModelStatus stores whether there was an error while processing a @@ -84,8 +88,6 @@ type ModelStatus struct { // every plugin execution. type PluginManager struct { modelPlugins map[string]modelPlugin - modelProcessFunc map[string]func(ModelInput) (ModelResults, error) - decisionCheckFunc map[string]func(DecisionInput) (bool, error) decisionPlugins map[string]decisionPlugin results sync.Map channelsMutex sync.Mutex @@ -94,6 +96,19 @@ type PluginManager struct { natConn *nats.Conn } +const ( + // model plugin function names + modelInitFunctionName = "InitPlugin" + modelInitAsyncFunctionName = "InitPluginAsync" + modelProcessFunctionName = "Process" + modelReloadFunction = "ReloadPlugin" + + // decision plugin function names + decisionInitFunctionName = "InitPlugin" + decisionCheckFuncionName = "CheckResults" + decisionReloadFunctionName = "ReloadPlugin" +) + // New creates a new PluginManager instance. func New(meter metric.Meter) (*PluginManager, error) { pm := new(PluginManager) @@ -112,102 +127,180 @@ func New(meter metric.Meter) (*PluginManager, error) { pm.natConn = nc - // Loading of model plugins pm.modelPlugins = make(map[string]modelPlugin) - pm.modelProcessFunc = make(map[string]func(ModelInput) (ModelResults, error)) + pm.loadModelPlugins(meter) + + pm.decisionPlugins = make(map[string]decisionPlugin) + pm.loadDecisionPlugins(meter) + + return pm, nil +} + +// Reload reloads the configuration for all already-loaded plugins and loads any +// newly added plugins from the current configstore state. +func (pm *PluginManager) Reload(meter metric.Meter) error { + if err := pm.loadModelPlugins(meter); err != nil { + return err + } + return pm.loadDecisionPlugins(meter) +} + +// loadModelPlugins load new Plugins and reload their configuration if they previously existed +func (pm *PluginManager) loadModelPlugins(meter metric.Meter) error { + conf, err := configstore.Get() + if err != nil { + return err + } + logger := logging.Get() + + // Load plugin models for _, data := range conf.ModelPlugins { - tp, err := plugin.Open(data.Path) - if err != nil { - logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) - continue - } - if data.Mode == "async" || conf.ModelPlugins[data.ID].Remote { - f, err := tp.Lookup("InitPluginAsync") + mp, found := pm.modelPlugins[data.ID] + if !found { + p, err := plugin.Open(data.Path) if err != nil { - logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) + logger.Printf(logging.WARN, "| %s | cannot open plugin: %v", data.ID, err) + continue + } + var processFunc func(ModelInput) (ModelResults, error) + // TODO: change mode to bool + if data.Mode == "async" || conf.ModelPlugins[data.ID].Remote { + f, err := p.Lookup(modelInitAsyncFunctionName) + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) + continue + } + initPlugin, ok := f.(func(map[string]string, metric.Meter, func(func(ModelInput) (ModelResults, error))) error) + if !ok { + logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid %s function type", data.ID, modelInitAsyncFunctionName) + continue + } + + // plugin initialization + err = initPlugin(data.Params, meter, func(modelProcess func(ModelInput) (ModelResults, error)) { + ModelProcessHandler(data.ID, modelProcess) + }) + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) + continue + } + go pm.ModelResultsHandler(data.ID) + } else { + f, err := p.Lookup(modelInitFunctionName) + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) + continue + } + initPlugin, ok := f.(func(map[string]string, metric.Meter) error) + if !ok { + logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid %s function type", data.ID, modelInitFunctionName) + continue + } + + // plugin initialization + err = initPlugin(data.Params, meter) + procFunc, err := p.Lookup(modelProcessFunctionName) + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: cannot load %s function", data.ID, modelProcessFunctionName) + continue + } + processFunc, ok = procFunc.(func(ModelInput) (ModelResults, error)) + if !ok { + logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid %s function type", data.ID, modelProcessFunctionName) + continue + } + } + rFun, err := p.Lookup(modelReloadFunction) + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: cannot load %s function", data.ID, modelReloadFunction) continue } - initPlugin, ok := f.(func(map[string]string, metric.Meter, func(func(ModelInput) (ModelResults, error))) error) + reload, ok := rFun.(func(map[string]string, metric.Meter) error) if !ok { - logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid InitPluginAsync function type", data.ID) + logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid %s function type", data.ID, modelReloadFunction) continue } - err = initPlugin(data.Params, meter, func(modelProcess func(ModelInput) (ModelResults, error)) { - ModelProcessHandler(data.ID, modelProcess) - }) + modelPluginLoaded := modelPlugin{p, data.PluginType, processFunc, reload} + pm.modelPlugins[data.ID] = modelPluginLoaded + logger.Printf(logging.INFO, "| %s | plugin loaded", data.ID) + } else { + err = mp.reload(data.Params, meter) + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot reload plugin: %s", data.ID, err.Error()) + continue + } + } + } + return nil +} + +// loadDecisionPlugins load new Plugins and reload their configuration if they previously existed +func (pm *PluginManager) loadDecisionPlugins(meter metric.Meter) error { + conf, err := configstore.Get() + if err != nil { + return err + } + logger := logging.Get() + + // Load decision plugins + for _, data := range conf.DecisionPlugins { + dp, found := pm.decisionPlugins[data.ID] + if !found { + p, err := plugin.Open(data.Path) if err != nil { logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) continue } - go pm.ModelResultsHandler(data.ID) - } else { - f, err := tp.Lookup("InitPlugin") + f, err := p.Lookup(decisionInitFunctionName) if err != nil { logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) continue } initPlugin, ok := f.(func(map[string]string, metric.Meter) error) if !ok { - logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid InitPlugin function type", data.ID) + logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid %s function type", data.ID, decisionInitFunctionName) continue } + + // plugin initialization err = initPlugin(data.Params, meter) - procFunc, err := tp.Lookup("Process") if err != nil { - logger.Printf(logging.WARN, "| %s | cannot load plugin: cannot load Process function", data.ID) + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) + continue + } + checkFunc, err := p.Lookup(decisionCheckFuncionName) + if err != nil { + logger.Printf(logging.ERROR, "| %s | cannot load plugin %s function: %v", data.ID, decisionCheckFuncionName, err) continue } - process, ok := procFunc.(func(ModelInput) (ModelResults, error)) + checkResults, ok := checkFunc.(func(DecisionInput) (bool, error)) if !ok { - logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid Process function type", data.ID) + logger.Printf(logging.ERROR, "| %s | %s lookup failed for plugin: invalid function type", data.ID, decisionCheckFuncionName) continue } - pm.modelProcessFunc[data.ID] = process - } - modelPluginLoaded := modelPlugin{tp, data.PluginType} - pm.modelPlugins[data.ID] = modelPluginLoaded - logger.Printf(logging.INFO, "| %s | plugin loaded", data.ID) - } - pm.decisionPlugins = make(map[string]decisionPlugin) - pm.decisionCheckFunc = make(map[string]func(DecisionInput) (bool, error)) - // Loading of decision plugins - for _, data := range conf.DecisionPlugins { - tp, err := plugin.Open(data.Path) - if err != nil { - logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) - continue - } - f, err := tp.Lookup("InitPlugin") - if err != nil { - logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) - continue - } - initPlugin, ok := f.(func(map[string]string, metric.Meter) error) - if !ok { - logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid InitPlugin function type", data.ID) - continue - } - err = initPlugin(data.Params, meter) - if err != nil { - logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) - continue - } - cR, err := tp.Lookup("CheckResults") - if err != nil { - logger.Printf(logging.ERROR, "| %s | cannot load plugin check results function: %v", data.ID, err) - continue - } - checkResults, ok := cR.(func(DecisionInput) (bool, error)) - if !ok { - logger.Printf(logging.ERROR, "| %s | CheckResults lookup failed for plugin: invalid function type", data.ID) - continue + reloadFunc, err := p.Lookup(decisionReloadFunctionName) + if err != nil { + logger.Printf(logging.ERROR, "| %s | cannot load plugin %s function: %v", data.ID, decisionReloadFunctionName, err) + continue + } + reload, ok := reloadFunc.(func(map[string]string, metric.Meter) error) + if !ok { + logger.Printf(logging.ERROR, "| %s | %s lookup failed for plugin: invalid function type", data.ID, decisionReloadFunctionName) + continue + } + + decisionPluginLoaded := decisionPlugin{p, checkResults, reload} + pm.decisionPlugins[data.ID] = decisionPluginLoaded + } else { + err = dp.reload(data.Params, meter) + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot reload plugin: %s", data.ID, err.Error()) + continue + } } - pm.decisionCheckFunc[data.ID] = checkResults - decisionPluginLoaded := decisionPlugin{tp} - pm.decisionPlugins[data.ID] = decisionPluginLoaded } - return pm, nil + return nil } // InitTransaction initializes the transaction with the given ID @@ -321,14 +414,16 @@ func (p *PluginManager) Process(modelID, transactionId string, payload HTTPPaylo return nil } - process := p.modelProcessFunc[modelID] + mp, ok := p.modelPlugins[modelID] + if !ok { + return fmt.Errorf("Model plugin %s not found", modelID) + } if conf.ModelPlugins[modelID].Mode == "async" { modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("model plugin is async")} return nil } else { - res, err := process(ModelInput{TransactionId: transactionId, Payload: payload}) - // res, err := process(transactionId, payload) + res, err := mp.process(ModelInput{TransactionId: transactionId, Payload: payload}) if err != nil { modelPlugStatus <- ModelStatus{ModelID: modelID, Err: err} @@ -351,7 +446,7 @@ func (p *PluginManager) Process(modelID, transactionId string, payload HTTPPaylo func (p *PluginManager) CheckResult(transactionId, decisionId string, wafParams map[string]string) (bool, error) { logger := logging.Get() - checkResults, ok := p.decisionCheckFunc[decisionId] + dp, ok := p.decisionPlugins[decisionId] if !ok { return false, fmt.Errorf("decision plugin not found") } @@ -374,7 +469,7 @@ func (p *PluginManager) CheckResult(transactionId, decisionId string, wafParams return true }) - res, err := checkResults(DecisionInput{TransactionId: transactionId, Results: modelResultMap, ModelWeight: modelWeightMap, WAFdata: wafParams}) + res, err := dp.checkResults(DecisionInput{TransactionId: transactionId, Results: modelResultMap, ModelWeight: modelWeightMap, WAFdata: wafParams}) logger.TPrintf(logging.INFO, transactionId, "%s | transaction checked. Block: %t ", decisionId, res) return res, err diff --git a/pluginmanager/pluginmanager_test.go b/pluginmanager/pluginmanager_test.go index 2834794..f14edfb 100644 --- a/pluginmanager/pluginmanager_test.go +++ b/pluginmanager/pluginmanager_test.go @@ -1,28 +1,39 @@ -package pluginmanager +package pluginmanager_test import ( "math/rand" + "testing" "time" + "github.com/tilsor/ModSecIntl_logging/logging" "github.com/tilsor/ModSecIntl_wace_lib/configstore" + "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" "go.opentelemetry.io/otel/sdk/metric" "gopkg.in/yaml.v3" - - "github.com/tilsor/ModSecIntl_logging/logging" ) var baseConfig = `--- -logpath: "/tmp/wacetmp.log" +logpath: "/dev/null" loglevel: "WARN" ` var trivialPlugin = ` - id: "trivial" path: "../testdata/plugins/model/trivial.so" weight: 1 - params: - param1: "first value" - param2: "second value" - param3: "third value" + plugintype: "Everything" + mode: sync +` + +var trivial2Plugin = ` - id: "trivial2" + path: "../testdata/plugins/model/trivial2.so" + weight: 1 + plugintype: "Everything" + mode: sync +` + +var errorReqPlugin = ` - id: "error_req" + path: "../testdata/plugins/model/error_req.so" + weight: 1 plugintype: "Everything" mode: sync ` @@ -31,341 +42,489 @@ var testPlugin = ` - id: "test" path: "../testdata/plugins/decision/test.so" wafweight: 0.5 decisionbalance: 0.5 +` + +var simplePlugin = ` - id: "simple" + path: "../testdata/plugins/decision/simple.so" + decisionbalance: 0.5 +` + +// Model plugins that fail to load (silently dropped by pluginmanager) +var noInitPlugin = ` - id: "no_init" + path: "../testdata/plugins/model/no_init.so" + weight: 1 + plugintype: "Everything" + mode: sync +` + +var wrongInitPlugin = ` - id: "wrong_init" + path: "../testdata/plugins/model/wrong_init.so" + weight: 1 + plugintype: "Everything" + mode: sync +` + +var errorInitPlugin = ` - id: "error_init" + path: "../testdata/plugins/model/error_init.so" + weight: 1 + plugintype: "Everything" + mode: sync +` + +var noReqPlugin = ` - id: "no_req" + path: "../testdata/plugins/model/no_req.so" + weight: 1 + plugintype: "Everything" + mode: sync +` + +var wrongReqPlugin = ` - id: "wrong_req" + path: "../testdata/plugins/model/wrong_req.so" + weight: 1 + plugintype: "Everything" + mode: sync +` + +// paramPlugin returns whatever float64 is stored in params["result"]. +// Its ReloadPlugin updates that value, so the output changes after a Reload. +var paramPlugin = ` - id: "param" + path: "../testdata/plugins/model/param.so" + weight: 1 + plugintype: "Everything" + mode: sync params: - test1: "test" - test2: "testtest" - test3: "testtesttest" + result: "0.3" ` +// Decision plugins that fail to load (wrong InitPlugin signature) +var noCheckPlugin = ` - id: "no_check" + path: "../testdata/plugins/decision/no_check.so" + decisionbalance: 0.5 +` + +var wrongCheckPlugin = ` - id: "wrong_check" + path: "../testdata/plugins/decision/wrong_check.so" + decisionbalance: 0.5 +` + +var provider = metric.NewMeterProvider() +var testMeter = provider.Meter("pluginmanager-test-meter") + +func init() { + rand.Seed(time.Now().UnixNano()) + logger := logging.Get() + if err := logger.LoadLogger("/dev/null", logging.ERROR); err != nil { + panic("Error loading logger: " + err.Error()) + } +} + func generateRandomID() string { letters := "1234567890ABCDEF" - id := "" - for i := 0; i < 16; i++ { - id += string(letters[rand.Intn(len(letters))]) + id := make([]byte, 16) + for i := range id { + id[i] = letters[rand.Intn(len(letters))] } - - return id + return string(id) } -var provider = metric.NewMeterProvider() -var testMeter = provider.Meter("example-meter") +// setupPluginManager creates a fresh ConfigStore from the given YAML config, +// returns an initialised PluginManager, and registers configstore.Clean as a +// test cleanup function. +func setupPluginManager(t *testing.T, configuration []byte) *pluginmanager.PluginManager { + t.Helper() + configstore.Clean() + cs, err := configstore.New() + if err != nil { + t.Fatalf("configstore.New() failed: %v", err) + } + t.Cleanup(configstore.Clean) -func initilize(configuration []byte) error { var aux configstore.ConfigFileData - err := yaml.Unmarshal(configuration, &aux) - if err != nil { - return err + if err := yaml.Unmarshal(configuration, &aux); err != nil { + t.Fatalf("yaml.Unmarshal failed: %v", err) } - cs, err := configstore.Get() - if err != nil { - return err + if err := cs.SetConfig(aux); err != nil { + t.Fatalf("SetConfig failed: %v", err) } - err = cs.SetConfig(aux) + logger := logging.Get() + if err := logger.LoadLogger(cs.LogPath, cs.LogLevel); err != nil { + t.Fatalf("LoadLogger failed: %v", err) + } + pm, err := pluginmanager.New(testMeter) if err != nil { - return err + t.Fatalf("pluginmanager.New() failed: %v", err) } - logger := logging.Get() + return pm +} + +func TestPluginManagerNew(t *testing.T) { + config := []byte(baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + testPlugin) + pm := setupPluginManager(t, config) + if pm == nil { + t.Fatal("New() returned nil plugin manager") + } +} + +func TestPluginManagerProcessSync(t *testing.T) { + tests := []struct { + name string + modelConf string + modelID string + wantProb float64 + wantErr bool + }{ + { + name: "trivial returns zero probability", + modelConf: trivialPlugin, + modelID: "trivial", + wantProb: 0.0, + }, + { + name: "trivial2 returns full attack probability", + modelConf: trivial2Plugin, + modelID: "trivial2", + wantProb: 1.0, + }, + { + name: "error_req plugin reports error via channel", + modelConf: errorReqPlugin, + modelID: "error_req", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := []byte(baseConfig + "modelplugins:\n" + tt.modelConf) + pm := setupPluginManager(t, config) + + txID := generateRandomID() + pm.InitTransaction(txID) + defer pm.CloseTransaction(txID) + + ch := make(chan pluginmanager.ModelStatus, 1) + go pm.Process(tt.modelID, txID, pluginmanager.HTTPPayload{URI: "/test"}, configstore.Everything, ch) + status := <-ch + + if (status.Err != nil) != tt.wantErr { + if tt.wantErr { + t.Errorf("Process(%q) expected error but got none", tt.modelID) + } else { + t.Errorf("Process(%q) unexpected error: %v", tt.modelID, status.Err) + } + } + if !tt.wantErr && status.ProbAttack != tt.wantProb { + t.Errorf("Process(%q) ProbAttack = %f, want %f", tt.modelID, status.ProbAttack, tt.wantProb) + } + }) + } +} + +func TestPluginManagerProcessNonexistentPlugin(t *testing.T) { + config := []byte(baseConfig + "modelplugins:\n" + trivialPlugin) + pm := setupPluginManager(t, config) + + txID := generateRandomID() + pm.InitTransaction(txID) + defer pm.CloseTransaction(txID) + + ch := make(chan pluginmanager.ModelStatus, 1) + go pm.Process("nonexistent", txID, pluginmanager.HTTPPayload{}, configstore.Everything, ch) + status := <-ch + if status.Err == nil { + t.Error("Process with nonexistent plugin ID should return error via channel") + } +} + +func TestPluginManagerCheckResult(t *testing.T) { + tests := []struct { + name string + modelConf string + modelID string + wafParams map[string]string + wantBlock bool + }{ + { + name: "trivial (prob=0.0) does not block even with alerting WAF", + modelConf: trivialPlugin, + modelID: "trivial", + wafParams: map[string]string{"inbound_blocking": "20", "inbound_threshold": "5"}, + wantBlock: false, + }, + { + name: "trivial2 (prob=1.0) blocks when WAF also alerts", + modelConf: trivial2Plugin, + modelID: "trivial2", + wafParams: map[string]string{"inbound_blocking": "20", "inbound_threshold": "5"}, + wantBlock: true, + }, + { + name: "trivial2 does not block with empty WAF data", + modelConf: trivial2Plugin, + modelID: "trivial2", + wafParams: make(map[string]string), + wantBlock: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := []byte(baseConfig + "modelplugins:\n" + tt.modelConf + "decisionplugins:\n" + simplePlugin) + pm := setupPluginManager(t, config) + + txID := generateRandomID() + pm.InitTransaction(txID) + defer pm.CloseTransaction(txID) + + ch := make(chan pluginmanager.ModelStatus, 1) + go pm.Process(tt.modelID, txID, pluginmanager.HTTPPayload{URI: "/test"}, configstore.Everything, ch) + <-ch + + result, err := pm.CheckResult(txID, "simple", tt.wafParams) + if err != nil { + t.Fatalf("CheckResult error: %v", err) + } + if result != tt.wantBlock { + t.Errorf("CheckResult = %v, want %v", result, tt.wantBlock) + } + }) + } +} + +func TestPluginManagerCheckResultNonexistentDecision(t *testing.T) { + config := []byte(baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + testPlugin) + pm := setupPluginManager(t, config) + + txID := generateRandomID() + pm.InitTransaction(txID) + defer pm.CloseTransaction(txID) - err = logger.LoadLogger(cs.LogPath, cs.LogLevel) + _, err := pm.CheckResult(txID, "nonexistent", make(map[string]string)) + if err == nil { + t.Error("CheckResult with nonexistent decision plugin should return error") + } +} + +func TestPluginManagerTransactionLifecycle(t *testing.T) { + config := []byte(baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + testPlugin) + pm := setupPluginManager(t, config) + + txID := generateRandomID() + pm.InitTransaction(txID) + + ch := make(chan pluginmanager.ModelStatus, 1) + go pm.Process("trivial", txID, pluginmanager.HTTPPayload{URI: "/test"}, configstore.Everything, ch) + status := <-ch + if status.Err != nil { + t.Fatalf("Process error: %v", status.Err) + } + + // test plugin blocks when anomalyscore >= inboundthreshold + result, err := pm.CheckResult(txID, "test", map[string]string{"anomalyscore": "100", "inboundthreshold": "10"}) if err != nil { - return err + t.Fatalf("CheckResult error: %v", err) + } + if !result { + t.Error("expected transaction to be blocked (anomalyscore 100 >= inboundthreshold 10)") + } + pm.CloseTransaction(txID) +} + +// TestPluginManagerLoadModelFailures checks that New() succeeds even when model +// plugins fail to load (due to missing/wrong Init or Process symbols), and that +// those plugins are not available for processing. +func TestPluginManagerLoadModelFailures(t *testing.T) { + tests := []struct { + name string + modelConf string + modelID string + }{ + {"no InitPlugin symbol", noInitPlugin, "no_init"}, + {"wrong InitPlugin signature", wrongInitPlugin, "wrong_init"}, + {"InitPlugin returns error", errorInitPlugin, "error_init"}, + {"no Process symbol", noReqPlugin, "no_req"}, + {"wrong Process signature", wrongReqPlugin, "wrong_req"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := []byte(baseConfig + "modelplugins:\n" + tt.modelConf) + pm := setupPluginManager(t, config) + if pm == nil { + t.Fatal("New() returned nil — expected success even with bad plugin") + } + + // The bad plugin must have been dropped: Process should return an error. + txID := generateRandomID() + pm.InitTransaction(txID) + defer pm.CloseTransaction(txID) + + ch := make(chan pluginmanager.ModelStatus, 1) + go pm.Process(tt.modelID, txID, pluginmanager.HTTPPayload{}, configstore.Everything, ch) + status := <-ch + if status.Err == nil { + t.Errorf("Process(%q): expected error (plugin should not have been loaded)", tt.modelID) + } + }) } - return nil } -func init() { - rand.Seed(time.Now().UnixNano()) +// TestPluginManagerLoadDecisionFailures checks that New() succeeds even when +// decision plugins fail to load (wrong InitPlugin signature), and that those +// plugins are not available for CheckResult. +func TestPluginManagerLoadDecisionFailures(t *testing.T) { + tests := []struct { + name string + decConf string + decisionID string + }{ + {"no CheckResults symbol", noCheckPlugin, "no_check"}, + {"wrong CheckResults signature", wrongCheckPlugin, "wrong_check"}, + } - logger := logging.Get() - err := logger.LoadLogger("/dev/null", logging.ERROR) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := []byte(baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + tt.decConf) + pm := setupPluginManager(t, config) + if pm == nil { + t.Fatal("New() returned nil — expected success even with bad decision plugin") + } + + txID := generateRandomID() + pm.InitTransaction(txID) + defer pm.CloseTransaction(txID) + + _, err := pm.CheckResult(txID, tt.decisionID, make(map[string]string)) + if err == nil { + t.Errorf("CheckResult(%q): expected error (plugin should not have been loaded)", tt.decisionID) + } + }) + } +} + +// TestPluginManagerReload exercises the already-loaded-plugin branch in +// loadModelPlugins / loadDecisionPlugins (the `found == true` path). +func TestPluginManagerReload(t *testing.T) { + config := []byte(baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + simplePlugin) + pm := setupPluginManager(t, config) + + if err := pm.Reload(testMeter); err != nil { + t.Fatalf("Reload() returned error: %v", err) + } + + // Plugin must still be functional after reload. + txID := generateRandomID() + pm.InitTransaction(txID) + defer pm.CloseTransaction(txID) + + ch := make(chan pluginmanager.ModelStatus, 1) + go pm.Process("trivial", txID, pluginmanager.HTTPPayload{URI: "/test"}, configstore.Everything, ch) + status := <-ch + if status.Err != nil { + t.Errorf("Process after Reload: unexpected error: %v", status.Err) + } + if status.ProbAttack != 0.0 { + t.Errorf("Process after Reload: ProbAttack = %f, want 0.0", status.ProbAttack) + } +} + +// TestPluginManagerProcessTypeMismatch verifies that Process sends an error +// when the plugin's registered type does not match the requested type. +func TestPluginManagerProcessTypeMismatch(t *testing.T) { + // Configure trivial as RequestHeaders type. + conf := baseConfig + `modelplugins: + - id: "trivial" + path: "../testdata/plugins/model/trivial.so" + weight: 1 + plugintype: "RequestHeaders" + mode: sync +` + pm := setupPluginManager(t, []byte(conf)) + + txID := generateRandomID() + pm.InitTransaction(txID) + defer pm.CloseTransaction(txID) + + // Pass Everything — does not match the registered RequestHeaders type. + ch := make(chan pluginmanager.ModelStatus, 1) + go pm.Process("trivial", txID, pluginmanager.HTTPPayload{URI: "/test"}, configstore.Everything, ch) + status := <-ch + if status.Err == nil { + t.Error("Process with mismatched plugin type should return error via channel") + } +} + +// TestPluginManagerReloadChangesOutput verifies that after Reload the plugin +// picks up new params and returns a different ProbAttack value. +// param.so is configured with result=0.3; after updating the configstore to +// result=0.8 and calling Reload, Process must return 0.8. +func TestPluginManagerReloadChangesOutput(t *testing.T) { + config := []byte(baseConfig + "modelplugins:\n" + paramPlugin) + pm := setupPluginManager(t, config) + + runProcess := func(wantProb float64) { + t.Helper() + txID := generateRandomID() + pm.InitTransaction(txID) + defer pm.CloseTransaction(txID) + + ch := make(chan pluginmanager.ModelStatus, 1) + go pm.Process("param", txID, pluginmanager.HTTPPayload{URI: "/test"}, configstore.Everything, ch) + status := <-ch + if status.Err != nil { + t.Errorf("Process: unexpected error: %v", status.Err) + return + } + if status.ProbAttack != wantProb { + t.Errorf("ProbAttack = %f, want %f", status.ProbAttack, wantProb) + } + } + + runProcess(0.3) + + // Update configstore so Reload picks up the new params. + updatedConfig := baseConfig + `modelplugins: + - id: "param" + path: "../testdata/plugins/model/param.so" + weight: 1 + plugintype: "Everything" + mode: sync + params: + result: "0.8" +` + cs, err := configstore.Get() if err != nil { - panic("Error loading logger") + t.Fatalf("configstore.Get: %v", err) + } + var aux configstore.ConfigFileData + if err := yaml.Unmarshal([]byte(updatedConfig), &aux); err != nil { + t.Fatalf("yaml.Unmarshal: %v", err) + } + if err := cs.SetConfig(aux); err != nil { + t.Fatalf("SetConfig with updated params: %v", err) } + + if err := pm.Reload(testMeter); err != nil { + t.Fatalf("Reload: %v", err) + } + + runProcess(0.8) } -// func TestPluginInit(t *testing.T) { -// cases := []struct{ id, conf string }{ -// // {"invalid_path", ` - id: "invalid_path" -// // path: "../testdata/plugins/model/nonexistent.so" -// // plugintype: "AllRequest" -// // `}, -// {"no_init", ` - id: "no_init" -// path: "../testdata/plugins/model/no_init.so" -// plugintype: "AllRequest" -// `}, -// {"wrong_init", ` - id: "wrong_init" -// path: "../testdata/plugins/model/wrong_init.so" -// plugintype: "AllRequest" -// `}, -// {"error_init", ` - id: "error_init" -// path: "../testdata/plugins/model/error_init.so" -// plugintype: "AllRequest" -// `}, -// } - -// // Test model plugin initialization -// for _, c := range cases { -// config := baseConfig + "modelplugins:\n" + trivialPlugin + c.conf - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } -// plugins := New(testMeter) -// if _, exists := plugins.modelPlugins["trivial"]; !exists { -// t.Errorf("trivial plugin not loaded") -// } -// if _, exists := plugins.modelPlugins[c.id]; exists { -// t.Errorf(c.id + " should not load") -// } -// } - -// // Test decision plugin initialization -// for _, c := range cases { -// config := baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + testPlugin + c.conf - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } -// plugins := New(testMeter) -// if _, exists := plugins.decisionPlugins["test"]; !exists { -// t.Errorf("test plugin not loaded") -// } -// if _, exists := plugins.decisionPlugins[c.id]; exists { -// t.Errorf(c.id + " should not load") -// } -// } - -// } - -// func TestPluginParams(t *testing.T) { -// config := baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + testPlugin - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } - -// var buf bytes.Buffer -// logger := lg.Get() -// err = logger.LoadLoggerWriter(&buf, lg.INFO) -// if err != nil { -// t.Errorf("Error loading logger: %v", err) -// } - -// plugins := New(testMeter) - -// if !strings.Contains(buf.String(), "[trivial:InitPlugin] map[param1:first value param2:second value param3:third value]") { -// t.Errorf("trivial plugin did not initialize correctly, got: %v, expected: %v", buf.String(), "[trivial:InitPlugin] map[param1:first value param2:second value param3:third value]") -// } -// if !strings.Contains(buf.String(), "[test:InitPlugin] map[test1:test test2:testtest test3:testtesttest]") { -// t.Errorf("test plugin did not initialize correctly") -// } - -// transactionID := generateRandomID() -// modelPlugStatus := make(chan ModelStatus) -// go plugins.Process("trivial", transactionID, "test request1", cf.AllRequest, modelPlugStatus) -// <-modelPlugStatus -// if !strings.Contains(buf.String(), "[trivial:ProcessRequest] \"test request1\"") { -// t.Errorf("trivial plugin did not analyze request") -// } - -// go plugins.Process("trivial", transactionID, "test response1", cf.AllResponse, modelPlugStatus) -// <-modelPlugStatus -// if !strings.Contains(buf.String(), "[trivial:ProcessResponse] \"test response1\"") { -// t.Errorf("trivial plugin did not analyze response") -// } - -// _, err = plugins.CheckResult(transactionID, "test", map[string]string{"anomalyscore": "100", "inboundthreshold": "10"}) -// if err != nil { -// t.Errorf("Error checking result: %v", err) -// } -// if !strings.Contains(buf.String(), "[test:CheckResults]") { -// t.Errorf("test plugin did not execute correctly") -// } -// if !strings.Contains(buf.String(), "modelRes: map[trivial:") { -// t.Errorf("trivial result is not stored in modelRes") -// } -// if !strings.Contains(buf.String(), "modelWeight: map[trivial:1]") { -// t.Errorf("trivial weight is not stored in modelWeight") -// } -// if !strings.Contains(buf.String(), "modelThres: map[trivial:0.5]") { -// t.Errorf("trivial threshold is not stored in modelWeight") -// } -// if !strings.Contains(buf.String(), "wafData: map[anomalyscore:100 inboundthreshold:10]") { -// t.Errorf("waf params are not stored in wafData") -// } -// } - -// func TestPluginType(t *testing.T) { -// cases := []struct { -// id string -// pluginType, requestType cf.ModelPluginType -// executes bool -// }{ -// {"req_headers-req_headers", cf.RequestHeaders, cf.RequestHeaders, true}, -// {"req_headers-resp_headers", cf.RequestHeaders, cf.ResponseHeaders, false}, -// {"req_headers-all_req", cf.RequestHeaders, cf.AllRequest, false}, -// {"all_req-req_headers", cf.AllRequest, cf.RequestHeaders, false}, -// {"all_req-all_resp", cf.AllRequest, cf.AllResponse, false}, - -// {"resp_headers-resp_headers", cf.ResponseHeaders, cf.ResponseHeaders, true}, -// {"resp_headers-req_headers", cf.ResponseHeaders, cf.RequestHeaders, false}, -// {"resp_headers-all_resp", cf.ResponseHeaders, cf.AllResponse, false}, -// {"all_resp-resp_headers", cf.AllResponse, cf.ResponseHeaders, false}, -// {"all_resp-all_req", cf.AllResponse, cf.AllRequest, false}, - -// {"everything-req_headers", cf.Everything, cf.RequestHeaders, true}, -// {"everything-all_req", cf.Everything, cf.AllRequest, true}, -// {"everything-resp_body", cf.Everything, cf.ResponseBody, true}, -// {"everything-all_resp", cf.Everything, cf.AllResponse, true}, -// } - -// for _, c := range cases { -// config := baseConfig + "modelplugins:\n" + -// " - id: \"" + c.id + "\"\n" + -// " path: \"../testdata/plugins/model/trivial.so\"\n" + -// " plugintype: \"" + c.pluginType.String() + "\"\n" - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } - -// old := log.Writer() -// var buf bytes.Buffer -// log.SetOutput(&buf) -// defer log.SetOutput(old) - -// plugins := New(testMeter) - -// transactionID := generateRandomID() -// modelPlugStatus := make(chan ModelStatus) -// switch c.requestType { -// case cf.RequestHeaders, cf.RequestBody, cf.AllRequest: -// go plugins.Process(c.id, transactionID, "test request", c.requestType, modelPlugStatus) -// <-modelPlugStatus -// if strings.Contains(buf.String(), "[trivial:ProcessRequest] \"test request\"") != c.executes { -// t.Errorf("case %s: expected to run trivial plugin: %v", c.id, c.executes) -// } -// if _, exists := plugins.results.Load(transactionID); exists != c.executes { -// t.Errorf("case %s: expected to store results: %v", c.id, c.executes) -// } -// case cf.ResponseHeaders, cf.ResponseBody, cf.AllResponse: -// go plugins.Process(c.id, transactionID, "test response", c.requestType, modelPlugStatus) -// <-modelPlugStatus -// if strings.Contains(buf.String(), "[trivial:ProcessResponse] \"test response\"") != c.executes { -// t.Errorf("case %s: expected to run trivial plugin: %v", c.id, c.executes) -// } -// if _, exists := plugins.results.Load(transactionID); exists != c.executes { -// t.Errorf("case %s: expected to store results: %v", c.id, c.executes) -// } -// } - -// } -// } - -// func TestProcessRequestInvalid(t *testing.T) { -// cases := []struct{ id, conf string }{ -// {"no_req", ` - id: "no_req" -// path: "../testdata/plugins/model/no_req.so" -// plugintype: "Everything" -// `}, -// {"wrong_req", ` - id: "wrong_req" -// path: "../testdata/plugins/model/wrong_req.so" -// plugintype: "Everything" -// `}, -// {"error_req", ` - id: "error_req" -// path: "../testdata/plugins/model/error_req.so" -// plugintype: "Everything" -// `}, -// } - -// // Test model plugin initialization -// for _, c := range cases { -// config := baseConfig + "modelplugins:\n" + trivialPlugin + c.conf - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } -// plugins := New(testMeter) - -// transactionID := generateRandomID() -// modelPlugStatus := make(chan ModelStatus) -// go plugins.Process(c.id, transactionID, "test request", cf.AllRequest, modelPlugStatus) -// <-modelPlugStatus -// go plugins.Process(c.id, transactionID, "test response", cf.AllResponse, modelPlugStatus) -// <-modelPlugStatus - -// if _, exists := plugins.results.Load(transactionID); exists { -// t.Errorf("invalid test %s stored a result", c.id) -// } -// } - -// config := baseConfig + "modelplugins:\n" + trivialPlugin - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } -// plugins := New(testMeter) - -// transactionID := generateRandomID() -// modelPlugStatus := make(chan ModelStatus) -// go plugins.Process("nonexistent", transactionID, "test request", cf.AllRequest, modelPlugStatus) -// <-modelPlugStatus -// go plugins.Process("nonexistent", transactionID, "test response", cf.AllResponse, modelPlugStatus) -// <-modelPlugStatus - -// if _, exists := plugins.results.Load(transactionID); exists { -// t.Errorf("nonexistent test stored a result") -// } - -// } - -// func TestCheckResultInvalid(t *testing.T) { -// cases := []struct{ id, conf string }{ -// {"no_check", ` - id: "no_check" -// path: "../testdata/plugins/decision/no_check.so" -// `}, -// {"wrong_check", ` - id: "wrong_check" -// path: "../testdata/plugins/decision/wrong_check.so" -// `}, -// {"error_check", ` - id: "error_check" -// path: "../testdata/plugins/decision/error_check.so" -// `}, -// } - -// // Test model plugin initialization -// for _, c := range cases { -// config := baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + c.conf - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } -// plugins := New(testMeter) - -// _, err = plugins.CheckResult(generateRandomID(), c.id, make(map[string]string)) -// if err == nil { -// t.Errorf("invalid CheckResult %s did not rise an error", c.id) -// } -// } - -// config := baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + testPlugin - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } -// plugins := New(testMeter) - -// _, err = plugins.CheckResult(generateRandomID(), "nonexitent", make(map[string]string)) -// if err == nil { -// t.Errorf("nonexistent plugin did not rise an error") -// } - -// } +// TestPluginManagerProcessWithoutTransaction verifies that Process sends an +// error when the transaction was never initialised (results map is absent). +func TestPluginManagerProcessWithoutTransaction(t *testing.T) { + config := []byte(baseConfig + "modelplugins:\n" + trivialPlugin) + pm := setupPluginManager(t, config) + + // Deliberately skip pm.InitTransaction so there is no results entry. + txID := generateRandomID() + + ch := make(chan pluginmanager.ModelStatus, 1) + go pm.Process("trivial", txID, pluginmanager.HTTPPayload{URI: "/test"}, configstore.Everything, ch) + status := <-ch + if status.Err == nil { + t.Error("Process without InitTransaction should return error via channel") + } +} diff --git a/testdata/plugins/Makefile b/testdata/plugins/Makefile index 2943c3e..25ada92 100644 --- a/testdata/plugins/Makefile +++ b/testdata/plugins/Makefile @@ -1,10 +1,18 @@ all: model/trivial.so model/trivial2.so \ model/no_init.so model/wrong_init.so model/error_init.so \ model/no_req.so model/wrong_req.so model/error_req.so \ + model/param.so \ decision/test.so \ decision/no_check.so decision/wrong_check.so decision/error_check.so \ decision/simple.so +# Pass COVER=1 to build plugins instrumented for coverage: +# make -B COVER=1 +# This is required when running: go test -cover ./... +ifdef COVER +FLAGS += -cover +endif + %.so: %.go go build $(FLAGS) -buildmode=plugin -o $@ $< diff --git a/testdata/plugins/decision/error_check.go b/testdata/plugins/decision/error_check.go index dfa3e5e..9c7e198 100644 --- a/testdata/plugins/decision/error_check.go +++ b/testdata/plugins/decision/error_check.go @@ -7,6 +7,7 @@ import ( "errors" lg "github.com/tilsor/ModSecIntl_logging/logging" + "go.opentelemetry.io/otel/metric" ) // InitPlugin intitalizes the plugins (does nothing in this case) @@ -21,3 +22,8 @@ func InitPlugin(params map[string]string) error { func CheckResults(transactionID string, modelRes map[string]float64, modelWeight map[string]float64, modelThres map[string]float64, wafData map[string]string) (bool, error) { return false, errors.New("Some error") } + +// ReloadPlugin reload the plugin +func ReloadPlugin(params map[string]string, meter metric.Meter) error { + return nil +} diff --git a/testdata/plugins/decision/no_check.go b/testdata/plugins/decision/no_check.go index 13d8516..fc88e1e 100644 --- a/testdata/plugins/decision/no_check.go +++ b/testdata/plugins/decision/no_check.go @@ -5,6 +5,7 @@ package main import ( lg "github.com/tilsor/ModSecIntl_logging/logging" + "go.opentelemetry.io/otel/metric" ) // InitPlugin intitalizes the plugins (does nothing in this case) @@ -13,3 +14,8 @@ func InitPlugin(params map[string]string) error { logger.Printf(lg.WARN, "[simple:InitPlugin] %v\n", params) return nil } + +// ReloadPlugin reload the plugin +func ReloadPlugin(params map[string]string, meter metric.Meter) error { + return nil +} diff --git a/testdata/plugins/decision/simple.go b/testdata/plugins/decision/simple.go index dc0e1f2..9143158 100644 --- a/testdata/plugins/decision/simple.go +++ b/testdata/plugins/decision/simple.go @@ -90,3 +90,8 @@ func CheckResults(decisionInput pm.DecisionInput) (bool, error) { // } // return false, nil // } + +// ReloadPlugin reload the plugin +func ReloadPlugin(params map[string]string, meter metric.Meter) error { + return nil +} diff --git a/testdata/plugins/decision/test.go b/testdata/plugins/decision/test.go index daaf564..350b7cc 100644 --- a/testdata/plugins/decision/test.go +++ b/testdata/plugins/decision/test.go @@ -47,3 +47,8 @@ func CheckResults(decisionInput pm.DecisionInput) (bool, error) { } return false, nil } + +// ReloadPlugin reload the plugin +func ReloadPlugin(params map[string]string, meter metric.Meter) error { + return nil +} diff --git a/testdata/plugins/decision/weighted_sum.go b/testdata/plugins/decision/weighted_sum.go index a8ea7a5..c987aa8 100644 --- a/testdata/plugins/decision/weighted_sum.go +++ b/testdata/plugins/decision/weighted_sum.go @@ -88,3 +88,8 @@ func CheckResults(decisionInput pm.DecisionInput) (bool, error) { logger.TPrintf(lg.DEBUG, decisionInput.TransactionId, "weighted_sum | weighted sum: %v threshold: %v", weightedSum, threshold) return weightedSum > threshold, nil } + +// ReloadPlugin reload the plugin +func ReloadPlugin(params map[string]string, meter metric.Meter) error { + return nil +} diff --git a/testdata/plugins/decision/wrong_check.go b/testdata/plugins/decision/wrong_check.go index d5f897b..461a344 100644 --- a/testdata/plugins/decision/wrong_check.go +++ b/testdata/plugins/decision/wrong_check.go @@ -5,6 +5,7 @@ package main import ( lg "github.com/tilsor/ModSecIntl_logging/logging" + "go.opentelemetry.io/otel/metric" ) // InitPlugin intitalizes the plugins (does nothing in this case) @@ -19,3 +20,8 @@ func InitPlugin(params map[string]string) error { func CheckResults() (bool, error) { return false, nil } + +// ReloadPlugin reload the plugin +func ReloadPlugin(params map[string]string, meter metric.Meter) error { + return nil +} diff --git a/testdata/plugins/model/error_init.go b/testdata/plugins/model/error_init.go index a10211d..d4645e4 100644 --- a/testdata/plugins/model/error_init.go +++ b/testdata/plugins/model/error_init.go @@ -23,3 +23,8 @@ func Process(input pm.ModelInput) (pm.ModelResults, error) { } return result, errors.New("Some error") } + +// ReloadPlugin reload the plugin +func ReloadPlugin(params map[string]string, meter metric.Meter) error { + return nil +} diff --git a/testdata/plugins/model/error_req.go b/testdata/plugins/model/error_req.go index c3a0ffd..ef77472 100644 --- a/testdata/plugins/model/error_req.go +++ b/testdata/plugins/model/error_req.go @@ -23,3 +23,8 @@ func Process(input pm.ModelInput) (pm.ModelResults, error) { } return result, errors.New("Some error") } + +// ReloadPlugin reload the plugin +func ReloadPlugin(params map[string]string, meter metric.Meter) error { + return nil +} diff --git a/testdata/plugins/model/no_init.go b/testdata/plugins/model/no_init.go index b3c7c35..9ca99dd 100644 --- a/testdata/plugins/model/no_init.go +++ b/testdata/plugins/model/no_init.go @@ -3,7 +3,10 @@ package main -import pm "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" +import ( + pm "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" + "go.opentelemetry.io/otel/metric" +) // Process always returns 0 probability of attack func Process(input pm.ModelInput) (pm.ModelResults, error) { @@ -13,3 +16,8 @@ func Process(input pm.ModelInput) (pm.ModelResults, error) { } return result, nil } + +// ReloadPlugin reload the plugin +func ReloadPlugin(params map[string]string, meter metric.Meter) error { + return nil +} diff --git a/testdata/plugins/model/no_req.go b/testdata/plugins/model/no_req.go index 86953f7..554b7f7 100644 --- a/testdata/plugins/model/no_req.go +++ b/testdata/plugins/model/no_req.go @@ -8,3 +8,8 @@ import "go.opentelemetry.io/otel/metric" func InitPlugin(params map[string]string, meter metric.Meter) error { return nil } + +// ReloadPlugin reload the plugin +func ReloadPlugin(params map[string]string, meter metric.Meter) error { + return nil +} diff --git a/testdata/plugins/model/param.go b/testdata/plugins/model/param.go new file mode 100644 index 0000000..b8871fa --- /dev/null +++ b/testdata/plugins/model/param.go @@ -0,0 +1,67 @@ +package main + +import ( + "context" + "fmt" + "strconv" + + lg "github.com/tilsor/ModSecIntl_logging/logging" + pm "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +var result float64 + +// InitPlugin reads the "result" param and sets the probability that Process will return. +func InitPlugin(params map[string]string, meter metric.Meter) error { + logger := lg.Get() + logger.Printf(lg.WARN, "[param:InitPlugin] %v\n", params) + resultString, ok := params["result"] + if !ok { + return fmt.Errorf("result parameter not found") + } + var err error + result, err = strconv.ParseFloat(resultString, 64) + if err != nil { + return fmt.Errorf("error parsing result parameter: %v", err) + } + ctx := context.Background() + pluginCounter, err := meter.Int64Counter("plugin_register") + if err != nil { + return err + } + pluginCounter.Add(ctx, 1, metric.WithAttributes(attribute.String("plugin_name", "param"), attribute.String("plugin_type", "model"))) + return nil +} + +func InitPluginAsync(params map[string]string, meter metric.Meter, natsManager func(func(pm.ModelInput) (pm.ModelResults, error))) error { + InitPlugin(params, meter) + natsManager(Process) + return nil +} + +func Process(input pm.ModelInput) (pm.ModelResults, error) { + logger := lg.Get() + logger.TPrintf(lg.WARN, input.TransactionId, "[param:Process] \"%v\"\n", input.Payload) + return pm.ModelResults{ + ProbAttack: result, + Data: make(map[string]interface{}), + }, nil +} + +// ReloadPlugin updates the probability returned by Process from the new params. +func ReloadPlugin(params map[string]string, meter metric.Meter) error { + logger := lg.Get() + logger.Printf(lg.WARN, "[param:ReloadPlugin] %v\n", params) + resultString, ok := params["result"] + if !ok { + return fmt.Errorf("result parameter not found") + } + var err error + result, err = strconv.ParseFloat(resultString, 64) + if err != nil { + return fmt.Errorf("error parsing result parameter: %v", err) + } + return nil +} diff --git a/testdata/plugins/model/trivial.go b/testdata/plugins/model/trivial.go index 046ff50..9e1315d 100644 --- a/testdata/plugins/model/trivial.go +++ b/testdata/plugins/model/trivial.go @@ -41,3 +41,10 @@ func Process(input pm.ModelInput) (pm.ModelResults, error) { } return result, nil } + +// ReloadPlugin reload the plugin (does nothing in this case) +func ReloadPlugin(params map[string]string, meter metric.Meter) error { + logger := lg.Get() + logger.Printf(lg.WARN, "[trivial:ReloadPlugin] %v\n", params) + return nil +} diff --git a/testdata/plugins/model/trivial2.go b/testdata/plugins/model/trivial2.go index 7b9ed66..c60470b 100644 --- a/testdata/plugins/model/trivial2.go +++ b/testdata/plugins/model/trivial2.go @@ -41,3 +41,10 @@ func Process(input pm.ModelInput) (pm.ModelResults, error) { } return result, nil } + +// ReloadPlugin reload the plugin +func ReloadPlugin(params map[string]string, meter metric.Meter) error { + logger := lg.Get() + logger.Printf(lg.WARN, "[trivial2:ReloadPlugin] %v\n", params) + return nil +} diff --git a/testdata/plugins/model/trivial_async.go b/testdata/plugins/model/trivial_async.go index 9e38210..b473ca4 100644 --- a/testdata/plugins/model/trivial_async.go +++ b/testdata/plugins/model/trivial_async.go @@ -56,3 +56,7 @@ func Process(input pm.ModelInput) (pm.ModelResults, error) { } return result, nil } + +// ReloadPlugin reload the pluginfunc ReloadPlugin(params map[string]string, meter metric.Meter) error { + return nil +} diff --git a/testdata/plugins/model/trivial_async2.go b/testdata/plugins/model/trivial_async2.go index 79a0810..1d4eecc 100644 --- a/testdata/plugins/model/trivial_async2.go +++ b/testdata/plugins/model/trivial_async2.go @@ -56,3 +56,8 @@ func Process(input pm.ModelInput) (pm.ModelResults, error) { } return result, nil } + +// ReloadPlugin reload the plugin +func ReloadPlugin(params map[string]string, meter metric.Meter) error { + return nil +} diff --git a/testdata/plugins/model/wrong_init.go b/testdata/plugins/model/wrong_init.go index 51937b7..52596a8 100644 --- a/testdata/plugins/model/wrong_init.go +++ b/testdata/plugins/model/wrong_init.go @@ -3,6 +3,8 @@ package main +import "go.opentelemetry.io/otel/metric" + // InitPlugin intitalizes the plugins (does nothing in this case) func InitPlugin() error { return nil @@ -12,3 +14,8 @@ func InitPlugin() error { func Process() (float64, error) { return 0.0, nil } + +// ReloadPlugin reload the plugin +func ReloadPlugin(params map[string]string, meter metric.Meter) error { + return nil +} diff --git a/testdata/plugins/model/wrong_req.go b/testdata/plugins/model/wrong_req.go index 7d83724..e85f7aa 100644 --- a/testdata/plugins/model/wrong_req.go +++ b/testdata/plugins/model/wrong_req.go @@ -14,3 +14,8 @@ func InitPlugin(params map[string]string, meter metric.Meter) error { func Process() (float64, error) { return 0.0, nil } + +// ReloadPlugin reload the plugin +func ReloadPlugin(params map[string]string, meter metric.Meter) error { + return nil +} diff --git a/wacecore.go b/wacecore.go index 82f315b..88a217f 100644 --- a/wacecore.go +++ b/wacecore.go @@ -248,6 +248,24 @@ func CloseTransaction(transactionID string) { } } +// Reload applies a new configuration and reloads all plugins. +func Reload(met metric.Meter, conf configstore.ConfigFileData) error { + logger := logging.Get() + + cs, err := configstore.Get() + if err != nil { + return err + } + if err = cs.SetConfig(conf); err != nil { + return err + } + if err = logger.LoadLogger(cs.LogPath, cs.LogLevel); err != nil { + return err + } + meter = met + return plugins.Reload(met) +} + // Init initializes the WACE core with the given metric meter func Init(met metric.Meter, conf configstore.ConfigFileData) error { logger := logging.Get() diff --git a/wacecore_test.go b/wacecore_test.go index 5115dae..9f9fd67 100644 --- a/wacecore_test.go +++ b/wacecore_test.go @@ -1,6 +1,7 @@ package wace import ( + "fmt" "math/rand" "strconv" "strings" @@ -31,9 +32,10 @@ var requestHeaders = []pluginmanager.HTTPHeader{ } var requestHeadersPayload = pluginmanager.HTTPPayload{ - URI: requestURI, - Method: requestMethod, - HTTPVersion: requestVersion, + URI: requestURI, + Method: requestMethod, + HTTPVersion: requestVersion, + RequestHeaders: requestHeaders, } var requestBody = "licenseID=string&content=string&/paramsXML=string\n" @@ -228,49 +230,10 @@ decisionplugins: decisionbalance: 0.1 `) -// var configRoberta = []byte(`--- -// logpath: "/dev/null" -// loglevel: DEBUG -// listenport: "50051" -// modelplugins: -// - id: "trivial" -// path: "testdata/plugins/model/trivial.so" -// weight: 1 -// threshold: 0.5 -// params: -// d: "sds" -// b: "dnid" -// e: "dofnno" -// # plugintype: "RequestHeaders" -// plugintype: "Everything" -// - id: "trivial2" -// path: "testdata/plugins/model/trivial2.so" -// weight: 2 -// threshold: 0.1 -// params: -// a: "sdsds" -// b: "sdfjdnid" -// c: "kfoskdofnno" -// plugintype: "Everything" -// - id: "roberta" -// path: "testdata/plugins/model/roberta.so" -// weight: 1 -// threshold: 0.5 -// params: -// url: "localhost:9999" -// distance_threshold: -0.02 -// plugintype: "AllRequest" -// decisionplugins: -// - id: "simple" -// path: "testdata/plugins/decision/simple.so" -// wafweight: 0.5 -// decisionbalance: 0.5 -// `) - var provider = metric.NewMeterProvider() var testMeter = provider.Meter("example-meter") -func initilize(configuration []byte) error { +func initialize(configuration []byte) error { var aux configstore.ConfigFileData err := yaml.Unmarshal(configuration, &aux) if err != nil { @@ -293,8 +256,101 @@ func generateRandomID() string { return id } -func TestAnalyzeRequestInParts(t *testing.T) { - err := initilize(configAllModels) +func TestAnalyze(t *testing.T) { + type step struct { + payloadType string + payload pluginmanager.HTTPPayload + plugins []string + } + tests := []struct { + name string + config []byte + steps []step + postDelay time.Duration + }{ + { + name: "request in parts", + config: configAllModels, + steps: []step{ + {"RequestHeaders", requestHeadersPayload, []string{"trivialRequestHeaders"}}, + {"RequestBody", pluginmanager.HTTPPayload{RequestBody: requestBody}, []string{"trivialRequestBody"}}, + }, + }, + { + name: "whole request", + config: configAllModels, + steps: []step{ + {"AllRequest", wholeRequest, []string{"trivialAllRequest"}}, + }, + }, + { + name: "response in parts", + config: configAllModels, + steps: []step{ + {"ResponseHeaders", responseHeadersPayload, []string{"trivialResponseHeaders"}}, + {"ResponseBody", pluginmanager.HTTPPayload{ResponseBody: responseBody}, []string{"trivialResponseBody"}}, + }, + }, + { + name: "whole response", + config: configAllModels, + steps: []step{ + {"AllResponse", wholeResponse, []string{"trivialAllResponse"}}, + }, + }, + { + name: "request in parts async", + config: configAsync, + steps: []step{{"RequestHeaders", requestHeadersPayload, []string{"trivial", "trivial2"}}}, + postDelay: 10 * time.Millisecond, + }, + { + name: "empty models list is a no-op", + config: configAllModels, + steps: []step{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := initialize(tt.config) + defer configstore.Clean() + if err != nil { + t.Fatalf("Error initing test: %v", err) + } + + transactionID := generateRandomID() + InitTransaction(transactionID) + + for _, s := range tt.steps { + if err := Analyze(s.payloadType, transactionID, s.payload, s.plugins); err != nil { + t.Errorf("Analyze %s: %v", s.payloadType, err) + } + } + + _, err = CheckTransaction(transactionID, "simple", make(map[string]string)) + if err != nil { + t.Errorf("CheckTransaction: %v", err) + } + + CloseTransaction(transactionID) + + if tt.postDelay > 0 { + time.Sleep(tt.postDelay) + } + }) + } +} + +func TestCheckInvalidTransaction(t *testing.T) { + _, err := CheckTransaction("INEXISTENT", "simple", make(map[string]string)) + if err == nil { + t.Errorf("Error: CheckTransaction with inexistent transaction does not rise an error") + } +} + +func TestCheckAttackTransaction(t *testing.T) { + err := initialize(configSyncNoRemote) defer configstore.Clean() if err != nil { t.Errorf("Error initing test: %v", err) @@ -304,213 +360,318 @@ func TestAnalyzeRequestInParts(t *testing.T) { InitTransaction(transactionID) - res := Analyze("RequestHeaders", transactionID, requestHeadersPayload, []string{"trivialRequestHeaders"}) - if res != nil { - t.Errorf("Error: Analyze RequestHeaders: %s", res.Error()) + wafParams := make(map[string]string) + auxString := "COMBINED_SCORE=0,HTTP=0,LFI=0,PHPI=0,RCE=0,RFI=0,SESS=0,SQLI=0,XSS=0,inbound_blocking=20,inbound_detection=0,inbound_per_pl=0-0-0-0,inbound_threshold=5,outbound_blocking=0,outbound_detection=0,outbound_per_pl=0-0-0-0,outbound_threshold=4,phase=2" + for _, score := range strings.Split(auxString, ",") { + scoreParts := strings.Split(score, "=") + wafParams[scoreParts[0]] = scoreParts[1] } - res = Analyze("RequestBody", transactionID, pluginmanager.HTTPPayload{ResponseBody: requestBody}, []string{"trivialRequestBody"}) - if res != nil { - t.Errorf("Error: Analyze RequestBody: %s", res.Error()) + + err = Analyze("RequestHeaders", transactionID, requestHeadersPayload, []string{"trivial", "trivial2", "trivial3"}) + if err != nil { + t.Errorf("Error: Analyze RequestHeaders: %s", err.Error()) } - _, err = CheckTransaction(transactionID, "simple", make(map[string]string)) + res, err := CheckTransaction(transactionID, "simple", wafParams) if err != nil { t.Errorf("Error: CheckTransaction: %s", err.Error()) } + if !res { + t.Errorf("Error: CheckTransaction: transaction should be blocked") + } CloseTransaction(transactionID) } -func TestAnalyzeWholeRequest(t *testing.T) { - err := initilize(configAllModels) +func TestAnalyzeInvalidType(t *testing.T) { + err := initialize(configAllModels) defer configstore.Clean() if err != nil { - t.Errorf("Error initing test: %v", err) + t.Fatalf("Error initing test: %v", err) } transactionID := generateRandomID() - InitTransaction(transactionID) + defer CloseTransaction(transactionID) - res := Analyze("AllRequest", transactionID, wholeRequest, []string{"trivialAllRequest"}) - if res != nil { - t.Errorf("Error: Analyze AllRequest: %s", res.Error()) - } - - _, err = CheckTransaction(transactionID, "simple", make(map[string]string)) - if err != nil { - t.Errorf("Error: CheckTransaction: %s", err.Error()) + err = Analyze("InvalidType", transactionID, requestHeadersPayload, []string{"trivialRequestHeaders"}) + if err == nil { + t.Errorf("Analyze with invalid type should return error") } - - CloseTransaction(transactionID) } -func TestAnalyzeResponseInParts(t *testing.T) { - err := initilize(configAllModels) - defer configstore.Clean() +// TestInitDuplicate covers the configstore.New() error branch in Init: calling +// Init a second time without Clean in between must return an error. +func TestInitDuplicate(t *testing.T) { + err := initialize(config) if err != nil { - t.Errorf("Error initing test: %v", err) + t.Fatalf("first initialize: %v", err) } + defer configstore.Clean() - transactionID := generateRandomID() - - InitTransaction(transactionID) + err = initialize(config) + if err == nil { + t.Error("second Init without Clean should return error") + } +} - res := Analyze("ResponseHeaders", transactionID, responseHeadersPayload, []string{"trivialResponseHeaders"}) - if res != nil { - t.Errorf("Error: Analyze ResponseHeaders: %s", res.Error()) +// TestInitInvalidConfig covers the SetConfig error branch in Init: a config +// referencing a nonexistent plugin path must cause Init to return an error. +func TestInitInvalidConfig(t *testing.T) { + badConfig := []byte(`--- +logpath: "/dev/null" +loglevel: "ERROR" +modelplugins: + - id: "missing" + path: "testdata/plugins/model/does_not_exist.so" + plugintype: "Everything" +`) + var aux configstore.ConfigFileData + if err := yaml.Unmarshal(badConfig, &aux); err != nil { + t.Fatalf("yaml.Unmarshal: %v", err) } - res = Analyze("ResponseBody", transactionID, pluginmanager.HTTPPayload{ResponseBody: responseBody}, []string{"trivialResponseBody"}) - if res != nil { - t.Errorf("Error: Analyze ResponseBody: %s", res.Error()) + err := Init(testMeter, aux) + configstore.Clean() + if err == nil { + t.Error("Init with nonexistent plugin path should return error") } +} - _, err = CheckTransaction(transactionID, "simple", make(map[string]string)) +func TestCloseNonexistentTransaction(t *testing.T) { + err := initialize(configAllModels) + defer configstore.Clean() if err != nil { - t.Errorf("Error: CheckTransaction: %s", err.Error()) + t.Fatalf("Error initing test: %v", err) } - CloseTransaction(transactionID) + // should log an error but not panic + CloseTransaction("NONEXISTENT") } -func TestAnalyzeWholeResponse(t *testing.T) { - err := initilize(configAllModels) +func TestCheckNonexistentDecisionPlugin(t *testing.T) { + err := initialize(configAllModels) defer configstore.Clean() if err != nil { - t.Errorf("Error initing test: %v", err) + t.Fatalf("Error initing test: %v", err) } transactionID := generateRandomID() - InitTransaction(transactionID) + defer CloseTransaction(transactionID) - res := Analyze("AllResponse", transactionID, wholeResponse, []string{"trivialAllResponse"}) - if res != nil { - t.Errorf("Error: Analyze AllResponse: %s", res.Error()) + _, err = CheckTransaction(transactionID, "nonexistent_plugin", make(map[string]string)) + if err == nil { + t.Errorf("CheckTransaction with nonexistent decision plugin should return error") } +} - _, err = CheckTransaction(transactionID, "simple", make(map[string]string)) - if err != nil { - t.Errorf("Error: CheckTransaction: %s", err.Error()) +// parseWAFParams parses a comma-separated "key=value" string into a map. +func parseWAFParams(s string) map[string]string { + params := make(map[string]string) + for _, pair := range strings.Split(s, ",") { + parts := strings.SplitN(pair, "=", 2) + if len(parts) == 2 { + params[parts[0]] = parts[1] + } } + return params +} - CloseTransaction(transactionID) +func TestCheckTransactionResult(t *testing.T) { + blockingWAF := parseWAFParams("inbound_blocking=20,inbound_threshold=5") + noAlertWAF := parseWAFParams("inbound_blocking=0,inbound_threshold=5") + + tests := []struct { + name string + config []byte + models []string + wafParams map[string]string + wantBlock bool + }{ + { + name: "trivial2 (prob=1.0) with alerting WAF blocks", + config: configSyncNoRemote, + models: []string{"trivial2"}, + wafParams: blockingWAF, + wantBlock: true, + }, + { + name: "trivial (prob=0.0) with alerting WAF does not block", + config: configSyncNoRemote, + models: []string{"trivial"}, + wafParams: blockingWAF, + wantBlock: false, + }, + { + name: "trivial2 with non-alerting WAF does not block", + config: configSyncNoRemote, + models: []string{"trivial2"}, + wafParams: noAlertWAF, + wantBlock: false, + }, + { + name: "empty models list never blocks", + config: configSyncNoRemote, + models: []string{}, + wafParams: blockingWAF, + wantBlock: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := initialize(tt.config) + defer configstore.Clean() + if err != nil { + t.Fatalf("initialize: %v", err) + } + + txID := generateRandomID() + InitTransaction(txID) + defer CloseTransaction(txID) + + if len(tt.models) > 0 { + if err := Analyze("RequestHeaders", txID, requestHeadersPayload, tt.models); err != nil { + t.Fatalf("Analyze: %v", err) + } + } + + blocked, err := CheckTransaction(txID, "simple", tt.wafParams) + if err != nil { + t.Fatalf("CheckTransaction: %v", err) + } + if blocked != tt.wantBlock { + t.Errorf("blocked = %v, want %v", blocked, tt.wantBlock) + } + }) + } } -func TestAnalyzeRequestInPartsAsync(t *testing.T) { - err := initilize(configAsync) +func TestAnalyzeMultiPhase(t *testing.T) { + err := initialize(configAllModels) defer configstore.Clean() if err != nil { - t.Errorf("Error initing test: %v", err) + t.Fatalf("initialize: %v", err) } - transactionID := generateRandomID() - - InitTransaction(transactionID) + txID := generateRandomID() + InitTransaction(txID) + defer CloseTransaction(txID) - res := Analyze("RequestHeaders", transactionID, requestHeadersPayload, []string{"trivial", "trivial2"}) - if res != nil { - t.Errorf("Error: Analyze RequestHeaders: %s", res.Error()) + phases := []struct { + payloadType string + payload pluginmanager.HTTPPayload + models []string + }{ + {"RequestHeaders", requestHeadersPayload, []string{"trivialRequestHeaders"}}, + {"RequestBody", pluginmanager.HTTPPayload{RequestBody: requestBody}, []string{"trivialRequestBody"}}, + {"ResponseHeaders", responseHeadersPayload, []string{"trivialResponseHeaders"}}, + {"ResponseBody", pluginmanager.HTTPPayload{ResponseBody: responseBody}, []string{"trivialResponseBody"}}, } - _, err = CheckTransaction(transactionID, "simple", make(map[string]string)) - if err != nil { - t.Errorf("Error: CheckTransaction: %s", err.Error()) + for _, p := range phases { + if err := Analyze(p.payloadType, txID, p.payload, p.models); err != nil { + t.Errorf("Analyze(%s): %v", p.payloadType, err) + } } - CloseTransaction(transactionID) - - time.Sleep(10 * time.Millisecond) -} - -func TestCheckInvalidTransaction(t *testing.T) { - _, err := CheckTransaction("INEXISTENT", "simple", make(map[string]string)) - if err == nil { - t.Errorf("Error: CheckTransaction with inexistent transaction does not rise an error") + _, err = CheckTransaction(txID, "simple", make(map[string]string)) + if err != nil { + t.Errorf("CheckTransaction after multi-phase analysis: %v", err) } } -func TestCheckAttackTransaction(t *testing.T) { - err := initilize(configSyncNoRemote) +func TestConcurrentTransactions(t *testing.T) { + err := initialize(configSyncNoRemote) defer configstore.Clean() if err != nil { - t.Errorf("Error initing test: %v", err) + t.Fatalf("initialize: %v", err) } - transactionID := generateRandomID() + wafParams := parseWAFParams("inbound_blocking=20,inbound_threshold=5") - InitTransaction(transactionID) + const goroutines = 20 + errs := make(chan error, goroutines) - wafParams := make(map[string]string) - auxString := "COMBINED_SCORE=0,HTTP=0,LFI=0,PHPI=0,RCE=0,RFI=0,SESS=0,SQLI=0,XSS=0,inbound_blocking=20,inbound_detection=0,inbound_per_pl=0-0-0-0,inbound_threshold=5,outbound_blocking=0,outbound_detection=0,outbound_per_pl=0-0-0-0,outbound_threshold=4,phase=2" - for _, score := range strings.Split(auxString, ",") { - scoreParts := strings.Split(score, "=") - wafParams[scoreParts[0]] = scoreParts[1] + for i := 0; i < goroutines; i++ { + go func() { + txID := generateRandomID() + InitTransaction(txID) + + if err := Analyze("RequestHeaders", txID, requestHeadersPayload, []string{"trivial", "trivial2"}); err != nil { + errs <- fmt.Errorf("Analyze: %w", err) + CloseTransaction(txID) + return + } + + if _, err := CheckTransaction(txID, "simple", wafParams); err != nil { + errs <- fmt.Errorf("CheckTransaction: %w", err) + CloseTransaction(txID) + return + } + + CloseTransaction(txID) + errs <- nil + }() } - err = Analyze("RequestHeaders", transactionID, requestHeadersPayload, []string{"trivial", "trivial2", "trivial3"}) - if err != nil { - t.Errorf("Error: Analyze RequestHeaders: %s", err.Error()) + for i := 0; i < goroutines; i++ { + if err := <-errs; err != nil { + t.Errorf("concurrent transaction error: %v", err) + } } +} - res, err := CheckTransaction(transactionID, "simple", wafParams) - if err != nil { - t.Errorf("Error: CheckTransaction: %s", err.Error()) +// configParamWith returns a YAML config using param.so with the given result value. +func configParamWith(result string) []byte { + return []byte(`--- +logpath: "/dev/null" +loglevel: "WARN" +modelplugins: + - id: "param" + path: "testdata/plugins/model/param.so" + weight: 1 + plugintype: "Everything" + mode: sync + params: + result: "` + result + `" +decisionplugins: + - id: "simple" + path: "testdata/plugins/decision/simple.so" + decisionbalance: 0.5 +`) +} + +// TestReload verifies that Reload succeeds and that transactions still work +// correctly after it. +func TestReload(t *testing.T) { + if err := initialize(configParamWith("0.3")); err != nil { + t.Fatalf("initialize: %v", err) } - if !res { - t.Errorf("Error: CheckTransaction: transaction should be blocked") + defer configstore.Clean() + + var newConf configstore.ConfigFileData + if err := yaml.Unmarshal(configParamWith("0.8"), &newConf); err != nil { + t.Fatalf("yaml.Unmarshal: %v", err) + } + if err := Reload(testMeter, newConf); err != nil { + t.Fatalf("Reload: %v", err) } - CloseTransaction(transactionID) + // Transactions must still complete successfully after a reload. + txID := generateRandomID() + InitTransaction(txID) + defer CloseTransaction(txID) + if err := Analyze("Everything", txID, pluginmanager.HTTPPayload{URI: "/test"}, []string{"param"}); err != nil { + t.Fatalf("Analyze after Reload: %v", err) + } + if _, err := CheckTransaction(txID, "simple", make(map[string]string)); err != nil { + t.Fatalf("CheckTransaction after Reload: %v", err) + } } -// func TestAnalyzeStress(t *testing.T) { -// for i := 0; i < 1000; i++ { -// transactionID := generateRandomID() -// AnalyzeRequest(transactionID, wholeRequest, []string{"trivial", "trivial2"}) -// _, err := CheckTransaction(transactionID, "simple", make(map[string]string)) -// if err != nil { -// t.Errorf("checkTransaction error: %v", err) -// } -// } - -// } - -// func processRequest(models []string) error { -// transactionID := generateRandomID() - -// res := AnalyzeRequest(transactionID, wholeRequest, models) -// if res != 0 { -// return errors.New("analyzeRequest returned non-zero") -// } - -// _, err := CheckTransaction(transactionID, "simple", -// map[string]string{"anomalyscore": "1", -// "inboundthreshold": "100"}) -// return err -// } - -// func TestRoberta(t *testing.T) { -// conf := cf.Get() -// err := conf.LoadConfigYaml(configRoberta) -// if err != nil { -// panic("Error loading config: " + err.Error()) -// } - -// err = processRequest([]string{"roberta"}) -// if err != nil { -// t.Errorf("callRoberta error: %v", err) -// } -// } - -// func BenchmarkRoberta(b *testing.B) { -// for i := 0; i < b.N; i++ { -// processRequest([]string{"roberta"}) -// } -// } - func BenchmarkTrivial(b *testing.B) { - err := initilize(configSyncNoRemote) + err := initialize(configSyncNoRemote) defer configstore.Clean() if err != nil { b.Errorf("Error initing test: %v", err) @@ -537,7 +698,7 @@ func BenchmarkTrivial(b *testing.B) { } func BenchmarkTrivialFullNATS(b *testing.B) { - err := initilize(configSyncRemote) + err := initialize(configSyncRemote) defer configstore.Clean() if err != nil { b.Errorf("Error initing test: %v", err)