Skip to content

Commit 6da26fe

Browse files
bertmaherfacebook-github-bot
authored andcommitted
[te] Fix pow (pytorch#48213)
Summary: Pull Request resolved: pytorch#48213 it was completely broken unless rhs was a constant. Test Plan: new unit test in test_jit_fuser_te.py Reviewed By: eellison Differential Revision: D25071639 fbshipit-source-id: ef1010a9fd551db646b83adfaa961648a5c388ae
1 parent ed57f80 commit 6da26fe

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

test/test_jit_fuser_te.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,6 +1366,25 @@ def eager(a, b):
13661366
b = torch.randint(-2, 2, (1, 64), device='cuda', dtype=torch.long)
13671367
script = self.checkScript(eager, (a, b))
13681368

1369+
def test_neg_pow(self):
1370+
def eager_tt(a: torch.Tensor, b: torch.Tensor):
1371+
return torch.neg(torch.pow(a, b))
1372+
1373+
def eager_ts(a: torch.Tensor, b: float):
1374+
return torch.neg(torch.pow(a, b))
1375+
1376+
def eager_st(a: float, b: torch.Tensor):
1377+
return torch.neg(torch.pow(a, b))
1378+
1379+
a = torch.rand(1, dtype=torch.float)
1380+
b = torch.rand(1, dtype=torch.float)
1381+
s = b.item()
1382+
script = self.checkScript(eager_tt, (a, b))
1383+
self.assertAllFused(script.graph_for(a, b))
1384+
script = self.checkScript(eager_ts, (a, s))
1385+
self.assertAllFused(script.graph_for(a, s))
1386+
script = self.checkScript(eager_st, (s, b))
1387+
self.assertAllFused(script.graph_for(s, b))
13691388

13701389
if __name__ == '__main__':
13711390
run_tests()

torch/csrc/jit/tensorexpr/kernel.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,10 +1026,11 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
10261026
case aten::pow: {
10271027
return computeTwoOperand(
10281028
"aten_pow", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
1029-
double val = 0;
1030-
if (rhs.node()->isConstant()) {
1031-
val = immediateAs<double>(IRSimplifier::simplify(rhs.node()));
1029+
if (!rhs.node()->isConstant()) {
1030+
return pow(lhs, rhs);
10321031
}
1032+
double val =
1033+
immediateAs<double>(IRSimplifier::simplify(rhs.node()));
10331034

10341035
if (val == 1.0f) {
10351036
return lhs;

0 commit comments

Comments
 (0)