|
9 | 9 | using Microsoft.ML.Internal.Utilities;
|
10 | 10 | using Microsoft.ML.Runtime;
|
11 | 11 |
|
12 |
| -namespace Microsoft.ML.EntryPoints |
| 12 | +namespace Microsoft.ML.EntryPoints; |
| 13 | + |
| 14 | +[BestFriend] |
| 15 | +internal static class EntryPointUtils |
13 | 16 | {
|
14 |
| - [BestFriend] |
15 |
| - internal static class EntryPointUtils |
| 17 | + private static readonly FuncStaticMethodInfo1<TlcModule.RangeAttribute, object, bool> _isValueWithinRangeMethodInfo |
| 18 | + = new FuncStaticMethodInfo1<TlcModule.RangeAttribute, object, bool>(IsValueWithinRange<int>); |
| 19 | + |
| 20 | + private static bool IsValueWithinRange<T>(TlcModule.RangeAttribute range, object obj) |
16 | 21 | {
|
17 |
| - private static readonly FuncStaticMethodInfo1<TlcModule.RangeAttribute, object, bool> _isValueWithinRangeMethodInfo |
18 |
| - = new FuncStaticMethodInfo1<TlcModule.RangeAttribute, object, bool>(IsValueWithinRange<int>); |
| 22 | + T val; |
| 23 | + if (obj is Optional<T> asOptional) |
| 24 | + val = asOptional.Value; |
| 25 | + else |
| 26 | + val = (T)obj; |
19 | 27 |
|
20 |
| - private static bool IsValueWithinRange<T>(TlcModule.RangeAttribute range, object obj) |
21 |
| - { |
22 |
| - T val; |
23 |
| - if (obj is Optional<T> asOptional) |
24 |
| - val = asOptional.Value; |
25 |
| - else |
26 |
| - val = (T)obj; |
27 |
| - |
28 |
| - return |
29 |
| - (range.Min == null || ((IComparable)range.Min).CompareTo(val) <= 0) && |
30 |
| - (range.Inf == null || ((IComparable)range.Inf).CompareTo(val) < 0) && |
31 |
| - (range.Max == null || ((IComparable)range.Max).CompareTo(val) >= 0) && |
32 |
| - (range.Sup == null || ((IComparable)range.Sup).CompareTo(val) > 0); |
33 |
| - } |
| 28 | + return |
| 29 | + (range.Min == null || ((IComparable)range.Min).CompareTo(val) <= 0) && |
| 30 | + (range.Inf == null || ((IComparable)range.Inf).CompareTo(val) < 0) && |
| 31 | + (range.Max == null || ((IComparable)range.Max).CompareTo(val) >= 0) && |
| 32 | + (range.Sup == null || ((IComparable)range.Sup).CompareTo(val) > 0); |
| 33 | + } |
34 | 34 |
|
35 |
| - public static bool IsValueWithinRange(this TlcModule.RangeAttribute range, object val) |
36 |
| - { |
37 |
| - Contracts.AssertValue(range); |
38 |
| - Contracts.AssertValue(val); |
39 |
| - // Avoid trying to cast double as float. If range |
40 |
| - // was specified using floats, but value being checked |
41 |
| - // is double, change range to be of type double |
42 |
| - if (range.Type == typeof(float) && val is double) |
43 |
| - range.CastToDouble(); |
44 |
| - return Utils.MarshalInvoke(_isValueWithinRangeMethodInfo, range.Type, range, val); |
45 |
| - } |
| 35 | + public static bool IsValueWithinRange(this TlcModule.RangeAttribute range, object val) |
| 36 | + { |
| 37 | + Contracts.AssertValue(range); |
| 38 | + Contracts.AssertValue(val); |
| 39 | + // Avoid trying to cast double as float. If range |
| 40 | + // was specified using floats, but value being checked |
| 41 | + // is double, change range to be of type double |
| 42 | + if (range.Type == typeof(float) && val is double) |
| 43 | + range.CastToDouble(); |
| 44 | + return Utils.MarshalInvoke(_isValueWithinRangeMethodInfo, range.Type, range, val); |
| 45 | + } |
46 | 46 |
|
47 |
| - /// <summary> |
48 |
| - /// Performs checks on an EntryPoint input class equivalent to the checks that are done |
49 |
| - /// when parsing a JSON EntryPoint graph. |
50 |
| - /// |
51 |
| - /// Call this method from EntryPoint methods to ensure that range and required checks are performed |
52 |
| - /// in a consistent manner when EntryPoints are created directly from code. |
53 |
| - /// </summary> |
54 |
| - public static void CheckInputArgs(IExceptionContext ectx, object args) |
| 47 | + /// <summary> |
| 48 | + /// Performs checks on an EntryPoint input class equivalent to the checks that are done |
| 49 | + /// when parsing a JSON EntryPoint graph. |
| 50 | + /// |
| 51 | + /// Call this method from EntryPoint methods to ensure that range and required checks are performed |
| 52 | + /// in a consistent manner when EntryPoints are created directly from code. |
| 53 | + /// </summary> |
| 54 | + public static void CheckInputArgs(IExceptionContext ectx, object args) |
| 55 | + { |
| 56 | + foreach (var fieldInfo in args.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)) |
55 | 57 | {
|
56 |
| - foreach (var fieldInfo in args.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)) |
57 |
| - { |
58 |
| - var attr = fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault() |
59 |
| - as ArgumentAttribute; |
60 |
| - if (attr == null || attr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly) |
61 |
| - continue; |
62 |
| - |
63 |
| - var fieldVal = fieldInfo.GetValue(args); |
64 |
| - var fieldType = fieldInfo.FieldType; |
65 |
| - |
66 |
| - // Optionals are either left in their Implicit constructed state or |
67 |
| - // a new Explicit optional is constructed. They should never be set |
68 |
| - // to null. |
69 |
| - if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>) && fieldVal == null) |
70 |
| - throw ectx.Except("Field '{0}' is Optional<> and set to null instead of an explicit value.", fieldInfo.Name); |
71 |
| - |
72 |
| - if (attr.IsRequired) |
73 |
| - { |
74 |
| - bool equalToDefault; |
75 |
| - if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>)) |
76 |
| - equalToDefault = !((Optional)fieldVal).IsExplicit; |
77 |
| - else |
78 |
| - equalToDefault = fieldType.IsValueType ? Activator.CreateInstance(fieldType).Equals(fieldVal) : fieldVal == null; |
79 |
| - |
80 |
| - if (equalToDefault) |
81 |
| - throw ectx.Except("Field '{0}' is required but is not set.", fieldInfo.Name); |
82 |
| - } |
83 |
| - |
84 |
| - var rangeAttr = fieldInfo.GetCustomAttributes(typeof(TlcModule.RangeAttribute), false).FirstOrDefault() |
85 |
| - as TlcModule.RangeAttribute; |
86 |
| - if (rangeAttr != null && fieldVal != null && !rangeAttr.IsValueWithinRange(fieldVal)) |
87 |
| - throw ectx.Except("Field '{0}' is set to a value that falls outside the range bounds.", fieldInfo.Name); |
88 |
| - } |
89 |
| - } |
| 58 | + var attr = fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault() |
| 59 | + as ArgumentAttribute; |
| 60 | + if (attr == null || attr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly) |
| 61 | + continue; |
90 | 62 |
|
91 |
| - public static IHost CheckArgsAndCreateHost(IHostEnvironment env, string hostName, object input) |
92 |
| - { |
93 |
| - Contracts.CheckValue(env, nameof(env)); |
94 |
| - var host = env.Register(hostName); |
95 |
| - host.CheckValue(input, nameof(input)); |
96 |
| - CheckInputArgs(host, input); |
97 |
| - return host; |
98 |
| - } |
| 63 | + var fieldVal = fieldInfo.GetValue(args); |
| 64 | + var fieldType = fieldInfo.FieldType; |
99 | 65 |
|
100 |
| - /// <summary> |
101 |
| - /// Searches for the given column name in the schema. This method applies a |
102 |
| - /// common policy that throws an exception if the column is not found |
103 |
| - /// and the column name was explicitly specified. If the column is not found |
104 |
| - /// and the column name was not explicitly specified, it returns null. |
105 |
| - /// </summary> |
106 |
| - public static string FindColumnOrNull(IExceptionContext ectx, DataViewSchema schema, Optional<string> value) |
107 |
| - { |
108 |
| - Contracts.CheckValueOrNull(ectx); |
109 |
| - ectx.CheckValue(schema, nameof(schema)); |
110 |
| - ectx.CheckValue(value, nameof(value)); |
| 66 | + // Optionals are either left in their Implicit constructed state or |
| 67 | + // a new Explicit optional is constructed. They should never be set |
| 68 | + // to null. |
| 69 | + if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>) && fieldVal == null) |
| 70 | + throw ectx.Except("Field '{0}' is Optional<> and set to null instead of an explicit value.", fieldInfo.Name); |
111 | 71 |
|
112 |
| - if (value == "") |
113 |
| - return null; |
114 |
| - if (schema.GetColumnOrNull(value) == null) |
| 72 | + if (attr.IsRequired) |
115 | 73 | {
|
116 |
| - if (value.IsExplicit) |
117 |
| - throw ectx.Except("Column '{0}' not found", value); |
118 |
| - return null; |
| 74 | + bool equalToDefault; |
| 75 | + if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>)) |
| 76 | + equalToDefault = !((Optional)fieldVal).IsExplicit; |
| 77 | + else |
| 78 | + equalToDefault = fieldType.IsValueType ? Activator.CreateInstance(fieldType).Equals(fieldVal) : fieldVal == null; |
| 79 | + |
| 80 | + if (equalToDefault) |
| 81 | + throw ectx.Except("Field '{0}' is required but is not set.", fieldInfo.Name); |
119 | 82 | }
|
120 |
| - return value; |
| 83 | + |
| 84 | + var rangeAttr = fieldInfo.GetCustomAttributes(typeof(TlcModule.RangeAttribute), false).FirstOrDefault() |
| 85 | + as TlcModule.RangeAttribute; |
| 86 | + if (rangeAttr != null && fieldVal != null && !rangeAttr.IsValueWithinRange(fieldVal)) |
| 87 | + throw ectx.Except("Field '{0}' is set to a value that falls outside the range bounds.", fieldInfo.Name); |
121 | 88 | }
|
| 89 | + } |
122 | 90 |
|
123 |
| - /// <summary> |
124 |
| - /// Converts EntryPoint Optional{T} types into nullable types, with the |
125 |
| - /// implicit value being converted to the null value. |
126 |
| - /// </summary> |
127 |
| - public static T? AsNullable<T>(this Optional<T> opt) where T : struct |
| 91 | + public static IHost CheckArgsAndCreateHost(IHostEnvironment env, string hostName, object input) |
| 92 | + { |
| 93 | + Contracts.CheckValue(env, nameof(env)); |
| 94 | + var host = env.Register(hostName); |
| 95 | + host.CheckValue(input, nameof(input)); |
| 96 | + CheckInputArgs(host, input); |
| 97 | + return host; |
| 98 | + } |
| 99 | + |
| 100 | + /// <summary> |
| 101 | + /// Searches for the given column name in the schema. This method applies a |
| 102 | + /// common policy that throws an exception if the column is not found |
| 103 | + /// and the column name was explicitly specified. If the column is not found |
| 104 | + /// and the column name was not explicitly specified, it returns null. |
| 105 | + /// </summary> |
| 106 | + public static string FindColumnOrNull(IExceptionContext ectx, DataViewSchema schema, Optional<string> value) |
| 107 | + { |
| 108 | + Contracts.CheckValueOrNull(ectx); |
| 109 | + ectx.CheckValue(schema, nameof(schema)); |
| 110 | + ectx.CheckValue(value, nameof(value)); |
| 111 | + |
| 112 | + if (value == "") |
| 113 | + return null; |
| 114 | + if (schema.GetColumnOrNull(value) == null) |
128 | 115 | {
|
129 |
| - if (opt.IsExplicit) |
130 |
| - return opt.Value; |
131 |
| - else |
132 |
| - return null; |
| 116 | + if (value.IsExplicit) |
| 117 | + throw ectx.Except("Column '{0}' not found", value); |
| 118 | + return null; |
133 | 119 | }
|
| 120 | + return value; |
| 121 | + } |
| 122 | + |
| 123 | + /// <summary> |
| 124 | + /// Converts EntryPoint Optional{T} types into nullable types, with the |
| 125 | + /// implicit value being converted to the null value. |
| 126 | + /// </summary> |
| 127 | + public static T? AsNullable<T>(this Optional<T> opt) where T : struct |
| 128 | + { |
| 129 | + if (opt.IsExplicit) |
| 130 | + return opt.Value; |
| 131 | + else |
| 132 | + return null; |
134 | 133 | }
|
135 | 134 | }
|
0 commit comments