Open
Description
Cream/AutoFormer/model/module/qkv_super.py
Lines 72 to 77 in 4a13c40
I think, there's something wrong in the way weight sharing is done here. I think this code should be:
N = weight.size(0) // 3
sample_weight = torch.cat([sample_weight[i*N:i*N+sample_out_dim//3, :] for i in range(3)], dim=0)
To be more intuitive, I drew a schematic diagram to represent the way 4 and 5 heads SA is shared with Linear.weight.
Maybe I misunderstood the implementation here, can you help check it?
Metadata
Metadata
Assignees
Labels
No labels