Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions .github/workflows/go.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,5 @@ jobs:
with:
go-version: '1.22.2'

- name: Make
run: |
cd testdata/plugins
make

- name: Test
run: go test -v ./...
run: go run mage.go test
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/tilsor/ModSecIntl_wace_lib
go 1.26.2

require (
github.com/magefile/mage v1.17.2
github.com/nats-io/nats.go v1.51.0
github.com/tilsor/ModSecIntl_logging v1.0.1
go.opentelemetry.io/otel v1.43.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/magefile/mage v1.17.2 h1:fyXVu1eadI8Ap1HCCNgEhJ5McIWiYhLR8uol64ZZc40=
github.com/magefile/mage v1.17.2/go.mod h1:Yj51kqllmsgFpvvSzgrZPK9WtluG3kUhFaBUVLo4feA=
github.com/nats-io/nats.go v1.51.0 h1:ByW84XTz6W03GSSsygsZcA+xgKK8vPGaa/FCAAEHnAI=
github.com/nats-io/nats.go v1.51.0/go.mod h1:26HypzazeOkyO3/mqd1zZd53STJN0EjCYF9Uy2ZOBno=
github.com/nats-io/nkeys v0.4.15 h1:JACV5jRVO9V856KOapQ7x+EY8Jo3qw1vJt/9Jpwzkk4=
Expand Down
16 changes: 16 additions & 0 deletions mage.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//go:build ignore
// +build ignore

// Entrypoint to mage for running without needing to install the command.
// https://magefile.org/zeroinstall/
package main

import (
"os"

"github.com/magefile/mage/mage"
)

func main() {
os.Exit(mage.Main())
}
81 changes: 81 additions & 0 deletions magefile.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
//go:build mage

package main

import (
"fmt"
"os"
"path/filepath"
"strings"

"github.com/magefile/mage/mg"
"github.com/magefile/mage/sh"
)

const pluginDir = "testdata/plugins"

// Plugins builds all test plugins without coverage instrumentation.
func Plugins() error {
return buildPlugins()
}

// pluginsCover builds all test plugins with coverage instrumentation.
func pluginsCover() error {
return buildPlugins("-cover")
}

// buildPlugins compiles every .go file under testdata/plugins/model and
// testdata/plugins/decision into a .so plugin. Pass "-cover" to instrument
// for coverage.
func buildPlugins(extraFlags ...string) error {
for _, dir := range []string{"model", "decision"} {
sources, err := filepath.Glob(filepath.Join(pluginDir, dir, "*.go"))
if err != nil {
return err
}
for _, src := range sources {
out := strings.TrimSuffix(src, ".go") + ".so"
args := []string{"build", "-buildmode=plugin"}
args = append(args, extraFlags...)
args = append(args, "-o", out, src)
fmt.Printf("building %s\n", out)
if err := sh.RunV("go", args...); err != nil {
return err
}
}
}
return nil
}

// Test builds the plugins and runs the full test suite.
func Test() error {
mg.Deps(Plugins)
return sh.RunV("go", "test", "./...", "-v", "-count=1")
}

// TestCoverage builds coverage-instrumented plugins and runs the test suite
// with coverage reporting across all packages.
func TestCoverage() error {
mg.Deps(pluginsCover)
return sh.RunV("go", "test", "-cover", "./...", "-v", "-count=1", "-coverprofile=coverage.out")
}

// Clean removes all compiled plugin .so files.
func Clean() error {
for _, pattern := range []string{
filepath.Join(pluginDir, "model", "*.so"),
filepath.Join(pluginDir, "decision", "*.so"),
} {
matches, err := filepath.Glob(pattern)
if err != nil {
return err
}
for _, f := range matches {
fmt.Printf("removing %s\n", f)
if err := os.Remove(f); err != nil {
return err
}
}
}
return nil
}
80 changes: 22 additions & 58 deletions pluginmanager/pluginmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,68 +11,32 @@ import (
"sync"

"github.com/tilsor/ModSecIntl_wace_lib/configstore"
"github.com/tilsor/ModSecIntl_wace_lib/waceapi"
"go.opentelemetry.io/otel/metric"

"github.com/nats-io/nats.go"
"github.com/tilsor/ModSecIntl_logging/logging"
)

// ResultData maps the model plugin ID with the corresponding analysis result.
type ModelResults struct {
ProbAttack float64 `json:"probattack"`
Data map[string]interface{} `json:"data"`
}

type HTTPHeader struct {
Key string
Value string
}

type HTTPPayload struct {
URI string
Method string
HTTPVersion string
RequestHeaders []HTTPHeader
RequestBody string
ResponseProtocol string
ResponseCode int
ResponseHeaders []HTTPHeader
ResponseBody string
}

// ModelInput is the struct that contains the input data for the model plugin
type ModelInput struct {
TransactionId string `json:"transactionId"`
Payload HTTPPayload `json:"payload"`
}

// DecisionInput is the struct that contains the input data for the decision plugin
type DecisionInput struct {
TransactionId string
Results map[string]ModelResults
ModelWeight map[string]float64
WAFdata map[string]string
}

// ModelTransmitionResults is the struct that contains the results of the model plugin
type ModelTransmitionResults struct {
TransactionId string `json:"transactionId"`
ModelResults `json:",inline"`
Error error `json:"error"`
TransactionId string `json:"transactionId"`
waceapi.ModelResults `json:",inline"`
Error error `json:"error"`
}

// modelPlugin is the struct that stores the model plugin and its type
type modelPlugin struct {
p *plugin.Plugin
pluginType configstore.ModelPluginType
process func(ModelInput) (ModelResults, error)
process func(waceapi.ModelInput) (waceapi.ModelResults, error)
reload func(map[string]string, metric.Meter) error
}

// decisionPlugin is the struct that stores the decision plugin
type decisionPlugin struct {
p *plugin.Plugin
checkResults func(DecisionInput) (bool, error)
checkResults func(waceapi.DecisionInput) (bool, error)
reload func(map[string]string, metric.Meter) error
}

Expand Down Expand Up @@ -162,22 +126,22 @@ func (pm *PluginManager) loadModelPlugins(meter metric.Meter) error {
logger.Printf(logging.WARN, "| %s | cannot open plugin: %v", data.ID, err)
continue
}
var processFunc func(ModelInput) (ModelResults, error)
var processFunc func(waceapi.ModelInput) (waceapi.ModelResults, error)
// TODO: change mode to bool
if data.Mode == "async" || conf.ModelPlugins[data.ID].Remote {
f, err := p.Lookup(modelInitAsyncFunctionName)
if err != nil {
logger.Printf(logging.WARN, "| %s | cannot load plugin: %v", data.ID, err)
continue
}
initPlugin, ok := f.(func(map[string]string, metric.Meter, func(func(ModelInput) (ModelResults, error))) error)
initPlugin, ok := f.(func(map[string]string, metric.Meter, func(func(waceapi.ModelInput) (waceapi.ModelResults, error))) error)
if !ok {
logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid %s function type", data.ID, modelInitAsyncFunctionName)
continue
}

// plugin initialization
err = initPlugin(data.Params, meter, func(modelProcess func(ModelInput) (ModelResults, error)) {
err = initPlugin(data.Params, meter, func(modelProcess func(waceapi.ModelInput) (waceapi.ModelResults, error)) {
ModelProcessHandler(data.ID, modelProcess)
})
if err != nil {
Expand All @@ -204,7 +168,7 @@ func (pm *PluginManager) loadModelPlugins(meter metric.Meter) error {
logger.Printf(logging.WARN, "| %s | cannot load plugin: cannot load %s function", data.ID, modelProcessFunctionName)
continue
}
processFunc, ok = procFunc.(func(ModelInput) (ModelResults, error))
processFunc, ok = procFunc.(func(waceapi.ModelInput) (waceapi.ModelResults, error))
if !ok {
logger.Printf(logging.WARN, "| %s | cannot load plugin: invalid %s function type", data.ID, modelProcessFunctionName)
continue
Expand Down Expand Up @@ -273,7 +237,7 @@ func (pm *PluginManager) loadDecisionPlugins(meter metric.Meter) error {
logger.Printf(logging.ERROR, "| %s | cannot load plugin %s function: %v", data.ID, decisionCheckFuncionName, err)
continue
}
checkResults, ok := checkFunc.(func(DecisionInput) (bool, error))
checkResults, ok := checkFunc.(func(waceapi.DecisionInput) (bool, error))
if !ok {
logger.Printf(logging.ERROR, "| %s | %s lookup failed for plugin: invalid function type", data.ID, decisionCheckFuncionName)
continue
Expand Down Expand Up @@ -379,8 +343,8 @@ func (p *PluginManager) RemoveAsyncModelChannel(transactionId string, t configst
}

// AddToQueue adds a payload to the model queue
func (p *PluginManager) AddToQueue(modelID, transactionID string, payload HTTPPayload) error {
payloadToSend := &ModelInput{
func (p *PluginManager) AddToQueue(modelID, transactionID string, payload waceapi.HTTPPayload) error {
payloadToSend := &waceapi.ModelInput{
TransactionId: transactionID,
Payload: payload,
}
Expand All @@ -395,7 +359,7 @@ func (p *PluginManager) AddToQueue(modelID, transactionID string, payload HTTPPa
}

// Process is in charge of calling the model plugin with id modelID
func (p *PluginManager) Process(modelID, transactionId string, payload HTTPPayload, t configstore.ModelPluginType, modelPlugStatus chan ModelStatus) error {
func (p *PluginManager) Process(modelID, transactionId string, payload waceapi.HTTPPayload, t configstore.ModelPluginType, modelPlugStatus chan ModelStatus) error {
conf, err := configstore.Get()
if err != nil {
return err
Expand Down Expand Up @@ -423,7 +387,7 @@ func (p *PluginManager) Process(modelID, transactionId string, payload HTTPPaylo
modelPlugStatus <- ModelStatus{ModelID: modelID, Err: fmt.Errorf("model plugin is async")}
return nil
} else {
res, err := mp.process(ModelInput{TransactionId: transactionId, Payload: payload})
res, err := mp.process(waceapi.ModelInput{TransactionId: transactionId, Payload: payload})

if err != nil {
modelPlugStatus <- ModelStatus{ModelID: modelID, Err: err}
Expand Down Expand Up @@ -461,15 +425,15 @@ func (p *PluginManager) CheckResult(transactionId, decisionId string, wafParams
return false, nil
}

modelResultMap := make(map[string]ModelResults)
modelResultMap := make(map[string]waceapi.ModelResults)
modelWeightMap := make(map[string]float64)
transactionResults.(*sync.Map).Range(func(key, value interface{}) bool {
modelResultMap[key.(string)] = value.(ModelResults)
modelResultMap[key.(string)] = value.(waceapi.ModelResults)
modelWeightMap[key.(string)] = cs.ModelPlugins[key.(string)].Weight
return true
})

res, err := dp.checkResults(DecisionInput{TransactionId: transactionId, Results: modelResultMap, ModelWeight: modelWeightMap, WAFdata: wafParams})
res, err := dp.checkResults(waceapi.DecisionInput{TransactionId: transactionId, Results: modelResultMap, ModelWeight: modelWeightMap, WAFdata: wafParams})
logger.TPrintf(logging.INFO, transactionId, "%s | transaction checked. Block: %t ", decisionId, res)

return res, err
Expand Down Expand Up @@ -514,7 +478,7 @@ func (p *PluginManager) ModelResultsHandler(modelId string) error {
modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, Err: fmt.Errorf("transaction results not found")}
return
}
modelResult := ModelResults{ProbAttack: data.ProbAttack, Data: data.Data}
modelResult := waceapi.ModelResults{ProbAttack: data.ProbAttack, Data: data.Data}
resultSyncMap.(*sync.Map).Store(modelId, modelResult)
}
modelChannel.(chan ModelStatus) <- ModelStatus{ModelID: modelId, ProbAttack: data.ProbAttack, Err: nil}
Expand All @@ -540,7 +504,7 @@ func (p *PluginManager) ModelResultsHandler(modelId string) error {
}

// ModelProcessHandler listens for messages on the model queue
func ModelProcessHandler(modelId string, modelProcess func(ModelInput) (ModelResults, error)) error {
func ModelProcessHandler(modelId string, modelProcess func(waceapi.ModelInput) (waceapi.ModelResults, error)) error {
logger := logging.Get()
logger.Printf(logging.INFO, "Model: %s | Starting model process handler", modelId)
cs, err := configstore.Get()
Expand All @@ -557,13 +521,13 @@ func ModelProcessHandler(modelId string, modelProcess func(ModelInput) (ModelRes

_, err = nc.Subscribe(modelId, func(msg *nats.Msg) {
go func(msg nats.Msg) {
data := &ModelInput{}
data := &waceapi.ModelInput{}
err := json.Unmarshal(msg.Data, data)
if err != nil {
logger.Printf(logging.ERROR, "Model: %s | Failed to parse JSON payload", modelId)
} else {
res, err := modelProcess(*data)
modelResult := ModelResults{ProbAttack: res.ProbAttack, Data: res.Data}
modelResult := waceapi.ModelResults{ProbAttack: res.ProbAttack, Data: res.Data}
payloadToSend := &ModelTransmitionResults{
TransactionId: data.TransactionId,
ModelResults: modelResult,
Expand Down
Loading
Loading