Skip to content

Commit badb76d

Browse files
authored
fix sync on ascend (#19)
1 parent ccc62cb commit badb76d

File tree

4 files changed

+11
-7
lines changed

4 files changed

+11
-7
lines changed

lmdeploy/pytorch/engine/devices/ascend.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ def update_step_context(cls, step_context):
1717
single_attention_mask = torch.logical_not(
1818
torch.tril(
1919
torch.ones(step_context.q_seq_length[i],
20-
step_context.block_offsets.shape[1] * block_size,
20+
step_context.block_offsets.shape[1] *
21+
block_size,
2122
dtype=torch.bool).cuda(),
2223
diagonal=step_context.kv_seq_length[i] -
2324
step_context.q_seq_length[i],
@@ -38,6 +39,10 @@ def update_step_context(cls, step_context):
3839
kv_start_indices, device=step_context.block_offsets.device)
3940
setattr(step_context, 'kv_start_indices', kv_start_indices)
4041
setattr(step_context, 'attention_mask', attention_mask)
42+
setattr(step_context, 'q_start_loc', step_context.q_start_loc.cpu())
43+
setattr(step_context, 'q_seq_length', step_context.q_seq_length.cpu())
44+
setattr(step_context, 'kv_seq_length',
45+
step_context.kv_seq_length.cpu())
4146
is_unpaged_prefill = (not step_context.is_decoding) and all(
4247
(step_context.q_seq_length == step_context.kv_seq_length).tolist())
4348
setattr(step_context, 'is_unpaged_prefill', is_unpaged_prefill)

lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def apply_rotary_pos_emb(
2626
cos = cos[position_ids_1d].view(1, bs, 1, -1)
2727
sin = sin[position_ids_1d].view(1, bs, 1, -1)
2828
else:
29-
raise RuntimeError("Cannot handle cos/sin shape dims!")
29+
raise RuntimeError('Cannot handle cos/sin shape dims!')
3030

3131
if context:
3232
setattr(context, 'cos', cos)

lmdeploy/pytorch/kernels/ascend/rms_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ def rms_norm(hidden_states: Tensor,
1212
out = rms_norm_out
1313
else:
1414
out.copy_(rms_norm_out)
15-
return rms_norm_out
15+
return out

lmdeploy/pytorch/models/qwen2_moe.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
106106
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
107107
routing_weights = routing_weights.to(hidden_states.dtype)
108108

109-
out_states = torch.zeros(
110-
(batch_size * sequence_length, hidden_dim),
111-
dtype=hidden_states.dtype,
112-
device=hidden_states.device)
109+
out_states = torch.zeros((batch_size * sequence_length, hidden_dim),
110+
dtype=hidden_states.dtype,
111+
device=hidden_states.device)
113112

114113
expert_mask = torch.nn.functional.one_hot(
115114
selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

0 commit comments

Comments
 (0)