diff --git a/.github/workflows/go.yaml b/.github/workflows/go.yaml index a27b3ba..de4e22a 100644 --- a/.github/workflows/go.yaml +++ b/.github/workflows/go.yaml @@ -21,10 +21,5 @@ jobs: with: go-version: '1.22.2' - - name: Make - run: | - cd testdata/plugins - make - - name: Test - run: go test -v ./... + run: go run mage.go test diff --git a/go.mod b/go.mod index aabc361..dee7938 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/tilsor/ModSecIntl_wace_lib go 1.26.2 require ( + github.com/magefile/mage v1.17.2 github.com/nats-io/nats.go v1.51.0 github.com/tilsor/ModSecIntl_logging v1.0.1 go.opentelemetry.io/otel v1.43.0 diff --git a/go.sum b/go.sum index 7b259f6..d93fb52 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/magefile/mage v1.17.2 h1:fyXVu1eadI8Ap1HCCNgEhJ5McIWiYhLR8uol64ZZc40= +github.com/magefile/mage v1.17.2/go.mod h1:Yj51kqllmsgFpvvSzgrZPK9WtluG3kUhFaBUVLo4feA= github.com/nats-io/nats.go v1.51.0 h1:ByW84XTz6W03GSSsygsZcA+xgKK8vPGaa/FCAAEHnAI= github.com/nats-io/nats.go v1.51.0/go.mod h1:26HypzazeOkyO3/mqd1zZd53STJN0EjCYF9Uy2ZOBno= github.com/nats-io/nkeys v0.4.15 h1:JACV5jRVO9V856KOapQ7x+EY8Jo3qw1vJt/9Jpwzkk4= diff --git a/mage.go b/mage.go new file mode 100644 index 0000000..ab1e54e --- /dev/null +++ b/mage.go @@ -0,0 +1,16 @@ +//go:build ignore +// +build ignore + +// Entrypoint to mage for running without needing to install the command. +// https://magefile.org/zeroinstall/ +package main + +import ( + "os" + + "github.com/magefile/mage/mage" +) + +func main() { + os.Exit(mage.Main()) +} diff --git a/magefile.go b/magefile.go new file mode 100644 index 0000000..a656f4f --- /dev/null +++ b/magefile.go @@ -0,0 +1,81 @@ +//go:build mage + +package main + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/magefile/mage/mg" + "github.com/magefile/mage/sh" +) + +const pluginDir = "testdata/plugins" + +// Plugins builds all test plugins without coverage instrumentation. +func Plugins() error { + return buildPlugins() +} + +// pluginsCover builds all test plugins with coverage instrumentation. +func pluginsCover() error { + return buildPlugins("-cover") +} + +// buildPlugins compiles every .go file under testdata/plugins/model and +// testdata/plugins/decision into a .so plugin. Pass "-cover" to instrument +// for coverage. +func buildPlugins(extraFlags ...string) error { + for _, dir := range []string{"model", "decision"} { + sources, err := filepath.Glob(filepath.Join(pluginDir, dir, "*.go")) + if err != nil { + return err + } + for _, src := range sources { + out := strings.TrimSuffix(src, ".go") + ".so" + args := []string{"build", "-buildmode=plugin"} + args = append(args, extraFlags...) + args = append(args, "-o", out, src) + fmt.Printf("building %s\n", out) + if err := sh.RunV("go", args...); err != nil { + return err + } + } + } + return nil +} + +// Test builds the plugins and runs the full test suite. +func Test() error { + mg.Deps(Plugins) + return sh.RunV("go", "test", "./...", "-v", "-count=1") +} + +// TestCoverage builds coverage-instrumented plugins and runs the test suite +// with coverage reporting across all packages. +func TestCoverage() error { + mg.Deps(pluginsCover) + return sh.RunV("go", "test", "-cover", "./...", "-v", "-count=1", "-coverprofile=coverage.out") +} + +// Clean removes all compiled plugin .so files. +func Clean() error { + for _, pattern := range []string{ + filepath.Join(pluginDir, "model", "*.so"), + filepath.Join(pluginDir, "decision", "*.so"), + } { + matches, err := filepath.Glob(pattern) + if err != nil { + return err + } + for _, f := range matches { + fmt.Printf("removing %s\n", f) + if err := os.Remove(f); err != nil { + return err + } + } + } + return nil +} diff --git a/pluginmanager/pluginmanager.go b/pluginmanager/pluginmanager.go index 3c63503..34cd431 100644 --- a/pluginmanager/pluginmanager.go +++ b/pluginmanager/pluginmanager.go @@ -11,68 +11,32 @@ import ( "sync" "github.com/tilsor/ModSecIntl_wace_lib/configstore" + "github.com/tilsor/ModSecIntl_wace_lib/waceapi" "go.opentelemetry.io/otel/metric" "github.com/nats-io/nats.go" "github.com/tilsor/ModSecIntl_logging/logging" ) -// ResultData maps the model plugin ID with the corresponding analysis result. -type ModelResults struct { - ProbAttack float64 `json:"probattack"` - Data map[string]interface{} `json:"data"` -} - -type HTTPHeader struct { - Key string - Value string -} - -type HTTPPayload struct { - URI string - Method string - HTTPVersion string - RequestHeaders []HTTPHeader - RequestBody string - ResponseProtocol string - ResponseCode int - ResponseHeaders []HTTPHeader - ResponseBody string -} - -// ModelInput is the struct that contains the input data for the model plugin -type ModelInput struct { - TransactionId string `json:"transactionId"` - Payload HTTPPayload `json:"payload"` -} - -// DecisionInput is the struct that contains the input data for the decision plugin -type DecisionInput struct { - TransactionId string - Results map[string]ModelResults - ModelWeight map[string]float64 - WAFdata map[string]string -} - // ModelTransmitionResults is the struct that contains the results of the model plugin type ModelTransmitionResults struct { - TransactionId string `json:"transactionId"` - ModelResults `json:",inline"` - Error error `json:"error"` + TransactionId string `json:"transactionId"` + waceapi.ModelResults `json:",inline"` + Error error `json:"error"` } // modelPlugin is the struct that stores the model plugin and its type type modelPlugin struct { p *plugin.Plugin pluginType configstore.ModelPluginType - process func(ModelInput) (ModelResults, error) + process func(waceapi.ModelInput) (waceapi.ModelResults, error) reload func(map[string]string, metric.Meter) error } // decisionPlugin is the struct that stores the decision plugin type decisionPlugin struct { p *plugin.Plugin - checkResults func(DecisionInput) (bool, error) + checkResults func(waceapi.DecisionInput) (bool, error) reload func(map[string]string, metric.Meter) error } @@ -162,7 +126,7 @@ func (pm *PluginManager) loadModelPlugins(meter metric.Meter) error { logger.Printf(logging.WARN, "| %s | cannot open plugin: %v", data.ID, err) continue } - var processFunc func(ModelInput) (ModelResults, error) + var processFunc func(waceapi.ModelInput) (waceapi.ModelResults, error) // TODO: change mode to bool if data.Mode == "async" || conf.ModelPlugins[data.ID].Remote { f, err := p.Lookup(modelInitAsyncFunctionName) @@ -170,14 +134,14 @@ func (pm *PluginManager) loadModelPlugins(meter metric.Meter) error { logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) continue } - initPlugin, ok := f.(func(map[string]string, metric.Meter, func(func(ModelInput) (ModelResults, error))) error) + initPlugin, ok := f.(func(map[string]string, metric.Meter, func(func(waceapi.ModelInput) (waceapi.ModelResults, error))) error) if !ok { logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid %s function type", data.ID, modelInitAsyncFunctionName) continue } // plugin initialization - err = initPlugin(data.Params, meter, func(modelProcess func(ModelInput) (ModelResults, error)) { + err = initPlugin(data.Params, meter, func(modelProcess func(waceapi.ModelInput) (waceapi.ModelResults, error)) { ModelProcessHandler(data.ID, modelProcess) }) if err != nil { @@ -204,7 +168,7 @@ func (pm *PluginManager) loadModelPlugins(meter metric.Meter) error { logger.Printf(logging.WARN, "| %s | cannot load plugin: cannot load %s function", data.ID, modelProcessFunctionName) continue } - processFunc, ok = procFunc.(func(ModelInput) (ModelResults, error)) + processFunc, ok = procFunc.(func(waceapi.ModelInput) (waceapi.ModelResults, error)) if !ok { logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid %s function type", data.ID, modelProcessFunctionName) continue @@ -273,7 +237,7 @@ func (pm *PluginManager) loadDecisionPlugins(meter metric.Meter) error { logger.Printf(logging.ERROR, "| %s | cannot load plugin %s function: %v", data.ID, decisionCheckFuncionName, err) continue } - checkResults, ok := checkFunc.(func(DecisionInput) (bool, error)) + checkResults, ok := checkFunc.(func(waceapi.DecisionInput) (bool, error)) if !ok { logger.Printf(logging.ERROR, "| %s | %s lookup failed for plugin: invalid function type", data.ID, decisionCheckFuncionName) continue @@ -379,8 +343,8 @@ func (p *PluginManager) RemoveAsyncModelChannel(transactionId string, t configst } // AddToQueue adds a payload to the model queue -func (p *PluginManager) AddToQueue(modelID, transactionID string, payload HTTPPayload) error { - payloadToSend := &ModelInput{ +func (p *PluginManager) AddToQueue(modelID, transactionID string, payload waceapi.HTTPPayload) error { + payloadToSend := &waceapi.ModelInput{ TransactionId: transactionID, Payload: payload, } @@ -395,7 +359,7 @@ func (p *PluginManager) AddToQueue(modelID, transactionID string, payload HTTPPa } // Process is in charge of calling the model plugin with id modelID -func (p *PluginManager) Process(modelID, transactionId string, payload HTTPPayload, t configstore.ModelPluginType, modelPlugStatus chan ModelStatus) error { +func (p *PluginManager) Process(modelID, transactionId string, payload waceapi.HTTPPayload, t configstore.ModelPluginType, modelPlugStatus chan ModelStatus) error { conf, err := configstore.Get() if err != nil { return err @@ -423,7 +387,7 @@ func (p *PluginManager) Process(modelID, transactionId string, payload HTTPPaylo modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("model plugin is async")} return nil } else { - res, err := mp.process(ModelInput{TransactionId: transactionId, Payload: payload}) + res, err := mp.process(waceapi.ModelInput{TransactionId: transactionId, Payload: payload}) if err != nil { modelPlugStatus <- ModelStatus{ModelID: modelID, Err: err} @@ -461,15 +425,15 @@ func (p *PluginManager) CheckResult(transactionId, decisionId string, wafParams return false, nil } - modelResultMap := make(map[string]ModelResults) + modelResultMap := make(map[string]waceapi.ModelResults) modelWeightMap := make(map[string]float64) transactionResults.(*sync.Map).Range(func(key, value interface{}) bool { - modelResultMap[key.(string)] = value.(ModelResults) + modelResultMap[key.(string)] = value.(waceapi.ModelResults) modelWeightMap[key.(string)] = cs.ModelPlugins[key.(string)].Weight return true }) - res, err := dp.checkResults(DecisionInput{TransactionId: transactionId, Results: modelResultMap, ModelWeight: modelWeightMap, WAFdata: wafParams}) + res, err := dp.checkResults(waceapi.DecisionInput{TransactionId: transactionId, Results: modelResultMap, ModelWeight: modelWeightMap, WAFdata: wafParams}) logger.TPrintf(logging.INFO, transactionId, "%s | transaction checked. Block: %t ", decisionId, res) return res, err @@ -514,7 +478,7 @@ func (p *PluginManager) ModelResultsHandler(modelId string) error { modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, Err: fmt.Errorf("transaction results not found")} return } - modelResult := ModelResults{ProbAttack: data.ProbAttack, Data: data.Data} + modelResult := waceapi.ModelResults{ProbAttack: data.ProbAttack, Data: data.Data} resultSyncMap.(*sync.Map).Store(modelId, modelResult) } modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, ProbAttack: data.ProbAttack, Err: nil} @@ -540,7 +504,7 @@ func (p *PluginManager) ModelResultsHandler(modelId string) error { } // ModelProcessHandler listens for messages on the model queue -func ModelProcessHandler(modelId string, modelProcess func(ModelInput) (ModelResults, error)) error { +func ModelProcessHandler(modelId string, modelProcess func(waceapi.ModelInput) (waceapi.ModelResults, error)) error { logger := logging.Get() logger.Printf(logging.INFO, "Model: %s | Starting model process handler", modelId) cs, err := configstore.Get() @@ -557,13 +521,13 @@ func ModelProcessHandler(modelId string, modelProcess func(ModelInput) (ModelRes _, err = nc.Subscribe(modelId, func(msg *nats.Msg) { go func(msg nats.Msg) { - data := &ModelInput{} + data := &waceapi.ModelInput{} err := json.Unmarshal(msg.Data, data) if err != nil { logger.Printf(logging.ERROR, "Model: %s | Failed to parse JSON payload", modelId) } else { res, err := modelProcess(*data) - modelResult := ModelResults{ProbAttack: res.ProbAttack, Data: res.Data} + modelResult := waceapi.ModelResults{ProbAttack: res.ProbAttack, Data: res.Data} payloadToSend := &ModelTransmitionResults{ TransactionId: data.TransactionId, ModelResults: modelResult, diff --git a/pluginmanager/pluginmanager_test.go b/pluginmanager/pluginmanager_test.go index f14edfb..ce1af1d 100644 --- a/pluginmanager/pluginmanager_test.go +++ b/pluginmanager/pluginmanager_test.go @@ -1,4 +1,4 @@ -package pluginmanager_test +package pluginmanager import ( "math/rand" @@ -7,7 +7,7 @@ import ( "github.com/tilsor/ModSecIntl_logging/logging" "github.com/tilsor/ModSecIntl_wace_lib/configstore" - "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" + "github.com/tilsor/ModSecIntl_wace_lib/waceapi" "go.opentelemetry.io/otel/sdk/metric" "gopkg.in/yaml.v3" ) @@ -130,7 +130,7 @@ func generateRandomID() string { // setupPluginManager creates a fresh ConfigStore from the given YAML config, // returns an initialised PluginManager, and registers configstore.Clean as a // test cleanup function. -func setupPluginManager(t *testing.T, configuration []byte) *pluginmanager.PluginManager { +func setupPluginManager(t *testing.T, configuration []byte) *PluginManager { t.Helper() configstore.Clean() cs, err := configstore.New() @@ -150,9 +150,9 @@ func setupPluginManager(t *testing.T, configuration []byte) *pluginmanager.Plugi if err := logger.LoadLogger(cs.LogPath, cs.LogLevel); err != nil { t.Fatalf("LoadLogger failed: %v", err) } - pm, err := pluginmanager.New(testMeter) + pm, err := New(testMeter) if err != nil { - t.Fatalf("pluginmanager.New() failed: %v", err) + t.Fatalf("New() failed: %v", err) } return pm } @@ -202,8 +202,8 @@ func TestPluginManagerProcessSync(t *testing.T) { pm.InitTransaction(txID) defer pm.CloseTransaction(txID) - ch := make(chan pluginmanager.ModelStatus, 1) - go pm.Process(tt.modelID, txID, pluginmanager.HTTPPayload{URI: "/test"}, configstore.Everything, ch) + ch := make(chan ModelStatus, 1) + go pm.Process(tt.modelID, txID, waceapi.HTTPPayload{URI: "/test"}, configstore.Everything, ch) status := <-ch if (status.Err != nil) != tt.wantErr { @@ -228,8 +228,8 @@ func TestPluginManagerProcessNonexistentPlugin(t *testing.T) { pm.InitTransaction(txID) defer pm.CloseTransaction(txID) - ch := make(chan pluginmanager.ModelStatus, 1) - go pm.Process("nonexistent", txID, pluginmanager.HTTPPayload{}, configstore.Everything, ch) + ch := make(chan ModelStatus, 1) + go pm.Process("nonexistent", txID, waceapi.HTTPPayload{}, configstore.Everything, ch) status := <-ch if status.Err == nil { t.Error("Process with nonexistent plugin ID should return error via channel") @@ -276,8 +276,8 @@ func TestPluginManagerCheckResult(t *testing.T) { pm.InitTransaction(txID) defer pm.CloseTransaction(txID) - ch := make(chan pluginmanager.ModelStatus, 1) - go pm.Process(tt.modelID, txID, pluginmanager.HTTPPayload{URI: "/test"}, configstore.Everything, ch) + ch := make(chan ModelStatus, 1) + go pm.Process(tt.modelID, txID, waceapi.HTTPPayload{URI: "/test"}, configstore.Everything, ch) <-ch result, err := pm.CheckResult(txID, "simple", tt.wafParams) @@ -312,8 +312,8 @@ func TestPluginManagerTransactionLifecycle(t *testing.T) { txID := generateRandomID() pm.InitTransaction(txID) - ch := make(chan pluginmanager.ModelStatus, 1) - go pm.Process("trivial", txID, pluginmanager.HTTPPayload{URI: "/test"}, configstore.Everything, ch) + ch := make(chan ModelStatus, 1) + go pm.Process("trivial", txID, waceapi.HTTPPayload{URI: "/test"}, configstore.Everything, ch) status := <-ch if status.Err != nil { t.Fatalf("Process error: %v", status.Err) @@ -360,8 +360,8 @@ func TestPluginManagerLoadModelFailures(t *testing.T) { pm.InitTransaction(txID) defer pm.CloseTransaction(txID) - ch := make(chan pluginmanager.ModelStatus, 1) - go pm.Process(tt.modelID, txID, pluginmanager.HTTPPayload{}, configstore.Everything, ch) + ch := make(chan ModelStatus, 1) + go pm.Process(tt.modelID, txID, waceapi.HTTPPayload{}, configstore.Everything, ch) status := <-ch if status.Err == nil { t.Errorf("Process(%q): expected error (plugin should not have been loaded)", tt.modelID) @@ -418,8 +418,8 @@ func TestPluginManagerReload(t *testing.T) { pm.InitTransaction(txID) defer pm.CloseTransaction(txID) - ch := make(chan pluginmanager.ModelStatus, 1) - go pm.Process("trivial", txID, pluginmanager.HTTPPayload{URI: "/test"}, configstore.Everything, ch) + ch := make(chan ModelStatus, 1) + go pm.Process("trivial", txID, waceapi.HTTPPayload{URI: "/test"}, configstore.Everything, ch) status := <-ch if status.Err != nil { t.Errorf("Process after Reload: unexpected error: %v", status.Err) @@ -447,8 +447,8 @@ func TestPluginManagerProcessTypeMismatch(t *testing.T) { defer pm.CloseTransaction(txID) // Pass Everything — does not match the registered RequestHeaders type. - ch := make(chan pluginmanager.ModelStatus, 1) - go pm.Process("trivial", txID, pluginmanager.HTTPPayload{URI: "/test"}, configstore.Everything, ch) + ch := make(chan ModelStatus, 1) + go pm.Process("trivial", txID, waceapi.HTTPPayload{URI: "/test"}, configstore.Everything, ch) status := <-ch if status.Err == nil { t.Error("Process with mismatched plugin type should return error via channel") @@ -469,8 +469,8 @@ func TestPluginManagerReloadChangesOutput(t *testing.T) { pm.InitTransaction(txID) defer pm.CloseTransaction(txID) - ch := make(chan pluginmanager.ModelStatus, 1) - go pm.Process("param", txID, pluginmanager.HTTPPayload{URI: "/test"}, configstore.Everything, ch) + ch := make(chan ModelStatus, 1) + go pm.Process("param", txID, waceapi.HTTPPayload{URI: "/test"}, configstore.Everything, ch) status := <-ch if status.Err != nil { t.Errorf("Process: unexpected error: %v", status.Err) @@ -512,6 +512,88 @@ func TestPluginManagerReloadChangesOutput(t *testing.T) { runProcess(0.8) } +// TestPluginManagerAddModelChannelAndClose exercises AddModelChannel (sync path) +// and the CloseTransaction sync-cleanup branch, which only runs when +// syncModelsChannels has an entry for the transaction. +func TestPluginManagerAddModelChannelAndClose(t *testing.T) { + config := []byte(baseConfig + "modelplugins:\n" + trivialPlugin) + pm := setupPluginManager(t, config) + + txID := generateRandomID() + pm.InitTransaction(txID) + + ch := make(chan ModelStatus) + pm.AddModelChannel(txID, configstore.Everything, ch, "sync") + + // CloseTransaction must close the registered channel and clean up maps. + pm.CloseTransaction(txID) + + // A closed channel returns immediately with ok=false. + select { + case _, ok := <-ch: + if ok { + t.Error("expected channel to be closed by CloseTransaction") + } + default: + t.Error("CloseTransaction should have closed the registered channel") + } +} + +// TestPluginManagerProcessAsyncPlugin verifies that Process sends an error when +// the configstore marks the plugin as async, even though it is present in the +// in-process plugin map. +func TestPluginManagerProcessAsyncPlugin(t *testing.T) { + // Load trivial as sync so it ends up in pm.modelPlugins. + config := []byte(baseConfig + "modelplugins:\n" + trivialPlugin) + pm := setupPluginManager(t, config) + + // Update configstore to mark the plugin as async without reloading pm. + asyncConf := baseConfig + `modelplugins: + - id: "trivial" + path: "../testdata/plugins/model/trivial.so" + weight: 1 + plugintype: "Everything" + mode: async +` + cs, err := configstore.Get() + if err != nil { + t.Fatalf("configstore.Get: %v", err) + } + var aux configstore.ConfigFileData + if err := yaml.Unmarshal([]byte(asyncConf), &aux); err != nil { + t.Fatalf("yaml.Unmarshal: %v", err) + } + if err := cs.SetConfig(aux); err != nil { + t.Fatalf("SetConfig: %v", err) + } + + txID := generateRandomID() + pm.InitTransaction(txID) + defer pm.CloseTransaction(txID) + + ch := make(chan ModelStatus, 1) + go pm.Process("trivial", txID, waceapi.HTTPPayload{URI: "/test"}, configstore.Everything, ch) + status := <-ch + if status.Err == nil { + t.Error("Process on async-configured plugin should return error via channel") + } +} + +// TestPluginManagerCheckResultWithoutTransaction verifies that CheckResult +// returns an error when InitTransaction was never called (results map absent). +func TestPluginManagerCheckResultWithoutTransaction(t *testing.T) { + config := []byte(baseConfig + "modelplugins:\n" + trivialPlugin + "decisionplugins:\n" + simplePlugin) + pm := setupPluginManager(t, config) + + // Deliberately skip pm.InitTransaction — no results entry exists. + txID := generateRandomID() + + _, err := pm.CheckResult(txID, "simple", make(map[string]string)) + if err == nil { + t.Error("CheckResult without InitTransaction should return error") + } +} + // TestPluginManagerProcessWithoutTransaction verifies that Process sends an // error when the transaction was never initialised (results map is absent). func TestPluginManagerProcessWithoutTransaction(t *testing.T) { @@ -521,8 +603,8 @@ func TestPluginManagerProcessWithoutTransaction(t *testing.T) { // Deliberately skip pm.InitTransaction so there is no results entry. txID := generateRandomID() - ch := make(chan pluginmanager.ModelStatus, 1) - go pm.Process("trivial", txID, pluginmanager.HTTPPayload{URI: "/test"}, configstore.Everything, ch) + ch := make(chan ModelStatus, 1) + go pm.Process("trivial", txID, waceapi.HTTPPayload{URI: "/test"}, configstore.Everything, ch) status := <-ch if status.Err == nil { t.Error("Process without InitTransaction should return error via channel") diff --git a/testdata/plugins/Makefile b/testdata/plugins/Makefile deleted file mode 100644 index 25ada92..0000000 --- a/testdata/plugins/Makefile +++ /dev/null @@ -1,20 +0,0 @@ -all: model/trivial.so model/trivial2.so \ - model/no_init.so model/wrong_init.so model/error_init.so \ - model/no_req.so model/wrong_req.so model/error_req.so \ - model/param.so \ - decision/test.so \ - decision/no_check.so decision/wrong_check.so decision/error_check.so \ - decision/simple.so - -# Pass COVER=1 to build plugins instrumented for coverage: -# make -B COVER=1 -# This is required when running: go test -cover ./... -ifdef COVER -FLAGS += -cover -endif - -%.so: %.go - go build $(FLAGS) -buildmode=plugin -o $@ $< - -clean: - rm -f model/*.so decision/*.so diff --git a/testdata/plugins/decision/simple.go b/testdata/plugins/decision/simple.go index 9143158..59b2a89 100644 --- a/testdata/plugins/decision/simple.go +++ b/testdata/plugins/decision/simple.go @@ -8,7 +8,7 @@ import ( "strconv" lg "github.com/tilsor/ModSecIntl_logging/logging" - pm "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" + "github.com/tilsor/ModSecIntl_wace_lib/waceapi" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" ) @@ -24,7 +24,7 @@ func InitPlugin(params map[string]string, meter metric.Meter) error { return nil } -func CheckResults(decisionInput pm.DecisionInput) (bool, error) { +func CheckResults(decisionInput waceapi.DecisionInput) (bool, error) { logger := lg.Get() var totalModelW float64 = 0 var modelDetectionCount int = 0 diff --git a/testdata/plugins/decision/test.go b/testdata/plugins/decision/test.go index 350b7cc..80e7aea 100644 --- a/testdata/plugins/decision/test.go +++ b/testdata/plugins/decision/test.go @@ -7,7 +7,7 @@ import ( "strconv" lg "github.com/tilsor/ModSecIntl_logging/logging" - pm "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" + "github.com/tilsor/ModSecIntl_wace_lib/waceapi" "go.opentelemetry.io/otel/metric" ) @@ -22,7 +22,7 @@ func InitPlugin(params map[string]string, meter metric.Meter) error { // CheckResults returns true (block traffic) if WAF says so, and false // in other case. // func CheckResults(transactionID string, modelRes map[string]float64, modelWeight map[string]float64, modelThres map[string]float64, wafData map[string]string) (bool, error) { -func CheckResults(decisionInput pm.DecisionInput) (bool, error) { +func CheckResults(decisionInput waceapi.DecisionInput) (bool, error) { logger := lg.Get() modelRes := decisionInput.Results diff --git a/testdata/plugins/decision/weighted_sum.go b/testdata/plugins/decision/weighted_sum.go index c987aa8..6e34961 100644 --- a/testdata/plugins/decision/weighted_sum.go +++ b/testdata/plugins/decision/weighted_sum.go @@ -8,7 +8,7 @@ import ( "strconv" lg "github.com/tilsor/ModSecIntl_logging/logging" - pm "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" + "github.com/tilsor/ModSecIntl_wace_lib/waceapi" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" ) @@ -47,7 +47,7 @@ func InitPlugin(params map[string]string, meter metric.Meter) error { return nil } -func CheckResults(decisionInput pm.DecisionInput) (bool, error) { +func CheckResults(decisionInput waceapi.DecisionInput) (bool, error) { var weightedSum float64 = 0 var weightsSum float64 = 0 for key, value := range decisionInput.Results { diff --git a/testdata/plugins/model/error_init.go b/testdata/plugins/model/error_init.go index d4645e4..397f1ae 100644 --- a/testdata/plugins/model/error_init.go +++ b/testdata/plugins/model/error_init.go @@ -6,7 +6,7 @@ package main import ( "errors" - pm "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" + "github.com/tilsor/ModSecIntl_wace_lib/waceapi" "go.opentelemetry.io/otel/metric" ) @@ -16,8 +16,8 @@ func InitPlugin(params map[string]string, meter metric.Meter) error { } // Process always returns 0 probability of attack -func Process(input pm.ModelInput) (pm.ModelResults, error) { - result := pm.ModelResults{ +func Process(input waceapi.ModelInput) (waceapi.ModelResults, error) { + result := waceapi.ModelResults{ ProbAttack: 0.0, Data: make(map[string]interface{}), } diff --git a/testdata/plugins/model/error_req.go b/testdata/plugins/model/error_req.go index ef77472..fb7afed 100644 --- a/testdata/plugins/model/error_req.go +++ b/testdata/plugins/model/error_req.go @@ -6,7 +6,7 @@ package main import ( "errors" - pm "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" + "github.com/tilsor/ModSecIntl_wace_lib/waceapi" "go.opentelemetry.io/otel/metric" ) @@ -16,8 +16,8 @@ func InitPlugin(params map[string]string, meter metric.Meter) error { } // Process always returns 0 probability of attack -func Process(input pm.ModelInput) (pm.ModelResults, error) { - result := pm.ModelResults{ +func Process(input waceapi.ModelInput) (waceapi.ModelResults, error) { + result := waceapi.ModelResults{ ProbAttack: 0.0, Data: make(map[string]interface{}), } diff --git a/testdata/plugins/model/no_init.go b/testdata/plugins/model/no_init.go index 9ca99dd..96a0495 100644 --- a/testdata/plugins/model/no_init.go +++ b/testdata/plugins/model/no_init.go @@ -4,13 +4,13 @@ package main import ( - pm "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" + "github.com/tilsor/ModSecIntl_wace_lib/waceapi" "go.opentelemetry.io/otel/metric" ) // Process always returns 0 probability of attack -func Process(input pm.ModelInput) (pm.ModelResults, error) { - result := pm.ModelResults{ +func Process(input waceapi.ModelInput) (waceapi.ModelResults, error) { + result := waceapi.ModelResults{ ProbAttack: 0.0, Data: make(map[string]interface{}), } diff --git a/testdata/plugins/model/param.go b/testdata/plugins/model/param.go index b8871fa..a9a7862 100644 --- a/testdata/plugins/model/param.go +++ b/testdata/plugins/model/param.go @@ -6,7 +6,7 @@ import ( "strconv" lg "github.com/tilsor/ModSecIntl_logging/logging" - pm "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" + "github.com/tilsor/ModSecIntl_wace_lib/waceapi" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" ) @@ -35,16 +35,16 @@ func InitPlugin(params map[string]string, meter metric.Meter) error { return nil } -func InitPluginAsync(params map[string]string, meter metric.Meter, natsManager func(func(pm.ModelInput) (pm.ModelResults, error))) error { +func InitPluginAsync(params map[string]string, meter metric.Meter, natsManager func(func(waceapi.ModelInput) (waceapi.ModelResults, error))) error { InitPlugin(params, meter) natsManager(Process) return nil } -func Process(input pm.ModelInput) (pm.ModelResults, error) { +func Process(input waceapi.ModelInput) (waceapi.ModelResults, error) { logger := lg.Get() logger.TPrintf(lg.WARN, input.TransactionId, "[param:Process] \"%v\"\n", input.Payload) - return pm.ModelResults{ + return waceapi.ModelResults{ ProbAttack: result, Data: make(map[string]interface{}), }, nil diff --git a/testdata/plugins/model/trivial.go b/testdata/plugins/model/trivial.go index 9e1315d..9e65bd0 100644 --- a/testdata/plugins/model/trivial.go +++ b/testdata/plugins/model/trivial.go @@ -7,7 +7,7 @@ import ( "context" lg "github.com/tilsor/ModSecIntl_logging/logging" - pm "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" + "github.com/tilsor/ModSecIntl_wace_lib/waceapi" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" ) @@ -26,16 +26,16 @@ func InitPlugin(params map[string]string, meter metric.Meter) error { return nil } -func InitPluginAsync(params map[string]string, meter metric.Meter, natsManager func(func(pm.ModelInput) (pm.ModelResults, error))) error { +func InitPluginAsync(params map[string]string, meter metric.Meter, natsManager func(func(waceapi.ModelInput) (waceapi.ModelResults, error))) error { InitPlugin(params, meter) natsManager(Process) return nil } -func Process(input pm.ModelInput) (pm.ModelResults, error) { +func Process(input waceapi.ModelInput) (waceapi.ModelResults, error) { logger := lg.Get() logger.TPrintf(lg.WARN, input.TransactionId, "[trivial:Process] \"%v\"\n", input.Payload) - result := pm.ModelResults{ + result := waceapi.ModelResults{ ProbAttack: 0.0, Data: make(map[string]interface{}), } diff --git a/testdata/plugins/model/trivial2.go b/testdata/plugins/model/trivial2.go index c60470b..3554d11 100644 --- a/testdata/plugins/model/trivial2.go +++ b/testdata/plugins/model/trivial2.go @@ -7,7 +7,7 @@ import ( "context" lg "github.com/tilsor/ModSecIntl_logging/logging" - pm "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" + "github.com/tilsor/ModSecIntl_wace_lib/waceapi" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" ) @@ -26,16 +26,16 @@ func InitPlugin(params map[string]string, meter metric.Meter) error { return nil } -func InitPluginAsync(params map[string]string, meter metric.Meter, natsManager func(func(pm.ModelInput) (pm.ModelResults, error))) error { +func InitPluginAsync(params map[string]string, meter metric.Meter, natsManager func(func(waceapi.ModelInput) (waceapi.ModelResults, error))) error { InitPlugin(params, meter) natsManager(Process) return nil } -func Process(input pm.ModelInput) (pm.ModelResults, error) { +func Process(input waceapi.ModelInput) (waceapi.ModelResults, error) { logger := lg.Get() logger.TPrintf(lg.WARN, input.TransactionId, "[trivial2:Proccess] \"%v\"\n", input.Payload) - result := pm.ModelResults{ + result := waceapi.ModelResults{ ProbAttack: 1.0, Data: make(map[string]interface{}), } diff --git a/testdata/plugins/model/trivial_async.go b/testdata/plugins/model/trivial_async.go index b473ca4..d0171a1 100644 --- a/testdata/plugins/model/trivial_async.go +++ b/testdata/plugins/model/trivial_async.go @@ -10,7 +10,7 @@ import ( "time" lg "github.com/tilsor/ModSecIntl_logging/logging" - pm "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" + "github.com/tilsor/ModSecIntl_wace_lib/waceapi" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" ) @@ -40,23 +40,24 @@ func InitPlugin(params map[string]string, meter metric.Meter) error { return nil } -func InitPluginAsync(params map[string]string, meter metric.Meter, natsManager func(func(pm.ModelInput) (pm.ModelResults, error))) error { +func InitPluginAsync(params map[string]string, meter metric.Meter, natsManager func(func(waceapi.ModelInput) (waceapi.ModelResults, error))) error { InitPlugin(params, meter) natsManager(Process) return nil } -func Process(input pm.ModelInput) (pm.ModelResults, error) { +func Process(input waceapi.ModelInput) (waceapi.ModelResults, error) { time.Sleep(time.Duration(sleepTime) * time.Second) logger := lg.Get() logger.TPrintf(lg.WARN, input.TransactionId, "[trivial_async:Process] \"%v\"\n", input.Payload) - result := pm.ModelResults{ + result := waceapi.ModelResults{ ProbAttack: 0.0, Data: make(map[string]interface{}), } return result, nil } -// ReloadPlugin reload the pluginfunc ReloadPlugin(params map[string]string, meter metric.Meter) error { +// ReloadPlugin reload the plugin +func ReloadPlugin(params map[string]string, meter metric.Meter) error { return nil } diff --git a/testdata/plugins/model/trivial_async2.go b/testdata/plugins/model/trivial_async2.go index 1d4eecc..e023023 100644 --- a/testdata/plugins/model/trivial_async2.go +++ b/testdata/plugins/model/trivial_async2.go @@ -10,7 +10,7 @@ import ( "time" lg "github.com/tilsor/ModSecIntl_logging/logging" - pm "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" + "github.com/tilsor/ModSecIntl_wace_lib/waceapi" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" ) @@ -40,17 +40,17 @@ func InitPlugin(params map[string]string, meter metric.Meter) error { return nil } -func InitPluginAsync(params map[string]string, meter metric.Meter, natsManager func(func(pm.ModelInput) (pm.ModelResults, error))) error { +func InitPluginAsync(params map[string]string, meter metric.Meter, natsManager func(func(waceapi.ModelInput) (waceapi.ModelResults, error))) error { InitPlugin(params, meter) natsManager(Process) return nil } -func Process(input pm.ModelInput) (pm.ModelResults, error) { +func Process(input waceapi.ModelInput) (waceapi.ModelResults, error) { time.Sleep(time.Duration(sleepTime) * time.Second) logger := lg.Get() logger.TPrintf(lg.WARN, input.TransactionId, "[trivial_async2:Process] \"%v\"\n", input.Payload) - result := pm.ModelResults{ + result := waceapi.ModelResults{ ProbAttack: 1.0, Data: make(map[string]interface{}), } diff --git a/waceapi/waceapi.go b/waceapi/waceapi.go new file mode 100644 index 0000000..28180b9 --- /dev/null +++ b/waceapi/waceapi.go @@ -0,0 +1,37 @@ +package waceapi + +type ModelResults struct { + ProbAttack float64 `json:"probattack"` + Data map[string]interface{} `json:"data"` +} + +type HTTPHeader struct { + Key string + Value string +} + +type HTTPPayload struct { + URI string + Method string + HTTPVersion string + RequestHeaders []HTTPHeader + RequestBody string + ResponseProtocol string + ResponseCode int + ResponseHeaders []HTTPHeader + ResponseBody string +} + +// ModelInput is the struct that contains the input data for the model plugin +type ModelInput struct { + TransactionId string `json:"transactionId"` + Payload HTTPPayload `json:"payload"` +} + +// DecisionInput is the struct that contains the input data for the decision plugin +type DecisionInput struct { + TransactionId string + Results map[string]ModelResults + ModelWeight map[string]float64 + WAFdata map[string]string +} diff --git a/wacecore.go b/wacecore.go index 88a217f..a41d6e4 100644 --- a/wacecore.go +++ b/wacecore.go @@ -10,6 +10,7 @@ import ( "time" "github.com/tilsor/ModSecIntl_wace_lib/configstore" + "github.com/tilsor/ModSecIntl_wace_lib/waceapi" "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" @@ -60,7 +61,7 @@ func addTransactionAnalysis(transactionID string) { // callPlugins calls the model plugins in the given list, with the given input. // It waits for all the synchronous model plugins to finish, and sends the // result to the client. The asynchronous model plugins are executed in parallel -func callPlugins(input pluginmanager.HTTPPayload, models []string, t configstore.ModelPluginType, transactionID string) error { +func callPlugins(input waceapi.HTTPPayload, models []string, t configstore.ModelPluginType, transactionID string) error { logger := logging.Get() // channel to receive the status of the execution of the analysis @@ -176,7 +177,7 @@ func InitTransaction(transactionId string) { } // Analyze calls the model plugins with the given payload and models -func Analyze(modelsTypeAsString, transactionId string, payload pluginmanager.HTTPPayload, models []string) error { +func Analyze(modelsTypeAsString, transactionId string, payload waceapi.HTTPPayload, models []string) error { if len(models) > 0 { logger := logging.Get() modelsType, err := configstore.StringToPluginType(modelsTypeAsString) diff --git a/wacecore_test.go b/wacecore_test.go index 9f9fd67..27a2f07 100644 --- a/wacecore_test.go +++ b/wacecore_test.go @@ -9,7 +9,7 @@ import ( "time" "github.com/tilsor/ModSecIntl_wace_lib/configstore" - "github.com/tilsor/ModSecIntl_wace_lib/pluginmanager" + "github.com/tilsor/ModSecIntl_wace_lib/waceapi" "go.opentelemetry.io/otel/sdk/metric" "gopkg.in/yaml.v3" @@ -21,7 +21,7 @@ var requestVersion = "HTTP/1.1" // var requestLine = "POST /cgi-bin/process.cgi HTTP/1.1\n" -var requestHeaders = []pluginmanager.HTTPHeader{ +var requestHeaders = []waceapi.HTTPHeader{ {Key: "User-Agent", Value: "Mozilla/4.0 (compatible; MSIE5.01; Windows NT)"}, {Key: "Host", Value: "www.tutorialspoint.com"}, {Key: "Content-Type", Value: "application/x-www-form-urlencoded"}, @@ -31,7 +31,7 @@ var requestHeaders = []pluginmanager.HTTPHeader{ {Key: "Connection", Value: "Keep-Alive"}, } -var requestHeadersPayload = pluginmanager.HTTPPayload{ +var requestHeadersPayload = waceapi.HTTPPayload{ URI: requestURI, Method: requestMethod, HTTPVersion: requestVersion, @@ -39,7 +39,7 @@ var requestHeadersPayload = pluginmanager.HTTPPayload{ } var requestBody = "licenseID=string&content=string&/paramsXML=string\n" -var wholeRequest = pluginmanager.HTTPPayload{ +var wholeRequest = waceapi.HTTPPayload{ URI: requestURI, Method: requestMethod, HTTPVersion: requestVersion, @@ -49,7 +49,7 @@ var wholeRequest = pluginmanager.HTTPPayload{ // var wholeRequest = requestLine + requestHeaders + "\n" + requestBody var responseCode = 200 var responseProto = "HTTP/1.1" -var responseHeaders = []pluginmanager.HTTPHeader{ +var responseHeaders = []waceapi.HTTPHeader{ {Key: "Date", Value: "Mon, 27 Jul 2009 12:28:53 GMT"}, {Key: "Server", Value: "Apache/2.2.14 (Win32)"}, {Key: "Last-Modified", Value: "Wed, 22 Jul 2009 19:15:56 GMT"}, @@ -58,7 +58,7 @@ var responseHeaders = []pluginmanager.HTTPHeader{ {Key: "Connection", Value: "Closed"}, } -var responseHeadersPayload = pluginmanager.HTTPPayload{ +var responseHeadersPayload = waceapi.HTTPPayload{ ResponseProtocol: responseProto, ResponseCode: responseCode, ResponseHeaders: responseHeaders, @@ -71,7 +71,7 @@ var responseBody = ` ` -var wholeResponse = pluginmanager.HTTPPayload{ +var wholeResponse = waceapi.HTTPPayload{ ResponseProtocol: responseProto, ResponseCode: responseCode, ResponseHeaders: responseHeaders, @@ -259,7 +259,7 @@ func generateRandomID() string { func TestAnalyze(t *testing.T) { type step struct { payloadType string - payload pluginmanager.HTTPPayload + payload waceapi.HTTPPayload plugins []string } tests := []struct { @@ -273,7 +273,7 @@ func TestAnalyze(t *testing.T) { config: configAllModels, steps: []step{ {"RequestHeaders", requestHeadersPayload, []string{"trivialRequestHeaders"}}, - {"RequestBody", pluginmanager.HTTPPayload{RequestBody: requestBody}, []string{"trivialRequestBody"}}, + {"RequestBody", waceapi.HTTPPayload{RequestBody: requestBody}, []string{"trivialRequestBody"}}, }, }, { @@ -288,7 +288,7 @@ func TestAnalyze(t *testing.T) { config: configAllModels, steps: []step{ {"ResponseHeaders", responseHeadersPayload, []string{"trivialResponseHeaders"}}, - {"ResponseBody", pluginmanager.HTTPPayload{ResponseBody: responseBody}, []string{"trivialResponseBody"}}, + {"ResponseBody", waceapi.HTTPPayload{ResponseBody: responseBody}, []string{"trivialResponseBody"}}, }, }, { @@ -560,13 +560,13 @@ func TestAnalyzeMultiPhase(t *testing.T) { phases := []struct { payloadType string - payload pluginmanager.HTTPPayload + payload waceapi.HTTPPayload models []string }{ {"RequestHeaders", requestHeadersPayload, []string{"trivialRequestHeaders"}}, - {"RequestBody", pluginmanager.HTTPPayload{RequestBody: requestBody}, []string{"trivialRequestBody"}}, + {"RequestBody", waceapi.HTTPPayload{RequestBody: requestBody}, []string{"trivialRequestBody"}}, {"ResponseHeaders", responseHeadersPayload, []string{"trivialResponseHeaders"}}, - {"ResponseBody", pluginmanager.HTTPPayload{ResponseBody: responseBody}, []string{"trivialResponseBody"}}, + {"ResponseBody", waceapi.HTTPPayload{ResponseBody: responseBody}, []string{"trivialResponseBody"}}, } for _, p := range phases { @@ -662,7 +662,7 @@ func TestReload(t *testing.T) { txID := generateRandomID() InitTransaction(txID) defer CloseTransaction(txID) - if err := Analyze("Everything", txID, pluginmanager.HTTPPayload{URI: "/test"}, []string{"param"}); err != nil { + if err := Analyze("Everything", txID, waceapi.HTTPPayload{URI: "/test"}, []string{"param"}); err != nil { t.Fatalf("Analyze after Reload: %v", err) } if _, err := CheckTransaction(txID, "simple", make(map[string]string)); err != nil { @@ -687,7 +687,7 @@ func BenchmarkTrivial(b *testing.B) { transactionId := strconv.Itoa(i) InitTransaction(transactionId) - Analyze("RequestHeaders", transactionId, pluginmanager.HTTPPayload{URI: "Request line and headers\n"}, []string{"trivial", "trivial2"}) + Analyze("RequestHeaders", transactionId, waceapi.HTTPPayload{URI: "Request line and headers\n"}, []string{"trivial", "trivial2"}) _, err := CheckTransaction(transactionId, "simple", wafParams) if err != nil { @@ -715,7 +715,7 @@ func BenchmarkTrivialFullNATS(b *testing.B) { transactionId := generateRandomID() InitTransaction(transactionId) - Analyze("RequestHeaders", transactionId, pluginmanager.HTTPPayload{URI: "Request line and headers\n"}, []string{"trivial", "trivial2"}) + Analyze("RequestHeaders", transactionId, waceapi.HTTPPayload{URI: "Request line and headers\n"}, []string{"trivial", "trivial2"}) _, err := CheckTransaction(transactionId, "simple", wafParams) if err != nil {