|
| 1 | +#coding:UTF-8 |
| 2 | +''' |
| 3 | +Date:20161030 |
| 4 | +@author: zhaozhiyong |
| 5 | +''' |
| 6 | + |
| 7 | +import cPickle as pickle |
| 8 | +from random_forests_train import get_predict |
| 9 | + |
| 10 | +def load_data(file_name): |
| 11 | + '''导入待分类的数据集 |
| 12 | + input: file_name(string):待分类数据存储的位置 |
| 13 | + output: test_data(list) |
| 14 | + ''' |
| 15 | + f = open(file_name) |
| 16 | + test_data = [] |
| 17 | + for line in f.readlines(): |
| 18 | + lines = line.strip().split("\t") |
| 19 | + tmp = [] |
| 20 | + for x in lines: |
| 21 | + tmp.append(float(x)) |
| 22 | + tmp.append(0) # 保存初始的label |
| 23 | + test_data.append(tmp) |
| 24 | + f.close() |
| 25 | + return test_data |
| 26 | + |
| 27 | +def load_model(result_file, feature_file): |
| 28 | + '''导入随机森林模型和每一个分类树中选择的特征 |
| 29 | + input: result_file(string):随机森林模型存储的文件 |
| 30 | + feature_file(string):分类树选择的特征存储的文件 |
| 31 | + output: trees_result(list):随机森林模型 |
| 32 | + trees_fiture(list):每一棵分类树选择的特征 |
| 33 | + ''' |
| 34 | + # 1、导入选择的特征 |
| 35 | + trees_fiture = [] |
| 36 | + f_fea = open(feature_file) |
| 37 | + for line in f_fea.readlines(): |
| 38 | + lines = line.strip().split("\t") |
| 39 | + tmp = [] |
| 40 | + for x in lines: |
| 41 | + tmp.append(int(x)) |
| 42 | + trees_fiture.append(tmp) |
| 43 | + f_fea.close() |
| 44 | + |
| 45 | + # 2、导入随机森林模型 |
| 46 | + with open(result_file, 'r') as f: |
| 47 | + trees_result = pickle.load(f) |
| 48 | + |
| 49 | + return trees_result, trees_fiture |
| 50 | + |
| 51 | +def save_result(data_test, prediction, result_file): |
| 52 | + '''保存最终的预测结果 |
| 53 | + input: data_test(list):待预测的数据 |
| 54 | + prediction(list):预测的结果 |
| 55 | + result_file(string):存储最终预测结果的文件名 |
| 56 | + ''' |
| 57 | + m = len(prediction) |
| 58 | + n = len(data_test[0]) |
| 59 | + |
| 60 | + f_result = open(result_file, "w") |
| 61 | + for i in xrange(m): |
| 62 | + tmp = [] |
| 63 | + for j in xrange(n -1): |
| 64 | + tmp.append(str(data_test[i][j])) |
| 65 | + tmp.append(str(prediction[i])) |
| 66 | + f_result.writelines("\t".join(tmp) + "\n") |
| 67 | + f_result.close() |
| 68 | + |
| 69 | +if __name__ == "__main__": |
| 70 | + # 1、导入测试数据集 |
| 71 | + print "--------- 1、load test data --------" |
| 72 | + data_test = load_data("test_data.txt") |
| 73 | + # 2、导入随机森林模型 |
| 74 | + print "--------- 2、load random forest model ----------" |
| 75 | + trees_result, trees_feature = load_model("result_file", "feature_file") |
| 76 | + # 3、预测 |
| 77 | + print "--------- 3、get prediction -----------" |
| 78 | + prediction = get_predict(trees_result, trees_feature, data_test) |
| 79 | + # 4、保存最终的预测结果 |
| 80 | + print "--------- 4、save result -----------" |
| 81 | + save_result(data_test, prediction, "final_result") |
| 82 | + |
| 83 | + |
| 84 | + |
0 commit comments