Closed
Description
Hi,
I'm having trouble understanding the label_g[:,1] used in computeBatchLoss():
Lines 225 to 238 in d6c0210
Assume that batch size is 32, the logits_g
will have shape [32, 2]
.
And the label_g
have the same size [32, 2]
, if I didn't get it wrong, it should be the one-hot vector defined in
Lines 203 to 210 in d6c0210
My quesetion is that in the CrossEntropyLoss function, should we use label_g
instead of label_g[:,1]
( which take the 2nd column for each item)? Something like:
loss_g = loss_func(
logits_g,
label_g, # the one-hot vector instead of label_g[:,1]
)
or
loss_g = loss_func(
logits_g,
torch.argmax(label_g, dim=1), # if we want to use the index
)
Thanks
Metadata
Metadata
Assignees
Labels
No labels