diff --git a/plugins/module_utils/common/validators.py b/plugins/module_utils/common/validators.py new file mode 100644 index 000000000..07e9193a7 --- /dev/null +++ b/plugins/module_utils/common/validators.py @@ -0,0 +1,357 @@ +# -*- coding: utf-8 -*- + +# Copyright: (c) 2026, Akshayanat C S (@achengam) + +# GNU General Public License v3.0+ (see LICENSE or https://www.gnu.org/licenses/gpl-3.0.txt) + +"""Generic validators for common network and system fields. + +These validators are reusable across different models and modules. +They follow a consistent pattern: +- Accept str | None or other types +- Return None when input is None or empty after stripping +- Raise ValueError with descriptive messages on validation failure +""" + +from __future__ import annotations + +import re +from ipaddress import ip_address, ip_network +from typing import Callable + +# ------------------------------------------------------------------ +# Internal helpers +# ------------------------------------------------------------------ + + +def _normalize_optional_string(v: str | None) -> str | None: + """Normalize an optional string value. + + Converts the value to string, strips whitespace, and returns ``None`` + if the value is ``None`` or empty after stripping. + + This is a common preprocessing step for string-based validators. + + Args: + v: Raw value that may be ``None``, empty, or contain whitespace. + + Returns: + Stripped non-empty string, or ``None``. + """ + if v is None: + return None + v = str(v).strip() + if not v: + return None + return v + + +def _require_field( + v: str, + validator_func: Callable[[str | None], str | None], + field_name: str, +) -> str: + """Validate and require a non-empty field value. + + Generic wrapper that calls a nullable validator and raises + ``ValueError`` when the result is ``None`` (empty after stripping). + + Args: + v: Raw field value from Pydantic. + validator_func: Validator function that accepts optional string. + field_name: Field name used in the error message. + + Returns: + Validated non-empty string. + + Raises: + ValueError: When the value is empty or fails validation. + """ + result = validator_func(v) + if result is None: + raise ValueError(f"{field_name} cannot be empty") + return result + + +# ------------------------------------------------------------------ +# IP Address and Network Validators +# ------------------------------------------------------------------ + + +def validate_ip_address(v: str | None) -> str | None: + """Validate IPv4 or IPv6 address. + + Args: + v: Raw IP address value. + + Returns: + Validated IP address string, or ``None`` if input is None/empty. + + Raises: + ValueError: When the value is not a valid IPv4/v6 address. + """ + v = _normalize_optional_string(v) + if v is None: + return None + try: + ip_address(v) + return v + except ValueError: + raise ValueError(f"Invalid IP address format: {v}") + + +def validate_cidr(v: str | None) -> str | None: + """Validate CIDR notation (IP/mask). + + Args: + v: Raw CIDR value. + + Returns: + Validated CIDR string, or ``None`` if input is None/empty. + + Raises: + ValueError: When the value is not valid CIDR notation. + """ + v = _normalize_optional_string(v) + if v is None: + return None + if "/" not in v: + raise ValueError(f"CIDR notation required (IP/mask format): {v}") + try: + ip_network(v, strict=False) + return v + except ValueError: + raise ValueError(f"Invalid CIDR format: {v}") + + +def validate_ip_or_cidr_as_cidr(v: str | None) -> str | None: + """Validate IP or CIDR and normalize to CIDR notation. + + Accepts either a plain IP address or CIDR notation. + Plain IPs are normalized to /32 (IPv4) or /128 (IPv6). + + Args: + v: Raw IP or CIDR value. + + Returns: + Normalized CIDR string (e.g., "10.1.1.1/32"), or ``None``. + + Raises: + ValueError: When the value is not a valid IP or CIDR. + """ + v = _normalize_optional_string(v) + if v is None: + return None + + # If already in CIDR notation, validate it + if "/" in v: + try: + network = ip_network(v, strict=False) + return str(network) + except ValueError: + raise ValueError(f"Invalid CIDR format: {v}") + + # Plain IP - validate and append appropriate mask + try: + ip_obj = ip_address(v) + # IPv4 gets /32, IPv6 gets /128 + prefix_len = 32 if ip_obj.version == 4 else 128 + return f"{v}/{prefix_len}" + except ValueError: + raise ValueError(f"Invalid IP address format: {v}") + + +def require_ip_address(v: str) -> str: + """Validate and require a non-empty IP address. + + Args: + v: Raw IP address value. + + Returns: + Validated IP address string. + + Raises: + ValueError: When the value is empty or not a valid IPv4/v6 address. + """ + return _require_field(v, validate_ip_address, "IP address") + + +def require_ip_or_cidr_as_cidr(v: str) -> str: + """Validate and require a non-empty IP or CIDR, normalized to CIDR. + + Args: + v: Raw IP or CIDR value. + + Returns: + Validated CIDR string (plain IPs converted to /32 or /128). + + Raises: + ValueError: When the value is empty or not a valid IP/CIDR. + """ + return _require_field(v, validate_ip_or_cidr_as_cidr, "IP or CIDR") + + +def validate_cidr_optional(v: str | None) -> str | None: + """Validate an optional CIDR string; pass through ``None`` unchanged. + + Args: + v: Raw CIDR value or ``None``. + + Returns: + Validated CIDR string, or ``None``. + + Raises: + ValueError: When the value is present but not valid CIDR notation. + """ + if v is None: + return None + return _require_field(v, validate_cidr, "CIDR") + + +# ------------------------------------------------------------------ +# Hostname and MAC Address Validators +# ------------------------------------------------------------------ + + +def validate_hostname(v: str | None) -> str | None: + """Validate hostname format (RFC 1123). + + Args: + v: Raw hostname value. + + Returns: + Validated hostname string, or ``None`` if input is None/empty. + + Raises: + ValueError: When the value fails RFC 1123 checks. + """ + v = _normalize_optional_string(v) + if v is None: + return None + # RFC 1123 hostname validation + if len(v) > 255: + raise ValueError("Hostname cannot exceed 255 characters") + # Allow alphanumeric, dots, hyphens, underscores + # Must start with alphanumeric, cannot end with dot or contain consecutive dots + if not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9._-]*$", v): + raise ValueError(f"Invalid hostname format. Must start with alphanumeric and contain only alphanumeric, dots, hyphens, underscores: {v}") + if v.endswith(".") or ".." in v: + raise ValueError(f"Invalid hostname format. Cannot end with dot or contain consecutive dots: {v}") + return v + + +def require_hostname(v: str) -> str: + """Validate and require a non-empty hostname. + + Args: + v: Raw hostname value. + + Returns: + Validated hostname string. + + Raises: + ValueError: When the value is empty or fails RFC 1123 checks. + """ + return _require_field(v, validate_hostname, "hostname") + + +def validate_mac_address(v: str | None) -> str | None: + """Validate MAC address format and normalize. + + Accepts multiple common MAC address formats: + - Colon-separated: AA:BB:CC:DD:EE:FF + - Hyphen-separated: AA-BB-CC-DD-EE-FF + - Cisco dot-notation: aabb.ccdd.eeff + - Bare hex: aabbccddeeff + + Returns normalized uppercase colon-separated format: AA:BB:CC:DD:EE:FF + + Args: + v: Raw MAC address value. + + Returns: + Normalized MAC address (AA:BB:CC:DD:EE:FF), or ``None`` if input is None/empty. + + Raises: + ValueError: When the value is not a valid MAC address. + """ + v = _normalize_optional_string(v) + if v is None: + return None + + # Remove common separators and convert to lowercase + clean_mac = v.lower().replace(":", "").replace("-", "").replace(".", "") + + # Validate it's exactly 12 hex characters + if not re.match(r"^[0-9a-f]{12}$", clean_mac): + raise ValueError( + f"Invalid MAC address format: {v}. " + f"Expected 12 hex digits in formats like AA:BB:CC:DD:EE:FF, " + f"AA-BB-CC-DD-EE-FF, aabb.ccdd.eeff, or aabbccddeeff" + ) + + # Normalize to colon-separated format (AA:BB:CC:DD:EE:FF) + normalized = ":".join(clean_mac[i : i + 2] for i in range(0, 12, 2)) + return normalized.upper() + + +def require_mac_address(v: str) -> str: + """Validate and require a non-empty MAC address. + + Args: + v: Raw MAC address value. + + Returns: + Normalized MAC address string (AA:BB:CC:DD:EE:FF). + + Raises: + ValueError: When the value is empty or not a valid MAC address. + """ + return _require_field(v, validate_mac_address, "MAC address") + + +# ------------------------------------------------------------------ +# Credential Pair Validators +# ------------------------------------------------------------------ + + +def check_credentials_pair( + username: str | None, + password: str | None, + username_field: str = "username", + password_field: str = "password", +) -> None: + """Enforce mutual-presence of credential pairs. + + Both username and password must either be absent together or present together. + + Args: + username: Username value (may be ``None``). + password: Password value (may be ``None``). + username_field: Field name for username (for error messages). + password_field: Field name for password (for error messages). + + Raises: + ValueError: When exactly one of the two is provided. + """ + has_user = bool(username) + has_pass = bool(password) + if has_user and not has_pass: + raise ValueError(f"{password_field} must be set when {username_field} is specified") + if has_pass and not has_user: + raise ValueError(f"{username_field} must be set when {password_field} is specified") + + +__all__ = [ + "validate_ip_address", + "validate_cidr", + "validate_ip_or_cidr_as_cidr", + "validate_hostname", + "validate_mac_address", + "require_ip_address", + "require_ip_or_cidr_as_cidr", + "require_hostname", + "require_mac_address", + "validate_cidr_optional", + "check_credentials_pair", +] diff --git a/plugins/module_utils/endpoints/mixins.py b/plugins/module_utils/endpoints/mixins.py index e7f0620c9..940f09043 100644 --- a/plugins/module_utils/endpoints/mixins.py +++ b/plugins/module_utils/endpoints/mixins.py @@ -32,6 +32,12 @@ class FabricNameMixin(BaseModel): fabric_name: Optional[str] = Field(default=None, min_length=1, max_length=64, description="Fabric name") +class FilterMixin(BaseModel): + """Mixin for endpoints that require a Lucene filter expression.""" + + filter: Optional[str] = Field(default=None, min_length=1, description="Lucene filter expression") + + class ForceShowRunMixin(BaseModel): """Mixin for endpoints that require force_show_run parameter.""" @@ -62,16 +68,22 @@ class LoginIdMixin(BaseModel): login_id: Optional[str] = Field(default=None, min_length=1, description="Login ID") +class MaxMixin(BaseModel): + """Mixin for endpoints that require a max results parameter.""" + + max: Optional[int] = Field(default=None, ge=1, description="Maximum number of results") + + class NetworkNameMixin(BaseModel): """Mixin for endpoints that require network_name parameter.""" network_name: Optional[str] = Field(default=None, min_length=1, max_length=64, description="Network name") -class NodeNameMixin(BaseModel): - """Mixin for endpoints that require node_name parameter.""" +class OffsetMixin(BaseModel): + """Mixin for endpoints that require a pagination offset parameter.""" - node_name: Optional[str] = Field(default=None, min_length=1, description="Node name") + offset: Optional[int] = Field(default=None, ge=0, description="Pagination offset") class SwitchSerialNumberMixin(BaseModel): @@ -80,7 +92,19 @@ class SwitchSerialNumberMixin(BaseModel): switch_sn: Optional[str] = Field(default=None, min_length=1, description="Switch serial number") +class TicketIdMixin(BaseModel): + """Mixin for endpoints that require ticket_id parameter.""" + + ticket_id: Optional[str] = Field(default=None, min_length=1, description="Change control ticket ID") + + class VrfNameMixin(BaseModel): """Mixin for endpoints that require vrf_name parameter.""" vrf_name: Optional[str] = Field(default=None, min_length=1, max_length=64, description="VRF name") + + +class NodeNameMixin(BaseModel): + """Mixin for endpoints that require node_name parameter.""" + + node_name: Optional[str] = Field(default=None, min_length=1, description="Node name") diff --git a/plugins/module_utils/endpoints/v1/manage/manage_fabrics_switches.py b/plugins/module_utils/endpoints/v1/manage/manage_fabrics_switches.py new file mode 100644 index 000000000..de65b02dc --- /dev/null +++ b/plugins/module_utils/endpoints/v1/manage/manage_fabrics_switches.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- + +# Copyright: (c) 2026, Akshayanat C S (@achengam) + +# GNU General Public License v3.0+ (see LICENSE or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +ND Manage Fabric Switches endpoint models. + +This module contains endpoint definitions for switch CRUD operations +within fabrics in the ND Manage API. + +Endpoints covered: +- List switches in a fabric +- Add switches to a fabric +""" + +from __future__ import annotations + +__author__ = "Akshayanat C S" + +from typing import Literal + +from ansible_collections.cisco.nd.plugins.module_utils.enums import HttpVerbEnum +from ansible_collections.cisco.nd.plugins.module_utils.endpoints.mixins import ( + ClusterNameMixin, + FabricNameMixin, + FilterMixin, + MaxMixin, + OffsetMixin, + TicketIdMixin, +) +from ansible_collections.cisco.nd.plugins.module_utils.endpoints.query_params import ( + EndpointQueryParams, +) +from ansible_collections.cisco.nd.plugins.module_utils.endpoints.v1.manage.base_path import ( + BasePath, +) +from ansible_collections.cisco.nd.plugins.module_utils.common.pydantic_compat import ( + Field, +) +from ansible_collections.cisco.nd.plugins.module_utils.endpoints.base import ( + NDEndpointBaseModel, +) + + +class FabricSwitchesGetEndpointParams(FilterMixin, MaxMixin, OffsetMixin, EndpointQueryParams): + """ + # Summary + + Endpoint-specific query parameters for list fabric switches endpoint. + + ## Parameters + + - hostname: Filter by switch hostname (optional) + - max: Maximum number of results (optional, from `MaxMixin`) + - offset: Pagination offset (optional, from `OffsetMixin`) + - filter: Lucene filter expression (optional, from `FilterMixin`) + + ## Usage + + ```python + params = FabricSwitchesGetEndpointParams(hostname="leaf1", max=100) + query_string = params.to_query_string() + # Returns: "hostname=leaf1&max=100" + ``` + """ + + hostname: str | None = Field(default=None, min_length=1, description="Filter by switch hostname") + + +class FabricSwitchesAddEndpointParams(ClusterNameMixin, TicketIdMixin, EndpointQueryParams): + """ + # Summary + + Endpoint-specific query parameters for add switches to fabric endpoint. + + ## Parameters + + - cluster_name: Target cluster name for multi-cluster deployments (optional, from `ClusterNameMixin`) + - ticket_id: Change control ticket ID (optional, from `TicketIdMixin`) + + ## Usage + + ```python + params = FabricSwitchesAddEndpointParams(cluster_name="cluster1", ticket_id="CHG12345") + query_string = params.to_query_string() + # Returns: "clusterName=cluster1&ticketId=CHG12345" + ``` + """ + + +class _EpManageFabricsSwitchesBase(FabricNameMixin, NDEndpointBaseModel): + """ + Base class for Fabric Switches endpoints. + + Provides common functionality for all HTTP methods on the + /api/v1/manage/fabrics/{fabricName}/switches endpoint. + """ + + @property + def _base_path(self) -> str: + """Build the base endpoint path.""" + if self.fabric_name is None: + raise ValueError("fabric_name must be set before accessing path") + return BasePath.path("fabrics", self.fabric_name, "switches") + + +class EpManageFabricsSwitchesGet(_EpManageFabricsSwitchesBase): + """ + # Summary + + List Fabric Switches Endpoint + + ## Description + + Endpoint to list all switches in a specific fabric with optional filtering. + + ## Path + + - /api/v1/manage/fabrics/{fabricName}/switches + - /api/v1/manage/fabrics/{fabricName}/switches?hostname=leaf1&max=100 + + ## Verb + + - GET + + ## Query Parameters + + - hostname: Filter by switch hostname (optional) + - max: Maximum number of results (optional) + - offset: Pagination offset (optional) + - filter: Lucene filter expression (optional) + + ## Usage + + ```python + # List all switches + request = EpManageFabricsSwitchesGet() + request.fabric_name = "MyFabric" + path = request.path + verb = request.verb + + # List with filtering + request = EpManageFabricsSwitchesGet() + request.fabric_name = "MyFabric" + request.endpoint_params.hostname = "leaf1" + request.endpoint_params.max = 100 + path = request.path + verb = request.verb + # Path will be: /api/v1/manage/fabrics/MyFabric/switches?hostname=leaf1&max=100 + ``` + """ + + class_name: Literal["EpManageFabricsSwitchesGet"] = Field( + default="EpManageFabricsSwitchesGet", + frozen=True, + description="Class name for backward compatibility", + ) + endpoint_params: FabricSwitchesGetEndpointParams = Field( + default_factory=FabricSwitchesGetEndpointParams, + description="Endpoint-specific query parameters", + ) + + @property + def path(self) -> str: + """ + # Summary + + Build the endpoint path with optional query string. + + ## Returns + + - Complete endpoint path string, optionally including query parameters + """ + query_string = self.endpoint_params.to_query_string() + if query_string: + return f"{self._base_path}?{query_string}" + return self._base_path + + @property + def verb(self) -> HttpVerbEnum: + """Return the HTTP verb for this endpoint.""" + return HttpVerbEnum.GET diff --git a/plugins/module_utils/fabric_inventory.py b/plugins/module_utils/fabric_inventory.py new file mode 100644 index 000000000..be8d21e7a --- /dev/null +++ b/plugins/module_utils/fabric_inventory.py @@ -0,0 +1,174 @@ +# Copyright: (c) 2026, Jeet Ram (@jeeram) +# Copyright: (c) 2026, Akshayanat C S (@achengam) + +# GNU General Public License v3.0+ (see LICENSE or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import annotations + +import logging +from typing import Any + +from ansible_collections.cisco.nd.plugins.module_utils.endpoints.v1.manage.manage_fabrics_switches import ( + EpManageFabricsSwitchesGet, +) +from ansible_collections.cisco.nd.plugins.module_utils.nd_config_collection import ( + NDConfigCollection, +) + +# ========================================================================= +# Exceptions +# ========================================================================= + + +class SwitchOperationError(Exception): + """Raised when a switch operation fails.""" + + +# ========================================================================= +# API Response Validation +# ========================================================================= + + +class ApiDataChecker: + """Detect controller-embedded errors in API response DATA payloads. + + The Nexus Dashboard API signals certain errors by embedding an error + object inside ``DATA`` as ``{"code": , "message": ""}`` even + when the transport-level result is marked successful. Any payload dict + that contains a ``"code"`` key is treated as an error; the absence of + ``"code"`` means the payload is a genuine data body. + """ + + @staticmethod + def check( + data: Any, + context: str, + log: logging.Logger, + fail_callback=None, + ) -> None: + """Fail or raise if the response DATA contains an embedded error code. + + Args: + data: Value returned by ``nd.request()`` or extracted from + ``response_current["DATA"]``. + context: Human-readable description of the operation. + log: Logger instance. + fail_callback: Optional callable (e.g. ``module.fail_json``) that + accepts a ``msg`` keyword argument. When provided + it is called on error instead of raising + ``SwitchOperationError``. + """ + if isinstance(data, dict) and "code" in data: + error_msg = data.get("message", "Unknown error") + msg = f"{context} failed \u2014 controller returned error: " f"{error_msg} (code={data['code']})" + log.error(msg) + if fail_callback is not None: + fail_callback(msg=msg) + else: + raise SwitchOperationError(msg) + + +# ========================================================================= +# Fabric Switch Inventory +# ========================================================================= + + +class FabricSwitchInventory: + """Index a list of switch model instances for fast lookup by IP or ID. + + Use :meth:`from_fabric` to fetch, parse, and index in a single call, or + construct directly from an already-parsed list. :meth:`by_ip` and + :meth:`by_id` return keyed lookup dicts. + + Example:: + + inventory = FabricSwitchInventory.from_fabric(nd, fabric, log, SwitchDataModel) + switch = inventory.by_ip().get("192.0.2.1") + switch = inventory.by_id().get("FDO123456AB") + collection = inventory.collection # NDConfigCollection + """ + + def __init__(self, switches: list) -> None: + """Initialise the index from an already-parsed list of switch models. + + Args: + switches: List of parsed switch model instances. + """ + self.switches: list = switches + self.collection: NDConfigCollection | None = None + + @classmethod + def from_fabric(cls, nd, fabric: str, log: logging.Logger, model_class: type) -> "FabricSwitchInventory": + """Fetch, parse, and index the switch inventory for a fabric in one call. + + Args: + nd: NDModule instance used for the API request. + fabric: Fabric name to query. + log: Logger instance. + model_class: Pydantic model class to parse switch entries into + (e.g. ``SwitchDataModel``). + + Returns: + A new ``FabricSwitchInventory`` with ``switches`` and + ``collection`` populated. + """ + raw = cls.query_fabric_switches(nd, fabric, log) + collection = NDConfigCollection.from_api_response(response_data=raw, model_class=model_class) + instance = cls(list(collection)) + instance.collection = collection + return instance + + def by_ip(self) -> dict[str, Any]: + """Return switches keyed by fabric management IP address. + + Returns: + Dict mapping ``fabric_management_ip`` → model instance. + Entries with an empty or ``None`` IP are excluded. + """ + return {sw.fabric_management_ip: sw for sw in self.switches if sw.fabric_management_ip} + + def by_id(self) -> dict[str, Any]: + """Return switches keyed by switch ID (serial number). + + Returns: + Dict mapping ``switch_id`` → model instance. + Entries with an empty or ``None`` ID are excluded. + """ + return {sw.switch_id: sw for sw in self.switches if sw.switch_id} + + @staticmethod + def query_fabric_switches(nd, fabric: str, log: logging.Logger) -> list[dict[str, Any]]: + """Fetch the raw switch inventory list for a fabric from the controller. + + Args: + nd: NDModule instance used for the API request. + fabric: Fabric name to query. + log: Logger instance. + + Returns: + List of raw switch dicts as returned by the controller API. + """ + endpoint = EpManageFabricsSwitchesGet() + endpoint.fabric_name = fabric + log.debug("query_fabric_switches: querying inventory for fabric '%s'", fabric) + + try: + response = nd.request(path=endpoint.path, verb=endpoint.verb) + except Exception as exc: + msg = f"Failed to retrieve switch inventory for fabric '{fabric}': {exc}" + log.error(msg) + nd.module.fail_json(msg=msg) + return [] + + ApiDataChecker.check( + response, + f"Switch inventory query for fabric '{fabric}'", + log, + nd.module.fail_json, + ) + + if isinstance(response, list): + return response + if isinstance(response, dict): + return response.get("switches", []) + return [] diff --git a/plugins/module_utils/models/manage_switches/__init__.py b/plugins/module_utils/models/manage_switches/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/plugins/module_utils/models/manage_switches/enums.py b/plugins/module_utils/models/manage_switches/enums.py new file mode 100644 index 000000000..949b2e8c6 --- /dev/null +++ b/plugins/module_utils/models/manage_switches/enums.py @@ -0,0 +1,258 @@ +# -*- coding: utf-8 -*- + +# Copyright: (c) 2026, Akshayanat C S (@achengam) + +# GNU General Public License v3.0+ (see LICENSE or https://www.gnu.org/licenses/gpl-3.0.txt) + +"""Enumerations for Switch and Inventory Operations. + +Extracted from OpenAPI schema (manage.json) for Nexus Dashboard Manage APIs v1.1.332. +""" + +from __future__ import annotations + +from enum import Enum + +# ============================================================================= +# ENUMS - Extracted from OpenAPI Schema components/schemas +# ============================================================================= + + +class SwitchRole(str, Enum): + """ + Switch role enumeration. + + Based on: components/schemas/switchRole + Description: The role of the switch, meta is a read-only switch role + """ + + BORDER = "border" + BORDER_GATEWAY = "borderGateway" + BORDER_GATEWAY_SPINE = "borderGatewaySpine" + BORDER_GATEWAY_SUPER_SPINE = "borderGatewaySuperSpine" + BORDER_SPINE = "borderSpine" + BORDER_SUPER_SPINE = "borderSuperSpine" + LEAF = "leaf" + SPINE = "spine" + SUPER_SPINE = "superSpine" + TIER2_LEAF = "tier2Leaf" + TOR = "tor" + ACCESS = "access" + AGGREGATION = "aggregation" + CORE_ROUTER = "coreRouter" + EDGE_ROUTER = "edgeRouter" + META = "meta" # read-only + NEIGHBOR = "neighbor" + + @classmethod + def choices(cls) -> list[str]: + """Return list of valid choices.""" + return [e.value for e in cls] + + +class SystemMode(str, Enum): + """ + System mode enumeration. + + Based on: components/schemas/systemMode + """ + + NORMAL = "normal" + MAINTENANCE = "maintenance" + MIGRATION = "migration" + INCONSISTENT = "inconsistent" + WAITING = "waiting" + NOT_APPLICABLE = "notApplicable" + + @classmethod + def choices(cls) -> list[str]: + return [e.value for e in cls] + + +class PlatformType(str, Enum): + """ + Switch platform type enumeration. + + Used for POST /fabrics/{fabricName}/switches (AddSwitches). + Includes all platform types supported by the add-switches endpoint. + Based on: components/schemas + """ + + NX_OS = "nx-os" + OTHER = "other" + IOS_XE = "ios-xe" + IOS_XR = "ios-xr" + SONIC = "sonic" + APIC = "apic" + + @classmethod + def choices(cls) -> list[str]: + return [e.value for e in cls] + + +class SnmpV3AuthProtocol(str, Enum): + """ + SNMPv3 authentication protocols. + + Based on: components/schemas/snmpV3AuthProtocol and schemas-snmpV3AuthProtocol + """ + + MD5 = "md5" + SHA = "sha" + MD5_DES = "md5-des" + MD5_AES = "md5-aes" + SHA_AES = "sha-aes" + SHA_DES = "sha-des" + SHA_AES_256 = "sha-aes-256" + SHA_224 = "sha-224" + SHA_224_AES = "sha-224-aes" + SHA_224_AES_256 = "sha-224-aes-256" + SHA_256 = "sha-256" + SHA_256_AES = "sha-256-aes" + SHA_256_AES_256 = "sha-256-aes-256" + SHA_384 = "sha-384" + SHA_384_AES = "sha-384-aes" + SHA_384_AES_256 = "sha-384-aes-256" + SHA_512 = "sha-512" + SHA_512_AES = "sha-512-aes" + SHA_512_AES_256 = "sha-512-aes-256" + + @classmethod + def choices(cls) -> list[str]: + return [e.value for e in cls] + + +class DiscoveryStatus(str, Enum): + """ + Switch discovery status. + + Based on: components/schemas/additionalSwitchData.discoveryStatus + """ + + OK = "ok" + DISCOVERING = "discovering" + REDISCOVERING = "rediscovering" + DEVICE_SHUTTING_DOWN = "deviceShuttingDown" + UNREACHABLE = "unreachable" + IP_ADDRESS_CHANGE = "ipAddressChange" + DISCOVERY_TIMEOUT = "discoveryTimeout" + RETRYING = "retrying" + SSH_SESSION_ERROR = "sshSessionError" + TIMEOUT = "timeout" + UNKNOWN_USER_PASSWORD = "unknownUserPassword" + CONNECTION_ERROR = "connectionError" + NOT_APPLICABLE = "notApplicable" + + @classmethod + def choices(cls) -> list[str]: + return [e.value for e in cls] + + +class ConfigSyncStatus(str, Enum): + """ + Configuration sync status. + + Based on: components/schemas/switchConfigSyncStatus + """ + + DEPLOYED = "deployed" + DEPLOYMENT_IN_PROGRESS = "deploymentInProgress" + FAILED = "failed" + IN_PROGRESS = "inProgress" + IN_SYNC = "inSync" + NOT_APPLICABLE = "notApplicable" + OUT_OF_SYNC = "outOfSync" + PENDING = "pending" + PREVIEW_IN_PROGRESS = "previewInProgress" + SUCCESS = "success" + + @classmethod + def choices(cls) -> list[str]: + return [e.value for e in cls] + + +class VpcRole(str, Enum): + """ + VPC role enumeration. + + Based on: components/schemas/schemas-vpcRole + """ + + PRIMARY = "primary" + SECONDARY = "secondary" + OPERATIONAL_PRIMARY = "operationalPrimary" + OPERATIONAL_SECONDARY = "operationalSecondary" + NONE_ESTABLISHED = "noneEstablished" + + @classmethod + def choices(cls) -> list[str]: + return [e.value for e in cls] + + +class RemoteCredentialStore(str, Enum): + """ + Remote credential store type. + + Based on: components/schemas/remoteCredentialStore + """ + + LOCAL = "local" + CYBERARK = "cyberark" + + @classmethod + def choices(cls) -> list[str]: + return [e.value for e in cls] + + +class AnomalyLevel(str, Enum): + """ + Anomaly level classification. + + Based on: components/schemas/anomalyLevel + """ + + CRITICAL = "critical" + MAJOR = "major" + MINOR = "minor" + WARNING = "warning" + HEALTHY = "healthy" + NOT_APPLICABLE = "notApplicable" + UNKNOWN = "unknown" + + @classmethod + def choices(cls) -> list[str]: + return [e.value for e in cls] + + +class AdvisoryLevel(str, Enum): + """ + Advisory level classification. + + Based on: components/schemas/advisoryLevel + """ + + CRITICAL = "critical" + MAJOR = "major" + MINOR = "minor" + WARNING = "warning" + HEALTHY = "healthy" + NONE = "none" + NOT_APPLICABLE = "notApplicable" + + @classmethod + def choices(cls) -> list[str]: + return [e.value for e in cls] + + +__all__ = [ + "SwitchRole", + "SystemMode", + "PlatformType", + "SnmpV3AuthProtocol", + "DiscoveryStatus", + "ConfigSyncStatus", + "VpcRole", + "RemoteCredentialStore", + "AnomalyLevel", + "AdvisoryLevel", +] diff --git a/plugins/module_utils/models/manage_switches/switch_data_models.py b/plugins/module_utils/models/manage_switches/switch_data_models.py new file mode 100644 index 000000000..5edf50aa2 --- /dev/null +++ b/plugins/module_utils/models/manage_switches/switch_data_models.py @@ -0,0 +1,338 @@ +# -*- coding: utf-8 -*- + +# Copyright: (c) 2026, Akshayanat C S (@achengam) + +# GNU General Public License v3.0+ (see LICENSE or https://www.gnu.org/licenses/gpl-3.0.txt) + +"""Switch inventory data models (API response representations). + +Based on OpenAPI schema for Nexus Dashboard Manage APIs v1.1.332. +""" + +from __future__ import annotations + +from typing import Any, ClassVar, Literal + +from ansible_collections.cisco.nd.plugins.module_utils.common.pydantic_compat import ( + Field, + field_validator, +) +from ansible_collections.cisco.nd.plugins.module_utils.models.base import NDBaseModel +from ansible_collections.cisco.nd.plugins.module_utils.models.nested import ( + NDNestedModel, +) + +from ansible_collections.cisco.nd.plugins.module_utils.models.manage_switches.enums import ( + AdvisoryLevel, + AnomalyLevel, + ConfigSyncStatus, + DiscoveryStatus, + PlatformType, + RemoteCredentialStore, + SwitchRole, + SystemMode, + VpcRole, +) +from .validators import require_serial_number, validate_ip_address + + +class TelemetryIpCollection(NDNestedModel): + """ + Inband and out-of-band telemetry IP addresses for a switch. + """ + + identifiers: ClassVar[list[str]] = [] + inband_ipv4_address: str | None = Field(default=None, alias="inbandIpV4Address", description="Inband IPv4 address") + inband_ipv6_address: str | None = Field(default=None, alias="inbandIpV6Address", description="Inband IPv6 address") + out_of_band_ipv4_address: str | None = Field( + default=None, + alias="outOfBandIpV4Address", + description="Out of band IPv4 address", + ) + out_of_band_ipv6_address: str | None = Field( + default=None, + alias="outOfBandIpV6Address", + description="Out of band IPv6 address", + ) + + +class VpcData(NDNestedModel): + """ + vPC pair configuration and operational status for a switch. + """ + + identifiers: ClassVar[list[str]] = [] + vpc_domain: int = Field(alias="vpcDomain", ge=1, le=1000, description="vPC domain ID") + peer_switch_id: str = Field(alias="peerSwitchId", description="vPC peer switch serial number") + consistent_status: bool | None = Field( + default=None, + alias="consistentStatus", + description="Flag to indicate the vPC status is consistent", + ) + intended_peer_name: str | None = Field( + default=None, + alias="intendedPeerName", + description="Intended vPC host name for pre-provisioned peer switch", + ) + keep_alive_status: str | None = Field(default=None, alias="keepAliveStatus", description="vPC peer keep alive status") + peer_link_status: str | None = Field(default=None, alias="peerLinkStatus", description="vPC peer link status") + peer_name: str | None = Field(default=None, alias="peerName", description="vPC peer switch name") + vpc_role: VpcRole | None = Field(default=None, alias="vpcRole", description="The vPC role") + + @field_validator("peer_switch_id", mode="before") + @classmethod + def validate_peer_serial(cls, v: str) -> str: + return require_serial_number(v, "peer_switch_id") + + +class SwitchMetadata(NDNestedModel): + """ + Internal database identifiers associated with a switch record. + """ + + identifiers: ClassVar[list[str]] = [] + switch_db_id: int | None = Field(default=None, alias="switchDbId", description="Database Id of the switch") + switch_uuid: str | None = Field(default=None, alias="switchUuid", description="Internal unique Id of the switch") + + +class AdditionalSwitchData(NDNestedModel): + """ + Platform-specific additional data for NX-OS switches. + """ + + identifiers: ClassVar[list[str]] = [] + usage: str | None = Field(default="others", description="The usage of additional data") + config_sync_status: ConfigSyncStatus | None = Field(default=None, alias="configSyncStatus", description="Configuration sync status") + discovery_status: DiscoveryStatus | None = Field(default=None, alias="discoveryStatus", description="Discovery status") + domain_name: str | None = Field(default=None, alias="domainName", description="Domain name") + smart_switch: bool | None = Field( + default=None, + alias="smartSwitch", + description="Flag that indicates if the switch is equipped with DPUs or not", + ) + hypershield_connectivity_status: str | None = Field( + default=None, + alias="hypershieldConnectivityStatus", + description="Smart switch connectivity status to hypershield controller", + ) + hypershield_tenant: str | None = Field(default=None, alias="hypershieldTenant", description="Hypershield tenant name") + hypershield_integration_name: str | None = Field( + default=None, + alias="hypershieldIntegrationName", + description="Hypershield Integration Id", + ) + source_interface_name: str | None = Field( + default=None, + alias="sourceInterfaceName", + description="Source interface for switch discovery", + ) + source_vrf_name: str | None = Field( + default=None, + alias="sourceVrfName", + description="Source VRF for switch discovery", + ) + platform_type: PlatformType | None = Field(default=None, alias="platformType", description="Platform type of the switch") + discovered_system_mode: SystemMode | None = Field(default=None, alias="discoveredSystemMode", description="Discovered system mode") + intended_system_mode: SystemMode | None = Field(default=None, alias="intendedSystemMode", description="Intended system mode") + scalable_unit: str | None = Field(default=None, alias="scalableUnit", description="Name of the scalable unit") + system_mode: SystemMode | None = Field(default=None, alias="systemMode", description="System mode") + vendor: str | None = Field(default=None, description="Vendor of the switch") + username: str | None = Field(default=None, description="Discovery user name") + remote_credential_store: RemoteCredentialStore | None = Field(default=None, alias="remoteCredentialStore") + meta: SwitchMetadata | None = Field(default=None, description="Switch metadata") + + +class AdditionalAciSwitchData(NDNestedModel): + """ + Platform-specific additional data for ACI leaf and spine switches. + """ + + identifiers: ClassVar[list[str]] = [] + usage: str | None = Field(default="aci", description="The usage of additional data") + admin_status: Literal["inService", "outOfService"] | None = Field(default=None, alias="adminStatus", description="Admin status") + health_score: int | None = Field( + default=None, + alias="healthScore", + ge=1, + le=100, + description="Switch health score", + ) + last_reload_time: str | None = Field( + default=None, + alias="lastReloadTime", + description="Timestamp when the system is last reloaded", + ) + last_software_update_time: str | None = Field( + default=None, + alias="lastSoftwareUpdateTime", + description="Timestamp when the software is last updated", + ) + node_id: int | None = Field(default=None, alias="nodeId", ge=1, description="Node ID") + node_status: Literal["active", "inActive"] | None = Field(default=None, alias="nodeStatus", description="Node status") + pod_id: int | None = Field(default=None, alias="podId", ge=1, description="Pod ID") + remote_leaf_group_name: str | None = Field(default=None, alias="remoteLeafGroupName", description="Remote leaf group name") + switch_added: str | None = Field( + default=None, + alias="switchAdded", + description="Timestamp when the switch is added", + ) + tep_pool: str | None = Field(default=None, alias="tepPool", description="TEP IP pool") + + +class Metadata(NDNestedModel): + """ + Pagination and result-count metadata from a list API response. + """ + + identifiers: ClassVar[list[str]] = [] + + counts: dict[str, int] | None = Field(default=None, description="Count information including total and remaining") + + +class SwitchDataModel(NDBaseModel): + """ + Inventory record for a single switch as returned by the fabric switches API. + + Path: GET /fabrics/{fabricName}/switches + """ + + identifiers: ClassVar[list[str]] = ["switch_id"] + identifier_strategy: ClassVar[Literal["single", "composite", "hierarchical", "singleton"] | None] = "single" + exclude_from_diff: ClassVar[set] = {"system_up_time", "anomaly_level", "advisory_level", "alert_suspend"} + switch_id: str = Field( + alias="switchId", + description="Serial number of Switch or Node Id of ACI switch", + ) + serial_number: str | None = Field( + default=None, + alias="serialNumber", + description="Serial number of switch or APIC controller node", + ) + additional_data: AdditionalSwitchData | AdditionalAciSwitchData | None = Field(default=None, alias="additionalData", description="Additional switch data") + advisory_level: AdvisoryLevel | None = Field(default=None, alias="advisoryLevel") + anomaly_level: AnomalyLevel | None = Field(default=None, alias="anomalyLevel") + alert_suspend: str | None = Field(default=None, alias="alertSuspend") + fabric_management_ip: str | None = Field( + default=None, + alias="fabricManagementIp", + description="Switch IPv4/v6 address used for management", + ) + fabric_name: str | None = Field(default=None, alias="fabricName", description="Fabric name", max_length=64) + fabric_type: str | None = Field(default=None, alias="fabricType", description="Fabric type") + hostname: str | None = Field(default=None, description="Switch host name") + model: str | None = Field(default=None, description="Model of switch or APIC controller node") + software_version: str | None = Field( + default=None, + alias="softwareVersion", + description="Software version of switch or APIC controller node", + ) + switch_role: SwitchRole | None = Field(default=None, alias="switchRole") + system_up_time: str | None = Field(default=None, alias="systemUpTime", description="System up time") + vpc_configured: bool | None = Field( + default=None, + alias="vpcConfigured", + description="Flag to indicate switch is part of a vPC domain", + ) + vpc_data: VpcData | None = Field(default=None, alias="vpcData") + telemetry_ip_collection: TelemetryIpCollection | None = Field(default=None, alias="telemetryIpCollection") + + @field_validator("additional_data", mode="before") + @classmethod + def parse_additional_data(cls, v: Any) -> Any: + """Route additionalData to the correct nested model. + + The ND API may omit the ``usage`` field for non-ACI switches. + Default to ``"others"`` so Pydantic selects ``AdditionalSwitchData`` + and coerces ``discoveryStatus`` / ``systemMode`` as proper enums. + """ + if v is None or not isinstance(v, dict): + return v + if "usage" not in v: + v = {**v, "usage": "others"} + return v + + @field_validator("switch_id", mode="before") + @classmethod + def validate_switch_id(cls, v: str) -> str: + return require_serial_number(v, "switch_id") + + @field_validator("fabric_management_ip", mode="before") + @classmethod + def validate_mgmt_ip(cls, v: str | None) -> str | None: + return validate_ip_address(v) + + def to_payload(self) -> dict[str, Any]: + """Convert to API payload format.""" + return self.model_dump(by_alias=True, exclude_none=True) + + @classmethod + def from_response(cls, response: dict[str, Any]) -> "SwitchDataModel": + """ + Create model instance from API response. + + Handles two response formats: + 1. Inventory API format: {switchId, fabricManagementIp, switchRole, ...} + 2. Discovery API format: {serialNumber, ip, hostname, model, softwareVersion, status, ...} + + Args: + response: Response dict from either inventory or discovery API + + Returns: + SwitchDataModel instance + """ + # Detect format and transform if needed + if "switchId" in response or "fabricManagementIp" in response: + # Already in inventory format - use as-is + return cls.model_validate(response) + + # Discovery format - transform to inventory format + transformed = { + "switchId": response.get("serialNumber"), + "serialNumber": response.get("serialNumber"), + "fabricManagementIp": response.get("ip"), + "hostname": response.get("hostname"), + "model": response.get("model"), + "softwareVersion": response.get("softwareVersion"), + "mode": response.get("mode", "Normal"), + } + + # Only add switchRole if present in response (avoid overwriting with None) + if "switchRole" in response: + transformed["switchRole"] = response["switchRole"] + elif "role" in response: + transformed["switchRole"] = response["role"] + + return cls.model_validate(transformed) + + def to_config_dict(self) -> dict[str, Any]: + """Return this inventory record using the 7 standard user-facing fields. + + Produces a consistent dict for previous/current output keys. All 7 + fields are always present (None when not available). Credential fields + are never included. + + Returns: + Dict with keys: seed_ip, serial_number, hostname, model, + role, software_version, mode. + """ + ad = self.additional_data + return { + "seed_ip": self.fabric_management_ip or self.switch_id or "", + "serial_number": self.serial_number, + "hostname": self.hostname, + "model": self.model, + "role": self.switch_role, + "software_version": self.software_version, + "mode": (ad.system_mode if ad and hasattr(ad, "system_mode") else None), + } + + +__all__ = [ + "TelemetryIpCollection", + "VpcData", + "SwitchMetadata", + "AdditionalSwitchData", + "AdditionalAciSwitchData", + "Metadata", + "SwitchDataModel", +] diff --git a/plugins/module_utils/models/manage_switches/validators.py b/plugins/module_utils/models/manage_switches/validators.py new file mode 100644 index 000000000..f001ba7a7 --- /dev/null +++ b/plugins/module_utils/models/manage_switches/validators.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- + +# Copyright: (c) 2026, Akshayanat C S (@achengam) + +# GNU General Public License v3.0+ (see LICENSE or https://www.gnu.org/licenses/gpl-3.0.txt) + +"""Switch-specific validators. + +Domain-specific validators for switch models. Generic validators for IP, MAC, +hostname, etc. are imported from common.validators and re-exported for convenience. +""" + +from __future__ import annotations + +import re + +# Import and re-export generic validators from common module +from ...common.validators import ( + _normalize_optional_string, + _require_field, + validate_ip_address, + validate_cidr, + validate_ip_or_cidr_as_cidr, + validate_hostname, + validate_mac_address, + require_ip_address, + require_ip_or_cidr_as_cidr, + require_hostname, + require_mac_address, + validate_cidr_optional, + check_credentials_pair, +) + +# ------------------------------------------------------------------ +# Switch-specific validators +# ------------------------------------------------------------------ + + +def validate_serial_number(v: str | None) -> str | None: + """Validate switch serial number format. + + Args: + v: Raw serial number value. + + Returns: + Validated serial number, or ``None`` if input is None/empty. + + Raises: + ValueError: When the serial number contains invalid characters. + """ + v = _normalize_optional_string(v) + if v is None: + return None + # Serial numbers are typically alphanumeric with optional hyphens + if not re.match(r"^[A-Za-z0-9_-]+$", v): + raise ValueError(f"Serial number must be alphanumeric with optional hyphens/underscores: {v}") + return v + + +def require_serial_number(v: str, field_name: str = "serial_number") -> str: + """Validate and require a non-empty serial number. + + Args: + v: Raw serial number value. + field_name: Field name used in the error message. + + Returns: + Validated serial number string. + + Raises: + ValueError: When the value is empty or contains invalid characters. + """ + return _require_field(v, validate_serial_number, field_name) + + +def validate_vpc_domain(v: int | None) -> int | None: + """Validate VPC domain ID (1-1000). + + Args: + v: VPC domain ID. + + Returns: + Validated VPC domain ID, or ``None`` if input is None. + + Raises: + ValueError: When the value is out of valid range. + """ + if v is None: + return None + if not 1 <= v <= 1000: + raise ValueError(f"VPC domain must be between 1 and 1000: {v}") + return v + + +def check_discovery_credentials_pair(username: str | None, password: str | None) -> None: + """Enforce mutual-presence of discovery credentials. + + Both ``discovery_username`` and ``discovery_password`` must either be + absent together or present together. + + Args: + username: discovery_username value (may be ``None``). + password: discovery_password value (may be ``None``). + + Raises: + ValueError: When exactly one of the two is provided. + """ + check_credentials_pair(username, password, "discovery_username", "discovery_password") + + +__all__ = [ + # Generic validators (re-exported from common.validators) + "validate_ip_address", + "validate_cidr", + "validate_ip_or_cidr_as_cidr", + "validate_hostname", + "validate_mac_address", + "require_ip_address", + "require_ip_or_cidr_as_cidr", + "require_hostname", + "require_mac_address", + "validate_cidr_optional", + "check_credentials_pair", + # Switch-specific validators + "validate_serial_number", + "require_serial_number", + "validate_vpc_domain", + "check_discovery_credentials_pair", +] diff --git a/plugins/module_utils/nd.py b/plugins/module_utils/nd.py index f8f14e5d0..486e182c1 100644 --- a/plugins/module_utils/nd.py +++ b/plugins/module_utils/nd.py @@ -18,6 +18,7 @@ from ansible.module_utils._text import to_native, to_text from ansible.module_utils.connection import Connection from ansible_collections.cisco.nd.plugins.module_utils.constants import ALLOWED_STATES_TO_APPEND_SENT_AND_PROPOSED +from ansible_collections.cisco.nd.plugins.module_utils.utils import issubset def sanitize_dict(dict_to_sanitize, keys=None, values=None, recursive=True, remove_none_values=True): @@ -67,43 +68,6 @@ def cmp(a, b): return (a > b) - (a < b) -def issubset(subset, superset): - """Recurse through a nested dictionary and check if it is a subset of another.""" - - if type(subset) is not type(superset): - return False - - if not isinstance(subset, dict): - if isinstance(subset, list): - if len(subset) != len(superset): - return False - - remaining = list(superset) - for item in subset: - for index, candidate in enumerate(remaining): - if issubset(item, candidate) and issubset(candidate, item): - del remaining[index] - break - else: - return False - return True - return subset == superset - - for key, value in subset.items(): - if value is None: - continue - - if key not in superset: - return False - - superset_value = superset.get(key) - - if not issubset(value, superset_value): - return False - - return True - - def update_qs(params): """Append key-value pairs to self.filter_string""" accepted_params = dict((k, v) for (k, v) in params.items() if v is not None) diff --git a/tests/unit/module_utils/test_common_validators.py b/tests/unit/module_utils/test_common_validators.py new file mode 100644 index 000000000..933f67ec7 --- /dev/null +++ b/tests/unit/module_utils/test_common_validators.py @@ -0,0 +1,390 @@ +# -*- coding: utf-8 -*- + +# Copyright: (c) 2026, Akshayanat C S (@achengam) + +# GNU General Public License v3.0+ (see LICENSE or https://www.gnu.org/licenses/gpl-3.0.txt) + +"""Unit tests for common validators.""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import pytest + +from ansible_collections.cisco.nd.plugins.module_utils.common.validators import ( + _normalize_optional_string, + _require_field, + validate_ip_address, + validate_cidr, + validate_ip_or_cidr_as_cidr, + validate_hostname, + validate_mac_address, + require_ip_address, + require_ip_or_cidr_as_cidr, + require_hostname, + require_mac_address, + validate_cidr_optional, + check_credentials_pair, +) + + +class TestNormalizeOptionalString: + """Tests for _normalize_optional_string helper.""" + + def test_none_input(self): + """Test None input returns None.""" + assert _normalize_optional_string(None) is None + + def test_empty_string(self): + """Test empty string returns None.""" + assert _normalize_optional_string("") is None + + def test_whitespace_only(self): + """Test whitespace-only string returns None.""" + assert _normalize_optional_string(" ") is None + assert _normalize_optional_string("\t\n") is None + + def test_valid_string(self): + """Test valid string is stripped and returned.""" + assert _normalize_optional_string("hello") == "hello" + assert _normalize_optional_string(" hello ") == "hello" + assert _normalize_optional_string("\thello\n") == "hello" + + def test_numeric_input(self): + """Test numeric input is converted to string.""" + assert _normalize_optional_string(123) == "123" + assert _normalize_optional_string(0) == "0" + + +class TestRequireField: + """Tests for _require_field helper.""" + + def test_valid_value(self): + """Test valid value passes through.""" + + def dummy_validator(v): + return v if v else None + + result = _require_field("test", dummy_validator, "test_field") + assert result == "test" + + def test_empty_value_raises(self): + """Test empty value raises ValueError.""" + + def dummy_validator(v): + return None + + with pytest.raises(ValueError, match="test_field cannot be empty"): + _require_field("", dummy_validator, "test_field") + + +class TestValidateIpAddress: + """Tests for validate_ip_address.""" + + def test_none_input(self): + """Test None input returns None.""" + assert validate_ip_address(None) is None + + def test_empty_string(self): + """Test empty string returns None.""" + assert validate_ip_address("") is None + assert validate_ip_address(" ") is None + + def test_valid_ipv4(self): + """Test valid IPv4 addresses.""" + assert validate_ip_address("192.168.1.1") == "192.168.1.1" + assert validate_ip_address("10.0.0.1") == "10.0.0.1" + assert validate_ip_address("255.255.255.255") == "255.255.255.255" + assert validate_ip_address("0.0.0.0") == "0.0.0.0" + + def test_valid_ipv6(self): + """Test valid IPv6 addresses.""" + assert validate_ip_address("2001:db8::1") == "2001:db8::1" + assert validate_ip_address("::1") == "::1" + assert validate_ip_address("fe80::1") == "fe80::1" + + def test_invalid_ip(self): + """Test invalid IP addresses raise ValueError.""" + with pytest.raises(ValueError, match="Invalid IP address format"): + validate_ip_address("256.1.1.1") + + with pytest.raises(ValueError, match="Invalid IP address format"): + validate_ip_address("not-an-ip") + + with pytest.raises(ValueError, match="Invalid IP address format"): + validate_ip_address("192.168.1") + + def test_whitespace_handling(self): + """Test whitespace is stripped.""" + assert validate_ip_address(" 192.168.1.1 ") == "192.168.1.1" + + +class TestValidateCidr: + """Tests for validate_cidr.""" + + def test_none_input(self): + """Test None input returns None.""" + assert validate_cidr(None) is None + + def test_empty_string(self): + """Test empty string returns None.""" + assert validate_cidr("") is None + + def test_valid_ipv4_cidr(self): + """Test valid IPv4 CIDR notation.""" + assert validate_cidr("192.168.1.0/24") == "192.168.1.0/24" + assert validate_cidr("10.0.0.0/8") == "10.0.0.0/8" + assert validate_cidr("172.16.0.0/12") == "172.16.0.0/12" + + def test_valid_ipv6_cidr(self): + """Test valid IPv6 CIDR notation.""" + assert validate_cidr("2001:db8::/32") == "2001:db8::/32" + assert validate_cidr("fe80::/10") == "fe80::/10" + + def test_missing_slash(self): + """Test missing slash raises ValueError.""" + with pytest.raises(ValueError, match="CIDR notation required"): + validate_cidr("192.168.1.0") + + def test_invalid_cidr(self): + """Test invalid CIDR raises ValueError.""" + with pytest.raises(ValueError, match="Invalid CIDR format"): + validate_cidr("192.168.1.0/33") + + with pytest.raises(ValueError, match="Invalid CIDR format"): + validate_cidr("not-a-cidr/24") + + +class TestValidateIpOrCidrAsCidr: + """Tests for validate_ip_or_cidr_as_cidr.""" + + def test_none_input(self): + """Test None input returns None.""" + assert validate_ip_or_cidr_as_cidr(None) is None + + def test_plain_ipv4_normalized(self): + """Test plain IPv4 is normalized to /32.""" + assert validate_ip_or_cidr_as_cidr("192.168.1.1") == "192.168.1.1/32" + assert validate_ip_or_cidr_as_cidr("10.0.0.1") == "10.0.0.1/32" + + def test_plain_ipv6_normalized(self): + """Test plain IPv6 is normalized to /128.""" + assert validate_ip_or_cidr_as_cidr("2001:db8::1") == "2001:db8::1/128" + assert validate_ip_or_cidr_as_cidr("::1") == "::1/128" + + def test_ipv4_cidr_validated(self): + """Test IPv4 CIDR is validated and returned.""" + result = validate_ip_or_cidr_as_cidr("192.168.1.0/24") + assert result == "192.168.1.0/24" + + def test_ipv6_cidr_validated(self): + """Test IPv6 CIDR is validated and returned.""" + result = validate_ip_or_cidr_as_cidr("2001:db8::/32") + assert result == "2001:db8::/32" + + def test_invalid_ip(self): + """Test invalid IP raises ValueError.""" + with pytest.raises(ValueError, match="Invalid IP address format"): + validate_ip_or_cidr_as_cidr("not-an-ip") + + def test_invalid_cidr(self): + """Test invalid CIDR raises ValueError.""" + with pytest.raises(ValueError, match="Invalid CIDR format"): + validate_ip_or_cidr_as_cidr("192.168.1.0/33") + + +class TestValidateHostname: + """Tests for validate_hostname.""" + + def test_none_input(self): + """Test None input returns None.""" + assert validate_hostname(None) is None + + def test_empty_string(self): + """Test empty string returns None.""" + assert validate_hostname("") is None + + def test_valid_hostnames(self): + """Test valid hostnames.""" + assert validate_hostname("switch1") == "switch1" + assert validate_hostname("my-switch") == "my-switch" + assert validate_hostname("switch_1") == "switch_1" + assert validate_hostname("switch.example.com") == "switch.example.com" + assert validate_hostname("SW01-LEAF-01") == "SW01-LEAF-01" + + def test_too_long(self): + """Test hostname exceeding 255 characters.""" + long_name = "a" * 256 + with pytest.raises(ValueError, match="cannot exceed 255 characters"): + validate_hostname(long_name) + + def test_invalid_start(self): + """Test hostname starting with invalid character.""" + with pytest.raises(ValueError, match="Must start with alphanumeric"): + validate_hostname("-switch") + + with pytest.raises(ValueError, match="Must start with alphanumeric"): + validate_hostname(".switch") + + def test_invalid_characters(self): + """Test hostname with invalid characters.""" + with pytest.raises(ValueError, match="Must start with alphanumeric"): + validate_hostname("switch@host") + + def test_ending_with_dot(self): + """Test hostname ending with dot.""" + with pytest.raises(ValueError, match="Cannot end with dot"): + validate_hostname("switch.") + + def test_consecutive_dots(self): + """Test hostname with consecutive dots.""" + with pytest.raises(ValueError, match="consecutive dots"): + validate_hostname("switch..example.com") + + +class TestValidateMacAddress: + """Tests for validate_mac_address.""" + + def test_none_input(self): + """Test None input returns None.""" + assert validate_mac_address(None) is None + + def test_empty_string(self): + """Test empty string returns None.""" + assert validate_mac_address("") is None + + def test_colon_separated(self): + """Test colon-separated format.""" + assert validate_mac_address("AA:BB:CC:DD:EE:FF") == "AA:BB:CC:DD:EE:FF" + assert validate_mac_address("aa:bb:cc:dd:ee:ff") == "AA:BB:CC:DD:EE:FF" + assert validate_mac_address("00:11:22:33:44:55") == "00:11:22:33:44:55" + + def test_hyphen_separated(self): + """Test hyphen-separated format.""" + assert validate_mac_address("AA-BB-CC-DD-EE-FF") == "AA:BB:CC:DD:EE:FF" + assert validate_mac_address("aa-bb-cc-dd-ee-ff") == "AA:BB:CC:DD:EE:FF" + + def test_cisco_dot_notation(self): + """Test Cisco dot notation format.""" + assert validate_mac_address("aabb.ccdd.eeff") == "AA:BB:CC:DD:EE:FF" + assert validate_mac_address("AABB.CCDD.EEFF") == "AA:BB:CC:DD:EE:FF" + + def test_bare_hex(self): + """Test bare hex format.""" + assert validate_mac_address("aabbccddeeff") == "AA:BB:CC:DD:EE:FF" + assert validate_mac_address("AABBCCDDEEFF") == "AA:BB:CC:DD:EE:FF" + + def test_mixed_case(self): + """Test mixed case normalization.""" + assert validate_mac_address("Aa:Bb:Cc:Dd:Ee:Ff") == "AA:BB:CC:DD:EE:FF" + + def test_invalid_length(self): + """Test invalid MAC address length.""" + with pytest.raises(ValueError, match="Invalid MAC address format"): + validate_mac_address("AA:BB:CC:DD:EE") + + with pytest.raises(ValueError, match="Invalid MAC address format"): + validate_mac_address("AA:BB:CC:DD:EE:FF:00") + + def test_invalid_characters(self): + """Test invalid characters in MAC address.""" + with pytest.raises(ValueError, match="Invalid MAC address format"): + validate_mac_address("GG:HH:II:JJ:KK:LL") + + with pytest.raises(ValueError, match="Invalid MAC address format"): + validate_mac_address("not-a-mac") + + +class TestRequireValidators: + """Tests for require_* validators.""" + + def test_require_ip_address_valid(self): + """Test require_ip_address with valid IP.""" + assert require_ip_address("192.168.1.1") == "192.168.1.1" + + def test_require_ip_address_empty(self): + """Test require_ip_address with empty value.""" + with pytest.raises(ValueError, match="IP address cannot be empty"): + require_ip_address("") + + def test_require_hostname_valid(self): + """Test require_hostname with valid hostname.""" + assert require_hostname("switch1") == "switch1" + + def test_require_hostname_empty(self): + """Test require_hostname with empty value.""" + with pytest.raises(ValueError, match="hostname cannot be empty"): + require_hostname("") + + def test_require_mac_address_valid(self): + """Test require_mac_address with valid MAC.""" + assert require_mac_address("AA:BB:CC:DD:EE:FF") == "AA:BB:CC:DD:EE:FF" + + def test_require_mac_address_empty(self): + """Test require_mac_address with empty value.""" + with pytest.raises(ValueError, match="MAC address cannot be empty"): + require_mac_address("") + + def test_require_ip_or_cidr_as_cidr_valid(self): + """Test require_ip_or_cidr_as_cidr with valid IP.""" + assert require_ip_or_cidr_as_cidr("192.168.1.1") == "192.168.1.1/32" + + def test_require_ip_or_cidr_as_cidr_empty(self): + """Test require_ip_or_cidr_as_cidr with empty value.""" + with pytest.raises(ValueError, match="IP or CIDR cannot be empty"): + require_ip_or_cidr_as_cidr("") + + +class TestValidateCidrOptional: + """Tests for validate_cidr_optional.""" + + def test_none_input(self): + """Test None input returns None.""" + assert validate_cidr_optional(None) is None + + def test_valid_cidr(self): + """Test valid CIDR passes through.""" + assert validate_cidr_optional("192.168.1.0/24") == "192.168.1.0/24" + + def test_empty_raises(self): + """Test empty string raises ValueError.""" + with pytest.raises(ValueError, match="CIDR cannot be empty"): + validate_cidr_optional("") + + +class TestCheckCredentialsPair: + """Tests for check_credentials_pair.""" + + def test_both_none(self): + """Test both username and password are None.""" + # Should not raise + check_credentials_pair(None, None) + + def test_both_present(self): + """Test both username and password are present.""" + # Should not raise + check_credentials_pair("admin", "password123") + + def test_both_empty_strings(self): + """Test both are empty strings (falsy).""" + # Should not raise + check_credentials_pair("", "") + + def test_username_only(self): + """Test only username is provided.""" + with pytest.raises(ValueError, match="password must be set when username is specified"): + check_credentials_pair("admin", None) + + def test_password_only(self): + """Test only password is provided.""" + with pytest.raises(ValueError, match="username must be set when password is specified"): + check_credentials_pair(None, "password123") + + def test_custom_field_names(self): + """Test custom field names in error messages.""" + with pytest.raises(ValueError, match="api_password must be set when api_username is specified"): + check_credentials_pair("admin", None, "api_username", "api_password") + + with pytest.raises(ValueError, match="api_username must be set when api_password is specified"): + check_credentials_pair(None, "password", "api_username", "api_password") diff --git a/tests/unit/module_utils/test_manage_switches_validators.py b/tests/unit/module_utils/test_manage_switches_validators.py new file mode 100644 index 000000000..3c6d826b5 --- /dev/null +++ b/tests/unit/module_utils/test_manage_switches_validators.py @@ -0,0 +1,293 @@ +# -*- coding: utf-8 -*- + +# Copyright: (c) 2026, Akshayanat C S (@achengam) + +# GNU General Public License v3.0+ (see LICENSE or https://www.gnu.org/licenses/gpl-3.0.txt) + +"""Unit tests for switch-specific validators.""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import pytest + +from ansible_collections.cisco.nd.plugins.module_utils.models.manage_switches.validators import ( + validate_serial_number, + require_serial_number, + validate_vpc_domain, + check_discovery_credentials_pair, +) + + +class TestValidateSerialNumber: + """Tests for validate_serial_number.""" + + def test_none_input(self): + """Test None input returns None.""" + assert validate_serial_number(None) is None + + def test_empty_string(self): + """Test empty string returns None.""" + assert validate_serial_number("") is None + assert validate_serial_number(" ") is None + + def test_valid_serial_numbers(self): + """Test valid serial number formats.""" + assert validate_serial_number("ABC123") == "ABC123" + assert validate_serial_number("FOC12345678") == "FOC12345678" + assert validate_serial_number("SN-123-456") == "SN-123-456" + assert validate_serial_number("SN_123_456") == "SN_123_456" + assert validate_serial_number("123456") == "123456" + assert validate_serial_number("ABCD-1234-EFGH") == "ABCD-1234-EFGH" + + def test_alphanumeric_only(self): + """Test alphanumeric serial numbers.""" + assert validate_serial_number("ABC123DEF456") == "ABC123DEF456" + assert validate_serial_number("1234567890") == "1234567890" + assert validate_serial_number("ABCDEFGH") == "ABCDEFGH" + + def test_with_hyphens(self): + """Test serial numbers with hyphens.""" + assert validate_serial_number("A-B-C") == "A-B-C" + assert validate_serial_number("123-456-789") == "123-456-789" + + def test_with_underscores(self): + """Test serial numbers with underscores.""" + assert validate_serial_number("A_B_C") == "A_B_C" + assert validate_serial_number("123_456_789") == "123_456_789" + + def test_mixed_separators(self): + """Test serial numbers with mixed hyphens and underscores.""" + assert validate_serial_number("ABC_123-DEF") == "ABC_123-DEF" + + def test_whitespace_stripped(self): + """Test whitespace is stripped.""" + assert validate_serial_number(" ABC123 ") == "ABC123" + assert validate_serial_number("\tFOC12345\n") == "FOC12345" + + def test_invalid_characters(self): + """Test invalid characters raise ValueError.""" + with pytest.raises(ValueError, match="Serial number must be alphanumeric"): + validate_serial_number("ABC@123") + + with pytest.raises(ValueError, match="Serial number must be alphanumeric"): + validate_serial_number("ABC 123") + + with pytest.raises(ValueError, match="Serial number must be alphanumeric"): + validate_serial_number("ABC.123") + + with pytest.raises(ValueError, match="Serial number must be alphanumeric"): + validate_serial_number("ABC/123") + + with pytest.raises(ValueError, match="Serial number must be alphanumeric"): + validate_serial_number("ABC#123") + + def test_special_characters_not_allowed(self): + """Test various special characters are rejected.""" + invalid_chars = ["!", "@", "#", "$", "%", "^", "&", "*", "(", ")", "+", "=", "[", "]", "{", "}", "|", "\\", "/", "?", ".", ",", "<", ">", " "] + for char in invalid_chars: + with pytest.raises(ValueError, match="Serial number must be alphanumeric"): + validate_serial_number(f"ABC{char}123") + + +class TestRequireSerialNumber: + """Tests for require_serial_number.""" + + def test_valid_serial_number(self): + """Test valid serial number passes through.""" + assert require_serial_number("ABC123") == "ABC123" + assert require_serial_number("FOC12345678") == "FOC12345678" + + def test_empty_raises(self): + """Test empty value raises ValueError with default field name.""" + with pytest.raises(ValueError, match="serial_number cannot be empty"): + require_serial_number("") + + def test_whitespace_only_raises(self): + """Test whitespace-only value raises ValueError.""" + with pytest.raises(ValueError, match="serial_number cannot be empty"): + require_serial_number(" ") + + def test_custom_field_name(self): + """Test custom field name in error message.""" + with pytest.raises(ValueError, match="device_serial cannot be empty"): + require_serial_number("", "device_serial") + + def test_invalid_format_with_custom_field(self): + """Test invalid format with custom field name.""" + with pytest.raises(ValueError, match="Serial number must be alphanumeric"): + require_serial_number("ABC@123", "device_serial") + + def test_switch_id_field_name(self): + """Test with switch_id field name (common use case).""" + with pytest.raises(ValueError, match="switch_id cannot be empty"): + require_serial_number("", "switch_id") + + assert require_serial_number("FOC123", "switch_id") == "FOC123" + + def test_peer_switch_id_field_name(self): + """Test with peer_switch_id field name (common use case).""" + with pytest.raises(ValueError, match="peer_switch_id cannot be empty"): + require_serial_number("", "peer_switch_id") + + assert require_serial_number("FOC456", "peer_switch_id") == "FOC456" + + +class TestValidateVpcDomain: + """Tests for validate_vpc_domain.""" + + def test_none_input(self): + """Test None input returns None.""" + assert validate_vpc_domain(None) is None + + def test_valid_vpc_domains(self): + """Test valid VPC domain IDs.""" + assert validate_vpc_domain(1) == 1 + assert validate_vpc_domain(100) == 100 + assert validate_vpc_domain(500) == 500 + assert validate_vpc_domain(1000) == 1000 + + def test_boundary_values(self): + """Test boundary values.""" + assert validate_vpc_domain(1) == 1 # Minimum + assert validate_vpc_domain(1000) == 1000 # Maximum + + def test_below_minimum(self): + """Test VPC domain ID below minimum.""" + with pytest.raises(ValueError, match="VPC domain must be between 1 and 1000"): + validate_vpc_domain(0) + + with pytest.raises(ValueError, match="VPC domain must be between 1 and 1000"): + validate_vpc_domain(-1) + + with pytest.raises(ValueError, match="VPC domain must be between 1 and 1000"): + validate_vpc_domain(-100) + + def test_above_maximum(self): + """Test VPC domain ID above maximum.""" + with pytest.raises(ValueError, match="VPC domain must be between 1 and 1000"): + validate_vpc_domain(1001) + + with pytest.raises(ValueError, match="VPC domain must be between 1 and 1000"): + validate_vpc_domain(5000) + + def test_common_values(self): + """Test commonly used VPC domain IDs.""" + common_domains = [1, 10, 50, 100, 200, 500, 999, 1000] + for domain in common_domains: + assert validate_vpc_domain(domain) == domain + + +class TestCheckDiscoveryCredentialsPair: + """Tests for check_discovery_credentials_pair.""" + + def test_both_none(self): + """Test both discovery_username and discovery_password are None.""" + # Should not raise + check_discovery_credentials_pair(None, None) + + def test_both_present(self): + """Test both discovery_username and discovery_password are present.""" + # Should not raise + check_discovery_credentials_pair("admin", "password123") + check_discovery_credentials_pair("user", "pass") + + def test_both_empty_strings(self): + """Test both are empty strings (falsy).""" + # Should not raise + check_discovery_credentials_pair("", "") + + def test_username_only(self): + """Test only discovery_username is provided.""" + with pytest.raises(ValueError, match="discovery_password must be set when discovery_username is specified"): + check_discovery_credentials_pair("admin", None) + + with pytest.raises(ValueError, match="discovery_password must be set when discovery_username is specified"): + check_discovery_credentials_pair("admin", "") + + def test_password_only(self): + """Test only discovery_password is provided.""" + with pytest.raises(ValueError, match="discovery_username must be set when discovery_password is specified"): + check_discovery_credentials_pair(None, "password123") + + with pytest.raises(ValueError, match="discovery_username must be set when discovery_password is specified"): + check_discovery_credentials_pair("", "password123") + + def test_various_credential_combinations(self): + """Test various valid credential combinations.""" + # Valid combinations + valid_combinations = [ + (None, None), + ("", ""), + ("admin", "pass"), + ("user123", "P@ssw0rd!"), + ("discovery_user", "complex_pass_123"), + ] + + for username, password in valid_combinations: + # Should not raise + check_discovery_credentials_pair(username, password) + + def test_invalid_combinations(self): + """Test various invalid credential combinations.""" + invalid_combinations = [ + ("admin", None), + ("admin", ""), + (None, "password"), + ("", "password"), + ("user", None), + (None, "pass"), + ] + + for username, password in invalid_combinations: + with pytest.raises(ValueError): + check_discovery_credentials_pair(username, password) + + +class TestValidatorIntegration: + """Integration tests for validators working together.""" + + def test_serial_number_in_vpc_context(self): + """Test serial number validation in VPC configuration context.""" + # Valid VPC configuration + serial = validate_serial_number("FOC12345678") + vpc_domain = validate_vpc_domain(100) + + assert serial == "FOC12345678" + assert vpc_domain == 100 + + def test_discovery_with_valid_serial(self): + """Test discovery credentials with serial number validation.""" + serial = require_serial_number("FOC12345", "device_serial") + + # Valid credentials + check_discovery_credentials_pair("admin", "password") + + assert serial == "FOC12345" + + def test_multiple_validators_chain(self): + """Test multiple validators in a chain.""" + # Simulate validating a switch configuration + serial = require_serial_number("FOC123456", "switch_id") + vpc_domain = validate_vpc_domain(50) + + # Discovery credentials optional (both None) + check_discovery_credentials_pair(None, None) + + assert serial == "FOC123456" + assert vpc_domain == 50 + + def test_error_handling_order(self): + """Test that validators fail fast with clear errors.""" + # First validator should catch empty serial + with pytest.raises(ValueError, match="serial_number cannot be empty"): + require_serial_number("") + + # Invalid VPC domain + with pytest.raises(ValueError, match="VPC domain must be between"): + validate_vpc_domain(2000) + + # Incomplete credentials + with pytest.raises(ValueError, match="discovery_password must be set"): + check_discovery_credentials_pair("admin", None)