Skip to content

Commit ad84c02

Browse files
committed
Fix weighted sequential strategy
1 parent 277ee7b commit ad84c02

File tree

1 file changed

+4
-10
lines changed

1 file changed

+4
-10
lines changed

backend/examples/assignment/strategies.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,10 @@ def __init__(self, dataset_size: int, weights: List[int]):
4545

4646
def assign(self) -> List[Assignment]:
4747
assignments = []
48-
proba = np.array(self.weights) / 100
49-
counts = np.round(proba * self.dataset_size).astype(int)
50-
reminder = self.dataset_size - sum(counts)
51-
for i in np.random.choice(range(len(self.weights)), size=reminder, p=proba):
52-
counts[i] += 1
53-
54-
start = 0
55-
for user, count in enumerate(counts):
56-
assignments.extend([Assignment(user=user, example=example) for example in range(start, start + count)])
57-
start += count
48+
ratio = list(np.round(np.cumsum(self.weights) / 100 * self.dataset_size).astype(int))
49+
ratio = [0] + ratio[:-1] + [self.dataset_size]
50+
for user, (start, end) in enumerate(zip(ratio, ratio[1:])): # Todo: use itertools.pairwise
51+
assignments.extend([Assignment(user=user, example=example) for example in range(start, end)])
5852
return assignments
5953

6054

0 commit comments

Comments
 (0)