diff --git a/docs/tutorial/parameter-types/enum.md b/docs/tutorial/parameter-types/enum.md index 278908fc72..52362c10eb 100644 --- a/docs/tutorial/parameter-types/enum.md +++ b/docs/tutorial/parameter-types/enum.md @@ -145,3 +145,14 @@ Error: Invalid value for '--network': 'capsule' is not one of 'simple', 'conv', ``` + + +### Functional API + +In order to use an `Enum` created using the functional API, you need to create an enum with string values. + +You also need to supply the default value as a string (not the enum): + +{* docs_src/parameter_types/enum/tutorial005_an.py hl[6,13] *} + +Alternatively, you can create an `Enum` that extends both `str` and `Enum`. In Python 3.11+, there is `enum.StrEnum`. For Python 3.10 or earlier, there is the StrEnum package. diff --git a/docs_src/parameter_types/enum/tutorial005.py b/docs_src/parameter_types/enum/tutorial005.py new file mode 100644 index 0000000000..15d114c417 --- /dev/null +++ b/docs_src/parameter_types/enum/tutorial005.py @@ -0,0 +1,16 @@ +from enum import Enum + +import typer + +NeuralNetwork = Enum("NeuralNetwork", {k: k for k in ["simple", "conv", "lstm"]}) + +app = typer.Typer() + + +@app.command() +def main(network: NeuralNetwork = typer.Option("simple", case_sensitive=False)): + print(f"Training neural network of type: {network.value}") + + +if __name__ == "__main__": + app() diff --git a/docs_src/parameter_types/enum/tutorial005_an.py b/docs_src/parameter_types/enum/tutorial005_an.py new file mode 100644 index 0000000000..7f49d03785 --- /dev/null +++ b/docs_src/parameter_types/enum/tutorial005_an.py @@ -0,0 +1,19 @@ +from enum import Enum +from typing import Annotated + +import typer + +NeuralNetwork = Enum("NeuralNetwork", {k: k for k in ["simple", "conv", "lstm"]}) + +app = typer.Typer() + + +@app.command() +def main( + network: Annotated[NeuralNetwork, typer.Option(case_sensitive=False)] = "simple", +): + print(f"Training neural network of type: {network.value}") + + +if __name__ == "__main__": + app() diff --git a/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial005.py b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial005.py new file mode 100644 index 0000000000..a646ed302d --- /dev/null +++ b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial005.py @@ -0,0 +1,43 @@ +import subprocess +import sys + +from typer.testing import CliRunner + +from docs_src.parameter_types.enum import tutorial005 as mod + +runner = CliRunner() +app = mod.app + + +def test_help(): + result = runner.invoke(app, ["--help"]) + assert result.exit_code == 0 + assert "--network [simple|conv|lstm]" in result.output.replace(" ", "") + + +def test_main(): + result = runner.invoke(app, ["--network", "conv"]) + assert result.exit_code == 0 + assert "Training neural network of type: conv" in result.output + + +def test_invalid(): + result = runner.invoke(app, ["--network", "capsule"]) + assert result.exit_code != 0 + assert "Invalid value for '--network'" in result.output + assert ( + "invalid choice: capsule. (choose from" in result.output + or "'capsule' is not one of" in result.output + ) + assert "simple" in result.output + assert "conv" in result.output + assert "lstm" in result.output + + +def test_script(): + result = subprocess.run( + [sys.executable, "-m", "coverage", "run", mod.__file__, "--help"], + capture_output=True, + encoding="utf-8", + ) + assert "Usage" in result.stdout diff --git a/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial005_an.py b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial005_an.py new file mode 100644 index 0000000000..2e5a805fc1 --- /dev/null +++ b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial005_an.py @@ -0,0 +1,43 @@ +import subprocess +import sys + +from typer.testing import CliRunner + +from docs_src.parameter_types.enum import tutorial005_an as mod + +runner = CliRunner() +app = mod.app + + +def test_help(): + result = runner.invoke(app, ["--help"]) + assert result.exit_code == 0 + assert "--network [simple|conv|lstm]" in result.output.replace(" ", "") + + +def test_main(): + result = runner.invoke(app, ["--network", "conv"]) + assert result.exit_code == 0 + assert "Training neural network of type: conv" in result.output + + +def test_invalid(): + result = runner.invoke(app, ["--network", "capsule"]) + assert result.exit_code != 0 + assert "Invalid value for '--network'" in result.output + assert ( + "invalid choice: capsule. (choose from" in result.output + or "'capsule' is not one of" in result.output + ) + assert "simple" in result.output + assert "conv" in result.output + assert "lstm" in result.output + + +def test_script(): + result = subprocess.run( + [sys.executable, "-m", "coverage", "run", mod.__file__, "--help"], + capture_output=True, + encoding="utf-8", + ) + assert "Usage" in result.stdout