Skip to content

Commit e807743

Browse files
authored
[ETHOSU] Add early simplify to fix LoopPartition (apache#9387)
* [ETHOSU] Add early simplify to fix LoopPartition Certain loops aren't correctly partitioned if the loop condition hasn't been simplified. This can happen when a copy loop is split by a non-factor. To fix this, an additional simplify pass is added to the TIR pipeline prior to LoopPartition. Change-Id: Icd4ff14648ccaed41384da50c6d183a122b30048 * Fix linting again Change-Id: I9c9dc2ee2c679861866b23531e88584b94198e51
1 parent 1f8ef2a commit e807743

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

python/tvm/relay/backend/contrib/ethosu/tir/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def lower_ethosu(sch, args, const_dict, name="main"):
7878
mod = tvm.tir.transform.Simplify()(mod)
7979
mod = tvm.tir.transform.StorageFlatten(64)(mod)
8080
mod = tvm.tir.transform.UnrollLoop()(mod)
81+
mod = tvm.tir.transform.Simplify()(mod)
8182
mod = tvm.tir.transform.LoopPartition()(mod)
8283
mod = RemoveZeroStores()(mod)
8384
mod = tvm.tir.transform.Simplify()(mod)

tests/python/contrib/test_ethosu/test_replace_copy.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from tvm import relay
2323
from tvm.relay.testing import run_opt_pass
2424
from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
25-
from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants
25+
from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants, Convolution2DCompute
2626

2727
from .infra import make_ethosu_conv2d
2828

@@ -73,5 +73,67 @@ def _get_func():
7373
tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True)
7474

7575

76+
# fmt: off
77+
@tvm.script.ir_module
78+
class WeightStream:
79+
@T.prim_func
80+
def main(placeholder: T.handle, ethosu_write: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle) -> None:
81+
# function attr dict
82+
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
83+
placeholder_5 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8")
84+
ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 16], dtype="int8")
85+
buffer = T.match_buffer(placeholder_1, [416], dtype="uint8")
86+
buffer_1 = T.match_buffer(placeholder_2, [112], dtype="uint8")
87+
buffer_2 = T.match_buffer(placeholder_3, [272], dtype="uint8")
88+
buffer_3 = T.match_buffer(placeholder_4, [64], dtype="uint8")
89+
# body
90+
placeholder_global = T.allocate([416], "uint8", "global")
91+
placeholder_d_global = T.allocate([112], "uint8", "global")
92+
T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 416, T.load("uint8", placeholder_global, 0), dtype="handle"))
93+
T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 112, T.load("uint8", placeholder_d_global, 0), dtype="handle"))
94+
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 416, 12, T.load("uint8", placeholder_d_global, 0), 112, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle"))
95+
T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 272, T.load("uint8", placeholder_global, 0), dtype="handle"))
96+
T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3.data, 0), 64, T.load("uint8", placeholder_d_global, 0), dtype="handle"))
97+
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, T.load("int8", ethosu_write_1.data, 10), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 272, 12, T.load("uint8", placeholder_d_global, 0), 64, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle"))
98+
__tvm_meta__ = None
99+
# fmt: on
100+
101+
102+
def test_weight_stream():
103+
def _cascader(cached_func, const_dict, sch):
104+
weight = cached_func.inputs[1]
105+
scale_bias = cached_func.inputs[2]
106+
out = cached_func.outputs[0]
107+
conv_compute = Convolution2DCompute.from_output(out)
108+
co = conv_compute.split(sch, 3, 10)
109+
cache_weight = sch.cache_read(weight, "global", [conv_compute.conv2d])
110+
cache_scale_bias = sch.cache_read(scale_bias, "global", [conv_compute.conv2d])
111+
sch[cache_weight].compute_at(sch[out], co)
112+
sch[cache_scale_bias].compute_at(sch[out], co)
113+
114+
def _get_func():
115+
ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype="int8")
116+
conv = make_ethosu_conv2d(
117+
ifm,
118+
32,
119+
16,
120+
(1, 1),
121+
(0, 0),
122+
(1, 1),
123+
(1, 1),
124+
)
125+
func = relay.Function(relay.analysis.free_vars(conv), conv)
126+
func = run_opt_pass(func, relay.transform.InferType())
127+
return func
128+
129+
func = _get_func()
130+
mod, _ = lower_to_tir(func, cascader=_cascader)
131+
132+
script = mod.script(show_meta=True)
133+
test_mod = tvm.script.from_source(script)
134+
reference_mod = WeightStream
135+
tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True)
136+
137+
76138
if __name__ == "__main__":
77139
pytest.main([__file__])

0 commit comments

Comments
 (0)