Skip to content

FlashAttention Triton error on the MosaicBERT models other than base #441

@Taytay

Description

@Taytay

When I try to run MosaicBERT like this:

from transformers import AutoModelForMaskedLM, BertTokenizer, pipeline

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
config = transformers.BertConfig.from_pretrained('mosaicml/mosaic-bert-base-seqlen-2048')
# It's essential to set the attention_probs_dropout_prob to 0.1
#config.alibi_starting_size = 2048 # maximum sequence length updated to 2048 from config default of 512
mlm = AutoModelForMaskedLM.from_pretrained('mosaicml/mosaic-bert-base-seqlen-2048', trust_remote_code=True, config=config)

mlm.to("cuda")
classifier = pipeline('fill-mask', model=mlm, tokenizer=tokenizer, device="cuda")

classifier("I [MASK] to the store yesterday.")

I get this error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File [~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1124](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1124), in ast_to_ttir(fn, signature, specialization, constants, debug, arch)
   [1123](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1123) try:
-> [1124](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1124)     generator.visit(fn.parse())
   [1125](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1125) except CompilationError as e:

File [~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1017](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1017), in CodeGenerator.visit(self, node)
   [1016](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1016)     last_loc = self.builder.get_loc()
-> [1017](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1017) ret = super().visit(node)
   [1018](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1018) # Reset the location to the last one before the visit

File [~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:407](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:407), in NodeVisitor.visit(self, node)
    [406](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:406) visitor = getattr(self, method, self.generic_visit)
--> [407](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:407) return visitor(node)

File [~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:293](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:293), in CodeGenerator.visit_Module(self, node)
    [292](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:292) def visit_Module(self, node):
--> [293](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:293)     ast.NodeVisitor.generic_visit(self, node)

File [~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:415](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:415), in NodeVisitor.generic_visit(self, node)
    [414](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:414)         if isinstance(item, AST):
--> [415](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:415)             self.visit(item)
    [416](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:416) elif isinstance(value, AST):
...
                            other=0.0)
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k, trans_b=True)
                        ^
TypeError("dot() got an unexpected keyword argument 'trans_b'")

This appears to have been fixed a few days ago by @jacobfulano in the mosaic-bert-base repo:
https://huggingface.co/mosaicml/mosaic-bert-base/blob/ed2a544063a892b78823cba2858d1e098c0e6012/config.json

It looks like that removes FlashAttention? Does that mean that the speed increase from FA is also removed?

Here's how I can fix it in the meantime if someone else Googles and stumbles across this

config = transformers.BertConfig.from_pretrained('mosaicml/mosaic-bert-base')
# It's essential to set the attention_probs_dropout_prob to 0.1, which mosaic-bert-base does. So we just update the alibi_starting_size:
config.alibi_starting_size = 2048 # maximum sequence length updated to 2048 from config default of 512
mlm = AutoModelForMaskedLM.from_pretrained('mosaicml/mosaic-bert-base-seqlen-2048', trust_remote_code=True, config=config)
mlm.to("cuda")
classifier = pipeline('fill-mask', model=mlm, tokenizer=tokenizer, device="cuda")
classifier("I [MASK] to the store yesterday.")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions