Skip to content

Commit b4f3998

Browse files
author
Chun-Heng Huang
committed
Matlab interface parameter checking
Check 'nlhs' argument. Allow various output size for svmpredict.
1 parent d2630a2 commit b4f3998

File tree

8 files changed

+93
-60
lines changed

8 files changed

+93
-60
lines changed

matlab/libsvmread.c

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ void exit_with_help()
2525
);
2626
}
2727

28-
static void fake_answer(mxArray *plhs[])
28+
static void fake_answer(int nlhs, mxArray *plhs[])
2929
{
30-
plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
31-
plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL);
30+
int i;
31+
for(i=0;i<nlhs;i++)
32+
plhs[i] = mxCreateDoubleMatrix(0, 0, mxREAL);
3233
}
3334

3435
static char *line;
@@ -53,7 +54,7 @@ static char* readline(FILE *input)
5354
}
5455

5556
// read in a problem (in libsvm format)
56-
void read_problem(const char *filename, mxArray *plhs[])
57+
void read_problem(const char *filename, int nlhs, mxArray *plhs[])
5758
{
5859
int max_index, min_index, inst_max_index, i;
5960
long elements, k;
@@ -66,7 +67,7 @@ void read_problem(const char *filename, mxArray *plhs[])
6667
if(fp == NULL)
6768
{
6869
mexPrintf("can't open input file %s\n",filename);
69-
fake_answer(plhs);
70+
fake_answer(nlhs, plhs);
7071
return;
7172
}
7273

@@ -96,7 +97,7 @@ void read_problem(const char *filename, mxArray *plhs[])
9697
if(endptr == idx || errno != 0 || *endptr != '\0' || index <= inst_max_index)
9798
{
9899
mexPrintf("Wrong input format at line %d\n",l+1);
99-
fake_answer(plhs);
100+
fake_answer(nlhs, plhs);
100101
return;
101102
}
102103
else
@@ -135,14 +136,14 @@ void read_problem(const char *filename, mxArray *plhs[])
135136
if(label == NULL)
136137
{
137138
mexPrintf("Empty line at line %d\n",i+1);
138-
fake_answer(plhs);
139+
fake_answer(nlhs, plhs);
139140
return;
140141
}
141142
labels[i] = strtod(label,&endptr);
142143
if(endptr == label || *endptr != '\0')
143144
{
144145
mexPrintf("Wrong input format at line %d\n",i+1);
145-
fake_answer(plhs);
146+
fake_answer(nlhs, plhs);
146147
return;
147148
}
148149

@@ -161,7 +162,7 @@ void read_problem(const char *filename, mxArray *plhs[])
161162
if (endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
162163
{
163164
mexPrintf("Wrong input format at line %d\n",i+1);
164-
fake_answer(plhs);
165+
fake_answer(nlhs, plhs);
165166
return;
166167
}
167168
++k;
@@ -178,7 +179,7 @@ void read_problem(const char *filename, mxArray *plhs[])
178179
if(mexCallMATLAB(1, lhs, 1, rhs, "transpose"))
179180
{
180181
mexPrintf("Error: cannot transpose problem\n");
181-
fake_answer(plhs);
182+
fake_answer(nlhs, plhs);
182183
return;
183184
}
184185
plhs[1] = lhs[0];
@@ -188,25 +189,25 @@ void read_problem(const char *filename, mxArray *plhs[])
188189
void mexFunction( int nlhs, mxArray *plhs[],
189190
int nrhs, const mxArray *prhs[] )
190191
{
191-
if(nrhs == 1)
192+
char filename[256];
193+
194+
if(nrhs != 1 || nlhs != 2)
192195
{
193-
char filename[256];
196+
exit_with_help();
197+
fake_answer(nlhs, plhs);
198+
return;
199+
}
194200

195-
mxGetString(prhs[0], filename, mxGetN(prhs[0]) + 1);
201+
mxGetString(prhs[0], filename, mxGetN(prhs[0]) + 1);
196202

197-
if(filename == NULL)
198-
{
199-
mexPrintf("Error: filename is NULL\n");
200-
return;
201-
}
202-
203-
read_problem(filename, plhs);
204-
}
205-
else
203+
if(filename == NULL)
206204
{
207-
exit_with_help();
208-
fake_answer(plhs);
205+
mexPrintf("Error: filename is NULL\n");
209206
return;
210207
}
208+
209+
read_problem(filename, nlhs, plhs);
210+
211+
return;
211212
}
212213

matlab/libsvmwrite.c

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ void exit_with_help()
1616
);
1717
}
1818

19+
static void fake_answer(int nlhs, mxArray *plhs[])
20+
{
21+
int i;
22+
for(i=0;i<nlhs;i++)
23+
plhs[i] = mxCreateDoubleMatrix(0, 0, mxREAL);
24+
}
25+
1926
void libsvmwrite(const char *filename, const mxArray *label_vec, const mxArray *instance_mat)
2027
{
2128
FILE *fp = fopen(filename,"w");
@@ -77,7 +84,14 @@ void libsvmwrite(const char *filename, const mxArray *label_vec, const mxArray *
7784

7885
void mexFunction( int nlhs, mxArray *plhs[],
7986
int nrhs, const mxArray *prhs[] )
80-
{
87+
{
88+
if(nlhs > 0)
89+
{
90+
exit_with_help();
91+
fake_answer(nlhs, plhs);
92+
return;
93+
}
94+
8195
// Transform the input Matrix to libsvm format
8296
if(nrhs == 3)
8397
{

matlab/svmpredict.c

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ void read_sparse_instance(const mxArray *prhs, int index, struct svm_node *x)
3939
x[j].index = -1;
4040
}
4141

42-
static void fake_answer(mxArray *plhs[])
42+
static void fake_answer(int nlhs, mxArray *plhs[])
4343
{
44-
plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
45-
plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL);
46-
plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
44+
int i;
45+
for(i=0;i<nlhs;i++)
46+
plhs[i] = mxCreateDoubleMatrix(0, 0, mxREAL);
4747
}
4848

49-
void predict(mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, const int predict_probability)
49+
void predict(int nlhs, mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, const int predict_probability)
5050
{
5151
int label_vector_row_num, label_vector_col_num;
5252
int feature_number, testing_instance_number;
@@ -55,6 +55,7 @@ void predict(mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, co
5555
double *ptr_prob_estimates, *ptr_dec_values, *ptr;
5656
struct svm_node *x;
5757
mxArray *pplhs[1]; // transposed instance sparse matrix
58+
mxArray *tplhs[3]; // temporary storage for plhs[]
5859

5960
int correct = 0;
6061
int total = 0;
@@ -74,13 +75,13 @@ void predict(mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, co
7475
if(label_vector_row_num!=testing_instance_number)
7576
{
7677
mexPrintf("Length of label vector does not match # of instances.\n");
77-
fake_answer(plhs);
78+
fake_answer(nlhs, plhs);
7879
return;
7980
}
8081
if(label_vector_col_num!=1)
8182
{
8283
mexPrintf("label (1st argument) should be a vector (# of column is 1).\n");
83-
fake_answer(plhs);
84+
fake_answer(nlhs, plhs);
8485
return;
8586
}
8687

@@ -98,7 +99,7 @@ void predict(mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, co
9899
if(mexCallMATLAB(1, lhs, 1, rhs, "full"))
99100
{
100101
mexPrintf("Error: cannot full testing instance matrix\n");
101-
fake_answer(plhs);
102+
fake_answer(nlhs, plhs);
102103
return;
103104
}
104105
ptr_instance = mxGetPr(lhs[0]);
@@ -111,7 +112,7 @@ void predict(mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, co
111112
if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
112113
{
113114
mexPrintf("Error: cannot transpose testing instance matrix\n");
114-
fake_answer(plhs);
115+
fake_answer(nlhs, plhs);
115116
return;
116117
}
117118
}
@@ -125,14 +126,14 @@ void predict(mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, co
125126
prob_estimates = (double *) malloc(nr_class*sizeof(double));
126127
}
127128

128-
plhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
129+
tplhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
129130
if(predict_probability)
130131
{
131132
// prob estimates are in plhs[2]
132133
if(svm_type==C_SVC || svm_type==NU_SVC)
133-
plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL);
134+
tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL);
134135
else
135-
plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
136+
tplhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
136137
}
137138
else
138139
{
@@ -141,14 +142,14 @@ void predict(mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, co
141142
svm_type == EPSILON_SVR ||
142143
svm_type == NU_SVR ||
143144
nr_class == 1) // if only one class in training data, decision values are still returned.
144-
plhs[2] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
145+
tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
145146
else
146-
plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class*(nr_class-1)/2, mxREAL);
147+
tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class*(nr_class-1)/2, mxREAL);
147148
}
148149

149-
ptr_predict_label = mxGetPr(plhs[0]);
150-
ptr_prob_estimates = mxGetPr(plhs[2]);
151-
ptr_dec_values = mxGetPr(plhs[2]);
150+
ptr_predict_label = mxGetPr(tplhs[0]);
151+
ptr_prob_estimates = mxGetPr(tplhs[2]);
152+
ptr_dec_values = mxGetPr(tplhs[2]);
152153
x = (struct svm_node*)malloc((feature_number+1)*sizeof(struct svm_node) );
153154
for(instance_index=0;instance_index<testing_instance_number;instance_index++)
154155
{
@@ -229,8 +230,8 @@ void predict(mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, co
229230
(double)correct/total*100,correct,total);
230231

231232
// return accuracy, mean squared error, squared correlation coefficient
232-
plhs[1] = mxCreateDoubleMatrix(3, 1, mxREAL);
233-
ptr = mxGetPr(plhs[1]);
233+
tplhs[1] = mxCreateDoubleMatrix(3, 1, mxREAL);
234+
ptr = mxGetPr(tplhs[1]);
234235
ptr[0] = (double)correct/total*100;
235236
ptr[1] = error/total;
236237
ptr[2] = ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
@@ -239,12 +240,20 @@ void predict(mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, co
239240
free(x);
240241
if(prob_estimates != NULL)
241242
free(prob_estimates);
243+
244+
switch(nlhs){
245+
case 3: plhs[2] = tplhs[2];
246+
plhs[1] = tplhs[1];
247+
case 1:
248+
case 0: plhs[0] = tplhs[0];
249+
}
242250
}
243251

244252
void exit_with_help()
245253
{
246254
mexPrintf(
247255
"Usage: [predicted_label, accuracy, decision_values/prob_estimates] = svmpredict(testing_label_vector, testing_instance_matrix, model, 'libsvm_options')\n"
256+
" [predicted_label] = svmpredict(testing_label_vector, testing_instance_matrix, model, 'libsvm_options')\n"
248257
"Parameters:\n"
249258
" model: SVM model structure from svmtrain.\n"
250259
" libsvm_options:\n"
@@ -264,16 +273,16 @@ void mexFunction( int nlhs, mxArray *plhs[],
264273
struct svm_model *model;
265274
info = &mexPrintf;
266275

267-
if(nrhs > 4 || nrhs < 3)
276+
if(nlhs == 2 || nlhs > 3 || nrhs > 4 || nrhs < 3)
268277
{
269278
exit_with_help();
270-
fake_answer(plhs);
279+
fake_answer(nlhs, plhs);
271280
return;
272281
}
273282

274283
if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
275284
mexPrintf("Error: label vector and instance matrix must be double\n");
276-
fake_answer(plhs);
285+
fake_answer(nlhs, plhs);
277286
return;
278287
}
279288

@@ -299,7 +308,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
299308
if((++i>=argc) && argv[i-1][1] != 'q')
300309
{
301310
exit_with_help();
302-
fake_answer(plhs);
311+
fake_answer(nlhs, plhs);
303312
return;
304313
}
305314
switch(argv[i-1][1])
@@ -314,7 +323,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
314323
default:
315324
mexPrintf("Unknown option: -%c\n", argv[i-1][1]);
316325
exit_with_help();
317-
fake_answer(plhs);
326+
fake_answer(nlhs, plhs);
318327
return;
319328
}
320329
}
@@ -324,7 +333,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
324333
if (model == NULL)
325334
{
326335
mexPrintf("Error: can't read model: %s\n", error_msg);
327-
fake_answer(plhs);
336+
fake_answer(nlhs, plhs);
328337
return;
329338
}
330339

@@ -333,7 +342,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
333342
if(svm_check_probability_model(model)==0)
334343
{
335344
mexPrintf("Model does not support probabiliy estimates\n");
336-
fake_answer(plhs);
345+
fake_answer(nlhs, plhs);
337346
svm_free_and_destroy_model(&model);
338347
return;
339348
}
@@ -344,14 +353,14 @@ void mexFunction( int nlhs, mxArray *plhs[],
344353
info("Model supports probability estimates, but disabled in predicton.\n");
345354
}
346355

347-
predict(plhs, prhs, model, prob_estimate_flag);
356+
predict(nlhs, plhs, prhs, model, prob_estimate_flag);
348357
// destroy model
349358
svm_free_and_destroy_model(&model);
350359
}
351360
else
352361
{
353362
mexPrintf("model file should be a struct array\n");
354-
fake_answer(plhs);
363+
fake_answer(nlhs, plhs);
355364
}
356365

357366
return;

0 commit comments

Comments
 (0)