Skip to content

add node kind to node field specifiers #6359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions backend/infrahub/core/branch/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,6 @@ async def merge_branch(

merger: BranchMerger | None = None
async with lock.registry.global_graph_lock():
# await update_diff(model=RequestDiffUpdate(branch_name=obj.name))

diff_repository = await component_registry.get_component(DiffRepository, db=db, branch=obj)
diff_coordinator = await component_registry.get_component(DiffCoordinator, db=db, branch=obj)
diff_merger = await component_registry.get_component(DiffMerger, db=db, branch=obj)
Expand Down
7 changes: 4 additions & 3 deletions backend/infrahub/core/diff/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from infrahub.database import InfrahubDatabase
from infrahub.log import get_logger

from .model.field_specifiers_map import NodeFieldSpecifierMap
from .model.path import CalculatedDiffs

log = get_logger()
Expand All @@ -26,8 +27,8 @@ class DiffCalculationRequest:
branch_from_time: Timestamp
from_time: Timestamp
to_time: Timestamp
current_node_field_specifiers: dict[str, set[str]] | None = field(default=None)
new_node_field_specifiers: dict[str, set[str]] | None = field(default=None)
current_node_field_specifiers: NodeFieldSpecifierMap | None = field(default=None)
new_node_field_specifiers: NodeFieldSpecifierMap | None = field(default=None)


class DiffCalculator:
Expand Down Expand Up @@ -75,7 +76,7 @@ async def calculate_diff(
from_time: Timestamp,
to_time: Timestamp,
include_unchanged: bool = True,
previous_node_specifiers: dict[str, set[str]] | None = None,
previous_node_specifiers: NodeFieldSpecifierMap | None = None,
) -> CalculatedDiffs:
if diff_branch.name == registry.default_branch:
diff_branch_from_time = from_time
Expand Down
3 changes: 1 addition & 2 deletions backend/infrahub/core/diff/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,7 @@ def _combine_nodes(self, node_pairs: list[NodePair]) -> set[EnrichedDiffNode]:
):
combined_nodes.add(
EnrichedDiffNode(
uuid=node_pair.later.uuid,
kind=node_pair.later.kind,
identifier=node_pair.later.identifier,
label=node_pair.later.label,
changed_at=node_pair.later.changed_at or node_pair.earlier.changed_at,
action=combined_action,
Expand Down
72 changes: 44 additions & 28 deletions backend/infrahub/core/diff/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
from infrahub.exceptions import ValidationError
from infrahub.log import get_logger

from .model.field_specifiers_map import NodeFieldSpecifierMap
from .model.path import (
BranchTrackingId,
EnrichedDiffRoot,
EnrichedDiffRootMetadata,
EnrichedDiffs,
EnrichedDiffsMetadata,
NameTrackingId,
NodeIdentifier,
TrackingId,
)

Expand Down Expand Up @@ -43,7 +45,7 @@ class EnrichedDiffRequest:
from_time: Timestamp
to_time: Timestamp
tracking_id: TrackingId
node_field_specifiers: dict[str, set[str]] = field(default_factory=dict)
node_field_specifiers: NodeFieldSpecifierMap = field(default_factory=NodeFieldSpecifierMap)

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -141,15 +143,17 @@ async def update_branch_diff(self, base_branch: Branch, diff_branch: Branch) ->
self.lock_registry.get(name=incremental_lock_name, namespace=self.lock_namespace),
):
log.info(f"Acquired lock to run branch diff update for {base_branch.name} - {diff_branch.name}")
enriched_diffs = await self._update_diffs(
enriched_diffs, node_identifiers_to_drop = await self._update_diffs(
base_branch=base_branch,
diff_branch=diff_branch,
from_time=from_time,
to_time=to_time,
tracking_id=tracking_id,
force_branch_refresh=False,
)
await self.diff_repo.save(enriched_diffs=enriched_diffs)
await self.diff_repo.save(
enriched_diffs=enriched_diffs, node_identifiers_to_drop=list(node_identifiers_to_drop)
)
await self._update_core_data_checks(enriched_diff=enriched_diffs.diff_branch_diff)
log.info(f"Branch diff update complete for {base_branch.name} - {diff_branch.name}")
return enriched_diffs.diff_branch_diff
Expand All @@ -168,7 +172,7 @@ async def create_or_update_arbitrary_timeframe_diff(
)
async with self.lock_registry.get(name=general_lock_name, namespace=self.lock_namespace):
log.info(f"Acquired lock to run arbitrary diff update for {base_branch.name} - {diff_branch.name}")
enriched_diffs = await self._update_diffs(
enriched_diffs, node_identifiers_to_drop = await self._update_diffs(
base_branch=base_branch,
diff_branch=diff_branch,
from_time=from_time,
Expand All @@ -177,7 +181,9 @@ async def create_or_update_arbitrary_timeframe_diff(
force_branch_refresh=False,
)

await self.diff_repo.save(enriched_diffs=enriched_diffs)
await self.diff_repo.save(
enriched_diffs=enriched_diffs, node_identifiers_to_drop=list(node_identifiers_to_drop)
)
await self._update_core_data_checks(enriched_diff=enriched_diffs.diff_branch_diff)
log.info(f"Arbitrary diff update complete for {base_branch.name} - {diff_branch.name}")
return enriched_diffs.diff_branch_diff
Expand Down Expand Up @@ -205,7 +211,7 @@ async def recalculate(
from_time = current_branch_diff.from_time
branched_from_time = Timestamp(diff_branch.get_branched_from())
from_time = max(from_time, branched_from_time)
enriched_diffs = await self._update_diffs(
enriched_diffs, _ = await self._update_diffs(
base_branch=base_branch,
diff_branch=diff_branch,
from_time=branched_from_time,
Expand Down Expand Up @@ -282,7 +288,7 @@ async def _update_diffs(
to_time: Timestamp,
tracking_id: TrackingId,
force_branch_refresh: Literal[True] = ...,
) -> EnrichedDiffs: ...
) -> tuple[EnrichedDiffs, set[NodeIdentifier]]: ...

@overload
async def _update_diffs(
Expand All @@ -293,7 +299,7 @@ async def _update_diffs(
to_time: Timestamp,
tracking_id: TrackingId,
force_branch_refresh: Literal[False] = ...,
) -> EnrichedDiffs | EnrichedDiffsMetadata: ...
) -> tuple[EnrichedDiffs | EnrichedDiffsMetadata, set[NodeIdentifier]]: ...

async def _update_diffs(
self,
Expand All @@ -303,7 +309,7 @@ async def _update_diffs(
to_time: Timestamp,
tracking_id: TrackingId,
force_branch_refresh: bool = False,
) -> EnrichedDiffs | EnrichedDiffsMetadata:
) -> tuple[EnrichedDiffs | EnrichedDiffsMetadata, set[NodeIdentifier]]:
# start with empty diffs b/c we only care about their metadata for now, hydrate them with data as needed
diff_pairs_metadata = await self.diff_repo.get_diff_pairs_metadata(
base_branch_names=[base_branch.name],
Expand All @@ -312,7 +318,7 @@ async def _update_diffs(
to_time=to_time,
tracking_id=tracking_id,
)
aggregated_enriched_diffs = await self._aggregate_enriched_diffs(
aggregated_enriched_diffs, node_identifiers_to_drop = await self._aggregate_enriched_diffs(
diff_request=EnrichedDiffRequest(
base_branch=base_branch,
diff_branch=diff_branch,
Expand Down Expand Up @@ -343,7 +349,7 @@ async def _update_diffs(
# this is an EnrichedDiffsMetadata, so there are no nodes to enrich
if not isinstance(aggregated_enriched_diffs, EnrichedDiffs):
aggregated_enriched_diffs.update_metadata(from_time=from_time, to_time=to_time, tracking_id=tracking_id)
return aggregated_enriched_diffs
return aggregated_enriched_diffs, set()

await self.conflicts_enricher.add_conflicts_to_branch_diff(
base_diff_root=aggregated_enriched_diffs.base_branch_diff,
Expand All @@ -353,27 +359,27 @@ async def _update_diffs(
enriched_diff_root=aggregated_enriched_diffs.diff_branch_diff, conflicts_only=True
)

return aggregated_enriched_diffs
return aggregated_enriched_diffs, node_identifiers_to_drop

@overload
async def _aggregate_enriched_diffs(
self,
diff_request: EnrichedDiffRequest,
partial_enriched_diffs: list[EnrichedDiffsMetadata],
) -> EnrichedDiffs | EnrichedDiffsMetadata: ...
) -> tuple[EnrichedDiffs | EnrichedDiffsMetadata, set[NodeIdentifier]]: ...

@overload
async def _aggregate_enriched_diffs(
self,
diff_request: EnrichedDiffRequest,
partial_enriched_diffs: None,
) -> EnrichedDiffs: ...
) -> tuple[EnrichedDiffs, set[NodeIdentifier]]: ...

async def _aggregate_enriched_diffs(
self,
diff_request: EnrichedDiffRequest,
partial_enriched_diffs: list[EnrichedDiffsMetadata] | None,
) -> EnrichedDiffs | EnrichedDiffsMetadata:
) -> tuple[EnrichedDiffs | EnrichedDiffsMetadata, set[NodeIdentifier]]:
"""
If return is an EnrichedDiffsMetadata, it acts as a pointer to a diff in the database that has all the
necessary data for this diff_request. Might have a different time range and/or tracking_id
Expand All @@ -385,6 +391,7 @@ async def _aggregate_enriched_diffs(
diff_request=diff_request, is_incremental_diff=False
)

node_identifiers_to_drop: set[NodeIdentifier] = set()
if partial_enriched_diffs is not None and not aggregated_enriched_diffs:
ordered_diffs = self._get_ordered_diff_pairs(diff_pairs=partial_enriched_diffs, allow_overlap=False)
ordered_diff_reprs = [repr(d) for d in ordered_diffs]
Expand Down Expand Up @@ -430,31 +437,31 @@ async def _aggregate_enriched_diffs(
)
current_time = end_time

aggregated_enriched_diffs = await self._concatenate_diffs_and_requests(
aggregated_enriched_diffs, node_identifiers_to_drop = await self._concatenate_diffs_and_requests(
diff_or_request_list=incremental_diffs_and_requests, full_diff_request=diff_request
)

# no changes during this time period, so generate an EnrichedDiffs with no nodes
if not aggregated_enriched_diffs:
return self._build_enriched_diffs_with_no_nodes(diff_request=diff_request)
return self._build_enriched_diffs_with_no_nodes(diff_request=diff_request), node_identifiers_to_drop

# metadata-only diff, means that a diff exists in the database that covers at least
# part of this time period, but it might need to have its start or end time extended
# to cover time ranges with no changes
if not isinstance(aggregated_enriched_diffs, EnrichedDiffs):
return aggregated_enriched_diffs
return aggregated_enriched_diffs, node_identifiers_to_drop

# a new diff (with nodes) covering the time period
aggregated_enriched_diffs.update_metadata(
from_time=diff_request.from_time, to_time=diff_request.to_time, tracking_id=diff_request.tracking_id
)
return aggregated_enriched_diffs
return aggregated_enriched_diffs, node_identifiers_to_drop

async def _concatenate_diffs_and_requests(
self,
diff_or_request_list: Sequence[EnrichedDiffsMetadata | EnrichedDiffRequest | None],
full_diff_request: EnrichedDiffRequest,
) -> EnrichedDiffs | EnrichedDiffsMetadata | None:
) -> tuple[EnrichedDiffs | EnrichedDiffsMetadata | None, set[NodeIdentifier]]:
"""
Returns None if diff_or_request_list is empty or all Nones
meaning there are no changes for the diff during this time period
Expand All @@ -464,7 +471,7 @@ async def _concatenate_diffs_and_requests(
meaning multiple diffs (some that may have been freshly calculated) were combined
"""
previous_diff_pair: EnrichedDiffs | EnrichedDiffsMetadata | None = None
updated_node_uuids: set[str] = set()
updated_node_identifiers: set[NodeIdentifier] = set()
for diff_or_request in diff_or_request_list:
if isinstance(diff_or_request, EnrichedDiffRequest):
if previous_diff_pair:
Expand All @@ -478,8 +485,8 @@ async def _concatenate_diffs_and_requests(
calculated_diff = await self._calculate_enriched_diff(
diff_request=diff_or_request, is_incremental_diff=is_incremental_diff
)
updated_node_uuids |= calculated_diff.base_node_uuids
updated_node_uuids |= calculated_diff.branch_node_uuids
updated_node_identifiers |= calculated_diff.base_node_identifiers
updated_node_identifiers |= calculated_diff.branch_node_identifiers
single_enriched_diffs: EnrichedDiffs | EnrichedDiffsMetadata = calculated_diff

elif isinstance(diff_or_request, EnrichedDiffsMetadata):
Expand All @@ -495,17 +502,22 @@ async def _concatenate_diffs_and_requests(
previous_diff_pair = await self._combine_diffs(
earlier=previous_diff_pair,
later=single_enriched_diffs,
node_uuids=updated_node_uuids,
node_identifiers=updated_node_identifiers,
)
log.info("Diffs combined.")

return previous_diff_pair
node_identifiers_to_drop: set[NodeIdentifier] = set()
if isinstance(previous_diff_pair, EnrichedDiffs):
# nodes that were updated and that no longer exist on this diff have been removed
node_identifiers_to_drop = updated_node_identifiers - previous_diff_pair.branch_node_identifiers

return previous_diff_pair, node_identifiers_to_drop

async def _combine_diffs(
self,
earlier: EnrichedDiffs | EnrichedDiffsMetadata,
later: EnrichedDiffs | EnrichedDiffsMetadata,
node_uuids: set[str],
node_identifiers: set[NodeIdentifier],
) -> EnrichedDiffs | EnrichedDiffsMetadata:
log.info(f"Earlier diff to combine: {earlier!r}")
log.info(f"Later diff to combine: {later!r}")
Expand All @@ -522,11 +534,15 @@ async def _combine_diffs(
# hydrate the diffs to combine, if necessary
if not isinstance(earlier, EnrichedDiffs):
log.info("Hydrating earlier diff...")
earlier = await self.diff_repo.hydrate_diff_pair(enriched_diffs_metadata=earlier, node_uuids=node_uuids)
earlier = await self.diff_repo.hydrate_diff_pair(
enriched_diffs_metadata=earlier, node_identifiers=node_identifiers
)
log.info("Earlier diff hydrated.")
if not isinstance(later, EnrichedDiffs):
log.info("Hydrating later diff...")
later = await self.diff_repo.hydrate_diff_pair(enriched_diffs_metadata=later, node_uuids=node_uuids)
later = await self.diff_repo.hydrate_diff_pair(
enriched_diffs_metadata=later, node_identifiers=node_identifiers
)
log.info("Later diff hydrated.")

return await self.diff_combiner.combine(earlier_diffs=earlier, later_diffs=later)
Expand Down
5 changes: 3 additions & 2 deletions backend/infrahub/core/diff/data_check_synchronizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import Enum

from infrahub.core.constants import BranchConflictKeep, InfrahubKind
from infrahub.core.diff.query.filters import EnrichedDiffQueryFilters
from infrahub.core.integrity.object_conflict.conflict_recorder import ObjectConflictValidatorRecorder
from infrahub.core.manager import NodeManager
from infrahub.core.node import Node
Expand Down Expand Up @@ -74,7 +75,7 @@ async def synchronize(self, enriched_diff: EnrichedDiffRoot | EnrichedDiffRootMe
retrieved_diff_conflicts_only = await self.diff_repository.get_one(
diff_branch_name=enriched_diff.diff_branch_name,
diff_id=enriched_diff.uuid,
filters={"only_conflicted": True},
filters=EnrichedDiffQueryFilters(only_conflicted=True),
)
enriched_diff_all_conflicts = retrieved_diff_conflicts_only
# if `enriched_diff` is an EnrichedDiffRootsMetadata, then there have been no changes to the diff and
Expand Down Expand Up @@ -116,7 +117,7 @@ def _get_keep_branch_for_enriched_conflict(
def _update_diff_conflicts(self, updated_diff: EnrichedDiffRoot, retrieved_diff: EnrichedDiffRoot) -> None:
for updated_node in updated_diff.nodes:
try:
retrieved_node = retrieved_diff.get_node(node_uuid=updated_node.uuid)
retrieved_node = retrieved_diff.get_node(node_identifier=updated_node.identifier)
except ValueError:
retrieved_node = None
if not retrieved_node:
Expand Down
Loading
Loading