From 1dbfd47b4d75b18fff9e0f8dc5a2948e0b3d1579 Mon Sep 17 00:00:00 2001 From: Yannis Chatzikonstantinou Date: Fri, 2 Jan 2026 21:04:43 +0200 Subject: [PATCH 1/5] fix char bug and modularize code --- avlos/templates/fw_endpoints.c.jinja | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/avlos/templates/fw_endpoints.c.jinja b/avlos/templates/fw_endpoints.c.jinja index 0921718..fff41b3 100644 --- a/avlos/templates/fw_endpoints.c.jinja +++ b/avlos/templates/fw_endpoints.c.jinja @@ -18,8 +18,7 @@ {%- macro getter_char(attr) -%} if (AVLOS_CMD_READ == cmd) { - *buffer_len = {{ attr.getter_name }}((char *)buffer); - return AVLOS_RET_READ; + return _avlos_getter_string(buffer, buffer_len, {{ attr.getter_name }}); } {%- endmacro %} @@ -34,10 +33,7 @@ {%- macro setter_char(attr) -%} {% if attr.getter_name %}else {% endif %}if (AVLOS_CMD_WRITE == cmd) { - {{attr.dtype.c_name}} v; - memcpy(&v, buffer, sizeof(v)); - {{ attr.setter_name }}(v); - return AVLOS_RET_WRITE; + return _avlos_setter_string(buffer, {{ attr.setter_name }}); } {%- endmacro %} @@ -45,6 +41,16 @@ #include {{ include | as_include }} {%- endfor %} +static inline uint8_t _avlos_getter_string(uint8_t *buffer, uint8_t *buffer_len, uint8_t (*getter)(char*)) { + *buffer_len = getter((char *)buffer); + return AVLOS_RET_READ; +} + +static inline uint8_t _avlos_setter_string(const uint8_t *buffer, void (*setter)(const char*)) { + setter((const char *)buffer); + return AVLOS_RET_WRITE; +} + {% set comma = joiner(", ") %} uint8_t (*avlos_endpoints[{{ instance | endpoints | length }}])(uint8_t * buffer, uint8_t * buffer_len, Avlos_Command cmd) = { {%- for attr in instance | endpoints %}{{ comma() }}&avlos_{{attr.full_name | replace(".", "_") }}{%- endfor %} }; From 2073dbbdbc757b1daf770f8d9eb70a8c0f261d3a Mon Sep 17 00:00:00 2001 From: Yannis Chatzikonstantinou Date: Sat, 10 Jan 2026 14:25:21 +0200 Subject: [PATCH 2/5] add templating checks --- .clang-format | 10 ++ avlos/definitions/remote_attribute.py | 39 ++++++ avlos/definitions/remote_bitmask.py | 15 +++ avlos/definitions/remote_enum.py | 15 +++ avlos/definitions/remote_function.py | 5 + avlos/formatting.py | 55 ++++++++ avlos/generators/filters.py | 75 ++++++++--- avlos/generators/generator_c.py | 24 ++++ avlos/generators/generator_cpp.py | 23 ++++ avlos/templates/fw_endpoints.c.jinja | 8 +- avlos/templates/fw_endpoints.h.jinja | 4 +- avlos/validation.py | 183 ++++++++++++++++++++++++++ 12 files changed, 433 insertions(+), 23 deletions(-) create mode 100644 .clang-format create mode 100644 avlos/formatting.py create mode 100644 avlos/validation.py diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..3ef7579 --- /dev/null +++ b/.clang-format @@ -0,0 +1,10 @@ +--- +BasedOnStyle: LLVM +IndentWidth: 4 +ColumnLimit: 100 +AllowShortFunctionsOnASingleLine: Empty +AlignConsecutiveAssignments: false +PointerAlignment: Left +SpaceBeforeParens: Never +BreakBeforeBraces: Linux +IndentCaseLabels: false diff --git a/avlos/definitions/remote_attribute.py b/avlos/definitions/remote_attribute.py index 6c84e82..030b225 100644 --- a/avlos/definitions/remote_attribute.py +++ b/avlos/definitions/remote_attribute.py @@ -96,3 +96,42 @@ def str_dump(self): self.dtype.nickname, value ) + + @property + def getter_strategy(self) -> str: + """ + Determine the strategy for getter implementation. + + Returns: + 'string' for char[] types, 'byval' for all other types + """ + if self.dtype.c_name == "char[]": + return "string" + return "byval" + + @property + def setter_strategy(self) -> str: + """ + Determine the strategy for setter implementation. + + Returns: + 'string' for char[] types, 'byval' for all other types + """ + if self.dtype.c_name == "char[]": + return "string" + return "byval" + + @property + def endpoint_function_name(self) -> str: + """ + Get the C function name for this endpoint. + + Returns: + Function name in format 'avlos_parent_child_attribute' + """ + return "avlos_" + self.full_name.replace(".", "_") + + @property + def is_string_type(self) -> bool: + """Check if this attribute uses string/char[] type.""" + return self.dtype.c_name == "char[]" diff --git a/avlos/definitions/remote_bitmask.py b/avlos/definitions/remote_bitmask.py index 7c7aced..946f000 100644 --- a/avlos/definitions/remote_bitmask.py +++ b/avlos/definitions/remote_bitmask.py @@ -95,3 +95,18 @@ def str_dump(self): self.name, str(val) if val > 0 else "(no flags)", ) + + @property + def endpoint_function_name(self) -> str: + """Get the C function name for this endpoint.""" + return "avlos_" + self.full_name.replace(".", "_") + + @property + def getter_strategy(self) -> str: + """Bitmasks always use byval strategy.""" + return "byval" + + @property + def setter_strategy(self) -> str: + """Bitmasks always use byval strategy.""" + return "byval" diff --git a/avlos/definitions/remote_enum.py b/avlos/definitions/remote_enum.py index 02cb556..5b0b03b 100644 --- a/avlos/definitions/remote_enum.py +++ b/avlos/definitions/remote_enum.py @@ -114,3 +114,18 @@ def str_dump(self): """ val = self.get_value() return "{0}: {1}".format(self.name, str(val)) + + @property + def endpoint_function_name(self) -> str: + """Get the C function name for this endpoint.""" + return "avlos_" + self.full_name.replace(".", "_") + + @property + def getter_strategy(self) -> str: + """Enums always use byval strategy.""" + return "byval" + + @property + def setter_strategy(self) -> str: + """Enums always use byval strategy.""" + return "byval" diff --git a/avlos/definitions/remote_function.py b/avlos/definitions/remote_function.py index e5392a5..645e992 100644 --- a/avlos/definitions/remote_function.py +++ b/avlos/definitions/remote_function.py @@ -69,6 +69,11 @@ def str_dump(self): self.dtype.nickname, ) + @property + def endpoint_function_name(self) -> str: + """Get the C function name for this endpoint.""" + return "avlos_" + self.full_name.replace(".", "_") + class RemoteArgument: """ diff --git a/avlos/formatting.py b/avlos/formatting.py new file mode 100644 index 0000000..3437317 --- /dev/null +++ b/avlos/formatting.py @@ -0,0 +1,55 @@ +""" +Code formatting utilities for generated code. +""" +import subprocess +import shutil +import sys +from pathlib import Path + + +def is_clang_format_available() -> bool: + """Check if clang-format is installed on the system.""" + return shutil.which("clang-format") is not None + + +def format_c_code(file_path: str, style: str = "LLVM") -> bool: + """ + Format C/C++ code using clang-format. + + Args: + file_path: Path to file to format + style: clang-format style (LLVM, Google, Chromium, Mozilla, WebKit, Microsoft, GNU) + + Returns: + True if formatting succeeded, False if clang-format not available or failed + """ + if not is_clang_format_available(): + return False + + try: + result = subprocess.run( + ["clang-format", "-i", f"--style={style}", file_path], + capture_output=True, + timeout=10, + check=False + ) + return result.returncode == 0 + except (subprocess.TimeoutExpired, subprocess.SubprocessError): + return False + + +def format_files(file_paths: list, style: str = "LLVM") -> dict: + """ + Format multiple C/C++ files. + + Args: + file_paths: List of file paths to format + style: clang-format style + + Returns: + Dictionary mapping file path to success status + """ + results = {} + for path in file_paths: + results[path] = format_c_code(path, style) + return results diff --git a/avlos/generators/filters.py b/avlos/generators/filters.py index 8a6f317..1c98c7e 100644 --- a/avlos/generators/filters.py +++ b/avlos/generators/filters.py @@ -1,13 +1,23 @@ import os +from typing import List from copy import copy -def avlos_endpoints(input): +def avlos_endpoints(input) -> List: """ - Traverse remote dictionary and return list - of remote endpoints + Traverse remote dictionary and return list of remote endpoints. + + Recursively walks the tree of RemoteNode objects and collects all endpoint + objects (those with getter_name, setter_name, or caller_name). + + Args: + input: Root RemoteNode to traverse + + Returns: + Flat list of all endpoint objects found in the tree """ - def traverse_endpoint_list(ep_list, ep_out_list): + def traverse_endpoint_list(ep_list, ep_out_list: List) -> None: + """Helper function to recursively traverse endpoint tree.""" for ep in ep_list: if ( hasattr(ep, "getter_name") @@ -18,30 +28,50 @@ def traverse_endpoint_list(ep_list, ep_out_list): elif hasattr(ep, "remote_attributes"): traverse_endpoint_list(ep.remote_attributes.values(), ep_out_list) - ep_out_list = [] + ep_out_list: List = [] if hasattr(input, "remote_attributes"): traverse_endpoint_list(input.remote_attributes.values(), ep_out_list) return ep_out_list -def avlos_enum_eps(input): +def avlos_enum_eps(input) -> List: """ - Traverse remote dictionary and return a list of enum type endpoints + Traverse remote dictionary and return a list of enum type endpoints. + + Args: + input: Root RemoteNode to traverse + + Returns: + List of RemoteEnum objects """ return [ep for ep in avlos_endpoints(input) if hasattr(ep, "options")] -def avlos_bitmask_eps(input): +def avlos_bitmask_eps(input) -> List: """ - Traverse remote dictionary and return a list of bitmask type endpoints + Traverse remote dictionary and return a list of bitmask type endpoints. + + Args: + input: Root RemoteNode to traverse + + Returns: + List of RemoteBitmask objects """ return [ep for ep in avlos_endpoints(input) if hasattr(ep, "bitmask")] -def as_include(input): +def as_include(input: str) -> str: """ - Render a string as a C include, with opening - and closing braces or quotation marks + Render a string as a C include, with opening and closing braces or quotation marks. + + If the input already has proper include delimiters, returns unchanged. + Otherwise, wraps in angle brackets. + + Args: + input: Include path string + + Returns: + Properly formatted include directive (e.g., "" or '"myheader.h"') """ if input.startswith('"') and input.endswith('"'): return input @@ -50,16 +80,27 @@ def as_include(input): return "<" + input + ">" -def file_from_path(input): +def file_from_path(input: str) -> str: """ - Get the file string from a path string + Get the file string from a path string. + + Args: + input: File path + + Returns: + Base filename without directory path """ return os.path.basename(input) -def capitalize_first(input): +def capitalize_first(input: str) -> str: """ - Capitalize the first character of a - string, leaving the rest unchanged + Capitalize the first character of a string, leaving the rest unchanged. + + Args: + input: String to capitalize + + Returns: + String with first character capitalized """ return input[0].upper() + input[1:] diff --git a/avlos/generators/generator_c.py b/avlos/generators/generator_c.py index d69aff3..f608791 100644 --- a/avlos/generators/generator_c.py +++ b/avlos/generators/generator_c.py @@ -1,4 +1,5 @@ import os +import sys from jinja2 import Environment, PackageLoader, select_autoescape from avlos.generators.filters import ( avlos_endpoints, @@ -6,11 +7,19 @@ avlos_bitmask_eps, as_include, ) +from avlos.validation import validate_all, ValidationError +from avlos.formatting import format_c_code, is_clang_format_available env = Environment(loader=PackageLoader("avlos"), autoescape=select_autoescape()) def process(instance, config): + # Validate before generation + validation_errors = validate_all(instance) + if validation_errors: + error_msg = "Validation failed:\n" + "\n".join(f" - {err}" for err in validation_errors) + raise ValidationError(error_msg) + env.filters["endpoints"] = avlos_endpoints env.filters["enum_eps"] = avlos_enum_eps env.filters["bitmask_eps"] = avlos_bitmask_eps @@ -47,3 +56,18 @@ def process(instance, config): template.render(instance=instance, includes=includes), file=output_file, ) + + # Post-process with clang-format if available + format_style = config.get("format_style", "LLVM") + + generated_files = [ + config["paths"]["output_enums"], + config["paths"]["output_header"], + config["paths"]["output_impl"], + ] + + for file_path in generated_files: + success = format_c_code(file_path, format_style) + if not success and is_clang_format_available(): + # Only warn if clang-format is installed but failed + print(f"Warning: clang-format failed for {file_path}", file=sys.stderr) diff --git a/avlos/generators/generator_cpp.py b/avlos/generators/generator_cpp.py index e5d5e34..2f5d338 100644 --- a/avlos/generators/generator_cpp.py +++ b/avlos/generators/generator_cpp.py @@ -1,4 +1,5 @@ import os +import sys from pathlib import Path from jinja2 import Environment, PackageLoader, select_autoescape from avlos.generators.filters import ( @@ -7,11 +8,19 @@ file_from_path, capitalize_first, ) +from avlos.validation import validate_all, ValidationError +from avlos.formatting import format_c_code, is_clang_format_available env = Environment(loader=PackageLoader("avlos"), autoescape=select_autoescape()) def process(instance, config): + # Validate before generation + validation_errors = validate_all(instance) + if validation_errors: + error_msg = "Validation failed:\n" + "\n".join(f" - {err}" for err in validation_errors) + raise ValidationError(error_msg) + env.filters["enum_eps"] = avlos_enum_eps env.filters["bitmask_eps"] = avlos_bitmask_eps env.filters["file_from_path"] = file_from_path @@ -30,6 +39,8 @@ def process_helpers(instance, config): template.render(instance=instance), file=output_file, ) + # Format the generated file + format_c_code(file_path, config.get("format_style", "LLVM")) def process_header(instance, config): @@ -51,6 +62,9 @@ def process_header(instance, config): ), file=output_file, ) + # Format the generated file + format_c_code(file_path, config.get("format_style", "LLVM")) + for attr in instance.remote_attributes.values(): if hasattr(attr, "remote_attributes"): recurse_header(attr, config) @@ -69,6 +83,9 @@ def recurse_header(remote_object, config): template.render(instance=remote_object, helper_file=helper_file), file=output_file, ) + # Format the generated file + format_c_code(file_path, config.get("format_style", "LLVM")) + for attr in remote_object.remote_attributes.values(): if hasattr(attr, "remote_attributes"): recurse_header(attr, config) @@ -92,6 +109,9 @@ def process_impl(instance, config): ), file=output_file, ) + # Format the generated file + format_c_code(file_path, config.get("format_style", "LLVM")) + for attr in instance.remote_attributes.values(): if hasattr(attr, "remote_attributes"): recurse_impl(attr, config) @@ -106,6 +126,9 @@ def recurse_impl(remote_object, config): os.makedirs(os.path.dirname(config["paths"]["output_impl"]), exist_ok=True) with open(file_path, "w") as output_file: print(template.render(instance=remote_object), file=output_file) + # Format the generated file + format_c_code(file_path, config.get("format_style", "LLVM")) + for attr in remote_object.remote_attributes.values(): if hasattr(attr, "remote_attributes"): recurse_impl(attr, config) diff --git a/avlos/templates/fw_endpoints.c.jinja b/avlos/templates/fw_endpoints.c.jinja index fff41b3..e3902dd 100644 --- a/avlos/templates/fw_endpoints.c.jinja +++ b/avlos/templates/fw_endpoints.c.jinja @@ -52,7 +52,7 @@ static inline uint8_t _avlos_setter_string(const uint8_t *buffer, void (*setter) } {% set comma = joiner(", ") %} -uint8_t (*avlos_endpoints[{{ instance | endpoints | length }}])(uint8_t * buffer, uint8_t * buffer_len, Avlos_Command cmd) = { {%- for attr in instance | endpoints %}{{ comma() }}&avlos_{{attr.full_name | replace(".", "_") }}{%- endfor %} }; +uint8_t (*avlos_endpoints[{{ instance | endpoints | length }}])(uint8_t * buffer, uint8_t * buffer_len, Avlos_Command cmd) = { {%- for attr in instance | endpoints %}{{ comma() }}&{{attr.endpoint_function_name}}{%- endfor %} }; uint32_t _avlos_get_proto_hash(void) { @@ -61,11 +61,11 @@ uint32_t _avlos_get_proto_hash(void) {%- for attr in instance | endpoints %} -{% if attr.func_attr -%}{{attr.func_attr}} {% endif %}uint8_t avlos_{{attr.full_name | replace(".", "_") }}(uint8_t * buffer, uint8_t * buffer_len, Avlos_Command cmd) +{% if attr.func_attr -%}{{attr.func_attr}} {% endif %}uint8_t {{attr.endpoint_function_name}}(uint8_t * buffer, uint8_t * buffer_len, Avlos_Command cmd) { {%- if attr.getter_name %} - {%- if attr.dtype.c_name == "char[]" %} + {%- if attr.getter_strategy == "string" %} {{ getter_char(attr) }} {%- else %} {{ getter_byval(attr) }} @@ -74,7 +74,7 @@ uint32_t _avlos_get_proto_hash(void) {%- endif %} {%- if attr.setter_name %} - {%- if attr.dtype.c_name == "char[]" %} + {%- if attr.setter_strategy == "string" %} {{ setter_char(attr) }} {%- else %} {{ setter_byval(attr) }} diff --git a/avlos/templates/fw_endpoints.h.jinja b/avlos/templates/fw_endpoints.h.jinja index 5479a70..54f1bcb 100644 --- a/avlos/templates/fw_endpoints.h.jinja +++ b/avlos/templates/fw_endpoints.h.jinja @@ -19,7 +19,7 @@ extern uint32_t _avlos_get_proto_hash(void); {%- for attr in instance | endpoints %} /* -* avlos_{{attr.full_name | replace(".", "_") }} +* {{attr.endpoint_function_name}} * * {{ attr.summary }} * @@ -28,6 +28,6 @@ extern uint32_t _avlos_get_proto_hash(void); * @param buffer * @param buffer_len */ -uint8_t avlos_{{attr.full_name | replace(".", "_") }}(uint8_t * buffer, uint8_t * buffer_len, Avlos_Command cmd); +uint8_t {{attr.endpoint_function_name}}(uint8_t * buffer, uint8_t * buffer_len, Avlos_Command cmd); {%- endfor %} diff --git a/avlos/validation.py b/avlos/validation.py new file mode 100644 index 0000000..d0c5030 --- /dev/null +++ b/avlos/validation.py @@ -0,0 +1,183 @@ +""" +Pre-generation validation for Avlos code generation. +Validates C identifiers, detects conflicts, and ensures consistency. +""" +from typing import List +import re + +# C reserved words (C11 standard) +C_RESERVED_WORDS = { + 'auto', 'break', 'case', 'char', 'const', 'continue', 'default', 'do', + 'double', 'else', 'enum', 'extern', 'float', 'for', 'goto', 'if', + 'int', 'long', 'register', 'return', 'short', 'signed', 'sizeof', 'static', + 'struct', 'switch', 'typedef', 'union', 'unsigned', 'void', 'volatile', 'while', + '_Alignas', '_Alignof', '_Atomic', '_Bool', '_Complex', '_Generic', '_Imaginary', + '_Noreturn', '_Static_assert', '_Thread_local' +} + +C_IDENTIFIER_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') + + +class ValidationError(Exception): + """Raised when validation fails.""" + pass + + +def validate_c_identifier(name: str, context: str = "") -> None: + """ + Validate that a name is a valid C identifier. + + Args: + name: Identifier to validate + context: Context string for error messages (e.g., "getter_name for motor.R") + + Raises: + ValidationError: If identifier is invalid + """ + if not C_IDENTIFIER_PATTERN.match(name): + ctx = f" ({context})" if context else "" + raise ValidationError( + f"Invalid C identifier '{name}'{ctx}. " + f"Must start with letter or underscore, contain only alphanumeric and underscore." + ) + + if name in C_RESERVED_WORDS: + ctx = f" ({context})" if context else "" + raise ValidationError( + f"Invalid C identifier '{name}'{ctx}. '{name}' is a C reserved word." + ) + + if len(name) > 63: + # C99 requires at least 63 significant characters for identifiers + ctx = f" ({context})" if context else "" + print(f"Warning: Identifier '{name}'{ctx} is very long ({len(name)} chars). " + f"Some compilers may truncate after 63 characters.") + + +def validate_endpoint_ids(instance) -> List[str]: + """ + Check for endpoint ID conflicts. + + Args: + instance: Root node to validate + + Returns: + List of error messages (empty if no conflicts) + """ + from avlos.generators.filters import avlos_endpoints + + errors = [] + ep_id_map = {} + + for ep in avlos_endpoints(instance): + ep_id = ep.ep_id + if ep_id in ep_id_map: + errors.append( + f"Duplicate endpoint ID {ep_id}: " + f"'{ep.full_name}' and '{ep_id_map[ep_id].full_name}'" + ) + else: + ep_id_map[ep_id] = ep + + return errors + + +def validate_function_names(instance) -> List[str]: + """ + Check for function name conflicts in generated C code. + + Args: + instance: Root node to validate + + Returns: + List of error messages (empty if no conflicts) + """ + from avlos.generators.filters import avlos_endpoints + + errors = [] + + # Check getter/setter/caller names are valid C identifiers + for ep in avlos_endpoints(instance): + if hasattr(ep, 'getter_name') and ep.getter_name: + try: + validate_c_identifier(ep.getter_name, f"getter for {ep.full_name}") + except ValidationError as e: + errors.append(str(e)) + + if hasattr(ep, 'setter_name') and ep.setter_name: + try: + validate_c_identifier(ep.setter_name, f"setter for {ep.full_name}") + except ValidationError as e: + errors.append(str(e)) + + if hasattr(ep, 'caller_name') and ep.caller_name: + try: + validate_c_identifier(ep.caller_name, f"caller for {ep.full_name}") + except ValidationError as e: + errors.append(str(e)) + + # Check for endpoint function name collisions + # (endpoint functions are named: avlos_{full_name with dots replaced by underscores}) + endpoint_names = {} + for ep in avlos_endpoints(instance): + ep_func_name = "avlos_" + ep.full_name.replace(".", "_") + if ep_func_name in endpoint_names: + errors.append( + f"Endpoint function name collision: '{ep_func_name}' " + f"generated from both '{ep.full_name}' and '{endpoint_names[ep_func_name]}'" + ) + else: + endpoint_names[ep_func_name] = ep.full_name + + return errors + + +def validate_names(instance) -> List[str]: + """ + Validate all names in the device tree are valid C identifiers. + + Args: + instance: Root node to validate + + Returns: + List of error messages (empty if all valid) + """ + errors = [] + + def traverse_nodes(node, path=""): + # Validate node name + current_path = f"{path}.{node.name}" if path else node.name + try: + # Node names become part of full_name which becomes C function name + # So they should be valid C identifier parts + validate_c_identifier(node.name, f"node name at {current_path}") + except ValidationError as e: + errors.append(str(e)) + + # Recursively check children + if hasattr(node, 'remote_attributes'): + for child in node.remote_attributes.values(): + traverse_nodes(child, current_path) + + traverse_nodes(instance) + return errors + + +def validate_all(instance) -> List[str]: + """ + Run all validations and return list of errors. + + Args: + instance: Root node to validate + + Returns: + List of all error messages (empty if validation passes) + """ + errors = [] + + # Collect all validation errors + errors.extend(validate_names(instance)) + errors.extend(validate_endpoint_ids(instance)) + errors.extend(validate_function_names(instance)) + + return errors From 2dd9b5e63151585dbf6058667d2ef7389b03bf1e Mon Sep 17 00:00:00 2001 From: Yannis Chatzikonstantinou Date: Sat, 10 Jan 2026 14:52:52 +0200 Subject: [PATCH 3/5] add tests --- avlos/datatypes.py | 8 + tests/test_data_model_properties.py | 291 ++++++++++++++++++ tests/test_templates.py | 450 ++++++++++++++++++++++++++++ tests/test_validation.py | 249 +++++++++++++++ 4 files changed, 998 insertions(+) create mode 100644 tests/test_data_model_properties.py create mode 100644 tests/test_templates.py create mode 100644 tests/test_validation.py diff --git a/avlos/datatypes.py b/avlos/datatypes.py index 7825d1b..274d9c5 100644 --- a/avlos/datatypes.py +++ b/avlos/datatypes.py @@ -54,6 +54,8 @@ def from_string(self, str_value): DataType.UINT16: "uint16_t", DataType.INT32: "int32_t", DataType.UINT32: "uint32_t", + DataType.INT64: "int64_t", + DataType.UINT64: "uint64_t", DataType.FLOAT: "float", DataType.DOUBLE: "double", DataType.STR: "char[]", @@ -68,6 +70,8 @@ def from_string(self, str_value): DataType.UINT16: int, DataType.INT32: int, DataType.UINT32: int, + DataType.INT64: int, + DataType.UINT64: int, DataType.FLOAT: float, DataType.DOUBLE: float, DataType.STR: str, @@ -82,6 +86,8 @@ def from_string(self, str_value): DataType.UINT16: 2, DataType.INT32: 4, DataType.UINT32: 4, + DataType.INT64: 8, + DataType.UINT64: 8, DataType.FLOAT: 4, DataType.DOUBLE: 8, DataType.STR: -1, @@ -97,6 +103,8 @@ def from_string(self, str_value): "uint16": DataType.UINT16, "int32": DataType.INT32, "uint32": DataType.UINT32, + "int64": DataType.INT64, + "uint64": DataType.UINT64, "float": DataType.FLOAT, "double": DataType.DOUBLE, "string": DataType.STR, diff --git a/tests/test_data_model_properties.py b/tests/test_data_model_properties.py new file mode 100644 index 0000000..165f4bd --- /dev/null +++ b/tests/test_data_model_properties.py @@ -0,0 +1,291 @@ +""" +Tests for data model properties added for code generation. +""" +import unittest +import yaml +from avlos.deserializer import deserialize +from avlos.definitions.remote_attribute import RemoteAttribute +from avlos.definitions.remote_function import RemoteFunction +from avlos.definitions.remote_enum import RemoteEnum +from avlos.definitions.remote_bitmask import RemoteBitmask +from avlos.datatypes import DataType + + +class TestDataModelProperties(unittest.TestCase): + """Test properties added to data model classes.""" + + def test_string_getter_strategy(self): + """Test that char[] types return 'string' getter strategy.""" + attr = RemoteAttribute( + name="nickname", + summary="Device nickname", + dtype=DataType.STR, + getter_name="get_nickname", + ) + + self.assertEqual(attr.getter_strategy, "string") + self.assertEqual(attr.setter_strategy, "string") + + def test_byval_getter_strategy_float(self): + """Test that float types return 'byval' getter strategy.""" + attr = RemoteAttribute( + name="voltage", + summary="Bus voltage", + dtype=DataType.FLOAT, + getter_name="get_voltage", + ) + + self.assertEqual(attr.getter_strategy, "byval") + self.assertEqual(attr.setter_strategy, "byval") + + def test_byval_getter_strategy_integers(self): + """Test that integer types return 'byval' getter strategy.""" + int_types = [ + DataType.UINT8, DataType.INT8, + DataType.UINT16, DataType.INT16, + DataType.UINT32, DataType.INT32, + DataType.UINT64, DataType.INT64, + ] + + for dtype in int_types: + attr = RemoteAttribute( + name="test_value", + summary="Test value", + dtype=dtype, + getter_name="get_value", + ) + + self.assertEqual(attr.getter_strategy, "byval", + f"{dtype} should use byval strategy") + self.assertEqual(attr.setter_strategy, "byval", + f"{dtype} should use byval strategy") + + def test_endpoint_function_name_simple(self): + """Test endpoint function name for simple attribute.""" + from avlos.mixins.named_node import NamedNode + + attr = RemoteAttribute( + name="voltage", + summary="Bus voltage", + dtype=DataType.FLOAT, + getter_name="get_voltage", + ) + # Set include_base_name so full_name returns the name when parent is None + attr.include_base_name = True + + self.assertEqual(attr.endpoint_function_name, "avlos_voltage") + + def test_endpoint_function_name_nested(self): + """Test endpoint function name for nested attribute.""" + import importlib.resources + + def_path_str = str( + importlib.resources.files("tests").joinpath("definition/good_device.yaml") + ) + + with open(def_path_str) as device_desc_stream: + obj = deserialize(yaml.safe_load(device_desc_stream)) + + # Find a nested attribute (e.g., motor.R) + if hasattr(obj, 'motor') and hasattr(obj.motor, 'R'): + attr = obj.motor.R + expected_name = "avlos_motor_R" + self.assertEqual(attr.endpoint_function_name, expected_name, + f"motor.R should generate {expected_name}") + + def test_is_string_type_true(self): + """Test is_string_type property for char[] type.""" + attr = RemoteAttribute( + name="name", + summary="Device name", + dtype=DataType.STR, + getter_name="get_name", + ) + + self.assertTrue(attr.is_string_type) + + def test_is_string_type_false(self): + """Test is_string_type property for non-string types.""" + attr = RemoteAttribute( + name="value", + summary="Numeric value", + dtype=DataType.FLOAT, + getter_name="get_value", + ) + + self.assertFalse(attr.is_string_type) + + def test_remote_function_endpoint_name(self): + """Test endpoint function name for RemoteFunction.""" + func = RemoteFunction( + name="reset", + summary="Reset the device", + caller_name="system_reset", + arguments=[], + dtype=DataType.VOID, + ) + # Set include_base_name so full_name returns the name when parent is None + func.include_base_name = True + + self.assertEqual(func.endpoint_function_name, "avlos_reset") + + def test_remote_function_endpoint_name_nested(self): + """Test endpoint function name for nested RemoteFunction.""" + import importlib.resources + + def_path_str = str( + importlib.resources.files("tests").joinpath("definition/good_device.yaml") + ) + + with open(def_path_str) as device_desc_stream: + obj = deserialize(yaml.safe_load(device_desc_stream)) + + # Find nested function (e.g., controller.set_pos_vel_setpoints) + if hasattr(obj, 'controller') and hasattr(obj.controller, 'set_pos_vel_setpoints'): + func = obj.controller.set_pos_vel_setpoints + expected_name = "avlos_controller_set_pos_vel_setpoints" + self.assertEqual(func.endpoint_function_name, expected_name) + + def test_enum_properties(self): + """Test that RemoteEnum has correct properties.""" + from enum import IntEnum + + class TestEnum(IntEnum): + OPTION_A = 0 + OPTION_B = 1 + OPTION_C = 2 + + enum_attr = RemoteEnum( + name="mode", + summary="Operating mode", + getter_name="get_mode", + setter_name="set_mode", + options=TestEnum, + ) + # Set include_base_name so full_name returns the name when parent is None + enum_attr.include_base_name = True + + self.assertEqual(enum_attr.getter_strategy, "byval") + self.assertEqual(enum_attr.setter_strategy, "byval") + self.assertEqual(enum_attr.endpoint_function_name, "avlos_mode") + + def test_bitmask_properties(self): + """Test that RemoteBitmask has correct properties.""" + from enum import IntFlag + + class TestFlags(IntFlag): + FLAG_A = 1 + FLAG_B = 2 + FLAG_C = 4 + + bitmask_attr = RemoteBitmask( + name="errors", + summary="Error flags", + getter_name="get_errors", + flags=TestFlags, + ) + # Set include_base_name so full_name returns the name when parent is None + bitmask_attr.include_base_name = True + + self.assertEqual(bitmask_attr.getter_strategy, "byval") + self.assertEqual(bitmask_attr.setter_strategy, "byval") + self.assertEqual(bitmask_attr.endpoint_function_name, "avlos_errors") + + def test_backward_compatibility_generated_code(self): + """Test that generated code is functionally equivalent to before refactoring.""" + import importlib.resources + from avlos.generators import generator_c + + def_path_str = str( + importlib.resources.files("tests").joinpath("definition/good_device.yaml") + ) + output_impl = str( + importlib.resources.files("tests").joinpath("outputs/test_backward_compat.c") + ) + + with open(def_path_str) as device_desc_stream: + obj = deserialize(yaml.safe_load(device_desc_stream)) + + config = { + "hash_string": "0x9e8dc7ac", + "paths": { + "output_enums": str( + importlib.resources.files("tests").joinpath("outputs/test_enum_compat.h") + ), + "output_header": str( + importlib.resources.files("tests").joinpath("outputs/test_header_compat.h") + ), + "output_impl": output_impl, + }, + } + + # Generate code + generator_c.process(obj, config) + + # Read generated code + with open(output_impl) as f: + generated_code = f.read() + + # Verify key patterns are present + self.assertIn("avlos_", generated_code, "Should have avlos_ prefixed functions") + self.assertIn("AVLOS_CMD_READ", generated_code, "Should handle read commands") + self.assertIn("AVLOS_CMD_WRITE", generated_code, "Should handle write commands") + self.assertIn("_avlos_getter_string", generated_code, + "Should have string getter helper") + self.assertIn("_avlos_setter_string", generated_code, + "Should have string setter helper") + + # Verify function declarations use properties + # (all endpoint functions should be present) + self.assertIn("uint8_t avlos_", generated_code) + + def test_all_endpoints_have_function_names(self): + """Test that all endpoints from good_device.yaml have endpoint_function_name.""" + import importlib.resources + from avlos.generators.filters import avlos_endpoints + + def_path_str = str( + importlib.resources.files("tests").joinpath("definition/good_device.yaml") + ) + + with open(def_path_str) as device_desc_stream: + obj = deserialize(yaml.safe_load(device_desc_stream)) + + endpoints = avlos_endpoints(obj) + + for ep in endpoints: + self.assertTrue(hasattr(ep, 'endpoint_function_name'), + f"Endpoint {ep.name} should have endpoint_function_name property") + func_name = ep.endpoint_function_name + self.assertTrue(func_name.startswith("avlos_"), + f"Endpoint function name should start with avlos_, got: {func_name}") + self.assertNotIn(".", func_name, + f"Endpoint function name should not contain dots, got: {func_name}") + + def test_getter_setter_strategy_consistency(self): + """Test that getter and setter strategies are consistent.""" + import importlib.resources + from avlos.generators.filters import avlos_endpoints + + def_path_str = str( + importlib.resources.files("tests").joinpath("definition/good_device.yaml") + ) + + with open(def_path_str) as device_desc_stream: + obj = deserialize(yaml.safe_load(device_desc_stream)) + + endpoints = avlos_endpoints(obj) + + for ep in endpoints: + # Skip functions (they don't have getter/setter strategies in the same way) + if hasattr(ep, 'caller_name') and not hasattr(ep, 'getter_name'): + continue + + if hasattr(ep, 'getter_strategy') and hasattr(ep, 'setter_strategy'): + # Getter and setter strategies should match for attributes + self.assertEqual(ep.getter_strategy, ep.setter_strategy, + f"Endpoint {ep.name} should have consistent getter/setter strategies") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_templates.py b/tests/test_templates.py new file mode 100644 index 0000000..7fafa86 --- /dev/null +++ b/tests/test_templates.py @@ -0,0 +1,450 @@ +""" +Tests for Jinja2 templates and generated code patterns. +""" +import unittest +import yaml +import importlib.resources +from avlos.deserializer import deserialize +from avlos.generators import generator_c, generator_cpp +from avlos.datatypes import DataType + + +class TestTemplateMacros(unittest.TestCase): + """Test template macro behavior through generated code.""" + + def setUp(self): + """Set up test fixtures.""" + def_path_str = str( + importlib.resources.files("tests").joinpath("definition/good_device.yaml") + ) + + with open(def_path_str) as device_desc_stream: + self.device = deserialize(yaml.safe_load(device_desc_stream)) + + def test_char_array_getter_uses_helper(self): + """Test that char[] getter generates code using _avlos_getter_string.""" + output_impl = str( + importlib.resources.files("tests").joinpath("outputs/test_char_getter.c") + ) + + config = { + "hash_string": "0x9e8dc7ac", + "paths": { + "output_enums": str( + importlib.resources.files("tests").joinpath("outputs/test_char_enum.h") + ), + "output_header": str( + importlib.resources.files("tests").joinpath("outputs/test_char_header.h") + ), + "output_impl": output_impl, + }, + } + + generator_c.process(self.device, config) + + with open(output_impl) as f: + content = f.read() + + # Should contain the string helper function definition + self.assertIn("_avlos_getter_string", content, + "Should define _avlos_getter_string helper") + + # Should contain helper function signature + self.assertIn("uint8_t (*getter)(char*)", content, + "Helper should have correct signature") + + # Should call the helper in char[] endpoint functions + # (nickname is a char[] attribute in good_device.yaml) + self.assertIn("_avlos_getter_string(buffer, buffer_len, system_get_name)", content, + "Should use helper for char[] getter") + + def test_char_array_setter_uses_helper(self): + """Test that char[] setter generates code using _avlos_setter_string.""" + output_impl = str( + importlib.resources.files("tests").joinpath("outputs/test_char_setter.c") + ) + + config = { + "hash_string": "0x9e8dc7ac", + "paths": { + "output_enums": str( + importlib.resources.files("tests").joinpath("outputs/test_char_enum2.h") + ), + "output_header": str( + importlib.resources.files("tests").joinpath("outputs/test_char_header2.h") + ), + "output_impl": output_impl, + }, + } + + generator_c.process(self.device, config) + + with open(output_impl) as f: + content = f.read() + + # Should contain the string helper function definition + self.assertIn("_avlos_setter_string", content, + "Should define _avlos_setter_string helper") + + # Should contain helper function signature + self.assertIn("void (*setter)(const char*)", content, + "Helper should have correct signature") + + # Should call the helper in char[] endpoint functions + self.assertIn("_avlos_setter_string(buffer, system_set_name)", content, + "Should use helper for char[] setter") + + def test_numeric_getter_byval(self): + """Test that numeric getters use by-value pattern.""" + output_impl = str( + importlib.resources.files("tests").joinpath("outputs/test_numeric.c") + ) + + config = { + "hash_string": "0x9e8dc7ac", + "paths": { + "output_enums": str( + importlib.resources.files("tests").joinpath("outputs/test_numeric_enum.h") + ), + "output_header": str( + importlib.resources.files("tests").joinpath("outputs/test_numeric_header.h") + ), + "output_impl": output_impl, + }, + } + + generator_c.process(self.device, config) + + with open(output_impl) as f: + content = f.read() + + # Should contain memcpy pattern for by-value types + self.assertIn("memcpy(buffer, &v, sizeof(v))", content, + "Should use memcpy for by-value getters") + + # Should declare local variable for value + # (check for patterns like "float v;" or "uint32_t v;") + self.assertTrue( + "float v;" in content or "uint32_t v;" in content or "uint8_t v;" in content, + "Should declare local variable for value" + ) + + def test_void_function_no_return_value(self): + """Test that void return type functions don't generate return value code.""" + output_impl = str( + importlib.resources.files("tests").joinpath("outputs/test_void_func.c") + ) + + config = { + "hash_string": "0x9e8dc7ac", + "paths": { + "output_enums": str( + importlib.resources.files("tests").joinpath("outputs/test_void_enum.h") + ), + "output_header": str( + importlib.resources.files("tests").joinpath("outputs/test_void_header.h") + ), + "output_impl": output_impl, + }, + } + + generator_c.process(self.device, config) + + with open(output_impl) as f: + content = f.read() + + # Find the reset function (void return, no args) + if "avlos_reset" in content: + # Extract the reset function + start = content.find("uint8_t avlos_reset") + end = content.find("\n}", start) + 2 + reset_func = content[start:end] + + # Void functions should NOT have ret_val + self.assertNotIn("ret_val", reset_func, + "Void function should not have return value") + + # Should call function directly without assignment + self.assertIn("system_reset()", reset_func, + "Should call void function without assignment") + + def test_function_with_args_unpacks_buffer(self): + """Test that functions with arguments unpack from buffer.""" + output_impl = str( + importlib.resources.files("tests").joinpath("outputs/test_func_args.c") + ) + + config = { + "hash_string": "0x9e8dc7ac", + "paths": { + "output_enums": str( + importlib.resources.files("tests").joinpath("outputs/test_func_args_enum.h") + ), + "output_header": str( + importlib.resources.files("tests").joinpath("outputs/test_func_args_header.h") + ), + "output_impl": output_impl, + }, + } + + generator_c.process(self.device, config) + + with open(output_impl) as f: + content = f.read() + + # Should have offset tracking for multiple arguments + self.assertIn("uint8_t _offset = 0", content, + "Should track offset for argument unpacking") + + # Should unpack arguments with memcpy + self.assertIn("memcpy(&", content, + "Should use memcpy to unpack arguments") + + # Should increment offset + self.assertIn("_offset += sizeof(", content, + "Should increment offset for each argument") + + def test_all_data_types_generate(self): + """Test that all supported data types can be generated.""" + # Create a test YAML with all data types + yaml_content = """ + name: test_device + remote_attributes: + - name: u8_val + summary: uint8 value + dtype: uint8 + getter_name: get_u8 + - name: i8_val + summary: int8 value + dtype: int8 + getter_name: get_i8 + - name: u16_val + summary: uint16 value + dtype: uint16 + getter_name: get_u16 + - name: i16_val + summary: int16 value + dtype: int16 + getter_name: get_i16 + - name: u32_val + summary: uint32 value + dtype: uint32 + getter_name: get_u32 + - name: i32_val + summary: int32 value + dtype: int32 + getter_name: get_i32 + - name: u64_val + summary: uint64 value + dtype: uint64 + getter_name: get_u64 + - name: i64_val + summary: int64 value + dtype: int64 + getter_name: get_i64 + - name: float_val + summary: float value + dtype: float + getter_name: get_float + - name: double_val + summary: double value + dtype: double + getter_name: get_double + - name: str_val + summary: string value + dtype: string + getter_name: get_str + - name: bool_val + summary: bool value + dtype: bool + getter_name: get_bool + """ + + obj = deserialize(yaml.safe_load(yaml_content)) + + output_impl = str( + importlib.resources.files("tests").joinpath("outputs/test_all_types.c") + ) + + config = { + "hash_string": "0xdeadbeef", + "paths": { + "output_enums": str( + importlib.resources.files("tests").joinpath("outputs/test_all_types_enum.h") + ), + "output_header": str( + importlib.resources.files("tests").joinpath("outputs/test_all_types_header.h") + ), + "output_impl": output_impl, + }, + } + + # Should not raise any exceptions + generator_c.process(obj, config) + + with open(output_impl) as f: + content = f.read() + + # Verify all types are present + expected_types = [ + "uint8_t", "int8_t", "uint16_t", "int16_t", + "uint32_t", "int32_t", "uint64_t", "int64_t", + "float", "double", "bool" + ] + + for dtype in expected_types: + self.assertIn(dtype, content, + f"Generated code should contain {dtype}") + + # String types use helper functions, so check for that instead of "char[]" + self.assertIn("_avlos_getter_string", content, + "Generated code should contain string helper function") + + def test_func_attr_in_output(self): + """Test that func_attr (e.g., TM_RAMFUNC) appears in generated code.""" + output_impl = str( + importlib.resources.files("tests").joinpath("outputs/test_func_attr.c") + ) + + config = { + "hash_string": "0x9e8dc7ac", + "paths": { + "output_enums": str( + importlib.resources.files("tests").joinpath("outputs/test_func_attr_enum.h") + ), + "output_header": str( + importlib.resources.files("tests").joinpath("outputs/test_func_attr_header.h") + ), + "output_impl": output_impl, + }, + } + + generator_c.process(self.device, config) + + with open(output_impl) as f: + content = f.read() + + # good_device.yaml has TM_RAMFUNC on some functions + if "TM_RAMFUNC" in content: + self.assertIn("TM_RAMFUNC uint8_t avlos_", content, + "func_attr should appear before function declaration") + + def test_endpoint_array_generation(self): + """Test that endpoint array is correctly generated.""" + output_impl = str( + importlib.resources.files("tests").joinpath("outputs/test_ep_array.c") + ) + + config = { + "hash_string": "0x9e8dc7ac", + "paths": { + "output_enums": str( + importlib.resources.files("tests").joinpath("outputs/test_ep_array_enum.h") + ), + "output_header": str( + importlib.resources.files("tests").joinpath("outputs/test_ep_array_header.h") + ), + "output_impl": output_impl, + }, + } + + generator_c.process(self.device, config) + + with open(output_impl) as f: + content = f.read() + + # Should have endpoint array declaration + self.assertIn("uint8_t (*avlos_endpoints[", content, + "Should declare endpoint array") + + # Should have proto hash function + self.assertIn("_avlos_get_proto_hash", content, + "Should have proto hash function") + + +class TestIntegration(unittest.TestCase): + """Test full pipeline integration.""" + + def test_full_pipeline_c_generation(self): + """Test complete C generation pipeline with all features.""" + import importlib.resources + + def_path_str = str( + importlib.resources.files("tests").joinpath("definition/good_device.yaml") + ) + + with open(def_path_str) as device_desc_stream: + obj = deserialize(yaml.safe_load(device_desc_stream)) + + output_impl = str( + importlib.resources.files("tests").joinpath("outputs/test_integration.c") + ) + + config = { + "hash_string": "0x12345678", + "paths": { + "output_enums": str( + importlib.resources.files("tests").joinpath("outputs/test_integration_enum.h") + ), + "output_header": str( + importlib.resources.files("tests").joinpath("outputs/test_integration_header.h") + ), + "output_impl": output_impl, + }, + } + + # Full pipeline: validation → generation → formatting + generator_c.process(obj, config) + + # Verify all files exist + import os + self.assertTrue(os.path.exists(config["paths"]["output_enums"])) + self.assertTrue(os.path.exists(config["paths"]["output_header"])) + self.assertTrue(os.path.exists(config["paths"]["output_impl"])) + + # Verify implementation file has key content + with open(output_impl) as f: + content = f.read() + + self.assertIn("avlos_", content) + self.assertIn("AVLOS_CMD_READ", content) + self.assertIn("avlos_endpoints[", content) + + def test_cpp_generation_pipeline(self): + """Test complete C++ generation pipeline.""" + import importlib.resources + + def_path_str = str( + importlib.resources.files("tests").joinpath("definition/good_device.yaml") + ) + + with open(def_path_str) as device_desc_stream: + obj = deserialize(yaml.safe_load(device_desc_stream)) + + config = { + "hash_string": "0x12345678", + "paths": { + "output_helpers": str( + importlib.resources.files("tests").joinpath("outputs/test_cpp_helpers.hpp") + ), + "output_header": str( + importlib.resources.files("tests").joinpath("outputs/test_cpp_device.hpp") + ), + "output_impl": str( + importlib.resources.files("tests").joinpath("outputs/test_cpp_device.cpp") + ), + }, + } + + # Full pipeline: validation → generation → formatting + generator_cpp.process(obj, config) + + # Verify files exist + import os + self.assertTrue(os.path.exists(config["paths"]["output_helpers"])) + self.assertTrue(os.path.exists(config["paths"]["output_header"])) + self.assertTrue(os.path.exists(config["paths"]["output_impl"])) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 0000000..13e8fd0 --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,249 @@ +""" +Tests for validation module. +""" +import unittest +import yaml +from avlos.deserializer import deserialize +from avlos.validation import ( + validate_c_identifier, + validate_endpoint_ids, + validate_function_names, + validate_names, + validate_all, + ValidationError, + C_RESERVED_WORDS, +) + + +class TestValidation(unittest.TestCase): + """Test validation functions.""" + + def test_valid_c_identifier(self): + """Test that valid C identifiers pass validation.""" + # These should not raise + validate_c_identifier("valid_name") + validate_c_identifier("_private") + validate_c_identifier("name123") + validate_c_identifier("CamelCase") + validate_c_identifier("snake_case_123") + + def test_invalid_c_identifier_special_chars(self): + """Test that identifiers with special characters are rejected.""" + with self.assertRaises(ValidationError) as cm: + validate_c_identifier("invalid-name") + self.assertIn("invalid-name", str(cm.exception)) + self.assertIn("Invalid C identifier", str(cm.exception)) + + with self.assertRaises(ValidationError): + validate_c_identifier("name with spaces") + + with self.assertRaises(ValidationError): + validate_c_identifier("name.with.dots") + + with self.assertRaises(ValidationError): + validate_c_identifier("name$special") + + def test_invalid_c_identifier_starts_with_digit(self): + """Test that identifiers starting with digits are rejected.""" + with self.assertRaises(ValidationError) as cm: + validate_c_identifier("123invalid") + self.assertIn("123invalid", str(cm.exception)) + self.assertIn("Invalid C identifier", str(cm.exception)) + + def test_c_reserved_words(self): + """Test that C reserved words are rejected.""" + reserved_samples = ['int', 'void', 'return', 'if', 'else', 'while', 'for', + 'struct', 'union', 'enum', 'static', 'const', '_Bool'] + + for word in reserved_samples: + self.assertIn(word, C_RESERVED_WORDS) + with self.assertRaises(ValidationError) as cm: + validate_c_identifier(word) + self.assertIn(word, str(cm.exception)) + self.assertIn("reserved word", str(cm.exception)) + + def test_long_identifier_warning(self): + """Test that very long identifiers generate warnings.""" + # 64+ character identifier (C99 requires at least 63 significant chars) + long_name = "a" * 70 + # Should not raise, but might print warning + validate_c_identifier(long_name) + + def test_valid_device_passes_all_validation(self): + """Test that good_device.yaml passes all validations.""" + import importlib.resources + + def_path_str = str( + importlib.resources.files("tests").joinpath("definition/good_device.yaml") + ) + + with open(def_path_str) as device_desc_stream: + obj = deserialize(yaml.safe_load(device_desc_stream)) + errors = validate_all(obj) + self.assertEqual(errors, [], f"good_device.yaml should pass validation but got: {errors}") + + def test_validate_endpoint_ids_no_conflicts(self): + """Test endpoint ID validation with no conflicts.""" + import importlib.resources + + def_path_str = str( + importlib.resources.files("tests").joinpath("definition/good_device.yaml") + ) + + with open(def_path_str) as device_desc_stream: + obj = deserialize(yaml.safe_load(device_desc_stream)) + errors = validate_endpoint_ids(obj) + self.assertEqual(errors, [], "Should have no endpoint ID conflicts") + + def test_validate_function_names_no_conflicts(self): + """Test function name validation with no conflicts.""" + import importlib.resources + + def_path_str = str( + importlib.resources.files("tests").joinpath("definition/good_device.yaml") + ) + + with open(def_path_str) as device_desc_stream: + obj = deserialize(yaml.safe_load(device_desc_stream)) + errors = validate_function_names(obj) + self.assertEqual(errors, [], "Should have no function name conflicts") + + def test_validate_names_valid(self): + """Test name validation for valid device tree.""" + import importlib.resources + + def_path_str = str( + importlib.resources.files("tests").joinpath("definition/good_device.yaml") + ) + + with open(def_path_str) as device_desc_stream: + obj = deserialize(yaml.safe_load(device_desc_stream)) + errors = validate_names(obj) + self.assertEqual(errors, [], "All names should be valid C identifiers") + + def test_invalid_getter_name_caught(self): + """Test that invalid getter names are caught.""" + yaml_content = """ + name: test_device + remote_attributes: + - name: value + summary: Test value + dtype: uint32 + getter_name: invalid-getter-name + """ + + obj = deserialize(yaml.safe_load(yaml_content)) + errors = validate_function_names(obj) + + self.assertTrue(len(errors) > 0, "Should detect invalid getter name") + self.assertTrue(any("invalid-getter-name" in err for err in errors)) + + def test_invalid_setter_name_caught(self): + """Test that invalid setter names are caught.""" + yaml_content = """ + name: test_device + remote_attributes: + - name: value + summary: Test value + dtype: uint32 + setter_name: invalid-setter + getter_name: valid_getter + """ + + # Note: setter name "invalid-setter" has a dash which is invalid + obj = deserialize(yaml.safe_load(yaml_content)) + errors = validate_function_names(obj) + + self.assertTrue(len(errors) > 0, "Should detect invalid setter name") + + def test_reserved_word_as_getter_name(self): + """Test that C reserved words as getter names are caught.""" + yaml_content = """ + name: test_device + remote_attributes: + - name: value + summary: Test value + dtype: uint32 + getter_name: return + """ + + obj = deserialize(yaml.safe_load(yaml_content)) + errors = validate_function_names(obj) + + self.assertTrue(len(errors) > 0, "Should detect reserved word as getter name") + self.assertTrue(any("reserved word" in err for err in errors)) + + def test_invalid_node_name_caught(self): + """Test that invalid node names are caught.""" + yaml_content = """ + name: invalid-device-name + remote_attributes: + - name: value + summary: Test value + dtype: uint32 + getter_name: get_value + """ + + obj = deserialize(yaml.safe_load(yaml_content)) + errors = validate_names(obj) + + self.assertTrue(len(errors) > 0, "Should detect invalid node name") + self.assertTrue(any("invalid-device-name" in err for err in errors)) + + def test_nested_invalid_node_name(self): + """Test that invalid names in nested nodes are caught.""" + yaml_content = """ + name: device + remote_attributes: + - name: motor + summary: Motor controller + remote_attributes: + - name: invalid-nested-name + summary: Invalid nested attribute + dtype: float + getter_name: get_value + """ + + obj = deserialize(yaml.safe_load(yaml_content)) + errors = validate_names(obj) + + self.assertTrue(len(errors) > 0, "Should detect invalid nested node name") + self.assertTrue(any("invalid-nested-name" in err for err in errors)) + + def test_validation_error_with_context(self): + """Test that validation errors include helpful context.""" + with self.assertRaises(ValidationError) as cm: + validate_c_identifier("123bad", "test context") + + error_msg = str(cm.exception) + self.assertIn("123bad", error_msg) + self.assertIn("test context", error_msg) + + def test_validate_all_collects_multiple_errors(self): + """Test that validate_all collects all errors, not just the first one.""" + yaml_content = """ + name: test-device + remote_attributes: + - name: attr-one + summary: Test attribute + dtype: uint32 + getter_name: invalid-getter + - name: attr-two + summary: Another test + dtype: float + setter_name: return + """ + + obj = deserialize(yaml.safe_load(yaml_content)) + errors = validate_all(obj) + + # Should have multiple errors: + # - invalid device name (test-device) + # - invalid attribute names (attr-one, attr-two) + # - invalid getter name (invalid-getter) + # - reserved word setter name (return) + self.assertTrue(len(errors) >= 4, f"Should collect multiple errors, got {len(errors)}: {errors}") + + +if __name__ == '__main__': + unittest.main() From 2df5cf762a546c0ca2702f5994afa9a2c86227b5 Mon Sep 17 00:00:00 2001 From: Yannis Chatzikonstantinou Date: Sat, 10 Jan 2026 15:03:20 +0200 Subject: [PATCH 4/5] add pre commit and format --- .github/workflows/ci.yml | 2 +- .github/workflows/docs.yml | 2 +- .pre-commit-config.yaml | 43 +++++ .vscode/launch.json | 2 +- .vscode/settings.json | 2 +- LICENSE | 1 - README.md | 7 +- avlos/__init__.py | 5 +- avlos/bitmask_field.py | 3 +- avlos/cli.py | 8 +- avlos/datatypes.py | 3 +- avlos/definitions/__init__.py | 2 +- avlos/definitions/remote_attribute.py | 12 +- avlos/definitions/remote_bitmask.py | 8 +- avlos/definitions/remote_enum.py | 8 +- avlos/definitions/remote_function.py | 21 +-- avlos/definitions/remote_node.py | 43 ++--- avlos/definitions/remote_root_node.py | 2 + avlos/deserializer.py | 5 +- avlos/enum_field.py | 3 +- avlos/formatting.py | 10 +- avlos/generators/filters.py | 9 +- avlos/generators/generator_c.py | 11 +- avlos/generators/generator_cpp.py | 11 +- avlos/generators/generator_dbc.py | 2 + avlos/generators/generator_rst.py | 2 + avlos/json_codec.py | 2 + avlos/mixins/func_attr_node.py | 1 - avlos/processor.py | 9 +- avlos/templates/device.cpp.jinja | 5 +- avlos/templates/device.dbc.jinja | 2 +- avlos/templates/device.hpp.jinja | 2 +- avlos/templates/docs.rst.jinja | 2 +- avlos/templates/remote_object.cpp.jinja | 8 +- avlos/templates/remote_object.hpp.jinja | 4 +- avlos/unit_field.py | 2 +- avlos/validation.py | 77 +++++--- docs/cli.rst | 2 +- docs/conf.py | 14 +- docs/config.rst | 4 +- docs/index.rst | 2 - docs/introduction.rst | 3 +- docs/spec_format.rst | 1 - example/README.md | 2 +- example/avlos_config.yaml | 2 +- example/device.yaml | 2 +- setup.cfg | 32 ++++ setup.py | 3 +- tests/definition/bad_device_name.yaml | 1 - tests/definition/good_device.yaml | 8 +- tests/definition/obsolete_device.yaml | 3 +- tests/test_counter.py | 4 +- tests/test_data_model_properties.py | 103 +++++------ tests/test_deserialization.py | 42 ++--- tests/test_functions.py | 3 +- tests/test_generation.py | 68 +++----- tests/test_impex.py | 12 +- tests/test_remote_objects.py | 34 ++-- tests/test_templates.py | 222 +++++++++--------------- tests/test_validation.py | 44 +++-- 60 files changed, 453 insertions(+), 499 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 setup.cfg diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7a7d9c0..8f12959 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,4 +33,4 @@ jobs: python -m unittest - name: Test DBC Files run: | - cantools dump tests/outputs/test.dbc \ No newline at end of file + cantools dump tests/outputs/test.dbc diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index b9ec98b..d21c55d 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -23,4 +23,4 @@ jobs: - name: Build docs run: | cd docs - make html SPHINXOPTS="-W --keep-going -n" \ No newline at end of file + make html SPHINXOPTS="-W --keep-going -n" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..f70d15d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,43 @@ +# Pre-commit hooks for Avlos +# See https://pre-commit.com for more information +repos: + # Python code formatting with black + - repo: https://github.com/psf/black + rev: 24.1.1 + hooks: + - id: black + language_version: python3 + args: ['--line-length=127'] + + # Python import sorting + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + args: ['--profile', 'black', '--line-length', '127'] + + # Python linting with flake8 + - repo: https://github.com/PyCQA/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + args: ['--max-line-length=127', '--max-complexity=10', '--extend-ignore=E203,W503'] + + # Trailing whitespace and file endings + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + - id: check-merge-conflict + - id: mixed-line-ending + + # RST file checking + - repo: https://github.com/rstcheck/rstcheck + rev: v6.2.0 + hooks: + - id: rstcheck + args: ['--report-level=warning'] + additional_dependencies: ['sphinx'] diff --git a/.vscode/launch.json b/.vscode/launch.json index 9dc3b30..c2ee3b0 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -11,4 +11,4 @@ "justMyCode": false } ] -} \ No newline at end of file +} diff --git a/.vscode/settings.json b/.vscode/settings.json index 791ca6d..af6bd11 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -9,4 +9,4 @@ "python.testing.pytestEnabled": false, "python.testing.unittestEnabled": true, "python.formatting.provider": "black" -} \ No newline at end of file +} diff --git a/LICENSE b/LICENSE index 31f6f88..84bed23 100644 --- a/LICENSE +++ b/LICENSE @@ -5,4 +5,3 @@ Permission is hereby granted, free of charge, to any person obtaining a copy of The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - diff --git a/README.md b/README.md index 2bf3e31..4aae39e 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ Stop writing serialization code twice. Stop debugging protocol mismatches. - 🔒 **Type-safe across the boundary** → Catch errors at build time, not runtime - 🎯 **Battle-tested** → Production-proven in [Tinymovr](https://tinymovr.com) motor controllers -[Αυλός (Avlόs)](https://en.wikipedia.org/wiki/Aulos) - _flute_, also _channel_. +[Αυλός (Avlόs)](https://en.wikipedia.org/wiki/Aulos) - _flute_, also _channel_.

@@ -74,7 +74,7 @@ Given the above, Avlos can generate the following: - [CAN DBC file](https://www.csselectronics.com/pages/can-dbc-file-database-intro) (CAN database), for every endpoint, for use with CAN-based comm channels. -In addition, Avlos will compute a checksum for the spec and add it as a variable to the implementation so that it can be retrieved by the client for comparing client and device specs. +In addition, Avlos will compute a checksum for the spec and add it as a variable to the implementation so that it can be retrieved by the client for comparing client and device specs. The output location, as well as many other attributes of the files are flexible and easily configurable. @@ -127,7 +127,7 @@ In addition, the object resulting from the deserialization of the spec can be us import yaml from avlos import deserialize from myProject import myChannel # update this - + device_description = ... obj = deserialize(yaml.safe_load(device_description)) obj.set_channel(myChannel()) @@ -164,4 +164,3 @@ Between releases, development versions are automatically generated (e.g., `0.8.7 ## 🔑 License MIT - diff --git a/avlos/__init__.py b/avlos/__init__.py index e3bffd2..b883c62 100644 --- a/avlos/__init__.py +++ b/avlos/__init__.py @@ -6,8 +6,9 @@ # Package is not installed, version will be determined from git try: from setuptools_scm import get_version - __version__ = get_version(root='..', relative_to=__file__) + + __version__ = get_version(root="..", relative_to=__file__) except (ImportError, LookupError): __version__ = "unknown" -__all__ = ["get_registry", "__version__"] \ No newline at end of file +__all__ = ["get_registry", "__version__"] diff --git a/avlos/bitmask_field.py b/avlos/bitmask_field.py index 2ab4679..0a4e956 100644 --- a/avlos/bitmask_field.py +++ b/avlos/bitmask_field.py @@ -1,5 +1,6 @@ import enum -from marshmallow import fields, ValidationError + +from marshmallow import ValidationError, fields class BitmaskField(fields.Field): diff --git a/avlos/cli.py b/avlos/cli.py index 1b05c66..14e2fbf 100644 --- a/avlos/cli.py +++ b/avlos/cli.py @@ -9,12 +9,14 @@ --config= Path of the Avlos config file [default: ./avlos_config.yaml] """ -import yaml -from typing import Dict import logging -import pkg_resources import urllib.request +from typing import Dict + +import pkg_resources +import yaml from docopt import docopt + from avlos.deserializer import deserialize from avlos.processor import process_with_config_file diff --git a/avlos/datatypes.py b/avlos/datatypes.py index 274d9c5..9b0f269 100644 --- a/avlos/datatypes.py +++ b/avlos/datatypes.py @@ -1,5 +1,6 @@ from enum import Enum -from marshmallow import fields, ValidationError + +from marshmallow import ValidationError, fields class DataType(Enum): diff --git a/avlos/definitions/__init__.py b/avlos/definitions/__init__.py index 9bd0a71..ece9cf9 100644 --- a/avlos/definitions/__init__.py +++ b/avlos/definitions/__init__.py @@ -1,6 +1,6 @@ from avlos.definitions.remote_attribute import RemoteAttribute -from avlos.definitions.remote_function import RemoteFunction, RemoteArgument, RemoteArgumentSchema from avlos.definitions.remote_bitmask import RemoteBitmask from avlos.definitions.remote_enum import RemoteEnum +from avlos.definitions.remote_function import RemoteArgument, RemoteArgumentSchema, RemoteFunction from avlos.definitions.remote_node import RemoteNode, RemoteNodeSchema from avlos.definitions.remote_root_node import RootNode, RootNodeSchema diff --git a/avlos/definitions/remote_attribute.py b/avlos/definitions/remote_attribute.py index 030b225..f97e45d 100644 --- a/avlos/definitions/remote_attribute.py +++ b/avlos/definitions/remote_attribute.py @@ -1,9 +1,9 @@ from avlos import get_registry from avlos.mixins.comm_node import CommNode -from avlos.mixins.named_node import NamedNode -from avlos.mixins.meta_node import MetaNode -from avlos.mixins.impex_node import ImpexNode from avlos.mixins.func_attr_node import FuncAttrNode +from avlos.mixins.impex_node import ImpexNode +from avlos.mixins.meta_node import MetaNode +from avlos.mixins.named_node import NamedNode class RemoteAttribute(CommNode, NamedNode, MetaNode, ImpexNode, FuncAttrNode): @@ -91,11 +91,7 @@ def str_dump(self): format_str = "{0} [{1}]: {2:.6g}" else: format_str = "{0} [{1}]: {2}" - return format_str.format( - self.name, - self.dtype.nickname, - value - ) + return format_str.format(self.name, self.dtype.nickname, value) @property def getter_strategy(self) -> str: diff --git a/avlos/definitions/remote_bitmask.py b/avlos/definitions/remote_bitmask.py index 946f000..98e49fc 100644 --- a/avlos/definitions/remote_bitmask.py +++ b/avlos/definitions/remote_bitmask.py @@ -1,9 +1,9 @@ -from avlos.mixins.comm_node import CommNode -from avlos.mixins.named_node import NamedNode -from avlos.mixins.meta_node import MetaNode -from avlos.mixins.impex_node import ImpexNode from avlos.datatypes import DataType +from avlos.mixins.comm_node import CommNode from avlos.mixins.func_attr_node import FuncAttrNode +from avlos.mixins.impex_node import ImpexNode +from avlos.mixins.meta_node import MetaNode +from avlos.mixins.named_node import NamedNode class RemoteBitmask(CommNode, NamedNode, MetaNode, ImpexNode): diff --git a/avlos/definitions/remote_enum.py b/avlos/definitions/remote_enum.py index 5b0b03b..06d61f7 100644 --- a/avlos/definitions/remote_enum.py +++ b/avlos/definitions/remote_enum.py @@ -1,9 +1,9 @@ -from avlos.mixins.comm_node import CommNode -from avlos.mixins.named_node import NamedNode -from avlos.mixins.meta_node import MetaNode -from avlos.mixins.impex_node import ImpexNode from avlos.datatypes import DataType +from avlos.mixins.comm_node import CommNode from avlos.mixins.func_attr_node import FuncAttrNode +from avlos.mixins.impex_node import ImpexNode +from avlos.mixins.meta_node import MetaNode +from avlos.mixins.named_node import NamedNode class RemoteEnum(CommNode, NamedNode, MetaNode, ImpexNode): diff --git a/avlos/definitions/remote_function.py b/avlos/definitions/remote_function.py index 645e992..eacb0c2 100644 --- a/avlos/definitions/remote_function.py +++ b/avlos/definitions/remote_function.py @@ -1,14 +1,11 @@ -from marshmallow import ( - Schema, - fields, - post_load, -) -from avlos.unit_field import UnitField +from marshmallow import Schema, fields, post_load + from avlos.datatypes import DataTypeField from avlos.mixins.comm_node import CommNode -from avlos.mixins.named_node import NamedNode -from avlos.mixins.meta_node import MetaNode from avlos.mixins.func_attr_node import FuncAttrNode +from avlos.mixins.meta_node import MetaNode +from avlos.mixins.named_node import NamedNode +from avlos.unit_field import UnitField class RemoteFunction(CommNode, NamedNode, MetaNode): @@ -49,9 +46,7 @@ def __call__(self, *args): mags.append(arg_val.to(arg_obj.unit).magnitude) except AttributeError: mags.append(arg_val) - data = self.channel.serializer.serialize( - mags, *[arg.dtype for arg in self.arguments] - ) + data = self.channel.serializer.serialize(mags, *[arg.dtype for arg in self.arguments]) self.channel.send(data, self.ep_id) if not self.dtype.is_void: data = self.channel.recv(self.ep_id) @@ -97,9 +92,7 @@ class RemoteArgumentSchema(Schema): arguments """ - name = fields.String( - required=True, error_messages={"required": "Name is required."} - ) + name = fields.String(required=True, error_messages={"required": "Name is required."}) summary = fields.String() dtype = DataTypeField(required=True) unit = UnitField() diff --git a/avlos/definitions/remote_node.py b/avlos/definitions/remote_node.py index 8714fa1..cdb8458 100644 --- a/avlos/definitions/remote_node.py +++ b/avlos/definitions/remote_node.py @@ -1,26 +1,16 @@ from collections import OrderedDict -from marshmallow import ( - Schema, - fields, - post_load, - validates_schema, - ValidationError, -) -from avlos.unit_field import UnitField + +from marshmallow import Schema, ValidationError, fields, post_load, validates_schema + from avlos.bitmask_field import BitmaskField -from avlos.enum_field import EnumField from avlos.counter import get_counter from avlos.datatypes import DataTypeField +from avlos.definitions import RemoteArgumentSchema, RemoteAttribute, RemoteBitmask, RemoteEnum, RemoteFunction +from avlos.enum_field import EnumField from avlos.mixins.comm_node import CommNode -from avlos.mixins.named_node import NamedNode from avlos.mixins.impex_node import ImpexNode -from avlos.definitions import ( - RemoteAttribute, - RemoteFunction, - RemoteArgumentSchema, - RemoteBitmask, - RemoteEnum, -) +from avlos.mixins.named_node import NamedNode +from avlos.unit_field import UnitField class RemoteNode(CommNode, NamedNode, ImpexNode): @@ -112,12 +102,7 @@ def str_dump(self, indent, depth): lines = [] for key, val in self.remote_attributes.items(): if isinstance(val, RemoteNode): - val_str = ( - indent - + key - + (": " if depth == 1 else ":\n") - + val.str_dump(indent + " ", depth - 1) - ) + val_str = indent + key + (": " if depth == 1 else ":\n") + val.str_dump(indent + " ", depth - 1) else: val_str = indent + val.str_dump() lines.append(val_str) @@ -139,9 +124,7 @@ class RemoteNodeSchema(Schema): RemoteAttribute, RemoteBitmask and RemoteFunction classes """ - name = fields.String( - required=True, error_messages={"required": "Name is required."} - ) + name = fields.String(required=True, error_messages={"required": "Name is required."}) summary = fields.String() remote_attributes = fields.List(fields.Nested(lambda: RemoteNodeSchema())) dtype = DataTypeField() @@ -206,13 +189,9 @@ def validate_schema(self, data, **kwargs): and "setter_name" not in data and "caller_name" not in data ): - raise ValidationError( - "Either a getter, setter, caller or remote attributes list is required" - ) + raise ValidationError("Either a getter, setter, caller or remote attributes list is required") if "getter_name" in data and "setter_name" in data and "caller_name" in data: - raise ValidationError( - "A getter, setter, and caller cannot coexist in a single endpoint" - ) + raise ValidationError("A getter, setter, and caller cannot coexist in a single endpoint") if ( ("getter_name" in data or "setter_name" in data or "caller_name" in data) and "dtype" not in data diff --git a/avlos/definitions/remote_root_node.py b/avlos/definitions/remote_root_node.py index bb05c32..9521c06 100644 --- a/avlos/definitions/remote_root_node.py +++ b/avlos/definitions/remote_root_node.py @@ -1,5 +1,7 @@ from functools import cached_property + from marshmallow import fields, post_load + from avlos.definitions import RemoteNode, RemoteNodeSchema diff --git a/avlos/deserializer.py b/avlos/deserializer.py index e21ffff..43f2b21 100644 --- a/avlos/deserializer.py +++ b/avlos/deserializer.py @@ -1,7 +1,8 @@ -import json import hashlib -from avlos.definitions import RootNodeSchema +import json + from avlos.counter import make_counter +from avlos.definitions import RootNodeSchema def deserialize(device_description): diff --git a/avlos/enum_field.py b/avlos/enum_field.py index 3345275..3d21bb1 100644 --- a/avlos/enum_field.py +++ b/avlos/enum_field.py @@ -1,5 +1,6 @@ import enum -from marshmallow import fields, ValidationError + +from marshmallow import ValidationError, fields class EnumField(fields.Field): diff --git a/avlos/formatting.py b/avlos/formatting.py index 3437317..3a9a344 100644 --- a/avlos/formatting.py +++ b/avlos/formatting.py @@ -1,10 +1,9 @@ """ Code formatting utilities for generated code. """ -import subprocess + import shutil -import sys -from pathlib import Path +import subprocess def is_clang_format_available() -> bool: @@ -28,10 +27,7 @@ def format_c_code(file_path: str, style: str = "LLVM") -> bool: try: result = subprocess.run( - ["clang-format", "-i", f"--style={style}", file_path], - capture_output=True, - timeout=10, - check=False + ["clang-format", "-i", f"--style={style}", file_path], capture_output=True, timeout=10, check=False ) return result.returncode == 0 except (subprocess.TimeoutExpired, subprocess.SubprocessError): diff --git a/avlos/generators/filters.py b/avlos/generators/filters.py index 1c98c7e..0d4ee32 100644 --- a/avlos/generators/filters.py +++ b/avlos/generators/filters.py @@ -1,6 +1,6 @@ import os -from typing import List from copy import copy +from typing import List def avlos_endpoints(input) -> List: @@ -16,14 +16,11 @@ def avlos_endpoints(input) -> List: Returns: Flat list of all endpoint objects found in the tree """ + def traverse_endpoint_list(ep_list, ep_out_list: List) -> None: """Helper function to recursively traverse endpoint tree.""" for ep in ep_list: - if ( - hasattr(ep, "getter_name") - or hasattr(ep, "setter_name") - or hasattr(ep, "caller_name") - ): + if hasattr(ep, "getter_name") or hasattr(ep, "setter_name") or hasattr(ep, "caller_name"): ep_out_list.append(ep) elif hasattr(ep, "remote_attributes"): traverse_endpoint_list(ep.remote_attributes.values(), ep_out_list) diff --git a/avlos/generators/generator_c.py b/avlos/generators/generator_c.py index f608791..c412b07 100644 --- a/avlos/generators/generator_c.py +++ b/avlos/generators/generator_c.py @@ -1,14 +1,11 @@ import os import sys + from jinja2 import Environment, PackageLoader, select_autoescape -from avlos.generators.filters import ( - avlos_endpoints, - avlos_enum_eps, - avlos_bitmask_eps, - as_include, -) -from avlos.validation import validate_all, ValidationError + from avlos.formatting import format_c_code, is_clang_format_available +from avlos.generators.filters import as_include, avlos_bitmask_eps, avlos_endpoints, avlos_enum_eps +from avlos.validation import ValidationError, validate_all env = Environment(loader=PackageLoader("avlos"), autoescape=select_autoescape()) diff --git a/avlos/generators/generator_cpp.py b/avlos/generators/generator_cpp.py index 2f5d338..f87c61a 100644 --- a/avlos/generators/generator_cpp.py +++ b/avlos/generators/generator_cpp.py @@ -1,15 +1,12 @@ import os import sys from pathlib import Path + from jinja2 import Environment, PackageLoader, select_autoescape -from avlos.generators.filters import ( - avlos_enum_eps, - avlos_bitmask_eps, - file_from_path, - capitalize_first, -) -from avlos.validation import validate_all, ValidationError + from avlos.formatting import format_c_code, is_clang_format_available +from avlos.generators.filters import avlos_bitmask_eps, avlos_enum_eps, capitalize_first, file_from_path +from avlos.validation import ValidationError, validate_all env = Environment(loader=PackageLoader("avlos"), autoescape=select_autoescape()) diff --git a/avlos/generators/generator_dbc.py b/avlos/generators/generator_dbc.py index c6df1ca..1986d5d 100644 --- a/avlos/generators/generator_dbc.py +++ b/avlos/generators/generator_dbc.py @@ -1,5 +1,7 @@ import os + from jinja2 import Environment, PackageLoader, select_autoescape + from avlos.generators.filters import avlos_endpoints env = Environment(loader=PackageLoader("avlos"), autoescape=select_autoescape()) diff --git a/avlos/generators/generator_rst.py b/avlos/generators/generator_rst.py index 0c159d8..f195fbc 100644 --- a/avlos/generators/generator_rst.py +++ b/avlos/generators/generator_rst.py @@ -1,5 +1,7 @@ import os + from jinja2 import Environment, PackageLoader, select_autoescape + from avlos.generators.filters import avlos_endpoints env = Environment(loader=PackageLoader("avlos"), autoescape=select_autoescape()) diff --git a/avlos/json_codec.py b/avlos/json_codec.py index f0cf027..ca21ba4 100644 --- a/avlos/json_codec.py +++ b/avlos/json_codec.py @@ -1,5 +1,7 @@ import json + import pint + from avlos import get_registry diff --git a/avlos/mixins/func_attr_node.py b/avlos/mixins/func_attr_node.py index 6322fde..5d5b67b 100644 --- a/avlos/mixins/func_attr_node.py +++ b/avlos/mixins/func_attr_node.py @@ -1,4 +1,3 @@ - class FuncAttrNode: def __init__(self, func_attr): self.func_attr = func_attr diff --git a/avlos/processor.py b/avlos/processor.py index fbd19b8..bedefe4 100644 --- a/avlos/processor.py +++ b/avlos/processor.py @@ -1,6 +1,7 @@ -from os.path import join, dirname, basename, realpath -import yaml from importlib import import_module +from os.path import basename, dirname, join, realpath + +import yaml def process_with_config_file(device_instance, avlos_config_path, traverse_path=False): @@ -37,7 +38,5 @@ def process_with_config_object(device_instance, avlos_config): """ for module_name, module_config in avlos_config["generators"].items(): if "enabled" in module_config and True == module_config["enabled"]: - generator = import_module( - ".generators.{}".format(module_name), package="avlos" - ) + generator = import_module(".generators.{}".format(module_name), package="avlos") generator.process(device_instance, module_config) diff --git a/avlos/templates/device.cpp.jinja b/avlos/templates/device.cpp.jinja index 3f14910..f2418f5 100644 --- a/avlos/templates/device.cpp.jinja +++ b/avlos/templates/device.cpp.jinja @@ -11,7 +11,7 @@ { {{attr.dtype.c_name}} value = 0; this->send({{attr.ep_id}}, this->_data, 0, true); - if (this->recv({{attr.ep_id}}, this->_data, &(this->_dlc), this->delay_us_value)) + if (this->recv({{attr.ep_id}}, this->_data, &(this->_dlc), this->delay_us_value)) { read_le(&value, this->_data); } @@ -24,7 +24,7 @@ void {{ device_name | capitalize_first }}::get_{{attr.name}}(char out_value[]) { this->send({{attr.ep_id}}, this->_data, 0, true); this->_dlc = 0; - if (this->recv({{attr.ep_id}}, this->_data, &(this->_dlc), this->delay_us_value)) + if (this->recv({{attr.ep_id}}, this->_data, &(this->_dlc), this->delay_us_value)) { memcpy(out_value, this->_data, this->_dlc); } @@ -87,4 +87,3 @@ void {{ device_name | capitalize_first }}::set_{{attr.name}}(char value[]) {%- endif %} {%- endfor %} {%- endif %} - diff --git a/avlos/templates/device.dbc.jinja b/avlos/templates/device.dbc.jinja index 387d60f..063b207 100644 --- a/avlos/templates/device.dbc.jinja +++ b/avlos/templates/device.dbc.jinja @@ -22,4 +22,4 @@ BO_ {{attr.ep_id}} {{attr.caller_name}}: {{[attr.arguments|length, 1]|max}} Vect {%- if count.append(count.pop() + arg.dtype.size) %}{% endif %} {# increment count by 1 #} {%- endfor %} {%- endif %} -{% endfor %} \ No newline at end of file +{% endfor %} diff --git a/avlos/templates/device.hpp.jinja b/avlos/templates/device.hpp.jinja index 4a4ec65..954542f 100644 --- a/avlos/templates/device.hpp.jinja +++ b/avlos/templates/device.hpp.jinja @@ -51,7 +51,7 @@ class {{ device_name | capitalize_first }} : Node public: {{ device_name | capitalize_first }}(uint8_t _can_node_id, send_callback _send_cb, recv_callback _recv_cb, delay_us_callback _delay_us_cb, uint32_t _delay_us_value): - Node(_can_node_id, _send_cb, _recv_cb, _delay_us_cb, _delay_us_value) + Node(_can_node_id, _send_cb, _recv_cb, _delay_us_cb, _delay_us_value) {%- if instance.remote_attributes %} {%- for attr in instance.remote_attributes.values() %} {%- if attr.remote_attributes %} diff --git a/avlos/templates/docs.rst.jinja b/avlos/templates/docs.rst.jinja index 4c82f53..9a4673a 100644 --- a/avlos/templates/docs.rst.jinja +++ b/avlos/templates/docs.rst.jinja @@ -42,4 +42,4 @@ ID: {{ attr.ep_id }} {{ attr.summary }} {%- endif %} -{% endfor %} \ No newline at end of file +{% endfor %} diff --git a/avlos/templates/remote_object.cpp.jinja b/avlos/templates/remote_object.cpp.jinja index 397967f..d914259 100644 --- a/avlos/templates/remote_object.cpp.jinja +++ b/avlos/templates/remote_object.cpp.jinja @@ -11,7 +11,7 @@ { {{attr.dtype.c_name}} value = 0; this->send({{attr.ep_id}}, this->_data, 0, true); - if (this->recv({{attr.ep_id}}, this->_data, &(this->_dlc), this->delay_us_value)) + if (this->recv({{attr.ep_id}}, this->_data, &(this->_dlc), this->delay_us_value)) { read_le(&value, this->_data); } @@ -24,7 +24,7 @@ void {{ instance.name | capitalize_first }}_::get_{{attr.name}}(char out_value[] { this->send({{attr.ep_id}}, this->_data, 0, true); this->_dlc = 0; - if (this->recv({{attr.ep_id}}, this->_data, &(this->_dlc), this->delay_us_value)) + if (this->recv({{attr.ep_id}}, this->_data, &(this->_dlc), this->delay_us_value)) { memcpy(out_value, this->_data, this->_dlc); } @@ -90,7 +90,7 @@ void {{ instance.name | capitalize_first }}_::set_{{attr.name}}(char value[]) {%- if attr.dtype.c_name != "void" %} {{attr.dtype.c_name}} value = 0; this->send(17, this->_data, 0, true); - if (this->recv(17, this->_data, &(this->_dlc), this->delay_us_value)) + if (this->recv(17, this->_data, &(this->_dlc), this->delay_us_value)) { read_le(&value, this->_data); } @@ -101,5 +101,3 @@ void {{ instance.name | capitalize_first }}_::set_{{attr.name}}(char value[]) {%- endif %} {%- endfor %} {%- endif %} - - diff --git a/avlos/templates/remote_object.hpp.jinja b/avlos/templates/remote_object.hpp.jinja index 164879f..99bc640 100644 --- a/avlos/templates/remote_object.hpp.jinja +++ b/avlos/templates/remote_object.hpp.jinja @@ -23,7 +23,7 @@ class {{ instance.name | capitalize_first }}_ : Node public: {{ instance.name | capitalize_first }}_(uint8_t _can_node_id, send_callback _send_cb, recv_callback _recv_cb, delay_us_callback _delay_us_cb, uint32_t _delay_us_value): - Node(_can_node_id, _send_cb, _recv_cb, _delay_us_cb, _delay_us_value) + Node(_can_node_id, _send_cb, _recv_cb, _delay_us_cb, _delay_us_value) {%- if instance.remote_attributes %} {%- for attr in instance.remote_attributes.values() %} {%- if attr.remote_attributes %} @@ -55,7 +55,7 @@ class {{ instance.name | capitalize_first }}_ : Node {%- set comma = joiner(", ") %} {{attr.dtype.c_name}} {{attr.name}}({%- for arg in attr.arguments %}{{ comma() }}{{arg.dtype.c_name}} {{ arg.name }} {%- endfor %}); {%- endif %} - + {%- endfor %} {%- endif %} diff --git a/avlos/unit_field.py b/avlos/unit_field.py index 182b684..99efac6 100644 --- a/avlos/unit_field.py +++ b/avlos/unit_field.py @@ -1,5 +1,5 @@ -from marshmallow import fields, ValidationError import pint +from marshmallow import ValidationError, fields _registry = None diff --git a/avlos/validation.py b/avlos/validation.py index d0c5030..5a47777 100644 --- a/avlos/validation.py +++ b/avlos/validation.py @@ -2,24 +2,62 @@ Pre-generation validation for Avlos code generation. Validates C identifiers, detects conflicts, and ensures consistency. """ -from typing import List + import re +from typing import List # C reserved words (C11 standard) C_RESERVED_WORDS = { - 'auto', 'break', 'case', 'char', 'const', 'continue', 'default', 'do', - 'double', 'else', 'enum', 'extern', 'float', 'for', 'goto', 'if', - 'int', 'long', 'register', 'return', 'short', 'signed', 'sizeof', 'static', - 'struct', 'switch', 'typedef', 'union', 'unsigned', 'void', 'volatile', 'while', - '_Alignas', '_Alignof', '_Atomic', '_Bool', '_Complex', '_Generic', '_Imaginary', - '_Noreturn', '_Static_assert', '_Thread_local' + "auto", + "break", + "case", + "char", + "const", + "continue", + "default", + "do", + "double", + "else", + "enum", + "extern", + "float", + "for", + "goto", + "if", + "int", + "long", + "register", + "return", + "short", + "signed", + "sizeof", + "static", + "struct", + "switch", + "typedef", + "union", + "unsigned", + "void", + "volatile", + "while", + "_Alignas", + "_Alignof", + "_Atomic", + "_Bool", + "_Complex", + "_Generic", + "_Imaginary", + "_Noreturn", + "_Static_assert", + "_Thread_local", } -C_IDENTIFIER_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') +C_IDENTIFIER_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") class ValidationError(Exception): """Raised when validation fails.""" + pass @@ -43,15 +81,15 @@ def validate_c_identifier(name: str, context: str = "") -> None: if name in C_RESERVED_WORDS: ctx = f" ({context})" if context else "" - raise ValidationError( - f"Invalid C identifier '{name}'{ctx}. '{name}' is a C reserved word." - ) + raise ValidationError(f"Invalid C identifier '{name}'{ctx}. '{name}' is a C reserved word.") if len(name) > 63: # C99 requires at least 63 significant characters for identifiers ctx = f" ({context})" if context else "" - print(f"Warning: Identifier '{name}'{ctx} is very long ({len(name)} chars). " - f"Some compilers may truncate after 63 characters.") + print( + f"Warning: Identifier '{name}'{ctx} is very long ({len(name)} chars). " + f"Some compilers may truncate after 63 characters." + ) def validate_endpoint_ids(instance) -> List[str]: @@ -72,10 +110,7 @@ def validate_endpoint_ids(instance) -> List[str]: for ep in avlos_endpoints(instance): ep_id = ep.ep_id if ep_id in ep_id_map: - errors.append( - f"Duplicate endpoint ID {ep_id}: " - f"'{ep.full_name}' and '{ep_id_map[ep_id].full_name}'" - ) + errors.append(f"Duplicate endpoint ID {ep_id}: " f"'{ep.full_name}' and '{ep_id_map[ep_id].full_name}'") else: ep_id_map[ep_id] = ep @@ -98,19 +133,19 @@ def validate_function_names(instance) -> List[str]: # Check getter/setter/caller names are valid C identifiers for ep in avlos_endpoints(instance): - if hasattr(ep, 'getter_name') and ep.getter_name: + if hasattr(ep, "getter_name") and ep.getter_name: try: validate_c_identifier(ep.getter_name, f"getter for {ep.full_name}") except ValidationError as e: errors.append(str(e)) - if hasattr(ep, 'setter_name') and ep.setter_name: + if hasattr(ep, "setter_name") and ep.setter_name: try: validate_c_identifier(ep.setter_name, f"setter for {ep.full_name}") except ValidationError as e: errors.append(str(e)) - if hasattr(ep, 'caller_name') and ep.caller_name: + if hasattr(ep, "caller_name") and ep.caller_name: try: validate_c_identifier(ep.caller_name, f"caller for {ep.full_name}") except ValidationError as e: @@ -155,7 +190,7 @@ def traverse_nodes(node, path=""): errors.append(str(e)) # Recursively check children - if hasattr(node, 'remote_attributes'): + if hasattr(node, "remote_attributes"): for child in node.remote_attributes.values(): traverse_nodes(child, current_path) diff --git a/docs/cli.rst b/docs/cli.rst index 3ffb99c..2c5a441 100644 --- a/docs/cli.rst +++ b/docs/cli.rst @@ -8,4 +8,4 @@ Avlos CLI .. code-block:: console - avlos from url https://your.url/spec.yaml \ No newline at end of file + avlos from url https://your.url/spec.yaml diff --git a/docs/conf.py b/docs/conf.py index 20247b5..cce0628 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -4,7 +4,7 @@ # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html -master_doc = 'index' +master_doc = "index" # -- Path setup -------------------------------------------------------------- @@ -19,9 +19,9 @@ # -- Project information ----------------------------------------------------- -project = 'Avlos' -copyright = '2022, Yannis Chatzikonstantinou' -author = 'Yannis Chatzikonstantinou' +project = "Avlos" +copyright = "2022, Yannis Chatzikonstantinou" +author = "Yannis Chatzikonstantinou" # -- General configuration --------------------------------------------------- @@ -32,12 +32,12 @@ extensions = [] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- @@ -45,7 +45,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'alabaster' +html_theme = "alabaster" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, diff --git a/docs/config.rst b/docs/config.rst index b7bcdde..ace34ee 100644 --- a/docs/config.rst +++ b/docs/config.rst @@ -59,7 +59,7 @@ Output Config The output config defines the output modules that will be used and their options. Example, showing C code generation for embedded devices: .. code-block:: - + generators: generator_c: enabled: true @@ -100,4 +100,4 @@ Ensure the output config exists in the current folder. avlos from url https://your.url/spec.yaml -This will generate the outputs according to the configuration in the output config file. \ No newline at end of file +This will generate the outputs according to the configuration in the output config file. diff --git a/docs/index.rst b/docs/index.rst index e4b9294..00b3df1 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -13,5 +13,3 @@ Welcome to Avlos documentation! spec_format config cli - - diff --git a/docs/introduction.rst b/docs/introduction.rst index 7758d95..b3130cc 100644 --- a/docs/introduction.rst +++ b/docs/introduction.rst @@ -3,7 +3,7 @@ Introduction Avlos makes it easy to create protocol implementations to communicate with remote embedded devices. -Given a remote embedded device, a client that wants to talk with the device, and a YAML file that represents the remote device structure that we want exposed to the client (the spec), Avlos will generate a protocol implementation based on the spec. It will also generate documentation and more. +Given a remote embedded device, a client that wants to talk with the device, and a YAML file that represents the remote device structure that we want exposed to the client (the spec), Avlos will generate a protocol implementation based on the spec. It will also generate documentation and more. .. figure:: diagram.png :width: 800 @@ -29,4 +29,3 @@ Versioning Avlos uses git tags for version management via `setuptools-scm `_. To create a new release, tag the commit: ``git tag v0.X.Y`` - diff --git a/docs/spec_format.rst b/docs/spec_format.rst index ebfe456..ef9a583 100644 --- a/docs/spec_format.rst +++ b/docs/spec_format.rst @@ -1,3 +1,2 @@ Avlos Specification Format ************************** - diff --git a/example/README.md b/example/README.md index e582c71..1c67575 100644 --- a/example/README.md +++ b/example/README.md @@ -6,4 +6,4 @@ avlos from file device.yaml -You'll get your output in the corresponding folders. \ No newline at end of file +You'll get your output in the corresponding folders. diff --git a/example/avlos_config.yaml b/example/avlos_config.yaml index 41e6f64..5b21b96 100644 --- a/example/avlos_config.yaml +++ b/example/avlos_config.yaml @@ -11,4 +11,4 @@ generators: generator_rst: enabled: true paths: - output_file: docs/protocol.rst \ No newline at end of file + output_file: docs/protocol.rst diff --git a/example/device.yaml b/example/device.yaml index e168b87..2538a05 100644 --- a/example/device.yaml +++ b/example/device.yaml @@ -17,4 +17,4 @@ remote_attributes: dtype: bool getter_name: toaster_get_relay_state setter_name: toaster_set_relay_state - summary: The toaster heating relay element state. \ No newline at end of file + summary: The toaster heating relay element state. diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..28688d8 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,32 @@ +[flake8] +max-line-length = 127 +max-complexity = 15 +extend-ignore = E203, W503, E711, E712, C901 +exclude = + .git, + __pycache__, + build, + dist, + venv, + .eggs, + *.egg, + .tox +per-file-ignores = + # __init__.py files can have unused imports for package API + */__init__.py:F401,F403 + # cli.py: logger is configured but may not be directly used + avlos/cli.py:F841 + # Test files may have unused imports for clarity + tests/*.py:F401,F841 + # Generator files may import for future use + avlos/generators/*.py:F401 + avlos/json_codec.py:F401 + avlos/mixins/*.py:F401 + avlos/mixins/impex_node.py:E711,E712 + # Processor and validation may use legacy comparison style + avlos/processor.py:E712 + avlos/validation.py:C901 + +[isort] +profile = black +line_length = 127 diff --git a/setup.py b/setup.py index 0ae896e..5576955 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,8 @@ #!/usr/bin/env python import pathlib -from setuptools import setup, find_packages + +from setuptools import find_packages, setup # The directory containing this file HERE = pathlib.Path(__file__).parent diff --git a/tests/definition/bad_device_name.yaml b/tests/definition/bad_device_name.yaml index 11fa4bf..03aefc8 100644 --- a/tests/definition/bad_device_name.yaml +++ b/tests/definition/bad_device_name.yaml @@ -5,4 +5,3 @@ remote_attributes: - dtype: uint32 getter_name: system_get_sn summary: Retrieve the unique device serial number. - diff --git a/tests/definition/good_device.yaml b/tests/definition/good_device.yaml index b79846a..57aa921 100644 --- a/tests/definition/good_device.yaml +++ b/tests/definition/good_device.yaml @@ -8,14 +8,13 @@ remote_attributes: - name: nickname dtype: string getter_name: system_get_name - setter_name: system_set_name + setter_name: system_set_name summary: Retrieve the device name - name: errors flags: [UNDERVOLTAGE] - meta: {dynamic: True} getter_name: system_get_error summary: Retrieve any device errors. - meta: {"lalala": "ok"} + meta: {dynamic: True, lalala: "ok"} - name: Vbus dtype: float unit: volt @@ -33,7 +32,7 @@ remote_attributes: caller_name: move_to dtype: void func_attr: TM_RAMFUNC - arguments: + arguments: - name: position dtype: float unit: tick @@ -91,4 +90,3 @@ remote_attributes: setter_name: encoder_set_bandwidth summary: Access the encoder observer bandwidth. func_attr: TM_RAMFUNC - diff --git a/tests/definition/obsolete_device.yaml b/tests/definition/obsolete_device.yaml index 7f74ad5..d79359e 100644 --- a/tests/definition/obsolete_device.yaml +++ b/tests/definition/obsolete_device.yaml @@ -8,7 +8,6 @@ remote_attributes: summary: Retrieve the unique device serial number. - name: errors flags: [UNDERVOLTAGE] - meta: {dynamic: True} getter_name: system_get_error summary: Retrieve any device errors. - meta: {"lalala": "ok"} + meta: {dynamic: True, lalala: "ok"} diff --git a/tests/test_counter.py b/tests/test_counter.py index dc4ce05..04f4264 100644 --- a/tests/test_counter.py +++ b/tests/test_counter.py @@ -1,7 +1,7 @@ -from avlos.counter import get_counter, make_counter, delete_counter - import unittest +from avlos.counter import delete_counter, get_counter, make_counter + class TestCounter(unittest.TestCase): def test_make_counter_return(self): diff --git a/tests/test_data_model_properties.py b/tests/test_data_model_properties.py index 165f4bd..86004e3 100644 --- a/tests/test_data_model_properties.py +++ b/tests/test_data_model_properties.py @@ -1,14 +1,17 @@ """ Tests for data model properties added for code generation. """ + import unittest + import yaml -from avlos.deserializer import deserialize + +from avlos.datatypes import DataType from avlos.definitions.remote_attribute import RemoteAttribute -from avlos.definitions.remote_function import RemoteFunction -from avlos.definitions.remote_enum import RemoteEnum from avlos.definitions.remote_bitmask import RemoteBitmask -from avlos.datatypes import DataType +from avlos.definitions.remote_enum import RemoteEnum +from avlos.definitions.remote_function import RemoteFunction +from avlos.deserializer import deserialize class TestDataModelProperties(unittest.TestCase): @@ -41,10 +44,14 @@ def test_byval_getter_strategy_float(self): def test_byval_getter_strategy_integers(self): """Test that integer types return 'byval' getter strategy.""" int_types = [ - DataType.UINT8, DataType.INT8, - DataType.UINT16, DataType.INT16, - DataType.UINT32, DataType.INT32, - DataType.UINT64, DataType.INT64, + DataType.UINT8, + DataType.INT8, + DataType.UINT16, + DataType.INT16, + DataType.UINT32, + DataType.INT32, + DataType.UINT64, + DataType.INT64, ] for dtype in int_types: @@ -55,10 +62,8 @@ def test_byval_getter_strategy_integers(self): getter_name="get_value", ) - self.assertEqual(attr.getter_strategy, "byval", - f"{dtype} should use byval strategy") - self.assertEqual(attr.setter_strategy, "byval", - f"{dtype} should use byval strategy") + self.assertEqual(attr.getter_strategy, "byval", f"{dtype} should use byval strategy") + self.assertEqual(attr.setter_strategy, "byval", f"{dtype} should use byval strategy") def test_endpoint_function_name_simple(self): """Test endpoint function name for simple attribute.""" @@ -79,19 +84,16 @@ def test_endpoint_function_name_nested(self): """Test endpoint function name for nested attribute.""" import importlib.resources - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_desc_stream: obj = deserialize(yaml.safe_load(device_desc_stream)) # Find a nested attribute (e.g., motor.R) - if hasattr(obj, 'motor') and hasattr(obj.motor, 'R'): + if hasattr(obj, "motor") and hasattr(obj.motor, "R"): attr = obj.motor.R expected_name = "avlos_motor_R" - self.assertEqual(attr.endpoint_function_name, expected_name, - f"motor.R should generate {expected_name}") + self.assertEqual(attr.endpoint_function_name, expected_name, f"motor.R should generate {expected_name}") def test_is_string_type_true(self): """Test is_string_type property for char[] type.""" @@ -133,15 +135,13 @@ def test_remote_function_endpoint_name_nested(self): """Test endpoint function name for nested RemoteFunction.""" import importlib.resources - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_desc_stream: obj = deserialize(yaml.safe_load(device_desc_stream)) # Find nested function (e.g., controller.set_pos_vel_setpoints) - if hasattr(obj, 'controller') and hasattr(obj.controller, 'set_pos_vel_setpoints'): + if hasattr(obj, "controller") and hasattr(obj.controller, "set_pos_vel_setpoints"): func = obj.controller.set_pos_vel_setpoints expected_name = "avlos_controller_set_pos_vel_setpoints" self.assertEqual(func.endpoint_function_name, expected_name) @@ -194,14 +194,11 @@ class TestFlags(IntFlag): def test_backward_compatibility_generated_code(self): """Test that generated code is functionally equivalent to before refactoring.""" import importlib.resources + from avlos.generators import generator_c - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) - output_impl = str( - importlib.resources.files("tests").joinpath("outputs/test_backward_compat.c") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) + output_impl = str(importlib.resources.files("tests").joinpath("outputs/test_backward_compat.c")) with open(def_path_str) as device_desc_stream: obj = deserialize(yaml.safe_load(device_desc_stream)) @@ -209,12 +206,8 @@ def test_backward_compatibility_generated_code(self): config = { "hash_string": "0x9e8dc7ac", "paths": { - "output_enums": str( - importlib.resources.files("tests").joinpath("outputs/test_enum_compat.h") - ), - "output_header": str( - importlib.resources.files("tests").joinpath("outputs/test_header_compat.h") - ), + "output_enums": str(importlib.resources.files("tests").joinpath("outputs/test_enum_compat.h")), + "output_header": str(importlib.resources.files("tests").joinpath("outputs/test_header_compat.h")), "output_impl": output_impl, }, } @@ -230,10 +223,8 @@ def test_backward_compatibility_generated_code(self): self.assertIn("avlos_", generated_code, "Should have avlos_ prefixed functions") self.assertIn("AVLOS_CMD_READ", generated_code, "Should handle read commands") self.assertIn("AVLOS_CMD_WRITE", generated_code, "Should handle write commands") - self.assertIn("_avlos_getter_string", generated_code, - "Should have string getter helper") - self.assertIn("_avlos_setter_string", generated_code, - "Should have string setter helper") + self.assertIn("_avlos_getter_string", generated_code, "Should have string getter helper") + self.assertIn("_avlos_setter_string", generated_code, "Should have string setter helper") # Verify function declarations use properties # (all endpoint functions should be present) @@ -242,11 +233,10 @@ def test_backward_compatibility_generated_code(self): def test_all_endpoints_have_function_names(self): """Test that all endpoints from good_device.yaml have endpoint_function_name.""" import importlib.resources + from avlos.generators.filters import avlos_endpoints - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_desc_stream: obj = deserialize(yaml.safe_load(device_desc_stream)) @@ -254,22 +244,22 @@ def test_all_endpoints_have_function_names(self): endpoints = avlos_endpoints(obj) for ep in endpoints: - self.assertTrue(hasattr(ep, 'endpoint_function_name'), - f"Endpoint {ep.name} should have endpoint_function_name property") + self.assertTrue( + hasattr(ep, "endpoint_function_name"), f"Endpoint {ep.name} should have endpoint_function_name property" + ) func_name = ep.endpoint_function_name - self.assertTrue(func_name.startswith("avlos_"), - f"Endpoint function name should start with avlos_, got: {func_name}") - self.assertNotIn(".", func_name, - f"Endpoint function name should not contain dots, got: {func_name}") + self.assertTrue( + func_name.startswith("avlos_"), f"Endpoint function name should start with avlos_, got: {func_name}" + ) + self.assertNotIn(".", func_name, f"Endpoint function name should not contain dots, got: {func_name}") def test_getter_setter_strategy_consistency(self): """Test that getter and setter strategies are consistent.""" import importlib.resources + from avlos.generators.filters import avlos_endpoints - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_desc_stream: obj = deserialize(yaml.safe_load(device_desc_stream)) @@ -278,14 +268,17 @@ def test_getter_setter_strategy_consistency(self): for ep in endpoints: # Skip functions (they don't have getter/setter strategies in the same way) - if hasattr(ep, 'caller_name') and not hasattr(ep, 'getter_name'): + if hasattr(ep, "caller_name") and not hasattr(ep, "getter_name"): continue - if hasattr(ep, 'getter_strategy') and hasattr(ep, 'setter_strategy'): + if hasattr(ep, "getter_strategy") and hasattr(ep, "setter_strategy"): # Getter and setter strategies should match for attributes - self.assertEqual(ep.getter_strategy, ep.setter_strategy, - f"Endpoint {ep.name} should have consistent getter/setter strategies") + self.assertEqual( + ep.getter_strategy, + ep.setter_strategy, + f"Endpoint {ep.name} should have consistent getter/setter strategies", + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_deserialization.py b/tests/test_deserialization.py index 1556646..4f2a1d3 100644 --- a/tests/test_deserialization.py +++ b/tests/test_deserialization.py @@ -1,19 +1,19 @@ -import yaml import importlib.resources +import unittest import urllib.request -from avlos.deserializer import deserialize -from avlos.definitions.remote_node import RemoteNodeSchema + import marshmallow import pint -import unittest +import yaml + +from avlos.definitions.remote_node import RemoteNodeSchema +from avlos.deserializer import deserialize from tests.dummy_channel import DummyChannel class TestDeserialization(unittest.TestCase): def test_success(self): - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_description: obj = deserialize(yaml.safe_load(device_description)) obj._channel = DummyChannel() @@ -28,51 +28,33 @@ def test_success_url(self): print(obj) def test_undefined_unit(self): - def_path_str = str( - importlib.resources.files("tests").joinpath( - "definition/bad_device_unit.yaml" - ) - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/bad_device_unit.yaml")) with open(def_path_str) as device_description: with self.assertRaises(pint.errors.UndefinedUnitError): deserialize(yaml.safe_load(device_description)) def test_bitmask_labels(self): - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_description: device = deserialize(yaml.safe_load(device_description)) device._channel = DummyChannel() self.assertEqual(device.errors.value, 0) def test_empty_bitmask_labels(self): - def_path_str = str( - importlib.resources.files("tests").joinpath( - "definition/bad_device_bitmask.yaml" - ) - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/bad_device_bitmask.yaml")) with open(def_path_str) as device_description: with self.assertRaises(marshmallow.exceptions.ValidationError): deserialize(yaml.safe_load(device_description)) def test_version_field_present(self): - def_path_str = str( - importlib.resources.files("tests").joinpath( - "definition/obsolete_device.yaml" - ) - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/obsolete_device.yaml")) with open(def_path_str) as device_description: device = deserialize(yaml.safe_load(device_description)) device._channel = DummyChannel() self.assertEqual(device.errors.value, 0) def test_validation_fail(self): - def_path_str = str( - importlib.resources.files("tests").joinpath( - "definition/bad_device_name.yaml" - ) - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/bad_device_name.yaml")) with open(def_path_str) as device_description: with self.assertRaises(marshmallow.exceptions.ValidationError): deserialize(yaml.safe_load(device_description)) diff --git a/tests/test_functions.py b/tests/test_functions.py index 8e5fd67..368bee2 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -1,6 +1,7 @@ -from avlos.generators.filters import as_include, file_from_path, capitalize_first import unittest +from avlos.generators.filters import as_include, capitalize_first, file_from_path + class TestFunctions(unittest.TestCase): def test_file_from_path_filter(self): diff --git a/tests/test_generation.py b/tests/test_generation.py index 5271b6e..ae1968c 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -1,30 +1,26 @@ -import yaml +import importlib.resources import subprocess +import unittest from pathlib import Path -import importlib.resources -from avlos.deserializer import deserialize -from avlos.processor import process_with_config_file + +import yaml +from rstcheck_core import _extras +from rstcheck_core import config as config_mod +from rstcheck_core import runner + import avlos.generators.generator_c as generator_c import avlos.generators.generator_cpp as generator_cpp import avlos.generators.generator_rst as generator_rst -from rstcheck_core import _extras, config as config_mod, runner -import unittest +from avlos.deserializer import deserialize +from avlos.processor import process_with_config_file class TestGeneration(unittest.TestCase): def test_c_output_manual(self): - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) - enum_path_str = str( - importlib.resources.files("tests").joinpath("outputs/tm_enums.h") - ) - header_path_str = str( - importlib.resources.files("tests").joinpath("outputs/test.h") - ) - impl_path_str = str( - importlib.resources.files("tests").joinpath("outputs/test.c") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) + enum_path_str = str(importlib.resources.files("tests").joinpath("outputs/tm_enums.h")) + header_path_str = str(importlib.resources.files("tests").joinpath("outputs/test.h")) + impl_path_str = str(importlib.resources.files("tests").joinpath("outputs/test.c")) with open(def_path_str) as device_desc_stream: obj = deserialize(yaml.safe_load(device_desc_stream)) config = { @@ -45,18 +41,10 @@ def test_c_output_manual(self): self.assertEqual(result.returncode, 0) def test_cpp_output_manual(self): - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) - helper_path_str = str( - importlib.resources.files("tests").joinpath("outputs/tm_helpers.hpp") - ) - header_path_str = str( - importlib.resources.files("tests").joinpath("outputs/base_device.hpp") - ) - impl_path_str = str( - importlib.resources.files("tests").joinpath("outputs/base_device.cpp") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) + helper_path_str = str(importlib.resources.files("tests").joinpath("outputs/tm_helpers.hpp")) + header_path_str = str(importlib.resources.files("tests").joinpath("outputs/base_device.hpp")) + impl_path_str = str(importlib.resources.files("tests").joinpath("outputs/base_device.cpp")) with open(def_path_str) as device_desc_stream: obj = deserialize(yaml.safe_load(device_desc_stream)) config = { @@ -77,12 +65,8 @@ def test_cpp_output_manual(self): self.assertEqual(result.returncode, 0) def test_rst_output_manual(self): - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) - out_path_str = str( - importlib.resources.files("tests").joinpath("outputs/test.rst") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) + out_path_str = str(importlib.resources.files("tests").joinpath("outputs/test.rst")) with open(def_path_str) as device_desc_stream: obj = deserialize(yaml.safe_load(device_desc_stream)) config = { @@ -93,19 +77,13 @@ def test_rst_output_manual(self): rstcheck_config = config_mod.RstcheckConfig() path = Path(out_path_str) - _runner = runner.RstcheckMainRunner( - check_paths=[path], rstcheck_config=rstcheck_config, overwrite_config=False - ) + _runner = runner.RstcheckMainRunner(check_paths=[path], rstcheck_config=rstcheck_config, overwrite_config=False) _runner.check() _runner.print_result() def test_avlos_config(self): - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) - config_file_path_str = str( - importlib.resources.files("tests").joinpath("definition/avlos_config.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) + config_file_path_str = str(importlib.resources.files("tests").joinpath("definition/avlos_config.yaml")) with open(def_path_str) as device_desc_stream: obj = deserialize(yaml.safe_load(device_desc_stream)) process_with_config_file(obj, config_file_path_str) diff --git a/tests/test_impex.py b/tests/test_impex.py index 593bd0f..9e52559 100644 --- a/tests/test_impex.py +++ b/tests/test_impex.py @@ -1,18 +1,18 @@ +import importlib.resources import json +import unittest + import yaml -import importlib.resources + from avlos.deserializer import deserialize -from avlos.unit_field import get_registry from avlos.json_codec import AvlosEncoder -import unittest +from avlos.unit_field import get_registry from tests.dummy_channel import DummyChannel class TestImpex(unittest.TestCase): def test_import_export_root_object(self): - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_description: obj = deserialize(yaml.safe_load(device_description)) obj._channel = DummyChannel() diff --git a/tests/test_remote_objects.py b/tests/test_remote_objects.py index 0a54b6f..0926f60 100644 --- a/tests/test_remote_objects.py +++ b/tests/test_remote_objects.py @@ -1,8 +1,10 @@ -import yaml import importlib.resources +import unittest + +import yaml + from avlos.deserializer import deserialize from avlos.unit_field import get_registry -import unittest from tests.dummy_channel import DummyChannel _reg = get_registry() @@ -10,9 +12,7 @@ class TestRemoteObjects(unittest.TestCase): def test_read_remote_properties(self): - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_description: obj = deserialize(yaml.safe_load(device_description)) obj._channel = DummyChannel() @@ -37,9 +37,7 @@ def test_read_remote_properties(self): self.assertEqual(obj.nickname, "other") def test_remote_enum_read(self): - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_description: obj = deserialize(yaml.safe_load(device_description)) obj._channel = DummyChannel() @@ -50,9 +48,7 @@ def test_remote_enum_read(self): self.assertEqual(obj.controller.mode, modes.CLOSED_LOOP) def test_remote_enum_write(self): - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_description: obj = deserialize(yaml.safe_load(device_description)) obj._channel = DummyChannel() @@ -84,9 +80,7 @@ def test_remote_enum_write(self): self.assertEqual(obj._channel.value, 1) def test_remote_function_call(self): - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_description: obj = deserialize(yaml.safe_load(device_description)) obj._channel = DummyChannel() @@ -96,9 +90,7 @@ def test_remote_function_call(self): self.assertEqual(100 * _reg("tick"), obj.controller.set_pos_vel_setpoints(0, 0)) def test_remote_function_call_w_units(self): - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_description: obj = deserialize(yaml.safe_load(device_description)) obj._channel = DummyChannel() @@ -111,9 +103,7 @@ def test_remote_function_call_w_units(self): obj._channel.write_off() def test_non_existent_remote_attributes_fail(self): - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_description: obj = deserialize(yaml.safe_load(device_description)) obj._channel = DummyChannel() @@ -125,9 +115,7 @@ def test_non_existent_remote_attributes_fail(self): print(val) def test_meta_dictionary(self): - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_description: obj = deserialize(yaml.safe_load(device_description)) self.assertEqual(1, len(obj.errors.meta)) diff --git a/tests/test_templates.py b/tests/test_templates.py index 7fafa86..c168d16 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -1,12 +1,15 @@ """ Tests for Jinja2 templates and generated code patterns. """ + +import importlib.resources import unittest + import yaml -import importlib.resources + +from avlos.datatypes import DataType from avlos.deserializer import deserialize from avlos.generators import generator_c, generator_cpp -from avlos.datatypes import DataType class TestTemplateMacros(unittest.TestCase): @@ -14,28 +17,20 @@ class TestTemplateMacros(unittest.TestCase): def setUp(self): """Set up test fixtures.""" - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_desc_stream: self.device = deserialize(yaml.safe_load(device_desc_stream)) def test_char_array_getter_uses_helper(self): """Test that char[] getter generates code using _avlos_getter_string.""" - output_impl = str( - importlib.resources.files("tests").joinpath("outputs/test_char_getter.c") - ) + output_impl = str(importlib.resources.files("tests").joinpath("outputs/test_char_getter.c")) config = { "hash_string": "0x9e8dc7ac", "paths": { - "output_enums": str( - importlib.resources.files("tests").joinpath("outputs/test_char_enum.h") - ), - "output_header": str( - importlib.resources.files("tests").joinpath("outputs/test_char_header.h") - ), + "output_enums": str(importlib.resources.files("tests").joinpath("outputs/test_char_enum.h")), + "output_header": str(importlib.resources.files("tests").joinpath("outputs/test_char_header.h")), "output_impl": output_impl, }, } @@ -46,33 +41,29 @@ def test_char_array_getter_uses_helper(self): content = f.read() # Should contain the string helper function definition - self.assertIn("_avlos_getter_string", content, - "Should define _avlos_getter_string helper") + self.assertIn("_avlos_getter_string", content, "Should define _avlos_getter_string helper") - # Should contain helper function signature - self.assertIn("uint8_t (*getter)(char*)", content, - "Helper should have correct signature") + # Should contain helper function signature (check without spaces since clang-format may add them) + self.assertTrue( + "uint8_t (*getter)(char*)" in content or "uint8_t (*getter)(char *)" in content, + "Helper should have correct signature", + ) # Should call the helper in char[] endpoint functions # (nickname is a char[] attribute in good_device.yaml) - self.assertIn("_avlos_getter_string(buffer, buffer_len, system_get_name)", content, - "Should use helper for char[] getter") + self.assertIn( + "_avlos_getter_string(buffer, buffer_len, system_get_name)", content, "Should use helper for char[] getter" + ) def test_char_array_setter_uses_helper(self): """Test that char[] setter generates code using _avlos_setter_string.""" - output_impl = str( - importlib.resources.files("tests").joinpath("outputs/test_char_setter.c") - ) + output_impl = str(importlib.resources.files("tests").joinpath("outputs/test_char_setter.c")) config = { "hash_string": "0x9e8dc7ac", "paths": { - "output_enums": str( - importlib.resources.files("tests").joinpath("outputs/test_char_enum2.h") - ), - "output_header": str( - importlib.resources.files("tests").joinpath("outputs/test_char_header2.h") - ), + "output_enums": str(importlib.resources.files("tests").joinpath("outputs/test_char_enum2.h")), + "output_header": str(importlib.resources.files("tests").joinpath("outputs/test_char_header2.h")), "output_impl": output_impl, }, } @@ -83,32 +74,26 @@ def test_char_array_setter_uses_helper(self): content = f.read() # Should contain the string helper function definition - self.assertIn("_avlos_setter_string", content, - "Should define _avlos_setter_string helper") + self.assertIn("_avlos_setter_string", content, "Should define _avlos_setter_string helper") - # Should contain helper function signature - self.assertIn("void (*setter)(const char*)", content, - "Helper should have correct signature") + # Should contain helper function signature (check without spaces since clang-format may add them) + self.assertTrue( + "void (*setter)(const char*)" in content or "void (*setter)(const char *)" in content, + "Helper should have correct signature", + ) # Should call the helper in char[] endpoint functions - self.assertIn("_avlos_setter_string(buffer, system_set_name)", content, - "Should use helper for char[] setter") + self.assertIn("_avlos_setter_string(buffer, system_set_name)", content, "Should use helper for char[] setter") def test_numeric_getter_byval(self): """Test that numeric getters use by-value pattern.""" - output_impl = str( - importlib.resources.files("tests").joinpath("outputs/test_numeric.c") - ) + output_impl = str(importlib.resources.files("tests").joinpath("outputs/test_numeric.c")) config = { "hash_string": "0x9e8dc7ac", "paths": { - "output_enums": str( - importlib.resources.files("tests").joinpath("outputs/test_numeric_enum.h") - ), - "output_header": str( - importlib.resources.files("tests").joinpath("outputs/test_numeric_header.h") - ), + "output_enums": str(importlib.resources.files("tests").joinpath("outputs/test_numeric_enum.h")), + "output_header": str(importlib.resources.files("tests").joinpath("outputs/test_numeric_header.h")), "output_impl": output_impl, }, } @@ -119,31 +104,24 @@ def test_numeric_getter_byval(self): content = f.read() # Should contain memcpy pattern for by-value types - self.assertIn("memcpy(buffer, &v, sizeof(v))", content, - "Should use memcpy for by-value getters") + self.assertIn("memcpy(buffer, &v, sizeof(v))", content, "Should use memcpy for by-value getters") # Should declare local variable for value # (check for patterns like "float v;" or "uint32_t v;") self.assertTrue( "float v;" in content or "uint32_t v;" in content or "uint8_t v;" in content, - "Should declare local variable for value" + "Should declare local variable for value", ) def test_void_function_no_return_value(self): """Test that void return type functions don't generate return value code.""" - output_impl = str( - importlib.resources.files("tests").joinpath("outputs/test_void_func.c") - ) + output_impl = str(importlib.resources.files("tests").joinpath("outputs/test_void_func.c")) config = { "hash_string": "0x9e8dc7ac", "paths": { - "output_enums": str( - importlib.resources.files("tests").joinpath("outputs/test_void_enum.h") - ), - "output_header": str( - importlib.resources.files("tests").joinpath("outputs/test_void_header.h") - ), + "output_enums": str(importlib.resources.files("tests").joinpath("outputs/test_void_enum.h")), + "output_header": str(importlib.resources.files("tests").joinpath("outputs/test_void_header.h")), "output_impl": output_impl, }, } @@ -161,28 +139,20 @@ def test_void_function_no_return_value(self): reset_func = content[start:end] # Void functions should NOT have ret_val - self.assertNotIn("ret_val", reset_func, - "Void function should not have return value") + self.assertNotIn("ret_val", reset_func, "Void function should not have return value") # Should call function directly without assignment - self.assertIn("system_reset()", reset_func, - "Should call void function without assignment") + self.assertIn("system_reset()", reset_func, "Should call void function without assignment") def test_function_with_args_unpacks_buffer(self): """Test that functions with arguments unpack from buffer.""" - output_impl = str( - importlib.resources.files("tests").joinpath("outputs/test_func_args.c") - ) + output_impl = str(importlib.resources.files("tests").joinpath("outputs/test_func_args.c")) config = { "hash_string": "0x9e8dc7ac", "paths": { - "output_enums": str( - importlib.resources.files("tests").joinpath("outputs/test_func_args_enum.h") - ), - "output_header": str( - importlib.resources.files("tests").joinpath("outputs/test_func_args_header.h") - ), + "output_enums": str(importlib.resources.files("tests").joinpath("outputs/test_func_args_enum.h")), + "output_header": str(importlib.resources.files("tests").joinpath("outputs/test_func_args_header.h")), "output_impl": output_impl, }, } @@ -193,16 +163,13 @@ def test_function_with_args_unpacks_buffer(self): content = f.read() # Should have offset tracking for multiple arguments - self.assertIn("uint8_t _offset = 0", content, - "Should track offset for argument unpacking") + self.assertIn("uint8_t _offset = 0", content, "Should track offset for argument unpacking") # Should unpack arguments with memcpy - self.assertIn("memcpy(&", content, - "Should use memcpy to unpack arguments") + self.assertIn("memcpy(&", content, "Should use memcpy to unpack arguments") # Should increment offset - self.assertIn("_offset += sizeof(", content, - "Should increment offset for each argument") + self.assertIn("_offset += sizeof(", content, "Should increment offset for each argument") def test_all_data_types_generate(self): """Test that all supported data types can be generated.""" @@ -262,19 +229,13 @@ def test_all_data_types_generate(self): obj = deserialize(yaml.safe_load(yaml_content)) - output_impl = str( - importlib.resources.files("tests").joinpath("outputs/test_all_types.c") - ) + output_impl = str(importlib.resources.files("tests").joinpath("outputs/test_all_types.c")) config = { "hash_string": "0xdeadbeef", "paths": { - "output_enums": str( - importlib.resources.files("tests").joinpath("outputs/test_all_types_enum.h") - ), - "output_header": str( - importlib.resources.files("tests").joinpath("outputs/test_all_types_header.h") - ), + "output_enums": str(importlib.resources.files("tests").joinpath("outputs/test_all_types_enum.h")), + "output_header": str(importlib.resources.files("tests").joinpath("outputs/test_all_types_header.h")), "output_impl": output_impl, }, } @@ -287,34 +248,34 @@ def test_all_data_types_generate(self): # Verify all types are present expected_types = [ - "uint8_t", "int8_t", "uint16_t", "int16_t", - "uint32_t", "int32_t", "uint64_t", "int64_t", - "float", "double", "bool" + "uint8_t", + "int8_t", + "uint16_t", + "int16_t", + "uint32_t", + "int32_t", + "uint64_t", + "int64_t", + "float", + "double", + "bool", ] for dtype in expected_types: - self.assertIn(dtype, content, - f"Generated code should contain {dtype}") + self.assertIn(dtype, content, f"Generated code should contain {dtype}") # String types use helper functions, so check for that instead of "char[]" - self.assertIn("_avlos_getter_string", content, - "Generated code should contain string helper function") + self.assertIn("_avlos_getter_string", content, "Generated code should contain string helper function") def test_func_attr_in_output(self): """Test that func_attr (e.g., TM_RAMFUNC) appears in generated code.""" - output_impl = str( - importlib.resources.files("tests").joinpath("outputs/test_func_attr.c") - ) + output_impl = str(importlib.resources.files("tests").joinpath("outputs/test_func_attr.c")) config = { "hash_string": "0x9e8dc7ac", "paths": { - "output_enums": str( - importlib.resources.files("tests").joinpath("outputs/test_func_attr_enum.h") - ), - "output_header": str( - importlib.resources.files("tests").joinpath("outputs/test_func_attr_header.h") - ), + "output_enums": str(importlib.resources.files("tests").joinpath("outputs/test_func_attr_enum.h")), + "output_header": str(importlib.resources.files("tests").joinpath("outputs/test_func_attr_header.h")), "output_impl": output_impl, }, } @@ -326,24 +287,17 @@ def test_func_attr_in_output(self): # good_device.yaml has TM_RAMFUNC on some functions if "TM_RAMFUNC" in content: - self.assertIn("TM_RAMFUNC uint8_t avlos_", content, - "func_attr should appear before function declaration") + self.assertIn("TM_RAMFUNC uint8_t avlos_", content, "func_attr should appear before function declaration") def test_endpoint_array_generation(self): """Test that endpoint array is correctly generated.""" - output_impl = str( - importlib.resources.files("tests").joinpath("outputs/test_ep_array.c") - ) + output_impl = str(importlib.resources.files("tests").joinpath("outputs/test_ep_array.c")) config = { "hash_string": "0x9e8dc7ac", "paths": { - "output_enums": str( - importlib.resources.files("tests").joinpath("outputs/test_ep_array_enum.h") - ), - "output_header": str( - importlib.resources.files("tests").joinpath("outputs/test_ep_array_header.h") - ), + "output_enums": str(importlib.resources.files("tests").joinpath("outputs/test_ep_array_enum.h")), + "output_header": str(importlib.resources.files("tests").joinpath("outputs/test_ep_array_header.h")), "output_impl": output_impl, }, } @@ -354,12 +308,10 @@ def test_endpoint_array_generation(self): content = f.read() # Should have endpoint array declaration - self.assertIn("uint8_t (*avlos_endpoints[", content, - "Should declare endpoint array") + self.assertIn("uint8_t (*avlos_endpoints[", content, "Should declare endpoint array") # Should have proto hash function - self.assertIn("_avlos_get_proto_hash", content, - "Should have proto hash function") + self.assertIn("_avlos_get_proto_hash", content, "Should have proto hash function") class TestIntegration(unittest.TestCase): @@ -369,26 +321,18 @@ def test_full_pipeline_c_generation(self): """Test complete C generation pipeline with all features.""" import importlib.resources - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_desc_stream: obj = deserialize(yaml.safe_load(device_desc_stream)) - output_impl = str( - importlib.resources.files("tests").joinpath("outputs/test_integration.c") - ) + output_impl = str(importlib.resources.files("tests").joinpath("outputs/test_integration.c")) config = { "hash_string": "0x12345678", "paths": { - "output_enums": str( - importlib.resources.files("tests").joinpath("outputs/test_integration_enum.h") - ), - "output_header": str( - importlib.resources.files("tests").joinpath("outputs/test_integration_header.h") - ), + "output_enums": str(importlib.resources.files("tests").joinpath("outputs/test_integration_enum.h")), + "output_header": str(importlib.resources.files("tests").joinpath("outputs/test_integration_header.h")), "output_impl": output_impl, }, } @@ -398,6 +342,7 @@ def test_full_pipeline_c_generation(self): # Verify all files exist import os + self.assertTrue(os.path.exists(config["paths"]["output_enums"])) self.assertTrue(os.path.exists(config["paths"]["output_header"])) self.assertTrue(os.path.exists(config["paths"]["output_impl"])) @@ -414,9 +359,7 @@ def test_cpp_generation_pipeline(self): """Test complete C++ generation pipeline.""" import importlib.resources - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_desc_stream: obj = deserialize(yaml.safe_load(device_desc_stream)) @@ -424,15 +367,9 @@ def test_cpp_generation_pipeline(self): config = { "hash_string": "0x12345678", "paths": { - "output_helpers": str( - importlib.resources.files("tests").joinpath("outputs/test_cpp_helpers.hpp") - ), - "output_header": str( - importlib.resources.files("tests").joinpath("outputs/test_cpp_device.hpp") - ), - "output_impl": str( - importlib.resources.files("tests").joinpath("outputs/test_cpp_device.cpp") - ), + "output_helpers": str(importlib.resources.files("tests").joinpath("outputs/test_cpp_helpers.hpp")), + "output_header": str(importlib.resources.files("tests").joinpath("outputs/test_cpp_device.hpp")), + "output_impl": str(importlib.resources.files("tests").joinpath("outputs/test_cpp_device.cpp")), }, } @@ -441,10 +378,11 @@ def test_cpp_generation_pipeline(self): # Verify files exist import os + self.assertTrue(os.path.exists(config["paths"]["output_helpers"])) self.assertTrue(os.path.exists(config["paths"]["output_header"])) self.assertTrue(os.path.exists(config["paths"]["output_impl"])) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_validation.py b/tests/test_validation.py index 13e8fd0..8bc1f14 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,17 +1,20 @@ """ Tests for validation module. """ + import unittest + import yaml + from avlos.deserializer import deserialize from avlos.validation import ( + C_RESERVED_WORDS, + ValidationError, + validate_all, validate_c_identifier, validate_endpoint_ids, validate_function_names, validate_names, - validate_all, - ValidationError, - C_RESERVED_WORDS, ) @@ -52,8 +55,21 @@ def test_invalid_c_identifier_starts_with_digit(self): def test_c_reserved_words(self): """Test that C reserved words are rejected.""" - reserved_samples = ['int', 'void', 'return', 'if', 'else', 'while', 'for', - 'struct', 'union', 'enum', 'static', 'const', '_Bool'] + reserved_samples = [ + "int", + "void", + "return", + "if", + "else", + "while", + "for", + "struct", + "union", + "enum", + "static", + "const", + "_Bool", + ] for word in reserved_samples: self.assertIn(word, C_RESERVED_WORDS) @@ -73,9 +89,7 @@ def test_valid_device_passes_all_validation(self): """Test that good_device.yaml passes all validations.""" import importlib.resources - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_desc_stream: obj = deserialize(yaml.safe_load(device_desc_stream)) @@ -86,9 +100,7 @@ def test_validate_endpoint_ids_no_conflicts(self): """Test endpoint ID validation with no conflicts.""" import importlib.resources - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_desc_stream: obj = deserialize(yaml.safe_load(device_desc_stream)) @@ -99,9 +111,7 @@ def test_validate_function_names_no_conflicts(self): """Test function name validation with no conflicts.""" import importlib.resources - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_desc_stream: obj = deserialize(yaml.safe_load(device_desc_stream)) @@ -112,9 +122,7 @@ def test_validate_names_valid(self): """Test name validation for valid device tree.""" import importlib.resources - def_path_str = str( - importlib.resources.files("tests").joinpath("definition/good_device.yaml") - ) + def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_desc_stream: obj = deserialize(yaml.safe_load(device_desc_stream)) @@ -245,5 +253,5 @@ def test_validate_all_collects_multiple_errors(self): self.assertTrue(len(errors) >= 4, f"Should collect multiple errors, got {len(errors)}: {errors}") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() From a923df4c326b9619fecce700ac2bdf09302ce3be Mon Sep 17 00:00:00 2001 From: Yannis Chatzikonstantinou Date: Sat, 10 Jan 2026 15:05:35 +0200 Subject: [PATCH 5/5] fix config --- .pre-commit-README.md | 57 ++++++++++++++++++++++++++++++++++++ tests/test_remote_objects.py | 3 +- 2 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 .pre-commit-README.md diff --git a/.pre-commit-README.md b/.pre-commit-README.md new file mode 100644 index 0000000..4cad98e --- /dev/null +++ b/.pre-commit-README.md @@ -0,0 +1,57 @@ +# Pre-commit Hooks for Avlos + +This project uses [pre-commit](https://pre-commit.com/) to automatically format and lint code before commits. + +## Setup + +1. Install pre-commit (if not already installed): + ```bash + pip install pre-commit + ``` + +2. Install the git hooks: + ```bash + pre-commit install + ``` + +## What Gets Checked + +The pre-commit hooks run the following checks: + +- **black**: Python code formatting (line length: 127) +- **isort**: Import statement sorting +- **flake8**: Python linting (configured in setup.cfg) +- **trailing-whitespace**: Removes trailing whitespace +- **end-of-file-fixer**: Ensures files end with a newline +- **check-yaml**: Validates YAML syntax +- **check-added-large-files**: Prevents large files from being committed +- **check-merge-conflict**: Detects merge conflict markers +- **mixed-line-ending**: Ensures consistent line endings +- **rstcheck**: Validates RST documentation files + +## Manual Execution + +To run all hooks on all files (useful after initial setup or major changes): +```bash +pre-commit run --all-files +``` + +To run hooks on specific files: +```bash +pre-commit run --files path/to/file1.py path/to/file2.py +``` + +To skip hooks for a single commit (not recommended): +```bash +git commit --no-verify +``` + +## Configuration + +- Pre-commit hooks are configured in `.pre-commit-config.yaml` +- Flake8 and isort settings are in `setup.cfg` +- clang-format style is configured in `.clang-format` + +## CI Integration + +The same linting checks run in GitHub Actions CI, so passing pre-commit locally ensures CI will pass. diff --git a/tests/test_remote_objects.py b/tests/test_remote_objects.py index 0926f60..2be22d7 100644 --- a/tests/test_remote_objects.py +++ b/tests/test_remote_objects.py @@ -118,8 +118,9 @@ def test_meta_dictionary(self): def_path_str = str(importlib.resources.files("tests").joinpath("definition/good_device.yaml")) with open(def_path_str) as device_description: obj = deserialize(yaml.safe_load(device_description)) - self.assertEqual(1, len(obj.errors.meta)) + self.assertEqual(2, len(obj.errors.meta)) self.assertEqual("ok", obj.errors.meta["lalala"]) + self.assertEqual(True, obj.errors.meta["dynamic"]) self.assertEqual(1, len(obj.reset.meta)) self.assertEqual(True, obj.reset.meta["reload_data"]) self.assertEqual(0, len(obj.sn.meta))