diff --git a/bytes/bytes/api/root.py b/bytes/bytes/api/root.py index f51f6ead207..9ffb21d2886 100644 --- a/bytes/bytes/api/root.py +++ b/bytes/bytes/api/root.py @@ -43,12 +43,12 @@ def validation_exception_handler(_: Request, exc: RequestValidationError | Valid @router.get("/", include_in_schema=False) -def health() -> RedirectResponse: +def root() -> RedirectResponse: return RedirectResponse(url="/health") @router.get("/health", response_model=ServiceHealth) -def root() -> ServiceHealth: +def health() -> ServiceHealth: bytes_health = ServiceHealth(service="bytes", healthy=True, version=__version__) return bytes_health diff --git a/octopoes/octopoes/api/router.py b/octopoes/octopoes/api/router.py index 4c31968d618..d1c374abab0 100644 --- a/octopoes/octopoes/api/router.py +++ b/octopoes/octopoes/api/router.py @@ -32,6 +32,7 @@ from octopoes.version import __version__ from octopoes.xtdb.client import XTDBSession from octopoes.xtdb.exceptions import XTDBException +from octopoes.xtdb.query import A from octopoes.xtdb.query import Query as XTDBQuery logger = getLogger(__name__) @@ -140,6 +141,51 @@ def query( return octopoes.ooi_repository.query(xtdb_query, valid_time) +@router.get("/query-many", tags=["Objects"]) +def query_many( + path: str, + sources: list[Reference] = Query(), + octopoes: OctopoesService = Depends(octopoes_service), + valid_time: datetime = Depends(extract_valid_time), +): + """ + How does this work and why do we do this? + + We want to fetch all results but be able to tie these back to the source that was used for a result. + If we query "Network.hostname" for a list of Networks ids, how do we know which hostname lives on which network? + The answer is to add the network id to the "select" statement, so the result is of the form + + [(network_id_1, hostname1), (network_id_2, hostname3), ...] + + Because you can only select variables in Datalog, "network_id_1" needs to be an Alias. Hence `source_alias`. + We need to tie that to the Network primary_key and add a where-in clause. The example projected on the code: + + q = XTDBQuery.from_path(object_path) # Adds "where ?Hostname.network = ?Network + + q.find(source_alias).pull(query.result_type) # "select ?network_id, ?Hostname + .where(object_path.segments[0].source_type, primary_key=source_alias) # where ?Network.primary_key = ?network_id + .where_in(object_path.segments[0].source_type, primary_key=sources) # and ?Network.primary_key in ["1", ...]" + """ + + if not sources: + return [] + + object_path = ObjectPath.parse(path) + if not object_path.segments: + raise HTTPException(status_code=400, detail="No path components provided.") + + q = XTDBQuery.from_path(object_path) + source_alias = A(object_path.segments[0].source_type, field="primary_key") + + return octopoes.ooi_repository.query( + q.find(source_alias) + .pull(q.result_type) + .where(object_path.segments[0].source_type, primary_key=source_alias) + .where_in(object_path.segments[0].source_type, primary_key=sources), + valid_time, + ) + + @router.post("/objects/load_bulk", tags=["Objects"]) def load_objects_bulk( octopoes: OctopoesService = Depends(octopoes_service), diff --git a/octopoes/octopoes/connector/octopoes.py b/octopoes/octopoes/connector/octopoes.py index cbd79d2cc60..8f442f322a2 100644 --- a/octopoes/octopoes/connector/octopoes.py +++ b/octopoes/octopoes/connector/octopoes.py @@ -272,13 +272,13 @@ def query( self, path: str, valid_time: datetime, - source: Reference | str | None = None, + source: OOI | Reference | str | None = None, offset: int = DEFAULT_OFFSET, limit: int = DEFAULT_LIMIT, ) -> list[OOI]: params = { "path": path, - "source": source, + "source": source.reference if isinstance(source, OOI) else source, "valid_time": str(valid_time), "offset": offset, "limit": limit, @@ -287,3 +287,22 @@ def query( TypeAdapter(OOIType).validate_python(ooi) for ooi in self.session.get(f"/{self.client}/query", params=params).json() ] + + def query_many( + self, + path: str, + valid_time: datetime, + sources: list[OOI | Reference | str], + ) -> list[tuple[str, OOIType]]: + if not sources: + return [] + + params = { + "path": path, + "sources": [str(ooi) for ooi in sources], + "valid_time": str(valid_time), + } + + result = self.session.get(f"/{self.client}/query-many", params=params).json() + + return TypeAdapter(list[tuple[str, OOIType]]).validate_python(result) diff --git a/octopoes/octopoes/repositories/ooi_repository.py b/octopoes/octopoes/repositories/ooi_repository.py index b70caa77da5..c2f55fb4216 100644 --- a/octopoes/octopoes/repositories/ooi_repository.py +++ b/octopoes/octopoes/repositories/ooi_repository.py @@ -730,12 +730,28 @@ def list_findings( }} """ - res = self.session.client.query(finding_query, valid_time) - findings = [self.deserialize(x[0]) for x in res] return Paginated( count=count, - items=findings, + items=[x[0] for x in self.query(finding_query, valid_time)], ) - def query(self, query: Query, valid_time: datetime) -> list[OOI]: - return [self.deserialize(row[0]) for row in self.session.client.query(query, valid_time=valid_time)] + def query(self, query: str | Query, valid_time: datetime) -> list[OOI | tuple]: + results = self.session.client.query(query, valid_time=valid_time) + + parsed_results = [] + for result in results: + parsed_result = [] + + for item in result: + try: + parsed_result.append(self.deserialize(item)) + except (ValueError, TypeError): + parsed_result.append(item) + + if len(parsed_result) == 1: + parsed_results.append(parsed_result[0]) + continue + + parsed_results.append(tuple(parsed_result)) + + return parsed_results diff --git a/octopoes/octopoes/xtdb/query.py b/octopoes/octopoes/xtdb/query.py index 1253b5e4410..699c82461f5 100644 --- a/octopoes/octopoes/xtdb/query.py +++ b/octopoes/octopoes/xtdb/query.py @@ -3,7 +3,7 @@ from octopoes.models import OOI from octopoes.models.path import Direction, Path -from octopoes.models.types import get_abstract_types, get_relations, to_concrete +from octopoes.models.types import get_abstract_types, to_concrete class InvalidField(ValueError): @@ -41,6 +41,10 @@ class Aliased: # https://stackoverflow.com/questions/61257658/python-dataclasses-mocking-the-default-factory-in-a-frozen-dataclass alias: UUID = field(default_factory=lambda: uuid4()) + # Sometimes an Alias refers to a plain field, not a whole model. The current solution is suboptimal + # as you can use aliases freely in Datalog but are now tied to the OOI types too much. TODO! + field: str | None = field(default=None) + Ref = type[OOI] | Aliased A = Aliased @@ -78,6 +82,14 @@ def where(self, ooi_type: Ref, **kwargs) -> "Query": return self + def where_in(self, ooi_type: Ref, **kwargs: list[str]) -> "Query": + """Allows for filtering on multiple values for a specific field.""" + + for field_name, values in kwargs.items(): + self._where_field_in(ooi_type, field_name, values) + + return self + def format(self) -> str: return self._compile(separator="\n ") @@ -119,10 +131,19 @@ def from_path(cls, path: Path) -> "Query": return query def pull(self, ooi_type: Ref) -> "Query": + """By default, we pull the target type. But when using find, count, etc., you have to pull explicitly.""" + self._find_clauses.append(f"(pull {self._get_object_alias(ooi_type)} [*])") return self + def find(self, item: Ref) -> "Query": + """Add a find clause, so we can select specific fields in a query to be returned as well.""" + + self._find_clauses.append(self._get_object_alias(item)) + + return self + def count(self, ooi_type: Ref) -> "Query": self._find_clauses.append(f"(count {self._get_object_alias(ooi_type)})") @@ -139,6 +160,23 @@ def offset(self, offset: int) -> "Query": return self def _where_field_is(self, ref: Ref, field_name: str, value: Ref | str | set[str]) -> None: + """ + We need isinstance(value, type) checks to verify value is an OOIType, as issubclass() fails on non-classes: + + >>> value = Network + >>> isinstance(value, OOIType) + False + >>> isinstance(value, OOI) + False + >>> isinstance(value, type) + True + >>> issubclass(value, OOI) + True + >>> issubclass(3, OOI) + Traceback (most recent call last): + [...] + TypeError: issubclass() arg 1 must be a class + """ ooi_type = ref.type if isinstance(ref, Aliased) else ref if field_name not in ooi_type.model_fields: @@ -146,17 +184,19 @@ def _where_field_is(self, ref: Ref, field_name: str, value: Ref | str | set[str] abstract_types = get_abstract_types() + if isinstance(value, str): + value = value.replace('"', r"\"") + if ooi_type in abstract_types: if isinstance(value, str): - value = value.replace('"', r"\"") - self._add_or_statement(ref, field_name, f'"{value}"') + self._add_or_statement_for_abstract_types(ref, field_name, f'"{value}"') return - if not isinstance(value, type): + if not isinstance(value, type | Aliased): raise InvalidField(f"value '{value}' for abstract class fields should be a string or an OOI Type") - if issubclass(value, OOI): - self._add_or_statement( + if isinstance(value, Aliased) or issubclass(value, OOI): + self._add_or_statement_for_abstract_types( ref, field_name, self._get_object_alias( @@ -166,7 +206,6 @@ def _where_field_is(self, ref: Ref, field_name: str, value: Ref | str | set[str] return if isinstance(value, str): - value = value.replace('"', r"\"") self._add_where_statement(ref, field_name, f'"{value}"') return @@ -176,11 +215,26 @@ def _where_field_is(self, ref: Ref, field_name: str, value: Ref | str | set[str] if not isinstance(value, Aliased) and not issubclass(value, OOI): raise InvalidField(f"{value} is not an OOI") - if field_name not in get_relations(ooi_type): - raise InvalidField(f'"{field_name}" is not a relation of {ooi_type.get_object_type()}') - self._add_where_statement(ref, field_name, self._get_object_alias(value)) + def _where_field_in(self, ref: Ref, field_name: str, values: list[str]) -> None: + ooi_type = ref.type if isinstance(ref, Aliased) else ref + + if field_name not in ooi_type.model_fields: + raise InvalidField(f'"{field_name}" is not a field of {ooi_type.get_object_type()}') + + new_values = [] + for value in values: + if not isinstance(value, str): + raise InvalidField("Only strings allowed as values for a WHERE IN statement for now.") + + value = value.replace('"', r"\"") + new_values.append(f'"{value}"') + + self._where_clauses.append( + self._or_statement_for_multiple_values(self._get_object_alias(ref), ooi_type, field_name, new_values) + ) + def _add_where_statement(self, ref: Ref, field_name: str, to_alias: str) -> None: ooi_type = ref.type if isinstance(ref, Aliased) else ref @@ -194,12 +248,12 @@ def _add_where_statement(self, ref: Ref, field_name: str, to_alias: str) -> None ) ) - def _add_or_statement(self, ref: Ref, field_name: str, to_alias: str) -> None: + def _add_or_statement_for_abstract_types(self, ref: Ref, field_name: str, to_alias: str) -> None: ooi_type = ref.type if isinstance(ref, Aliased) else ref self._where_clauses.append(self._assert_type(ref, ooi_type)) self._where_clauses.append( - self._or_statement( + self._or_statement_for_abstract_types( self._get_object_alias(ref), ooi_type.strict_subclasses(), field_name, @@ -207,7 +261,9 @@ def _add_or_statement(self, ref: Ref, field_name: str, to_alias: str) -> None: ) ) - def _or_statement(self, from_alias: str, concrete_types: list[type[OOI]], field_name: str, to_alias: str) -> str: + def _or_statement_for_abstract_types( + self, from_alias: str, concrete_types: list[type[OOI]], field_name: str, to_alias: str + ) -> str: relationships = [ self._relationship(from_alias, concrete_type.get_object_type(), field_name, to_alias) for concrete_type in concrete_types @@ -215,6 +271,15 @@ def _or_statement(self, from_alias: str, concrete_types: list[type[OOI]], field_ return f"(or {' '.join(relationships)} )" + def _or_statement_for_multiple_values( + self, from_alias: str, ooi_type: type[OOI], field_name: str, to_aliases: list[str] + ) -> str: + relationships = [ + self._relationship(from_alias, ooi_type.get_object_type(), field_name, to_alias) for to_alias in to_aliases + ] + + return f"(or {' '.join(relationships)} )" + def _relationship(self, from_alias: str, field_type: str, field_name: str, to_alias: str) -> str: return f"[ {from_alias} :{field_type}/{field_name} {to_alias} ]" @@ -258,7 +323,10 @@ def _compile(self, *, separator=" ") -> str: def _get_object_alias(self, object_type: Ref) -> str: if isinstance(object_type, Aliased): - return "?" + str(object_type.alias) + base = "?" + str(object_type.alias) + + # To have at least a way to separate aliases for types and plain fields in the raw query + return base if not object_type.field else base + "?" + object_type.field return object_type.get_object_type() diff --git a/octopoes/tests/integration/test_api_connector.py b/octopoes/tests/integration/test_api_connector.py index d70b10b6af4..e73eef80880 100644 --- a/octopoes/tests/integration/test_api_connector.py +++ b/octopoes/tests/integration/test_api_connector.py @@ -201,8 +201,23 @@ def test_query(octopoes_api_connector: OctopoesAPIConnector, valid_time: datetim assert len(results) == 1 assert str(results[0].port) == "443" - results = octopoes_api_connector.query(query, valid_time, source=hostnames[0].reference) + results = octopoes_api_connector.query(query, valid_time, source=hostnames[0]) assert len(results) == 0 - results = octopoes_api_connector.query(query, valid_time, source=hostnames[1].reference) + results = octopoes_api_connector.query(query, valid_time, source=hostnames[1]) assert len(results) == 1 + + query = "Hostname. is not an OOI" - with pytest.raises(InvalidField) as ctx: - Query(Network).where(Network, name=Network) - - assert ctx.exconly() == 'octopoes.xtdb.query.InvalidField: "name" is not a relation of Network' - def test_allow_string_for_foreign_keys(): query = Query(Network).where(Finding, ooi="Network|internet") @@ -200,6 +195,16 @@ def test_create_query_from_path_abstract(): assert query.format() == expected_query +def test_value_for_abstract_class_check(): + Query(IPAddress).where(IPAddress, network=Network).where(Network, name="test") + Query(IPAddress).where(IPAddress, network=A(Network)).where(Network, name="test") + + with pytest.raises(InvalidField) as ctx: + Query(IPAddress).where(IPAddress, network=3).where(Network, name="test") + + assert "value '3' for abstract class fields should be a string or an OOI Type" in ctx.exconly() + + def test_aliased_query(): h1 = A(Hostname, UUID("4b4afa7e-5b76-4506-a373-069216b051c2")) h2 = A(Hostname, UUID("98076f7a-7606-47ac-85b7-b511ee21ae42")) @@ -310,3 +315,32 @@ def test_build_system_query_with_path_segments(mocker): assert str(query) == str(path_query) assert query == path_query + + +def test_build_parth_query_with_multiple_sources(mocker): + mocker.patch("octopoes.xtdb.query.uuid4", return_value=UUID("311d6399-4bb4-4830-b077-661cc3f4f2c1")) + + query = Query(Website).where_in(Website, primary_key=["test_pk", "second_test_pk"]) + assert ( + query.format() + == """{:query {:find [(pull Website [*])] :where [ + (or [ Website :Website/primary_key "test_pk" ] [ Website :Website/primary_key "second_test_pk" ] ) + [ Website :object_type "Website" ]]}}""" + ) + + pk = A(Website, field="primary_key") + query = ( + Query(Website) + .find(pk) + .pull(Website) + .where(Website, primary_key=pk) + .where_in(Website, primary_key=["test_pk", "second_test_pk"]) + ) + + assert ( + query.format() + == """{:query {:find [?311d6399-4bb4-4830-b077-661cc3f4f2c1?primary_key (pull Website [*])] :where [ + (or [ Website :Website/primary_key "test_pk" ] [ Website :Website/primary_key "second_test_pk" ] ) + [ Website :Website/primary_key ?311d6399-4bb4-4830-b077-661cc3f4f2c1?primary_key ] + [ Website :object_type "Website" ]]}}""" + ) diff --git a/rocky/reports/report_types/aggregate_organisation_report/report.py b/rocky/reports/report_types/aggregate_organisation_report/report.py index 27febcb485e..81922c37d65 100644 --- a/rocky/reports/report_types/aggregate_organisation_report/report.py +++ b/rocky/reports/report_types/aggregate_organisation_report/report.py @@ -5,7 +5,7 @@ from django.utils.translation import gettext_lazy as _ from octopoes.connector.octopoes import OctopoesAPIConnector -from octopoes.models import OOI, Reference +from octopoes.models import OOI from octopoes.models.exception import ObjectNotFoundException from octopoes.models.ooi.config import Config from reports.report_types.definitions import AggregateReport @@ -444,27 +444,39 @@ def aggregate_reports( selected_report_types: list[str], valid_time: datetime, ) -> tuple[AggregateOrganisationReport, dict[str, Any], dict[str, Any], list[str]]: - aggregate_report = AggregateOrganisationReport(connector) - report_data: dict[str, Any] = {} - error_oois = [] + by_type: dict[str, list[str]] = {} for ooi in input_ooi_references: - report_data[ooi.primary_key] = {} + if ooi.get_object_type() not in by_type: + by_type[ooi.get_object_type()] = [] + + by_type[ooi.get_object_type()].append(str(ooi.reference)) + + all_types = [ + t + for t in AggregateOrganisationReport.reports["required"] + AggregateOrganisationReport.reports["optional"] + if t.id in selected_report_types + ] + report_data: dict[str, Any] = {} + errors = [] + + for report_type in all_types: + oois = {x for ooi_type in report_type.input_ooi_types for x in by_type.get(ooi_type.get_object_type(), [])} + try: - for options, report_types in aggregate_report.reports.items(): - # Mypy doesn't support TypedDict and .values() - # https://github.com/python/mypy/issues/7981 - for report_type in report_types: # type: ignore[attr-defined] - if ( - Reference.from_str(ooi).class_type in report_type.input_ooi_types - and report_type.id in selected_report_types - ): - report = report_type(connector) - data = report.generate_data(ooi.primary_key, valid_time=valid_time) - report_data[ooi.primary_key][report_type.id] = data + results = report_type(connector).collect_data(oois, valid_time) except ObjectNotFoundException: - logger.error("Object not found: %s", ooi.primary_key) - error_oois.append(ooi.primary_key) + logger.error("Object not found") + errors.append(report_type.id) + continue + + for ooi, data in results.items(): + if ooi not in report_data: + report_data[ooi] = {} + + report_data[ooi][report_type.id] = data + + aggregate_report = AggregateOrganisationReport(connector) post_processed_data = aggregate_report.post_process_data(report_data, valid_time=valid_time) - return aggregate_report, post_processed_data, report_data, error_oois + return aggregate_report, post_processed_data, report_data, errors diff --git a/rocky/reports/report_types/definitions.py b/rocky/reports/report_types/definitions.py index 4d36332f775..39d66801f66 100644 --- a/rocky/reports/report_types/definitions.py +++ b/rocky/reports/report_types/definitions.py @@ -1,9 +1,13 @@ +from collections.abc import Callable, Iterable from datetime import datetime from logging import getLogger from pathlib import Path -from typing import Any, TypedDict +from typing import Any, TypedDict, TypeVar from octopoes.connector.octopoes import OctopoesAPIConnector +from octopoes.models import Reference +from octopoes.models.ooi.dns.zone import Hostname +from octopoes.models.ooi.network import IPAddressV4, IPAddressV6 from octopoes.models.types import OOIType REPORTS_DIR = Path(__file__).parent @@ -25,6 +29,9 @@ def __init__(self, octopoes_api_connector: OctopoesAPIConnector): self.octopoes_api_connector = octopoes_api_connector +BaseReportType = TypeVar("BaseReportType", bound="BaseReport") + + class Report(BaseReport): plugins: ReportPlugins input_ooi_types: set[OOIType] @@ -32,6 +39,11 @@ class Report(BaseReport): def generate_data(self, input_ooi: str, valid_time: datetime) -> dict[str, Any]: raise NotImplementedError + def collect_data(self, input_oois: Iterable[str], valid_time: datetime) -> dict[str, dict[str, Any]]: + """Generate data for multiple OOIs. Child classes can override this method to improve performance.""" + + return {input_ooi: self.generate_data(input_ooi, valid_time) for input_ooi in input_oois} + @classmethod def class_attributes(cls) -> dict[str, Any]: return { @@ -43,6 +55,52 @@ def class_attributes(cls) -> dict[str, Any]: "template_path": cls.template_path, } + @staticmethod + def group_by_source( + query_result: list[tuple[str, OOIType]], + check: Callable[[OOIType], bool] | None = None, + ) -> dict[str, list[OOIType]]: + """Transform a query-many result from [(ref1, obj1), (ref1, obj2), ...] into {ref1: [obj1, obj2], ...}""" + + result: dict[str, list[OOIType]] = {} + + for source, ooi in query_result: + if source not in result: + result[source] = [] + + if not check or check(ooi): + result[source].append(ooi) + + return result + + @staticmethod + def group_finding_types_by_source( + query_result: list[tuple[str, OOIType]], + keep_ids: list[str] | None = None, + ) -> dict[str, list[OOIType]]: + if keep_ids: + return Report.group_by_source(query_result, lambda x: x.id in keep_ids) + + return Report.group_by_source(query_result) + + def to_hostnames(self, input_oois: Iterable[str], valid_time: datetime) -> dict[str, list[Reference]]: + """Turn a list of either Hostname and IPAddress reference strings into a list of related hostnames.""" + + refs = [Reference.from_str(input_ooi) for input_ooi in input_oois] + + hostnames_by_input_ooi = {str(ref): [ref] for ref in refs if ref.class_type == Hostname} + ip_refs = [ref for ref in refs if ref.class_type in (IPAddressV4, IPAddressV6)] + + for input_ooi, ip_hostname in self.octopoes_api_connector.query_many( + "IPAddress. dict[str, Any]: class AggregateReportSubReports(TypedDict): - required: list[type[Report] | type[MultiReport]] - optional: list[type[Report] | type[MultiReport]] + required: list[type[Report]] + optional: list[type[Report]] class AggregateReport(BaseReport): diff --git a/rocky/reports/report_types/helpers.py b/rocky/reports/report_types/helpers.py index 6d77bb151a2..890ab09b307 100644 --- a/rocky/reports/report_types/helpers.py +++ b/rocky/reports/report_types/helpers.py @@ -49,7 +49,7 @@ def get_report_types_for_ooi(ooi_pk: str) -> list[type[Report]]: return [report for report in REPORTS if ooi_type in report.input_ooi_types] -def get_report_types_for_oois(ooi_pks: list[str]) -> set[type[Report] | type[MultiReport]]: +def get_report_types_for_oois(ooi_pks: list[str]) -> set[type[Report]]: """ Get all report types that can be generated for a given list of OOIs """ @@ -86,7 +86,7 @@ def get_plugins_for_report_ids(reports: list[str]) -> dict[str, set[str]]: def get_report_types_from_aggregate_report( aggregate_report: type[AggregateReport], -) -> dict[str, set[type[Report] | type[MultiReport]]]: +) -> dict[str, set[type[Report]]]: required_reports = set() optional_reports = set() diff --git a/rocky/reports/report_types/ipv6_report/report.py b/rocky/reports/report_types/ipv6_report/report.py index 0ceedb1d1cd..a87683408b5 100644 --- a/rocky/reports/report_types/ipv6_report/report.py +++ b/rocky/reports/report_types/ipv6_report/report.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable from dataclasses import dataclass from datetime import datetime from logging import getLogger @@ -5,11 +6,8 @@ from django.utils.translation import gettext_lazy as _ -from octopoes.models import Reference -from octopoes.models.exception import ObjectNotFoundException from octopoes.models.ooi.dns.zone import Hostname from octopoes.models.ooi.network import IPAddressV4, IPAddressV6 -from octopoes.models.path import Path from reports.report_types.definitions import Report logger = getLogger(__name__) @@ -29,33 +27,24 @@ class IPv6Report(Report): input_ooi_types = {Hostname, IPAddressV4, IPAddressV6} template_path = "ipv6_report/report.html" - def generate_data(self, input_ooi: str, valid_time: datetime) -> dict[str, Any]: + def collect_data(self, input_oois: Iterable[str], valid_time: datetime) -> dict[str, dict[str, Any]]: """ For hostnames, check whether they point to ipv6 addresses. For ip addresses, check all hostnames that point to them, and check whether they point to ipv6 addresses. """ - try: - ooi = self.octopoes_api_connector.get(Reference.from_str(input_ooi), valid_time) - except ObjectNotFoundException as e: - logger.error("No data found for OOI '%s' on date %s.", str(e), str(valid_time)) - raise ObjectNotFoundException(e) - - if ooi.reference.class_type == IPAddressV4 or ooi.reference.class_type == IPAddressV6: - path = Path.parse("IPAddress.
dict[str, Any]: - hostnames = [] - finding_types = {} - - try: - ooi = self.octopoes_api_connector.get(Reference.from_str(input_ooi), valid_time) - except ObjectNotFoundException as e: - logger.error("No data found for OOI '%s' on date %s.", str(e), str(valid_time)) - raise ObjectNotFoundException(e) - - if ooi.reference.class_type == Hostname: - hostnames = [ooi] - elif ooi.reference.class_type in (IPAddressV4, IPAddressV6): - hostnames = self.octopoes_api_connector.query( - "IPAddress. dict[str, dict[str, Any]]: + hostnames_by_input_ooi = self.to_hostnames(input_oois, valid_time) + all_hostnames = [h for key, hostnames in hostnames_by_input_ooi.items() for h in hostnames] + + filtered_finding_types = self.group_finding_types_by_source( + self.octopoes_api_connector.query_many("Hostname. list[OOI]: finding_types = self.octopoes_api_connector.query( "Hostname. dict[str, Any]: - hostnames = [] - - try: - ooi = self.octopoes_api_connector.get(Reference.from_str(input_ooi), valid_time) - except ObjectNotFoundException as e: - logger.error("No data found for OOI '%s' on date %s.", str(e), str(valid_time)) - raise ObjectNotFoundException(e) - - if ooi.reference.class_type == Hostname: - hostnames = [ooi] - - elif ooi.reference.class_type in (IPAddressV4, IPAddressV6): - hostnames = self.octopoes_api_connector.query( - "IPAddress. dict[str, dict[str, Any]]: + hostnames_by_input_ooi = self.to_hostnames(input_oois, valid_time) + all_hostnames = [h for key, hostnames in hostnames_by_input_ooi.items() for h in hostnames] + + query = "Hostname. dict[str, Any]: - hostnames = [] - - try: - ooi = self.octopoes_api_connector.get(Reference.from_str(input_ooi), valid_time) - except ObjectNotFoundException as e: - logger.error("No data found for OOI '%s' on date %s.", str(e), str(valid_time)) - raise ObjectNotFoundException(e) - - if ooi.reference.class_type == Hostname: - hostnames = [ooi] - - elif ooi.reference.class_type in (IPAddressV4, IPAddressV6): - hostnames = self.octopoes_api_connector.query( - "IPAddress. dict[str, dict[str, Any]]: + hostnames_by_input_ooi = self.to_hostnames(input_oois, valid_time) + all_hostnames = [h for key, hostnames in hostnames_by_input_ooi.items() for h in hostnames] + + query = "Hostname. tuple[AggregateOrganisationReport, Any, dict[Any, dict[Any, Any]]]: - aggregate_report, post_processed_data, report_data, error_oois = aggregate_reports( + aggregate_report, post_processed_data, report_data, report_errors = aggregate_reports( self.octopoes_api_connector, self.get_oois(), self.selected_report_types, self.observed_at ) # If OOI could not be found or the date is incorrect, it will be shown to the user as a message error - if error_oois: - oois = ", ".join(set(error_oois)) + if report_errors: + report_types = ", ".join(set(report_errors)) date = self.observed_at.date() - error_message = _("No data could be found for %(oois)s. Object(s) did not exist on %(date)s.") % { - "oois": oois, + error_message = _("No data could be found for %(report_types). Object(s) did not exist on %(date)s.") % { + "report_types": report_types, "date": date, } messages.add_message(self.request, messages.ERROR, error_message) diff --git a/rocky/reports/views/base.py b/rocky/reports/views/base.py index 973cb0ce1ce..583df7278a5 100644 --- a/rocky/reports/views/base.py +++ b/rocky/reports/views/base.py @@ -18,7 +18,7 @@ from octopoes.models import OOI from octopoes.models.types import OOIType from reports.forms import OOITypeMultiCheckboxForReportForm -from reports.report_types.definitions import MultiReport, Report, ReportType +from reports.report_types.definitions import BaseReportType, MultiReport, Report, ReportType from reports.report_types.helpers import get_plugins_for_report_ids, get_report_by_id from rocky.views.mixins import OOIList from rocky.views.ooi_view import OOIFilterView @@ -114,16 +114,14 @@ def get_ooi_filter_forms(self, ooi_types: set[OOIType]) -> dict[str, Form]: ) } - def get_report_types_for_generate_report( - self, reports: set[type[Report] | type[MultiReport]] - ) -> list[dict[str, str]]: + def get_report_types_for_generate_report(self, reports: set[type[BaseReportType]]) -> list[dict[str, str]]: return [ {"id": report_type.id, "name": report_type.name, "description": report_type.description} for report_type in reports ] def get_report_types_for_aggregate_report( - self, reports_dict: dict[str, set[type[Report] | type[MultiReport]]] + self, reports_dict: dict[str, set[type[Report]]] ) -> dict[str, list[dict[str, str]]]: report_types = {} for option, reports in reports_dict.items(): diff --git a/rocky/reports/views/generate_report.py b/rocky/reports/views/generate_report.py index 6079c6c5acc..b906b41cbe7 100644 --- a/rocky/reports/views/generate_report.py +++ b/rocky/reports/views/generate_report.py @@ -146,30 +146,48 @@ def get(self, request, *args, **kwargs): return super().get(request, *args, **kwargs) def generate_reports_for_oois(self) -> dict[str, dict[str, dict[str, Any]]]: + error_reports = [] report_data: dict[str, dict[str, dict[str, Any]]] = {} - error_oois = [] + by_type: dict[str, list[str]] = {} + for ooi in self.selected_oois: - report_data[ooi] = {} + ooi_type = Reference.from_str(ooi).class_ + + if ooi_type not in by_type: + by_type[ooi_type] = [] + + by_type[ooi_type].append(ooi) + + for report_type in self.report_types: + oois = { + ooi for ooi_type in report_type.input_ooi_types for ooi in by_type.get(ooi_type.get_object_type(), []) + } + try: - for report_type in self.report_types: - if Reference.from_str(ooi).class_type in report_type.input_ooi_types: - report = report_type(self.octopoes_api_connector) - data = report.generate_data(ooi, valid_time=self.observed_at) - template = report.template_path - report_data[ooi][report_type.name] = {"data": data, "template": template} + results = report_type(self.octopoes_api_connector).collect_data(oois, self.observed_at) except ObjectNotFoundException: - error_oois.append(ooi) + error_reports.append(report_type.id) + continue except StopIteration: - error_oois.append(ooi) + error_reports.append(report_type.id) + continue + + for ooi, data in results.items(): + if ooi not in report_data: + report_data[ooi] = {} + + report_data[ooi][report_type.name] = {"data": data, "template": report_type.template_path} + # If OOI could not be found or the date is incorrect, it will be shown to the user as a message error - if error_oois: - oois = ", ".join(set(error_oois)) + if error_reports: + report_types = ", ".join(set(error_reports)) date = self.observed_at.date() - error_message = _("No data could be found for %(oois)s. Object(s) did not exist on %(date)s.") % { - "oois": oois, + error_message = _("No data could be found for %(report_types). Object(s) did not exist on %(date)s.") % { + "report_types": report_types, "date": date, } messages.add_message(self.request, messages.ERROR, error_message) + return report_data def get_context_data(self, **kwargs): diff --git a/rocky/tests/conftest.py b/rocky/tests/conftest.py index ce290619f0d..213f783b3bf 100644 --- a/rocky/tests/conftest.py +++ b/rocky/tests/conftest.py @@ -30,6 +30,7 @@ from octopoes.models.ooi.web import URL, SecurityTXT, Website from octopoes.models.origin import Origin, OriginType from octopoes.models.transaction import TransactionRecord +from octopoes.models.types import OOIType from rocky.scheduler import Task LANG_LIST = [code for code, _ in settings.LANGUAGES] @@ -909,6 +910,20 @@ def query( ) -> list[OOI]: return self.queries[path][source] + def query_many( + self, + path: str, + valid_time: datetime, + sources: list[OOI | Reference | str], + ) -> list[tuple[str, OOIType]]: + result = [] + + for source in sources: + for ooi in self.queries[path][source]: + result.append((source, ooi)) + + return result + def get_history(self, reference: Reference) -> list[TransactionRecord]: return [ TransactionRecord( diff --git a/rocky/tests/integration/conftest.py b/rocky/tests/integration/conftest.py index 47fab357cfb..1ed4f57098b 100644 --- a/rocky/tests/integration/conftest.py +++ b/rocky/tests/integration/conftest.py @@ -9,7 +9,7 @@ from octopoes.api.models import Declaration, Observation from octopoes.connector.octopoes import OctopoesAPIConnector -from octopoes.models import DeclaredScanProfile, Reference +from octopoes.models import OOI, DeclaredScanProfile, Reference from octopoes.models.ooi.certificate import X509Certificate from octopoes.models.ooi.dns.zone import Hostname, ResolvedHostname from octopoes.models.ooi.findings import CVEFindingType, KATFindingType, RetireJSFindingType, RiskLevelSeverity @@ -53,7 +53,7 @@ def seed_system( valid_time: datetime, test_hostname: str = "example.com", test_ip: str = "192.0.2.3", -): +) -> dict[str, list[OOI]]: network = Network(name="test") octopoes_api_connector.save_declaration(Declaration(ooi=network, valid_time=valid_time)) diff --git a/rocky/tests/integration/test_bench.py b/rocky/tests/integration/test_bench.py index ffdc879e6dc..1d53d55c4f5 100644 --- a/rocky/tests/integration/test_bench.py +++ b/rocky/tests/integration/test_bench.py @@ -1,6 +1,8 @@ import pytest from reports.report_types.aggregate_organisation_report.report import AggregateOrganisationReport, aggregate_reports +from octopoes.models.ooi.dns.zone import Hostname +from octopoes.models.ooi.network import Network from tests.integration.conftest import seed_system @@ -10,9 +12,14 @@ def test_aggregate_report_benchmark(octopoes_api_connector, valid_time): for x in hostname_range: seed_system(octopoes_api_connector, valid_time, test_hostname=f"{x}.com", test_ip=f"192.0.{x % 7}.{x % 256}") - reports = AggregateOrganisationReport.reports["required"] + AggregateOrganisationReport.reports["optional"] + reports = [ + x.id for x in AggregateOrganisationReport.reports["required"] + AggregateOrganisationReport.reports["optional"] + ] _, data, _, _ = aggregate_reports( - octopoes_api_connector, [f"Hostname|test|{x}.com" for x in hostname_range], reports, valid_time + octopoes_api_connector, + [Hostname(name=f"{x}.com", network=Network(name="test").reference) for x in hostname_range], + reports, + valid_time, ) - assert data + assert data["systems"] diff --git a/rocky/tests/integration/test_reports.py b/rocky/tests/integration/test_reports.py index a0f115701d2..4e9fc1ad432 100644 --- a/rocky/tests/integration/test_reports.py +++ b/rocky/tests/integration/test_reports.py @@ -19,7 +19,7 @@ def test_web_report(octopoes_api_connector: OctopoesAPIConnector, valid_time): report = WebSystemReport(octopoes_api_connector) input_ooi = "Hostname|test|example.com" - data = report.generate_data(input_ooi, valid_time) + data = report.collect_data([input_ooi], valid_time)[input_ooi] assert data["input_ooi"] == input_ooi assert len(data["finding_types"]) == 1 @@ -42,7 +42,7 @@ def test_web_report(octopoes_api_connector: OctopoesAPIConnector, valid_time): ooi=Reference.from_str("HTTPResource|test|192.0.2.3|tcp|25|smtp|test|example.com|http|test|example.com|80|/"), ) octopoes_api_connector.save_declaration(Declaration(ooi=finding, valid_time=valid_time)) - checks = report.generate_data(input_ooi, valid_time)["web_checks"].checks + checks = report.collect_data([input_ooi], valid_time)[input_ooi]["web_checks"].checks assert checks[0].has_csp is False assert checks[0].has_no_csp_vulnerabilities is False @@ -51,7 +51,7 @@ def test_web_report(octopoes_api_connector: OctopoesAPIConnector, valid_time): ooi=Reference.from_str("Website|test|192.0.2.3|tcp|25|smtp|test|example.com"), ) octopoes_api_connector.save_declaration(Declaration(ooi=finding, valid_time=valid_time)) - data = report.generate_data(input_ooi, valid_time) + data = report.collect_data([input_ooi], valid_time)[input_ooi] assert data["web_checks"].checks[0].offers_https is False assert len(data["finding_types"]) == 3 @@ -62,7 +62,7 @@ def test_system_report(octopoes_api_connector: OctopoesAPIConnector, valid_time) report = SystemReport(octopoes_api_connector) input_ooi = "Hostname|test|example.com" - data = report.generate_data(input_ooi, valid_time) + data = report.collect_data([input_ooi], valid_time)[input_ooi] assert data["input_ooi"] == input_ooi assert data["summary"] == { diff --git a/rocky/tests/reports/test_ipv6_report.py b/rocky/tests/reports/test_ipv6_report.py index a08cfd87b44..9c86df760f6 100644 --- a/rocky/tests/reports/test_ipv6_report.py +++ b/rocky/tests/reports/test_ipv6_report.py @@ -13,7 +13,7 @@ def test_ipv6_report_hostname_with_ipv6(mock_octopoes_api_connector, valid_time, report = IPv6Report(mock_octopoes_api_connector) - data = report.generate_data(str(hostname.reference), valid_time) + data = report.collect_data([str(hostname.reference)], valid_time)[str(hostname.reference)] assert data[hostname.name] == {"enabled": True} @@ -30,7 +30,7 @@ def test_ipv6_report_hostname_without_ipv6(mock_octopoes_api_connector, valid_ti report = IPv6Report(mock_octopoes_api_connector) - data = report.generate_data(str(hostname.reference), valid_time) + data = report.collect_data([str(hostname.reference)], valid_time)[str(hostname.reference)] assert data[hostname.name] == {"enabled": False} @@ -50,7 +50,7 @@ def test_ipv6_report_ipv4_without_ipv6(mock_octopoes_api_connector, valid_time, report = IPv6Report(mock_octopoes_api_connector) - data = report.generate_data(str(ipaddressv4.reference), valid_time) + data = report.collect_data([str(ipaddressv4.reference)], valid_time)[str(ipaddressv4.reference)] assert data[hostname.name] == {"enabled": False} @@ -70,7 +70,7 @@ def test_ipv6_report_ipv4_with_ipv6(mock_octopoes_api_connector, valid_time, hos report = IPv6Report(mock_octopoes_api_connector) - data = report.generate_data(str(ipaddressv4.reference), valid_time) + data = report.collect_data([str(ipaddressv4.reference)], valid_time)[str(ipaddressv4.reference)] assert data[hostname.name] == {"enabled": True} @@ -90,6 +90,6 @@ def test_ipv6_report_ipv6_wit_ipv6(mock_octopoes_api_connector, valid_time, host report = IPv6Report(mock_octopoes_api_connector) - data = report.generate_data(str(ipaddressv6.reference), valid_time) + data = report.collect_data([str(ipaddressv6.reference)], valid_time)[str(ipaddressv6.reference)] assert data[hostname.name] == {"enabled": True} diff --git a/rocky/tests/reports/test_mail_report.py b/rocky/tests/reports/test_mail_report.py index 5574aa536df..b5ec2dbca4e 100644 --- a/rocky/tests/reports/test_mail_report.py +++ b/rocky/tests/reports/test_mail_report.py @@ -13,7 +13,7 @@ def test_mail_report_no_findings(mock_octopoes_api_connector, valid_time, hostna report = MailReport(mock_octopoes_api_connector) - data = report.generate_data(str(hostname.reference), valid_time) + data = report.collect_data([str(hostname.reference)], valid_time)[str(hostname.reference)] assert len(data["finding_types"][str(hostname.reference)]) == 0 assert data["number_of_hostnames"] == 1 @@ -34,7 +34,7 @@ def test_mail_report_spf_finding(mock_octopoes_api_connector, valid_time, hostna report = MailReport(mock_octopoes_api_connector) - data = report.generate_data(str(hostname.reference), valid_time) + data = report.collect_data([str(hostname.reference)], valid_time)[str(hostname.reference)] assert len(data["finding_types"][str(hostname.reference)]) == 1 assert data["number_of_hostnames"] == 1 @@ -58,7 +58,7 @@ def test_mail_report_dkim_finding( report = MailReport(mock_octopoes_api_connector) - data = report.generate_data(str(ipaddressv4.reference), valid_time) + data = report.collect_data([str(ipaddressv4.reference)], valid_time)[str(ipaddressv4.reference)] assert len(data["finding_types"][str(hostname.reference)]) == 1 assert data["number_of_hostnames"] == 1 @@ -82,7 +82,7 @@ def test_mail_report_dmarc_finding( report = MailReport(mock_octopoes_api_connector) - data = report.generate_data(str(ipaddressv4.reference), valid_time) + data = report.collect_data([str(ipaddressv4.reference)], valid_time)[str(ipaddressv4.reference)] assert len(data["finding_types"][str(hostname.reference)]) == 1 assert data["number_of_hostnames"] == 1 @@ -110,7 +110,7 @@ def test_mail_report_multiple_findings( report = MailReport(mock_octopoes_api_connector) - data = report.generate_data(str(hostname.reference), valid_time) + data = report.collect_data([str(hostname.reference)], valid_time)[str(hostname.reference)] assert len(data["finding_types"][str(hostname.reference)]) == 3 assert data["number_of_hostnames"] == 1 diff --git a/rocky/tests/reports/test_name_server_report.py b/rocky/tests/reports/test_name_server_report.py index cd5c92db930..adb60dd9190 100644 --- a/rocky/tests/reports/test_name_server_report.py +++ b/rocky/tests/reports/test_name_server_report.py @@ -14,7 +14,7 @@ def test_name_server_report_no_hostname(mock_octopoes_api_connector, valid_time, report = NameServerSystemReport(mock_octopoes_api_connector) - data = report.generate_data(str(ipaddressv4.reference), valid_time) + data = report.collect_data([str(ipaddressv4.reference)], valid_time)[str(ipaddressv4.reference)] assert len(data["name_server_checks"].checks) == 0 assert data["name_server_checks"].has_dnssec == 0 @@ -38,7 +38,7 @@ def test_name_server_report_no_finding_types(mock_octopoes_api_connector, valid_ report = NameServerSystemReport(mock_octopoes_api_connector) - data = report.generate_data(str(hostname.reference), valid_time) + data = report.collect_data([str(hostname.reference)], valid_time)[str(hostname.reference)] assert len(data["name_server_checks"].checks) == 1 assert data["name_server_checks"].has_dnssec == 1 @@ -86,7 +86,7 @@ def test_name_server_report_multiple_finding_types( report = NameServerSystemReport(mock_octopoes_api_connector) - data = report.generate_data(str(ipaddressv4.reference), valid_time) + data = report.collect_data([str(ipaddressv4.reference)], valid_time)[str(ipaddressv4.reference)] assert len(data["name_server_checks"].checks) == 1 assert data["name_server_checks"].has_dnssec == 0 diff --git a/rocky/tests/reports/test_web_systems_report.py b/rocky/tests/reports/test_web_systems_report.py index fd7d3c986bc..aa6eea3ab6b 100644 --- a/rocky/tests/reports/test_web_systems_report.py +++ b/rocky/tests/reports/test_web_systems_report.py @@ -35,7 +35,7 @@ def test_web_report_no_findings(mock_octopoes_api_connector, valid_time, hostnam report = WebSystemReport(mock_octopoes_api_connector) - data = report.generate_data(str(hostname.reference), valid_time) + data = report.collect_data([str(hostname.reference)], valid_time)[str(hostname.reference)] assert bool(data["web_checks"]) @@ -77,7 +77,7 @@ def test_web_report_all_findings( report = WebSystemReport(mock_octopoes_api_connector) - data = report.generate_data(str(hostname.reference), valid_time) + data = report.collect_data([str(hostname.reference)], valid_time)[str(hostname.reference)] checks = data["web_checks"]