Skip to content

Commit 4f6b478

Browse files
authored
Address review comments on Arm(R) Ethos(TM)-U PR 3/6 (apache#9159)
* Address review comments on Arm(R) Ethos(TM)-U PR 3/6 Change-Id: I22961885a503be31f6a72622ae0b5f874cc6f463 * Fix rebasing error Change-Id: I3e2fde786096ea331fcb366080fa779ec4ea4a5d * Fix more rebasing problems Change-Id: I1026e3ccee33a3fdec9ebbf6456bae244ad4f1d5
1 parent 9f27be6 commit 4f6b478

File tree

15 files changed

+197
-404
lines changed

15 files changed

+197
-404
lines changed

python/tvm/relay/backend/contrib/ethosu/tir/compiler.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint: disable=invalid-name, unused-argument
18-
"""The integration of Arm(R) Ethos(TM)-U NPU TIR compiler"""
18+
"""The integration of the Arm(R) Ethos(TM)-U NPU TIR compiler."""
1919
import tvm
2020
from tvm import relay
2121
from tvm.relay.expr_functor import ExprMutator
@@ -29,7 +29,7 @@ def lower_ethosu(sch, args, const_dict, name="main"):
2929
"""Lower a schedule to TIR for the Arm(R) Ethos(TM)-U NPU target.
3030
3131
The resulting TIR module will contain a single function
32-
that comprises of a sequence of tir.extern_calls to NPU
32+
that consists of a sequence of tir.extern_calls to NPU
3333
operations.
3434
3535
Parameters
@@ -96,20 +96,20 @@ def lower_ethosu(sch, args, const_dict, name="main"):
9696

9797

9898
def lower_to_te(prim_func):
99-
"""Lower a Relay primitive function to a Tensor Expression graph.
99+
"""Lower a Relay primitive function to a Tensor Expression in an unscheduled CachedFunc.
100100
101101
Parameters
102102
----------
103103
prim_func : tvm.relay.Function
104-
The Relay function to lowerethosu_runtime([]).
104+
The Relay function to lower.
105105
106106
Returns
107107
-------
108-
out : TEGraph
109-
The lowered Tensor Expression graph.
108+
out : CachedFunc
109+
The lowered Tensor Expression as part of a CachedFunc.
110110
111111
"""
112-
f = tvm._ffi.get_global_func("relay.backend.contrib.ethosu.LowerToTE")
112+
f = tvm._ffi.get_global_func("relay.backend.LowerToTE")
113113
return f(prim_func)
114114

115115

@@ -193,7 +193,7 @@ def lower_to_tir(func, cascader=None):
193193
func, consts = extract_constants(func)
194194
mod = tvm.IRModule.from_expr(func)
195195
func = relay.transform.InferType()(mod)["main"]
196-
te_graph = lower_to_te(func)
197-
s = schedule(te_graph, consts, cascader)
198-
mod, consts = lower_ethosu(s, te_graph, consts)
196+
cached_func = lower_to_te(func)
197+
s = schedule(cached_func, consts, cascader)
198+
mod, consts = lower_ethosu(s, cached_func, consts)
199199
return mod, consts

python/tvm/relay/backend/contrib/ethosu/tir/convolution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint: disable=invalid-name, unused-argument
18-
"""Extract information from the convolution operators in TIR."""
18+
"""Extract parameters from the convolution operators in TIR."""
1919
import tvm
2020
from ..vela_api import SCALE_BIAS_LENGTH
2121
from .utils import get_outer_loops, get_op_attrs, get_base_address, get_loads, get_stores

python/tvm/relay/backend/contrib/ethosu/tir/dma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint: disable=invalid-name, unused-argument
18-
"""Extract information from the DMA operators in TIR."""
18+
"""Extract parameters from the DMA operators in TIR."""
1919
import tvm
2020
from .utils import get_outer_loops, get_base_address, get_strides, get_op_attrs
2121
from .spec import SerialFeatureMap, SerialPadding

python/tvm/relay/backend/contrib/ethosu/tir/passes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint: disable=invalid-name, unused-argument
18-
"""The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler"""
18+
"""The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler."""
1919
import numpy as np # type: ignore
2020

2121
import tvm
@@ -301,7 +301,7 @@ def EncodeConstants(const_dict):
301301
pointer_to_buffer = {}
302302
rewrite_buffer = {}
303303
rewrite_pointer = {}
304-
accel_type = vela_api.get_target_accel_type() # type: ignore
304+
accel_config = vela_api.get_accelerator_config()
305305

306306
def _align_scale_bias(tir_extern_call, bias):
307307
"""Align the scale_bias to 16 bytes."""
@@ -316,7 +316,7 @@ def _align_scale_bias(tir_extern_call, bias):
316316

317317
def _encode_weights(tir_extern_call, weights):
318318
"""Encode the weights for a TIR extern call."""
319-
value_bytes = vela_api.encode_weights(tir_extern_call, weights, accel_type)
319+
value_bytes = vela_api.encode_weights(tir_extern_call, weights, accel_config)
320320
value = np.frombuffer(value_bytes, dtype="uint8")
321321
return value
322322

python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,17 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint: disable=invalid-name, unused-argument
18-
"""Different schedulers for Arm(R) Ethos(TM)-U NPU"""
18+
"""Scheduling for Arm(R) Ethos(TM)-U NPU."""
1919
import tvm
2020

2121

22-
def schedule(te_graph, const_dict, cascader=None):
23-
"""Schedule a TE graph for NPU compilation.
22+
def schedule(cached_func, const_dict, cascader=None):
23+
"""Schedule a CachedFunc for NPU compilation.
2424
2525
Parameters
2626
----------
27-
te_graph
28-
The TE graph to schedule.
27+
cached_func : CachedFunc
28+
The CachedFunc to schedule.
2929
const_dict : dict of int to numpy.ndarray
3030
The constant dictionary.
3131
cascader : callable, optional
@@ -38,10 +38,10 @@ def schedule(te_graph, const_dict, cascader=None):
3838
The completed schedule for the graph.
3939
4040
"""
41-
s = tvm.te.create_schedule([t.op for t in te_graph.outputs])
41+
s = tvm.te.create_schedule([t.op for t in cached_func.outputs])
4242
if cascader:
43-
cascader(te_graph, const_dict, s)
44-
inline_no_ops(te_graph, s)
43+
cascader(cached_func, const_dict, s)
44+
inline_no_ops(cached_func, s)
4545
schedule_pragmas(s)
4646
schedule_cache_reads(s)
4747
return s
@@ -96,7 +96,7 @@ def total_cascader(stripe_size):
9696
9797
"""
9898

99-
def _cascader(te_graph, const_dict, sch):
99+
def _cascader(cached_func, const_dict, sch):
100100
scheduled = set()
101101

102102
def _visit(tensor, stage, ax):
@@ -106,8 +106,8 @@ def _visit(tensor, stage, ax):
106106
for input_tensor in tensor.op.input_tensors:
107107
_visit(input_tensor, stage, ax)
108108

109-
assert len(te_graph.outputs) == 1
110-
out = te_graph.outputs[0]
109+
assert len(cached_func.outputs) == 1
110+
out = cached_func.outputs[0]
111111
oi, _ = tile_nd(sch, out, stripe_size)
112112
for ax in oi:
113113
sch[out].unroll(ax)
@@ -126,22 +126,22 @@ def copy_constants():
126126
The planning function.
127127
"""
128128

129-
def _planner(te_graph, const_dict, sch):
129+
def _planner(cached_func, const_dict, sch):
130130
planned = set() # type: ignore
131131

132132
def _visit(tensor, reader):
133133
if tensor is not planned:
134134
planned.add(tensor)
135135
if isinstance(tensor.op, tvm.te.PlaceholderOp):
136-
index = list(te_graph.inputs).index(tensor)
136+
index = list(cached_func.inputs).index(tensor)
137137
if index in const_dict:
138138
sch.cache_read(tensor, "global", [reader])
139139

140140
elif isinstance(tensor.op, tvm.te.ComputeOp):
141141
for input_tensor in tensor.op.input_tensors:
142142
_visit(input_tensor, tensor)
143143

144-
for output_tensor in te_graph.outputs:
144+
for output_tensor in cached_func.outputs:
145145
_visit(output_tensor, None)
146146

147147
return _planner
@@ -216,16 +216,16 @@ def _detect_cache_read(stage):
216216
stage.pragma(fax, "op", "ethosu_copy")
217217

218218

219-
def inline_no_ops(te_graph, sch):
219+
def inline_no_ops(cached_func, sch):
220220
"""Inline 'no-ops' - operations that in principle do nothing.
221221
222222
Modifies the schedule in-place. For now we inline reshape and
223223
strided slice - more could be added.
224224
225225
Parameters
226226
----------
227-
te_graph
228-
The TE graph.
227+
cached_func : CachedFunc
228+
The cached func.
229229
sch : tvm.te.Schedule
230230
The schedule.
231231
@@ -241,7 +241,7 @@ def _visit(tensor):
241241
for input_tensor in tensor.op.input_tensors:
242242
_visit(input_tensor)
243243

244-
for out in te_graph.outputs:
244+
for out in cached_func.outputs:
245245
_visit(out)
246246

247247

python/tvm/relay/backend/contrib/ethosu/tir/transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint: disable=invalid-name, unused-argument
18-
"""Extract information from the transform operators in TIR."""
18+
"""Extract parameters from the transform operators in TIR."""
1919
import tvm
2020
from .spec import SerialCopy
2121
from .utils import get_base_address, get_op_attrs

python/tvm/relay/backend/contrib/ethosu/tir/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint: disable=invalid-name
18-
"""Helper utility functions used by the TIR compiler"""
18+
"""Helper utility functions used by the NPU TIR compiler"""
1919
import tvm
2020
from tvm import arith
2121

0 commit comments

Comments
 (0)