Skip to content

Commit 83596bd

Browse files
committed
produce a Declarations.yaml file that describes Functions/Type/Tensor methods that framework produced.
1 parent f3f8ce4 commit 83596bd

File tree

3 files changed

+117
-35
lines changed

3 files changed

+117
-35
lines changed

CMakeLists.txt

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,11 @@ ELSE()
8585
ENDIF()
8686

8787
# Can be compiled standalone
88-
IF(NOT TENSOR_LIB_INSTALL_BIN_DIR OR NOT TENSOR_LIB_INSTALL_LIB_DIR OR NOT TENSOR_LIB_INSTALL_INCLUDE_DIR)
89-
SET(TENSOR_LIB_INSTALL_BIN_DIR "bin" CACHE PATH "TENSOR_LIB install binary subdirectory")
90-
SET(TENSOR_LIB_INSTALL_LIB_DIR "lib" CACHE PATH "TENSOR_LIB install library subdirectory")
91-
SET(TENSOR_LIB_INSTALL_INCLUDE_DIR "include" CACHE PATH "TENSOR_LIB install include subdirectory")
88+
IF(NOT AT_INSTALL_BIN_DIR OR NOT AT_INSTALL_LIB_DIR OR NOT AT_INSTALL_INCLUDE_DIR OR NOT AT_INSTALL_SHARE_DIR)
89+
SET(AT_INSTALL_BIN_DIR "bin" CACHE PATH "AT install binary subdirectory")
90+
SET(AT_INSTALL_LIB_DIR "lib" CACHE PATH "AT install library subdirectory")
91+
SET(AT_INSTALL_INCLUDE_DIR "include" CACHE PATH "AT install include subdirectory")
92+
SET(AT_INSTALL_SHARE_DIR "share" CACHE PATH "AT install include subdirectory")
9293
ENDIF()
9394

9495
FILE(GLOB base_h RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.h")
@@ -115,8 +116,14 @@ IF(NOT DEFINED cwrap_files)
115116
)
116117
ENDIF()
117118

119+
SET(GEN_COMMAND
120+
python ${CMAKE_CURRENT_SOURCE_DIR}/gen.py ${CUDA_FLAG}
121+
-s ${CMAKE_CURRENT_SOURCE_DIR}
122+
${cwrap_files}
123+
)
124+
118125
EXECUTE_PROCESS(
119-
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/gen.py ${CUDA_FLAG} -s ${CMAKE_CURRENT_SOURCE_DIR} --output-dependencies ${CMAKE_CURRENT_BINARY_DIR}/generated_cpp.txt ${cwrap_files}
126+
COMMAND ${GEN_COMMAND} --output-dependencies ${CMAKE_CURRENT_BINARY_DIR}/generated_cpp.txt
120127
RESULT_VARIABLE RETURN_VALUE
121128
)
122129
if (NOT RETURN_VALUE EQUAL 0)
@@ -130,7 +137,7 @@ FILE(GLOB_RECURSE all_templates "templates/*")
130137
FILE(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ATen)
131138

132139
ADD_CUSTOM_COMMAND(OUTPUT ${generated_cpp}
133-
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/gen.py ${CUDA_FLAG} -s ${CMAKE_CURRENT_SOURCE_DIR} ${cwrap_files}
140+
COMMAND ${GEN_COMMAND}
134141
DEPENDS ${all_python} ${all_templates} ${cwrap_files})
135142

136143
SET(all_cpp ${base_cpp} ${generated_cpp})
@@ -153,14 +160,16 @@ IF(CUDA_FOUND)
153160
ENDIF()
154161

155162
INSTALL(TARGETS ATen
156-
RUNTIME DESTINATION "${TENSOR_LIB_INSTALL_BIN_DIR}"
157-
LIBRARY DESTINATION "${TENSOR_LIB_INSTALL_LIB_DIR}"
158-
ARCHIVE DESTINATION "${TENSOR_LIB_INSTALL_LIB_DIR}")
163+
RUNTIME DESTINATION "${AT_INSTALL_BIN_DIR}"
164+
LIBRARY DESTINATION "${AT_INSTALL_LIB_DIR}"
165+
ARCHIVE DESTINATION "${AT_INSTALL_LIB_DIR}")
159166

160167
FOREACH(HEADER ${base_h})
161-
INSTALL(FILES ${HEADER} DESTINATION ${TENSOR_LIB_INSTALL_INCLUDE_DIR}/ATen)
168+
INSTALL(FILES ${HEADER} DESTINATION ${AT_INSTALL_INCLUDE_DIR}/ATen)
162169
ENDFOREACH()
163170
FOREACH(HEADER ${generated_h})
164171
INSTALL(FILES ${CMAKE_CURRENT_BINARY_DIR}/${HEADER}
165-
DESTINATION ${TENSOR_LIB_INSTALL_INCLUDE_DIR}/ATen)
172+
DESTINATION ${AT_INSTALL_INCLUDE_DIR}/ATen)
166173
ENDFOREACH()
174+
INSTALL(FILES ${CMAKE_CURRENT_BINARY_DIR}/ATen/Declarations.yaml
175+
DESTINATION ${AT_INSTALL_SHARE_DIR}/ATen)

function_wrapper.py

Lines changed: 86 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,20 @@ def __init__(self, reason):
8585
'long': 'int64_t',
8686
}
8787

88+
DYNAMIC_TYPE = {
89+
'THTensor*': 'Tensor',
90+
'THBoolTensor*': 'BoolTensor',
91+
'THIndexTensor*': 'IndexTensor',
92+
'THIntegerTensor*': 'IntegerTensor',
93+
'THStorage*': 'Storage',
94+
'THGenerator*': 'Generator',
95+
'THSize*': 'IntList',
96+
'THStride*': 'IntList',
97+
'accreal': 'accreal',
98+
'real': 'real',
99+
'long': 'int64_t',
100+
}
101+
88102
TYPE_RETURN = {
89103
'THTensor*': 'Tensor',
90104
'THIndexTensor*': 'Tensor',
@@ -164,11 +178,29 @@ def to_return_type(arg, option):
164178
rt = rt + ' &'
165179
if not is_mutable_formal_argument(arg, option):
166180
rt = 'const ' + rt
167-
return rt
181+
return {
182+
'type': rt,
183+
'dynamic_type': DYNAMIC_TYPE.get(arg['type'], arg['type']),
184+
}
168185

169186

170187
def create_generic(top_env, declarations):
171188

189+
# change from THTensor* to Tensor & so we get how it will appear
190+
# in the aten argument list...
191+
def translate_formal(argument, option):
192+
type_str = TYPE_FORMAL_GENERIC.get(argument['type'], argument['type'])
193+
if type_str == 'Tensor &' and not is_mutable_formal_argument(argument, option):
194+
type_str = 'const ' + type_str
195+
translated = {
196+
'name': argument['name'],
197+
'type': type_str,
198+
'dynamic_type': DYNAMIC_TYPE.get(argument['type'], argument['type']),
199+
}
200+
if argument.get('output'):
201+
translated['output'] = True
202+
return translated
203+
172204
def get_formals(option):
173205
seen = set()
174206
result = []
@@ -185,38 +217,43 @@ def insert(argument):
185217
for argument in option['arguments']:
186218
if argument.get('output') and not argument.get('allocate', False):
187219
insert(argument)
188-
return result
189220

190-
def format_formal(argument, option):
191-
type_str = TYPE_FORMAL_GENERIC.get(argument['type'], argument['type'])
192-
if type_str == 'Tensor &' and not is_mutable_formal_argument(argument, option):
193-
type_str = 'const ' + type_str
194-
return '{} {}'.format(type_str, argument['name'])
221+
return [translate_formal(argument, option) for argument in result]
195222

196-
def format_return_type(option):
223+
def get_return_types(option):
197224
ret = option['return']
198225
if ret['kind'] == 'arguments':
199226
argument_indices = ret['arguments']
200227
if len(argument_indices) == 1:
201228
the_arg = option['arguments'][argument_indices[0]]
202-
return to_return_type(the_arg, option)
229+
return [to_return_type(the_arg, option)]
203230
else:
204-
types = [to_return_type(option['arguments'][idx], option)
205-
for idx in argument_indices]
206-
return "std::tuple<{}>".format(','.join(types))
207-
231+
return [to_return_type(option['arguments'][idx], option)
232+
for idx in argument_indices]
208233
elif ret['kind'] == 'type':
209-
return TYPE_RETURN.get(ret['type'], ret['type'])
234+
return [{
235+
'type': TYPE_RETURN.get(ret['type'], ret['type']),
236+
'dynamic_type': DYNAMIC_TYPE.get(ret['type'], ret['type']),
237+
}]
210238
else:
211239
raise Exception("format_return_type")
212240

241+
def format_return_type(return_types):
242+
if len(return_types) == 1:
243+
return return_types[0]['type']
244+
return "std::tuple<{}>".format(','.join(r['type'] for r in return_types))
245+
return return_types
246+
213247
def find_first_tensor(formals):
214-
for argument in formals:
215-
if argument['type'] == "THTensor*" or argument['type'] == 'TensorList':
216-
return argument['name']
248+
for formal in formals:
249+
if 'Tensor' == formal['dynamic_type'] or 'TensorList' == formal['dynamic_type']:
250+
return formal['name']
217251
return None
218252

219-
def process_option(option):
253+
def format_formal(f):
254+
return '{} {}'.format(f['type'],f['name'])
255+
256+
def process_option(option, output_options):
220257
option['inplace'] = re.search(
221258
'(^__i|[^_]_$)', option['api_name']) is not None
222259

@@ -226,13 +263,15 @@ def process_option(option):
226263
# print(yaml.dump(option))
227264
formals = get_formals(option)
228265
option['formals_list'] = formals
229-
option['formals'] = [format_formal(f, option) for f in formals]
266+
option['formals'] = [format_formal(f) for f in formals]
267+
option['returns'] = get_return_types(option)
230268
option['actuals'] = [f['name'] for f in formals]
231-
option['method_formals'] = [format_formal(f, option) for f in formals
269+
270+
option['method_formals'] = [format_formal(f) for f in formals
232271
if f['name'] != 'self']
233272
option['method_actuals'] = [
234273
f['name'] if f['name'] != 'self' else '*this' for f in formals]
235-
option['return_type'] = format_return_type(option)
274+
option['return_type'] = format_return_type(option['returns'])
236275

237276
option['const_mark'] = '' if option['inplace'] else ' const'
238277

@@ -253,22 +292,46 @@ def process_option(option):
253292
TENSOR_METHOD_DECLARATION.substitute(env))
254293
top_env['tensor_method_definitions'].append(
255294
TENSOR_METHOD_DEFINITION.substitute(env))
295+
output_options.append({
296+
'name': option['name'],
297+
'arguments': [f for f in formals if f['name'] != 'self'],
298+
'method_of': 'Tensor',
299+
'returns': option['returns'],
300+
'inplace': option['inplace'],
301+
})
256302

257303
if is_function:
258304
first_tensor = find_first_tensor(formals)
305+
output_option = {
306+
'name': option['name'],
307+
'arguments': formals,
308+
'returns': option['returns'],
309+
'inplace': option['inplace'],
310+
}
259311
if first_tensor is not None:
260312
option['inferred_type'] = 'infer_type({})'.format(first_tensor)
261313
top_env['function_declarations'].append(
262314
FUNCTION_DECLARATION.substitute(env))
263315
top_env['function_definitions'].append(
264316
FUNCTION_DEFINITION.substitute(env))
317+
else:
318+
output_option['method_of'] = 'Type'
319+
output_options.append(output_option)
265320

321+
output_declarations = []
266322
for declaration in declarations:
323+
output_options = []
267324
for option in declaration['options']:
268325
try:
269-
process_option(option)
326+
process_option(option,output_options)
270327
except NYIError:
271328
option['skip'] = True
329+
if len(output_options) > 0:
330+
output_declarations.append({
331+
'name': output_options[0]['name'],
332+
'options': output_options,
333+
})
334+
return output_declarations
272335

273336

274337
def create_derived(backend_type_env, declarations):
@@ -429,7 +492,7 @@ def emit_body(env, option):
429492
arg = arguments[0]
430493
body.append("return {};".format(arg['name']))
431494
else:
432-
types = [to_return_type(arg, option) for arg in arguments]
495+
types = [to_return_type(arg, option)['type'] for arg in arguments]
433496
# TODO: check for move semantics...
434497
names = [arg['name'] for arg in arguments]
435498
body.append(CodeTemplate("return std::tuple<${types}>(${names});").substitute(

gen.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from optparse import OptionParser
2+
import yaml
23

34
import cwrap_parser
45
import nn_parse
@@ -99,6 +100,14 @@ def write(filename, s):
99100
f.write(s)
100101

101102

103+
def format_yaml(data):
104+
if options.output_dependencies:
105+
return "" # yaml formatting is slow so don't do it if we will ditch it.
106+
noalias_dumper = yaml.dumper.SafeDumper
107+
noalias_dumper.ignore_aliases = lambda self, data: True
108+
return yaml.dump(data, default_flow_style=False, Dumper=noalias_dumper)
109+
110+
102111
def generate_storage_type_and_tensor(backend, density, scalar_type, declarations):
103112
scalar_name, c_type, accreal, th_scalar_type = scalar_type
104113
env = {}
@@ -218,7 +227,8 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations
218227
# note: this will fill in top_env['type/tensor_method_declarations/definitions']
219228
# and modify the declarations to include any information that will all_backends
220229
# be used by function_wrapper.create_derived
221-
function_wrapper.create_generic(top_env, declarations)
230+
output_declarations = function_wrapper.create_generic(top_env, declarations)
231+
write("Declarations.yaml", format_yaml(output_declarations))
222232

223233
# populated by generate_storage_type_and_tensor
224234
all_types = []

0 commit comments

Comments
 (0)