Skip to content

Commit fa540c3

Browse files
committed
rename diversity
1 parent 71cb014 commit fa540c3

File tree

4 files changed

+21
-21
lines changed

4 files changed

+21
-21
lines changed

conf/tokenizer/swin/loss.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
quant_weight: 1.
2-
entropy_weight: 0.2
1+
quant_weight: 1
2+
entropy_weight: 0
33
rec_loss: l1
44
rec_weight: 1
55
perceptual_loss:

pumit/tokenizer/loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,13 @@ def forward_gen(
126126
'rec_loss': rec_loss,
127127
'perceptual_loss': perceptual_loss,
128128
'quant_loss': vq_out.loss,
129+
'util_var': vq_out.util_var,
129130
'vq_loss': vq_loss,
130131
'gan_loss': gan_loss,
131132
'gan_weight': gan_weight,
132133
}
133134
if vq_out.entropy is not None:
134135
log_dict['entropy'] = vq_out.entropy
135-
log_dict['diversity'] = vq_out.diversity
136136
return loss, log_dict
137137

138138
def disc_fix_rgb(self, x: torch.Tensor, not_rgb: torch.Tensor):

pumit/tokenizer/vq.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class VectorQuantizerOutput:
1515
"""codebook index, can be a discrete one or soft one (categorical distribution over codebook)"""
1616
loss: torch.Tensor
1717
"""quantization loss (e.g., |z_q - e|, or prior distribution regularization)"""
18-
diversity: float
18+
util_var: float
1919
"""evaluate if the utilization of the codebook is "uniform" enough"""
2020
logits: torch.Tensor | None = None
2121
"""original logits over the codebook for probabilistic VQ"""
@@ -49,16 +49,16 @@ def _load_from_state_dict(self, state_dict: dict[str, torch.Tensor], prefix: str
4949
state_dict[proj_weight_key] = weight
5050
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
5151

52-
def cal_diversity(self, index: torch.Tensor):
53-
bias_correction = self.num_embeddings
54-
if index.ndim == 4:
55-
# discrete
56-
raise NotImplementedError
57-
else:
58-
# probabilistic
59-
p = einops.reduce(index, '... n_e -> n_e', 'mean')
60-
var = p.var(unbiased=False)
61-
return var * bias_correction
52+
# def cal_diversity(self, index: torch.Tensor):
53+
# bias_correction = self.num_embeddings
54+
# if index.ndim == 4:
55+
# # discrete
56+
# raise NotImplementedError
57+
# else:
58+
# # probabilistic
59+
# p = einops.reduce(index, '... n_e -> n_e', 'mean')
60+
# var = p.var(unbiased=False)
61+
# return var * bias_correction
6262

6363
def forward(self, z: torch.Tensor, fabric: Fabric | None = None) -> VectorQuantizerOutput:
6464
"""
@@ -113,11 +113,11 @@ def __init__(self, num_embeddings: int, embedding_dim: int, in_channels: int | N
113113
self.pdr_eps = pdr_eps
114114

115115
def get_pdr_loss(self, probs: torch.Tensor, fabric: Fabric | None):
116-
"""prior distribution regularization"""
116+
"""calculate prior distribution regularization and util_var"""
117117
mean_probs = einops.reduce(probs, '... d -> d', reduction='mean')
118118
if fabric is not None and fabric.world_size > 1:
119119
mean_probs = fabric.all_reduce(mean_probs) - mean_probs.detach() + mean_probs
120-
return (mean_probs * (mean_probs + self.pdr_eps).log()).sum()
120+
return (mean_probs * (mean_probs + self.pdr_eps).log()).sum(), mean_probs.var(unbiased=False)
121121

122122
def embed_index(self, index_probs: torch.Tensor):
123123
z_q = einops.einsum(index_probs, self.embedding.weight, '... ne, ne d -> ... d')
@@ -151,13 +151,13 @@ def adjust_temperature(self, global_step: int, max_steps: int):
151151

152152
def forward(self, z: torch.Tensor, fabric: Fabric | None = None):
153153
logits, probs, entropy = self.project_over_codebook(z)
154-
loss = self.get_pdr_loss(probs, fabric)
154+
loss, util_var = self.get_pdr_loss(probs, fabric)
155155
if self.training:
156156
index_probs = nnf.gumbel_softmax(logits, self.temperature, self.hard_gumbel, dim=-1)
157157
else:
158158
index_probs = probs
159159
z_q = self.embed_index(index_probs)
160-
return VectorQuantizerOutput(z_q, index_probs, loss, self.cal_diversity(index_probs), logits, entropy)
160+
return VectorQuantizerOutput(z_q, index_probs, loss, util_var, logits, entropy)
161161

162162
class SoftVQ(ProbabilisticVQ):
163163
def __init__(self, num_embeddings: int, embedding_dim: int, in_channels: int | None = None, prune: int | None = 3):
@@ -175,6 +175,6 @@ def forward(self, z: torch.Tensor, fabric: Fabric | None = None):
175175
index_probs = torch.zeros_like(logits)
176176
index_probs.scatter_(-1, top_indices, top_logits.softmax(dim=-1))
177177
index_probs = index_probs + probs - probs.detach()
178-
loss = self.get_pdr_loss(index_probs, fabric)
178+
loss, util_var = self.get_pdr_loss(index_probs, fabric)
179179
z_q = self.embed_index(index_probs)
180-
return VectorQuantizerOutput(z_q, index_probs, loss, self.cal_diversity(index_probs), logits, entropy)
180+
return VectorQuantizerOutput(z_q, index_probs, loss, util_var, logits, entropy)

third-party/LuoLib

0 commit comments

Comments
 (0)