Skip to content

Commit 946a80d

Browse files
committed
Fix Bug jimichan#1, model from C,but different prediction results
1 parent f511f9a commit 946a80d

File tree

7 files changed

+127752
-51
lines changed

7 files changed

+127752
-51
lines changed

src/example/java/com/mayabot/mynlp/fasttext/AgnewsTest.kt

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,12 @@ import java.io.File
2525
//Process finished with exit code 0
2626

2727
fun main(args: Array<String>) {
28-
val file = File("data/fasttext/ag.train")
28+
val file = File("src/example/resources/ag.train")
2929

3030
val train = FastText.train(file, ModelName.sup)
3131

3232
AgnewsTest.predict(train)
3333

34-
val x = intArrayOf(1)
35-
3634
}
3735

3836
object AgnewsTest{
@@ -44,7 +42,7 @@ object AgnewsTest{
4442
var right = 0
4543
val splitter = Splitter.on(CharMatcher.whitespace())
4644

47-
for (line in Files.asCharSource(File("data/fasttext/ag.test"), Charsets.UTF_8).readLines()) {
45+
for (line in Files.asCharSource(File("src/example/resources/ag.test"), Charsets.UTF_8).readLines()) {
4846

4947
val i = line.indexOf(',')
5048
val label = line.substring(0, i).trim { it <= ' ' }
@@ -54,7 +52,7 @@ object AgnewsTest{
5452
val predict = fastText.predict(splitter.split(text), 3)
5553

5654
if (!predict.isEmpty()) {
57-
if (label == predict.get(0).second) {
55+
if (label == predict[0].second) {
5856
right++
5957
}
6058
}
Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,25 @@
11
package com.mayabot.mynlp.fasttext
22

33
fun main(args: Array<String>) {
4-
val train = FastText.loadFasttextBinModel("data/fasttext/model.bin")
5-
AgnewsTest.predict(train)
4+
val train = FastText.loadModel("/Users/jimichan/Downloads/model0726.dir")
5+
6+
val text = listOf(
7+
"^ 霸王腰花 是 什么 $",
8+
"^ 洗碗机 开 门 $",
9+
"^ 你好 吗 $",
10+
"^ 搞 个 三 菜 一 汤 $",
11+
"^ 还 要 多 久 $",
12+
"^ 播放 一首 刘德华 的 歌 $",
13+
"^ 今天 北京 天气 怎么样 $",
14+
"^ 启动 汽车 引擎 $",
15+
"^ 找 一家 中国餐厅 $",
16+
"^ 鱼翅 怎么做 $",
17+
"^ 有什么 清热解暑 的 菜 吗 $")
18+
//train.saveModel("/Users/jimichan/Downloads/model0726.dir")
19+
//println(train.predict("^ 霸王腰花 是 什么 $".split(" "),10))
20+
//println(train.predict("^ 洗碗机 开 门 $".split(" "),10))
21+
text.forEach { line->
22+
println("$line \t\t"+train.predict(line.split(" "),10))
23+
}
24+
625
}

src/example/resources/ag.test

Lines changed: 7600 additions & 0 deletions
Large diffs are not rendered by default.

src/example/resources/ag.train

Lines changed: 120000 additions & 0 deletions
Large diffs are not rendered by default.

src/main/java/com/mayabot/mynlp/fasttext/Dictionary.kt

Lines changed: 83 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,15 @@ import com.google.common.base.CharMatcher
88
import com.google.common.base.Splitter
99
import java.io.File
1010
import java.io.IOException
11+
import java.math.BigInteger
1112
import java.nio.ByteBuffer
1213
import java.nio.ByteOrder
1314
import java.nio.channels.FileChannel
1415
import java.util.*
16+
import com.sun.tools.javac.tree.TreeInfo.args
17+
import com.google.common.primitives.UnsignedLong
18+
19+
1520

1621
const val HASH_C = 116049371
1722
const val MAX_VOCAB_SIZE = 30000000
@@ -45,31 +50,32 @@ class Dictionary(private val args: Args) {
4550
var size: Int = 0
4651
private set
4752

48-
var wordList:MutableList<Entry> = ArrayList(50000 * 4)
53+
var wordList: MutableList<Entry> = ArrayList(50000 * 4)
4954
private var word_hash_2_id: IntArray = IntArray(MAX_VOCAB_SIZE).apply {
5055
fill(-1)
5156
}
5257

53-
var nwords:Int = 0
54-
var nlabels:Int = 0
55-
var ntokens:Long = 0
58+
var nwords: Int = 0
59+
var nlabels: Int = 0
60+
var ntokens: Long = 0
5661

5762
var pruneidxSize = -1L
58-
var pdiscard:FloatArray = FloatArray(0)
59-
var pruneidx:IntIntMap = IntIntHashMap()
63+
var pdiscard: FloatArray = FloatArray(0)
64+
var pruneidx: IntIntMap = IntIntHashMap()
6065

6166
/**
6267
* maxn length of char ngram
6368
*/
6469
val maxn = args.maxn
6570
val minn = args.minn
6671
val bucket = args.bucket
72+
val bucketLong = args.bucket.toLong()
6773
val wordNgrams = args.wordNgrams
6874
val label = args.label
6975
val model = args.model
7076

7177

72-
fun isPruned() = pruneidxSize >=0
78+
fun isPruned() = pruneidxSize >= 0
7379

7480

7581
fun getType(id: Int): EntryType {
@@ -95,8 +101,8 @@ class Dictionary(private val args: Args) {
95101
} else word_hash_2_id[id]
96102
}
97103

98-
private fun getId(w: String, h:Long): Int {
99-
val id = find(w,h)
104+
private fun getId(w: String, h: Long): Int {
105+
val id = find(w, h)
100106
return if (id == -1) {
101107
-1 //词不存在
102108
} else word_hash_2_id[id]
@@ -110,7 +116,7 @@ class Dictionary(private val args: Args) {
110116
val id = word_hash_2_id[h]
111117

112118
if (id == -1) {
113-
wordList.add(Entry(w,1,getType(w)))
119+
wordList.add(Entry(w, 1, getType(w)))
114120
word_hash_2_id[h] = size++
115121
} else {
116122
wordList[id].count++
@@ -162,7 +168,6 @@ class Dictionary(private val args: Args) {
162168
}
163169

164170

165-
166171
/**
167172
* 读取分析原始语料,语料单词直接空格
168173
*
@@ -186,8 +191,7 @@ class Dictionary(private val args: Args) {
186191
file.useLines { lines ->
187192
lines.filterNot { it.isNullOrBlank() || it.startsWith("#") }
188193
.forEach { line ->
189-
splitter.split(line).forEach {
190-
token ->
194+
splitter.split(line).forEach { token ->
191195
add(token)
192196
if (ntokens % 1000000 == 0L && args.verbose > 1) {
193197
print("\rRead " + ntokens / 1000000 + "M words")
@@ -253,11 +257,14 @@ class Dictionary(private val args: Args) {
253257
continue
254258
}
255259

256-
val ngram = StringBuilder()
260+
var ngram: StringBuilder? = null
257261

258262
var j = i
259263
var n = 1
260264
while (j < word_len && n <= maxn) {
265+
if (ngram == null) {
266+
ngram = StringBuilder()
267+
}
261268
ngram.append(word[j++])
262269
while (j < word.length && charMatches(word[j])) {
263270
ngram.append(word[j++])
@@ -297,7 +304,7 @@ class Dictionary(private val args: Args) {
297304
val t = args.t
298305
for (i in 0 until size) {
299306
val f = wordList[i].count * 1.0f / ntokens
300-
pdiscard[i] = (Math.sqrt(t / f) + t/ f).toFloat()
307+
pdiscard[i] = (Math.sqrt(t / f) + t / f).toFloat()
301308
}
302309
}
303310

@@ -307,8 +314,8 @@ class Dictionary(private val args: Args) {
307314
.toMutableList()
308315
(wordList as java.util.ArrayList<Entry>).trimToSize()
309316

310-
size=0
311-
nwords=0
317+
size = 0
318+
nwords = 0
312319
nlabels = 0
313320

314321
word_hash_2_id.fill(-1)
@@ -318,7 +325,7 @@ class Dictionary(private val args: Args) {
318325
word_hash_2_id[h] = size++
319326
if (it.type == EntryType.word) {
320327
nwords++
321-
}else if (it.type == EntryType.label) {
328+
} else if (it.type == EntryType.label) {
322329
nlabels++
323330
}
324331
}
@@ -355,7 +362,6 @@ class Dictionary(private val args: Args) {
355362
val h = stringHash(token)
356363
val wid = getId(token, h)
357364
val type = if (wid < 0) getType(token) else getType(wid)
358-
359365
ntokens++
360366

361367
if (type == EntryType.word) {
@@ -371,21 +377,60 @@ class Dictionary(private val args: Args) {
371377
return ntokens
372378
}
373379

380+
companion object {
381+
val coeff = UnsignedLong.valueOf(116049371L)
382+
val U64_START = UnsignedLong.valueOf("18446744069414584320")
383+
}
374384

375385
private fun addWordNgrams(line: IntArrayList,
376386
hashes: LongArrayList,
377387
n: Int) {
378-
for (i in 0 until hashes.size()) {
379-
var h = hashes.get(i)
388+
//read word^ hash 3675003649 int32 -619963647 uint64 18446744073089587969 wid 1
389+
// for (i in 0 until hashes.size()) {
390+
// var h = hashes.get(i)
391+
// var j = i + 1
392+
// while (j < hashes.size() && j < i + n) {
393+
// h = (h * 116049371) + hashes.get(j)
394+
// pushHash(line, (h % bucket).toInt())
395+
// j++
396+
// }
397+
// }
398+
// AddWordNgramsHelper.addWordNGrams(line,hashes,n,bucket.toLong(),{x->
399+
// pushHash(h)
400+
// })
401+
val hashSize = hashes.size()
402+
403+
for (i in 0 until hashSize) {
404+
var h = toUnsignedLong64(hashes.get(i))
380405
var j = i + 1
381-
while (j < hashes.size() && j < i + n) {
382-
h = h * 116049371 + hashes.get(j)
383-
pushHash(line, (h % bucket).toInt())
406+
while (j < hashSize && j < i + n) {
407+
//val h2 = hashes.get(j)
408+
val h2 = hashes.get(j).toInt().toLong()
409+
410+
if (h2 >= 0) {
411+
h = h.times(coeff).plus(UnsignedLong.valueOf(h2))
412+
} else {
413+
h = h.times(coeff).minus(UnsignedLong.valueOf(-h2))
414+
}
415+
var id = h.mod(UnsignedLong.valueOf(bucketLong)).toInt()
416+
417+
pushHash(line ,id)
384418
j++
385419
}
386420
}
387421
}
388422

423+
// from https://github.com/linkfluence/fastText4j/blob/b018438e84bebd20f89a701c35f022139418930c/src/main/java/fasttext/BaseDictionary.java
424+
425+
426+
private fun toUnsignedLong64(l: Long): UnsignedLong {
427+
return if (l > Integer.MAX_VALUE) {
428+
U64_START.plus(UnsignedLong.valueOf(l))
429+
} else {
430+
UnsignedLong.valueOf(l)
431+
}
432+
}
433+
389434
private fun addSubwords(line: IntArrayList,
390435
token: String,
391436
wid: Int) {
@@ -482,30 +527,30 @@ class Dictionary(private val args: Args) {
482527
channel.writeLong(ntokens)
483528
channel.writeLong(pruneidxSize)
484529

485-
val buffer = ByteBuffer.allocate(1024*1024)
486-
val em = buffer.capacity()*0.25f
530+
val buffer = ByteBuffer.allocate(1024 * 1024)
531+
val em = buffer.capacity() * 0.25f
487532
for (entry in wordList) {
488533
buffer.writeUTF(entry.word)
489534
buffer.putLong(entry.count)
490535
buffer.put(entry.type.value.toByte())
491536

492-
if(buffer.remaining() < em){
493-
buffer.flip()
494-
while (buffer.hasRemaining()) {
495-
channel.write(buffer)
496-
}
497-
buffer.clear()
498-
}
537+
if (buffer.remaining() < em) {
538+
buffer.flip()
539+
while (buffer.hasRemaining()) {
540+
channel.write(buffer)
541+
}
542+
buffer.clear()
543+
}
499544
}
500545

501546
buffer.flip()
502547
while (buffer.hasRemaining()) {
503548
channel.write(buffer)
504549
}
505550

506-
val buffer2 = ByteBuffer.allocate(pruneidx.size()*4)
551+
val buffer2 = ByteBuffer.allocate(pruneidx.size() * 4)
507552
pruneidx.forEach {
508-
buffer2.putInt(it.key,it.value)
553+
buffer2.putInt(it.key, it.value)
509554
}
510555
buffer2.flip()
511556
channel.write(buffer2)
@@ -514,7 +559,7 @@ class Dictionary(private val args: Args) {
514559

515560

516561
@Throws(IOException::class)
517-
fun load(buffer: AutoDataInput) : Dictionary {
562+
fun load(buffer: AutoDataInput): Dictionary {
518563
// wordList.clear();
519564
// word2int_.clear();
520565

@@ -530,7 +575,7 @@ class Dictionary(private val args: Args) {
530575
//size 189997 18万的词汇
531576
//val byteArray = ByteArray(1024)
532577
for (i in 0 until size) {
533-
val e = Entry(buffer.readUTF(),buffer.readLong(),EntryType.fromValue(buffer.readUnsignedByte().toInt()))
578+
val e = Entry(buffer.readUTF(), buffer.readLong(), EntryType.fromValue(buffer.readUnsignedByte().toInt()))
534579
wordList.add(e)
535580
word_hash_2_id[find(e.word)] = i
536581
}
@@ -558,6 +603,7 @@ class Dictionary(private val args: Args) {
558603
*/
559604

560605
val Empty_IntArrayList = IntArrayList(0)
606+
561607
data class Entry(
562608
val word: String,
563609
var count: Long,

0 commit comments

Comments
 (0)