Skip to content

Commit ae6247c

Browse files
committed
addressing feedback
1 parent 64718ab commit ae6247c

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

+10-10
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ pub enum DiffActivity {
6969
/// length of a slice/vec. This is used for safety checks on slices.
7070
FakeActivitySize,
7171
}
72+
73+
impl DiffActivity {
74+
pub fn is_dual_or_const(&self) -> bool {
75+
use DiffActivity::*;
76+
matches!(self, |Dual | DualOnly | Dualv | DualvOnly | Const)
77+
}
7278
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
7379
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
7480
pub struct AutoDiffItem {
@@ -140,11 +146,7 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
140146
DiffMode::Error => false,
141147
DiffMode::Source => false,
142148
DiffMode::Forward => {
143-
activity == DiffActivity::Dual
144-
|| activity == DiffActivity::Dualv
145-
|| activity == DiffActivity::DualOnly
146-
|| activity == DiffActivity::DualvOnly
147-
|| activity == DiffActivity::Const
149+
activity.is_dual_or_const()
148150
}
149151
DiffMode::Reverse => {
150152
activity == DiffActivity::Const
@@ -163,10 +165,8 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
163165
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
164166
use DiffActivity::*;
165167
// It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
166-
if matches!(activity, Const) {
167-
return true;
168-
}
169-
if matches!(activity, Dual | DualOnly | Dualv | DualvOnly) {
168+
// Dual variants also support all types.
169+
if activity.is_dual_or_const() {
170170
return true;
171171
}
172172
// FIXME(ZuseZ4) We should make this more robust to also
@@ -183,7 +183,7 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
183183
DiffMode::Error => false,
184184
DiffMode::Source => false,
185185
DiffMode::Forward => {
186-
matches!(activity, Dual | DualOnly | Dualv | DualvOnly | Const)
186+
activity.is_dual_or_const()
187187
}
188188
DiffMode::Reverse => {
189189
matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)

0 commit comments

Comments
 (0)