-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Description
Hi,
I'm having trouble understanding the label_g[:,1] used in computeBatchLoss():
Lines 225 to 238 in d6c0210
def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g): | |
input_t, label_t, _series_list, _center_list = batch_tup | |
input_g = input_t.to(self.device, non_blocking=True) | |
label_g = label_t.to(self.device, non_blocking=True) | |
logits_g, probability_g = self.model(input_g) | |
loss_func = nn.CrossEntropyLoss(reduction='none') | |
loss_g = loss_func( | |
logits_g, | |
label_g[:,1], | |
) | |
start_ndx = batch_ndx * batch_size |
Assume that batch size is 32, the logits_g
will have shape [32, 2]
.
And the label_g
have the same size [32, 2]
, if I didn't get it wrong, it should be the one-hot vector defined in
Lines 203 to 210 in d6c0210
pos_t = torch.tensor([ | |
not candidateInfo_tup.isNodule_bool, | |
candidateInfo_tup.isNodule_bool | |
], | |
dtype=torch.long, | |
) | |
return candidate_t, pos_t, candidateInfo_tup.series_uid, torch.tensor(center_irc) |
My quesetion is that in the CrossEntropyLoss function, should we use label_g
instead of label_g[:,1]
( which take the 2nd column for each item)? Something like:
loss_g = loss_func(
logits_g,
label_g, # the one-hot vector instead of label_g[:,1]
)
or
loss_g = loss_func(
logits_g,
torch.argmax(label_g, dim=1), # if we want to use the index
)
Thanks
Metadata
Metadata
Assignees
Labels
No labels