Skip to content

Commit fa46d7a

Browse files
Add support for hamming_window op (#18036)
* Add support for hamming_window * Add testcase for exportedProgram * move key-value pair in exportedProgram to correct place * Fix lint issue * Changed default periodic value to True * Fix lint issue
1 parent 9cb6705 commit fa46d7a

File tree

9 files changed

+206
-0
lines changed

9 files changed

+206
-0
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,20 @@ def _one_hot(self, node: fx.Node) -> relax.Var:
284284

285285
return self.block_builder.emit(relax.op.one_hot(x, on_value, off_value, num_classes, axis))
286286

287+
def _hamming_window(self, node: fx.Node) -> relax.Var:
288+
args = self.retrieve_args(node)
289+
290+
window_size = args[0]
291+
periodic = args[1] if len(args) > 1 else True
292+
alpha = args[2] if len(args) > 2 else 0.54
293+
beta = args[3] if len(args) > 3 else 0.46
294+
dtype = node.kwargs.get("dtype", "float")
295+
dtype = self._convert_data_type(dtype)
296+
297+
return self.block_builder.emit(
298+
relax.op.hamming_window(window_size, periodic, alpha, beta, dtype)
299+
)
300+
287301
def _zeros(self, node: fx.Node) -> relax.Var:
288302
args = self.retrieve_args(node)
289303
size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],))
@@ -528,6 +542,10 @@ def create_convert_map(
528542
"fill_.Scalar": self._inplace_fill,
529543
"full.default": self._full,
530544
"full_like.default": self._full_like,
545+
"hamming_window.periodic": self._hamming_window,
546+
"hamming_window.periodic_alpha": self._hamming_window,
547+
"hamming_window.periodic_alpha_beta": self._hamming_window,
548+
"hamming_window.default": self._hamming_window,
531549
"index_select.default": self._index_select,
532550
"lift_fresh_copy.default": self._to_copy,
533551
"linspace.default": self._linspace,

python/tvm/relax/op/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
arange,
7474
full,
7575
full_like,
76+
hamming_window,
7677
ones,
7778
ones_like,
7879
eye,

python/tvm/relax/op/create.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,41 @@ def is_int(expr):
283283
return _ffi_api.arange(start, end, step, dtype) # type: ignore
284284

285285

286+
def hamming_window(window_size, periodic, alpha, beta, dtype):
287+
"""Hamming window function.
288+
289+
Parameters
290+
----------
291+
window_size : PrimExpr
292+
The size of returned window.
293+
294+
periodic : PrimExpr
295+
If True, returns a window to be used as periodic function.
296+
If False, return a symmetric window.
297+
298+
alpha : PrimExpr
299+
The co-efficient alpha.
300+
301+
beta : PrimExpr
302+
The co-efficient beta.
303+
304+
Returns
305+
-------
306+
ret : relax.Expr
307+
The result tensor.
308+
"""
309+
if not isinstance(window_size, Expr):
310+
window_size = PrimValue(window_size)
311+
if not isinstance(periodic, Expr):
312+
periodic = PrimValue(periodic)
313+
if not isinstance(alpha, Expr):
314+
alpha = PrimValue(alpha)
315+
if not isinstance(beta, Expr):
316+
beta = PrimValue(beta)
317+
318+
return _ffi_api.hamming_window(window_size, periodic, alpha, beta, dtype)
319+
320+
286321
def tril(x: Expr, k: Union[int, PrimExpr, Expr] = 0) -> Expr:
287322
"""Return the lower triangular part of a matrix or a batch of matrices.
288323

python/tvm/relax/transform/legalize_ops/create.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,14 @@ def is_const_scalar(x: PrimValue):
114114
return const(np.arange(start.value, end.value, step.value, dtype=dtype), dtype=dtype)
115115
else:
116116
return bb.call_te(topi.arange, start, end, step, dtype)
117+
118+
119+
@register_legalize("relax.hamming_window")
120+
def _hamming_window(bb: BlockBuilder, call: Call) -> Expr:
121+
assert len(call.args) == 4
122+
dtype = call.attrs.dtype
123+
window_size = call.args[0].value
124+
periodic = call.args[1].value
125+
alpha = call.args[2].value
126+
beta = call.args[3].value
127+
return bb.call_te(topi.hamming_window, window_size, periodic, alpha, beta, dtype)

python/tvm/script/ir_builder/relax/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
grad,
100100
greater,
101101
greater_equal,
102+
hamming_window,
102103
hint_on_device,
103104
index_put,
104105
image,
@@ -786,6 +787,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
786787
"grad",
787788
"greater",
788789
"greater_equal",
790+
"hamming_window",
789791
"hexagon",
790792
"hint_on_device",
791793
"index_put",

python/tvm/topi/transform.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
"""Injective transformation operators"""
1919
from __future__ import absolute_import as _abs
2020

21+
from math import pi
22+
import numpy as np
23+
2124
import tvm
2225
from tvm import te, topi
2326

@@ -1106,3 +1109,45 @@ def index_tensor(data, indices):
11061109
z = topi.index_tensor(x, [row, col]) # shape (2, 3)
11071110
"""
11081111
return topi.adv_index(data, indices)
1112+
1113+
1114+
def hamming_window(window_size, periodic, alpha, beta, dtype):
1115+
"""Hamming window function.
1116+
1117+
Parameters
1118+
----------
1119+
window_size: tvm.Expr
1120+
The size of returned window.
1121+
1122+
periodic: tvm.Expr
1123+
If True, returns a window to be used as periodic function.
1124+
If False, return a symmetric window.
1125+
1126+
alpha: tvm.Expr
1127+
The co-efficient alpha.
1128+
1129+
beta: tvm.Expr
1130+
The co-efficient beta.
1131+
1132+
Returns
1133+
-------
1134+
ret : tvm.te.Tensor
1135+
The result tensor.
1136+
"""
1137+
if window_size == 1:
1138+
return topi.const_vector(np.array([1], dtype=dtype))
1139+
1140+
periodic = topi.cast(periodic, "bool")
1141+
1142+
if periodic:
1143+
window_size += 1
1144+
1145+
index = topi.arange(0, window_size, dtype=dtype)
1146+
angular_freq = 2 * pi * index / (window_size - 1)
1147+
cos_values = topi.cos(angular_freq)
1148+
window = topi.cast(alpha - beta * cos_values, dtype=dtype)
1149+
1150+
if periodic:
1151+
return topi.strided_slice(window, [0], [window.shape[0] - 1])
1152+
1153+
return window

src/relax/op/tensor/create.cc

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
#include <string>
3030
#include <utility>
3131

32+
#include "tvm/relax/expr.h"
33+
3234
namespace tvm {
3335
namespace relax {
3436

@@ -363,6 +365,57 @@ TVM_REGISTER_OP("relax.arange")
363365
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
364366
.set_attr<Bool>("FPurity", Bool(true));
365367

368+
/* relax.hamming_window */
369+
Expr hamming_window(PrimValue window_size, PrimValue periodic, PrimValue alpha, PrimValue beta,
370+
DataType dtype) {
371+
ObjectPtr<InitAttrs> attrs = make_object<InitAttrs>();
372+
attrs->dtype = dtype;
373+
static const Op& op = Op::Get("relax.hamming_window");
374+
return Call(op, {std::move(window_size), std::move(periodic), std::move(alpha), std::move(beta)},
375+
Attrs(attrs), {});
376+
}
377+
378+
TVM_FFI_REGISTER_GLOBAL("relax.op.hamming_window").set_body_typed(hamming_window);
379+
380+
StructInfo InferStructInfoHammingWindow(const Call& call, const BlockBuilder& ctx) {
381+
DataType dtype = call->attrs.as<InitAttrs>()->dtype;
382+
if (dtype.is_int() || dtype.is_uint() || dtype.is_uint()) {
383+
ctx->ReportFatal(Diagnostic::Error(call)
384+
<< "Hamming Window expects the datatype to be float but got " << dtype);
385+
}
386+
auto get_prim_value = [&ctx](const Expr& expr, std::string key) {
387+
if (!expr->IsInstance<PrimValueNode>()) {
388+
ctx->ReportFatal(Diagnostic::Error(expr)
389+
<< "Hamming_window expects the `" << key << "` to be a PrimValue, but got "
390+
<< expr->GetTypeKey());
391+
}
392+
return expr.as<PrimValueNode>()->value;
393+
};
394+
PrimExpr window_size = get_prim_value(call->args[0], "window_size");
395+
396+
arith::Analyzer analyzer;
397+
if (analyzer.CanProveLess(window_size, 1)) {
398+
ctx->ReportFatal(Diagnostic::Error(call)
399+
<< "Hamming_window expects the window_size must be greater than zero but got "
400+
<< window_size);
401+
}
402+
window_size = analyzer.Simplify(window_size);
403+
return TensorStructInfo(ShapeExpr({window_size}), dtype);
404+
}
405+
406+
TVM_REGISTER_OP("relax.hamming_window")
407+
.set_attrs_type<InitAttrs>()
408+
.set_num_inputs(4)
409+
.add_argument("window_size", "PrimValue", "The size of the window")
410+
.add_argument("periodic", "PrimValue",
411+
"If True, returns a window to be used as periodic function. If False, return a "
412+
"symmetric window")
413+
.add_argument("alpha", "PrimValue", "The coefficient alpha")
414+
.add_argument("beta", "PrimValue", "The coefficient beta")
415+
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoHammingWindow)
416+
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
417+
.set_attr<Bool>("FPurity", Bool(true));
418+
366419
/* relax.tril & relax.triu */
367420
TVM_REGISTER_NODE_TYPE(TriluAttrs);
368421

src/relax/op/tensor/create.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <tvm/relax/attrs/create.h>
2929

3030
#include "../op_common.h"
31+
#include "tvm/relax/expr.h"
3132

3233
namespace tvm {
3334
namespace relax {
@@ -118,6 +119,19 @@ Expr eye_like(Expr x, PrimValue k, Optional<DataType> dtype);
118119
/*! \brief Construct a tensor with evenly spaced elements. */
119120
Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype);
120121

122+
/*!
123+
* \brief Hamming window function.
124+
* \param window_size The size of the returned window.
125+
* \param periodic If True, returns a window to be used as periodic function.
126+
* If False, return a symmetric window.
127+
* \param alpha The co-efficient alpha.
128+
* \param beta The co-efficient beta.
129+
* \param dtype The data type of the created tensor.
130+
* \return The result tensor.
131+
*/
132+
Expr hamming_window(PrimValue window_size, PrimValue periodic, PrimValue alpha, PrimValue beta,
133+
DataType dtype);
134+
121135
/*! \brief Return the lower triangular part of a matrix or a batch of matrices. */
122136
Expr tril(Expr x, Expr k);
123137

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4349,6 +4349,33 @@ def main(
43494349
verify_model(Arange(), example_args, {}, Expected)
43504350

43514351

4352+
def test_hamming_window():
4353+
class HammingWindow(Module):
4354+
def forward(self, input):
4355+
return torch.hamming_window(20, True, dtype=torch.float32)
4356+
4357+
@tvm.script.ir_module
4358+
class Expected:
4359+
@R.function
4360+
def main(
4361+
input: R.Tensor((10, 10), dtype="float32")
4362+
) -> R.Tuple(R.Tensor((20,), dtype="float32")):
4363+
with R.dataflow():
4364+
lv: R.Tensor((20,), dtype="float32") = R.hamming_window(
4365+
R.prim_value(20),
4366+
R.prim_value(1),
4367+
R.prim_value(T.float32(0.54000000000000004)),
4368+
R.prim_value(T.float32(0.46000000000000002)),
4369+
dtype="float32",
4370+
)
4371+
gv: R.Tuple(R.Tensor((20,), dtype="float32")) = (lv,)
4372+
R.output(gv)
4373+
return gv
4374+
4375+
example_args = (torch.randn(10, 10, dtype=torch.float32),)
4376+
verify_model(HammingWindow(), example_args, {}, Expected)
4377+
4378+
43524379
def test_contiguous():
43534380
class Contiguous(Module):
43544381
def forward(self, input):

0 commit comments

Comments
 (0)