Skip to content

fix parameter specials in Field.build_vocab #495

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 1, 2019

Conversation

speedcell4
Copy link
Contributor

No description provided.

@mttk
Copy link
Contributor

mttk commented Jan 31, 2019

I believe that the specials argument was off-limits to users intentionally. Do you have any use-case not covered by pad, bos, eos or unk?

@speedcell4
Copy link
Contributor Author

speedcell4 commented Feb 1, 2019

case 1

say we have a sentence [I, have, an, apple], when I want to convert it to its contiguous character-level representation [I, <space>, h, a, v, e, <space>, a, n, <space>, a, p, p, l, e], I need a special token for <space>.

case 2

another case is building a separated bi-LSTM. I want the forward field forward_word uses <fwd_init> and <fwd_eos> while the backward field backward_word utilizes <bwd_init> and <bwd_eos>, and their vocabulary should be shared. if torchtext supports that parameter, then we can do this by

forward_word = Field(..., init='<fwd_init>', eos='<fwd_eos>')
backward_word = Field(..., init='<bwd_init>', eos='<bwd_eos>')

forward_word.build_vocab(train, speicals=['<bwd_init>', '<bwd_eos>'])
backward_word.vocab = forward_word.vocab

case 3

one more case is like BERT, we randomly replace word by <mask> in the postprocessing stage. so in the build_vocab stage, the field even has not seen any <mask> but <mask> is needed for the future postprocessing work

def postprocessing(batch, vocab):
    res_batch = []
    for ex in batch:
        res_ex = []
        for token in ex:
            if random.random() < 0.2:
                res_ex.append(vocab.stoi['<mask>'])
            else:
                res_ex.append(token)
        res_batch.append(res_ex)
    return res_batch


word = Field(..., postprocessing=postprocessing)
word.build_vocab(train, speicals=['<mask>'])

@mttk
Copy link
Contributor

mttk commented Feb 1, 2019

Makes sense. AFAIK, the case # 3 should be better handled while creating batches (because post-processing is used just once, and you would want to have different tokens masked between epochs).
Could you fix the travis errors & write a test for one of those use-cases (ex. add a special token and then validate it's added)

@speedcell4
Copy link
Contributor Author

no, postprocessing will be used once in every batch

sentence = 'I have an apple'


def postprocessing(batch, vocab):
    print(f'postprocessing is called')
    res_batch = []
    for ex in batch:
        res_ex = []
        for token in ex:
            if random.random() < 0.5:
                res_ex.append(vocab.stoi['<unk>'])
            else:
                res_ex.append(token)
        res_batch.append(res_ex)
    return res_batch


class Dummy(Dataset):
    def __init__(self, examples, fields):
        super(Dummy, self).__init__(examples, fields)

    @classmethod
    def iters(cls):
        WORD = Field(postprocessing=postprocessing)
        fields = [('word', WORD)]
        examples = [
            Example.fromlist([sentence], fields=fields),
        ]
        dataset = cls(examples, fields)
        WORD.build_vocab(dataset)
        return Iterator(dataset, batch_size=1)


if __name__ == '__main__':
    train = Dummy.iters()
    for _ in range(2):
        for batch in train:
            print(batch.word.tolist())

# postprocessing is called
# [[2], [0], [3], [4]]
# postprocessing is called
# [[2], [5], [0], [4]]

I will finish the unit tests soon

@speedcell4
Copy link
Contributor Author

speedcell4 commented Feb 1, 2019

this project should provide an edit configuration file, my PyCharm always fix the non-PEP8 format automatically

@mttk
Copy link
Contributor

mttk commented Feb 1, 2019

Thanks!

Could you elaborate on the edit configuration file? The flake config is written in .flake8., but I assume you're not referring to that.

@mttk mttk merged commit a6e520e into pytorch:master Feb 1, 2019
@speedcell4 speedcell4 deleted the fix/build_vocab branch February 1, 2019 13:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants