diff --git a/lufa/api_v1.py b/lufa/api_v1.py index bd12b95..66f669c 100644 --- a/lufa/api_v1.py +++ b/lufa/api_v1.py @@ -9,6 +9,7 @@ from lufa.decorators import debug_only from lufa.provider import get_api_repository, get_awx_client, get_database_manager from lufa.repository.api_repository import JobExport, LufaKeyError +from lufa.repository.backend_repository import ResourceNotFoundError MALFORMED_JSON = {"error": "Malformed json"} @@ -138,6 +139,21 @@ def compliance(): return jsonify(resp) +@bp.route("/compliance/hosts/", methods=["GET"]) +@ro_token_required +@pass_safe_exceptions +def compliance_hosts(ansible_host): + """ + Returns compliance state to the given host. + """ + repository = get_api_repository() + + try: + return jsonify(repository.get_host_compliance_state(ansible_host)) + except ResourceNotFoundError: + return jsonify({"error": f"Host {ansible_host} not found"}), 404 + + @bp.route("/tasks", methods=["POST"]) @token_required @pass_safe_exceptions diff --git a/lufa/repository/api_repository.py b/lufa/repository/api_repository.py index ba3cd70..c349e68 100644 --- a/lufa/repository/api_repository.py +++ b/lufa/repository/api_repository.py @@ -69,6 +69,12 @@ class JobTemplateComplianceStates(TypedDict): organisation: str +class HostComplianceState(TypedDict): + ansible_host: str + compliant: bool + noncompliant: list[JobTemplateComplianceStates] + + class FullJob(TypedDict): tower_job_id: int tower_job_template_id: int @@ -122,6 +128,11 @@ def get_all_noncompliant_hosts(self) -> dict[str, list[JobTemplateComplianceStat """ pass + @abstractmethod + def get_host_compliance_state(self, ansible_host: str) -> HostComplianceState: + """Returns compliance state to the given host.""" + pass + @abstractmethod def update_job( self, tower_job_id: int, end_time: Optional[TimeStamp] = None, artifacts: Optional[JSon] = None @@ -212,6 +223,33 @@ def get_all_noncompliant_hosts(self) -> dict[str, list[JobTemplateComplianceStat ret[line["ansible_host"]] = json.loads(line["noncompliant"]) return ret + def get_host_compliance_state(self, ansible_host: str) -> HostComplianceState: + conn = self.db_manager.get_db_connection() + cursor = conn.cursor() + cursor.execute( + """ + SELECT + c.ansible_host, + c.compliant, + n.noncompliant + FROM v_host_compliance c + LEFT JOIN v_host_noncompliance n + ON c.ansible_host = n.ansible_host + WHERE c.ansible_host = ? + """, + (ansible_host,), + ) + line = cursor.fetchone() + + if line is None: + raise ResourceNotFoundError(f"Host {ansible_host} not found") + + return HostComplianceState( + ansible_host=line["ansible_host"], + compliant=bool(line["compliant"]), + noncompliant=json.loads(line["noncompliant"]) if line["noncompliant"] else [], + ) + def add_stats(self, tower_job_id: int, stats: list[TowerJobStats]) -> None: conn: sqlite3.Connection = self.db_manager.get_db_connection() cursor = conn.cursor() @@ -768,6 +806,33 @@ def get_all_noncompliant_hosts(self) -> dict[str, list[JobTemplateComplianceStat ret[line["ansible_host"]] = line["noncompliant"] return ret + def get_host_compliance_state(self, ansible_host: str) -> HostComplianceState: + conn = self.db_manager.get_db_connection() + cursor = conn.cursor() + cursor.execute( + """ + SELECT + c.ansible_host, + c.compliant, + n.noncompliant + FROM v_host_compliance c + LEFT JOIN v_host_noncompliance n + ON c.ansible_host = n.ansible_host + WHERE c.ansible_host = %s + """, + (ansible_host,), + ) + line = cursor.fetchone() + + if line is None: + raise ResourceNotFoundError(f"Host {ansible_host} not found") + + return HostComplianceState( + ansible_host=line["ansible_host"], + compliant=bool(line["compliant"]), + noncompliant=line["noncompliant"] or [], + ) + def job_exists(self, tower_job_id) -> bool: conn = self.db_manager.get_db_connection() cursor = conn.cursor() diff --git a/tests/e2e/test_api.py b/tests/e2e/test_api.py index 336a7b9..c9b6bd1 100644 --- a/tests/e2e/test_api.py +++ b/tests/e2e/test_api.py @@ -164,6 +164,38 @@ def test_compliance(self, client): assert len(resp) == 1 assert len(list(resp.values())[0]) == 1 + def test_compliance_host_authorisation(self, client): + r = client.get(endpoint_uri + "/compliance/hosts/win999.example.com", headers=AUTH_HEADERS) + assert r.status_code in (200, 404) + + r = client.get(endpoint_uri + "/compliance/hosts/win999.example.com", headers=RO_AUTH_HEADERS) + assert r.status_code in (200, 404) + + r = client.get(endpoint_uri + "/compliance/hosts/win999.example.com", headers=INVALID_AUTH_HEADERS) + assert r.status_code == 401 + + def test_compliance_host(self, client): + r = client.post(endpoint_uri + "/jobs", json=generic_job_data) + assert r.status_code == 201, r.text + + r = client.post(endpoint_uri + "/stats", json=generic_stats_data) + assert r.status_code == 201, r.text + + r = client.get(endpoint_uri + "/compliance/hosts/win999.example.com", headers=AUTH_HEADERS) + assert r.status_code == 200, r.text + assert r.json["ansible_host"] == "win999.example.com" + assert r.json["compliant"] is True + assert r.json["noncompliant"] == [] + + r = client.get(endpoint_uri + "/compliance/hosts/win443.example.com", headers=AUTH_HEADERS) + assert r.status_code == 200, r.text + assert r.json["ansible_host"] == "win443.example.com" + assert r.json["compliant"] is False + assert len(r.json["noncompliant"]) == 1 + + r = client.get(endpoint_uri + "/compliance/hosts/unknown.example.com", headers=AUTH_HEADERS) + assert r.status_code == 404 + def test_post_jobs_multiple(self, client): # create x templates with y jobs each count_templates = 10