@@ -230,6 +230,7 @@ impl Booster {
230230 n_features : i32 ,
231231 is_row_major : bool ,
232232 predict_type : PredictType ,
233+ parameters : Option < & str > ,
233234 ) -> Result < Vec < f64 > > {
234235 if self . n_features <= 0 {
235236 return Err ( Error :: new ( "n_features should be greater than 0" ) ) ;
@@ -252,7 +253,10 @@ impl Booster {
252253 ) ) ) ;
253254 }
254255 let n_rows = flat_x. len ( ) / n_features as usize ;
255- let params = CString :: new ( "" ) . unwrap ( ) ;
256+ let params_cstring = parameters
257+ . map ( |s| CString :: new ( s) )
258+ . unwrap_or ( CString :: new ( "" ) )
259+ . unwrap ( ) ;
256260 let mut out_length: c_longlong = 0 ;
257261
258262 let out_result: Vec < f64 > = vec ! [ Default :: default ( ) ; n_rows * self . n_classes as usize ] ;
@@ -266,7 +270,7 @@ impl Booster {
266270 predict_type. into( ) , // predict_type
267271 0_i32 , // start_iteration
268272 self . max_iterations, // num_iteration, <= 0 means no limit
269- params . as_ptr( ) as * const c_char,
273+ params_cstring . as_ptr( ) as * const c_char,
270274 & mut out_length,
271275 out_result. as_ptr( ) as * mut c_double
272276 ) ) ?;
@@ -282,7 +286,31 @@ impl Booster {
282286 n_features : i32 ,
283287 is_row_major : bool ,
284288 ) -> Result < Vec < f64 > > {
285- self . real_predict ( flat_x, n_features, is_row_major, PredictType :: Normal )
289+ self . real_predict ( flat_x, n_features, is_row_major, PredictType :: Normal , None )
290+ }
291+
292+ /// Get predictions given `&[f32]` or `&[f64]` slice of features. The resulting vector
293+ /// will have the size of `n_rows` by `n_classes`.
294+ ///
295+ /// Example:
296+ /// ```compile_fail
297+ /// use serde_json::json;
298+ /// let y_pred = bst.predict_with_params(&xs, 10, true, "num_threads=1").unwrap();
299+ /// ```
300+ pub fn predict_with_params < T : DType > (
301+ & self ,
302+ flat_x : & [ T ] ,
303+ n_features : i32 ,
304+ is_row_major : bool ,
305+ params : & str ,
306+ ) -> Result < Vec < f64 > > {
307+ self . real_predict (
308+ flat_x,
309+ n_features,
310+ is_row_major,
311+ PredictType :: Normal ,
312+ Some ( params) ,
313+ )
286314 }
287315
288316 /// Get raw scores given `&[f32]` or `&[f64]` slice of features. The resulting vector
@@ -293,7 +321,37 @@ impl Booster {
293321 n_features : i32 ,
294322 is_row_major : bool ,
295323 ) -> Result < Vec < f64 > > {
296- self . real_predict ( flat_x, n_features, is_row_major, PredictType :: RawScore )
324+ self . real_predict (
325+ flat_x,
326+ n_features,
327+ is_row_major,
328+ PredictType :: RawScore ,
329+ None ,
330+ )
331+ }
332+
333+ /// Get raw scores given `&[f32]` or `&[f64]` slice of features. The resulting vector
334+ /// will have the size of `n_rows` by `n_classes`.
335+ ///
336+ /// Example:
337+ /// ```compile_fail
338+ /// use serde_json::json;
339+ /// let y_pred = bst.predict_with_params(&xs, 10, true, "num_threads=1").unwrap();
340+ /// ```
341+ pub fn raw_scores_with_params < T : DType > (
342+ & self ,
343+ flat_x : & [ T ] ,
344+ n_features : i32 ,
345+ is_row_major : bool ,
346+ parameters : & str ,
347+ ) -> Result < Vec < f64 > > {
348+ self . real_predict (
349+ flat_x,
350+ n_features,
351+ is_row_major,
352+ PredictType :: RawScore ,
353+ Some ( parameters) ,
354+ )
297355 }
298356
299357 /// Predicts results for the given `x` and returns a vector or vectors (inner vectors will
@@ -482,6 +540,36 @@ mod tests {
482540 assert_eq ! ( normalized_result, vec![ 0 , 0 , 1 ] ) ;
483541 }
484542
543+ #[ test]
544+ fn predict_with_params ( ) {
545+ let params = json ! {
546+ {
547+ "num_iterations" : 10 ,
548+ "objective" : "binary" ,
549+ "metric" : "auc" ,
550+ "data_random_seed" : 0
551+ }
552+ } ;
553+ let bst = _train_booster ( & params) ;
554+ // let feature = vec![vec![0.5; 28], vec![0.0; 28], vec![0.9; 28]];
555+ let mut feature = [ 0.0 ; 28 * 3 ] ;
556+ for i in 0 ..28 {
557+ feature[ i] = 0.5 ;
558+ }
559+ for i in 56 ..feature. len ( ) {
560+ feature[ i] = 0.9 ;
561+ }
562+
563+ let result = bst
564+ . predict_with_params ( & feature, 28 , true , "num_threads=1" )
565+ . unwrap ( ) ;
566+ let mut normalized_result = Vec :: new ( ) ;
567+ for r in & result {
568+ normalized_result. push ( if * r > 0.5 { 1 } else { 0 } ) ;
569+ }
570+ assert_eq ! ( normalized_result, vec![ 0 , 0 , 1 ] ) ;
571+ }
572+
485573 #[ test]
486574 fn num_feature ( ) {
487575 let params = _default_params ( ) ;
0 commit comments