Skip to content

Commit 354019d

Browse files
mbaretmanupak
andauthored
[3/6] Arm(R) Ethos(TM)-U NPU TIR compiler with conv2d support (apache#8806)
* Arm(R) Ethos(TM)-U NPU TIR compiler with conv2d support This commit adds the lowering passes necessary to lower an NPU Relay module down to a TIR module that can be compiled for the NPU. Conv2d is supported as the first NPU operator. An intermediate TE stage between Relay and TIR allows support for scheduling the operators. Co-authored-by: Manupa Karunaratne <[email protected]> * Fix Conv2D TIR type sensitivity Change-Id: I3741f9dd8bb5952590ff8c586f6b96e5c3a03795 * Arm(R) Ethos(TM)-U NPU TIR passes and TE for Conv2D *fixing tests Change-Id: Id4a4c80f72ce29b98fc8b3954a1413c1c7fda500 * Fix import guards for tests Change-Id: Iaee06017bd125d3040ce42182c4ccdb80d7fc946 * Fix typing failures with ignores Change-Id: I81513f112a42b93cfdd3bcaf8e8852dd60ffe9e9 * Remove unused import Change-Id: I6596b62ab56e4ca8b31ef08293686f53f38454d2 * Reintroduce get_target_accel_type Change-Id: I0aaf83fe0204c0db435692e9b92dee6e9d6997fe Co-authored-by: Manupa Karunaratne <[email protected]>
1 parent 34570f2 commit 354019d

27 files changed

+4144
-2
lines changed

python/tvm/relay/backend/contrib/ethosu/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@
2020
from . import preprocess
2121
from . import errors
2222
from . import vela_api
23+
from . import tir_to_cs_translator
2324
from .util import partition_for_ethosu

python/tvm/relay/backend/contrib/ethosu/te/convolution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def conv2d_compute(
140140
).astype(ifm.dtype)
141141
* weight[cc, rh, rw, rc].astype(ifm.dtype)
142142
# This is a trick to load 10 elements of the scale_bias at once, not accurate maths
143-
+ (scale_bias[cc, 0] * scale_bias[cc, 9]),
143+
+ (scale_bias[cc, 0] * scale_bias[cc, 9]).astype(ifm.dtype),
144144
axis=[rh, rw, rc],
145145
),
146146
name="ethosu_conv2d",

python/tvm/relay/backend/contrib/ethosu/te/dma.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def _pad(*indices):
5959
not_zero.append(indices[i] < tensor.shape[i] + pad_before[i])
6060
if not_zero:
6161
not_zero = tvm.tir.all(*not_zero)
62-
return tvm.tir.if_then_else(not_zero, tensor(*index_tuple), tvm.tir.const(0, "uint8"))
62+
return tvm.tir.if_then_else(
63+
not_zero, tensor(*index_tuple), tvm.tir.const(0, tensor.dtype)
64+
)
6365
return tensor(*index_tuple)
6466

6567
return _pad
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Arm(R) Ethos(TM)-U NPU TIR codegen modules."""
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name, unused-argument
18+
"""The integration of Arm(R) Ethos(TM)-U NPU TIR compiler"""
19+
import tvm
20+
from tvm import relay
21+
from tvm.relay.expr_functor import ExprMutator
22+
from tvm.driver.build_module import get_binds
23+
24+
from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants
25+
from .scheduler import schedule
26+
27+
28+
def lower_ethosu(sch, args, const_dict, name="main"):
29+
"""Lower a schedule to TIR for the Arm(R) Ethos(TM)-U NPU target.
30+
31+
The resulting TIR module will contain a single function
32+
that comprises of a sequence of tir.extern_calls to NPU
33+
operations.
34+
35+
Parameters
36+
----------
37+
sch : tvm.te.Schedule
38+
The schedule to be lowered.
39+
args : Union[list of tvm.te.Tensor, TEGraph]
40+
The input/output tensors.
41+
const_dict : dict of int to numpy.ndarray
42+
The constant dictionary.
43+
name : str, optional
44+
The name of the lowered primitive function.
45+
46+
Returns
47+
-------
48+
mod : tvm.IRModule
49+
The lowered TIR module.
50+
const_dict : dict of int to numpy.ndarray
51+
The modified constant dictionary.
52+
53+
"""
54+
if not isinstance(args, list):
55+
args = list(args.inputs) + list(args.outputs)
56+
# config setup
57+
curr_pass_ctx = tvm.ir.transform.PassContext.current()
58+
curr_cfg = dict()
59+
for key, value in curr_pass_ctx.config.items():
60+
curr_cfg[key] = value
61+
tir_compiler_cfg = {
62+
"tir.LoopPartition": {
63+
"partition_const_loop": True,
64+
"no_unroll_loop_with_extent_one": True,
65+
},
66+
"tir.UnrollLoop": {"auto_max_depth": -1},
67+
}
68+
# Merge two configs
69+
curr_cfg = {**curr_cfg, **tir_compiler_cfg}
70+
71+
sch = sch.normalize()
72+
bounds = tvm.te.schedule.InferBound(sch)
73+
stmt = tvm.te.schedule.ScheduleOps(sch, bounds, True)
74+
75+
compact = tvm.te.schedule.VerifyCompactBuffer(stmt)
76+
binds, arg_list = get_binds(args, compact, None)
77+
func = tvm.te.schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)
78+
79+
func = func.with_attr("global_symbol", name)
80+
func = func.with_attr("tir.noalias", True)
81+
mod = tvm.IRModule({name: func})
82+
with tvm.transform.PassContext(config=curr_cfg):
83+
mod = tvm.tir.transform.Simplify()(mod)
84+
mod = tvm.tir.transform.StorageFlatten(64)(mod)
85+
mod = tvm.tir.transform.UnrollLoop()(mod)
86+
mod = tvm.tir.transform.LoopPartition()(mod)
87+
mod = RemoveZeroStores()(mod)
88+
mod = tvm.tir.transform.Simplify()(mod)
89+
mod = tvm.tir.transform.RemoveNoOp()(mod)
90+
mod = ReplaceOperators()(mod)
91+
mod = tvm.tir.transform.RemoveNoOp()(mod)
92+
mod, const_dict = EncodeConstants(const_dict)(mod)
93+
mod = tvm.tir.transform.StorageRewrite()(mod)
94+
mod = tvm.tir.transform.RemoveNoOp()(mod)
95+
return mod, const_dict
96+
97+
98+
def lower_to_te(prim_func):
99+
"""Lower a Relay primitive function to a Tensor Expression graph.
100+
101+
Parameters
102+
----------
103+
prim_func : tvm.relay.Function
104+
The Relay function to lowerethosu_runtime([]).
105+
106+
Returns
107+
-------
108+
out : TEGraph
109+
The lowered Tensor Expression graph.
110+
111+
"""
112+
f = tvm._ffi.get_global_func("relay.backend.contrib.ethosu.LowerToTE")
113+
return f(prim_func)
114+
115+
116+
class ExtractConstants(ExprMutator):
117+
"""The actual mutator pass to extract the constants from a function and replace them with
118+
Vars so the function can be lowered to a TE graph. Additionally returns all the values of
119+
the constants extracted."""
120+
121+
def __init__(self):
122+
super().__init__()
123+
self.constants = []
124+
125+
def visit_constant(self, const):
126+
if isinstance(const.checked_type, relay.ty.TensorType):
127+
if const.checked_type.concrete_shape != ():
128+
self.constants.append(const.data.asnumpy())
129+
name = "p" + str(len(self.constants))
130+
return relay.var(type_annotation=const.checked_type, name_hint=name)
131+
132+
return const
133+
134+
def visit_function(self, fn):
135+
new_body = self.visit(fn.body)
136+
new_params = list(relay.analysis.free_vars(new_body))
137+
return relay.Function(new_params, new_body)
138+
139+
def extract_constants(self, func):
140+
new_func = self.visit(func)
141+
return new_func, self.constants
142+
143+
144+
def extract_constants(func):
145+
"""Extract the constants from a function and replace them with
146+
Vars so the function can be lowered to a TE graph. Additionally
147+
returns all the values of the constants extracted.
148+
149+
Parameters
150+
----------
151+
func : tvm.relay.Function
152+
The Relay function from which to extract constants.
153+
154+
Returns
155+
-------
156+
new_func : tvm.relay.Function
157+
The Relay function with constants replaced by vars.
158+
const_dict : dict of int to numpy.ndarray
159+
A dict of the extracted constants keyed by their param index.
160+
161+
"""
162+
const_dict = {}
163+
params = len(func.params)
164+
new_func, consts = ExtractConstants().extract_constants(func)
165+
for i, const in enumerate(consts):
166+
const_dict[params + i] = const
167+
168+
new_func = tvm.relay.transform.InferType()(tvm.IRModule.from_expr(new_func))["main"]
169+
return new_func, const_dict
170+
171+
172+
def lower_to_tir(func, cascader=None):
173+
"""Lower a Relay function to TIR for the Arm(R) Ethos(TM)-U NPU target.
174+
175+
The Relay function should only contain operations supported
176+
by the NPU.
177+
178+
Parameters
179+
----------
180+
func : tvm.relay.Function
181+
The Relay function to lower.
182+
cascader : Callable
183+
An optional cascading function,
184+
185+
Returns
186+
-------
187+
mod : tvm.IRModule
188+
The lowered TIR module.
189+
consts : dict of int to numpy.ndarray
190+
A dict of the extracted constants keyed by their param index.
191+
192+
"""
193+
func, consts = extract_constants(func)
194+
mod = tvm.IRModule.from_expr(func)
195+
func = relay.transform.InferType()(mod)["main"]
196+
te_graph = lower_to_te(func)
197+
s = schedule(te_graph, consts, cascader)
198+
mod, consts = lower_ethosu(s, te_graph, consts)
199+
return mod, consts
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name, unused-argument
18+
"""Extract information from the convolution operators in TIR."""
19+
import tvm
20+
from ..vela_api import SCALE_BIAS_LENGTH
21+
from .utils import get_outer_loops, get_op_attrs, get_base_address, get_loads, get_stores
22+
from .dma import get_ifm_params, get_ofm_params
23+
from .spec import SerialKernel, SerialAddressRange, SerialActivation, Serial2DConvolution
24+
25+
26+
def get_conv2d_params(stmt, producers, consumers):
27+
"""Get the parameters necessary to construct a call_extern for a 2D convolution.
28+
29+
Parameters
30+
----------
31+
stmt : tvm.tir.AttrStmt
32+
The outermost attribute statement of a convolution loop nest.
33+
producers : dict of tvm.tir.Var to tvm.tir.AttrStmt
34+
A dictionary to associate pointers with the loop nest
35+
that produces their values.
36+
consumers : dict of tvm.tir.Var to tvm.tir.AttrStmt
37+
A dictionary to associate pointers with the loop nest
38+
that consumes their values.
39+
40+
Returns
41+
-------
42+
Serial2DConvolution
43+
The parameters needed to construct a 2D convolution.
44+
output_pointer : tvm.tir.Var
45+
The output pointer of the convolution operation.
46+
replace_pointer : tvm.tir.Var
47+
The output pointer of the DMA write operation, which is to replace
48+
the convolution output pointer.
49+
50+
"""
51+
attrs, body = get_op_attrs(stmt)
52+
_, _, _, _, _, inner = get_outer_loops(body, "NHWC")
53+
rh = inner
54+
rw = rh.body
55+
rc = rw.body
56+
# loads = [output, input, weights, scale_bias, scale_bias]
57+
loads = get_loads(rc.body)
58+
# stores = [output]
59+
stores = get_stores(rc.body)
60+
input_pointer = loads[1].buffer_var
61+
output_pointer = stores[0].buffer_var
62+
# Get feature map info
63+
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers)
64+
serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers)
65+
# Get kernel info
66+
serial_kernel = SerialKernel(
67+
width=int(rw.extent),
68+
height=int(rh.extent),
69+
stride_w=int(attrs["stride_w"]),
70+
stride_h=int(attrs["stride_h"]),
71+
dilation_w=int(attrs["dilation_w"]),
72+
dilation_h=int(attrs["dilation_h"]),
73+
)
74+
# Get scale_bias info
75+
scale_bias_load = loads[3]
76+
scale_bias_base = get_base_address(scale_bias_load.index)
77+
serial_scale_bias = SerialAddressRange(
78+
address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base),
79+
length=SCALE_BIAS_LENGTH * serial_ofm[3],
80+
)
81+
# Get weight info
82+
weight_load = loads[2]
83+
weight_base = get_base_address(weight_load.index)
84+
serial_weight = SerialAddressRange(
85+
address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base),
86+
length=serial_ofm[3] * serial_kernel[0] * serial_kernel[1] * rc.extent,
87+
)
88+
# Get activation info
89+
serial_activation = SerialActivation(
90+
op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"]
91+
)
92+
return (
93+
Serial2DConvolution(
94+
ifm=serial_ifm,
95+
ofm=serial_ofm,
96+
kernel=serial_kernel,
97+
weight=serial_weight,
98+
weight_zero_point=attrs["weight_zero_point"],
99+
scale_bias=serial_scale_bias,
100+
padding=serial_padding,
101+
activation=serial_activation,
102+
upscale="NONE",
103+
),
104+
output_pointer,
105+
replace_pointer,
106+
)

0 commit comments

Comments
 (0)