diff --git a/src/detectmatelibrary/common/core.py b/src/detectmatelibrary/common/core.py index b95e362..73c3321 100644 --- a/src/detectmatelibrary/common/core.py +++ b/src/detectmatelibrary/common/core.py @@ -8,6 +8,7 @@ from tools.logging import logger, setup_logging from typing import Any, Dict, Tuple, List +from enum import Enum setup_logging() @@ -38,12 +39,32 @@ def postprocess( return data if not is_byte else data.serialize() +class TrainState(Enum): + DEFAULT = 0 + STOP_TRAINING = 1 + KEEP_TRAINING = 2 + + def describe(self) -> str: + descriptions = [ + "Follow default training behavior.", + "Force stop training.", + "Keep training regardless of default behavior." + ] + + return descriptions[self.value] + + class CoreConfig(BasicConfig): start_id: int = 10 data_use_training: int | None = None -def do_training(config: CoreConfig, index: int) -> bool: +def do_training(config: CoreConfig, index: int, train_state: TrainState) -> bool: + if train_state == TrainState.STOP_TRAINING: + return False + elif train_state == TrainState.KEEP_TRAINING: + return True + return config.data_use_training is not None and config.data_use_training > index @@ -65,6 +86,7 @@ def __init__( self.data_buffer = DataBuffer(args_buffer) self.id_generator = SimpleIDGenerator(self.config.start_id) self.data_used_train = 0 + self.train_state: TrainState = TrainState.DEFAULT def __repr__(self) -> str: return f"<{self.type_}> {self.name}: {self.config}" @@ -86,7 +108,7 @@ def process(self, data: BaseSchema | bytes) -> BaseSchema | bytes | None: if (data_buffered := self.data_buffer.add(data)) is None: # type: ignore return None - if do_training(config=self.config, index=self.data_used_train): + if do_training(config=self.config, index=self.data_used_train, train_state=self.train_state): self.data_used_train += 1 logger.info(f"<<{self.name}>> use data for training") self.train(input_=data_buffered) diff --git a/tests/test_common/test_core.py b/tests/test_common/test_core.py index 57ff4cd..f63e39a 100644 --- a/tests/test_common/test_core.py +++ b/tests/test_common/test_core.py @@ -1,6 +1,8 @@ +from detectmatelibrary.common.core import CoreConfig, CoreComponent, TrainState from detectmatelibrary.common._config import BasicConfig -from detectmatelibrary.common.core import CoreConfig, CoreComponent + from detectmatelibrary.utils.data_buffer import ArgsBuffer + import detectmatelibrary.schemas._op as op_schemas import detectmatelibrary.schemas as schemas @@ -181,3 +183,36 @@ def test_training(self) -> None: "hostname": "test_hostname" }) assert expected == log + + def test_training_force_stop(self) -> None: + component = MockComponentWithTraining(name="Dummy5") + + for i in range(10): + if i == 2: + component.train_state = TrainState.STOP_TRAINING + component.process( + schemas.LogSchema({ + "__version__": "1.0.0", + "logID": i, + "logSource": "test", + "hostname": "test_hostname" + }) + ) + + assert len(component.train_data) == 2 + + def test_training_keep_training(self) -> None: + component = MockComponentWithTraining(name="Dummy6") + component.train_state = TrainState.KEEP_TRAINING + + for i in range(10): + component.process( + schemas.LogSchema({ + "__version__": "1.0.0", + "logID": i, + "logSource": "test", + "hostname": "test_hostname" + }) + ) + + assert len(component.train_data) == 10