1+ # coding:utf-8
2+
3+ from numpy import *
4+
5+ def load_data (path ):
6+ f = open (path )
7+ data = []
8+ for line in f .readlines ():
9+ arr = []
10+ lines = line .strip ().split ("\t " )
11+ for x in lines :
12+ if x != "-" :
13+ arr .append (float (x ))
14+ else :
15+ arr .append (float (0 ))
16+ #print arr
17+ data .append (arr )
18+ #print data
19+ return data
20+
21+ def gradAscent (data , K ):
22+ dataMat = mat (data )
23+ print (dataMat )
24+ m , n = shape (dataMat )
25+ p = mat (random .random ((m , K )))
26+ q = mat (random .random ((K , n )))
27+
28+ alpha = 0.0002
29+ beta = 0.02
30+ maxCycles = 10000
31+
32+ for step in range (maxCycles ):
33+ for i in range (m ):
34+ for j in range (n ):
35+ if dataMat [i ,j ] > 0 :
36+ #print dataMat[i,j]
37+ error = dataMat [i ,j ]
38+ for k in range (K ):
39+ error = error - p [i ,k ]* q [k ,j ]
40+ for k in range (K ):
41+ p [i ,k ] = p [i ,k ] + alpha * (2 * error * q [k ,j ] - beta * p [i ,k ])
42+ q [k ,j ] = q [k ,j ] + alpha * (2 * error * p [i ,k ] - beta * q [k ,j ])
43+
44+ # 损失函数,判断收敛
45+ loss = 0.0
46+ for i in range (m ):
47+ for j in range (n ):
48+ if dataMat [i ,j ] > 0 :
49+ error = 0.0
50+ for k in range (K ):
51+ error = error + p [i ,k ]* q [k ,j ]
52+ loss = (dataMat [i ,j ] - error ) * (dataMat [i ,j ] - error )
53+ for k in range (K ):
54+ loss = loss + beta * (p [i ,k ] * p [i ,k ] + q [k ,j ] * q [k ,j ]) / 2
55+
56+ if loss < 0.001 :
57+ break
58+ # print(step)
59+ if step % 1000 == 0 :
60+ print (loss )
61+
62+ return p , q
63+
64+
65+ if __name__ == "__main__" :
66+ dataMatrix = load_data ("data.txt" )
67+
68+ p , q = gradAscent (dataMatrix , 5 )
69+ '''
70+ p = mat(ones((4,10)))
71+ print p
72+ q = mat(ones((10,5)))
73+ '''
74+ result = p * q
75+ #print p
76+ #print q
77+
78+ print (result )
79+
80+ # [[ 4.01235942 2.99635324 2.8688891 4.96017332 4.30307713]
81+ # [ 4.94058713 5.42013643 4.00107426 4.02193661 4.45573454]
82+ # [ 3.9898181 4.46104338 4.97200637 3.48454365 3.00227572]
83+ # [ 2.05763283 2.95284794 2.13460563 0.97376823 1.48549247]
84+ # [ 4.50701719 3.99370117 2.00291634 4.50752955 4.97625932]]
0 commit comments