@@ -34,39 +34,30 @@ void NORETURN abort() {
34
34
}
35
35
#endif
36
36
37
+ // Global variable to capture data from layer_cb
38
+ uint16_t * g_out_dims = NULL ;
37
39
40
+ // XXX: caller must set g_out_dims
38
41
static tm_err_t layer_cb (tm_mdl_t * mdl , tml_head_t * lh )
39
- {
40
- #if 0
41
- //dump middle result
42
- int h = lh -> out_dims [1 ];
43
- int w = lh -> out_dims [2 ];
44
- int ch = lh -> out_dims [3 ];
45
- mtype_t * output = TML_GET_OUTPUT (mdl , lh );
46
- return TM_OK ;
47
- TM_PRINTF ("Layer %d callback ========\n" , mdl -> layer_i );
48
- #if 1
49
- for (int y = 0 ; y < h ; y ++ ){
50
- TM_PRINTF ("[" );
51
- for (int x = 0 ; x < w ; x ++ ){
52
- TM_PRINTF ("[" );
53
- for (int c = 0 ; c < ch ; c ++ ){
54
- #if TM_MDL_TYPE == TM_MDL_FP32
55
- TM_PRINTF ("%.3f," , output [(y * w + x )* ch + c ]);
56
- #else
57
- TM_PRINTF ("%.3f," , TML_DEQUANT (lh ,output [(y * w + x )* ch + c ]));
58
- #endif
59
- }
60
- TM_PRINTF ("]," );
42
+ {
43
+ const int h = lh -> out_dims [1 ];
44
+ const int w = lh -> out_dims [2 ];
45
+ const int ch = lh -> out_dims [3 ];
46
+ const bool is_out = lh -> is_out ;
47
+
48
+ #if DEBUG
49
+ mp_printf (& mp_plat_print ,
50
+ "cnn-layer-cb is_out=%d h=%d w=%d ch=%d size=%d \n" ,
51
+ (int )is_out , h , w , ch );
52
+ #endif
53
+
54
+ if (is_out ) {
55
+ for (int i = 0 ; i < 4 ; i ++ ) {
56
+ g_out_dims [i ] = lh -> out_dims [i ];
61
57
}
62
- TM_PRINTF ( "],\n" );
58
+
63
59
}
64
- TM_PRINTF ("\n" );
65
- #endif
66
- return TM_OK ;
67
- #else
68
60
return TM_OK ;
69
- #endif
70
61
}
71
62
72
63
#define DEBUG (1)
@@ -79,6 +70,7 @@ typedef struct _mp_obj_mod_cnn_t {
79
70
tm_mat_t input ;
80
71
uint8_t * model_buffer ;
81
72
uint8_t * data_buffer ;
73
+ uint16_t out_dims [4 ];
82
74
} mp_obj_mod_cnn_t ;
83
75
84
76
mp_obj_full_type_t mod_cnn_type ;
@@ -116,11 +108,17 @@ static mp_obj_t mod_cnn_new(mp_obj_t model_data_obj) {
116
108
117
109
// loading model
118
110
// will set the dimensions of the input matrix
111
+ o -> out_dims [0 ] = 0 ;
112
+ g_out_dims = o -> out_dims ;
119
113
tm_err_t load_err = tm_load (model , o -> model_buffer , o -> data_buffer , layer_cb , & o -> input );
120
114
if (load_err != TM_OK ) {
121
115
mp_raise_msg (& mp_type_RuntimeError , MP_ERROR_TEXT ("tm_load error" ));
122
116
}
123
117
118
+ #if DEBUG
119
+ mp_printf (& mp_plat_print , "cnn-new-done out.dims=%d \n" , o -> out_dims [0 ]);
120
+ #endif
121
+
124
122
return MP_OBJ_FROM_PTR (o );
125
123
}
126
124
static MP_DEFINE_CONST_FUN_OBJ_1 (mod_cnn_new_obj , mod_cnn_new ) ;
@@ -141,15 +139,15 @@ static MP_DEFINE_CONST_FUN_OBJ_1(mod_cnn_del_obj, mod_cnn_del);
141
139
142
140
143
141
// Add a node to the tree
144
- static mp_obj_t mod_cnn_run (mp_obj_t self_obj , mp_obj_t input_obj ) {
142
+ static mp_obj_t mod_cnn_run (mp_obj_t self_obj , mp_obj_t input_obj , mp_obj_t output_obj ) {
145
143
146
144
mp_obj_mod_cnn_t * o = MP_OBJ_TO_PTR (self_obj );
147
145
148
146
// Extract input
149
147
mp_buffer_info_t bufinfo ;
150
148
mp_get_buffer_raise (input_obj , & bufinfo , MP_BUFFER_RW );
151
149
if (bufinfo .typecode != 'B' ) {
152
- mp_raise_ValueError (MP_ERROR_TEXT ("expecting float array" ));
150
+ mp_raise_ValueError (MP_ERROR_TEXT ("expecting byte array" ));
153
151
}
154
152
uint8_t * input_buffer = bufinfo .buf ;
155
153
const int input_length = bufinfo .len / sizeof (* input_buffer );
@@ -160,6 +158,21 @@ static mp_obj_t mod_cnn_run(mp_obj_t self_obj, mp_obj_t input_obj) {
160
158
mp_raise_ValueError (MP_ERROR_TEXT ("wrong input size" ));
161
159
}
162
160
161
+ // Extract output
162
+ mp_get_buffer_raise (output_obj , & bufinfo , MP_BUFFER_RW );
163
+ if (bufinfo .typecode != 'f' ) {
164
+ mp_raise_ValueError (MP_ERROR_TEXT ("expecting float array" ));
165
+ }
166
+ float * output_buffer = bufinfo .buf ;
167
+ const int output_length = bufinfo .len / sizeof (* output_buffer );
168
+
169
+
170
+ // check buffer size wrt input
171
+ const int expect_out_length = o -> out_dims [1 ]* o -> out_dims [2 ]* o -> out_dims [3 ];
172
+ if (output_length != expect_out_length ) {
173
+ mp_raise_ValueError (MP_ERROR_TEXT ("wrong output size" ));
174
+ }
175
+
163
176
// Preprocess data
164
177
tm_mat_t in_uint8 = o -> input ;
165
178
in_uint8 .data = (mtype_t * )input_buffer ;
@@ -181,27 +194,33 @@ static mp_obj_t mod_cnn_run(mp_obj_t self_obj, mp_obj_t input_obj) {
181
194
mp_raise_ValueError (MP_ERROR_TEXT ("run error" ));
182
195
}
183
196
197
+ // Copy output into
184
198
tm_mat_t out = outs [0 ];
185
- float * data = out .dataf ;
186
- float maxp = 0 ;
187
- int maxi = -1 ;
188
-
189
- // TODO: pass the entire output vector out to Python
190
- // FIXME: unhardcode output handling
191
- for (int i = 0 ; i < 10 ; i ++ ){
192
- //printf("%d: %.3f\n", i, data[i]);
193
- if (data [i ] > maxp ) {
194
- maxi = i ;
195
- maxp = data [i ];
196
- }
199
+ for (int i = 0 ; i < expect_out_length ; i ++ ){
200
+ output_buffer [i ] = out .dataf [i ];
197
201
}
198
202
199
- return mp_obj_new_int (maxi );
203
+ return mp_const_none ;
204
+ }
205
+ static MP_DEFINE_CONST_FUN_OBJ_3 (mod_cnn_run_obj , mod_cnn_run ) ;
206
+
207
+
208
+ // Return the shape of the output
209
+ static mp_obj_t mod_cnn_output_dimensions (mp_obj_t self_obj ) {
210
+
211
+ mp_obj_mod_cnn_t * o = MP_OBJ_TO_PTR (self_obj );
212
+ const int dimensions = o -> out_dims [0 ];
213
+ mp_obj_tuple_t * tuple = MP_OBJ_TO_PTR (mp_obj_new_tuple (dimensions , NULL ));
214
+
215
+ for (int i = 0 ; i < dimensions ; i ++ ) {
216
+ tuple -> items [i ] = mp_obj_new_int (o -> out_dims [i + 1 ]);
217
+ }
218
+ return tuple ;
200
219
}
201
- static MP_DEFINE_CONST_FUN_OBJ_2 ( mod_cnn_run_obj , mod_cnn_run ) ;
220
+ static MP_DEFINE_CONST_FUN_OBJ_1 ( mod_cnn_output_dimensions_obj , mod_cnn_output_dimensions ) ;
202
221
203
222
204
- mp_map_elem_t mod_locals_dict_table [2 ];
223
+ mp_map_elem_t mod_locals_dict_table [3 ];
205
224
static MP_DEFINE_CONST_DICT (mod_locals_dict , mod_locals_dict_table ) ;
206
225
207
226
// This is the entry point and is called when the module is imported
@@ -217,6 +236,7 @@ mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *a
217
236
// methods
218
237
mod_locals_dict_table [0 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_run ), MP_OBJ_FROM_PTR (& mod_cnn_run_obj ) };
219
238
mod_locals_dict_table [1 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR___del__ ), MP_OBJ_FROM_PTR (& mod_cnn_del_obj ) };
239
+ mod_locals_dict_table [2 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_output_dimensions ), MP_OBJ_FROM_PTR (& mod_cnn_output_dimensions_obj ) };
220
240
221
241
MP_OBJ_TYPE_SET_SLOT (& mod_cnn_type , locals_dict , (void * )& mod_locals_dict , 2 );
222
242
0 commit comments