Skip to content

Commit 9e3e161

Browse files
Update FM_train.py
1 parent 57892f5 commit 9e3e161

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

Chapter_3 Factorization Machine/FM_train.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ def loadDataSet(data):
2323
lineArr.append(float(lines[i]))
2424
dataMat.append(lineArr)
2525

26-
#labelMat.append(float(lines[-1]) * 2 - 1) # 转换成{-1,1}
27-
labelMat.append(float(lines[-1]))
26+
labelMat.append(float(lines[-1]) * 2 - 1) # 转换成{-1,1}
2827
fr.close()
2928
return dataMat, labelMat
3029

@@ -34,7 +33,7 @@ def sigmoid(inx):
3433
def initialize_v(n, k):
3534
'''初始化交叉项
3635
input: n(int)特征的个数
37-
k(int)FM模型的交叉向量维度
36+
k(int)FM模型的超参数
3837
output: v(mat):交叉项的系数权重
3938
'''
4039
v = np.mat(np.zeros((n, k)))
@@ -174,7 +173,7 @@ def save_model(file_name, w0, w, v):
174173
dataTrain, labelTrain = loadDataSet("data_1.txt")
175174
print "---------- 2.learning ---------"
176175
# 2、利用随机梯度训练FM模型
177-
w0, w, v = stocGradAscent(np.mat(dataTrain), labelTrain, 2, 20000, 0.01)
176+
w0, w, v = stocGradAscent(np.mat(dataTrain), labelTrain, 3, 10000, 0.01)
178177
predict_result = getPrediction(np.mat(dataTrain), w0, w, v) # 得到训练的准确性
179178
print "----------training accuracy: %f" % (1 - getAccuracy(predict_result, labelTrain))
180179
print "---------- 3.save result ---------"

0 commit comments

Comments
 (0)