Skip to content

Commit 82e7289

Browse files
committed
Add bulk assignment API
1 parent 6ef85df commit 82e7289

File tree

4 files changed

+96
-2
lines changed

4 files changed

+96
-2
lines changed

backend/examples/assignment/strategies.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import abc
22
import dataclasses
3+
import enum
34
import random
45
from typing import List
56

@@ -12,6 +13,23 @@ class Assignment:
1213
example: int
1314

1415

16+
class StrategyName(enum.Enum):
17+
weighted_sequential = enum.auto()
18+
weighted_random = enum.auto()
19+
sampling_without_replacement = enum.auto()
20+
21+
22+
def create_assignment_strategy(strategy_name: StrategyName, dataset_size: int, weights: List[int]) -> "BaseStrategy":
23+
if strategy_name == StrategyName.weighted_sequential:
24+
return WeightedSequentialStrategy(dataset_size, weights)
25+
elif strategy_name == StrategyName.weighted_random:
26+
return WeightedRandomStrategy(dataset_size, weights)
27+
elif strategy_name == StrategyName.sampling_without_replacement:
28+
return SamplingWithoutReplacementStrategy(dataset_size, weights)
29+
else:
30+
raise ValueError(f"Unknown strategy name: {strategy_name}")
31+
32+
1533
class BaseStrategy(abc.ABC):
1634
@abc.abstractmethod
1735
def assign(self) -> List[Assignment]:
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from typing import List
2+
3+
from pydantic import BaseModel, NonNegativeInt
4+
5+
6+
class Workload(BaseModel):
7+
weight: NonNegativeInt
8+
member_id: int
9+
10+
11+
class WorkloadAllocation(BaseModel):
12+
workloads: List[Workload]
13+
14+
@property
15+
def member_ids(self):
16+
return [w.member_id for w in self.workloads]
17+
18+
@property
19+
def weights(self):
20+
return [w.weight for w in self.workloads]

backend/examples/urls.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from django.urls import path
22

3-
from .views.assignment import AssignmentDetail, AssignmentList, ResetAssignment
3+
from .views.assignment import (
4+
AssignmentDetail,
5+
AssignmentList,
6+
BulkAssignment,
7+
ResetAssignment,
8+
)
49
from .views.comment import CommentDetail, CommentList
510
from .views.example import ExampleDetail, ExampleList
611
from .views.example_state import ExampleStateList
@@ -9,6 +14,7 @@
914
path(route="assignments", view=AssignmentList.as_view(), name="assignment_list"),
1015
path(route="assignments/<uuid:assignment_id>", view=AssignmentDetail.as_view(), name="assignment_detail"),
1116
path(route="assignments/reset", view=ResetAssignment.as_view(), name="assignment_reset"),
17+
path(route="assignments/bulk_assign", view=BulkAssignment.as_view(), name="bulk_assignment"),
1218
path(route="examples", view=ExampleList.as_view(), name="example_list"),
1319
path(route="examples/<int:example_id>", view=ExampleDetail.as_view(), name="example_detail"),
1420
path(route="comments", view=CommentList.as_view(), name="comment_list"),

backend/examples/views/assignment.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from django.shortcuts import get_object_or_404
22
from django_filters.rest_framework import DjangoFilterBackend
3+
from pydantic import ValidationError
34
from rest_framework import filters, generics, status
45
from rest_framework.permissions import IsAuthenticated
56
from rest_framework.views import APIView, Response
67

8+
from examples.assignment.strategies import StrategyName, create_assignment_strategy
9+
from examples.assignment.workload import WorkloadAllocation
710
from examples.models import Assignment
811
from examples.serializers import AssignmentSerializer
9-
from projects.models import Project
12+
from projects.models import Member, Project
1013
from projects.permissions import IsProjectAdmin, IsProjectStaffAndReadOnly
1114

1215

@@ -46,3 +49,50 @@ def project(self):
4649
def delete(self, *args, **kwargs):
4750
Assignment.objects.filter(project=self.project).delete()
4851
return Response(status=status.HTTP_204_NO_CONTENT)
52+
53+
54+
class BulkAssignment(APIView):
55+
serializer_class = AssignmentSerializer
56+
permission_classes = [IsAuthenticated & IsProjectAdmin]
57+
58+
def post(self, *args, **kwargs):
59+
try:
60+
strategy_name = StrategyName[self.request.data["strategy_name"]]
61+
except KeyError:
62+
return Response(
63+
{"detail": "Invalid strategy name"},
64+
status=status.HTTP_400_BAD_REQUEST,
65+
)
66+
67+
try:
68+
workload_allocation = WorkloadAllocation(workloads=self.request.data["workloads"])
69+
except ValidationError as e:
70+
return Response(
71+
{"detail": e.errors()},
72+
status=status.HTTP_400_BAD_REQUEST,
73+
)
74+
75+
project = get_object_or_404(Project, pk=self.kwargs["project_id"])
76+
members = Member.objects.filter(project=project, pk__in=workload_allocation.member_ids)
77+
if len(members) != len(workload_allocation.member_ids):
78+
return Response(
79+
{"detail": "Invalid member ids"},
80+
status=status.HTTP_400_BAD_REQUEST,
81+
)
82+
# Sort members by workload_allocation.member_ids
83+
members = sorted(members, key=lambda m: workload_allocation.member_ids.index(m.id))
84+
85+
dataset_size = project.examples.count() # Todo: unassigned examples
86+
strategy = create_assignment_strategy(strategy_name, dataset_size, workload_allocation.weights)
87+
assignments = strategy.assign()
88+
example_ids = project.examples.values_list("pk", flat=True)
89+
assignments = [
90+
Assignment(
91+
project=project,
92+
example=example_ids[assignment.example],
93+
assignee=members[assignment.user].user,
94+
)
95+
for assignment in assignments
96+
]
97+
Assignment.objects.bulk_create(assignments)
98+
return Response(status=status.HTTP_201_CREATED)

0 commit comments

Comments
 (0)