Skip to content

Commit 85631dc

Browse files
Merge pull request tensorflow#21616 from karllessard:java-ops-default-type-attrs
PiperOrigin-RevId: 210615110
2 parents af94082 + a053d7b commit 85631dc

File tree

5 files changed

+118
-43
lines changed

5 files changed

+118
-43
lines changed

tensorflow/java/src/gen/cc/java_defs.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ limitations under the License.
2121
#include <string>
2222
#include <utility>
2323

24+
#include "tensorflow/core/framework/types.h"
25+
2426
namespace tensorflow {
2527
namespace java {
2628

@@ -95,6 +97,34 @@ class Type {
9597
static Type IterableOf(const Type& type) {
9698
return Interface("Iterable").add_parameter(type);
9799
}
100+
static Type ForDataType(DataType data_type) {
101+
switch (data_type) {
102+
case DataType::DT_BOOL:
103+
return Class("Boolean");
104+
case DataType::DT_STRING:
105+
return Class("String");
106+
case DataType::DT_FLOAT:
107+
return Class("Float");
108+
case DataType::DT_DOUBLE:
109+
return Class("Double");
110+
case DataType::DT_UINT8:
111+
return Class("UInt8", "org.tensorflow.types");
112+
case DataType::DT_INT32:
113+
return Class("Integer");
114+
case DataType::DT_INT64:
115+
return Class("Long");
116+
case DataType::DT_RESOURCE:
117+
// TODO(karllessard) create a Resource utility class that could be
118+
// used to store a resource and its type (passed in a second argument).
119+
// For now, we need to force a wildcard and we will unfortunately lose
120+
// track of the resource type.
121+
// Falling through...
122+
default:
123+
// Any other datatypes does not have a equivalent in Java and must
124+
// remain a wildcard (e.g. DT_COMPLEX64, DT_QINT8, ...)
125+
return Wildcard();
126+
}
127+
}
98128
const Kind& kind() const { return kind_; }
99129
const string& name() const { return name_; }
100130
const string& package() const { return package_; }

tensorflow/java/src/gen/cc/op_generator.cc

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#include <memory>
1919
#include <set>
2020
#include <string>
21+
#include <utility>
2122
#include <vector>
2223

2324
#include "tensorflow/core/framework/op_gen_lib.h"
@@ -100,6 +101,10 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode,
100101
for (const AttributeSpec& attribute : op.attributes()) {
101102
out->push_back(attribute.var().type());
102103
out->push_back(attribute.jni_type());
104+
if (attribute.has_default_value() &&
105+
attribute.type().kind() == Type::GENERIC) {
106+
out->push_back(Type::ForDataType(attribute.default_value()->type()));
107+
}
103108
}
104109
for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
105110
out->push_back(optional_attribute.var().type());
@@ -139,6 +144,60 @@ void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
139144
}
140145
}
141146

147+
void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class,
148+
std::map<string, Type> default_types,
149+
SourceWriter* writer) {
150+
// Build the return type for the secondary factory, replacing generic
151+
// parameters with their default value if any
152+
Type return_type = Type::Class(op_class.name(), op_class.package());
153+
for (const Type& parameter : op_class.parameters()) {
154+
if (parameter.kind() == Type::GENERIC &&
155+
default_types.find(parameter.name()) != default_types.end()) {
156+
return_type.add_parameter(default_types.at(parameter.name()));
157+
} else {
158+
return_type.add_parameter(parameter);
159+
}
160+
}
161+
Method factory = Method::Create("create", return_type);
162+
Javadoc factory_doc = Javadoc::Create(
163+
"Factory method to create a class to wrap a new " + op_class.name() +
164+
" operation to the graph, using "
165+
"default output types.");
166+
Variable scope =
167+
Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
168+
AddArgument(scope, "current graph scope", &factory, &factory_doc);
169+
std::stringstream factory_statement;
170+
factory_statement << "return create(scope";
171+
for (const ArgumentSpec& input : op.inputs()) {
172+
AddArgument(input.var(), input.description(), &factory, &factory_doc);
173+
factory_statement << ", " << input.var().name();
174+
}
175+
for (const AttributeSpec& attr : op.attributes()) {
176+
// Only add attributes that are not types or have no default value to the
177+
// signature of the secondary factory
178+
factory_statement << ", ";
179+
if (attr.type().kind() == Type::GENERIC &&
180+
default_types.find(attr.type().name()) != default_types.end()) {
181+
factory_statement << default_types.at(attr.type().name()).name()
182+
<< ".class";
183+
} else {
184+
AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
185+
factory_statement << attr.var().name();
186+
}
187+
}
188+
if (!op.optional_attributes().empty()) {
189+
Variable options_var = Variable::Varargs("options", Type::Class("Options"));
190+
AddArgument(options_var, "carries optional attributes values", &factory,
191+
&factory_doc);
192+
factory_statement << ", " << options_var.name();
193+
}
194+
factory_doc.add_tag("return", "a new instance of " + op_class.name());
195+
196+
writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc);
197+
writer->Append(factory_statement.str().c_str()).Append(");").EndLine();
198+
writer->EndMethod();
199+
}
200+
142201
void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
143202
SourceWriter* writer) {
144203
Method factory = Method::Create("create", op_class);
@@ -151,8 +210,17 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
151210
for (const ArgumentSpec& input : op.inputs()) {
152211
AddArgument(input.var(), input.description(), &factory, &factory_doc);
153212
}
213+
std::map<string, Type> default_types;
154214
for (const AttributeSpec& attr : op.attributes()) {
155215
AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
216+
// If this attribute is a type with a default value, save its value
217+
// for passing it implicitly in a secondary factory method
218+
if (attr.has_default_value() && attr.type().kind() == Type::GENERIC) {
219+
Type default_type = Type::ForDataType(attr.default_value()->type());
220+
if (!default_type.wildcard()) {
221+
default_types.insert(std::make_pair(attr.type().name(), default_type));
222+
}
223+
}
156224
}
157225
if (!op.optional_attributes().empty()) {
158226
AddArgument(Variable::Varargs("options", Type::Class("Options")),
@@ -194,6 +262,12 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
194262
.Append("(opBuilder.build());")
195263
.EndLine();
196264
writer->EndMethod();
265+
266+
// If this operation has type attributes with a default value, create a
267+
// second factory method that infers those values implicitly
268+
if (!default_types.empty()) {
269+
RenderSecondaryFactoryMethod(op, op_class, default_types, writer);
270+
}
197271
}
198272

199273
void RenderConstructor(const OpSpec& op, const Type& op_class,

tensorflow/java/src/gen/cc/op_specs.cc

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -96,43 +96,10 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out) {
9696
*iterable_out = true;
9797
visited_attrs_.insert(std::make_pair(arg_def.number_attr(), Type::Int()));
9898
}
99-
10099
Type type = Type::Wildcard();
101100
if (arg_def.type() != DataType::DT_INVALID) {
102-
// resolve type from DataType
103-
switch (arg_def.type()) {
104-
case DataType::DT_BOOL:
105-
type = Type::Class("Boolean");
106-
break;
107-
case DataType::DT_STRING:
108-
type = Type::Class("String");
109-
break;
110-
case DataType::DT_FLOAT:
111-
type = Type::Class("Float");
112-
break;
113-
case DataType::DT_DOUBLE:
114-
type = Type::Class("Double");
115-
break;
116-
case DataType::DT_UINT8:
117-
type = Type::Class("UInt8", "org.tensorflow.types");
118-
break;
119-
case DataType::DT_INT32:
120-
type = Type::Class("Integer");
121-
break;
122-
case DataType::DT_INT64:
123-
type = Type::Class("Long");
124-
break;
125-
case DataType::DT_RESOURCE:
126-
// TODO(karllessard) create a Resource utility class that could be
127-
// used to store a resource and its type (passed in a second argument).
128-
// For now, we need to force a wildcard and we will unfortunately lose
129-
// track of the resource type.
130-
break;
131-
default:
132-
// Any other datatypes does not have a equivalent in Java and must
133-
// remain a wildcard (e.g. DT_COMPLEX64, DT_QINT8, ...)
134-
break;
135-
}
101+
type = Type::ForDataType(arg_def.type());
102+
136103
} else if (!arg_def.type_attr().empty()) {
137104
// resolve type from attribute (if already visited, retrieve its type)
138105
if (IsAttributeVisited(arg_def.type_attr())) {
@@ -337,7 +304,7 @@ AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
337304
bool iterable = false;
338305
std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable);
339306
Type var_type = types.first.kind() == Type::GENERIC
340-
? Type::Class("Class").add_parameter(types.first)
307+
? Type::ClassOf(types.first)
341308
: types.first;
342309
if (iterable) {
343310
var_type = Type::ListOf(var_type);
@@ -346,7 +313,8 @@ AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
346313
attr_api_def.name(),
347314
Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type),
348315
types.first, types.second, ParseDocumentation(attr_api_def.description()),
349-
iterable, attr_api_def.has_default_value());
316+
iterable,
317+
attr_def.has_default_value() ? &attr_def.default_value() : nullptr);
350318
}
351319

352320
ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def,

tensorflow/java/src/gen/cc/op_specs.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,26 +94,30 @@ class AttributeSpec {
9494
// jni_type: the type of this attribute in JNI layer (see OperationBuilder)
9595
// description: a description of this attribute, in javadoc
9696
// iterable: true if this attribute is a list
97-
// has_default_value: true if this attribute has a default value if not set
97+
// default_value: default value for this attribute or nullptr if none. Any
98+
// value referenced by this pointer must outlive the lifetime
99+
// of the AttributeSpec. This is guaranteed if the value is
100+
// issued by an OpDef of the global OpRegistry.
98101
AttributeSpec(const string& op_def_name, const Variable& var,
99102
const Type& type, const Type& jni_type,
100103
const string& description, bool iterable,
101-
bool has_default_value)
104+
const AttrValue* default_value)
102105
: op_def_name_(op_def_name),
103106
var_(var),
104107
type_(type),
105108
description_(description),
106109
iterable_(iterable),
107110
jni_type_(jni_type),
108-
has_default_value_(has_default_value) {}
111+
default_value_(default_value) {}
109112

110113
const string& op_def_name() const { return op_def_name_; }
111114
const Variable& var() const { return var_; }
112115
const Type& type() const { return type_; }
113116
const string& description() const { return description_; }
114117
bool iterable() const { return iterable_; }
115118
const Type& jni_type() const { return jni_type_; }
116-
bool has_default_value() const { return has_default_value_; }
119+
bool has_default_value() const { return default_value_ != nullptr; }
120+
const AttrValue* default_value() const { return default_value_; }
117121

118122
private:
119123
const string op_def_name_;
@@ -122,7 +126,7 @@ class AttributeSpec {
122126
const string description_;
123127
const bool iterable_;
124128
const Type jni_type_;
125-
const bool has_default_value_;
129+
const AttrValue* default_value_;
126130
};
127131

128132
class OpSpec {

tensorflow/java/src/gen/cc/source_writer.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ limitations under the License.
1616
#include <string>
1717
#include <algorithm>
1818
#include <list>
19-
#include <string>
2019

2120
#include "tensorflow/java/src/gen/cc/source_writer.h"
2221

0 commit comments

Comments
 (0)