|
22 | 22 | from tvm import relay |
23 | 23 | from tvm.relay.testing import run_opt_pass |
24 | 24 | from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir |
25 | | -from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants |
| 25 | +from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants, Convolution2DCompute |
26 | 26 |
|
27 | 27 | from .infra import make_ethosu_conv2d |
28 | 28 |
|
@@ -73,5 +73,67 @@ def _get_func(): |
73 | 73 | tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) |
74 | 74 |
|
75 | 75 |
|
| 76 | +# fmt: off |
| 77 | +@tvm.script.ir_module |
| 78 | +class WeightStream: |
| 79 | + @T.prim_func |
| 80 | + def main(placeholder: T.handle, ethosu_write: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle) -> None: |
| 81 | + # function attr dict |
| 82 | + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) |
| 83 | + placeholder_5 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8") |
| 84 | + ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 16], dtype="int8") |
| 85 | + buffer = T.match_buffer(placeholder_1, [416], dtype="uint8") |
| 86 | + buffer_1 = T.match_buffer(placeholder_2, [112], dtype="uint8") |
| 87 | + buffer_2 = T.match_buffer(placeholder_3, [272], dtype="uint8") |
| 88 | + buffer_3 = T.match_buffer(placeholder_4, [64], dtype="uint8") |
| 89 | + # body |
| 90 | + placeholder_global = T.allocate([416], "uint8", "global") |
| 91 | + placeholder_d_global = T.allocate([112], "uint8", "global") |
| 92 | + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 416, T.load("uint8", placeholder_global, 0), dtype="handle")) |
| 93 | + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 112, T.load("uint8", placeholder_d_global, 0), dtype="handle")) |
| 94 | + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 416, 12, T.load("uint8", placeholder_d_global, 0), 112, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) |
| 95 | + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 272, T.load("uint8", placeholder_global, 0), dtype="handle")) |
| 96 | + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3.data, 0), 64, T.load("uint8", placeholder_d_global, 0), dtype="handle")) |
| 97 | + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, T.load("int8", ethosu_write_1.data, 10), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 272, 12, T.load("uint8", placeholder_d_global, 0), 64, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) |
| 98 | + __tvm_meta__ = None |
| 99 | +# fmt: on |
| 100 | + |
| 101 | + |
| 102 | +def test_weight_stream(): |
| 103 | + def _cascader(cached_func, const_dict, sch): |
| 104 | + weight = cached_func.inputs[1] |
| 105 | + scale_bias = cached_func.inputs[2] |
| 106 | + out = cached_func.outputs[0] |
| 107 | + conv_compute = Convolution2DCompute.from_output(out) |
| 108 | + co = conv_compute.split(sch, 3, 10) |
| 109 | + cache_weight = sch.cache_read(weight, "global", [conv_compute.conv2d]) |
| 110 | + cache_scale_bias = sch.cache_read(scale_bias, "global", [conv_compute.conv2d]) |
| 111 | + sch[cache_weight].compute_at(sch[out], co) |
| 112 | + sch[cache_scale_bias].compute_at(sch[out], co) |
| 113 | + |
| 114 | + def _get_func(): |
| 115 | + ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype="int8") |
| 116 | + conv = make_ethosu_conv2d( |
| 117 | + ifm, |
| 118 | + 32, |
| 119 | + 16, |
| 120 | + (1, 1), |
| 121 | + (0, 0), |
| 122 | + (1, 1), |
| 123 | + (1, 1), |
| 124 | + ) |
| 125 | + func = relay.Function(relay.analysis.free_vars(conv), conv) |
| 126 | + func = run_opt_pass(func, relay.transform.InferType()) |
| 127 | + return func |
| 128 | + |
| 129 | + func = _get_func() |
| 130 | + mod, _ = lower_to_tir(func, cascader=_cascader) |
| 131 | + |
| 132 | + script = mod.script(show_meta=True) |
| 133 | + test_mod = tvm.script.from_source(script) |
| 134 | + reference_mod = WeightStream |
| 135 | + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) |
| 136 | + |
| 137 | + |
76 | 138 | if __name__ == "__main__": |
77 | 139 | pytest.main([__file__]) |
0 commit comments