Skip to content

Commit

Permalink
Recalculate bit when a config object changes (#3206)
Browse files Browse the repository at this point in the history
Co-authored-by: stephanie0x00 <[email protected]>
  • Loading branch information
originalsouth and stephanie0x00 authored Jul 10, 2024
1 parent 09b1362 commit 1b518cf
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
21 changes: 14 additions & 7 deletions octopoes/octopoes/core/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from octopoes.models.exception import ObjectNotFoundException
from octopoes.models.explanation import InheritanceSection
from octopoes.models.ooi.config import Config
from octopoes.models.origin import Origin, OriginParameter, OriginType
from octopoes.models.pagination import Paginated
from octopoes.models.path import (
Expand Down Expand Up @@ -400,14 +401,20 @@ def _on_update_ooi(self, event: OOIDBEvent) -> None:
if event.new_data is None:
raise ValueError("Update event new_data should not be None")

inference_origins = self.origin_repository.list_origins(event.valid_time, source=event.new_data.reference)
inference_params = self.origin_parameter_repository.list_by_reference(
event.new_data.reference, valid_time=event.valid_time
)
for inference_param in inference_params:
inference_origins.append(self.origin_repository.get(inference_param.origin_id, event.valid_time))
if isinstance(event.new_data, Config):
relevant_bit_ids = [
bit.id for bit in get_bit_definitions().values() if bit.config_ooi_relation_path is not None
]
inference_origins = self.origin_repository.list_origins(event.valid_time, method=relevant_bit_ids)
else:
inference_origins = self.origin_repository.list_origins(event.valid_time, source=event.new_data.reference)
inference_params = self.origin_parameter_repository.list_by_reference(
event.new_data.reference, valid_time=event.valid_time
)
for inference_param in inference_params:
inference_origins.append(self.origin_repository.get(inference_param.origin_id, event.valid_time))

inference_origins = [o for o in inference_origins if o.origin_type == OriginType.INFERENCE]
inference_origins = [o for o in inference_origins if o.origin_type == OriginType.INFERENCE]
for inference_origin in inference_origins:
self._run_inference(inference_origin, event.valid_time)

Expand Down
7 changes: 6 additions & 1 deletion octopoes/octopoes/repositories/origin_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def list_origins(
task_id: UUID | None = None,
source: Reference | None = None,
result: Reference | None = None,
method: str | list[str] | None = None,
origin_type: OriginType | None = None,
) -> list[Origin]:
raise NotImplementedError
Expand Down Expand Up @@ -73,9 +74,10 @@ def list_origins(
task_id: UUID | None = None,
source: Reference | None = None,
result: Reference | None = None,
method: str | list[str] | None = None,
origin_type: OriginType | None = None,
) -> list[Origin]:
where_parameters = {"type": Origin.__name__}
where_parameters: dict[str, str | list[str]] = {"type": Origin.__name__}

if task_id:
where_parameters["task_id"] = str(task_id)
Expand All @@ -86,6 +88,9 @@ def list_origins(
if result:
where_parameters["result"] = str(result)

if method:
where_parameters["method"] = method

if origin_type:
where_parameters["origin_type"] = origin_type.value

Expand Down

0 comments on commit 1b518cf

Please sign in to comment.