@@ -15,7 +15,7 @@ class VectorQuantizerOutput:
15
15
"""codebook index, can be a discrete one or soft one (categorical distribution over codebook)"""
16
16
loss : torch .Tensor
17
17
"""quantization loss (e.g., |z_q - e|, or prior distribution regularization)"""
18
- diversity : float
18
+ util_var : float
19
19
"""evaluate if the utilization of the codebook is "uniform" enough"""
20
20
logits : torch .Tensor | None = None
21
21
"""original logits over the codebook for probabilistic VQ"""
@@ -49,16 +49,16 @@ def _load_from_state_dict(self, state_dict: dict[str, torch.Tensor], prefix: str
49
49
state_dict [proj_weight_key ] = weight
50
50
return super ()._load_from_state_dict (state_dict , prefix , * args , ** kwargs )
51
51
52
- def cal_diversity (self , index : torch .Tensor ):
53
- bias_correction = self .num_embeddings
54
- if index .ndim == 4 :
55
- # discrete
56
- raise NotImplementedError
57
- else :
58
- # probabilistic
59
- p = einops .reduce (index , '... n_e -> n_e' , 'mean' )
60
- var = p .var (unbiased = False )
61
- return var * bias_correction
52
+ # def cal_diversity(self, index: torch.Tensor):
53
+ # bias_correction = self.num_embeddings
54
+ # if index.ndim == 4:
55
+ # # discrete
56
+ # raise NotImplementedError
57
+ # else:
58
+ # # probabilistic
59
+ # p = einops.reduce(index, '... n_e -> n_e', 'mean')
60
+ # var = p.var(unbiased=False)
61
+ # return var * bias_correction
62
62
63
63
def forward (self , z : torch .Tensor , fabric : Fabric | None = None ) -> VectorQuantizerOutput :
64
64
"""
@@ -113,11 +113,11 @@ def __init__(self, num_embeddings: int, embedding_dim: int, in_channels: int | N
113
113
self .pdr_eps = pdr_eps
114
114
115
115
def get_pdr_loss (self , probs : torch .Tensor , fabric : Fabric | None ):
116
- """prior distribution regularization"""
116
+ """calculate prior distribution regularization and util_var """
117
117
mean_probs = einops .reduce (probs , '... d -> d' , reduction = 'mean' )
118
118
if fabric is not None and fabric .world_size > 1 :
119
119
mean_probs = fabric .all_reduce (mean_probs ) - mean_probs .detach () + mean_probs
120
- return (mean_probs * (mean_probs + self .pdr_eps ).log ()).sum ()
120
+ return (mean_probs * (mean_probs + self .pdr_eps ).log ()).sum (), mean_probs . var ( unbiased = False )
121
121
122
122
def embed_index (self , index_probs : torch .Tensor ):
123
123
z_q = einops .einsum (index_probs , self .embedding .weight , '... ne, ne d -> ... d' )
@@ -151,13 +151,13 @@ def adjust_temperature(self, global_step: int, max_steps: int):
151
151
152
152
def forward (self , z : torch .Tensor , fabric : Fabric | None = None ):
153
153
logits , probs , entropy = self .project_over_codebook (z )
154
- loss = self .get_pdr_loss (probs , fabric )
154
+ loss , util_var = self .get_pdr_loss (probs , fabric )
155
155
if self .training :
156
156
index_probs = nnf .gumbel_softmax (logits , self .temperature , self .hard_gumbel , dim = - 1 )
157
157
else :
158
158
index_probs = probs
159
159
z_q = self .embed_index (index_probs )
160
- return VectorQuantizerOutput (z_q , index_probs , loss , self . cal_diversity ( index_probs ) , logits , entropy )
160
+ return VectorQuantizerOutput (z_q , index_probs , loss , util_var , logits , entropy )
161
161
162
162
class SoftVQ (ProbabilisticVQ ):
163
163
def __init__ (self , num_embeddings : int , embedding_dim : int , in_channels : int | None = None , prune : int | None = 3 ):
@@ -175,6 +175,6 @@ def forward(self, z: torch.Tensor, fabric: Fabric | None = None):
175
175
index_probs = torch .zeros_like (logits )
176
176
index_probs .scatter_ (- 1 , top_indices , top_logits .softmax (dim = - 1 ))
177
177
index_probs = index_probs + probs - probs .detach ()
178
- loss = self .get_pdr_loss (index_probs , fabric )
178
+ loss , util_var = self .get_pdr_loss (index_probs , fabric )
179
179
z_q = self .embed_index (index_probs )
180
- return VectorQuantizerOutput (z_q , index_probs , loss , self . cal_diversity ( index_probs ) , logits , entropy )
180
+ return VectorQuantizerOutput (z_q , index_probs , loss , util_var , logits , entropy )
0 commit comments