@@ -18,6 +18,7 @@ limitations under the License.
18
18
#include < memory>
19
19
#include < set>
20
20
#include < string>
21
+ #include < utility>
21
22
#include < vector>
22
23
23
24
#include " tensorflow/core/framework/op_gen_lib.h"
@@ -100,6 +101,10 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode,
100
101
for (const AttributeSpec& attribute : op.attributes ()) {
101
102
out->push_back (attribute.var ().type ());
102
103
out->push_back (attribute.jni_type ());
104
+ if (attribute.has_default_value () &&
105
+ attribute.type ().kind () == Type::GENERIC) {
106
+ out->push_back (Type::ForDataType (attribute.default_value ()->type ()));
107
+ }
103
108
}
104
109
for (const AttributeSpec& optional_attribute : op.optional_attributes ()) {
105
110
out->push_back (optional_attribute.var ().type ());
@@ -139,6 +144,60 @@ void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
139
144
}
140
145
}
141
146
147
+ void RenderSecondaryFactoryMethod (const OpSpec& op, const Type& op_class,
148
+ std::map<string, Type> default_types,
149
+ SourceWriter* writer) {
150
+ // Build the return type for the secondary factory, replacing generic
151
+ // parameters with their default value if any
152
+ Type return_type = Type::Class (op_class.name (), op_class.package ());
153
+ for (const Type& parameter : op_class.parameters ()) {
154
+ if (parameter.kind () == Type::GENERIC &&
155
+ default_types.find (parameter.name ()) != default_types.end ()) {
156
+ return_type.add_parameter (default_types.at (parameter.name ()));
157
+ } else {
158
+ return_type.add_parameter (parameter);
159
+ }
160
+ }
161
+ Method factory = Method::Create (" create" , return_type);
162
+ Javadoc factory_doc = Javadoc::Create (
163
+ " Factory method to create a class to wrap a new " + op_class.name () +
164
+ " operation to the graph, using "
165
+ " default output types." );
166
+ Variable scope =
167
+ Variable::Create (" scope" , Type::Class (" Scope" , " org.tensorflow.op" ));
168
+ AddArgument (scope, " current graph scope" , &factory, &factory_doc);
169
+ std::stringstream factory_statement;
170
+ factory_statement << " return create(scope" ;
171
+ for (const ArgumentSpec& input : op.inputs ()) {
172
+ AddArgument (input.var (), input.description (), &factory, &factory_doc);
173
+ factory_statement << " , " << input.var ().name ();
174
+ }
175
+ for (const AttributeSpec& attr : op.attributes ()) {
176
+ // Only add attributes that are not types or have no default value to the
177
+ // signature of the secondary factory
178
+ factory_statement << " , " ;
179
+ if (attr.type ().kind () == Type::GENERIC &&
180
+ default_types.find (attr.type ().name ()) != default_types.end ()) {
181
+ factory_statement << default_types.at (attr.type ().name ()).name ()
182
+ << " .class" ;
183
+ } else {
184
+ AddArgument (attr.var (), attr.description (), &factory, &factory_doc);
185
+ factory_statement << attr.var ().name ();
186
+ }
187
+ }
188
+ if (!op.optional_attributes ().empty ()) {
189
+ Variable options_var = Variable::Varargs (" options" , Type::Class (" Options" ));
190
+ AddArgument (options_var, " carries optional attributes values" , &factory,
191
+ &factory_doc);
192
+ factory_statement << " , " << options_var.name ();
193
+ }
194
+ factory_doc.add_tag (" return" , " a new instance of " + op_class.name ());
195
+
196
+ writer->BeginMethod (factory, PUBLIC | STATIC, &factory_doc);
197
+ writer->Append (factory_statement.str ().c_str ()).Append (" );" ).EndLine ();
198
+ writer->EndMethod ();
199
+ }
200
+
142
201
void RenderFactoryMethods (const OpSpec& op, const Type& op_class,
143
202
SourceWriter* writer) {
144
203
Method factory = Method::Create (" create" , op_class);
@@ -151,8 +210,17 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
151
210
for (const ArgumentSpec& input : op.inputs ()) {
152
211
AddArgument (input.var (), input.description (), &factory, &factory_doc);
153
212
}
213
+ std::map<string, Type> default_types;
154
214
for (const AttributeSpec& attr : op.attributes ()) {
155
215
AddArgument (attr.var (), attr.description (), &factory, &factory_doc);
216
+ // If this attribute is a type with a default value, save its value
217
+ // for passing it implicitly in a secondary factory method
218
+ if (attr.has_default_value () && attr.type ().kind () == Type::GENERIC) {
219
+ Type default_type = Type::ForDataType (attr.default_value ()->type ());
220
+ if (!default_type.wildcard ()) {
221
+ default_types.insert (std::make_pair (attr.type ().name (), default_type));
222
+ }
223
+ }
156
224
}
157
225
if (!op.optional_attributes ().empty ()) {
158
226
AddArgument (Variable::Varargs (" options" , Type::Class (" Options" )),
@@ -194,6 +262,12 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
194
262
.Append (" (opBuilder.build());" )
195
263
.EndLine ();
196
264
writer->EndMethod ();
265
+
266
+ // If this operation has type attributes with a default value, create a
267
+ // second factory method that infers those values implicitly
268
+ if (!default_types.empty ()) {
269
+ RenderSecondaryFactoryMethod (op, op_class, default_types, writer);
270
+ }
197
271
}
198
272
199
273
void RenderConstructor (const OpSpec& op, const Type& op_class,
0 commit comments