Skip to content

KAFKA-19478 [2/N]: Remove task pairs #20127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -91,8 +90,6 @@ private void initialize(final GroupSpec groupSpec, final TopologyDescriber topol
localState.totalCapacity = groupSpec.members().size();
localState.tasksPerMember = computeTasksPerMember(localState.allTasks, localState.totalCapacity);

localState.taskPairs = new TaskPairs(localState.allTasks * (localState.allTasks - 1) / 2);

localState.processIdToState = new HashMap<>();
localState.activeTaskToPrevMember = new HashMap<>();
localState.standbyTaskToPrevMember = new HashMap<>();
Expand Down Expand Up @@ -175,7 +172,7 @@ private void assignActive(final Set<TaskId> activeTasks) {
final Member prevMember = localState.activeTaskToPrevMember.get(task);
if (prevMember != null && hasUnfulfilledQuota(prevMember)) {
localState.processIdToState.get(prevMember.processId).addTask(prevMember.memberId, task, true);
updateHelpers(prevMember, task, true);
updateHelpers(prevMember, true);
it.remove();
}
}
Expand All @@ -187,7 +184,7 @@ private void assignActive(final Set<TaskId> activeTasks) {
final Member prevMember = findMemberWithLeastLoad(prevMembers, task, true);
if (prevMember != null && hasUnfulfilledQuota(prevMember)) {
localState.processIdToState.get(prevMember.processId).addTask(prevMember.memberId, task, true);
updateHelpers(prevMember, task, true);
updateHelpers(prevMember, true);
it.remove();
}
}
Expand All @@ -204,7 +201,7 @@ private void assignActive(final Set<TaskId> activeTasks) {
}
localState.processIdToState.get(member.processId).addTask(member.memberId, task, true);
it.remove();
updateHelpers(member, task, true);
updateHelpers(member, true);

}
}
Expand All @@ -221,20 +218,10 @@ private Member findMemberWithLeastLoad(final Set<Member> members, TaskId taskId,
if (members == null || members.isEmpty()) {
return null;
}
Set<Member> rightPairs = members.stream()
.filter(member -> localState.taskPairs.hasNewPair(taskId, localState.processIdToState.get(member.processId).assignedTasks()))
.collect(Collectors.toSet());
if (rightPairs.isEmpty()) {
rightPairs = members;
}
Optional<ProcessState> processWithLeastLoad = rightPairs.stream()
Optional<ProcessState> processWithLeastLoad = members.stream()
.map(member -> localState.processIdToState.get(member.processId))
.min(Comparator.comparingDouble(ProcessState::load));

// processWithLeastLoad must be present at this point, but we do a double check
if (processWithLeastLoad.isEmpty()) {
return null;
}
// if the same exact former member is needed
if (returnSameMember) {
return localState.standbyTaskToPrevMember.get(taskId).stream()
Expand Down Expand Up @@ -275,8 +262,7 @@ private void assignStandby(final Set<TaskId> standbyTasks, final int numStandbyR

// prev active task
Member prevMember = localState.activeTaskToPrevMember.get(task);
if (prevMember != null && availableProcesses.contains(prevMember.processId) && isLoadBalanced(prevMember.processId)
&& localState.taskPairs.hasNewPair(task, localState.processIdToState.get(prevMember.processId).assignedTasks())) {
if (prevMember != null && availableProcesses.contains(prevMember.processId) && isLoadBalanced(prevMember.processId)) {
standby = prevMember;
}

Expand Down Expand Up @@ -304,7 +290,7 @@ private void assignStandby(final Set<TaskId> standbyTasks, final int numStandbyR
}
}
localState.processIdToState.get(standby.processId).addTask(standby.memberId, task, false);
updateHelpers(standby, task, false);
updateHelpers(standby, false);
}

}
Expand All @@ -323,10 +309,7 @@ private boolean isLoadBalanced(final String processId) {
return process.hasCapacity() || isLeastLoadedProcess;
}

private void updateHelpers(final Member member, final TaskId taskId, final boolean isActive) {
// add all pair combinations: update taskPairs
localState.taskPairs.addPairs(taskId, localState.processIdToState.get(member.processId).assignedTasks());

private void updateHelpers(final Member member, final boolean isActive) {
if (isActive) {
// update task per process
maybeUpdateTasksPerMember(localState.processIdToState.get(member.processId).activeTaskCount());
Expand All @@ -344,75 +327,6 @@ private static int computeTasksPerMember(final int numberOfTasks, final int numb
return tasksPerMember;
}

private static class TaskPairs {
private final Set<Pair> pairs;
private final int maxPairs;

TaskPairs(final int maxPairs) {
this.maxPairs = maxPairs;
this.pairs = new HashSet<>(maxPairs);
}

boolean hasNewPair(final TaskId task1,
final Set<TaskId> taskIds) {
if (pairs.size() == maxPairs) {
return false;
}
if (taskIds.size() == 0) {
return true;
}
for (final TaskId taskId : taskIds) {
if (!pairs.contains(pair(task1, taskId))) {
return true;
}
}
return false;
}

void addPairs(final TaskId taskId, final Set<TaskId> assigned) {
for (final TaskId id : assigned) {
if (!id.equals(taskId))
pairs.add(pair(id, taskId));
}
}

Pair pair(final TaskId task1, final TaskId task2) {
if (task1.compareTo(task2) < 0) {
return new Pair(task1, task2);
}
return new Pair(task2, task1);
}


private static class Pair {
private final TaskId task1;
private final TaskId task2;

Pair(final TaskId task1, final TaskId task2) {
this.task1 = task1;
this.task2 = task2;
}

@Override
public boolean equals(final Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
final Pair pair = (Pair) o;
return Objects.equals(task1, pair.task1) &&
Objects.equals(task2, pair.task2);
}

@Override
public int hashCode() {
return Objects.hash(task1, task2);
}
}
}

static class Member {
private final String processId;
private final String memberId;
Expand All @@ -425,7 +339,6 @@ public Member(final String processId, final String memberId) {

private static class LocalState {
// helper data structures:
private TaskPairs taskPairs;
Map<TaskId, Member> activeTaskToPrevMember;
Map<TaskId, Set<Member>> standbyTaskToPrevMember;
Map<String, ProcessState> processIdToState;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static java.util.Arrays.asList;
import static org.apache.kafka.common.utils.Utils.mkEntry;
import static org.apache.kafka.common.utils.Utils.mkMap;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand Down Expand Up @@ -581,86 +579,6 @@ public void shouldAssignMoreTasksToClientWithMoreCapacity() {
assertEquals(4, getAllActiveTaskCount(result, "member1"));
}

@Test
public void shouldNotHaveSameAssignmentOnAnyTwoHosts() {
final AssignmentMemberSpec memberSpec1 = createAssignmentMemberSpec("process1");
final AssignmentMemberSpec memberSpec2 = createAssignmentMemberSpec("process2");
final AssignmentMemberSpec memberSpec3 = createAssignmentMemberSpec("process3");
final AssignmentMemberSpec memberSpec4 = createAssignmentMemberSpec("process4");
final List<String> allMemberIds = asList("member1", "member2", "member3", "member4");
Map<String, AssignmentMemberSpec> members = mkMap(
mkEntry("member1", memberSpec1), mkEntry("member2", memberSpec2), mkEntry("member3", memberSpec3), mkEntry("member4", memberSpec4));

final GroupAssignment result = assignor.assign(
new GroupSpecImpl(members,
mkMap(mkEntry(NUM_STANDBY_REPLICAS_CONFIG, "1"))),
new TopologyDescriberImpl(4, true, List.of("test-subtopology"))
);

for (final String memberId : allMemberIds) {
final List<Integer> taskIds = getAllTaskIds(result, memberId);
for (final String otherMemberId : allMemberIds) {
if (!memberId.equals(otherMemberId)) {
assertNotEquals(taskIds, getAllTaskIds(result, otherMemberId));
}
}
}
}

@Test
public void shouldNotHaveSameAssignmentOnAnyTwoHostsWhenThereArePreviousActiveTasks() {
final AssignmentMemberSpec memberSpec1 = createAssignmentMemberSpec("process1", mkMap(mkEntry("test-subtopology", Sets.newSet(1, 2))), Map.of());
final AssignmentMemberSpec memberSpec2 = createAssignmentMemberSpec("process2", mkMap(mkEntry("test-subtopology", Sets.newSet(3))), Map.of());
final AssignmentMemberSpec memberSpec3 = createAssignmentMemberSpec("process3", mkMap(mkEntry("test-subtopology", Sets.newSet(0))), Map.of());
final AssignmentMemberSpec memberSpec4 = createAssignmentMemberSpec("process4");
final List<String> allMemberIds = asList("member1", "member2", "member3", "member4");
Map<String, AssignmentMemberSpec> members = mkMap(
mkEntry("member1", memberSpec1), mkEntry("member2", memberSpec2), mkEntry("member3", memberSpec3), mkEntry("member4", memberSpec4));

final GroupAssignment result = assignor.assign(
new GroupSpecImpl(members,
mkMap(mkEntry(NUM_STANDBY_REPLICAS_CONFIG, "1"))),
new TopologyDescriberImpl(4, true, List.of("test-subtopology"))
);

for (final String memberId : allMemberIds) {
final List<Integer> taskIds = getAllTaskIds(result, memberId);
for (final String otherMemberId : allMemberIds) {
if (!memberId.equals(otherMemberId)) {
assertNotEquals(taskIds, getAllTaskIds(result, otherMemberId));
}
}
}
}

@Test
public void shouldNotHaveSameAssignmentOnAnyTwoHostsWhenThereArePreviousStandbyTasks() {
final AssignmentMemberSpec memberSpec1 = createAssignmentMemberSpec("process1",
mkMap(mkEntry("test-subtopology", Sets.newSet(1, 2))), mkMap(mkEntry("test-subtopology", Sets.newSet(3, 0))));
final AssignmentMemberSpec memberSpec2 = createAssignmentMemberSpec("process2",
mkMap(mkEntry("test-subtopology", Sets.newSet(3, 0))), mkMap(mkEntry("test-subtopology", Sets.newSet(1, 2))));
final AssignmentMemberSpec memberSpec3 = createAssignmentMemberSpec("process3");
final AssignmentMemberSpec memberSpec4 = createAssignmentMemberSpec("process4");
final List<String> allMemberIds = asList("member1", "member2", "member3", "member4");
Map<String, AssignmentMemberSpec> members = mkMap(
mkEntry("member1", memberSpec1), mkEntry("member2", memberSpec2), mkEntry("member3", memberSpec3), mkEntry("member4", memberSpec4));

final GroupAssignment result = assignor.assign(
new GroupSpecImpl(members,
mkMap(mkEntry(NUM_STANDBY_REPLICAS_CONFIG, "1"))),
new TopologyDescriberImpl(4, true, List.of("test-subtopology"))
);

for (final String memberId : allMemberIds) {
final List<Integer> taskIds = getAllTaskIds(result, memberId);
for (final String otherMemberId : allMemberIds) {
if (!memberId.equals(otherMemberId)) {
assertNotEquals(taskIds, getAllTaskIds(result, otherMemberId));
}
}
}
}

@Test
public void shouldReBalanceTasksAcrossAllClientsWhenCapacityAndTaskCountTheSame() {
final AssignmentMemberSpec memberSpec3 = createAssignmentMemberSpec("process3", mkMap(mkEntry("test-subtopology", Sets.newSet(0, 1, 2, 3))), Map.of());
Expand Down Expand Up @@ -1020,13 +938,6 @@ private Map<String, Set<Integer>> mergeAllActiveTasks(GroupAssignment result, St
return res;
}

private List<Integer> getAllTaskIds(GroupAssignment result, String... memberIds) {
List<Integer> res = new ArrayList<>();
res.addAll(getAllActiveTaskIds(result, memberIds));
res.addAll(getAllStandbyTaskIds(result, memberIds));
return res;
}

private Map<String, Set<Integer>> mergeAllStandbyTasks(GroupAssignment result, String... memberIds) {
Map<String, Set<Integer>> res = new HashMap<>();
for (String memberId : memberIds) {
Expand Down