Skip to content

Commit 926751d

Browse files
committed
cnn: Implement custom function to get output size
The layer_cb of tm_load is not called at load time, only at tm_run() time
1 parent 7c63df3 commit 926751d

File tree

2 files changed

+59
-27
lines changed

2 files changed

+59
-27
lines changed

src/tinymaix_cnn/mod_cnn.c

+44-25
Original file line numberDiff line numberDiff line change
@@ -34,29 +34,30 @@ void NORETURN abort() {
3434
}
3535
#endif
3636

37-
// Global variable to capture data from layer_cb
38-
uint16_t *g_out_dims = NULL;
39-
40-
// XXX: caller must set g_out_dims
41-
static tm_err_t layer_cb(tm_mdl_t* mdl, tml_head_t* lh)
37+
// get model output shapes
38+
//mdl: model handle; in: input mat; out: output mat
39+
int TM_WEAK tm_get_outputs(tm_mdl_t* mdl, tm_mat_t* out, int out_length)
4240
{
43-
const int h = lh->out_dims[1];
44-
const int w = lh->out_dims[2];
45-
const int ch= lh->out_dims[3];
46-
const bool is_out = lh->is_out;
47-
48-
#if DEBUG
49-
mp_printf(&mp_plat_print,
50-
"cnn-layer-cb is_out=%d h=%d w=%d ch=%d size=%d \n",
51-
(int)is_out, h, w, ch);
52-
#endif
53-
54-
if (is_out) {
55-
for (int i=0; i<4; i++) {
56-
g_out_dims[i] = lh->out_dims[i];
41+
// NOTE: based on tm_run, but without actually executing
42+
int out_idx = 0;
43+
mdl->layer_body = mdl->b->layers_body;
44+
for(mdl->layer_i = 0; mdl->layer_i < mdl->b->layer_cnt; mdl->layer_i++){
45+
tml_head_t* h = (tml_head_t*)(mdl->layer_body);
46+
if(h->is_out) {
47+
if (out_idx < out_length) {
48+
memcpy((void*)(&out[out_idx]), (void*)(&(h->out_dims)), sizeof(uint16_t)*4);
49+
out_idx += 1;
50+
} else {
51+
return -1;
52+
}
5753
}
58-
54+
mdl->layer_body += (h->size);
5955
}
56+
return out_idx;
57+
}
58+
59+
static tm_err_t layer_cb(tm_mdl_t* mdl, tml_head_t* lh)
60+
{
6061
return TM_OK;
6162
}
6263

@@ -108,15 +109,28 @@ static mp_obj_t mod_cnn_new(mp_obj_t model_data_obj) {
108109

109110
// loading model
110111
// will set the dimensions of the input matrix
111-
o->out_dims[0] = 0;
112-
g_out_dims = o->out_dims;
113112
tm_err_t load_err = tm_load(model, o->model_buffer, o->data_buffer, layer_cb, &o->input);
114113
if (load_err != TM_OK) {
115114
mp_raise_msg(&mp_type_RuntimeError, MP_ERROR_TEXT("tm_load error"));
116115
}
117116

117+
// find model output shape
118+
o->out_dims[0] = 0;
119+
tm_mat_t outs[1];
120+
const int outputs = tm_get_outputs(model, outs, 1);
121+
if (outputs != 1) {
122+
mp_raise_msg(&mp_type_RuntimeError, MP_ERROR_TEXT("only 1 output supported"));
123+
}
124+
memcpy((void*)(o->out_dims), (void*)(&(outs[0])), sizeof(uint16_t)*4);
125+
126+
if ((o->out_dims[0] != 1)) {
127+
mp_raise_msg(&mp_type_RuntimeError, MP_ERROR_TEXT("output must be 1d"));
128+
}
129+
memcpy((void*)(o->out_dims), (void*)(&(outs[0])), sizeof(uint16_t)*4);
130+
118131
#if DEBUG
119-
mp_printf(&mp_plat_print, "cnn-new-done out.dims=%d \n", o->out_dims[0]);
132+
mp_printf(&mp_plat_print, "cnn-new-done outs=%d out.dims=(%d,%d,%d,%d) \n",
133+
outputs, o->out_dims[0], o->out_dims[1], o->out_dims[2], o->out_dims[3]);
120134
#endif
121135

122136
return MP_OBJ_FROM_PTR(o);
@@ -212,9 +226,14 @@ static mp_obj_t mod_cnn_output_dimensions(mp_obj_t self_obj) {
212226
const int dimensions = o->out_dims[0];
213227
mp_obj_tuple_t *tuple = MP_OBJ_TO_PTR(mp_obj_new_tuple(dimensions, NULL));
214228

215-
for (int i=0; i<dimensions; i++) {
216-
tuple->items[i] = mp_obj_new_int(o->out_dims[i+1]);
229+
// A regular output should have C channels, and 1 for everything else
230+
// TODO: support other shapes?
231+
//dims==1, 11c
232+
if (!(o->out_dims[0] == 1 && o->out_dims[1] == 1 && o->out_dims[2] == 1)) {
233+
mp_raise_msg(&mp_type_RuntimeError, MP_ERROR_TEXT("wrong output shape"));
217234
}
235+
236+
tuple->items[0] = mp_obj_new_int(o->out_dims[3]);
218237
return tuple;
219238
}
220239
static MP_DEFINE_CONST_FUN_OBJ_1(mod_cnn_output_dimensions_obj, mod_cnn_output_dimensions);

tests/test_cnn.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_cnn_create():
1313
model = emlearn_cnn.new(model_data)
1414

1515
out_shape = model.output_dimensions()
16-
assert out_shape == (10,)
16+
assert out_shape == (10,), (out_shape)
1717

1818
# TODO: enable these checks
1919
#wrong_type = array.array('f', [])
@@ -38,13 +38,25 @@ def print_2d_buffer(arr, rowstride):
3838

3939
print('\n')
4040

41+
def argmax(arr):
42+
idx_max = 0
43+
value_max = arr[0]
44+
for i in range(1, len(arr)):
45+
if arr[i] > value_max:
46+
value_max = arr[i]
47+
idx_max = i
48+
49+
return idx_max
50+
4151
def test_cnn_mnist():
4252

4353
model = None
4454
with open(MNIST_MODEL, 'rb') as f:
4555
model_data = array.array('B', f.read())
4656
model = emlearn_cnn.new(model_data)
4757

58+
probabilities = array.array('f', (-1 for _ in range(10)))
59+
4860
correct = 0
4961
for class_no in range(0, 10):
5062
data_path = MNIST_DATA_DIR + 'mnist_example_{0:d}.bin'.format(class_no)
@@ -54,7 +66,8 @@ def test_cnn_mnist():
5466

5567
#print_2d_buffer(img, 28)
5668

57-
out = model.run(img)
69+
model.run(img, probabilities)
70+
out = argmax(probabilities)
5871
# TODO replace with assert
5972
print('mnist-example-check', class_no, out, class_no == out)
6073
if out == class_no:

0 commit comments

Comments
 (0)