diff --git a/simple_typing_application/key_monitor/factory.py b/simple_typing_application/key_monitor/factory.py index 9559115..aaceb18 100644 --- a/simple_typing_application/key_monitor/factory.py +++ b/simple_typing_application/key_monitor/factory.py @@ -24,7 +24,7 @@ def _select_class_and_config_model(key_monitor_type: EKeyMonitorType) -> tuple[t def create_key_monitor( key_monitor_type: EKeyMonitorType, - dict_config: dict[str, str | float | int | bool | None | dict | list], + config: BaseKeyMonitorConfigModel, logger: Logger = getLogger(__name__), ) -> BaseKeyMonitor: # select key monitor class and config model @@ -37,7 +37,10 @@ def create_key_monitor( # create key monitor logger.debug(f"create {key_monitor_cls.__name__}") - key_monitor_config: BaseKeyMonitorConfigModel = key_monitor_config_model(**dict_config) # noqa + if isinstance(config, key_monitor_config_model): + key_monitor_config: BaseKeyMonitorConfigModel = config + else: + key_monitor_config: BaseKeyMonitorConfigModel = key_monitor_config_model(**config.model_dump()) # noqa key_monitor: BaseKeyMonitor = key_monitor_cls(**key_monitor_config.model_dump()) # type: ignore # noqa return key_monitor diff --git a/simple_typing_application/models/config_models/general_config_model.py b/simple_typing_application/models/config_models/general_config_model.py index 6da2d28..2edf4d3 100644 --- a/simple_typing_application/models/config_models/general_config_model.py +++ b/simple_typing_application/models/config_models/general_config_model.py @@ -1,9 +1,21 @@ from __future__ import annotations from pydantic import BaseModel -from .key_monitor_config_model import BaseKeyMonitorConfigModel -from .sentence_generator_config_model import BaseSentenceGeneratorConfigModel -from .user_interface_config_model import BaseUserInterfaceConfigModel +from .key_monitor_config_model import ( + BaseKeyMonitorConfigModel, + PynputBasedKeyMonitorConfigModel, + SSHKeyboardBasedKeyMonitorConfigModel, +) +from .sentence_generator_config_model import ( + BaseSentenceGeneratorConfigModel, + HuggingfaceSentenceGeneratorConfigModel, + OpenAISentenceGeneratorConfigModel, + StaticSentenceGeneratorConfigModel, +) +from .user_interface_config_model import ( + BaseUserInterfaceConfigModel, + ConsoleUserInterfaceConfigModel, +) from ...const.key_monitor import EKeyMonitorType from ...const.sentence_generator import ESentenceGeneratorType from ...const.user_interface import EUserInterfaceType @@ -11,16 +23,21 @@ class ConfigModel(BaseModel): sentence_generator_type: ESentenceGeneratorType = ESentenceGeneratorType.OPENAI # noqa - sentence_generator_config: dict[str, str | float | int | bool | None | dict | list] = ( - BaseSentenceGeneratorConfigModel().model_dump() - ) # noqa + sentence_generator_config: ( + BaseSentenceGeneratorConfigModel + | HuggingfaceSentenceGeneratorConfigModel + | OpenAISentenceGeneratorConfigModel + | StaticSentenceGeneratorConfigModel + ) = BaseSentenceGeneratorConfigModel() user_interface_type: EUserInterfaceType = EUserInterfaceType.CONSOLE - user_interface_config: dict[str, str | float | int | None | dict | list] = ( - BaseUserInterfaceConfigModel().model_dump() - ) # noqa + user_interface_config: BaseUserInterfaceConfigModel | ConsoleUserInterfaceConfigModel = ( + BaseUserInterfaceConfigModel() + ) key_monitor_type: EKeyMonitorType = EKeyMonitorType.PYNPUT - key_monitor_config: dict[str, str | float | int | None | dict | list] = BaseKeyMonitorConfigModel().model_dump() # noqa + key_monitor_config: ( + BaseKeyMonitorConfigModel | PynputBasedKeyMonitorConfigModel | SSHKeyboardBasedKeyMonitorConfigModel + ) = BaseKeyMonitorConfigModel() record_direc: str = "./record" diff --git a/simple_typing_application/sentence_generator/factory.py b/simple_typing_application/sentence_generator/factory.py index 9cf39bc..1cc92f9 100644 --- a/simple_typing_application/sentence_generator/factory.py +++ b/simple_typing_application/sentence_generator/factory.py @@ -35,7 +35,7 @@ def _select_class_and_config_model(sentence_generator_type: ESentenceGeneratorTy def create_sentence_generator( sentence_generator_type: ESentenceGeneratorType, - dict_config: dict[str, str | float | int | bool | None | dict | list], + config: BaseSentenceGeneratorConfigModel, logger: Logger = getLogger(__name__), ) -> BaseSentenceGenerator: # select sentence generator class and config model @@ -50,7 +50,10 @@ def create_sentence_generator( # create sentence generator logger.debug(f"create {sentence_generator_cls.__name__}") - sentence_generator_config: BaseSentenceGeneratorConfigModel = sentence_generator_config_model(**dict_config) # noqa + if isinstance(config, sentence_generator_config_model): + sentence_generator_config: BaseSentenceGeneratorConfigModel = config + else: + sentence_generator_config: BaseSentenceGeneratorConfigModel = sentence_generator_config_model(**config.model_dump()) # noqa sentence_generator: BaseSentenceGenerator = sentence_generator_cls(**sentence_generator_config.model_dump()) # type: ignore # noqa return sentence_generator diff --git a/simple_typing_application/ui/factory.py b/simple_typing_application/ui/factory.py index 0ad5ea6..3f937af 100644 --- a/simple_typing_application/ui/factory.py +++ b/simple_typing_application/ui/factory.py @@ -19,7 +19,7 @@ def _select_class_and_config_model(user_interface_type: EUserInterfaceType) -> t def create_user_interface( user_interface_type: EUserInterfaceType, - dict_config: dict[str, str | float | int | bool | None | dict | list], + config: BaseUserInterfaceConfigModel, logger: Logger = getLogger(__name__), ) -> BaseUserInterface: # select user interface class and config model @@ -32,7 +32,10 @@ def create_user_interface( # create user interface logger.debug(f"create {user_interface_cls.__name__}") - user_interface_config: BaseUserInterfaceConfigModel = user_interface_config_model(**dict_config) # noqa + if isinstance(config, user_interface_config_model): + user_interface_config: BaseUserInterfaceConfigModel = config + else: + user_interface_config: BaseUserInterfaceConfigModel = user_interface_config_model(**config.model_dump()) # noqa user_interface: BaseUserInterface = user_interface_cls(**user_interface_config.model_dump()) # type: ignore # noqa return user_interface diff --git a/tests/key_monitor/test_factory.py b/tests/key_monitor/test_factory.py index b13d683..7e8c58d 100644 --- a/tests/key_monitor/test_factory.py +++ b/tests/key_monitor/test_factory.py @@ -3,6 +3,7 @@ from simple_typing_application.const.key_monitor import EKeyMonitorType # noqa from simple_typing_application.models.config_models.key_monitor_config_model import ( # noqa + BaseKeyMonitorConfigModel, SSHKeyboardBasedKeyMonitorConfigModel, PynputBasedKeyMonitorConfigModel, ) @@ -38,26 +39,29 @@ def test_select_class_and_config_model_raise_value_error(): @pytest.mark.parametrize( - "key_monitor_type, key_monitor_config_dict, expected_class", + "key_monitor_type, key_monitor_config, expected_class", [ ( EKeyMonitorType.PYNPUT, - PynputBasedKeyMonitorConfigModel().model_dump(), + PynputBasedKeyMonitorConfigModel(), PynputBasedKeyMonitor, ), ( EKeyMonitorType.PYNPUT, - {}, + BaseKeyMonitorConfigModel(), PynputBasedKeyMonitor, ), - (EKeyMonitorType.SSHKEYBOARD, SSHKeyboardBasedKeyMonitorConfigModel().model_dump(), SSHKeyboardBasedKeyMonitor), # noqa + ( + EKeyMonitorType.SSHKEYBOARD, + SSHKeyboardBasedKeyMonitorConfigModel(), + SSHKeyboardBasedKeyMonitor, + ), # noqa ], ) def test_create_key_monitor( key_monitor_type: EKeyMonitorType, - key_monitor_config_dict: dict[str, str | float | int | bool | None | dict | list], # noqa + key_monitor_config: BaseKeyMonitorConfigModel, expected_class: type, - mocker, ): # mock # for PynputBasedKeyMonitor @@ -66,7 +70,7 @@ def test_create_key_monitor( # execute key_monitor = create_key_monitor( key_monitor_type, - key_monitor_config_dict, + key_monitor_config, ) # assert diff --git a/tests/sentence_generator/test_factory.py b/tests/sentence_generator/test_factory.py index a4434e7..e35c3cf 100644 --- a/tests/sentence_generator/test_factory.py +++ b/tests/sentence_generator/test_factory.py @@ -14,8 +14,9 @@ from simple_typing_application.const.sentence_generator import ESentenceGeneratorType # noqa from simple_typing_application.models.config_models.sentence_generator_config_model import ( # noqa - OpenAISentenceGeneratorConfigModel, + BaseSentenceGeneratorConfigModel, HuggingfaceSentenceGeneratorConfigModel, + OpenAISentenceGeneratorConfigModel, StaticSentenceGeneratorConfigModel, ) from simple_typing_application.sentence_generator.factory import ( @@ -67,26 +68,26 @@ def test_select_class_and_config_model_raise_value_error(): @pytest.mark.parametrize( - "sentence_generator_type, sentence_generator_config_dict, expected_class", + "sentence_generator_type, sentence_generator_config, expected_class", [ ( ESentenceGeneratorType.OPENAI, - OpenAISentenceGeneratorConfigModel().model_dump(), + OpenAISentenceGeneratorConfigModel(), OpenaiSentenceGenerator, ), ( ESentenceGeneratorType.STATIC, - StaticSentenceGeneratorConfigModel(text_kana_map={}).model_dump(), + StaticSentenceGeneratorConfigModel(text_kana_map={}), StaticSentenceGenerator, ), ( ESentenceGeneratorType.OPENAI, - {}, + BaseSentenceGeneratorConfigModel(), OpenaiSentenceGenerator, ), ( ESentenceGeneratorType.STATIC, - {}, + BaseSentenceGeneratorConfigModel(), StaticSentenceGenerator, ), ] @@ -94,12 +95,12 @@ def test_select_class_and_config_model_raise_value_error(): [ ( ESentenceGeneratorType.HUGGINGFACE, - HuggingfaceSentenceGeneratorConfigModel().model_dump(), + HuggingfaceSentenceGeneratorConfigModel(), HuggingfaceSentenceGenerator, ), ( ESentenceGeneratorType.HUGGINGFACE, - {}, + BaseSentenceGeneratorConfigModel(), HuggingfaceSentenceGenerator, ), ] @@ -109,7 +110,7 @@ def test_select_class_and_config_model_raise_value_error(): ) def test_create_sentence_generator( sentence_generator_type: ESentenceGeneratorType, - sentence_generator_config_dict: dict[str, str | float | int | bool | None | dict | list], # noqa + sentence_generator_config: BaseSentenceGeneratorConfigModel, expected_class: type, mocker, ): @@ -131,7 +132,7 @@ def test_create_sentence_generator( # execute sentence_generator = create_sentence_generator( sentence_generator_type, - sentence_generator_config_dict, + sentence_generator_config, ) # assert diff --git a/tests/test_config.py b/tests/test_config.py index 747cdba..bc44035 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -37,10 +37,14 @@ def test_load_config_json(mocker): ) # run - actual = load_config(path) + actual: ConfigModel = load_config(path) + + # postprocess # assert - assert actual == expected + actual_dic = actual.model_dump(exclude={"sentence_generator_config": {"openai_api_key"}}) + expected_dic = expected.model_dump(exclude={"sentence_generator_config": {"openai_api_key"}}) + assert actual_dic == expected_dic def test_load_config_json_not_found(mocker): diff --git a/tests/ui/test_factory.py b/tests/ui/test_factory.py index efc7989..ffb6341 100644 --- a/tests/ui/test_factory.py +++ b/tests/ui/test_factory.py @@ -3,6 +3,7 @@ from simple_typing_application.const.user_interface import EUserInterfaceType from simple_typing_application.models.config_models.user_interface_config_model import ( # noqa + BaseUserInterfaceConfigModel, ConsoleUserInterfaceConfigModel, ) from simple_typing_application.ui.cui import ConsoleUserInterface @@ -42,25 +43,24 @@ def test_select_class_and_config_model_raise_value_error(): @pytest.mark.parametrize( - "user_interface_type, user_interface_config_dict, expected_class", + "user_interface_type, user_interface_config, expected_class", [ ( EUserInterfaceType.CONSOLE, - ConsoleUserInterfaceConfigModel().model_dump(), + ConsoleUserInterfaceConfigModel(), ConsoleUserInterface, ), ( EUserInterfaceType.CONSOLE, - {}, + BaseUserInterfaceConfigModel(), ConsoleUserInterface, ), ], ) def test_create_user_interface( user_interface_type: EUserInterfaceType, - user_interface_config_dict: dict[str, str | float | int | bool | None | dict | list], # noqa + user_interface_config: BaseUserInterfaceConfigModel, expected_class: type, - mocker, ): # mock # for ConsoleUserInterface @@ -69,7 +69,7 @@ def test_create_user_interface( # execute user_interface = create_user_interface( user_interface_type, - user_interface_config_dict, + user_interface_config, ) # assert