Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions src/detectmatelibrary/common/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tools.logging import logger, setup_logging

from typing import Any, Dict, Tuple, List
from enum import Enum


setup_logging()
Expand Down Expand Up @@ -38,12 +39,32 @@ def postprocess(
return data if not is_byte else data.serialize()


class TrainState(Enum):
DEFAULT = 0
STOP_TRAINING = 1
KEEP_TRAINING = 2

def describe(self) -> str:
descriptions = [
"Follow default training behavior.",
"Force stop training.",
"Keep training regardless of default behavior."
]

return descriptions[self.value]


class CoreConfig(BasicConfig):
start_id: int = 10
data_use_training: int | None = None


def do_training(config: CoreConfig, index: int) -> bool:
def do_training(config: CoreConfig, index: int, train_state: TrainState) -> bool:
if train_state == TrainState.STOP_TRAINING:
return False
elif train_state == TrainState.KEEP_TRAINING:
return True

return config.data_use_training is not None and config.data_use_training > index


Expand All @@ -65,6 +86,7 @@ def __init__(
self.data_buffer = DataBuffer(args_buffer)
self.id_generator = SimpleIDGenerator(self.config.start_id)
self.data_used_train = 0
self.train_state: TrainState = TrainState.DEFAULT

def __repr__(self) -> str:
return f"<{self.type_}> {self.name}: {self.config}"
Expand All @@ -86,7 +108,7 @@ def process(self, data: BaseSchema | bytes) -> BaseSchema | bytes | None:
if (data_buffered := self.data_buffer.add(data)) is None: # type: ignore
return None

if do_training(config=self.config, index=self.data_used_train):
if do_training(config=self.config, index=self.data_used_train, train_state=self.train_state):
self.data_used_train += 1
logger.info(f"<<{self.name}>> use data for training")
self.train(input_=data_buffered)
Expand Down
37 changes: 36 additions & 1 deletion tests/test_common/test_core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from detectmatelibrary.common.core import CoreConfig, CoreComponent, TrainState
from detectmatelibrary.common._config import BasicConfig
from detectmatelibrary.common.core import CoreConfig, CoreComponent

from detectmatelibrary.utils.data_buffer import ArgsBuffer

import detectmatelibrary.schemas._op as op_schemas
import detectmatelibrary.schemas as schemas

Expand Down Expand Up @@ -181,3 +183,36 @@ def test_training(self) -> None:
"hostname": "test_hostname"
})
assert expected == log

def test_training_force_stop(self) -> None:
component = MockComponentWithTraining(name="Dummy5")

for i in range(10):
if i == 2:
component.train_state = TrainState.STOP_TRAINING
component.process(
schemas.LogSchema({
"__version__": "1.0.0",
"logID": i,
"logSource": "test",
"hostname": "test_hostname"
})
)

assert len(component.train_data) == 2

def test_training_keep_training(self) -> None:
component = MockComponentWithTraining(name="Dummy6")
component.train_state = TrainState.KEEP_TRAINING

for i in range(10):
component.process(
schemas.LogSchema({
"__version__": "1.0.0",
"logID": i,
"logSource": "test",
"hostname": "test_hostname"
})
)

assert len(component.train_data) == 10