diff --git a/application/tests/web_main_test.py b/application/tests/web_main_test.py index 3f78a5e2f..4930b2b73 100644 --- a/application/tests/web_main_test.py +++ b/application/tests/web_main_test.py @@ -400,7 +400,11 @@ def test_find_document_by_tag(self) -> None: response = client.get(f"/rest/v1/tags?tag=CW") self.assertEqual(404, response.status_code) - expected = {"data": [cres["ca"].todict(), cres["cb"].todict()]} + expected = { + "data": [cres["ca"].todict(), cres["cb"].todict()], + "page": 1, + "total_pages": 1, + } response = client.get(f"/rest/v1/tags?tag=ta") self.assertEqual(200, response.status_code) @@ -442,6 +446,89 @@ def test_test_search(self) -> None: self.assertEqual(200, resp.status_code) self.assertDictEqual(resp.json[0], expected[0]) + @patch.object(db, "Node_collection") + def test_tags_pagination_and_format_parity(self, db_mock) -> None: + docs = [ + defs.CRE(id="111-111", description="A", name="A", tags=["ta"]), + defs.CRE(id="222-222", description="B", name="B", tags=["ta"]), + defs.CRE(id="333-333", description="C", name="C", tags=["ta"]), + ] + db_mock.return_value.get_by_tags.return_value = docs + + with self.app.test_client() as client: + json_response = client.get("/rest/v1/tags?tag=ta&page=2&items_per_page=1") + payload = json.loads(json_response.data.decode()) + self.assertEqual(200, json_response.status_code) + self.assertEqual(2, payload["page"]) + self.assertEqual(3, payload["total_pages"]) + self.assertEqual("222-222", payload["data"][0]["id"]) + + md_response = client.get( + "/rest/v1/tags?tag=ta&page=2&items_per_page=1&format=md" + ) + self.assertEqual(200, md_response.status_code) + self.assertIn("222-222", md_response.data.decode()) + + csv_response = client.get( + "/rest/v1/tags?tag=ta&page=2&items_per_page=1&format=csv" + ) + self.assertEqual(200, csv_response.status_code) + self.assertGreater(len(csv_response.data), 0) + + oscal_response = client.get( + "/rest/v1/tags?tag=ta&page=2&items_per_page=1&format=oscal" + ) + oscal_payload = json.loads(oscal_response.data.decode()) + self.assertEqual(200, oscal_response.status_code) + self.assertIn("metadata", oscal_payload) + + out_of_range_response = client.get( + "/rest/v1/tags?tag=ta&page=99&items_per_page=1" + ) + self.assertEqual(404, out_of_range_response.status_code) + + @patch.object(db, "Node_collection") + def test_text_search_pagination_and_format_parity(self, db_mock) -> None: + docs = [ + defs.Standard(name="SB", section="s1", subsection="a"), + defs.Standard(name="SB", section="s2", subsection="b"), + defs.Standard(name="SB", section="s3", subsection="c"), + ] + db_mock.return_value.text_search.return_value = docs + + with self.app.test_client() as client: + json_response = client.get( + "/rest/v1/text_search?text=SB&page=2&items_per_page=1" + ) + payload = json.loads(json_response.data.decode()) + self.assertEqual(200, json_response.status_code) + self.assertEqual(1, len(payload)) + self.assertEqual("s2", payload[0]["section"]) + + md_response = client.get( + "/rest/v1/text_search?text=SB&page=2&items_per_page=1&format=md" + ) + self.assertEqual(200, md_response.status_code) + self.assertIn("s2", md_response.data.decode()) + + csv_response = client.get( + "/rest/v1/text_search?text=SB&page=2&items_per_page=1&format=csv" + ) + self.assertEqual(200, csv_response.status_code) + self.assertGreater(len(csv_response.data), 0) + + oscal_response = client.get( + "/rest/v1/text_search?text=SB&page=2&items_per_page=1&format=oscal" + ) + oscal_payload = json.loads(oscal_response.data.decode()) + self.assertEqual(200, oscal_response.status_code) + self.assertEqual(1, len(oscal_payload["controls"])) + + out_of_range_response = client.get( + "/rest/v1/text_search?text=SB&page=99&items_per_page=1" + ) + self.assertEqual(404, out_of_range_response.status_code) + def test_find_root_cres(self) -> None: self.maxDiff = None collection = db.Node_collection().with_graph() diff --git a/application/web/web_main.py b/application/web/web_main.py index 77eee36e7..ed26fe129 100644 --- a/application/web/web_main.py +++ b/application/web/web_main.py @@ -5,6 +5,7 @@ from functools import wraps import json import logging +import math import os import io import pathlib @@ -68,6 +69,35 @@ class SupportedFormats(Enum): OSCAL = "oscal" +def _parse_positive_int(value: str | None, default: int) -> int: + if value is None: + return default + try: + parsed = int(value) + if parsed > 0: + return parsed + except ValueError: + pass + return default + + +def _paginate_documents( + documents: list[defs.Document], +) -> tuple[int, int, list[defs.Document]]: + page = _parse_positive_int(request.args.get("page"), 1) + items_per_page = _parse_positive_int( + request.args.get("items_per_page"), ITEMS_PER_PAGE + ) + + total_pages = max(1, math.ceil(len(documents) / items_per_page)) + if page > total_pages: + abort(404, "Page does not exist") + + start = (page - 1) * items_per_page + end = start + items_per_page + return page, total_pages, documents[start:end] + + def extend_cre_with_tag_links( cre: defs.CRE, collection: db.Node_collection ) -> defs.CRE: @@ -250,19 +280,20 @@ def find_document_by_tag() -> Any: opt_format = request.args.get("format") documents = database.get_by_tags(tags) if documents: - res = [doc.todict() for doc in documents] - result = {"data": res} + page, total_pages, paged_documents = _paginate_documents(documents) + res = [doc.todict() for doc in paged_documents] + result = {"data": res, "page": page, "total_pages": total_pages} # if opt_osib: # result["osib"] = odefs.cre2osib(documents).todict() if opt_format == SupportedFormats.Markdown.value: - return f"
{mdutils.cre_to_md(documents)}"
+ return f"{mdutils.cre_to_md(paged_documents)}"
elif opt_format == SupportedFormats.CSV.value:
docs = sheet_utils.ExportSheet().prepare_spreadsheet(
- docs=documents, storage=database
+ docs=paged_documents, storage=database
)
return write_csv(docs=docs).getvalue().encode("utf-8")
elif opt_format == SupportedFormats.OSCAL.value:
- return jsonify(json.loads(oscal_utils.list_to_oscal(documents)))
+ return jsonify(json.loads(oscal_utils.list_to_oscal(paged_documents)))
return jsonify(result)
abort(404, "Tag does not exist")
@@ -414,17 +445,18 @@ def text_search() -> Any:
opt_format = request.args.get("format")
documents = database.text_search(text)
if documents:
+ _, _, paged_documents = _paginate_documents(documents)
if opt_format == SupportedFormats.Markdown.value:
- return f"{mdutils.cre_to_md(documents)}"
+ return f"{mdutils.cre_to_md(paged_documents)}"
elif opt_format == SupportedFormats.CSV.value:
docs = sheet_utils.ExportSheet().prepare_spreadsheet(
- docs=documents, storage=database
+ docs=paged_documents, storage=database
)
return write_csv(docs=docs).getvalue().encode("utf-8")
elif opt_format == SupportedFormats.OSCAL.value:
- return jsonify(json.loads(oscal_utils.list_to_oscal(documents)))
+ return jsonify(json.loads(oscal_utils.list_to_oscal(paged_documents)))
- res = [doc.todict() for doc in documents]
+ res = [doc.todict() for doc in paged_documents]
return jsonify(res)
else:
abort(404, "No object matches the given search terms")
diff --git a/docs/api/openapi.yaml b/docs/api/openapi.yaml
index 73ff48c83..367e4c28e 100644
--- a/docs/api/openapi.yaml
+++ b/docs/api/openapi.yaml
@@ -65,6 +65,20 @@ paths:
schema:
type: string
enum: [json, md, csv, oscal]
+ - name: page
+ in: query
+ required: false
+ description: 1-based page number for paginated results.
+ schema:
+ type: integer
+ minimum: 1
+ - name: items_per_page
+ in: query
+ required: false
+ description: Number of items per page.
+ schema:
+ type: integer
+ minimum: 1
responses:
'200':
description: CRE found
@@ -125,6 +139,20 @@ paths:
schema:
type: string
enum: [json, md, csv, oscal]
+ - name: page
+ in: query
+ required: false
+ description: 1-based page number for paginated results.
+ schema:
+ type: integer
+ minimum: 1
+ - name: items_per_page
+ in: query
+ required: false
+ description: Number of items per page.
+ schema:
+ type: integer
+ minimum: 1
responses:
'200':
description: Nodes retrieved