@@ -69,6 +69,12 @@ pub enum DiffActivity {
69
69
/// length of a slice/vec. This is used for safety checks on slices.
70
70
FakeActivitySize ,
71
71
}
72
+
73
+ impl DiffActivity {
74
+ pub fn is_dual_or_const ( & self ) -> bool {
75
+ use DiffActivity :: * ;
76
+ matches ! ( self , |Dual | DualOnly | Dualv | DualvOnly | Const )
77
+ }
72
78
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
73
79
#[ derive( Clone , Eq , PartialEq , Encodable , Decodable , Debug , HashStable_Generic ) ]
74
80
pub struct AutoDiffItem {
@@ -140,11 +146,7 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
140
146
DiffMode :: Error => false ,
141
147
DiffMode :: Source => false ,
142
148
DiffMode :: Forward => {
143
- activity == DiffActivity :: Dual
144
- || activity == DiffActivity :: Dualv
145
- || activity == DiffActivity :: DualOnly
146
- || activity == DiffActivity :: DualvOnly
147
- || activity == DiffActivity :: Const
149
+ activity. is_dual_or_const ( )
148
150
}
149
151
DiffMode :: Reverse => {
150
152
activity == DiffActivity :: Const
@@ -163,10 +165,8 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
163
165
pub fn valid_ty_for_activity ( ty : & P < Ty > , activity : DiffActivity ) -> bool {
164
166
use DiffActivity :: * ;
165
167
// It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
166
- if matches ! ( activity, Const ) {
167
- return true ;
168
- }
169
- if matches ! ( activity, Dual | DualOnly | Dualv | DualvOnly ) {
168
+ // Dual variants also support all types.
169
+ if activity. is_dual_or_const ( ) {
170
170
return true ;
171
171
}
172
172
// FIXME(ZuseZ4) We should make this more robust to also
@@ -183,7 +183,7 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
183
183
DiffMode :: Error => false ,
184
184
DiffMode :: Source => false ,
185
185
DiffMode :: Forward => {
186
- matches ! ( activity, Dual | DualOnly | Dualv | DualvOnly | Const )
186
+ activity. is_dual_or_const ( )
187
187
}
188
188
DiffMode :: Reverse => {
189
189
matches ! ( activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const )
0 commit comments