@@ -34,29 +34,30 @@ void NORETURN abort() {
34
34
}
35
35
#endif
36
36
37
- // Global variable to capture data from layer_cb
38
- uint16_t * g_out_dims = NULL ;
39
-
40
- // XXX: caller must set g_out_dims
41
- static tm_err_t layer_cb (tm_mdl_t * mdl , tml_head_t * lh )
37
+ // get model output shapes
38
+ //mdl: model handle; in: input mat; out: output mat
39
+ int TM_WEAK tm_get_outputs (tm_mdl_t * mdl , tm_mat_t * out , int out_length )
42
40
{
43
- const int h = lh -> out_dims [1 ];
44
- const int w = lh -> out_dims [2 ];
45
- const int ch = lh -> out_dims [3 ];
46
- const bool is_out = lh -> is_out ;
47
-
48
- #if DEBUG
49
- mp_printf (& mp_plat_print ,
50
- "cnn-layer-cb is_out=%d h=%d w=%d ch=%d size=%d \n" ,
51
- (int )is_out , h , w , ch );
52
- #endif
53
-
54
- if (is_out ) {
55
- for (int i = 0 ; i < 4 ; i ++ ) {
56
- g_out_dims [i ] = lh -> out_dims [i ];
41
+ // NOTE: based on tm_run, but without actually executing
42
+ int out_idx = 0 ;
43
+ mdl -> layer_body = mdl -> b -> layers_body ;
44
+ for (mdl -> layer_i = 0 ; mdl -> layer_i < mdl -> b -> layer_cnt ; mdl -> layer_i ++ ){
45
+ tml_head_t * h = (tml_head_t * )(mdl -> layer_body );
46
+ if (h -> is_out ) {
47
+ if (out_idx < out_length ) {
48
+ memcpy ((void * )(& out [out_idx ]), (void * )(& (h -> out_dims )), sizeof (uint16_t )* 4 );
49
+ out_idx += 1 ;
50
+ } else {
51
+ return -1 ;
52
+ }
57
53
}
58
-
54
+ mdl -> layer_body += ( h -> size );
59
55
}
56
+ return out_idx ;
57
+ }
58
+
59
+ static tm_err_t layer_cb (tm_mdl_t * mdl , tml_head_t * lh )
60
+ {
60
61
return TM_OK ;
61
62
}
62
63
@@ -108,15 +109,28 @@ static mp_obj_t mod_cnn_new(mp_obj_t model_data_obj) {
108
109
109
110
// loading model
110
111
// will set the dimensions of the input matrix
111
- o -> out_dims [0 ] = 0 ;
112
- g_out_dims = o -> out_dims ;
113
112
tm_err_t load_err = tm_load (model , o -> model_buffer , o -> data_buffer , layer_cb , & o -> input );
114
113
if (load_err != TM_OK ) {
115
114
mp_raise_msg (& mp_type_RuntimeError , MP_ERROR_TEXT ("tm_load error" ));
116
115
}
117
116
117
+ // find model output shape
118
+ o -> out_dims [0 ] = 0 ;
119
+ tm_mat_t outs [1 ];
120
+ const int outputs = tm_get_outputs (model , outs , 1 );
121
+ if (outputs != 1 ) {
122
+ mp_raise_msg (& mp_type_RuntimeError , MP_ERROR_TEXT ("only 1 output supported" ));
123
+ }
124
+ memcpy ((void * )(o -> out_dims ), (void * )(& (outs [0 ])), sizeof (uint16_t )* 4 );
125
+
126
+ if ((o -> out_dims [0 ] != 1 )) {
127
+ mp_raise_msg (& mp_type_RuntimeError , MP_ERROR_TEXT ("output must be 1d" ));
128
+ }
129
+ memcpy ((void * )(o -> out_dims ), (void * )(& (outs [0 ])), sizeof (uint16_t )* 4 );
130
+
118
131
#if DEBUG
119
- mp_printf (& mp_plat_print , "cnn-new-done out.dims=%d \n" , o -> out_dims [0 ]);
132
+ mp_printf (& mp_plat_print , "cnn-new-done outs=%d out.dims=(%d,%d,%d,%d) \n" ,
133
+ outputs , o -> out_dims [0 ], o -> out_dims [1 ], o -> out_dims [2 ], o -> out_dims [3 ]);
120
134
#endif
121
135
122
136
return MP_OBJ_FROM_PTR (o );
@@ -212,9 +226,14 @@ static mp_obj_t mod_cnn_output_dimensions(mp_obj_t self_obj) {
212
226
const int dimensions = o -> out_dims [0 ];
213
227
mp_obj_tuple_t * tuple = MP_OBJ_TO_PTR (mp_obj_new_tuple (dimensions , NULL ));
214
228
215
- for (int i = 0 ; i < dimensions ; i ++ ) {
216
- tuple -> items [i ] = mp_obj_new_int (o -> out_dims [i + 1 ]);
229
+ // A regular output should have C channels, and 1 for everything else
230
+ // TODO: support other shapes?
231
+ //dims==1, 11c
232
+ if (!(o -> out_dims [0 ] == 1 && o -> out_dims [1 ] == 1 && o -> out_dims [2 ] == 1 )) {
233
+ mp_raise_msg (& mp_type_RuntimeError , MP_ERROR_TEXT ("wrong output shape" ));
217
234
}
235
+
236
+ tuple -> items [0 ] = mp_obj_new_int (o -> out_dims [3 ]);
218
237
return tuple ;
219
238
}
220
239
static MP_DEFINE_CONST_FUN_OBJ_1 (mod_cnn_output_dimensions_obj , mod_cnn_output_dimensions ) ;
0 commit comments