Open
Description
Hello, I want to reproduce your great job, but to my limited knowledge, I have two questions right now.
Firstly, I'm trying to rewrite the training phrase and beginning to train on the wikiart with content-dir of 'wikiart/Rococo' while style-dir of 'wikiart/Symbolism', but the intermediate result is not good as you, so I want to know what content-dir and style-dir you choose on the wikiart datasets?
Secondly, my loss on style distribution could not converge, it is always around between 4.2-4.4. My code is as below:
class StyleDistLoss(nn.Module):
'''
style distribition loss of s and s'
'''
def __init__(self, pool_size):
super(StyleDistLoss, self).__init__()
self.pool_size = pool_size
if self.pool_size > 0:
self.num_style_batch = 0
self.style_batches = []
self.loss = nn.L1Loss()
def __call__(self, sc, st):
'''
return the standart Gaussian distribution loss of input
style source {sc} and style traget {st} which are respective to s and s' in the paper
'''
styles = []
if self.pool_size == 0:
styles.extend([sc, st])
else:
styles += self.style_batches
styles.extend([sc, st])
detach_sc = sc.clone().detach()
detach_st = st.clone().detach()
if self.num_style_batch + 2 < self.pool_size:
self.style_batches.extend([detach_sc, detach_st])
self.num_style_batch += 2
else:
random_idx = [x for x in range(self.num_style_batch)]
random.shuffle(random_idx)
self.style_batches[random_idx[0]] = detach_sc
self.style_batches[random_idx[1]] = detach_st
tensor_styles = torch.squeeze(torch.cat(styles, 0))
styles_mean = torch.mean(tensor_styles, dim=0)
tminuss = tensor_styles - styles_mean
cov = torch.mm(tminuss.t(), tminuss) / tensor_styles.shape[0]
std_cov = cov.diag(diagonal=0)
total_loss = self.loss(styles_mean, torch.zeros_like(styles_mean))
total_loss += self.loss(cov, torch.ones_like(cov))
total_loss += self.loss(std_cov, torch.ones_like(std_cov))
return total_loss
Could you please give me some advice? Thanks!
Metadata
Metadata
Assignees
Labels
No labels