Skip to content

Commit 417f79a

Browse files
authored
[Executorch][llm] Enable local global attention in export_llama script
Differential Revision: D73891423 Pull Request resolved: #10612
1 parent 766a3be commit 417f79a

File tree

3 files changed

+65
-1
lines changed

3 files changed

+65
-1
lines changed

examples/models/llama/export_llama_lib.py

+35
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from .source_transformation.custom_kv_cache import (
6363
replace_kv_cache_with_custom_kv_cache,
6464
replace_kv_cache_with_quantized_kv_cache,
65+
replace_kv_cache_with_ring_kv_cache,
6566
)
6667

6768
from .source_transformation.quantize import (
@@ -153,6 +154,23 @@ def build_model(
153154
return export_llama(args)
154155

155156

157+
def parse_list_of_ints(s):
158+
import ast
159+
160+
try:
161+
parsed = ast.literal_eval(s)
162+
if isinstance(parsed, list) and all(isinstance(i, int) for i in parsed):
163+
print(parsed)
164+
return parsed
165+
raise argparse.ArgumentTypeError(
166+
"Must be a list of integers, e.g., [0, 16, 0, 16]"
167+
)
168+
except Exception:
169+
raise argparse.ArgumentTypeError(
170+
"Must be a list of integers, e.g., [0, 16, 0, 16]"
171+
)
172+
173+
156174
def build_args_parser() -> argparse.ArgumentParser:
157175
parser = argparse.ArgumentParser()
158176
parser.add_argument("-o", "--output-dir", default=".", help="output directory")
@@ -363,6 +381,15 @@ def build_args_parser() -> argparse.ArgumentParser:
363381
help="maximum length of context for model to remember",
364382
)
365383

384+
parser.add_argument(
385+
"--local_global_attention",
386+
type=parse_list_of_ints,
387+
default=None,
388+
help="List of integers specifying local and global attention pattern, e.g., [0, 16, 0, 16] to specify that every other layer is sliding window of 16."
389+
" [0, 16, 32] pattern specifes 2nd and 3rd layer has sliding window of 16 and 32 respecitvely. "
390+
" [16] pattern specifies all layers have sliding window of 16.",
391+
)
392+
366393
parser.add_argument("-2", "--fairseq2", action="store_true")
367394
parser.add_argument("-v", "--verbose", action="store_true")
368395
parser.add_argument(
@@ -1307,6 +1334,14 @@ def _get_source_transforms( # noqa
13071334
if args.vulkan:
13081335
transforms.append(replace_with_vulkan_rotary_emb)
13091336

1337+
if getattr(args, "local_global_attention", None) is not None:
1338+
transforms.append(
1339+
partial(
1340+
replace_kv_cache_with_ring_kv_cache,
1341+
layer_sizes=args.local_global_attention,
1342+
)
1343+
)
1344+
13101345
return transforms
13111346

13121347

examples/models/llama/source_transformation/custom_kv_cache.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -555,8 +555,17 @@ def replace_kv_cache_with_ring_kv_cache(module, layer_sizes):
555555
# This is needed to ensure that custom ops are registered
556556
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
557557

558+
assert len(module.layers) >= len(
559+
layer_sizes
560+
), f"Length of layer sizes {len(layer_sizes)} must match the number of layers in the module {len(module.layers)}."
561+
multiplier = len(module.layers) // len(layer_sizes)
562+
modulo = len(module.layers) % len(layer_sizes)
563+
assert (
564+
modulo == 0
565+
), f"num layers specified must be multiple of model layers in order to specify pattern. pattern: {layer_sizes} model's num layers {len(module.layers)}"
566+
layer_sizes = layer_sizes * multiplier
558567
logging.info(
559-
"Replacing kv cache with ring kv cache. This modifies the model in place."
568+
f"Applying local sliding window attention with following pattern {layer_sizes}."
560569
)
561570
assert len(layer_sizes) == len(
562571
module.layers
@@ -570,4 +579,8 @@ def replace_kv_cache_with_ring_kv_cache(module, layer_sizes):
570579
), f"Transfomer block must have attention module. Transformer block {transformer_block}"
571580
attention = transformer_block.attention
572581
_replace_kv_cache_with_ring_kv_cache(attention, sliding_window_size)
582+
# if attention's sdpa is custom sdpa then we have to make sure
583+
# it is not doing causal attention
584+
if "SDPACustom" in attention.SDPA.__class__.__name__:
585+
attention.SDPA.use_attention_mask = True
573586
return module

examples/models/llama/tests/TARGETS

+16
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,19 @@ python_unittest(
8585
"//executorch/examples/models/llama:sdpa",
8686
],
8787
)
88+
89+
python_unittest(
90+
name = "test_export_llama_lib",
91+
srcs = [
92+
"test_export_llama_lib.py",
93+
],
94+
preload_deps = [
95+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
96+
],
97+
deps = [
98+
"//caffe2:torch",
99+
"//executorch/examples/models/llama:export_library",
100+
"//executorch/examples/models/llama:llama_transformer",
101+
"//executorch/extension/pybindings:portable_lib",
102+
],
103+
)

0 commit comments

Comments
 (0)