diff --git a/configstore/configstore.go b/configstore/configstore.go index 5e88005..573a372 100644 --- a/configstore/configstore.go +++ b/configstore/configstore.go @@ -67,16 +67,23 @@ func StringToPluginType(textType string) (ModelPluginType, error) { return -1, fmt.Errorf("invalid plugin type %s", textType) } +type TrainingData struct { + MaxSamples int `yaml:"max_samples"` + ResultFilePath string `yaml:"result_file_path"` +} + // ModelPluginConfig stores the configuration of a model plugin type modelPluginConfig struct { - ID string - Path string - Weight float64 - Threshold float64 - Params map[string]string - PluginType ModelPluginType - Mode string - Remote bool + ID string + Path string + Weight float64 + Threshold float64 + Params map[string]string + PluginType ModelPluginType + async bool + remote bool + training bool + TrainingData TrainingData } // DecisionPluginConfig stores the configuration of a decision plugin @@ -121,14 +128,16 @@ func Clean() { } type configFileModelPlugin struct { - ID string - Path string - Weight float64 - Threshold float64 - Params map[string]string - PluginType string `yaml:"plugintype"` - Mode string - Remote bool + ID string + Path string + Weight float64 + Threshold float64 + Params map[string]string + PluginType string `yaml:"plugintype"` + Async bool + Remote bool + Training bool + TrainingData TrainingData `yaml:"training_data"` } type configFileDecisionPlugin struct { @@ -147,7 +156,17 @@ type ConfigFileData struct { // IsAsync returns true if the model plugin is async func (c *ConfigStore) IsAsync(modelID string) bool { - return c.ModelPlugins[modelID].Mode == "async" + return c.ModelPlugins[modelID].async +} + +// IsRemote returns true if the model plugin is remote +func (c *ConfigStore) IsRemote(modelID string) bool { + return c.ModelPlugins[modelID].remote +} + +// IsInTraining returns true if the model plugin is in training mode (collecting data) +func (c *ConfigStore) IsInTraining(modelID string) bool { + return c.ModelPlugins[modelID].training } // CheckLogging verifies if the log path is valid @@ -178,7 +197,6 @@ func checkConfig(inConf ConfigFileData) error { // check modelplugins for _, modelP := range inConf.Modelplugins { - if modelP.Path != "" { if _, err := os.Stat(modelP.Path); err != nil { return fmt.Errorf("%s plugin path %s: %v", modelP.ID, modelP.Path, err) @@ -189,7 +207,15 @@ func checkConfig(inConf ConfigFileData) error { if modelP.PluginType == "" { return fmt.Errorf("%s plugin type cannot be empty, please provide a valid type", modelP.ID) } - // fmt.Printf("modelP.Type: %s\n", modelP.Type) + if modelP.Training && modelP.Async { + return fmt.Errorf("model %s plugin cannot be in training mode and async mode at the same time", modelP.ID) + } + if modelP.Training && modelP.Remote { + return fmt.Errorf("model %s: remote training mode is not supported", modelP.ID) + } + if modelP.Training && modelP.TrainingData.MaxSamples == 0 { + return fmt.Errorf("model %s: max sample count should be greater than 0", modelP.ID) + } } // check decisionplugins for _, decisionP := range inConf.Decisionplugins { @@ -228,8 +254,10 @@ func (cs *ConfigStore) SetConfig(inConf ConfigFileData) error { modelConfig.Threshold = modelP.Threshold modelConfig.Params = modelP.Params modelConfig.PluginType, err = StringToPluginType(modelP.PluginType) - modelConfig.Mode = modelP.Mode - modelConfig.Remote = modelP.Remote + modelConfig.async = modelP.Async + modelConfig.remote = modelP.Remote + modelConfig.training = modelP.Training + modelConfig.TrainingData = modelP.TrainingData if err != nil { return err } @@ -245,11 +273,7 @@ func (cs *ConfigStore) SetConfig(inConf ConfigFileData) error { cs.DecisionPlugins[decisionConfig.ID] = decisionConfig } - if inConf.NatsURL != "" { - cs.NatsURL = inConf.NatsURL - } else { - cs.NatsURL = "localhost:4222" - } + cs.NatsURL = inConf.NatsURL return nil } diff --git a/configstore/configstore_test.go b/configstore/configstore_test.go index d27eced..01f18e4 100644 --- a/configstore/configstore_test.go +++ b/configstore/configstore_test.go @@ -485,12 +485,11 @@ func TestModelPluginTypeString(t *testing.T) { func TestIsAsync(t *testing.T) { tests := []struct { name string - mode string + async bool wantAsync bool }{ - {"sync mode", "sync", false}, - {"async mode", "async", true}, - {"empty mode defaults to sync", "", false}, + {"no async field defaults to sync", false, false}, + {"async: true is async", true, true}, } for _, tt := range tests { @@ -508,14 +507,14 @@ modelplugins: - id: "testplugin" path: "../testdata/plugins/model/trivial.so" plugintype: "RequestHeaders" - mode: "%s" -`, tt.mode) + async: %v +`, tt.async) if err := initialize([]byte(config)); err != nil { t.Fatalf("initialize failed: %v", err) } if got := cs.IsAsync("testplugin"); got != tt.wantAsync { - t.Errorf("IsAsync with mode %q = %v, want %v", tt.mode, got, tt.wantAsync) + t.Errorf("IsAsync with async=%v = %v, want %v", tt.async, got, tt.wantAsync) } }) } @@ -531,6 +530,150 @@ func TestGetBeforeNew(t *testing.T) { } } +func TestIsInTraining(t *testing.T) { + tests := []struct { + name string + training bool + wantTraining bool + }{ + {"no training field defaults to false", false, false}, + {"training: true enables training mode", true, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cs, err := New() + if err != nil { + t.Fatal(err) + } + defer Clean() + + trainingSection := "" + if tt.training { + trainingSection = "\n training_data:\n max_samples: 10" + } + config := fmt.Sprintf(`--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "../testdata/plugins/model/trivial.so" + plugintype: "RequestHeaders" + training: %v%s +`, tt.training, trainingSection) + if err := initialize([]byte(config)); err != nil { + t.Fatalf("initialize failed: %v", err) + } + + if got := cs.IsInTraining("testplugin"); got != tt.wantTraining { + t.Errorf("IsInTraining with training=%v = %v, want %v", tt.training, got, tt.wantTraining) + } + }) + } +} + +func TestTrainingDataConfig(t *testing.T) { + tests := []struct { + name string + config string + wantErr bool + wantMaxSamples int + wantPath string + }{ + { + name: "training with zero max_samples returns error", + config: `--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "../testdata/plugins/model/trivial.so" + plugintype: "RequestHeaders" + training: true +`, + wantErr: true, + }, + { + name: "training and async are mutually exclusive", + config: `--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "../testdata/plugins/model/trivial.so" + plugintype: "RequestHeaders" + training: true + async: true + training_data: + max_samples: 10 +`, + wantErr: true, + }, + { + name: "training and remote are mutually exclusive", + config: `--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "../testdata/plugins/model/trivial.so" + plugintype: "RequestHeaders" + training: true + remote: true + training_data: + max_samples: 10 +`, + wantErr: true, + }, + { + name: "valid training config stores TrainingData correctly", + config: `--- +loglevel: ERROR +logpath: /dev/null +modelplugins: + - id: "testplugin" + path: "../testdata/plugins/model/trivial.so" + plugintype: "RequestHeaders" + training: true + training_data: + max_samples: 42 + result_file_path: "/dev/null" +`, + wantErr: false, + wantMaxSamples: 42, + wantPath: "/dev/null", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cs, err := New() + if err != nil { + t.Fatal(err) + } + defer Clean() + + err = initialize([]byte(tt.config)) + if (err != nil) != tt.wantErr { + if tt.wantErr { + t.Errorf("expected error but got none") + } else { + t.Errorf("unexpected error: %v", err) + } + return + } + if !tt.wantErr { + got := cs.ModelPlugins["testplugin"].TrainingData + if got.MaxSamples != tt.wantMaxSamples { + t.Errorf("MaxSamples = %d, want %d", got.MaxSamples, tt.wantMaxSamples) + } + if got.ResultFilePath != tt.wantPath { + t.Errorf("ResultsFilePath = %q, want %q", got.ResultFilePath, tt.wantPath) + } + } + }) + } +} func TestNatsURL(t *testing.T) { tests := []struct { @@ -539,12 +682,12 @@ func TestNatsURL(t *testing.T) { wantURL string }{ { - name: "defaults to localhost:4222", + name: "empty string when natsurl not set", config: `--- loglevel: ERROR logpath: /dev/null `, - wantURL: "localhost:4222", + wantURL: "", }, { name: "stores custom URL", diff --git a/magefile.go b/magefiles/magefile.go similarity index 100% rename from magefile.go rename to magefiles/magefile.go diff --git a/pluginmanager/pluginmanager.go b/pluginmanager/pluginmanager.go index 34cd431..89463c1 100644 --- a/pluginmanager/pluginmanager.go +++ b/pluginmanager/pluginmanager.go @@ -5,6 +5,7 @@ decision plugins package pluginmanager import ( + "context" "encoding/json" "fmt" "plugin" @@ -27,10 +28,13 @@ type ModelTransmitionResults struct { // modelPlugin is the struct that stores the model plugin and its type type modelPlugin struct { - p *plugin.Plugin - pluginType configstore.ModelPluginType - process func(waceapi.ModelInput) (waceapi.ModelResults, error) - reload func(map[string]string, metric.Meter) error + p *plugin.Plugin + pluginType configstore.ModelPluginType + process func(waceapi.ModelInput) (waceapi.ModelResults, error) + reload func(map[string]string, metric.Meter) error + trainingChannel chan waceapi.ModelResults + trainingCtx context.Context + trainingCancel context.CancelFunc } // decisionPlugin is the struct that stores the decision plugin @@ -83,13 +87,16 @@ func New(meter metric.Meter) (*PluginManager, error) { logger := logging.Get() logger.Printf(logging.DEBUG, "Connecting to NATS server at %s", conf.NatsURL) - nc, err := nats.Connect(conf.NatsURL) + if conf.NatsURL != "" { + nc, err := nats.Connect(conf.NatsURL) - if err != nil { - logger.Printf(logging.ERROR, "Failed to connect to NATS server") - } + if err != nil { + logger.Printf(logging.ERROR, "Failed to connect to NATS server") + return nil, err + } - pm.natConn = nc + pm.natConn = nc + } pm.modelPlugins = make(map[string]modelPlugin) pm.loadModelPlugins(meter) @@ -118,6 +125,7 @@ func (pm *PluginManager) loadModelPlugins(meter metric.Meter) error { logger := logging.Get() // Load plugin models + // TODO: remove data from old plugins in a best effort approach for _, data := range conf.ModelPlugins { mp, found := pm.modelPlugins[data.ID] if !found { @@ -127,8 +135,7 @@ func (pm *PluginManager) loadModelPlugins(meter metric.Meter) error { continue } var processFunc func(waceapi.ModelInput) (waceapi.ModelResults, error) - // TODO: change mode to bool - if data.Mode == "async" || conf.ModelPlugins[data.ID].Remote { + if conf.IsAsync(data.ID) || conf.IsRemote(data.ID) { f, err := p.Lookup(modelInitAsyncFunctionName) if err != nil { logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err) @@ -184,7 +191,15 @@ func (pm *PluginManager) loadModelPlugins(meter metric.Meter) error { logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid %s function type", data.ID, modelReloadFunction) continue } - modelPluginLoaded := modelPlugin{p, data.PluginType, processFunc, reload} + var trainingChannel chan waceapi.ModelResults + var trainingCtx context.Context + var trainingCancel context.CancelFunc + if conf.IsInTraining(data.ID) { + trainingChannel = make(chan waceapi.ModelResults) + trainingCtx, trainingCancel = context.WithCancel(context.Background()) + go pm.handleTrainingModel(data.ID, data.TrainingData, trainingCtx, trainingCancel, trainingChannel) + } + modelPluginLoaded := modelPlugin{p, data.PluginType, processFunc, reload, trainingChannel, trainingCtx, trainingCancel} pm.modelPlugins[data.ID] = modelPluginLoaded logger.Printf(logging.INFO, "| %s | plugin loaded", data.ID) } else { @@ -193,6 +208,9 @@ func (pm *PluginManager) loadModelPlugins(meter metric.Meter) error { logger.Printf(logging.WARN, "| %s | cannot reload plugin: %s", data.ID, err.Error()) continue } + if !conf.IsInTraining(data.ID) && mp.trainingCancel != nil { + mp.trainingCancel() + } } } return nil @@ -358,51 +376,45 @@ func (p *PluginManager) AddToQueue(modelID, transactionID string, payload waceap return p.natConn.Publish(modelID, jsonPayload) } -// Process is in charge of calling the model plugin with id modelID -func (p *PluginManager) Process(modelID, transactionId string, payload waceapi.HTTPPayload, t configstore.ModelPluginType, modelPlugStatus chan ModelStatus) error { +func (p *PluginManager) modelProcess(modelID string, mp modelPlugin, payload waceapi.ModelInput, t configstore.ModelPluginType) (waceapi.ModelResults, error) { + // check if the plugin is capable of analyzing the indicated part of the transaction + if mp.pluginType != t { + return waceapi.ModelResults{}, fmt.Errorf("plugin type %v cannot process a request with incompatible type %v", mp.pluginType, t) + } + conf, err := configstore.Get() if err != nil { - return err + return waceapi.ModelResults{}, err } + if conf.IsAsync(modelID) { + return waceapi.ModelResults{}, fmt.Errorf("model plugin is async") + } + return mp.process(payload) +} + +// Process is in charge of calling the model plugin with id modelID +func (p *PluginManager) Process(modelID, transactionID string, payload waceapi.HTTPPayload, t configstore.ModelPluginType, modelPlugStatus chan ModelStatus) { mp, exists := p.modelPlugins[modelID] if !exists { - modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("model plugin not found")} - return nil + modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("Model plugin %s not found", modelID)} + return } - // check if the plugin is capable of analyzing the indicated part of the transaction - if mp.pluginType != t { - modelPlugStatus <- ModelStatus{ModelID: modelID, - Err: fmt.Errorf("plugin type %v cannot process a request with incompatible type %v", mp.pluginType, t)} - return nil - } + res, err := p.modelProcess(modelID, mp, waceapi.ModelInput{TransactionId: transactionID, Payload: payload}, t) - mp, ok := p.modelPlugins[modelID] - if !ok { - return fmt.Errorf("Model plugin %s not found", modelID) + if err != nil { + modelPlugStatus <- ModelStatus{ModelID: modelID, Err: err} + return } - - if conf.ModelPlugins[modelID].Mode == "async" { - modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("model plugin is async")} - return nil - } else { - res, err := mp.process(waceapi.ModelInput{TransactionId: transactionId, Payload: payload}) - - if err != nil { - modelPlugStatus <- ModelStatus{ModelID: modelID, Err: err} - return nil - } - // store the results - resultSyncMap, ok := p.results.Load(transactionId) - if !ok { - modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("transaction results not found")} - return nil - } - resultSyncMap.(*sync.Map).Store(modelID, res) - modelPlugStatus <- ModelStatus{ModelID: modelID, ProbAttack: res.ProbAttack, Err: nil} + // store the results + resultSyncMap, ok := p.results.Load(transactionID) + if !ok { + modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("transaction results not found")} + return } - return nil + resultSyncMap.(*sync.Map).Store(modelID, res) + modelPlugStatus <- ModelStatus{ModelID: modelID, ProbAttack: res.ProbAttack, Err: nil} } // CheckResult is in charge of calling the decision plugin with id decisionID over the @@ -440,48 +452,48 @@ func (p *PluginManager) CheckResult(transactionId, decisionId string, wafParams } // ModelResultsHandler listens for messages on the model results queue -func (p *PluginManager) ModelResultsHandler(modelId string) error { +func (p *PluginManager) ModelResultsHandler(modelID string) error { logger := logging.Get() cs, err := configstore.Get() if err != nil { return err } - sub, err := p.natConn.Subscribe(modelId+"/results", func(msg *nats.Msg) { + sub, err := p.natConn.Subscribe(modelID+"/results", func(msg *nats.Msg) { go func(msg nats.Msg) { data := &ModelTransmitionResults{} err := json.Unmarshal(msg.Data, data) if err != nil { - logger.Printf(logging.ERROR, "Model: %s | Failed to parse JSON payload", modelId) + logger.Printf(logging.ERROR, "Model: %s | Failed to parse JSON payload", modelID) } else { var channel interface{} var ok bool - if cs.ModelPlugins[modelId].Mode == "async" { + if cs.IsAsync(modelID) { channel, ok = p.asyncModelsChannels.Load(data.TransactionId) } else { channel, ok = p.syncModelsChannels.Load(data.TransactionId) } if !ok { - logger.TPrintf(logging.ERROR, data.TransactionId, " Model %s | Transaction not found", modelId) + logger.TPrintf(logging.ERROR, data.TransactionId, " Model %s | Transaction not found", modelID) } else { - modelChannel, ok := channel.(*sync.Map).Load(cs.ModelPlugins[modelId].PluginType.String()) + modelChannel, ok := channel.(*sync.Map).Load(cs.ModelPlugins[modelID].PluginType.String()) if !ok { - logger.Printf(logging.ERROR, "Model %s not found", modelId) + logger.Printf(logging.ERROR, "Model %s not found", modelID) } else { if data.Error != nil { - modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, Err: data.Error} + modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelID, Err: data.Error} } else { - if cs.ModelPlugins[modelId].Mode != "async" { + if !cs.IsAsync(modelID) { // store the results resultSyncMap, ok := p.results.Load(data.TransactionId) if !ok { - modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, Err: fmt.Errorf("transaction results not found")} + modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("transaction results not found")} return } modelResult := waceapi.ModelResults{ProbAttack: data.ProbAttack, Data: data.Data} - resultSyncMap.(*sync.Map).Store(modelId, modelResult) + resultSyncMap.(*sync.Map).Store(modelID, modelResult) } - modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, ProbAttack: data.ProbAttack, Err: nil} + modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelID, ProbAttack: data.ProbAttack, Err: nil} } } } @@ -490,11 +502,11 @@ func (p *PluginManager) ModelResultsHandler(modelId string) error { }) if err != nil { - logger.Printf(logging.ERROR, "Model: %s | Failed to subscribe to model queue | %s", modelId, err.Error()) + logger.Printf(logging.ERROR, "Model: %s | Failed to subscribe to model queue | %s", modelID, err.Error()) return err } - logger.Printf(logging.INFO, "Model: %s | Listening for messages on model results queue", modelId) + logger.Printf(logging.INFO, "Model: %s | Listening for messages on model results queue", modelID) defer sub.Unsubscribe() defer p.natConn.Drain() diff --git a/pluginmanager/pluginmanager_test.go b/pluginmanager/pluginmanager_test.go index ce1af1d..1f6e5e3 100644 --- a/pluginmanager/pluginmanager_test.go +++ b/pluginmanager/pluginmanager_test.go @@ -553,7 +553,7 @@ func TestPluginManagerProcessAsyncPlugin(t *testing.T) { path: "../testdata/plugins/model/trivial.so" weight: 1 plugintype: "Everything" - mode: async + async: true ` cs, err := configstore.Get() if err != nil { @@ -594,6 +594,190 @@ func TestPluginManagerCheckResultWithoutTransaction(t *testing.T) { } } +var trivialTrainingPlugin = ` - id: "trivial" + path: "../testdata/plugins/model/trivial.so" + weight: 1 + plugintype: "Everything" + training: true + training_data: + max_samples: 3 + result_file_path: "/dev/null" +` + +// TestPluginManagerTrainingPluginLoaded verifies that a training plugin is +// loaded with its channel, context, and cancel function all initialised and +// that the context is live immediately after load. +func TestPluginManagerTrainingPluginLoaded(t *testing.T) { + config := []byte(baseConfig + "modelplugins:\n" + trivialTrainingPlugin) + pm := setupPluginManager(t, config) + + mp, ok := pm.modelPlugins["trivial"] + if !ok { + t.Fatal("training plugin was not loaded into modelPlugins") + } + if mp.trainingChannel == nil { + t.Error("trainingChannel should not be nil for a training plugin") + } + if mp.trainingCtx == nil { + t.Error("trainingCtx should not be nil for a training plugin") + } + if mp.trainingCancel == nil { + t.Error("trainingCancel should not be nil for a training plugin") + } + select { + case <-mp.trainingCtx.Done(): + t.Error("trainingCtx should not be done immediately after load") + default: + } +} + +// TestPluginManagerProcessTrainingExhaustsMaxSamples sends maxSamples results +// through ProcessTraining and verifies that the goroutine exits (calling +// defer cancel()) once it has consumed all of them. +func TestPluginManagerProcessTrainingExhaustsMaxSamples(t *testing.T) { + config := []byte(baseConfig + "modelplugins:\n" + trivialTrainingPlugin) + pm := setupPluginManager(t, config) + mp := pm.modelPlugins["trivial"] + + txID := generateRandomID() + + for i := 0; i < 3; i++ { + go pm.ProcessTraining("trivial", txID, waceapi.HTTPPayload{URI: "/test"}, configstore.Everything) + } + + select { + case <-mp.trainingCtx.Done(): + // goroutine exited after maxSamples and called defer cancel() + case <-time.After(2 * time.Second): + t.Error("training goroutine did not exit after exhausting maxSamples") + } +} + +// TestPluginManagerProcessTrainingNonexistent verifies that ProcessTraining +// returns immediately when the model ID does not exist, without blocking. +func TestPluginManagerProcessTrainingNonexistent(t *testing.T) { + config := []byte(baseConfig + "modelplugins:\n" + trivialTrainingPlugin) + pm := setupPluginManager(t, config) + + done := make(chan struct{}) + go func() { + pm.ProcessTraining("nonexistent", generateRandomID(), waceapi.HTTPPayload{URI: "/test"}, configstore.Everything) + close(done) + }() + select { + case <-done: + case <-time.After(time.Second): + t.Error("ProcessTraining with nonexistent plugin should return immediately") + } +} + +// TestPluginManagerProcessTrainingAfterCancel verifies that ProcessTraining +// does not block when the training context has already been cancelled. +func TestPluginManagerProcessTrainingAfterCancel(t *testing.T) { + config := []byte(baseConfig + "modelplugins:\n" + trivialTrainingPlugin) + pm := setupPluginManager(t, config) + mp := pm.modelPlugins["trivial"] + + mp.trainingCancel() + + done := make(chan struct{}) + go func() { + pm.ProcessTraining("trivial", generateRandomID(), waceapi.HTTPPayload{URI: "/test"}, configstore.Everything) + close(done) + }() + select { + case <-done: + case <-time.After(time.Second): + t.Error("ProcessTraining should not block after context cancellation") + } +} + +// TestPluginManagerTrainingReloadDisablesTraining verifies that Reload cancels +// the training goroutine when the new config has training disabled for the plugin. +func TestPluginManagerTrainingReloadDisablesTraining(t *testing.T) { + config := []byte(baseConfig + "modelplugins:\n" + trivialTrainingPlugin) + pm := setupPluginManager(t, config) + mp := pm.modelPlugins["trivial"] + + select { + case <-mp.trainingCtx.Done(): + t.Fatal("trainingCtx should be alive before Reload") + default: + } + + disabledConfig := baseConfig + `modelplugins: + - id: "trivial" + path: "../testdata/plugins/model/trivial.so" + weight: 1 + plugintype: "Everything" +` + cs, err := configstore.Get() + if err != nil { + t.Fatalf("configstore.Get: %v", err) + } + var aux configstore.ConfigFileData + if err := yaml.Unmarshal([]byte(disabledConfig), &aux); err != nil { + t.Fatalf("yaml.Unmarshal: %v", err) + } + if err := cs.SetConfig(aux); err != nil { + t.Fatalf("SetConfig: %v", err) + } + if err := pm.Reload(testMeter); err != nil { + t.Fatalf("Reload: %v", err) + } + + select { + case <-mp.trainingCtx.Done(): + // goroutine was cancelled by Reload + case <-time.After(time.Second): + t.Error("trainingCtx should be cancelled after Reload with training disabled") + } +} + +// TestPluginManagerTrainingResultNotUsedInDecision verifies that a result +// produced by ProcessTraining is stored in the training channel only and never +// reaches p.results, so CheckResult cannot use it to influence a decision. +// +// trivial2 always returns ProbAttack=1.0. If its training result leaked into +// p.results, CheckResult would block the transaction; it must not. +func TestPluginManagerTrainingResultNotUsedInDecision(t *testing.T) { + conf := baseConfig + `modelplugins: + - id: "trivial2" + path: "../testdata/plugins/model/trivial2.so" + weight: 1 + plugintype: "Everything" + training: true + training_data: + max_samples: 5 + result_file_path: "/dev/null" +decisionplugins: +` + simplePlugin + pm := setupPluginManager(t, []byte(conf)) + + txID := generateRandomID() + pm.InitTransaction(txID) + defer pm.CloseTransaction(txID) + + done := make(chan struct{}) + go func() { + pm.ProcessTraining("trivial2", txID, waceapi.HTTPPayload{URI: "/test"}, configstore.Everything) + close(done) + }() + <-done // wait until the result has been handed off to the training goroutine + + // WAF params that would cause a block if trivial2's prob=1.0 reached the decision. + result, err := pm.CheckResult(txID, "simple", map[string]string{ + "inbound_blocking": "20", + "inbound_threshold": "5", + }) + if err != nil { + t.Fatalf("CheckResult error: %v", err) + } + if result { + t.Error("CheckResult blocked — training model result must not feed into the decision plugin") + } +} + // TestPluginManagerProcessWithoutTransaction verifies that Process sends an // error when the transaction was never initialised (results map is absent). func TestPluginManagerProcessWithoutTransaction(t *testing.T) { diff --git a/pluginmanager/training_models.go b/pluginmanager/training_models.go new file mode 100644 index 0000000..ea0f9a3 --- /dev/null +++ b/pluginmanager/training_models.go @@ -0,0 +1,83 @@ +package pluginmanager + +import ( + "bufio" + "context" + "encoding/json" + "os" + + "github.com/tilsor/ModSecIntl_logging/logging" + "github.com/tilsor/ModSecIntl_wace_lib/configstore" + "github.com/tilsor/ModSecIntl_wace_lib/waceapi" +) + +func (p *PluginManager) handleTrainingModel(modelID string, td configstore.TrainingData, ctx context.Context, cancel context.CancelFunc, tc chan waceapi.ModelResults) { + defer cancel() + logger := logging.Get() + logger.Printf(logging.INFO, "Model %s | Handling training data\n", modelID) + + f, err := os.OpenFile(td.ResultFilePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0644) + if err != nil { + logger.Printf(logging.ERROR, "Model %s | Error handling data: %s", modelID, err.Error()) + return + } + defer f.Close() + + lineCount := 0 + scanner := bufio.NewScanner(f) + for scanner.Scan() { + lineCount++ + } + if err := scanner.Err(); err != nil { + logger.Printf(logging.ERROR, "Model %s | Error handling data: %s", modelID, err.Error()) + return + } + + encoder := json.NewEncoder(f) + + if lineCount >= td.MaxSamples { + logger.Printf(logging.INFO, "Model %s | The maximum number of samples has already been written.\n", modelID) + return + } else { + logger.Printf(logging.INFO, "Model %s | Previously amount of samples written %d\n", modelID, lineCount) + } + + for i := 0; i < td.MaxSamples-lineCount; i++ { + select { + case data := <-tc: + logger.Printf(logging.DEBUG, "Model %s | Recieved data %v", modelID, data) + if err := encoder.Encode(data); err != nil { + logger.Printf(logging.ERROR, "Model %s | Error writing data: %s", modelID, err.Error()) + return + } + case <-ctx.Done(): + logger.Printf(logging.DEBUG, "Model %s | Training cancelled\n", modelID) + return + } + } + + logger.Printf(logging.INFO, "Model %s | Data collection for training completed.\n", modelID) +} + +// ProcessTraining is in charge of calling the model plugin with id modelID +func (p *PluginManager) ProcessTraining(modelID, transactionID string, payload waceapi.HTTPPayload, t configstore.ModelPluginType) { + logger := logging.Get() + + mp, exists := p.modelPlugins[modelID] + if !exists { + logger.TPrintf(logging.ERROR, transactionID, "Model %s not found", modelID) + return + } + + res, err := p.modelProcess(modelID, mp, waceapi.ModelInput{TransactionId: transactionID, Payload: payload, TrainingMode: true}, t) + if err != nil { + logger.TPrintf(logging.ERROR, transactionID, "Error processing model %s: %s", modelID, err.Error()) + return + } + + select { + case mp.trainingChannel <- res: + case <-mp.trainingCtx.Done(): + logger.TPrintf(logging.DEBUG, transactionID, "training cancelled for model %s, dropping result", modelID) + } +} diff --git a/pluginmanager/training_models_test.go b/pluginmanager/training_models_test.go new file mode 100644 index 0000000..c3848b5 --- /dev/null +++ b/pluginmanager/training_models_test.go @@ -0,0 +1,181 @@ +package pluginmanager + +import ( + "bufio" + "context" + "encoding/json" + "os" + "testing" + "time" + + "github.com/tilsor/ModSecIntl_wace_lib/configstore" + "github.com/tilsor/ModSecIntl_wace_lib/waceapi" +) + +func countFileLines(t *testing.T, path string) int { + t.Helper() + f, err := os.Open(path) + if err != nil { + t.Fatalf("countFileLines: open %s: %v", path, err) + } + defer f.Close() + sc := bufio.NewScanner(f) + n := 0 + for sc.Scan() { + n++ + } + return n +} + +func prefillResultsFile(t *testing.T, path string, count int) { + t.Helper() + f, err := os.Create(path) + if err != nil { + t.Fatalf("prefillResultsFile: %v", err) + } + defer f.Close() + enc := json.NewEncoder(f) + for i := 0; i < count; i++ { + if err := enc.Encode(waceapi.ModelResults{ProbAttack: float64(i) * 0.1}); err != nil { + t.Fatalf("prefillResultsFile encode: %v", err) + } + } +} + +// TestHandleTrainingModelWritesData verifies that handleTrainingModel writes +// each received result as a JSON line and exits after consuming MaxSamples. +func TestHandleTrainingModelWritesData(t *testing.T) { + path := t.TempDir() + "/results.ndjson" + td := configstore.TrainingData{MaxSamples: 3, ResultFilePath: path} + ctx, cancel := context.WithCancel(context.Background()) + tc := make(chan waceapi.ModelResults) + go (&PluginManager{}).handleTrainingModel("test", td, ctx, cancel, tc) + + for i := 0; i < 3; i++ { + tc <- waceapi.ModelResults{ProbAttack: float64(i) * 0.5} + } + select { + case <-ctx.Done(): + case <-time.After(2 * time.Second): + t.Fatal("goroutine did not exit after consuming MaxSamples") + } + + if n := countFileLines(t, path); n != 3 { + t.Errorf("file has %d lines, want 3", n) + } + + f, err := os.Open(path) + if err != nil { + t.Fatalf("open results: %v", err) + } + defer f.Close() + dec := json.NewDecoder(f) + for i := 0; i < 3; i++ { + var r waceapi.ModelResults + if err := dec.Decode(&r); err != nil { + t.Fatalf("decode record %d: %v", i, err) + } + if want := float64(i) * 0.5; r.ProbAttack != want { + t.Errorf("record %d: ProbAttack=%f, want %f", i, r.ProbAttack, want) + } + } +} + +// TestHandleTrainingModelAlreadyFull verifies that handleTrainingModel exits +// immediately without reading from the channel when the file already contains +// MaxSamples lines. +func TestHandleTrainingModelAlreadyFull(t *testing.T) { + path := t.TempDir() + "/results.ndjson" + prefillResultsFile(t, path, 5) + + td := configstore.TrainingData{MaxSamples: 5, ResultFilePath: path} + ctx, cancel := context.WithCancel(context.Background()) + tc := make(chan waceapi.ModelResults) + go (&PluginManager{}).handleTrainingModel("test", td, ctx, cancel, tc) + + select { + case <-ctx.Done(): + case <-time.After(time.Second): + t.Fatal("goroutine did not exit immediately when file was already full") + } + if n := countFileLines(t, path); n != 5 { + t.Errorf("file has %d lines after full-file run, want 5 (no new writes)", n) + } +} + +// TestHandleTrainingModelResumesFromExisting verifies that handleTrainingModel +// counts existing lines and reads only the remaining samples needed, reaching +// MaxSamples total without re-writing previously collected data. +func TestHandleTrainingModelResumesFromExisting(t *testing.T) { + path := t.TempDir() + "/results.ndjson" + prefillResultsFile(t, path, 2) + + td := configstore.TrainingData{MaxSamples: 5, ResultFilePath: path} + ctx, cancel := context.WithCancel(context.Background()) + tc := make(chan waceapi.ModelResults) + go (&PluginManager{}).handleTrainingModel("test", td, ctx, cancel, tc) + + for i := 0; i < 3; i++ { + tc <- waceapi.ModelResults{ProbAttack: float64(i)} + } + select { + case <-ctx.Done(): + case <-time.After(2 * time.Second): + t.Fatal("goroutine did not exit after consuming remaining samples") + } + if n := countFileLines(t, path); n != 5 { + t.Errorf("file has %d lines, want 5 (2 existing + 3 new)", n) + } +} + +// TestHandleTrainingModelCancellation verifies that handleTrainingModel stops +// reading from the channel and exits when the context is cancelled before +// MaxSamples are consumed. +func TestHandleTrainingModelCancellation(t *testing.T) { + path := t.TempDir() + "/results.ndjson" + td := configstore.TrainingData{MaxSamples: 10, ResultFilePath: path} + ctx, cancel := context.WithCancel(context.Background()) + tc := make(chan waceapi.ModelResults) + done := make(chan struct{}) + go func() { + (&PluginManager{}).handleTrainingModel("test", td, ctx, cancel, tc) + close(done) + }() + + tc <- waceapi.ModelResults{ProbAttack: 0.1} + tc <- waceapi.ModelResults{ProbAttack: 0.2} + cancel() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("goroutine did not exit after context cancellation") + } + if n := countFileLines(t, path); n != 2 { + t.Errorf("file has %d lines, want 2 (written before cancellation)", n) + } +} + +// TestHandleTrainingModelFileOpenError verifies that handleTrainingModel +// returns gracefully (without panicking) when the result file cannot be opened. +func TestHandleTrainingModelFileOpenError(t *testing.T) { + td := configstore.TrainingData{MaxSamples: 3, ResultFilePath: "/nonexistent/dir/results.ndjson"} + ctx, cancel := context.WithCancel(context.Background()) + tc := make(chan waceapi.ModelResults) + done := make(chan struct{}) + go func() { + (&PluginManager{}).handleTrainingModel("test", td, ctx, cancel, tc) + close(done) + }() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("handleTrainingModel did not return on file open error") + } + select { + case <-ctx.Done(): + default: + t.Error("cancel should have been called when returning on file error") + } +} diff --git a/testdata/plugins/model/trivial.go b/testdata/plugins/model/trivial.go index 9e65bd0..6c784cd 100644 --- a/testdata/plugins/model/trivial.go +++ b/testdata/plugins/model/trivial.go @@ -37,7 +37,7 @@ func Process(input waceapi.ModelInput) (waceapi.ModelResults, error) { logger.TPrintf(lg.WARN, input.TransactionId, "[trivial:Process] \"%v\"\n", input.Payload) result := waceapi.ModelResults{ ProbAttack: 0.0, - Data: make(map[string]interface{}), + Data: input, } return result, nil } diff --git a/waceapi/waceapi.go b/waceapi/waceapi.go index 28180b9..1099cc5 100644 --- a/waceapi/waceapi.go +++ b/waceapi/waceapi.go @@ -1,8 +1,8 @@ package waceapi type ModelResults struct { - ProbAttack float64 `json:"probattack"` - Data map[string]interface{} `json:"data"` + ProbAttack float64 `json:"probattack"` + Data any `json:"data"` } type HTTPHeader struct { @@ -26,6 +26,7 @@ type HTTPPayload struct { type ModelInput struct { TransactionId string `json:"transactionId"` Payload HTTPPayload `json:"payload"` + TrainingMode bool `json:"trainingMode"` } // DecisionInput is the struct that contains the input data for the decision plugin diff --git a/wacecore.go b/wacecore.go index a41d6e4..310b9c6 100644 --- a/wacecore.go +++ b/wacecore.go @@ -86,23 +86,22 @@ func callPlugins(input waceapi.HTTPPayload, models []string, t configstore.Model logger.TPrintf(logging.DEBUG, transactionID, "%s | calling from core", id) if _, ok := conf.ModelPlugins[id]; !ok { logger.TPrintf(logging.ERROR, transactionID, "core | model plugin %s not found", id) + } else if conf.ModelPlugins[id].PluginType != t { + logger.TPrintf(logging.ERROR, transactionID, "core | model plugin %s is not of type %s", id, t) + } else if conf.IsAsync(id) { + asyncCounter++ + go plugins.AddToQueue(id, transactionID, input) + } else if conf.IsInTraining(id) { + go plugins.ProcessTraining(id, transactionID, input, t) } else { - if conf.ModelPlugins[id].PluginType != t { - logger.TPrintf(logging.ERROR, transactionID, "core | model plugin %s is not of type %s", id, t) + if conf.IsRemote(id) { + go plugins.AddToQueue(id, transactionID, input) } else { - if conf.IsAsync(id) { - asyncCounter++ - go plugins.AddToQueue(id, transactionID, input) - } else { - if conf.ModelPlugins[id].Remote { - go plugins.AddToQueue(id, transactionID, input) - } else { - go plugins.Process(id, transactionID, input, t, modelPluginStatus) - } - syncCounter++ - } + go plugins.Process(id, transactionID, input, t, modelPluginStatus) } + syncCounter++ } + } go func() { diff --git a/wacecore_test.go b/wacecore_test.go index 27a2f07..7727631 100644 --- a/wacecore_test.go +++ b/wacecore_test.go @@ -216,12 +216,12 @@ modelplugins: plugintype: RequestHeaders path: "testdata/plugins/model/trivial.so" weight: 1 - mode: async + async: true - id: "trivial2" plugintype: RequestHeaders path: "testdata/plugins/model/trivial2.so" weight: 2 - mode: async + async: true #The decision plugin configuration decisionplugins: - id: "simple"