diff --git a/configstore/configstore.go b/configstore/configstore.go index 98f52c8..58d9352 100644 --- a/configstore/configstore.go +++ b/configstore/configstore.go @@ -1,248 +1,261 @@ -/* -Package configstore handles the configuration of WACE. The -configuration file is parsed, checked for errors and loaded into -memory -*/ -package configstore - -import ( - "fmt" - "io/ioutil" - "os" - - lg "github.com/tilsor/ModSecIntl_logging/logging" -) - -// ModelPluginType is an enum listing the parts of a request or -// response that a model plugin can handle. -type ModelPluginType int - -const ( - RequestHeaders ModelPluginType = iota - RequestBody - AllRequest - ResponseHeaders - ResponseBody - AllResponse - Everything -) - -// String returns the string representation of a model plugin type -func (t ModelPluginType) String() string { - switch t { - case RequestHeaders: - return "RequestHeaders" - case RequestBody: - return "RequestBody" - case AllRequest: - return "AllRequest" - case ResponseHeaders: - return "ResponseHeaders" - case ResponseBody: - return "ResponseBody" - case AllResponse: - return "AllResponse" - default: - return "Everything" - } -} - -// StringToPluginType converts a string to the corresponding model plugin type -func StringToPluginType(textType string) (ModelPluginType, error) { - switch textType { - case "RequestHeaders": - return RequestHeaders, nil - case "RequestBody": - return RequestBody, nil - case "AllRequest": - return AllRequest, nil - case "ResponseHeaders": - return ResponseHeaders, nil - case "ResponseBody": - return ResponseBody, nil - case "AllResponse": - return AllResponse, nil - case "Everything": - return Everything, nil - } - return -1, fmt.Errorf("invalid plugin type %s", textType) -} - -// ModelPluginConfig stores the configuration of a model plugin -type modelPluginConfig struct { - ID string - Path string - Weight float64 - Threshold float64 - Params map[string]string - PluginType ModelPluginType - Mode string - Remote bool -} - -// DecisionPluginConfig stores the configuration of a decision plugin -type decisionPluginConfig struct { - ID string - Path string - WAFweight float64 - DecisionBalance float64 - Params map[string]string -} - -// ConfigStore stores all wacecore configuration from the config file. -type ConfigStore struct { - ModelPlugins map[string]modelPluginConfig - DecisionPlugins map[string]decisionPluginConfig - LogPath string - LogLevel lg.LogLevel - NatsURL string - ApplicationId string -} - -var config *ConfigStore - -// Get returns or creates the unique instance of configstore -func Get() *ConfigStore { - if config == nil { - config = new(ConfigStore) - } - return config -} - -type configFileModelPlugin struct { - ID string - Path string - Weight float64 - Threshold float64 - Params map[string]string - PluginType string `yaml:"plugintype"` - Mode string - Remote bool -} - -type configFileDecisionPlugin struct { - ID string - Path string - wafweight float64 - decisionbalance float64 - Params map[string]string -} - -type ConfigFileData struct { - Logpath string - Loglevel string - Modelplugins []configFileModelPlugin - Decisionplugins []configFileDecisionPlugin - NatsURL string -} - -// IsAsync returns true if the model plugin is async -func (c *ConfigStore) IsAsync(modelID string) bool { - return c.ModelPlugins[modelID].Mode == "async" -} - -// CheckLogging verifies if the log path is valid -func checkLogging(inConf ConfigFileData) error { - // check logpath - if inConf.Logpath == "" { - return fmt.Errorf("log path empty") - } - _, err := os.Stat(inConf.Logpath) - if err != nil { // check if log file does not exists already - // Attempt to create dummy file - var d []byte - err = ioutil.WriteFile(inConf.Logpath, d, 0644) - if err == nil { - err = os.Remove(inConf.Logpath) // delete it - } - } - return err -} - -// CheckConfig verifies if the configuration read from the config file -// is correct. -func checkConfig(inConf ConfigFileData) error { - err := checkLogging(inConf) - if err != nil { - return fmt.Errorf("invalid log path %s: %v", inConf.Logpath, err) - } - - // check modelplugins - for _, modelP := range inConf.Modelplugins { - - if modelP.Path != "" { - if _, err := os.Stat(modelP.Path); err != nil { - return fmt.Errorf("%s plugin path %s: %v", modelP.ID, modelP.Path, err) - } - } else { - return fmt.Errorf("%s plugin path is empty, please provide a valid path", modelP.ID) - } - if modelP.PluginType == "" { - return fmt.Errorf("%s plugin type cannot be empty, please provide a valid type", modelP.ID) - } - // fmt.Printf("modelP.Type: %s\n", modelP.Type) - } - // check decisionplugins - for _, decisionP := range inConf.Decisionplugins { - - if decisionP.Path != "" { - if _, err := os.Stat(decisionP.Path); err != nil { - return fmt.Errorf("%s plugin path %s cannot be opened: %v", decisionP.ID, decisionP.Path, err) - } - } else { - return fmt.Errorf("%s plugin path is empty, please provide a valid path", decisionP.ID) - } - } - - return nil -} - -// SetConfig sets the configuration of WACE from the configuration file -func (cs *ConfigStore) SetConfig(inConf ConfigFileData) error { - err := checkConfig(inConf) - if err != nil { - return err - } - - cs.LogPath = inConf.Logpath - cs.LogLevel, err = lg.StringToLogLevel(inConf.Loglevel) - if err != nil { - return err - } - - cs.ModelPlugins = make(map[string]modelPluginConfig) - for _, modelP := range inConf.Modelplugins { - var modelConfig modelPluginConfig - modelConfig.ID = modelP.ID - modelConfig.Path = modelP.Path - modelConfig.Weight = modelP.Weight - modelConfig.Threshold = modelP.Threshold - modelConfig.Params = modelP.Params - modelConfig.PluginType, err = StringToPluginType(modelP.PluginType) - modelConfig.Mode = modelP.Mode - modelConfig.Remote = modelP.Remote - if err != nil { - return err - } - cs.ModelPlugins[modelConfig.ID] = modelConfig - } - - cs.DecisionPlugins = make(map[string]decisionPluginConfig) - for _, decisionP := range inConf.Decisionplugins { - var decisionConfig decisionPluginConfig - decisionConfig.ID = decisionP.ID - decisionConfig.Path = decisionP.Path - decisionConfig.WAFweight = decisionP.wafweight - decisionConfig.DecisionBalance = decisionP.decisionbalance - decisionConfig.Params = decisionP.Params - cs.DecisionPlugins[decisionConfig.ID] = decisionConfig - } - - if inConf.NatsURL != "" { - cs.NatsURL = inConf.NatsURL - } else { - cs.NatsURL = "localhost:4222" - } - - return nil -} \ No newline at end of file +/* +Package configstore handles the configuration of WACE. The +configuration file is parsed, checked for errors and loaded into +memory +*/ +package configstore + +import ( + "fmt" + "os" + + "github.com/tilsor/ModSecIntl_logging/logging" +) + +// ModelPluginType is an enum listing the parts of a request or +// response that a model plugin can handle. +type ModelPluginType int + +const ( + RequestHeaders ModelPluginType = iota + RequestBody + AllRequest + ResponseHeaders + ResponseBody + AllResponse + Everything +) + +// String returns the string representation of a model plugin type +func (t ModelPluginType) String() string { + switch t { + case RequestHeaders: + return "RequestHeaders" + case RequestBody: + return "RequestBody" + case AllRequest: + return "AllRequest" + case ResponseHeaders: + return "ResponseHeaders" + case ResponseBody: + return "ResponseBody" + case AllResponse: + return "AllResponse" + default: + return "Everything" + } +} + +// StringToPluginType converts a string to the corresponding model plugin type +func StringToPluginType(textType string) (ModelPluginType, error) { + switch textType { + case "RequestHeaders": + return RequestHeaders, nil + case "RequestBody": + return RequestBody, nil + case "AllRequest": + return AllRequest, nil + case "ResponseHeaders": + return ResponseHeaders, nil + case "ResponseBody": + return ResponseBody, nil + case "AllResponse": + return AllResponse, nil + case "Everything": + return Everything, nil + } + return -1, fmt.Errorf("invalid plugin type %s", textType) +} + +// ModelPluginConfig stores the configuration of a model plugin +type modelPluginConfig struct { + ID string + Path string + Weight float64 + Threshold float64 + Params map[string]string + PluginType ModelPluginType + Mode string + Remote bool +} + +// DecisionPluginConfig stores the configuration of a decision plugin +type decisionPluginConfig struct { + ID string + Path string + WAFweight float64 + DecisionBalance float64 + Params map[string]string +} + +// ConfigStore stores all wacecore configuration from the config file. +type ConfigStore struct { + ModelPlugins map[string]modelPluginConfig + DecisionPlugins map[string]decisionPluginConfig + LogPath string + LogLevel logging.LogLevel + NatsURL string + ApplicationId string +} + +var config *ConfigStore + +// Create and returns the unique instance of configstore if it does not exist previously, in other case returns error +func New() (*ConfigStore, error) { + if config != nil { + return nil, fmt.Errorf("ConfigStore: an instance already exists") + } + config = new(ConfigStore) + return config, nil +} + +// Get returns the unique instance of configstore +func Get() (*ConfigStore, error) { + if config == nil { + return nil, fmt.Errorf("ConfigStore: Configuration was not loaded") + } + return config, nil +} + +// Clean remove the references to the stored instance of configstore +func Clean() { + config = nil +} + +type configFileModelPlugin struct { + ID string + Path string + Weight float64 + Threshold float64 + Params map[string]string + PluginType string `yaml:"plugintype"` + Mode string + Remote bool +} + +type configFileDecisionPlugin struct { + ID string + Path string + wafweight float64 + decisionbalance float64 + Params map[string]string +} + +type ConfigFileData struct { + Logpath string + Loglevel string + Modelplugins []configFileModelPlugin + Decisionplugins []configFileDecisionPlugin + NatsURL string +} + +// IsAsync returns true if the model plugin is async +func (c *ConfigStore) IsAsync(modelID string) bool { + return c.ModelPlugins[modelID].Mode == "async" +} + +// CheckLogging verifies if the log path is valid +func checkLogging(inConf ConfigFileData) error { + // check logpath + if inConf.Logpath == "" { + return fmt.Errorf("log path empty") + } + _, err := os.Stat(inConf.Logpath) + if err != nil { // check if log file does not exists already + // Attempt to create dummy file + var d []byte + err = os.WriteFile(inConf.Logpath, d, 0644) + if err == nil { + err = os.Remove(inConf.Logpath) // delete it + } + } + return err +} + +// CheckConfig verifies if the configuration read from the config file +// is correct. +func checkConfig(inConf ConfigFileData) error { + err := checkLogging(inConf) + if err != nil { + return fmt.Errorf("invalid log path %s: %v", inConf.Logpath, err) + } + + // check modelplugins + for _, modelP := range inConf.Modelplugins { + + if modelP.Path != "" { + if _, err := os.Stat(modelP.Path); err != nil { + return fmt.Errorf("%s plugin path %s: %v", modelP.ID, modelP.Path, err) + } + } else { + return fmt.Errorf("%s plugin path is empty, please provide a valid path", modelP.ID) + } + if modelP.PluginType == "" { + return fmt.Errorf("%s plugin type cannot be empty, please provide a valid type", modelP.ID) + } + // fmt.Printf("modelP.Type: %s\n", modelP.Type) + } + // check decisionplugins + for _, decisionP := range inConf.Decisionplugins { + + if decisionP.Path != "" { + if _, err := os.Stat(decisionP.Path); err != nil { + return fmt.Errorf("%s plugin path %s cannot be opened: %v", decisionP.ID, decisionP.Path, err) + } + } else { + return fmt.Errorf("%s plugin path is empty, please provide a valid path", decisionP.ID) + } + } + + return nil +} + +// SetConfig sets the configuration of WACE from the configuration file +func (cs *ConfigStore) SetConfig(inConf ConfigFileData) error { + err := checkConfig(inConf) + if err != nil { + return err + } + + cs.LogPath = inConf.Logpath + cs.LogLevel, err = logging.StringToLogLevel(inConf.Loglevel) + if err != nil { + return err + } + + cs.ModelPlugins = make(map[string]modelPluginConfig) + for _, modelP := range inConf.Modelplugins { + var modelConfig modelPluginConfig + modelConfig.ID = modelP.ID + modelConfig.Path = modelP.Path + modelConfig.Weight = modelP.Weight + modelConfig.Threshold = modelP.Threshold + modelConfig.Params = modelP.Params + modelConfig.PluginType, err = StringToPluginType(modelP.PluginType) + modelConfig.Mode = modelP.Mode + modelConfig.Remote = modelP.Remote + if err != nil { + return err + } + cs.ModelPlugins[modelConfig.ID] = modelConfig + } + + cs.DecisionPlugins = make(map[string]decisionPluginConfig) + for _, decisionP := range inConf.Decisionplugins { + var decisionConfig decisionPluginConfig + decisionConfig.ID = decisionP.ID + decisionConfig.Path = decisionP.Path + decisionConfig.WAFweight = decisionP.wafweight + decisionConfig.DecisionBalance = decisionP.decisionbalance + decisionConfig.Params = decisionP.Params + cs.DecisionPlugins[decisionConfig.ID] = decisionConfig + } + + if inConf.NatsURL != "" { + cs.NatsURL = inConf.NatsURL + } else { + cs.NatsURL = "localhost:4222" + } + + return nil +} diff --git a/configstore/configstore_test.go b/configstore/configstore_test.go index c5ed707..93b3d97 100644 --- a/configstore/configstore_test.go +++ b/configstore/configstore_test.go @@ -1,267 +1,277 @@ -package configstore - -import ( - "fmt" - "os" - "testing" - - "gopkg.in/yaml.v3" -) - -var validConfig = []byte(`--- -logpath: "/dev/stderr" -loglevel: "DEBUG" -modelplugins: - - id: "trivial" - path: "../testdata/plugins/model/trivial.so" - weight: 1 - threshold: 0.5 - params: - d: "sds" - b: "dnid" - e: "dofnno" - plugintype: "RequestHeaders" - mode: "sync" - - id: "trivial2" - path: "../testdata/plugins/model/trivial2.so" - weight: 2 - threshold: 0.1 - params: - a: "sdsds" - b: "sdfjdnid" - c: "kfoskdofnno" - plugintype: "RequestHeaders" -decisionplugins: - - id: "test" - path: "../testdata/plugins/decision/test.so" - wafweight: 0.5 - decisionbalance: 0.5 - params: - ssdaf: "sdsds" - dsfb: "sdfjdnid" - csfd: "kfoskdofnno" -`) - -func initialize(configuration []byte) error { - cs := Get() - var aux ConfigFileData - err := yaml.Unmarshal(configuration, &aux) - if err != nil { - return err - } - err = cs.SetConfig(aux) - if err != nil { - return err - } - return nil -} - -func TestLoadConfigYamlEmpty(t *testing.T) { - - err := initialize([]byte(`---`)) - if err == nil { - t.Errorf("empty config does not return error") - } -} - -func TestLoadConfigYamlValid(t *testing.T) { - - err := initialize(validConfig) - if err != nil { - t.Errorf("valid config returned error: %v", err) - } -} - -func TestLoadConfigYamlInvalid(t *testing.T) { - - err := initialize([]byte(`()=)(/&/()~@#~½¬{[{½¬½---sfdjlskjfs#@~sjdfa`)) - - if err == nil { - t.Errorf("invalid config does not return error") - } -} - -func TestLoadConfigYamlLogLevel(t *testing.T) { - - values := []string{ - "a", - "4", - "0", - } - - for _, v := range values { - config := `--- -logpath: "/dev/null" -loglevel: ` + v - err := initialize([]byte(config)) - if err == nil { - t.Errorf("invalid log level %v does not return error", v) - } - } -} - -func TestLoadConfigYamlPluginType(t *testing.T) { - cs := Get() - - err := initialize([]byte(`--- -loglevel: ERROR -logpath: /dev/null -modelplugins: - - id: "testplugin" - path: "../testdata/plugins/model/trivial.so" - plugintype: InvalidPluginType -`)) - if err == nil { - t.Errorf("invalid plugin type does not return error") - } - - err = initialize([]byte(`--- -loglevel: ERROR -logpath: /dev/null -modelplugins: - - id: "testplugin" - path: "../testdata/plugins/model/trivial.so" - plugintype: "" -`)) - if err == nil { - t.Errorf("empty plugin type does not return error") - } - - err = initialize([]byte(`--- -loglevel: ERROR -logpath: /dev/null -modelplugins: - - id: "testplugin" - path: "../testdata/plugins/model/nonexistent.so" - plugintype: "RequestHeaders" -`)) - if err == nil { - t.Errorf("nonexistent model plugin path does not return error") - } - - err = initialize([]byte(`--- -loglevel: ERROR -logpath: /dev/null -modelplugins: - - id: "testplugin" - path: "" - plugintype: "RequestHeaders" -`)) - if err == nil { - t.Errorf("empty plugin path does not return error") - } - - err = initialize([]byte(`--- -loglevel: ERROR -logpath: /dev/null -decisionplugins: - - id: "test" - path: "" -`)) - if err == nil { - t.Errorf("empty decision plugin path does not return error") - } - - err = initialize([]byte(`--- -loglevel: ERROR -logpath: /dev/null -decisionplugins: - - id: "testplugin" - path: "../testdata/plugins/decision/nonexistent.so" -`)) - if err == nil { - t.Errorf("nonexistent decision plugin path does not return error") - } - - values := []string{ - "RequestHeaders", - "RequestBody", - "AllRequest", - "ResponseHeaders", - "ResponseBody", - "AllResponse", - "Everything", - } - - for _, v := range values { - config := `--- -loglevel: ERROR -logpath: /dev/null -modelplugins: - - id: "testplugin" - path: "../testdata/plugins/model/trivial.so" - plugintype: "` + v + `" -` - err = initialize([]byte(config)) - if err != nil { - t.Errorf("Plugin type %s returns error: %v", v, err) - } - - if fmt.Sprint(cs.ModelPlugins["testplugin"].PluginType) != v { - t.Errorf("Stored plugin type is %v, expected %v", cs.ModelPlugins["testplugin"].PluginType, v) - } - } -} - -// func TestLoadConfig(t *testing.T) { -// cs := Get() - -// err := cs.LoadConfig("") -// if err == nil { -// t.Errorf("empty config file path does not return error") -// } - -// err = cs.LoadConfig("/dev/null") -// if err == nil { -// t.Errorf("empty config file contents does not return error") -// } - -// tmpFile, err := ioutil.TempFile(os.TempDir(), "configstore_test-") -// if err != nil { -// t.Errorf("cannot create temporary file: %v", err) -// } -// defer os.Remove(tmpFile.Name()) - -// if _, err = tmpFile.Write(validConfig); err != nil { -// t.Errorf("failed to write to temporary file: %v", err) -// } -// err = cs.LoadConfig(tmpFile.Name()) -// if err != nil { -// t.Errorf("valid config file returned error: %v", err) -// } -// } - -func TestInvalidLogging(t *testing.T) { - - err := initialize([]byte(`--- -loglevel: INVALIDLOGLEVEL -logpath: /dev/null -`)) - if err == nil { - t.Errorf("invalid log level does not return error") - } - - if _, err = os.Stat("./configstore_test.log"); err == nil { - err = os.Remove("./configstore_test.log") - if err != nil { - t.Errorf("could not remove ./configstore_test.log") - } - } - - err = initialize([]byte(`--- -loglevel: ERROR -logpath: ./configstore_test.log`)) - - if err != nil { - t.Errorf("Error loading config with nonexistent file: %v", err) - } - - err = initialize([]byte(`--- -loglevel: ERROR -logpath: /usr/configstore_test.log`)) - - if err == nil { - t.Errorf("non existent log file in directory without permissions does not rise error") - } - -} +package configstore + +import ( + "fmt" + "os" + "testing" + + "gopkg.in/yaml.v3" +) + +var validConfig = []byte(`--- +logpath: "/dev/stderr" +loglevel: "DEBUG" +modelplugins: + - id: "trivial" + path: "../testdata/plugins/model/trivial.so" + weight: 1 + threshold: 0.5 + params: + d: "sds" + b: "dnid" + e: "dofnno" + plugintype: "RequestHeaders" + mode: "sync" + - id: "trivial2" + path: "../testdata/plugins/model/trivial2.so" + weight: 2 + threshold: 0.1 + params: + a: "sdsds" + b: "sdfjdnid" + c: "kfoskdofnno" + plugintype: "RequestHeaders" +decisionplugins: + - id: "test" + path: "../testdata/plugins/decision/test.so" + wafweight: 0.5 + decisionbalance: 0.5 + params: + ssdaf: "sdsds" + dsfb: "sdfjdnid" + csfd: "kfoskdofnno" +`) + +func initialize(configuration []byte) error { + cs, err := Get() + if err != nil { + return err + } + var aux ConfigFileData + err = yaml.Unmarshal(configuration, &aux) + if err != nil { + return err + } + err = cs.SetConfig(aux) + if err != nil { + return err + } + return nil +} + +func TestLoadConfigYamlEmpty(t *testing.T) { + _, err := New() + if err != nil { + t.Error(err) + } + + defer Clean() + + err = initialize([]byte(`---`)) + if err == nil { + t.Error("empty config does not return error") + } +} + +func TestLoadConfigYamlValid(t *testing.T) { + _, err := New() + if err != nil { + t.Error(err) + } + + defer Clean() + + err = initialize(validConfig) + if err != nil { + t.Errorf("valid config returned error: %v", err) + } +} + +func TestLoadConfigYamlInvalid(t *testing.T) { + _, err := New() + if err != nil { + t.Error(err) + } + + defer Clean() + + err = initialize([]byte(`()=)(/&/()~@#~½¬{[{½¬½---sfdjlskjfs#@~sjdfa`)) + + if err == nil { + t.Error("invalid config does not return error") + } +} + +func TestLoadConfigYamlLogLevel(t *testing.T) { + _, err := New() + if err != nil { + t.Error(err) + } + + defer Clean() + + values := []string{ + "a", + "4", + "0", + } + + for _, v := range values { + config := `--- +logpath: "/dev/null" +loglevel: ` + v + err = initialize([]byte(config)) + if err == nil { + t.Errorf("invalid log level %v does not return error", v) + } + } +} + +func TestLoadConfigYamlPluginType(t *testing.T) { + cs, err := New() + if err != nil { + t.Error(err) + } + + defer Clean() + + err = initialize([]byte(`--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "../testdata/plugins/model/trivial.so" + plugintype: InvalidPluginType +`)) + if err == nil { + t.Errorf("invalid plugin type does not return error") + } + + err = initialize([]byte(`--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "../testdata/plugins/model/trivial.so" + plugintype: "" +`)) + if err == nil { + t.Errorf("empty plugin type does not return error") + } + + err = initialize([]byte(`--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "../testdata/plugins/model/nonexistent.so" + plugintype: "RequestHeaders" +`)) + if err == nil { + t.Errorf("nonexistent model plugin path does not return error") + } + + err = initialize([]byte(`--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "" + plugintype: "RequestHeaders" +`)) + if err == nil { + t.Errorf("empty plugin path does not return error") + } + + err = initialize([]byte(`--- +loglevel: ERROR +logpath: /dev/null +decisionplugins: + - id: "test" + path: "" +`)) + if err == nil { + t.Errorf("empty decision plugin path does not return error") + } + + err = initialize([]byte(`--- +loglevel: ERROR +logpath: /dev/null +decisionplugins: + - id: "testplugin" + path: "../testdata/plugins/decision/nonexistent.so" +`)) + if err == nil { + t.Errorf("nonexistent decision plugin path does not return error") + } + + values := []string{ + "RequestHeaders", + "RequestBody", + "AllRequest", + "ResponseHeaders", + "ResponseBody", + "AllResponse", + "Everything", + } + + for _, v := range values { + config := `--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "../testdata/plugins/model/trivial.so" + plugintype: "` + v + `" +` + err = initialize([]byte(config)) + if err != nil { + t.Errorf("Plugin type %s returns error: %v", v, err) + } + + if fmt.Sprint(cs.ModelPlugins["testplugin"].PluginType) != v { + t.Errorf("Stored plugin type is %v, expected %v", cs.ModelPlugins["testplugin"].PluginType, v) + } + } +} + +func TestInvalidLogging(t *testing.T) { + _, err := New() + if err != nil { + t.Error(err) + } + + defer Clean() + + err = initialize([]byte(`--- +loglevel: INVALIDLOGLEVEL +logpath: /dev/null +`)) + if err == nil { + t.Errorf("invalid log level does not return error") + } + + if _, err = os.Stat("./configstore_test.log"); err == nil { + err = os.Remove("./configstore_test.log") + if err != nil { + t.Errorf("could not remove ./configstore_test.log") + } + } + + err = initialize([]byte(`--- +loglevel: ERROR +logpath: ./configstore_test.log`)) + + if err != nil { + t.Errorf("Error loading config with nonexistent file: %v", err) + } + + err = initialize([]byte(`--- +loglevel: ERROR +logpath: /usr/configstore_test.log`)) + + if err == nil { + t.Errorf("non existent log file in directory without permissions does not rise error") + } + +} diff --git a/go.mod b/go.mod index fa6770e..aabc361 100644 --- a/go.mod +++ b/go.mod @@ -1,26 +1,27 @@ module github.com/tilsor/ModSecIntl_wace_lib -go 1.22.9 +go 1.26.2 require ( - github.com/nats-io/nats.go v1.38.0 + github.com/nats-io/nats.go v1.51.0 github.com/tilsor/ModSecIntl_logging v1.0.1 - go.opentelemetry.io/otel v1.34.0 - go.opentelemetry.io/otel/metric v1.34.0 - go.opentelemetry.io/otel/sdk/metric v1.34.0 + go.opentelemetry.io/otel v1.43.0 + go.opentelemetry.io/otel/metric v1.43.0 + go.opentelemetry.io/otel/sdk/metric v1.43.0 gopkg.in/yaml.v3 v3.0.1 ) require ( - github.com/go-logr/logr v1.4.2 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/google/uuid v1.6.0 // indirect - github.com/klauspost/compress v1.17.9 // indirect - github.com/nats-io/nkeys v0.4.9 // indirect + github.com/klauspost/compress v1.18.5 // indirect + github.com/nats-io/nkeys v0.4.15 // indirect github.com/nats-io/nuid v1.0.1 // indirect - go.opentelemetry.io/auto/sdk v1.1.0 // indirect - go.opentelemetry.io/otel/sdk v1.34.0 // indirect - go.opentelemetry.io/otel/trace v1.34.0 // indirect - golang.org/x/crypto v0.31.0 // indirect - golang.org/x/sys v0.29.0 // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect + go.opentelemetry.io/otel/sdk v1.43.0 // indirect + go.opentelemetry.io/otel/trace v1.43.0 // indirect + golang.org/x/crypto v0.50.0 // indirect + golang.org/x/sys v0.43.0 // indirect ) diff --git a/go.sum b/go.sum index ecc9057..7b259f6 100644 --- a/go.sum +++ b/go.sum @@ -1,50 +1,52 @@ +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= -github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= -github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= +github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/nats-io/nats.go v1.38.0 h1:A7P+g7Wjp4/NWqDOOP/K6hfhr54DvdDQUznt5JFg9XA= -github.com/nats-io/nats.go v1.38.0/go.mod h1:IGUM++TwokGnXPs82/wCuiHS02/aKrdYUQkU8If6yjw= -github.com/nats-io/nkeys v0.4.9 h1:qe9Faq2Gxwi6RZnZMXfmGMZkg3afLLOtrU+gDZJ35b0= -github.com/nats-io/nkeys v0.4.9/go.mod h1:jcMqs+FLG+W5YO36OX6wFIFcmpdAns+w1Wm6D3I/evE= +github.com/nats-io/nats.go v1.51.0 h1:ByW84XTz6W03GSSsygsZcA+xgKK8vPGaa/FCAAEHnAI= +github.com/nats-io/nats.go v1.51.0/go.mod h1:26HypzazeOkyO3/mqd1zZd53STJN0EjCYF9Uy2ZOBno= +github.com/nats-io/nkeys v0.4.15 h1:JACV5jRVO9V856KOapQ7x+EY8Jo3qw1vJt/9Jpwzkk4= +github.com/nats-io/nkeys v0.4.15/go.mod h1:CpMchTXC9fxA5zrMo4KpySxNjiDVvr8ANOSZdiNfUrs= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= -github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tilsor/ModSecIntl_logging v1.0.1 h1:wFd3SxJPUU5JxX2UlrsH0Ef/m9a18d/VRo4xVCtCxVM= github.com/tilsor/ModSecIntl_logging v1.0.1/go.mod h1:9RrpYmS4v/wYIiiYXzDW6Lqr8Xb8wq3ejpHi8jmQsyo= -go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= -go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= -go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY= -go.opentelemetry.io/otel v1.34.0/go.mod h1:OWFPOQ+h4G8xpyjgqo4SxJYdDQ/qmRH+wivy7zzx9oI= -go.opentelemetry.io/otel/metric v1.34.0 h1:+eTR3U0MyfWjRDhmFMxe2SsW64QrZ84AOhvqS7Y+PoQ= -go.opentelemetry.io/otel/metric v1.34.0/go.mod h1:CEDrp0fy2D0MvkXE+dPV7cMi8tWZwX3dmaIhwPOaqHE= -go.opentelemetry.io/otel/sdk v1.34.0 h1:95zS4k/2GOy069d321O8jWgYsW3MzVV+KuSPKp7Wr1A= -go.opentelemetry.io/otel/sdk v1.34.0/go.mod h1:0e/pNiaMAqaykJGKbi+tSjWfNNHMTxoC9qANsCzbyxU= -go.opentelemetry.io/otel/sdk/metric v1.34.0 h1:5CeK9ujjbFVL5c1PhLuStg1wxA7vQv7ce1EK0Gyvahk= -go.opentelemetry.io/otel/sdk/metric v1.34.0/go.mod h1:jQ/r8Ze28zRKoNRdkjCZxfs6YvBTG1+YIqyFVFYec5w= -go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k= -go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE= -golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= -golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= -golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= +go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= +go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= +go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= +go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= +go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= +go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= +go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= +go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= +go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/pluginmanager/pluginmanager.go b/pluginmanager/pluginmanager.go index 21b4534..e7204ce 100644 --- a/pluginmanager/pluginmanager.go +++ b/pluginmanager/pluginmanager.go @@ -1,459 +1,496 @@ -/* -Package pluginmanager handles the communication with the model and -decision plugins -*/ -package pluginmanager - -import ( - "encoding/json" - "fmt" - "plugin" - "sync" - - cf "github.com/tilsor/ModSecIntl_wace_lib/configstore" - "go.opentelemetry.io/otel/metric" - - "github.com/nats-io/nats.go" - lg "github.com/tilsor/ModSecIntl_logging/logging" -) - -// ResultData maps the model plugin ID with the corresponding analysis result. -type ModelResults struct { - ProbAttack float64 `json:"probattack"` - Data map[string]interface{} `json:"data"` -} - -// ModelInput is the struct that contains the input data for the model plugin -type ModelInput struct { - TransactionId string `json:"transactionId"` - Payload string `json:"payload"` -} - -// DecisionInput is the struct that contains the input data for the decision plugin -type DecisionInput struct { - TransactionId string - Results map[string]ModelResults - ModelWeight map[string]float64 - WAFdata map[string]string -} - -// ModelTransmitionResults is the struct that contains the results of the model plugin -type ModelTransmitionResults struct { - TransactionId string `json:"transactionId"` - ModelResults `json:",inline"` - Error error `json:"error"` -} - -// modelPlugin is the struct that stores the model plugin and its type -type modelPlugin struct { - p *plugin.Plugin - pluginType cf.ModelPluginType -} - -// decisionPlugin is the struct that stores the decision plugin -type decisionPlugin struct { - p *plugin.Plugin -} - -// ModelStatus stores whether there was an error while processing a -// request (response) by the modelID model plugin -type ModelStatus struct { - ModelID string - ProbAttack float64 - Err error -} - -// PluginManager is the main plugin struct storing information of -// every plugin execution. -type PluginManager struct { - modelPlugins map[string]modelPlugin - modelProcessFunc map[string]func(ModelInput) (ModelResults, error) - decisionCheckFunc map[string]func(DecisionInput) (bool, error) - decisionPlugins map[string]decisionPlugin - results sync.Map - channelsMutex sync.Mutex - syncModelsChannels sync.Map - asyncModelsChannels sync.Map - natConn *nats.Conn -} - -// New creates a new PluginManager instance. -func New(meter metric.Meter) *PluginManager { - pm := new(PluginManager) - conf := cf.Get() - logger := lg.Get() - logger.Printf(lg.DEBUG, "Connecting to NATS server at %s", conf.NatsURL) - - nc, err := nats.Connect(conf.NatsURL) - - if err != nil { - logger.Printf(lg.ERROR, "Failed to connect to NATS server") - } - - pm.natConn = nc - - // Loading of model plugins - pm.modelPlugins = make(map[string]modelPlugin) - pm.modelProcessFunc = make(map[string]func(ModelInput) (ModelResults, error)) - for _, data := range conf.ModelPlugins { - tp, err := plugin.Open(data.Path) - if err != nil { - logger.Printf(lg.WARN, "| %s | cannot load plugin: %v", data.ID, err) - continue - } - if data.Mode == "async" || conf.ModelPlugins[data.ID].Remote { - f, err := tp.Lookup("InitPluginAsync") - if err != nil { - logger.Printf(lg.WARN, "| %s | cannot load plugin: %v", data.ID, err) - continue - } - initPlugin, ok := f.(func(map[string]string, metric.Meter, func(func(ModelInput) (ModelResults, error))) error) - if !ok { - logger.Printf(lg.WARN, "| %s | cannot load plugin: invalid InitPluginAsync function type", data.ID) - continue - } - err = initPlugin(data.Params, meter, func(modelProcess func(ModelInput) (ModelResults, error)) { - ModelProcessHandler(data.ID, modelProcess) - }) - if err != nil { - logger.Printf(lg.WARN, "| %s | cannot load plugin: %v", data.ID, err) - continue - } - go pm.ModelResultsHandler(data.ID) - } else { - f, err := tp.Lookup("InitPlugin") - if err != nil { - logger.Printf(lg.WARN, "| %s | cannot load plugin: %v", data.ID, err) - continue - } - initPlugin, ok := f.(func(map[string]string, metric.Meter) error) - if !ok { - logger.Printf(lg.WARN, "| %s | cannot load plugin: invalid InitPlugin function type", data.ID) - continue - } - err = initPlugin(data.Params, meter) - procFunc, err := tp.Lookup("Process") - if err != nil { - logger.Printf(lg.WARN, "| %s | cannot load plugin: cannot load Process function", data.ID) - continue - } - process, ok := procFunc.(func(ModelInput) (ModelResults, error)) - if !ok { - logger.Printf(lg.WARN, "| %s | cannot load plugin: invalid Process function type", data.ID) - continue - } - pm.modelProcessFunc[data.ID] = process - } - modelPluginLoaded := modelPlugin{tp, data.PluginType} - pm.modelPlugins[data.ID] = modelPluginLoaded - logger.Printf(lg.INFO, "| %s | plugin loaded", data.ID) - } - - pm.decisionPlugins = make(map[string]decisionPlugin) - pm.decisionCheckFunc = make(map[string]func(DecisionInput) (bool, error)) - // Loading of decision plugins - for _, data := range conf.DecisionPlugins { - tp, err := plugin.Open(data.Path) - if err != nil { - logger.Printf(lg.WARN, "| %s | cannot load plugin: %v", data.ID, err) - continue - } - f, err := tp.Lookup("InitPlugin") - if err != nil { - logger.Printf(lg.WARN, "| %s | cannot load plugin: %v", data.ID, err) - continue - } - initPlugin, ok := f.(func(map[string]string, metric.Meter) error) - if !ok { - logger.Printf(lg.WARN, "| %s | cannot load plugin: invalid InitPlugin function type", data.ID) - continue - } - err = initPlugin(data.Params, meter) - if err != nil { - logger.Printf(lg.WARN, "| %s | cannot load plugin: %v", data.ID, err) - continue - } - cR, err := tp.Lookup("CheckResults") - if err != nil { - logger.Printf(lg.ERROR, "| %s | cannot load plugin check results function: %v", data.ID, err) - continue - } - checkResults, ok := cR.(func(DecisionInput) (bool, error)) - if !ok { - logger.Printf(lg.ERROR, "| %s | CheckResults lookup failed for plugin: invalid function type", data.ID) - continue - } - pm.decisionCheckFunc[data.ID] = checkResults - decisionPluginLoaded := decisionPlugin{tp} - pm.decisionPlugins[data.ID] = decisionPluginLoaded - } - return pm -} - -// InitTransaction initializes the transaction with the given ID -func (p *PluginManager) InitTransaction(transactionId string) { - p.results.Store(transactionId, new(sync.Map)) -} - -// CloseTransaction closes the transaction with the given ID -// removing all sync model data -func (p *PluginManager) CloseTransaction(transactionId string) { - logger := lg.Get() - transactionMap, ok := p.syncModelsChannels.Load(transactionId) - if !ok { - logger.TPrintf(lg.ERROR, transactionId, "Transaction %s not found", transactionId) - } else { - transactionMap.(*sync.Map).Range(func(key, value interface{}) bool { - ch := value.(chan ModelStatus) - close(ch) - for range ch {} - transactionMap.(*sync.Map).Delete(key) - return true - }) - p.syncModelsChannels.Delete(transactionId) - resultsMap, ok := p.results.Load(transactionId) - if !ok { - logger.TPrintf(lg.ERROR, transactionId, "Results for transaction %s not found", transactionId) - } else { - resultsMap.(*sync.Map).Range(func(key, value interface{}) bool { - resultsMap.(*sync.Map).Delete(key) - return true - }) - } - p.results.Delete(transactionId) - } -} - -// AddModelChannel adds a channel to result channel map -func (p *PluginManager) AddModelChannel(transactionId string, t cf.ModelPluginType, modelPlugStatus chan ModelStatus, modelType string) { - typeModel := new(sync.Map) - var value interface{} - if modelType == "sync" { - value, _ = p.syncModelsChannels.LoadOrStore(transactionId, typeModel) - } else { - value, _ = p.asyncModelsChannels.LoadOrStore(transactionId, typeModel) - } - value.(*sync.Map).Store(t.String(), modelPlugStatus) -} - -// RemoveModelChannel removes a channel from the result channel map -func (p *PluginManager) RemoveAsyncModelChannel(transactionId string, t cf.ModelPluginType) { - typeModel, ok := p.asyncModelsChannels.Load(transactionId) - if ok { - channelMap := typeModel.(*sync.Map) - ch, channelOk := channelMap.Load(t.String()) - - if channelOk { - close(ch.(chan ModelStatus)) - for range ch.(chan ModelStatus) {} - channelMap.Delete(t.String()) - } - - remainChannels := 0 - typeModel.(*sync.Map).Range(func(key, value interface{}) bool { - remainChannels++ - return true - }) - if remainChannels == 0 { - p.asyncModelsChannels.Delete(transactionId) - } - } else { - logger := lg.Get() - logger.TPrintf(lg.ERROR, transactionId, "Transaction %s not found when trying to remove async model channel", transactionId) - } -} - -// AddToQueue adds a payload to the model queue -func (p *PluginManager) AddToQueue(modelId, transactionId, payload string) error { - payloadToSend := &ModelInput{ - TransactionId: transactionId, - Payload: payload, - } - - jsonPayload, err := json.Marshal(payloadToSend) - - if err != nil { - return err - } - - return p.natConn.Publish(modelId, jsonPayload) -} - -// Process is in charge of calling the model plugin with id modelID -func (p *PluginManager) Process(modelID, transactionId, payload string, t cf.ModelPluginType, modelPlugStatus chan ModelStatus) { - conf := cf.Get() - - mp, exists := p.modelPlugins[modelID] - if !exists { - modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("model plugin not found")} - return - } - - // check if the plugin is capable of analyzing the indicated part of the transaction - if mp.pluginType != t { - modelPlugStatus <- ModelStatus{ModelID: modelID, - Err: fmt.Errorf("plugin type %v cannot process a request with incompatible type %v", mp.pluginType, t)} - return - } - - process := p.modelProcessFunc[modelID] - - if conf.ModelPlugins[modelID].Mode == "async" { - modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("model plugin is async")} - return - } else { - res, err := process(ModelInput{TransactionId: transactionId, Payload: payload}) - // res, err := process(transactionId, payload) - - if err != nil { - modelPlugStatus <- ModelStatus{ModelID: modelID, Err: err} - return - } - // store the results - resultSyncMap, ok := p.results.Load(transactionId) - if !ok { - modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("transaction results not found")} - return - } - resultSyncMap.(*sync.Map).Store(modelID, res) - modelPlugStatus <- ModelStatus{ModelID: modelID, ProbAttack: res.ProbAttack, Err: nil} - } -} - -// CheckResult is in charge of calling the decision plugin with id decisionID over the -// transaction with id transactID -func (p *PluginManager) CheckResult(transactionId, decisionId string, wafParams map[string]string) (bool, error) { - logger := lg.Get() - - checkResults, ok := p.decisionCheckFunc[decisionId] - if !ok { - return false, fmt.Errorf("decision plugin not found") - } - - transactionResults, ok := p.results.Load(transactionId) - if !ok { - return false, fmt.Errorf("transaction results not found") - } - - configStore := cf.Get() - - modelResultMap := make(map[string]ModelResults) - modelWeightMap := make(map[string]float64) - transactionResults.(*sync.Map).Range(func(key, value interface{}) bool { - modelResultMap[key.(string)] = value.(ModelResults) - modelWeightMap[key.(string)] = configStore.ModelPlugins[key.(string)].Weight - return true - }) - - res, err := checkResults(DecisionInput{TransactionId: transactionId, Results: modelResultMap, ModelWeight: modelWeightMap, WAFdata: wafParams}) - logger.TPrintf(lg.INFO, transactionId, "%s | transaction checked. Block: %t ", decisionId, res) - - return res, err -} - -// ModelResultsHandler listens for messages on the model results queue -func (p *PluginManager) ModelResultsHandler(modelId string) { - logger := lg.Get() - conf := cf.Get() - - sub, err := p.natConn.Subscribe(modelId+"/results", func(msg *nats.Msg) { - go func(msg nats.Msg) { - data := &ModelTransmitionResults{} - err := json.Unmarshal(msg.Data, data) - if err != nil { - logger.Printf(lg.ERROR, "Model: %s | Failed to parse JSON payload", modelId) - } else { - var channel interface{} - var ok bool - if conf.ModelPlugins[modelId].Mode == "async" { - channel, ok = p.asyncModelsChannels.Load(data.TransactionId) - } else { - channel, ok = p.syncModelsChannels.Load(data.TransactionId) - } - if !ok { - logger.TPrintf(lg.ERROR, data.TransactionId, " Model %s | Transaction not found", modelId) - } else { - modelChannel, ok := channel.(*sync.Map).Load(conf.ModelPlugins[modelId].PluginType.String()) - if !ok { - logger.Printf(lg.ERROR, "Model %s not found", modelId) - } else { - if data.Error != nil { - modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, Err: data.Error} - } else { - if conf.ModelPlugins[modelId].Mode != "async" { - // store the results - resultSyncMap, ok := p.results.Load(data.TransactionId) - if !ok { - modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, Err: fmt.Errorf("transaction results not found")} - return - } - modelResult := ModelResults{ProbAttack: data.ProbAttack, Data: data.Data} - resultSyncMap.(*sync.Map).Store(modelId, modelResult) - } - modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, ProbAttack: data.ProbAttack, Err: nil} - } - } - } - } - }(*msg) - }) - - if err != nil { - logger.Printf(lg.ERROR, "Model: %s | Failed to subscribe to model queue | %s", modelId, err.Error()) - return - } - - logger.Printf(lg.INFO, "Model: %s | Listening for messages on model results queue", modelId) - - defer sub.Unsubscribe() - defer p.natConn.Drain() - - select {} -} - -// ModelProcessHandler listens for messages on the model queue -func ModelProcessHandler(modelId string, modelProcess func(ModelInput) (ModelResults, error)) { - logger := lg.Get() - logger.Printf(lg.INFO, "Model: %s | Starting model process handler", modelId) - conf := cf.Get() - - nc, err := nats.Connect(conf.NatsURL) - - if err != nil { - logger.Printf(lg.ERROR, "Model: %s | Failed to connect to NATS server", modelId) - return - } - - _, err = nc.Subscribe(modelId, func(msg *nats.Msg) { - go func(msg nats.Msg) { - data := &ModelInput{} - err := json.Unmarshal(msg.Data, data) - if err != nil { - logger.Printf(lg.ERROR, "Model: %s | Failed to parse JSON payload", modelId) - } else { - res, err := modelProcess(*data) - modelResult := ModelResults{ProbAttack: res.ProbAttack, Data: res.Data} - payloadToSend := &ModelTransmitionResults{ - TransactionId: data.TransactionId, - ModelResults: modelResult, - Error: err, - } - - jsonPayload, err := json.Marshal(payloadToSend) - - if err != nil { - logger.Printf(lg.ERROR, "Model: %s | Failed to parse JSON payload", modelId) - } - - nc.Publish(modelId+"/results", jsonPayload) - } - }(*msg) - }) - - if err != nil { - logger.Printf(lg.ERROR, "Model: %s | Failed to subscribe to model queue | %s", modelId, err.Error()) - return - } - - logger.Printf(lg.INFO, "Model: %s | Listening for messages on model queue", modelId) -} +/* +Package pluginmanager handles the communication with the model and +decision plugins +*/ +package pluginmanager + +import ( + "encoding/json" + "fmt" + "plugin" + "sync" + + "github.com/tilsor/ModSecIntl_wace_lib/configstore" + "go.opentelemetry.io/otel/metric" + + "github.com/nats-io/nats.go" + "github.com/tilsor/ModSecIntl_logging/logging" +) + +// ResultData maps the model plugin ID with the corresponding analysis result. +type ModelResults struct { + ProbAttack float64 `json:"probattack"` + Data map[string]interface{} `json:"data"` +} + +type HTTPHeader struct { + Key string + Value string +} + +type HTTPPayload struct { + URI string + Method string + HTTPVersion string + RequestHeaders []HTTPHeader + RequestBody string + ResponseProtocol string + ResponseCode int + ResponseHeaders []HTTPHeader + ResponseBody string +} + +// ModelInput is the struct that contains the input data for the model plugin +type ModelInput struct { + TransactionId string `json:"transactionId"` + Payload HTTPPayload `json:"payload"` +} + +// DecisionInput is the struct that contains the input data for the decision plugin +type DecisionInput struct { + TransactionId string + Results map[string]ModelResults + ModelWeight map[string]float64 + WAFdata map[string]string +} + +// ModelTransmitionResults is the struct that contains the results of the model plugin +type ModelTransmitionResults struct { + TransactionId string `json:"transactionId"` + ModelResults `json:",inline"` + Error error `json:"error"` +} + +// modelPlugin is the struct that stores the model plugin and its type +type modelPlugin struct { + p *plugin.Plugin + pluginType configstore.ModelPluginType +} + +// decisionPlugin is the struct that stores the decision plugin +type decisionPlugin struct { + p *plugin.Plugin +} + +// ModelStatus stores whether there was an error while processing a +// request (response) by the modelID model plugin +type ModelStatus struct { + ModelID string + ProbAttack float64 + Err error +} + +// PluginManager is the main plugin struct storing information of +// every plugin execution. +type PluginManager struct { + modelPlugins map[string]modelPlugin + modelProcessFunc map[string]func(ModelInput) (ModelResults, error) + decisionCheckFunc map[string]func(DecisionInput) (bool, error) + decisionPlugins map[string]decisionPlugin + results sync.Map + channelsMutex sync.Mutex + syncModelsChannels sync.Map + asyncModelsChannels sync.Map + natConn *nats.Conn +} + +// New creates a new PluginManager instance. +func New(meter metric.Meter) (*PluginManager, error) { + pm := new(PluginManager) + conf, err := configstore.Get() + if err != nil { + return nil, err + } + logger := logging.Get() + logger.Printf(logging.DEBUG, "Connecting to NATS server at %s", conf.NatsURL) + + nc, err := nats.Connect(conf.NatsURL) + + if err != nil { + logger.Printf(logging.ERROR, "Failed to connect to NATS server") + } + + pm.natConn = nc + + // Loading of model plugins + pm.modelPlugins = make(map[string]modelPlugin) + pm.modelProcessFunc = make(map[string]func(ModelInput) (ModelResults, error)) + for _, data := range conf.ModelPlugins { + tp, err := plugin.Open(data.Path) + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) + continue + } + if data.Mode == "async" || conf.ModelPlugins[data.ID].Remote { + f, err := tp.Lookup("InitPluginAsync") + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) + continue + } + initPlugin, ok := f.(func(map[string]string, metric.Meter, func(func(ModelInput) (ModelResults, error))) error) + if !ok { + logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid InitPluginAsync function type", data.ID) + continue + } + err = initPlugin(data.Params, meter, func(modelProcess func(ModelInput) (ModelResults, error)) { + ModelProcessHandler(data.ID, modelProcess) + }) + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) + continue + } + go pm.ModelResultsHandler(data.ID) + } else { + f, err := tp.Lookup("InitPlugin") + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) + continue + } + initPlugin, ok := f.(func(map[string]string, metric.Meter) error) + if !ok { + logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid InitPlugin function type", data.ID) + continue + } + err = initPlugin(data.Params, meter) + procFunc, err := tp.Lookup("Process") + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: cannot load Process function", data.ID) + continue + } + process, ok := procFunc.(func(ModelInput) (ModelResults, error)) + if !ok { + logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid Process function type", data.ID) + continue + } + pm.modelProcessFunc[data.ID] = process + } + modelPluginLoaded := modelPlugin{tp, data.PluginType} + pm.modelPlugins[data.ID] = modelPluginLoaded + logger.Printf(logging.INFO, "| %s | plugin loaded", data.ID) + } + + pm.decisionPlugins = make(map[string]decisionPlugin) + pm.decisionCheckFunc = make(map[string]func(DecisionInput) (bool, error)) + // Loading of decision plugins + for _, data := range conf.DecisionPlugins { + tp, err := plugin.Open(data.Path) + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) + continue + } + f, err := tp.Lookup("InitPlugin") + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) + continue + } + initPlugin, ok := f.(func(map[string]string, metric.Meter) error) + if !ok { + logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid InitPlugin function type", data.ID) + continue + } + err = initPlugin(data.Params, meter) + if err != nil { + logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) + continue + } + cR, err := tp.Lookup("CheckResults") + if err != nil { + logger.Printf(logging.ERROR, "| %s | cannot load plugin check results function: %v", data.ID, err) + continue + } + checkResults, ok := cR.(func(DecisionInput) (bool, error)) + if !ok { + logger.Printf(logging.ERROR, "| %s | CheckResults lookup failed for plugin: invalid function type", data.ID) + continue + } + pm.decisionCheckFunc[data.ID] = checkResults + decisionPluginLoaded := decisionPlugin{tp} + pm.decisionPlugins[data.ID] = decisionPluginLoaded + } + return pm, nil +} + +// InitTransaction initializes the transaction with the given ID +func (p *PluginManager) InitTransaction(transactionId string) { + p.results.Store(transactionId, new(sync.Map)) +} + +// CloseTransaction closes the transaction with the given ID +// removing all sync model data +func (p *PluginManager) CloseTransaction(transactionId string) { + logger := logging.Get() + transactionMap, ok := p.syncModelsChannels.Load(transactionId) + if !ok { + logger.TPrintf(logging.ERROR, transactionId, "Transaction %s not found", transactionId) + } else { + transactionMap.(*sync.Map).Range(func(key, value interface{}) bool { + ch := value.(chan ModelStatus) + close(ch) + for range ch { + } + transactionMap.(*sync.Map).Delete(key) + return true + }) + p.syncModelsChannels.Delete(transactionId) + resultsMap, ok := p.results.Load(transactionId) + if !ok { + logger.TPrintf(logging.ERROR, transactionId, "Results for transaction %s not found", transactionId) + } else { + resultsMap.(*sync.Map).Range(func(key, value interface{}) bool { + resultsMap.(*sync.Map).Delete(key) + return true + }) + } + p.results.Delete(transactionId) + } +} + +// AddModelChannel adds a channel to result channel map +func (p *PluginManager) AddModelChannel(transactionId string, t configstore.ModelPluginType, modelPlugStatus chan ModelStatus, modelType string) { + typeModel := new(sync.Map) + var value interface{} + if modelType == "sync" { + value, _ = p.syncModelsChannels.LoadOrStore(transactionId, typeModel) + } else { + value, _ = p.asyncModelsChannels.LoadOrStore(transactionId, typeModel) + } + value.(*sync.Map).Store(t.String(), modelPlugStatus) +} + +// RemoveModelChannel removes a channel from the result channel map +func (p *PluginManager) RemoveAsyncModelChannel(transactionId string, t configstore.ModelPluginType) { + typeModel, ok := p.asyncModelsChannels.Load(transactionId) + if ok { + channelMap := typeModel.(*sync.Map) + ch, channelOk := channelMap.Load(t.String()) + + if channelOk { + close(ch.(chan ModelStatus)) + for range ch.(chan ModelStatus) { + } + channelMap.Delete(t.String()) + } + + remainChannels := 0 + typeModel.(*sync.Map).Range(func(key, value interface{}) bool { + remainChannels++ + return true + }) + if remainChannels == 0 { + p.asyncModelsChannels.Delete(transactionId) + } + } else { + logger := logging.Get() + logger.TPrintf(logging.ERROR, transactionId, "Transaction %s not found when trying to remove async model channel", transactionId) + } +} + +// AddToQueue adds a payload to the model queue +func (p *PluginManager) AddToQueue(modelID, transactionID string, payload HTTPPayload) error { + payloadToSend := &ModelInput{ + TransactionId: transactionID, + Payload: payload, + } + + jsonPayload, err := json.Marshal(payloadToSend) + + if err != nil { + return err + } + + return p.natConn.Publish(modelID, jsonPayload) +} + +// Process is in charge of calling the model plugin with id modelID +func (p *PluginManager) Process(modelID, transactionId string, payload HTTPPayload, t configstore.ModelPluginType, modelPlugStatus chan ModelStatus) error { + conf, err := configstore.Get() + if err != nil { + return err + } + + mp, exists := p.modelPlugins[modelID] + if !exists { + modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("model plugin not found")} + return nil + } + + // check if the plugin is capable of analyzing the indicated part of the transaction + if mp.pluginType != t { + modelPlugStatus <- ModelStatus{ModelID: modelID, + Err: fmt.Errorf("plugin type %v cannot process a request with incompatible type %v", mp.pluginType, t)} + return nil + } + + process := p.modelProcessFunc[modelID] + + if conf.ModelPlugins[modelID].Mode == "async" { + modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("model plugin is async")} + return nil + } else { + res, err := process(ModelInput{TransactionId: transactionId, Payload: payload}) + // res, err := process(transactionId, payload) + + if err != nil { + modelPlugStatus <- ModelStatus{ModelID: modelID, Err: err} + return nil + } + // store the results + resultSyncMap, ok := p.results.Load(transactionId) + if !ok { + modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("transaction results not found")} + return nil + } + resultSyncMap.(*sync.Map).Store(modelID, res) + modelPlugStatus <- ModelStatus{ModelID: modelID, ProbAttack: res.ProbAttack, Err: nil} + } + return nil +} + +// CheckResult is in charge of calling the decision plugin with id decisionID over the +// transaction with id transactID +func (p *PluginManager) CheckResult(transactionId, decisionId string, wafParams map[string]string) (bool, error) { + logger := logging.Get() + + checkResults, ok := p.decisionCheckFunc[decisionId] + if !ok { + return false, fmt.Errorf("decision plugin not found") + } + + transactionResults, ok := p.results.Load(transactionId) + if !ok { + return false, fmt.Errorf("transaction results not found") + } + + cs, err := configstore.Get() + if err != nil { + return false, nil + } + + modelResultMap := make(map[string]ModelResults) + modelWeightMap := make(map[string]float64) + transactionResults.(*sync.Map).Range(func(key, value interface{}) bool { + modelResultMap[key.(string)] = value.(ModelResults) + modelWeightMap[key.(string)] = cs.ModelPlugins[key.(string)].Weight + return true + }) + + res, err := checkResults(DecisionInput{TransactionId: transactionId, Results: modelResultMap, ModelWeight: modelWeightMap, WAFdata: wafParams}) + logger.TPrintf(logging.INFO, transactionId, "%s | transaction checked. Block: %t ", decisionId, res) + + return res, err +} + +// ModelResultsHandler listens for messages on the model results queue +func (p *PluginManager) ModelResultsHandler(modelId string) error { + logger := logging.Get() + cs, err := configstore.Get() + if err != nil { + return err + } + + sub, err := p.natConn.Subscribe(modelId+"/results", func(msg *nats.Msg) { + go func(msg nats.Msg) { + data := &ModelTransmitionResults{} + err := json.Unmarshal(msg.Data, data) + if err != nil { + logger.Printf(logging.ERROR, "Model: %s | Failed to parse JSON payload", modelId) + } else { + var channel interface{} + var ok bool + if cs.ModelPlugins[modelId].Mode == "async" { + channel, ok = p.asyncModelsChannels.Load(data.TransactionId) + } else { + channel, ok = p.syncModelsChannels.Load(data.TransactionId) + } + if !ok { + logger.TPrintf(logging.ERROR, data.TransactionId, " Model %s | Transaction not found", modelId) + } else { + modelChannel, ok := channel.(*sync.Map).Load(cs.ModelPlugins[modelId].PluginType.String()) + if !ok { + logger.Printf(logging.ERROR, "Model %s not found", modelId) + } else { + if data.Error != nil { + modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, Err: data.Error} + } else { + if cs.ModelPlugins[modelId].Mode != "async" { + // store the results + resultSyncMap, ok := p.results.Load(data.TransactionId) + if !ok { + modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, Err: fmt.Errorf("transaction results not found")} + return + } + modelResult := ModelResults{ProbAttack: data.ProbAttack, Data: data.Data} + resultSyncMap.(*sync.Map).Store(modelId, modelResult) + } + modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, ProbAttack: data.ProbAttack, Err: nil} + } + } + } + } + }(*msg) + }) + + if err != nil { + logger.Printf(logging.ERROR, "Model: %s | Failed to subscribe to model queue | %s", modelId, err.Error()) + return err + } + + logger.Printf(logging.INFO, "Model: %s | Listening for messages on model results queue", modelId) + + defer sub.Unsubscribe() + defer p.natConn.Drain() + + select {} + +} + +// ModelProcessHandler listens for messages on the model queue +func ModelProcessHandler(modelId string, modelProcess func(ModelInput) (ModelResults, error)) error { + logger := logging.Get() + logger.Printf(logging.INFO, "Model: %s | Starting model process handler", modelId) + cs, err := configstore.Get() + if err != nil { + return err + } + + nc, err := nats.Connect(cs.NatsURL) + + if err != nil { + logger.Printf(logging.ERROR, "Model: %s | Failed to connect to NATS server", modelId) + return err + } + + _, err = nc.Subscribe(modelId, func(msg *nats.Msg) { + go func(msg nats.Msg) { + data := &ModelInput{} + err := json.Unmarshal(msg.Data, data) + if err != nil { + logger.Printf(logging.ERROR, "Model: %s | Failed to parse JSON payload", modelId) + } else { + res, err := modelProcess(*data) + modelResult := ModelResults{ProbAttack: res.ProbAttack, Data: res.Data} + payloadToSend := &ModelTransmitionResults{ + TransactionId: data.TransactionId, + ModelResults: modelResult, + Error: err, + } + + jsonPayload, err := json.Marshal(payloadToSend) + + if err != nil { + logger.Printf(logging.ERROR, "Model: %s | Failed to parse JSON payload", modelId) + } + + nc.Publish(modelId+"/results", jsonPayload) + } + }(*msg) + }) + + if err != nil { + logger.Printf(logging.ERROR, "Model: %s | Failed to subscribe to model queue | %s", modelId, err.Error()) + return err + } + + logger.Printf(logging.INFO, "Model: %s | Listening for messages on model queue", modelId) + return nil +} diff --git a/pluginmanager/pluginmanager_test.go b/pluginmanager/pluginmanager_test.go index a00efa0..2834794 100644 --- a/pluginmanager/pluginmanager_test.go +++ b/pluginmanager/pluginmanager_test.go @@ -1,368 +1,371 @@ -package pluginmanager - -import ( - "math/rand" - "time" - - cf "github.com/tilsor/ModSecIntl_wace_lib/configstore" - "go.opentelemetry.io/otel/sdk/metric" - "gopkg.in/yaml.v3" - - lg "github.com/tilsor/ModSecIntl_logging/logging" -) - -var baseConfig = `--- -logpath: "/tmp/wacetmp.log" -loglevel: "WARN" -` - -var trivialPlugin = ` - id: "trivial" - path: "../testdata/plugins/model/trivial.so" - weight: 1 - params: - param1: "first value" - param2: "second value" - param3: "third value" - plugintype: "Everything" - mode: sync -` - -var testPlugin = ` - id: "test" - path: "../testdata/plugins/decision/test.so" - wafweight: 0.5 - decisionbalance: 0.5 - params: - test1: "test" - test2: "testtest" - test3: "testtesttest" -` - -func generateRandomID() string { - letters := "1234567890ABCDEF" - id := "" - for i := 0; i < 16; i++ { - id += string(letters[rand.Intn(len(letters))]) - } - - return id -} - -var provider = metric.NewMeterProvider() -var testMeter = provider.Meter("example-meter") - -func initilize(configuration []byte) error { - var aux cf.ConfigFileData - err := yaml.Unmarshal(configuration, &aux) - if err != nil { - return err - } - err = cf.Get().SetConfig(aux) - if err != nil { - return err - } - logger := lg.Get() - - conf := cf.Get() - err = logger.LoadLogger(conf.LogPath, conf.LogLevel) - if err != nil { - return err - - } - return nil -} - -func init() { - rand.Seed(time.Now().UnixNano()) - - logger := lg.Get() - err := logger.LoadLogger("/dev/null", lg.ERROR) - if err != nil { - panic("Error loading logger") - } -} - -// func TestPluginInit(t *testing.T) { -// cases := []struct{ id, conf string }{ -// // {"invalid_path", ` - id: "invalid_path" -// // path: "../testdata/plugins/model/nonexistent.so" -// // plugintype: "AllRequest" -// // `}, -// {"no_init", ` - id: "no_init" -// path: "../testdata/plugins/model/no_init.so" -// plugintype: "AllRequest" -// `}, -// {"wrong_init", ` - id: "wrong_init" -// path: "../testdata/plugins/model/wrong_init.so" -// plugintype: "AllRequest" -// `}, -// {"error_init", ` - id: "error_init" -// path: "../testdata/plugins/model/error_init.so" -// plugintype: "AllRequest" -// `}, -// } - -// // Test model plugin initialization -// for _, c := range cases { -// config := baseConfig + "modelplugins:\n" + trivialPlugin + c.conf - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } -// plugins := New(testMeter) -// if _, exists := plugins.modelPlugins["trivial"]; !exists { -// t.Errorf("trivial plugin not loaded") -// } -// if _, exists := plugins.modelPlugins[c.id]; exists { -// t.Errorf(c.id + " should not load") -// } -// } - -// // Test decision plugin initialization -// for _, c := range cases { -// config := baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + testPlugin + c.conf - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } -// plugins := New(testMeter) -// if _, exists := plugins.decisionPlugins["test"]; !exists { -// t.Errorf("test plugin not loaded") -// } -// if _, exists := plugins.decisionPlugins[c.id]; exists { -// t.Errorf(c.id + " should not load") -// } -// } - -// } - -// func TestPluginParams(t *testing.T) { -// config := baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + testPlugin - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } - -// var buf bytes.Buffer -// logger := lg.Get() -// err = logger.LoadLoggerWriter(&buf, lg.INFO) -// if err != nil { -// t.Errorf("Error loading logger: %v", err) -// } - -// plugins := New(testMeter) - -// if !strings.Contains(buf.String(), "[trivial:InitPlugin] map[param1:first value param2:second value param3:third value]") { -// t.Errorf("trivial plugin did not initialize correctly, got: %v, expected: %v", buf.String(), "[trivial:InitPlugin] map[param1:first value param2:second value param3:third value]") -// } -// if !strings.Contains(buf.String(), "[test:InitPlugin] map[test1:test test2:testtest test3:testtesttest]") { -// t.Errorf("test plugin did not initialize correctly") -// } - -// transactionID := generateRandomID() -// modelPlugStatus := make(chan ModelStatus) -// go plugins.Process("trivial", transactionID, "test request1", cf.AllRequest, modelPlugStatus) -// <-modelPlugStatus -// if !strings.Contains(buf.String(), "[trivial:ProcessRequest] \"test request1\"") { -// t.Errorf("trivial plugin did not analyze request") -// } - -// go plugins.Process("trivial", transactionID, "test response1", cf.AllResponse, modelPlugStatus) -// <-modelPlugStatus -// if !strings.Contains(buf.String(), "[trivial:ProcessResponse] \"test response1\"") { -// t.Errorf("trivial plugin did not analyze response") -// } - -// _, err = plugins.CheckResult(transactionID, "test", map[string]string{"anomalyscore": "100", "inboundthreshold": "10"}) -// if err != nil { -// t.Errorf("Error checking result: %v", err) -// } -// if !strings.Contains(buf.String(), "[test:CheckResults]") { -// t.Errorf("test plugin did not execute correctly") -// } -// if !strings.Contains(buf.String(), "modelRes: map[trivial:") { -// t.Errorf("trivial result is not stored in modelRes") -// } -// if !strings.Contains(buf.String(), "modelWeight: map[trivial:1]") { -// t.Errorf("trivial weight is not stored in modelWeight") -// } -// if !strings.Contains(buf.String(), "modelThres: map[trivial:0.5]") { -// t.Errorf("trivial threshold is not stored in modelWeight") -// } -// if !strings.Contains(buf.String(), "wafData: map[anomalyscore:100 inboundthreshold:10]") { -// t.Errorf("waf params are not stored in wafData") -// } -// } - -// func TestPluginType(t *testing.T) { -// cases := []struct { -// id string -// pluginType, requestType cf.ModelPluginType -// executes bool -// }{ -// {"req_headers-req_headers", cf.RequestHeaders, cf.RequestHeaders, true}, -// {"req_headers-resp_headers", cf.RequestHeaders, cf.ResponseHeaders, false}, -// {"req_headers-all_req", cf.RequestHeaders, cf.AllRequest, false}, -// {"all_req-req_headers", cf.AllRequest, cf.RequestHeaders, false}, -// {"all_req-all_resp", cf.AllRequest, cf.AllResponse, false}, - -// {"resp_headers-resp_headers", cf.ResponseHeaders, cf.ResponseHeaders, true}, -// {"resp_headers-req_headers", cf.ResponseHeaders, cf.RequestHeaders, false}, -// {"resp_headers-all_resp", cf.ResponseHeaders, cf.AllResponse, false}, -// {"all_resp-resp_headers", cf.AllResponse, cf.ResponseHeaders, false}, -// {"all_resp-all_req", cf.AllResponse, cf.AllRequest, false}, - -// {"everything-req_headers", cf.Everything, cf.RequestHeaders, true}, -// {"everything-all_req", cf.Everything, cf.AllRequest, true}, -// {"everything-resp_body", cf.Everything, cf.ResponseBody, true}, -// {"everything-all_resp", cf.Everything, cf.AllResponse, true}, -// } - -// for _, c := range cases { -// config := baseConfig + "modelplugins:\n" + -// " - id: \"" + c.id + "\"\n" + -// " path: \"../testdata/plugins/model/trivial.so\"\n" + -// " plugintype: \"" + c.pluginType.String() + "\"\n" - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } - -// old := log.Writer() -// var buf bytes.Buffer -// log.SetOutput(&buf) -// defer log.SetOutput(old) - -// plugins := New(testMeter) - -// transactionID := generateRandomID() -// modelPlugStatus := make(chan ModelStatus) -// switch c.requestType { -// case cf.RequestHeaders, cf.RequestBody, cf.AllRequest: -// go plugins.Process(c.id, transactionID, "test request", c.requestType, modelPlugStatus) -// <-modelPlugStatus -// if strings.Contains(buf.String(), "[trivial:ProcessRequest] \"test request\"") != c.executes { -// t.Errorf("case %s: expected to run trivial plugin: %v", c.id, c.executes) -// } -// if _, exists := plugins.results.Load(transactionID); exists != c.executes { -// t.Errorf("case %s: expected to store results: %v", c.id, c.executes) -// } -// case cf.ResponseHeaders, cf.ResponseBody, cf.AllResponse: -// go plugins.Process(c.id, transactionID, "test response", c.requestType, modelPlugStatus) -// <-modelPlugStatus -// if strings.Contains(buf.String(), "[trivial:ProcessResponse] \"test response\"") != c.executes { -// t.Errorf("case %s: expected to run trivial plugin: %v", c.id, c.executes) -// } -// if _, exists := plugins.results.Load(transactionID); exists != c.executes { -// t.Errorf("case %s: expected to store results: %v", c.id, c.executes) -// } -// } - -// } -// } - -// func TestProcessRequestInvalid(t *testing.T) { -// cases := []struct{ id, conf string }{ -// {"no_req", ` - id: "no_req" -// path: "../testdata/plugins/model/no_req.so" -// plugintype: "Everything" -// `}, -// {"wrong_req", ` - id: "wrong_req" -// path: "../testdata/plugins/model/wrong_req.so" -// plugintype: "Everything" -// `}, -// {"error_req", ` - id: "error_req" -// path: "../testdata/plugins/model/error_req.so" -// plugintype: "Everything" -// `}, -// } - -// // Test model plugin initialization -// for _, c := range cases { -// config := baseConfig + "modelplugins:\n" + trivialPlugin + c.conf - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } -// plugins := New(testMeter) - -// transactionID := generateRandomID() -// modelPlugStatus := make(chan ModelStatus) -// go plugins.Process(c.id, transactionID, "test request", cf.AllRequest, modelPlugStatus) -// <-modelPlugStatus -// go plugins.Process(c.id, transactionID, "test response", cf.AllResponse, modelPlugStatus) -// <-modelPlugStatus - -// if _, exists := plugins.results.Load(transactionID); exists { -// t.Errorf("invalid test %s stored a result", c.id) -// } -// } - -// config := baseConfig + "modelplugins:\n" + trivialPlugin - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } -// plugins := New(testMeter) - -// transactionID := generateRandomID() -// modelPlugStatus := make(chan ModelStatus) -// go plugins.Process("nonexistent", transactionID, "test request", cf.AllRequest, modelPlugStatus) -// <-modelPlugStatus -// go plugins.Process("nonexistent", transactionID, "test response", cf.AllResponse, modelPlugStatus) -// <-modelPlugStatus - -// if _, exists := plugins.results.Load(transactionID); exists { -// t.Errorf("nonexistent test stored a result") -// } - -// } - -// func TestCheckResultInvalid(t *testing.T) { -// cases := []struct{ id, conf string }{ -// {"no_check", ` - id: "no_check" -// path: "../testdata/plugins/decision/no_check.so" -// `}, -// {"wrong_check", ` - id: "wrong_check" -// path: "../testdata/plugins/decision/wrong_check.so" -// `}, -// {"error_check", ` - id: "error_check" -// path: "../testdata/plugins/decision/error_check.so" -// `}, -// } - -// // Test model plugin initialization -// for _, c := range cases { -// config := baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + c.conf - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } -// plugins := New(testMeter) - -// _, err = plugins.CheckResult(generateRandomID(), c.id, make(map[string]string)) -// if err == nil { -// t.Errorf("invalid CheckResult %s did not rise an error", c.id) -// } -// } - -// config := baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + testPlugin - -// err := initilize([]byte(config)) -// if err != nil { -// t.Errorf("Error loading config: %v", err) -// } -// plugins := New(testMeter) - -// _, err = plugins.CheckResult(generateRandomID(), "nonexitent", make(map[string]string)) -// if err == nil { -// t.Errorf("nonexistent plugin did not rise an error") -// } - -// } +package pluginmanager + +import ( + "math/rand" + "time" + + "github.com/tilsor/ModSecIntl_wace_lib/configstore" + "go.opentelemetry.io/otel/sdk/metric" + "gopkg.in/yaml.v3" + + "github.com/tilsor/ModSecIntl_logging/logging" +) + +var baseConfig = `--- +logpath: "/tmp/wacetmp.log" +loglevel: "WARN" +` + +var trivialPlugin = ` - id: "trivial" + path: "../testdata/plugins/model/trivial.so" + weight: 1 + params: + param1: "first value" + param2: "second value" + param3: "third value" + plugintype: "Everything" + mode: sync +` + +var testPlugin = ` - id: "test" + path: "../testdata/plugins/decision/test.so" + wafweight: 0.5 + decisionbalance: 0.5 + params: + test1: "test" + test2: "testtest" + test3: "testtesttest" +` + +func generateRandomID() string { + letters := "1234567890ABCDEF" + id := "" + for i := 0; i < 16; i++ { + id += string(letters[rand.Intn(len(letters))]) + } + + return id +} + +var provider = metric.NewMeterProvider() +var testMeter = provider.Meter("example-meter") + +func initilize(configuration []byte) error { + var aux configstore.ConfigFileData + err := yaml.Unmarshal(configuration, &aux) + if err != nil { + return err + } + cs, err := configstore.Get() + if err != nil { + return err + } + err = cs.SetConfig(aux) + if err != nil { + return err + } + logger := logging.Get() + + err = logger.LoadLogger(cs.LogPath, cs.LogLevel) + if err != nil { + return err + + } + return nil +} + +func init() { + rand.Seed(time.Now().UnixNano()) + + logger := logging.Get() + err := logger.LoadLogger("/dev/null", logging.ERROR) + if err != nil { + panic("Error loading logger") + } +} + +// func TestPluginInit(t *testing.T) { +// cases := []struct{ id, conf string }{ +// // {"invalid_path", ` - id: "invalid_path" +// // path: "../testdata/plugins/model/nonexistent.so" +// // plugintype: "AllRequest" +// // `}, +// {"no_init", ` - id: "no_init" +// path: "../testdata/plugins/model/no_init.so" +// plugintype: "AllRequest" +// `}, +// {"wrong_init", ` - id: "wrong_init" +// path: "../testdata/plugins/model/wrong_init.so" +// plugintype: "AllRequest" +// `}, +// {"error_init", ` - id: "error_init" +// path: "../testdata/plugins/model/error_init.so" +// plugintype: "AllRequest" +// `}, +// } + +// // Test model plugin initialization +// for _, c := range cases { +// config := baseConfig + "modelplugins:\n" + trivialPlugin + c.conf + +// err := initilize([]byte(config)) +// if err != nil { +// t.Errorf("Error loading config: %v", err) +// } +// plugins := New(testMeter) +// if _, exists := plugins.modelPlugins["trivial"]; !exists { +// t.Errorf("trivial plugin not loaded") +// } +// if _, exists := plugins.modelPlugins[c.id]; exists { +// t.Errorf(c.id + " should not load") +// } +// } + +// // Test decision plugin initialization +// for _, c := range cases { +// config := baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + testPlugin + c.conf + +// err := initilize([]byte(config)) +// if err != nil { +// t.Errorf("Error loading config: %v", err) +// } +// plugins := New(testMeter) +// if _, exists := plugins.decisionPlugins["test"]; !exists { +// t.Errorf("test plugin not loaded") +// } +// if _, exists := plugins.decisionPlugins[c.id]; exists { +// t.Errorf(c.id + " should not load") +// } +// } + +// } + +// func TestPluginParams(t *testing.T) { +// config := baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + testPlugin + +// err := initilize([]byte(config)) +// if err != nil { +// t.Errorf("Error loading config: %v", err) +// } + +// var buf bytes.Buffer +// logger := lg.Get() +// err = logger.LoadLoggerWriter(&buf, lg.INFO) +// if err != nil { +// t.Errorf("Error loading logger: %v", err) +// } + +// plugins := New(testMeter) + +// if !strings.Contains(buf.String(), "[trivial:InitPlugin] map[param1:first value param2:second value param3:third value]") { +// t.Errorf("trivial plugin did not initialize correctly, got: %v, expected: %v", buf.String(), "[trivial:InitPlugin] map[param1:first value param2:second value param3:third value]") +// } +// if !strings.Contains(buf.String(), "[test:InitPlugin] map[test1:test test2:testtest test3:testtesttest]") { +// t.Errorf("test plugin did not initialize correctly") +// } + +// transactionID := generateRandomID() +// modelPlugStatus := make(chan ModelStatus) +// go plugins.Process("trivial", transactionID, "test request1", cf.AllRequest, modelPlugStatus) +// <-modelPlugStatus +// if !strings.Contains(buf.String(), "[trivial:ProcessRequest] \"test request1\"") { +// t.Errorf("trivial plugin did not analyze request") +// } + +// go plugins.Process("trivial", transactionID, "test response1", cf.AllResponse, modelPlugStatus) +// <-modelPlugStatus +// if !strings.Contains(buf.String(), "[trivial:ProcessResponse] \"test response1\"") { +// t.Errorf("trivial plugin did not analyze response") +// } + +// _, err = plugins.CheckResult(transactionID, "test", map[string]string{"anomalyscore": "100", "inboundthreshold": "10"}) +// if err != nil { +// t.Errorf("Error checking result: %v", err) +// } +// if !strings.Contains(buf.String(), "[test:CheckResults]") { +// t.Errorf("test plugin did not execute correctly") +// } +// if !strings.Contains(buf.String(), "modelRes: map[trivial:") { +// t.Errorf("trivial result is not stored in modelRes") +// } +// if !strings.Contains(buf.String(), "modelWeight: map[trivial:1]") { +// t.Errorf("trivial weight is not stored in modelWeight") +// } +// if !strings.Contains(buf.String(), "modelThres: map[trivial:0.5]") { +// t.Errorf("trivial threshold is not stored in modelWeight") +// } +// if !strings.Contains(buf.String(), "wafData: map[anomalyscore:100 inboundthreshold:10]") { +// t.Errorf("waf params are not stored in wafData") +// } +// } + +// func TestPluginType(t *testing.T) { +// cases := []struct { +// id string +// pluginType, requestType cf.ModelPluginType +// executes bool +// }{ +// {"req_headers-req_headers", cf.RequestHeaders, cf.RequestHeaders, true}, +// {"req_headers-resp_headers", cf.RequestHeaders, cf.ResponseHeaders, false}, +// {"req_headers-all_req", cf.RequestHeaders, cf.AllRequest, false}, +// {"all_req-req_headers", cf.AllRequest, cf.RequestHeaders, false}, +// {"all_req-all_resp", cf.AllRequest, cf.AllResponse, false}, + +// {"resp_headers-resp_headers", cf.ResponseHeaders, cf.ResponseHeaders, true}, +// {"resp_headers-req_headers", cf.ResponseHeaders, cf.RequestHeaders, false}, +// {"resp_headers-all_resp", cf.ResponseHeaders, cf.AllResponse, false}, +// {"all_resp-resp_headers", cf.AllResponse, cf.ResponseHeaders, false}, +// {"all_resp-all_req", cf.AllResponse, cf.AllRequest, false}, + +// {"everything-req_headers", cf.Everything, cf.RequestHeaders, true}, +// {"everything-all_req", cf.Everything, cf.AllRequest, true}, +// {"everything-resp_body", cf.Everything, cf.ResponseBody, true}, +// {"everything-all_resp", cf.Everything, cf.AllResponse, true}, +// } + +// for _, c := range cases { +// config := baseConfig + "modelplugins:\n" + +// " - id: \"" + c.id + "\"\n" + +// " path: \"../testdata/plugins/model/trivial.so\"\n" + +// " plugintype: \"" + c.pluginType.String() + "\"\n" + +// err := initilize([]byte(config)) +// if err != nil { +// t.Errorf("Error loading config: %v", err) +// } + +// old := log.Writer() +// var buf bytes.Buffer +// log.SetOutput(&buf) +// defer log.SetOutput(old) + +// plugins := New(testMeter) + +// transactionID := generateRandomID() +// modelPlugStatus := make(chan ModelStatus) +// switch c.requestType { +// case cf.RequestHeaders, cf.RequestBody, cf.AllRequest: +// go plugins.Process(c.id, transactionID, "test request", c.requestType, modelPlugStatus) +// <-modelPlugStatus +// if strings.Contains(buf.String(), "[trivial:ProcessRequest] \"test request\"") != c.executes { +// t.Errorf("case %s: expected to run trivial plugin: %v", c.id, c.executes) +// } +// if _, exists := plugins.results.Load(transactionID); exists != c.executes { +// t.Errorf("case %s: expected to store results: %v", c.id, c.executes) +// } +// case cf.ResponseHeaders, cf.ResponseBody, cf.AllResponse: +// go plugins.Process(c.id, transactionID, "test response", c.requestType, modelPlugStatus) +// <-modelPlugStatus +// if strings.Contains(buf.String(), "[trivial:ProcessResponse] \"test response\"") != c.executes { +// t.Errorf("case %s: expected to run trivial plugin: %v", c.id, c.executes) +// } +// if _, exists := plugins.results.Load(transactionID); exists != c.executes { +// t.Errorf("case %s: expected to store results: %v", c.id, c.executes) +// } +// } + +// } +// } + +// func TestProcessRequestInvalid(t *testing.T) { +// cases := []struct{ id, conf string }{ +// {"no_req", ` - id: "no_req" +// path: "../testdata/plugins/model/no_req.so" +// plugintype: "Everything" +// `}, +// {"wrong_req", ` - id: "wrong_req" +// path: "../testdata/plugins/model/wrong_req.so" +// plugintype: "Everything" +// `}, +// {"error_req", ` - id: "error_req" +// path: "../testdata/plugins/model/error_req.so" +// plugintype: "Everything" +// `}, +// } + +// // Test model plugin initialization +// for _, c := range cases { +// config := baseConfig + "modelplugins:\n" + trivialPlugin + c.conf + +// err := initilize([]byte(config)) +// if err != nil { +// t.Errorf("Error loading config: %v", err) +// } +// plugins := New(testMeter) + +// transactionID := generateRandomID() +// modelPlugStatus := make(chan ModelStatus) +// go plugins.Process(c.id, transactionID, "test request", cf.AllRequest, modelPlugStatus) +// <-modelPlugStatus +// go plugins.Process(c.id, transactionID, "test response", cf.AllResponse, modelPlugStatus) +// <-modelPlugStatus + +// if _, exists := plugins.results.Load(transactionID); exists { +// t.Errorf("invalid test %s stored a result", c.id) +// } +// } + +// config := baseConfig + "modelplugins:\n" + trivialPlugin + +// err := initilize([]byte(config)) +// if err != nil { +// t.Errorf("Error loading config: %v", err) +// } +// plugins := New(testMeter) + +// transactionID := generateRandomID() +// modelPlugStatus := make(chan ModelStatus) +// go plugins.Process("nonexistent", transactionID, "test request", cf.AllRequest, modelPlugStatus) +// <-modelPlugStatus +// go plugins.Process("nonexistent", transactionID, "test response", cf.AllResponse, modelPlugStatus) +// <-modelPlugStatus + +// if _, exists := plugins.results.Load(transactionID); exists { +// t.Errorf("nonexistent test stored a result") +// } + +// } + +// func TestCheckResultInvalid(t *testing.T) { +// cases := []struct{ id, conf string }{ +// {"no_check", ` - id: "no_check" +// path: "../testdata/plugins/decision/no_check.so" +// `}, +// {"wrong_check", ` - id: "wrong_check" +// path: "../testdata/plugins/decision/wrong_check.so" +// `}, +// {"error_check", ` - id: "error_check" +// path: "../testdata/plugins/decision/error_check.so" +// `}, +// } + +// // Test model plugin initialization +// for _, c := range cases { +// config := baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + c.conf + +// err := initilize([]byte(config)) +// if err != nil { +// t.Errorf("Error loading config: %v", err) +// } +// plugins := New(testMeter) + +// _, err = plugins.CheckResult(generateRandomID(), c.id, make(map[string]string)) +// if err == nil { +// t.Errorf("invalid CheckResult %s did not rise an error", c.id) +// } +// } + +// config := baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + testPlugin + +// err := initilize([]byte(config)) +// if err != nil { +// t.Errorf("Error loading config: %v", err) +// } +// plugins := New(testMeter) + +// _, err = plugins.CheckResult(generateRandomID(), "nonexitent", make(map[string]string)) +// if err == nil { +// t.Errorf("nonexistent plugin did not rise an error") +// } + +// } diff --git a/testdata/plugins/model/trivial.go b/testdata/plugins/model/trivial.go index 228b75c..046ff50 100644 --- a/testdata/plugins/model/trivial.go +++ b/testdata/plugins/model/trivial.go @@ -34,7 +34,7 @@ func InitPluginAsync(params map[string]string, meter metric.Meter, natsManager f func Process(input pm.ModelInput) (pm.ModelResults, error) { logger := lg.Get() - logger.TPrintf(lg.WARN, input.TransactionId, "[trivial:Process] \"%s\"\n", input.Payload) + logger.TPrintf(lg.WARN, input.TransactionId, "[trivial:Process] \"%v\"\n", input.Payload) result := pm.ModelResults{ ProbAttack: 0.0, Data: make(map[string]interface{}), diff --git a/testdata/plugins/model/trivial2.go b/testdata/plugins/model/trivial2.go index f04218c..7b9ed66 100644 --- a/testdata/plugins/model/trivial2.go +++ b/testdata/plugins/model/trivial2.go @@ -34,7 +34,7 @@ func InitPluginAsync(params map[string]string, meter metric.Meter, natsManager f func Process(input pm.ModelInput) (pm.ModelResults, error) { logger := lg.Get() - logger.TPrintf(lg.WARN, input.TransactionId, "[trivial2:Proccess] \"%s\"\n", input.Payload) + logger.TPrintf(lg.WARN, input.TransactionId, "[trivial2:Proccess] \"%v\"\n", input.Payload) result := pm.ModelResults{ ProbAttack: 1.0, Data: make(map[string]interface{}), diff --git a/testdata/plugins/model/trivial_async.go b/testdata/plugins/model/trivial_async.go index 3896fbc..9e38210 100644 --- a/testdata/plugins/model/trivial_async.go +++ b/testdata/plugins/model/trivial_async.go @@ -49,7 +49,7 @@ func InitPluginAsync(params map[string]string, meter metric.Meter, natsManager f func Process(input pm.ModelInput) (pm.ModelResults, error) { time.Sleep(time.Duration(sleepTime) * time.Second) logger := lg.Get() - logger.TPrintf(lg.WARN, input.TransactionId, "[trivial_async:Process] \"%s\"\n", input.Payload) + logger.TPrintf(lg.WARN, input.TransactionId, "[trivial_async:Process] \"%v\"\n", input.Payload) result := pm.ModelResults{ ProbAttack: 0.0, Data: make(map[string]interface{}), diff --git a/testdata/plugins/model/trivial_async2.go b/testdata/plugins/model/trivial_async2.go index 86c0243..79a0810 100644 --- a/testdata/plugins/model/trivial_async2.go +++ b/testdata/plugins/model/trivial_async2.go @@ -49,7 +49,7 @@ func InitPluginAsync(params map[string]string, meter metric.Meter, natsManager f func Process(input pm.ModelInput) (pm.ModelResults, error) { time.Sleep(time.Duration(sleepTime) * time.Second) logger := lg.Get() - logger.TPrintf(lg.WARN, input.TransactionId, "[trivial_async2:Process] \"%s\"\n", input.Payload) + logger.TPrintf(lg.WARN, input.TransactionId, "[trivial_async2:Process] \"%v\"\n", input.Payload) result := pm.ModelResults{ ProbAttack: 1.0, Data: make(map[string]interface{}), diff --git a/wacecore.go b/wacecore.go index 3d5388b..82f315b 100644 --- a/wacecore.go +++ b/wacecore.go @@ -1,265 +1,282 @@ -/* -The main package of WACE. -*/ -package wace - -import ( - "fmt" - "os" - "strings" - "sync" - "sync/atomic" - "time" - - cf "github.com/tilsor/ModSecIntl_wace_lib/configstore" - - pm "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" - - lg "github.com/tilsor/ModSecIntl_logging/logging" - - "context" - - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/metric" -) - -var plugins *pm.PluginManager -var ctx = context.Background() -var meter metric.Meter - -// transactionSync is a struct to syncronize the analysis of a given -// transaction. Each time callPlugins is executed, the counter is -// incremented. At the end of each callPlugins execution, a message is -// sent through the channel, to signal checkTransaction that it has -// finished analyzing the request. checkTransaction waits for Counter -// number of messages in the channel, before calling the decision -// plugin and sending the result to the client. -type transactionSync struct { - Channel chan string - Counter int64 -} - -var ( - // Sync map witg channels to receive a notification when all plugins finish - // processing a transaction - analysisMap sync.Map -) - -// addTransactionAnalysis adds a transaction to the analysis map. If the -// transaction already exists, it increments the counter of the transaction -// by one. -func addTransactionAnalysis(transactionID string) { - tSync := transactionSync{ - Channel: make(chan string), - Counter: 1, - } - value, loaded := analysisMap.LoadOrStore(transactionID, &tSync) - if loaded { - atomic.AddInt64(&value.(*transactionSync).Counter, 1) - } -} - -// callPlugins calls the model plugins in the given list, with the given input. -// It waits for all the synchronous model plugins to finish, and sends the -// result to the client. The asynchronous model plugins are executed in parallel -func callPlugins(input string, models []string, t cf.ModelPluginType, transactionId string) { - logger := lg.Get() - - // channel to receive the status of the execution of the analysis - // of all the model plugins executed - modelPlugStatus := make(chan pm.ModelStatus) - asyncModelPlugStatus := make(chan pm.ModelStatus) - - plugins.AddModelChannel(transactionId, t, asyncModelPlugStatus, "async") - plugins.AddModelChannel(transactionId, t, modelPlugStatus, "sync") - - conf := cf.Get() - - syncCounter := 0 - asyncCounter := 0 - - startTime := time.Now() - - for _, id := range models { - logger.TPrintf(lg.DEBUG, transactionId, "%s | calling from core", id) - if _, ok := conf.ModelPlugins[id]; !ok { - logger.TPrintf(lg.ERROR, transactionId, "core | model plugin %s not found", id) - } else { - if conf.ModelPlugins[id].PluginType != t { - logger.TPrintf(lg.ERROR, transactionId, "core | model plugin %s is not of type %s", id, t) - } else { - if conf.IsAsync(id) { - asyncCounter++ - go plugins.AddToQueue(id, transactionId, input) - } else { - if conf.ModelPlugins[id].Remote { - go plugins.AddToQueue(id, transactionId, input) - } else { - go plugins.Process(id, transactionId, input, t, modelPlugStatus) - } - syncCounter++ - } - } - } - } - - go func() { - logger.TPrintf(lg.DEBUG, transactionId, "core | waiting for %d async model plugins to finish", asyncCounter) - wg := sync.WaitGroup{} - wg.Add(asyncCounter) - for i := 0; i < asyncCounter; i++ { - // Await for the execution of the async model plugins - logger.TPrintf(lg.DEBUG, transactionId, "core | Waiting for async model plugin %d...", i+1) - status := <-asyncModelPlugStatus - if status.Err == nil { - logger.TPrintf(lg.DEBUG, transactionId, "%s async | success. Result: %.5f", status.ModelID, status.ProbAttack) - histogramMeter, err := meter.Int64Histogram("wace.model.duration.nanoseconds") - if err != nil { - logger.TPrintf(lg.WARN, transactionId, "core | failed to record duration metric: %v", err.Error()) - } - histogramMeter.Record(ctx, time.Since(startTime).Nanoseconds(), metric.WithAttributes( - attribute.String("model_id", status.ModelID), - attribute.String("model_mode", "async"), - attribute.Float64("attack_probability", status.ProbAttack))) - } else { - logger.TPrintf(lg.WARN, transactionId, "%s | %v", status.ModelID, status.Err) - } - wg.Done() - } - wg.Wait() - plugins.RemoveAsyncModelChannel(transactionId, t) - }() - - logger.TPrintf(lg.DEBUG, transactionId, "core | waiting for %d sync model plugins to finish", syncCounter) - for i := 0; i < syncCounter; i++ { - // Await for the execution of the model plugins - logger.TPrintf(lg.DEBUG, transactionId, "core | Waiting for sync model plugin %d...", i+1) - status := <-modelPlugStatus - if status.Err == nil { - logger.TPrintf(lg.DEBUG, transactionId, "%s sync | success. Result: %.5f", status.ModelID, status.ProbAttack) - - histogramMeter, err := meter.Int64Histogram("wace.model.duration.nanoseconds") - if err != nil { - logger.TPrintf(lg.WARN, transactionId, "core | failed to record duration metric: %v", err.Error()) - } - histogramMeter.Record(ctx, time.Since(startTime).Nanoseconds(), metric.WithAttributes( - attribute.String("model_id", status.ModelID), - attribute.String("model_mode", "sync"), - attribute.Float64("attack_probability", status.ProbAttack))) - } else { - logger.TPrintf(lg.WARN, transactionId, "%s | %v", status.ModelID, status.Err) - } - } - - value, ok := analysisMap.Load(transactionId) - if !ok { - logger.TPrintf(lg.ERROR, transactionId, "core | could not find transaction %s in analysis map", transactionId) - return - } - analysisChan := value.(*transactionSync).Channel - analysisChan <- "done" -} - -// InitTransaction initializes a transaction with the given id -func InitTransaction(transactionId string) { - logger := lg.Get() - logger.StartTransaction(transactionId) - logger.TPrintf(lg.DEBUG, transactionId, "core | initializing transaction") - tSync := transactionSync{ - Channel: make(chan string), - Counter: 0, - } - analysisMap.Store(transactionId, &tSync) - plugins.InitTransaction(transactionId) -} - -// Analyze calls the model plugins with the given payload and models -func Analyze(modelsTypeAsString, transactionId, payload string, models []string) error { - if len(models) > 0 { - logger := lg.Get() - modelsType, err := cf.StringToPluginType(modelsTypeAsString) - if err != nil { - logger.TPrintf(lg.ERROR, transactionId, "core | %s is not a valid type", modelsTypeAsString) - return err - } - logger.TPrintf(lg.DEBUG, transactionId, "core | analyzing %s: [%s...]", modelsTypeAsString, strings.Split(payload, "\n")[0]) - addTransactionAnalysis(transactionId) - go callPlugins(payload, models, modelsType, transactionId) - } - return nil -} - -// CheckTransaction checks the result of the analysis of the transaction -// with the given id and decision plugin -func CheckTransaction(transactionID, decisionPlugin string, wafParams map[string]string) (bool, error) { - logger := lg.Get() - logger.TPrintf(lg.DEBUG, transactionID, "core | checking transaction") - - value, exists := analysisMap.Load(transactionID) - - if !exists { - return false, fmt.Errorf("transaction with id %s does not exist", transactionID) - } - - sync := value.(*transactionSync) - - logger.TPrintln(lg.DEBUG, transactionID, "core | waiting for all models to finish...") - - for i := 0; i < int(sync.Counter); i++ { - <-sync.Channel - } - sync.Counter = 0 - - logger.TPrintln(lg.DEBUG, transactionID, "core | done, checking data...") - res, err := plugins.CheckResult(transactionID, decisionPlugin, wafParams) - - if err == nil { - logger.TPrintf(lg.DEBUG, transactionID, "core | transaction checked successfully. Blocking transaction: %t", res) - - if res { - metric, err := meter.Int64Counter("wace.client.request.blocked.total", metric.WithDescription(decisionPlugin)) - if err != nil { - logger.TPrintf(lg.WARN, transactionID, "core | failed to record blocked request metric: %v", err.Error()) - } - metric.Add(ctx, 1) - } - } else { - logger.TPrintf(lg.ERROR, transactionID, "core | could not check transaction: %v", err) - } - return res, err -} - -// CloseTransaction closes the transaction with the given id -// removing the transaction sync model results -func CloseTransaction(transactionID string) { - plugins.CloseTransaction(transactionID) - value, ok := analysisMap.Load(transactionID) - logger := lg.Get() - - if !ok { - logger.TPrintf(lg.ERROR, transactionID, "Analysis for transaction %s not found", transactionID) - } else { - close(value.(*transactionSync).Channel) - for range value.(*transactionSync).Channel {} - analysisMap.Delete(transactionID) - } -} - -// Init initializes the WACE core with the given metric meter -func Init(met metric.Meter) { - logger := lg.Get() - conf := cf.Get() - meter = met - - err := logger.LoadLogger(conf.LogPath, conf.LogLevel) - if err != nil { - logger.Printf(lg.ERROR, "ERROR: could not open wace log file: %v", err) - os.Exit(1) - - } - logger.Printf(lg.DEBUG, "Writing logs to %s from now", conf.LogPath) - - logger.Println(lg.DEBUG, "Loading plugin manager...") - plugins = pm.New(met) - logger.Println(lg.DEBUG, "Plugin manager loaded") -} +/* +The main package of WACE. +*/ +package wace + +import ( + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/tilsor/ModSecIntl_wace_lib/configstore" + + "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" + + "github.com/tilsor/ModSecIntl_logging/logging" + + "context" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +var plugins *pluginmanager.PluginManager +var ctx = context.Background() +var meter metric.Meter + +// transactionSync is a struct to syncronize the analysis of a given +// transaction. Each time callPlugins is executed, the counter is +// incremented. At the end of each callPlugins execution, a message is +// sent through the channel, to signal checkTransaction that it has +// finished analyzing the request. checkTransaction waits for Counter +// number of messages in the channel, before calling the decision +// plugin and sending the result to the client. +type transactionSync struct { + Channel chan string + Counter int64 +} + +var ( + // Sync map witg channels to receive a notification when all plugins finish + // processing a transaction + analysisMap sync.Map +) + +// addTransactionAnalysis adds a transaction to the analysis map. If the +// transaction already exists, it increments the counter of the transaction +// by one. +func addTransactionAnalysis(transactionID string) { + tSync := transactionSync{ + Channel: make(chan string), + Counter: 1, + } + value, loaded := analysisMap.LoadOrStore(transactionID, &tSync) + if loaded { + atomic.AddInt64(&value.(*transactionSync).Counter, 1) + } +} + +// callPlugins calls the model plugins in the given list, with the given input. +// It waits for all the synchronous model plugins to finish, and sends the +// result to the client. The asynchronous model plugins are executed in parallel +func callPlugins(input pluginmanager.HTTPPayload, models []string, t configstore.ModelPluginType, transactionID string) error { + logger := logging.Get() + + // channel to receive the status of the execution of the analysis + // of all the model plugins executed + modelPluginStatus := make(chan pluginmanager.ModelStatus) + asyncModelPluginStatus := make(chan pluginmanager.ModelStatus) + + plugins.AddModelChannel(transactionID, t, asyncModelPluginStatus, "async") + plugins.AddModelChannel(transactionID, t, modelPluginStatus, "sync") + + conf, err := configstore.Get() + if err != nil { + return err + } + + syncCounter := 0 + asyncCounter := 0 + + startTime := time.Now() + + for _, id := range models { + logger.TPrintf(logging.DEBUG, transactionID, "%s | calling from core", id) + if _, ok := conf.ModelPlugins[id]; !ok { + logger.TPrintf(logging.ERROR, transactionID, "core | model plugin %s not found", id) + } else { + if conf.ModelPlugins[id].PluginType != t { + logger.TPrintf(logging.ERROR, transactionID, "core | model plugin %s is not of type %s", id, t) + } else { + if conf.IsAsync(id) { + asyncCounter++ + go plugins.AddToQueue(id, transactionID, input) + } else { + if conf.ModelPlugins[id].Remote { + go plugins.AddToQueue(id, transactionID, input) + } else { + go plugins.Process(id, transactionID, input, t, modelPluginStatus) + } + syncCounter++ + } + } + } + } + + go func() { + logger.TPrintf(logging.DEBUG, transactionID, "core | waiting for %d async model plugins to finish", asyncCounter) + wg := sync.WaitGroup{} + wg.Add(asyncCounter) + for i := 0; i < asyncCounter; i++ { + // Await for the execution of the async model plugins + logger.TPrintf(logging.DEBUG, transactionID, "core | Waiting for async model plugin %d...", i+1) + status := <-asyncModelPluginStatus + if status.Err == nil { + logger.TPrintf(logging.DEBUG, transactionID, "%s async | success. Result: %.5f", status.ModelID, status.ProbAttack) + histogramMeter, err := meter.Int64Histogram("wace.model.duration.nanoseconds") + if err != nil { + logger.TPrintf(logging.WARN, transactionID, "core | failed to record duration metric: %v", err.Error()) + } + histogramMeter.Record(ctx, time.Since(startTime).Nanoseconds(), metric.WithAttributes( + attribute.String("model_id", status.ModelID), + attribute.String("model_mode", "async"), + attribute.Float64("attack_probability", status.ProbAttack))) + } else { + logger.TPrintf(logging.WARN, transactionID, "%s | %v", status.ModelID, status.Err) + } + wg.Done() + } + wg.Wait() + plugins.RemoveAsyncModelChannel(transactionID, t) + }() + + logger.TPrintf(logging.DEBUG, transactionID, "core | waiting for %d sync model plugins to finish", syncCounter) + for i := 0; i < syncCounter; i++ { + // Await for the execution of the model plugins + logger.TPrintf(logging.DEBUG, transactionID, "core | Waiting for sync model plugin %d...", i+1) + status := <-modelPluginStatus + if status.Err == nil { + logger.TPrintf(logging.DEBUG, transactionID, "%s sync | success. Result: %.5f", status.ModelID, status.ProbAttack) + + histogramMeter, err := meter.Int64Histogram("wace.model.duration.nanoseconds") + if err != nil { + logger.TPrintf(logging.WARN, transactionID, "core | failed to record duration metric: %v", err.Error()) + } + histogramMeter.Record(ctx, time.Since(startTime).Nanoseconds(), metric.WithAttributes( + attribute.String("model_id", status.ModelID), + attribute.String("model_mode", "sync"), + attribute.Float64("attack_probability", status.ProbAttack))) + } else { + logger.TPrintf(logging.WARN, transactionID, "%s | %v", status.ModelID, status.Err) + } + } + + value, ok := analysisMap.Load(transactionID) + if !ok { + logger.TPrintf(logging.ERROR, transactionID, "core | could not find transaction %s in analysis map", transactionID) + return fmt.Errorf("core | could not find transaction %s in analysis map", transactionID) + } + analysisChan := value.(*transactionSync).Channel + analysisChan <- "done" + return nil +} + +// InitTransaction initializes a transaction with the given id +func InitTransaction(transactionId string) { + logger := logging.Get() + logger.StartTransaction(transactionId) + logger.TPrintf(logging.DEBUG, transactionId, "core | initializing transaction") + tSync := transactionSync{ + Channel: make(chan string), + Counter: 0, + } + analysisMap.Store(transactionId, &tSync) + plugins.InitTransaction(transactionId) +} + +// Analyze calls the model plugins with the given payload and models +func Analyze(modelsTypeAsString, transactionId string, payload pluginmanager.HTTPPayload, models []string) error { + if len(models) > 0 { + logger := logging.Get() + modelsType, err := configstore.StringToPluginType(modelsTypeAsString) + if err != nil { + logger.TPrintf(logging.ERROR, transactionId, "core | %s is not a valid type", modelsTypeAsString) + return err + } + logger.TPrintf(logging.DEBUG, transactionId, "core | analyzing %s: [%v...]", modelsTypeAsString, payload) + addTransactionAnalysis(transactionId) + go callPlugins(payload, models, modelsType, transactionId) + } + return nil +} + +// CheckTransaction checks the result of the analysis of the transaction +// with the given id and decision plugin +func CheckTransaction(transactionID, decisionPlugin string, wafParams map[string]string) (bool, error) { + logger := logging.Get() + logger.TPrintf(logging.DEBUG, transactionID, "core | checking transaction") + + value, exists := analysisMap.Load(transactionID) + + if !exists { + return false, fmt.Errorf("transaction with id %s does not exist", transactionID) + } + + sync := value.(*transactionSync) + + logger.TPrintln(logging.DEBUG, transactionID, "core | waiting for all models to finish...") + + for i := 0; i < int(sync.Counter); i++ { + <-sync.Channel + } + sync.Counter = 0 + + logger.TPrintln(logging.DEBUG, transactionID, "core | done, checking data...") + res, err := plugins.CheckResult(transactionID, decisionPlugin, wafParams) + + if err == nil { + logger.TPrintf(logging.DEBUG, transactionID, "core | transaction checked successfully. Blocking transaction: %t", res) + + if res { + metric, err := meter.Int64Counter("wace.client.request.blocked.total", metric.WithDescription(decisionPlugin)) + if err != nil { + logger.TPrintf(logging.WARN, transactionID, "core | failed to record blocked request metric: %v", err.Error()) + } + metric.Add(ctx, 1) + } + } else { + logger.TPrintf(logging.ERROR, transactionID, "core | could not check transaction: %v", err) + } + return res, err +} + +// CloseTransaction closes the transaction with the given id +// removing the transaction sync model results +func CloseTransaction(transactionID string) { + plugins.CloseTransaction(transactionID) + value, ok := analysisMap.Load(transactionID) + logger := logging.Get() + + if !ok { + logger.TPrintf(logging.ERROR, transactionID, "Analysis for transaction %s not found", transactionID) + } else { + close(value.(*transactionSync).Channel) + for range value.(*transactionSync).Channel { + } + analysisMap.Delete(transactionID) + } +} + +// Init initializes the WACE core with the given metric meter +func Init(met metric.Meter, conf configstore.ConfigFileData) error { + logger := logging.Get() + + cs, err := configstore.New() + if err != nil { + return err + } + + err = cs.SetConfig(conf) + if err != nil { + return err + } + + meter = met + + err = logger.LoadLogger(cs.LogPath, cs.LogLevel) + if err != nil { + logger.Printf(logging.ERROR, "ERROR: could not open wace log file: %v", err) + return err + } + logger.Printf(logging.DEBUG, "Writing logs to %s from now", cs.LogPath) + + logger.Println(logging.DEBUG, "Loading plugin manager...") + plugins, err = pluginmanager.New(met) + if err != nil { + return err + } + logger.Println(logging.DEBUG, "Plugin manager loaded") + + return nil +} diff --git a/wacecore_test.go b/wacecore_test.go index 1c2a34a..5115dae 100644 --- a/wacecore_test.go +++ b/wacecore_test.go @@ -1,533 +1,565 @@ -package wace - -import ( - "math/rand" - "strconv" - "strings" - "testing" - "time" - - cf "github.com/tilsor/ModSecIntl_wace_lib/configstore" - "go.opentelemetry.io/otel/sdk/metric" - - "gopkg.in/yaml.v3" -) - -var requestLine = "POST /cgi-bin/process.cgi HTTP/1.1\n" -var requestHeaders = `User-Agent: Mozilla/4.0 (compatible; MSIE5.01; Windows NT) -Host: www.tutorialspoint.com -Content-Type: application/x-www-form-urlencoded -Content-Length: length -Accept-Language: en-us -Accept-Encoding: gzip, deflate -Connection: Keep-Alive -` - -var requestBody = "licenseID=string&content=string&/paramsXML=string\n" -var wholeRequest = requestLine + requestHeaders + "\n" + requestBody - -var responseLine = "HTTP/1.1 200 OK\n" -var responseHeaders = `Date: Mon, 27 Jul 2009 12:28:53 GMT -Server: Apache/2.2.14 (Win32) -Last-Modified: Wed, 22 Jul 2009 19:15:56 GMT -Content-Length: 88 -Content-Type: text/html -Connection: Closed -` -var responseBody = ` -
-