|
| 1 | +# coding:UTF-8 |
| 2 | +''' |
| 3 | +Date:20161030 |
| 4 | +@author: zhaozhiyong |
| 5 | +''' |
| 6 | + |
| 7 | +import random as rd |
| 8 | +import cPickle as pickle |
| 9 | +from train_cart import predict,node |
| 10 | + |
| 11 | +def load_data(): |
| 12 | + '''导入测试数据集 |
| 13 | + ''' |
| 14 | + data_test = [] |
| 15 | + for i in xrange(400): |
| 16 | + tmp = [] |
| 17 | + tmp.append(rd.random()) # 随机生成[0,1]之间的样本 |
| 18 | + data_test.append(tmp) |
| 19 | + return data_test |
| 20 | + |
| 21 | +def load_model(tree_file): |
| 22 | + '''导入训练好的CART回归树模型 |
| 23 | + input: tree_file(list):保存CART回归树模型的文件 |
| 24 | + output: regression_tree:CART回归树 |
| 25 | + ''' |
| 26 | + with open(tree_file, 'r') as f: |
| 27 | + regression_tree = pickle.load(f) |
| 28 | + return regression_tree |
| 29 | + |
| 30 | +def get_prediction(data_test, regression_tree): |
| 31 | + '''对测试样本进行预测 |
| 32 | + input: data_test(list):需要预测的样本 |
| 33 | + regression_tree(regression_tree):训练好的回归树模型 |
| 34 | + output: result(list): |
| 35 | + ''' |
| 36 | + result = [] |
| 37 | + for x in data_test: |
| 38 | + result.append(predict(x, regression_tree)) |
| 39 | + return result |
| 40 | + |
| 41 | +def save_result(data_test, result, prediction_file): |
| 42 | + '''保存最终的预测结果 |
| 43 | + input: data_test(list):需要预测的数据集 |
| 44 | + result(list):预测的结果 |
| 45 | + prediction_file(string):保存结果的文件 |
| 46 | + ''' |
| 47 | + f = open(prediction_file, "w") |
| 48 | + for i in xrange(len(result)): |
| 49 | + a = str(data_test[i][0]) + "\t" + str(result[i]) + "\n" |
| 50 | + f.write(a) |
| 51 | + f.close() |
| 52 | + |
| 53 | +if __name__ == "__main__": |
| 54 | + # 1、导入待计算的数据 |
| 55 | + print "--------- 1、load data ----------" |
| 56 | + data_test = load_data() |
| 57 | + # 2、导入回归树模型 |
| 58 | + print "--------- 2、load regression tree ---------" |
| 59 | + regression_tree = load_model("regression_tree") |
| 60 | + # 3、进行预测 |
| 61 | + print "--------- 3、get prediction -----------" |
| 62 | + prediction = get_prediction(data_test, regression_tree) |
| 63 | + # 4、保存预测的结果 |
| 64 | + print "--------- 4、save result ----------" |
| 65 | + save_result(data_test, prediction, "prediction") |
0 commit comments