Skip to content

Commit 84173c4

Browse files
committed
Adds predict_with_params, raw_scores_with_params functions
1 parent 47f6d2f commit 84173c4

File tree

6 files changed

+98
-10
lines changed

6 files changed

+98
-10
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "lightgbm3"
3-
version = "1.0.1"
3+
version = "1.0.2"
44
edition = "2021"
55
authors = ["Dmitry Mottl <[email protected]>", "vaaaaanquish <[email protected]>"]
66
license = "MIT"
@@ -13,7 +13,7 @@ readme = "README.md"
1313
exclude = [".gitignore", ".github", ".gitmodules", "examples", "benches", "lightgbm3-sys"]
1414

1515
[dependencies]
16-
lightgbm3-sys = { path = "lightgbm3-sys", version = "1.0.0" }
16+
lightgbm3-sys = { path = "lightgbm3-sys", version = "1.0.2" }
1717
libc = "0.2"
1818
derive_builder = "0.12"
1919
serde_json = "1.0"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ use lightgbm3::{Dataset, Booster};
6464
let bst = Booster::from_file("path/to/model.lgb").unwrap();
6565
let features = vec![1.0, 2.0, -5.0];
6666
let n_features = features.len();
67-
let y_pred = bst.predict(&features, n_features as i32, true).unwrap()[0];
67+
let y_pred = bst.predict_with_params(&features, n_features as i32, true, "num_threads=1").unwrap()[0];
6868
```
6969

7070
Look in the [`./examples/`](https://github.com/Mottl/lightgbm3-rs/blob/main/examples/) folder for more details:

lightgbm3-sys/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "lightgbm3-sys"
3-
version = "1.0.1"
3+
version = "1.0.2"
44
edition = "2021"
55
authors = ["Dmitry Mottl <[email protected]>", "vaaaaanquish <[email protected]>"]
66
build = "build.rs"

lightgbm3-sys/lightgbm

src/booster.rs

Lines changed: 92 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ impl Booster {
230230
n_features: i32,
231231
is_row_major: bool,
232232
predict_type: PredictType,
233+
parameters: Option<&str>,
233234
) -> Result<Vec<f64>> {
234235
if self.n_features <= 0 {
235236
return Err(Error::new("n_features should be greater than 0"));
@@ -252,7 +253,10 @@ impl Booster {
252253
)));
253254
}
254255
let n_rows = flat_x.len() / n_features as usize;
255-
let params = CString::new("").unwrap();
256+
let params_cstring = parameters
257+
.map(|s| CString::new(s))
258+
.unwrap_or(CString::new(""))
259+
.unwrap();
256260
let mut out_length: c_longlong = 0;
257261

258262
let out_result: Vec<f64> = vec![Default::default(); n_rows * self.n_classes as usize];
@@ -266,7 +270,7 @@ impl Booster {
266270
predict_type.into(), // predict_type
267271
0_i32, // start_iteration
268272
self.max_iterations, // num_iteration, <= 0 means no limit
269-
params.as_ptr() as *const c_char,
273+
params_cstring.as_ptr() as *const c_char,
270274
&mut out_length,
271275
out_result.as_ptr() as *mut c_double
272276
))?;
@@ -282,7 +286,31 @@ impl Booster {
282286
n_features: i32,
283287
is_row_major: bool,
284288
) -> Result<Vec<f64>> {
285-
self.real_predict(flat_x, n_features, is_row_major, PredictType::Normal)
289+
self.real_predict(flat_x, n_features, is_row_major, PredictType::Normal, None)
290+
}
291+
292+
/// Get predictions given `&[f32]` or `&[f64]` slice of features. The resulting vector
293+
/// will have the size of `n_rows` by `n_classes`.
294+
///
295+
/// Example:
296+
/// ```compile_fail
297+
/// use serde_json::json;
298+
/// let y_pred = bst.predict_with_params(&xs, 10, true, "num_threads=1").unwrap();
299+
/// ```
300+
pub fn predict_with_params<T: DType>(
301+
&self,
302+
flat_x: &[T],
303+
n_features: i32,
304+
is_row_major: bool,
305+
params: &str,
306+
) -> Result<Vec<f64>> {
307+
self.real_predict(
308+
flat_x,
309+
n_features,
310+
is_row_major,
311+
PredictType::Normal,
312+
Some(params),
313+
)
286314
}
287315

288316
/// Get raw scores given `&[f32]` or `&[f64]` slice of features. The resulting vector
@@ -293,7 +321,37 @@ impl Booster {
293321
n_features: i32,
294322
is_row_major: bool,
295323
) -> Result<Vec<f64>> {
296-
self.real_predict(flat_x, n_features, is_row_major, PredictType::RawScore)
324+
self.real_predict(
325+
flat_x,
326+
n_features,
327+
is_row_major,
328+
PredictType::RawScore,
329+
None,
330+
)
331+
}
332+
333+
/// Get raw scores given `&[f32]` or `&[f64]` slice of features. The resulting vector
334+
/// will have the size of `n_rows` by `n_classes`.
335+
///
336+
/// Example:
337+
/// ```compile_fail
338+
/// use serde_json::json;
339+
/// let y_pred = bst.predict_with_params(&xs, 10, true, "num_threads=1").unwrap();
340+
/// ```
341+
pub fn raw_scores_with_params<T: DType>(
342+
&self,
343+
flat_x: &[T],
344+
n_features: i32,
345+
is_row_major: bool,
346+
parameters: &str,
347+
) -> Result<Vec<f64>> {
348+
self.real_predict(
349+
flat_x,
350+
n_features,
351+
is_row_major,
352+
PredictType::RawScore,
353+
Some(parameters),
354+
)
297355
}
298356

299357
/// Predicts results for the given `x` and returns a vector or vectors (inner vectors will
@@ -482,6 +540,36 @@ mod tests {
482540
assert_eq!(normalized_result, vec![0, 0, 1]);
483541
}
484542

543+
#[test]
544+
fn predict_with_params() {
545+
let params = json! {
546+
{
547+
"num_iterations": 10,
548+
"objective": "binary",
549+
"metric": "auc",
550+
"data_random_seed": 0
551+
}
552+
};
553+
let bst = _train_booster(&params);
554+
// let feature = vec![vec![0.5; 28], vec![0.0; 28], vec![0.9; 28]];
555+
let mut feature = [0.0; 28 * 3];
556+
for i in 0..28 {
557+
feature[i] = 0.5;
558+
}
559+
for i in 56..feature.len() {
560+
feature[i] = 0.9;
561+
}
562+
563+
let result = bst
564+
.predict_with_params(&feature, 28, true, "num_threads=1")
565+
.unwrap();
566+
let mut normalized_result = Vec::new();
567+
for r in &result {
568+
normalized_result.push(if *r > 0.5 { 1 } else { 0 });
569+
}
570+
assert_eq!(normalized_result, vec![0, 0, 1]);
571+
}
572+
485573
#[test]
486574
fn num_feature() {
487575
let params = _default_params();

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
//! let bst = Booster::from_file("path/to/model.lgb").unwrap();
3838
//! let features = vec![1.0, 2.0, -5.0];
3939
//! let n_features = features.len();
40-
//! let y_pred = bst.predict(&features, n_features as i32, true).unwrap()[0];
40+
//! let y_pred = bst.predict_with_params(&features, n_features as i32, true, "num_threads=1").unwrap()[0];
4141
//! ```
4242
4343
macro_rules! lgbm_call {

0 commit comments

Comments
 (0)