Skip to content

Questions about the wikiart training sets #12

Open
@Rancherzhang

Description

@Rancherzhang

Hello, I want to reproduce your great job, but to my limited knowledge, I have two questions right now.
Firstly, I'm trying to rewrite the training phrase and beginning to train on the wikiart with content-dir of 'wikiart/Rococo' while style-dir of 'wikiart/Symbolism', but the intermediate result is not good as you, so I want to know what content-dir and style-dir you choose on the wikiart datasets?
Secondly, my loss on style distribution could not converge, it is always around between 4.2-4.4. My code is as below:

class StyleDistLoss(nn.Module):
    '''
    style distribition loss of s and s'
    '''
    def __init__(self, pool_size):
        super(StyleDistLoss, self).__init__()
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_style_batch = 0
            self.style_batches = []
        self.loss = nn.L1Loss()

    def __call__(self, sc, st):
        '''
            return the standart Gaussian distribution loss of input 
            style source {sc} and style traget {st} which are respective to s and s' in the paper
        '''
        styles = []
        if self.pool_size == 0:
            styles.extend([sc, st])
        else:
            styles += self.style_batches
            styles.extend([sc, st])

            detach_sc = sc.clone().detach()
            detach_st = st.clone().detach()

            if self.num_style_batch + 2 < self.pool_size:
                self.style_batches.extend([detach_sc, detach_st])
                self.num_style_batch += 2
            else:
                random_idx = [x for x in range(self.num_style_batch)]
                random.shuffle(random_idx)
                self.style_batches[random_idx[0]] = detach_sc
                self.style_batches[random_idx[1]] = detach_st
        tensor_styles = torch.squeeze(torch.cat(styles, 0))
        styles_mean = torch.mean(tensor_styles, dim=0)
        tminuss = tensor_styles - styles_mean
        cov = torch.mm(tminuss.t(), tminuss) / tensor_styles.shape[0]
        std_cov = cov.diag(diagonal=0)
        total_loss = self.loss(styles_mean, torch.zeros_like(styles_mean))
        total_loss += self.loss(cov, torch.ones_like(cov))
        total_loss += self.loss(std_cov, torch.ones_like(std_cov))
        return total_loss

Could you please give me some advice? Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions