Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 50 additions & 26 deletions configstore/configstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,23 @@ func StringToPluginType(textType string) (ModelPluginType, error) {
return -1, fmt.Errorf("invalid plugin type %s", textType)
}

type TrainingData struct {
MaxSamples int `yaml:"max_samples"`
ResultFilePath string `yaml:"result_file_path"`
}

// ModelPluginConfig stores the configuration of a model plugin
type modelPluginConfig struct {
ID string
Path string
Weight float64
Threshold float64
Params map[string]string
PluginType ModelPluginType
Mode string
Remote bool
ID string
Path string
Weight float64
Threshold float64
Params map[string]string
PluginType ModelPluginType
async bool
remote bool
training bool
TrainingData TrainingData
}

// DecisionPluginConfig stores the configuration of a decision plugin
Expand Down Expand Up @@ -121,14 +128,16 @@ func Clean() {
}

type configFileModelPlugin struct {
ID string
Path string
Weight float64
Threshold float64
Params map[string]string
PluginType string `yaml:"plugintype"`
Mode string
Remote bool
ID string
Path string
Weight float64
Threshold float64
Params map[string]string
PluginType string `yaml:"plugintype"`
Async bool
Remote bool
Training bool
TrainingData TrainingData `yaml:"training_data"`
}

type configFileDecisionPlugin struct {
Expand All @@ -147,7 +156,17 @@ type ConfigFileData struct {

// IsAsync returns true if the model plugin is async
func (c *ConfigStore) IsAsync(modelID string) bool {
return c.ModelPlugins[modelID].Mode == "async"
return c.ModelPlugins[modelID].async
}

// IsRemote returns true if the model plugin is remote
func (c *ConfigStore) IsRemote(modelID string) bool {
return c.ModelPlugins[modelID].remote
}

// IsInTraining returns true if the model plugin is in training mode (collecting data)
func (c *ConfigStore) IsInTraining(modelID string) bool {
return c.ModelPlugins[modelID].training
}

// CheckLogging verifies if the log path is valid
Expand Down Expand Up @@ -178,7 +197,6 @@ func checkConfig(inConf ConfigFileData) error {

// check modelplugins
for _, modelP := range inConf.Modelplugins {

if modelP.Path != "" {
if _, err := os.Stat(modelP.Path); err != nil {
return fmt.Errorf("%s plugin path %s: %v", modelP.ID, modelP.Path, err)
Expand All @@ -189,7 +207,15 @@ func checkConfig(inConf ConfigFileData) error {
if modelP.PluginType == "" {
return fmt.Errorf("%s plugin type cannot be empty, please provide a valid type", modelP.ID)
}
// fmt.Printf("modelP.Type: %s\n", modelP.Type)
if modelP.Training && modelP.Async {
return fmt.Errorf("model %s plugin cannot be in training mode and async mode at the same time", modelP.ID)
}
if modelP.Training && modelP.Remote {
return fmt.Errorf("model %s: remote training mode is not supported", modelP.ID)
}
if modelP.Training && modelP.TrainingData.MaxSamples == 0 {
return fmt.Errorf("model %s: max sample count should be greater than 0", modelP.ID)
}
}
// check decisionplugins
for _, decisionP := range inConf.Decisionplugins {
Expand Down Expand Up @@ -228,8 +254,10 @@ func (cs *ConfigStore) SetConfig(inConf ConfigFileData) error {
modelConfig.Threshold = modelP.Threshold
modelConfig.Params = modelP.Params
modelConfig.PluginType, err = StringToPluginType(modelP.PluginType)
modelConfig.Mode = modelP.Mode
modelConfig.Remote = modelP.Remote
modelConfig.async = modelP.Async
modelConfig.remote = modelP.Remote
modelConfig.training = modelP.Training
modelConfig.TrainingData = modelP.TrainingData
if err != nil {
return err
}
Expand All @@ -245,11 +273,7 @@ func (cs *ConfigStore) SetConfig(inConf ConfigFileData) error {
cs.DecisionPlugins[decisionConfig.ID] = decisionConfig
}

if inConf.NatsURL != "" {
cs.NatsURL = inConf.NatsURL
} else {
cs.NatsURL = "localhost:4222"
}
cs.NatsURL = inConf.NatsURL

return nil
}
161 changes: 152 additions & 9 deletions configstore/configstore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -485,12 +485,11 @@ func TestModelPluginTypeString(t *testing.T) {
func TestIsAsync(t *testing.T) {
tests := []struct {
name string
mode string
async bool
wantAsync bool
}{
{"sync mode", "sync", false},
{"async mode", "async", true},
{"empty mode defaults to sync", "", false},
{"no async field defaults to sync", false, false},
{"async: true is async", true, true},
}

for _, tt := range tests {
Expand All @@ -508,14 +507,14 @@ modelplugins:
- id: "testplugin"
path: "../testdata/plugins/model/trivial.so"
plugintype: "RequestHeaders"
mode: "%s"
`, tt.mode)
async: %v
`, tt.async)
if err := initialize([]byte(config)); err != nil {
t.Fatalf("initialize failed: %v", err)
}

if got := cs.IsAsync("testplugin"); got != tt.wantAsync {
t.Errorf("IsAsync with mode %q = %v, want %v", tt.mode, got, tt.wantAsync)
t.Errorf("IsAsync with async=%v = %v, want %v", tt.async, got, tt.wantAsync)
}
})
}
Expand All @@ -531,6 +530,150 @@ func TestGetBeforeNew(t *testing.T) {
}
}

func TestIsInTraining(t *testing.T) {
tests := []struct {
name string
training bool
wantTraining bool
}{
{"no training field defaults to false", false, false},
{"training: true enables training mode", true, true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs, err := New()
if err != nil {
t.Fatal(err)
}
defer Clean()

trainingSection := ""
if tt.training {
trainingSection = "\n training_data:\n max_samples: 10"
}
config := fmt.Sprintf(`---
loglevel: ERROR
logpath: /dev/null
modelplugins:
- id: "testplugin"
path: "../testdata/plugins/model/trivial.so"
plugintype: "RequestHeaders"
training: %v%s
`, tt.training, trainingSection)
if err := initialize([]byte(config)); err != nil {
t.Fatalf("initialize failed: %v", err)
}

if got := cs.IsInTraining("testplugin"); got != tt.wantTraining {
t.Errorf("IsInTraining with training=%v = %v, want %v", tt.training, got, tt.wantTraining)
}
})
}
}

func TestTrainingDataConfig(t *testing.T) {
tests := []struct {
name string
config string
wantErr bool
wantMaxSamples int
wantPath string
}{
{
name: "training with zero max_samples returns error",
config: `---
loglevel: ERROR
logpath: /dev/null
modelplugins:
- id: "testplugin"
path: "../testdata/plugins/model/trivial.so"
plugintype: "RequestHeaders"
training: true
`,
wantErr: true,
},
{
name: "training and async are mutually exclusive",
config: `---
loglevel: ERROR
logpath: /dev/null
modelplugins:
- id: "testplugin"
path: "../testdata/plugins/model/trivial.so"
plugintype: "RequestHeaders"
training: true
async: true
training_data:
max_samples: 10
`,
wantErr: true,
},
{
name: "training and remote are mutually exclusive",
config: `---
loglevel: ERROR
logpath: /dev/null
modelplugins:
- id: "testplugin"
path: "../testdata/plugins/model/trivial.so"
plugintype: "RequestHeaders"
training: true
remote: true
training_data:
max_samples: 10
`,
wantErr: true,
},
{
name: "valid training config stores TrainingData correctly",
config: `---
loglevel: ERROR
logpath: /dev/null
modelplugins:
- id: "testplugin"
path: "../testdata/plugins/model/trivial.so"
plugintype: "RequestHeaders"
training: true
training_data:
max_samples: 42
result_file_path: "/dev/null"
`,
wantErr: false,
wantMaxSamples: 42,
wantPath: "/dev/null",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs, err := New()
if err != nil {
t.Fatal(err)
}
defer Clean()

err = initialize([]byte(tt.config))
if (err != nil) != tt.wantErr {
if tt.wantErr {
t.Errorf("expected error but got none")
} else {
t.Errorf("unexpected error: %v", err)
}
return
}
if !tt.wantErr {
got := cs.ModelPlugins["testplugin"].TrainingData
if got.MaxSamples != tt.wantMaxSamples {
t.Errorf("MaxSamples = %d, want %d", got.MaxSamples, tt.wantMaxSamples)
}
if got.ResultFilePath != tt.wantPath {
t.Errorf("ResultsFilePath = %q, want %q", got.ResultFilePath, tt.wantPath)
}
}
})
}
}

func TestNatsURL(t *testing.T) {
tests := []struct {
Expand All @@ -539,12 +682,12 @@ func TestNatsURL(t *testing.T) {
wantURL string
}{
{
name: "defaults to localhost:4222",
name: "empty string when natsurl not set",
config: `---
loglevel: ERROR
logpath: /dev/null
`,
wantURL: "localhost:4222",
wantURL: "",
},
{
name: "stores custom URL",
Expand Down
File renamed without changes.
Loading
Loading