Skip to content

Commit c3d40fd

Browse files
Hao Lufacebook-github-bot
Hao Lu
authored andcommitted
[ATen] Use expect_contiguous in layer_norm (#58067)
Summary: Pull Request resolved: #58067 - Use expect_contiguous in layer_norm to avoid unnecessary refcount bumps when the tensors are contiguous - Clean up some leftovers from the hacky wrappers removal cleanup: use c10::MaybeOwned<Tensor> for bias tensors - Skip dispatcher for at::empty in the layer_norm impl in Static Runtime Test Plan: CI Reviewed By: swolchok Differential Revision: D28214298 fbshipit-source-id: 73150fa62d5c18f41a2264f8e56bbe5e377ad045
1 parent c790fd2 commit c3d40fd

File tree

6 files changed

+202
-183
lines changed

6 files changed

+202
-183
lines changed

aten/src/ATen/native/cuda/layer_norm_kernel.cu

+81-74
Original file line numberDiff line numberDiff line change
@@ -424,32 +424,36 @@ void LayerNormBackwardKernelImpl(
424424

425425
std::tuple<Tensor, Tensor, Tensor> layer_norm_cuda(
426426
const Tensor& input,
427-
IntArrayRef normalized_shape, const c10::optional<Tensor>& weight_opt /* optional */, const c10::optional<Tensor>& bias_opt /* optional */,
427+
IntArrayRef normalized_shape,
428+
const c10::optional<Tensor>& weight_opt /* optional */,
429+
const c10::optional<Tensor>& bias_opt /* optional */,
428430
double eps) {
429431
// See [Note: hacky wrapper removal for optional tensor]
430-
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
432+
c10::MaybeOwned<Tensor> weight_maybe_owned =
433+
at::borrow_from_optional_tensor(weight_opt);
431434
const Tensor& weight = *weight_maybe_owned;
432-
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
433-
435+
c10::MaybeOwned<Tensor> bias_maybe_owned =
436+
at::borrow_from_optional_tensor(bias_opt);
437+
const Tensor& bias = *bias_maybe_owned;
434438

435-
auto inputs = _prepare_layer_norm_inputs(input, normalized_shape, weight, bias);
436-
auto X = std::get<0>(inputs);
437-
auto gamma = std::get<1>(inputs);
438-
auto beta = std::get<2>(inputs);
439-
auto M = std::get<3>(inputs);
440-
auto N = std::get<4>(inputs);
439+
auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);
440+
auto M = M_N.first;
441+
auto N = M_N.second;
442+
auto X = input.expect_contiguous();
443+
auto gamma = weight.expect_contiguous();
444+
auto beta = bias.expect_contiguous();
441445

442446
Tensor Y = at::native::empty_like(
443-
X,
447+
*X,
444448
c10::nullopt /* dtype */,
445449
c10::nullopt /* layout */,
446450
c10::nullopt /* device */,
447451
c10::nullopt /* pin_memory */,
448452
LEGACY_CONTIGUOUS_MEMORY_FORMAT);
449-
Tensor mean = at::empty({M}, X.options());
450-
Tensor rstd = at::empty({M}, X.options());
453+
Tensor mean = at::empty({M}, X->options());
454+
Tensor rstd = at::empty({M}, X->options());
451455
if (M > 0) {
452-
LayerNormKernelImpl(X, gamma, beta, M, N, eps, &Y, &mean, &rstd);
456+
LayerNormKernelImpl(*X, *gamma, *beta, M, N, eps, &Y, &mean, &rstd);
453457

454458
const auto input_shape = input.sizes();
455459
const size_t axis = input.dim() - normalized_shape.size();
@@ -473,73 +477,76 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_cuda(
473477
const Tensor& input,
474478
IntArrayRef normalized_shape,
475479
const Tensor& mean,
476-
const Tensor& rstd, const c10::optional<Tensor>& weight_opt /* optional */, const c10::optional<Tensor>& bias_opt /* optional */,
480+
const Tensor& rstd,
481+
const c10::optional<Tensor>& weight_opt /* optional */,
482+
const c10::optional<Tensor>& bias_opt /* optional */,
477483
std::array<bool, 3> grad_input_mask) {
478484
// See [Note: hacky wrapper removal for optional tensor]
479-
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
485+
c10::MaybeOwned<Tensor> weight_maybe_owned =
486+
at::borrow_from_optional_tensor(weight_opt);
480487
const Tensor& weight = *weight_maybe_owned;
481-
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
482-
488+
c10::MaybeOwned<Tensor> bias_maybe_owned =
489+
at::borrow_from_optional_tensor(bias_opt);
490+
const Tensor& bias = *bias_maybe_owned;
483491

484-
auto inputs = _prepare_layer_norm_inputs(input, normalized_shape, weight, bias);
485-
auto X = std::get<0>(inputs);
486-
auto gamma = std::get<1>(inputs);
487-
auto beta = std::get<2>(inputs);
488-
auto M = std::get<3>(inputs);
489-
auto N = std::get<4>(inputs);
492+
auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);
493+
auto M = M_N.first;
494+
auto N = M_N.second;
495+
auto X = input.expect_contiguous();
496+
auto gamma = weight.expect_contiguous();
497+
auto beta = bias.expect_contiguous();
490498

491-
Tensor dX;
492-
Tensor dgamma;
493-
Tensor dbeta;
494-
if (grad_input_mask[0]) {
495-
dX = at::native::empty_like(
496-
X,
497-
c10::nullopt /* dtype */,
498-
c10::nullopt /* layout */,
499-
c10::nullopt /* device */,
500-
c10::nullopt /* pin_memory */,
501-
LEGACY_CONTIGUOUS_MEMORY_FORMAT);
502-
}
503-
if (grad_input_mask[1]) {
504-
dgamma = M > 0 ? at::native::empty_like(
505-
gamma,
506-
c10::nullopt /* dtype */,
507-
c10::nullopt /* layout */,
508-
c10::nullopt /* device */,
509-
c10::nullopt /* pin_memory */,
510-
LEGACY_CONTIGUOUS_MEMORY_FORMAT)
511-
: at::native::zeros_like(
512-
gamma,
513-
c10::nullopt /* dtype */,
514-
c10::nullopt /* layout */,
515-
c10::nullopt /* device */,
516-
c10::nullopt /* pin_memory */,
517-
LEGACY_CONTIGUOUS_MEMORY_FORMAT);
518-
}
519-
if (grad_input_mask[2]) {
520-
dbeta = M > 0 ? at::native::empty_like(
521-
beta,
522-
c10::nullopt /* dtype */,
523-
c10::nullopt /* layout */,
524-
c10::nullopt /* device */,
525-
c10::nullopt /* pin_memory */,
526-
LEGACY_CONTIGUOUS_MEMORY_FORMAT)
527-
: at::native::zeros_like(
528-
beta,
529-
c10::nullopt /* dtype */,
530-
c10::nullopt /* layout */,
531-
c10::nullopt /* device */,
532-
c10::nullopt /* pin_memory */,
533-
LEGACY_CONTIGUOUS_MEMORY_FORMAT);
534-
}
535-
if (M > 0) {
536-
LayerNormBackwardKernelImpl(
537-
dY, X, mean, rstd, gamma, M, N, &dX, &dgamma, &dbeta);
538-
}
539-
return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta));
499+
Tensor dX;
500+
Tensor dgamma;
501+
Tensor dbeta;
502+
if (grad_input_mask[0]) {
503+
dX = at::native::empty_like(
504+
*X,
505+
c10::nullopt /* dtype */,
506+
c10::nullopt /* layout */,
507+
c10::nullopt /* device */,
508+
c10::nullopt /* pin_memory */,
509+
LEGACY_CONTIGUOUS_MEMORY_FORMAT);
510+
}
511+
if (grad_input_mask[1]) {
512+
dgamma = M > 0 ? at::native::empty_like(
513+
*gamma,
514+
c10::nullopt /* dtype */,
515+
c10::nullopt /* layout */,
516+
c10::nullopt /* device */,
517+
c10::nullopt /* pin_memory */,
518+
LEGACY_CONTIGUOUS_MEMORY_FORMAT)
519+
: at::native::zeros_like(
520+
*gamma,
521+
c10::nullopt /* dtype */,
522+
c10::nullopt /* layout */,
523+
c10::nullopt /* device */,
524+
c10::nullopt /* pin_memory */,
525+
LEGACY_CONTIGUOUS_MEMORY_FORMAT);
526+
}
527+
if (grad_input_mask[2]) {
528+
dbeta = M > 0 ? at::native::empty_like(
529+
*beta,
530+
c10::nullopt /* dtype */,
531+
c10::nullopt /* layout */,
532+
c10::nullopt /* device */,
533+
c10::nullopt /* pin_memory */,
534+
LEGACY_CONTIGUOUS_MEMORY_FORMAT)
535+
: at::native::zeros_like(
536+
*beta,
537+
c10::nullopt /* dtype */,
538+
c10::nullopt /* layout */,
539+
c10::nullopt /* device */,
540+
c10::nullopt /* pin_memory */,
541+
LEGACY_CONTIGUOUS_MEMORY_FORMAT);
542+
}
543+
if (M > 0) {
544+
LayerNormBackwardKernelImpl(
545+
dY, *X, mean, rstd, *gamma, M, N, &dX, &dgamma, &dbeta);
546+
}
547+
return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta));
540548
}
541549

542-
543550
REGISTER_DISPATCH(LayerNormKernel, &LayerNormKernelImpl);
544551
REGISTER_DISPATCH(LayerNormBackwardKernel, &LayerNormBackwardKernelImpl);
545552

aten/src/ATen/native/layer_norm.cpp

+85-81
Original file line numberDiff line numberDiff line change
@@ -60,24 +60,24 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_cpu(
6060
const Tensor& bias = *bias_maybe_owned;
6161

6262

63-
auto inputs = _prepare_layer_norm_inputs(input, normalized_shape, weight, bias);
64-
auto X = std::get<0>(inputs);
65-
auto gamma = std::get<1>(inputs);
66-
auto beta = std::get<2>(inputs);
67-
auto M = std::get<3>(inputs);
68-
auto N = std::get<4>(inputs);
63+
auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);
64+
auto M = M_N.first;
65+
auto N = M_N.second;
66+
auto X = input.expect_contiguous();
67+
auto gamma = weight.expect_contiguous();
68+
auto beta = bias.expect_contiguous();
6969

7070
Tensor Y = at::native::empty_like(
71-
X,
71+
*X,
7272
c10::nullopt /* dtype */,
7373
c10::nullopt /* layout */,
7474
c10::nullopt /* device */,
7575
c10::nullopt /* pin_memory */,
7676
at::MemoryFormat::Contiguous);
77-
Tensor mean = at::empty({M}, X.options());
78-
Tensor rstd = at::empty({M}, X.options());
77+
Tensor mean = at::empty({M}, X->options());
78+
Tensor rstd = at::empty({M}, X->options());
7979

80-
layer_norm_cpu_out(Y, mean, rstd, X, normalized_shape, gamma, beta, eps, M, N);
80+
layer_norm_cpu_out(Y, mean, rstd, *X, normalized_shape, *gamma, *beta, eps, M, N);
8181
return std::make_tuple(std::move(Y), std::move(mean), std::move(rstd));
8282
}
8383

@@ -86,70 +86,74 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_cpu(
8686
const Tensor& input,
8787
IntArrayRef normalized_shape,
8888
const Tensor& mean,
89-
const Tensor& rstd, const c10::optional<Tensor>& weight_opt /* optional */, const c10::optional<Tensor>& bias_opt /* optional */,
89+
const Tensor& rstd,
90+
const c10::optional<Tensor>& weight_opt /* optional */,
91+
const c10::optional<Tensor>& bias_opt /* optional */,
9092
std::array<bool, 3> grad_input_mask) {
9193
// See [Note: hacky wrapper removal for optional tensor]
92-
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
94+
c10::MaybeOwned<Tensor> weight_maybe_owned =
95+
at::borrow_from_optional_tensor(weight_opt);
9396
const Tensor& weight = *weight_maybe_owned;
94-
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
95-
96-
97-
auto inputs = _prepare_layer_norm_inputs(input, normalized_shape, weight, bias);
98-
auto X = std::get<0>(inputs);
99-
auto gamma = std::get<1>(inputs);
100-
auto beta = std::get<2>(inputs);
101-
auto M = std::get<3>(inputs);
102-
auto N = std::get<4>(inputs);
103-
104-
Tensor dX;
105-
Tensor dgamma;
106-
Tensor dbeta;
107-
if (grad_input_mask[0]) {
108-
dX = at::native::empty_like(
109-
X,
110-
c10::nullopt /* dtype */,
111-
c10::nullopt /* layout */,
112-
c10::nullopt /* device */,
113-
c10::nullopt /* pin_memory */,
114-
at::MemoryFormat::Contiguous);
115-
}
116-
if (grad_input_mask[1]) {
117-
dgamma = M > 0 ? at::native::empty_like(
118-
gamma,
119-
c10::nullopt /* dtype */,
120-
c10::nullopt /* layout */,
121-
c10::nullopt /* device */,
122-
c10::nullopt /* pin_memory */,
123-
at::MemoryFormat::Contiguous)
124-
: at::native::zeros_like(
125-
gamma,
126-
c10::nullopt /* dtype */,
127-
c10::nullopt /* layout */,
128-
c10::nullopt /* device */,
129-
c10::nullopt /* pin_memory */,
130-
at::MemoryFormat::Contiguous);
131-
}
132-
if (grad_input_mask[2]) {
133-
dbeta = M > 0 ? at::native::empty_like(
134-
beta,
135-
c10::nullopt /* dtype */,
136-
c10::nullopt /* layout */,
137-
c10::nullopt /* device */,
138-
c10::nullopt /* pin_memory */,
139-
at::MemoryFormat::Contiguous)
140-
: at::native::zeros_like(
141-
beta,
142-
c10::nullopt /* dtype */,
143-
c10::nullopt /* layout */,
144-
c10::nullopt /* device */,
145-
c10::nullopt /* pin_memory */,
146-
at::MemoryFormat::Contiguous);
147-
}
148-
if (M > 0) {
149-
LayerNormBackwardKernel(
150-
kCPU, dY, X, mean, rstd, gamma, M, N, &dX, &dgamma, &dbeta);
151-
}
152-
return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta));
97+
c10::MaybeOwned<Tensor> bias_maybe_owned =
98+
at::borrow_from_optional_tensor(bias_opt);
99+
const Tensor& bias = *bias_maybe_owned;
100+
101+
auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);
102+
auto M = M_N.first;
103+
auto N = M_N.second;
104+
auto X = input.expect_contiguous();
105+
auto gamma = weight.expect_contiguous();
106+
auto beta = bias.expect_contiguous();
107+
108+
Tensor dX;
109+
Tensor dgamma;
110+
Tensor dbeta;
111+
if (grad_input_mask[0]) {
112+
dX = at::native::empty_like(
113+
*X,
114+
c10::nullopt /* dtype */,
115+
c10::nullopt /* layout */,
116+
c10::nullopt /* device */,
117+
c10::nullopt /* pin_memory */,
118+
at::MemoryFormat::Contiguous);
119+
}
120+
if (grad_input_mask[1]) {
121+
dgamma = M > 0 ? at::native::empty_like(
122+
*gamma,
123+
c10::nullopt /* dtype */,
124+
c10::nullopt /* layout */,
125+
c10::nullopt /* device */,
126+
c10::nullopt /* pin_memory */,
127+
at::MemoryFormat::Contiguous)
128+
: at::native::zeros_like(
129+
*gamma,
130+
c10::nullopt /* dtype */,
131+
c10::nullopt /* layout */,
132+
c10::nullopt /* device */,
133+
c10::nullopt /* pin_memory */,
134+
at::MemoryFormat::Contiguous);
135+
}
136+
if (grad_input_mask[2]) {
137+
dbeta = M > 0 ? at::native::empty_like(
138+
*beta,
139+
c10::nullopt /* dtype */,
140+
c10::nullopt /* layout */,
141+
c10::nullopt /* device */,
142+
c10::nullopt /* pin_memory */,
143+
at::MemoryFormat::Contiguous)
144+
: at::native::zeros_like(
145+
*beta,
146+
c10::nullopt /* dtype */,
147+
c10::nullopt /* layout */,
148+
c10::nullopt /* device */,
149+
c10::nullopt /* pin_memory */,
150+
at::MemoryFormat::Contiguous);
151+
}
152+
if (M > 0) {
153+
LayerNormBackwardKernel(
154+
kCPU, dY, *X, mean, rstd, *gamma, M, N, &dX, &dgamma, &dbeta);
155+
}
156+
return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta));
153157
}
154158

155159
Tensor layer_norm(
@@ -160,7 +164,8 @@ Tensor layer_norm(
160164
// See [Note: hacky wrapper removal for optional tensor]
161165
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
162166
const Tensor& weight = *weight_maybe_owned;
163-
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
167+
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
168+
const Tensor& bias = *bias_maybe_owned;
164169

165170

166171
return std::get<0>(at::native_layer_norm(input, normalized_shape, weight, bias, eps));
@@ -179,15 +184,14 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
179184
// See [Note: hacky wrapper removal for optional tensor]
180185
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
181186
const Tensor& weight = *weight_maybe_owned;
182-
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
183-
184-
auto inputs = _prepare_layer_norm_inputs(input, normalized_shape, weight, bias);
185-
auto X = std::get<0>(inputs);
186-
auto gamma = std::get<1>(inputs);
187-
auto beta = std::get<2>(inputs);
188-
auto M = std::get<3>(inputs);
189-
// NOLINTNEXTLINE(clang-diagnostic-unused-variable,clang-analyzer-deadcode.DeadStores)
190-
auto N = std::get<4>(inputs);
187+
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
188+
const Tensor& bias = *bias_maybe_owned;
189+
190+
auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);
191+
auto M = M_N.first;
192+
auto X = input.expect_contiguous();
193+
auto gamma = weight.expect_contiguous();
194+
191195
auto input_shape = input.sizes();
192196
const auto input_ndim = input.dim();
193197
const int normalized_ndim = normalized_shape.size();

0 commit comments

Comments
 (0)