Skip to content

Commit 9aacb82

Browse files
committed
fixing image_encoder to work with cuda_graphs
Summary: the combination of tensors on multiple devices in get_rel_pos was preventing cuda graphs from correctly optimizing things Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent 6fdee8f commit 9aacb82

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

segment_anything/modeling/image_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,8 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
315315
rel_pos_resized = rel_pos
316316

317317
# Scale the coords with short length if shapes for q and k are different.
318-
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
319-
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
318+
q_coords = (torch.arange(q_size).to(rel_pos.device)[:, None] * max(k_size / q_size, 1.0))
319+
k_coords = (torch.arange(k_size).to(rel_pos.device)[None, :] * max(q_size / k_size, 1.0))
320320
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
321321

322322
return rel_pos_resized[relative_coords.long()]

0 commit comments

Comments
 (0)