Skip to content

Commit 7c63df3

Browse files
committed
cnn: Try fix hardcoded output handling
Would give corrupt results for anything else than outputs with 10 lenght The .run() function now requires an array.array('f') where to place output There is a new method, output_dimensions() to get the shape of the output
1 parent d4682d7 commit 7c63df3

File tree

2 files changed

+68
-45
lines changed

2 files changed

+68
-45
lines changed

src/tinymaix_cnn/mod_cnn.c

+65-45
Original file line numberDiff line numberDiff line change
@@ -34,39 +34,30 @@ void NORETURN abort() {
3434
}
3535
#endif
3636

37+
// Global variable to capture data from layer_cb
38+
uint16_t *g_out_dims = NULL;
3739

40+
// XXX: caller must set g_out_dims
3841
static tm_err_t layer_cb(tm_mdl_t* mdl, tml_head_t* lh)
39-
{
40-
#if 0
41-
//dump middle result
42-
int h = lh->out_dims[1];
43-
int w = lh->out_dims[2];
44-
int ch= lh->out_dims[3];
45-
mtype_t* output = TML_GET_OUTPUT(mdl, lh);
46-
return TM_OK;
47-
TM_PRINTF("Layer %d callback ========\n", mdl->layer_i);
48-
#if 1
49-
for(int y=0; y<h; y++){
50-
TM_PRINTF("[");
51-
for(int x=0; x<w; x++){
52-
TM_PRINTF("[");
53-
for(int c=0; c<ch; c++){
54-
#if TM_MDL_TYPE == TM_MDL_FP32
55-
TM_PRINTF("%.3f,", output[(y*w+x)*ch+c]);
56-
#else
57-
TM_PRINTF("%.3f,", TML_DEQUANT(lh,output[(y*w+x)*ch+c]));
58-
#endif
59-
}
60-
TM_PRINTF("],");
42+
{
43+
const int h = lh->out_dims[1];
44+
const int w = lh->out_dims[2];
45+
const int ch= lh->out_dims[3];
46+
const bool is_out = lh->is_out;
47+
48+
#if DEBUG
49+
mp_printf(&mp_plat_print,
50+
"cnn-layer-cb is_out=%d h=%d w=%d ch=%d size=%d \n",
51+
(int)is_out, h, w, ch);
52+
#endif
53+
54+
if (is_out) {
55+
for (int i=0; i<4; i++) {
56+
g_out_dims[i] = lh->out_dims[i];
6157
}
62-
TM_PRINTF("],\n");
58+
6359
}
64-
TM_PRINTF("\n");
65-
#endif
66-
return TM_OK;
67-
#else
6860
return TM_OK;
69-
#endif
7061
}
7162

7263
#define DEBUG (1)
@@ -79,6 +70,7 @@ typedef struct _mp_obj_mod_cnn_t {
7970
tm_mat_t input;
8071
uint8_t *model_buffer;
8172
uint8_t *data_buffer;
73+
uint16_t out_dims[4];
8274
} mp_obj_mod_cnn_t;
8375

8476
mp_obj_full_type_t mod_cnn_type;
@@ -116,11 +108,17 @@ static mp_obj_t mod_cnn_new(mp_obj_t model_data_obj) {
116108

117109
// loading model
118110
// will set the dimensions of the input matrix
111+
o->out_dims[0] = 0;
112+
g_out_dims = o->out_dims;
119113
tm_err_t load_err = tm_load(model, o->model_buffer, o->data_buffer, layer_cb, &o->input);
120114
if (load_err != TM_OK) {
121115
mp_raise_msg(&mp_type_RuntimeError, MP_ERROR_TEXT("tm_load error"));
122116
}
123117

118+
#if DEBUG
119+
mp_printf(&mp_plat_print, "cnn-new-done out.dims=%d \n", o->out_dims[0]);
120+
#endif
121+
124122
return MP_OBJ_FROM_PTR(o);
125123
}
126124
static MP_DEFINE_CONST_FUN_OBJ_1(mod_cnn_new_obj, mod_cnn_new);
@@ -141,15 +139,15 @@ static MP_DEFINE_CONST_FUN_OBJ_1(mod_cnn_del_obj, mod_cnn_del);
141139

142140

143141
// Add a node to the tree
144-
static mp_obj_t mod_cnn_run(mp_obj_t self_obj, mp_obj_t input_obj) {
142+
static mp_obj_t mod_cnn_run(mp_obj_t self_obj, mp_obj_t input_obj, mp_obj_t output_obj) {
145143

146144
mp_obj_mod_cnn_t *o = MP_OBJ_TO_PTR(self_obj);
147145

148146
// Extract input
149147
mp_buffer_info_t bufinfo;
150148
mp_get_buffer_raise(input_obj, &bufinfo, MP_BUFFER_RW);
151149
if (bufinfo.typecode != 'B') {
152-
mp_raise_ValueError(MP_ERROR_TEXT("expecting float array"));
150+
mp_raise_ValueError(MP_ERROR_TEXT("expecting byte array"));
153151
}
154152
uint8_t *input_buffer = bufinfo.buf;
155153
const int input_length = bufinfo.len / sizeof(*input_buffer);
@@ -160,6 +158,21 @@ static mp_obj_t mod_cnn_run(mp_obj_t self_obj, mp_obj_t input_obj) {
160158
mp_raise_ValueError(MP_ERROR_TEXT("wrong input size"));
161159
}
162160

161+
// Extract output
162+
mp_get_buffer_raise(output_obj, &bufinfo, MP_BUFFER_RW);
163+
if (bufinfo.typecode != 'f') {
164+
mp_raise_ValueError(MP_ERROR_TEXT("expecting float array"));
165+
}
166+
float *output_buffer = bufinfo.buf;
167+
const int output_length = bufinfo.len / sizeof(*output_buffer);
168+
169+
170+
// check buffer size wrt input
171+
const int expect_out_length = o->out_dims[1]*o->out_dims[2]*o->out_dims[3];
172+
if (output_length != expect_out_length) {
173+
mp_raise_ValueError(MP_ERROR_TEXT("wrong output size"));
174+
}
175+
163176
// Preprocess data
164177
tm_mat_t in_uint8 = o->input;
165178
in_uint8.data = (mtype_t*)input_buffer;
@@ -181,27 +194,33 @@ static mp_obj_t mod_cnn_run(mp_obj_t self_obj, mp_obj_t input_obj) {
181194
mp_raise_ValueError(MP_ERROR_TEXT("run error"));
182195
}
183196

197+
// Copy output into
184198
tm_mat_t out = outs[0];
185-
float* data = out.dataf;
186-
float maxp = 0;
187-
int maxi = -1;
188-
189-
// TODO: pass the entire output vector out to Python
190-
// FIXME: unhardcode output handling
191-
for(int i=0; i<10; i++){
192-
//printf("%d: %.3f\n", i, data[i]);
193-
if (data[i] > maxp) {
194-
maxi = i;
195-
maxp = data[i];
196-
}
199+
for(int i=0; i<expect_out_length; i++){
200+
output_buffer[i] = out.dataf[i];
197201
}
198202

199-
return mp_obj_new_int(maxi);
203+
return mp_const_none;
204+
}
205+
static MP_DEFINE_CONST_FUN_OBJ_3(mod_cnn_run_obj, mod_cnn_run);
206+
207+
208+
// Return the shape of the output
209+
static mp_obj_t mod_cnn_output_dimensions(mp_obj_t self_obj) {
210+
211+
mp_obj_mod_cnn_t *o = MP_OBJ_TO_PTR(self_obj);
212+
const int dimensions = o->out_dims[0];
213+
mp_obj_tuple_t *tuple = MP_OBJ_TO_PTR(mp_obj_new_tuple(dimensions, NULL));
214+
215+
for (int i=0; i<dimensions; i++) {
216+
tuple->items[i] = mp_obj_new_int(o->out_dims[i+1]);
217+
}
218+
return tuple;
200219
}
201-
static MP_DEFINE_CONST_FUN_OBJ_2(mod_cnn_run_obj, mod_cnn_run);
220+
static MP_DEFINE_CONST_FUN_OBJ_1(mod_cnn_output_dimensions_obj, mod_cnn_output_dimensions);
202221

203222

204-
mp_map_elem_t mod_locals_dict_table[2];
223+
mp_map_elem_t mod_locals_dict_table[3];
205224
static MP_DEFINE_CONST_DICT(mod_locals_dict, mod_locals_dict_table);
206225

207226
// This is the entry point and is called when the module is imported
@@ -217,6 +236,7 @@ mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *a
217236
// methods
218237
mod_locals_dict_table[0] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_run), MP_OBJ_FROM_PTR(&mod_cnn_run_obj) };
219238
mod_locals_dict_table[1] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR___del__), MP_OBJ_FROM_PTR(&mod_cnn_del_obj) };
239+
mod_locals_dict_table[2] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_output_dimensions), MP_OBJ_FROM_PTR(&mod_cnn_output_dimensions_obj) };
220240

221241
MP_OBJ_TYPE_SET_SLOT(&mod_cnn_type, locals_dict, (void*)&mod_locals_dict, 2);
222242

tests/test_cnn.py

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ def test_cnn_create():
1212
model_data = array.array('B', f.read())
1313
model = emlearn_cnn.new(model_data)
1414

15+
out_shape = model.output_dimensions()
16+
assert out_shape == (10,)
17+
1518
# TODO: enable these checks
1619
#wrong_type = array.array('f', [])
1720
#model.run(wrong_type)

0 commit comments

Comments
 (0)