@@ -8,10 +8,15 @@ import com.google.common.base.CharMatcher
8
8
import com.google.common.base.Splitter
9
9
import java.io.File
10
10
import java.io.IOException
11
+ import java.math.BigInteger
11
12
import java.nio.ByteBuffer
12
13
import java.nio.ByteOrder
13
14
import java.nio.channels.FileChannel
14
15
import java.util.*
16
+ import com.sun.tools.javac.tree.TreeInfo.args
17
+ import com.google.common.primitives.UnsignedLong
18
+
19
+
15
20
16
21
const val HASH_C = 116049371
17
22
const val MAX_VOCAB_SIZE = 30000000
@@ -45,31 +50,32 @@ class Dictionary(private val args: Args) {
45
50
var size: Int = 0
46
51
private set
47
52
48
- var wordList: MutableList <Entry > = ArrayList (50000 * 4 )
53
+ var wordList: MutableList <Entry > = ArrayList (50000 * 4 )
49
54
private var word_hash_2_id: IntArray = IntArray (MAX_VOCAB_SIZE ).apply {
50
55
fill(- 1 )
51
56
}
52
57
53
- var nwords: Int = 0
54
- var nlabels: Int = 0
55
- var ntokens: Long = 0
58
+ var nwords: Int = 0
59
+ var nlabels: Int = 0
60
+ var ntokens: Long = 0
56
61
57
62
var pruneidxSize = - 1L
58
- var pdiscard: FloatArray = FloatArray (0 )
59
- var pruneidx: IntIntMap = IntIntHashMap ()
63
+ var pdiscard: FloatArray = FloatArray (0 )
64
+ var pruneidx: IntIntMap = IntIntHashMap ()
60
65
61
66
/* *
62
67
* maxn length of char ngram
63
68
*/
64
69
val maxn = args.maxn
65
70
val minn = args.minn
66
71
val bucket = args.bucket
72
+ val bucketLong = args.bucket.toLong()
67
73
val wordNgrams = args.wordNgrams
68
74
val label = args.label
69
75
val model = args.model
70
76
71
77
72
- fun isPruned () = pruneidxSize >= 0
78
+ fun isPruned () = pruneidxSize >= 0
73
79
74
80
75
81
fun getType (id : Int ): EntryType {
@@ -95,8 +101,8 @@ class Dictionary(private val args: Args) {
95
101
} else word_hash_2_id[id]
96
102
}
97
103
98
- private fun getId (w : String , h : Long ): Int {
99
- val id = find(w,h)
104
+ private fun getId (w : String , h : Long ): Int {
105
+ val id = find(w, h)
100
106
return if (id == - 1 ) {
101
107
- 1 // 词不存在
102
108
} else word_hash_2_id[id]
@@ -110,7 +116,7 @@ class Dictionary(private val args: Args) {
110
116
val id = word_hash_2_id[h]
111
117
112
118
if (id == - 1 ) {
113
- wordList.add(Entry (w,1 , getType(w)))
119
+ wordList.add(Entry (w, 1 , getType(w)))
114
120
word_hash_2_id[h] = size++
115
121
} else {
116
122
wordList[id].count++
@@ -162,7 +168,6 @@ class Dictionary(private val args: Args) {
162
168
}
163
169
164
170
165
-
166
171
/* *
167
172
* 读取分析原始语料,语料单词直接空格
168
173
*
@@ -186,8 +191,7 @@ class Dictionary(private val args: Args) {
186
191
file.useLines { lines ->
187
192
lines.filterNot { it.isNullOrBlank() || it.startsWith(" #" ) }
188
193
.forEach { line ->
189
- splitter.split(line).forEach {
190
- token ->
194
+ splitter.split(line).forEach { token ->
191
195
add(token)
192
196
if (ntokens % 1000000 == 0L && args.verbose > 1 ) {
193
197
print (" \r Read " + ntokens / 1000000 + " M words" )
@@ -253,11 +257,14 @@ class Dictionary(private val args: Args) {
253
257
continue
254
258
}
255
259
256
- val ngram = StringBuilder ()
260
+ var ngram: StringBuilder ? = null
257
261
258
262
var j = i
259
263
var n = 1
260
264
while (j < word_len && n <= maxn) {
265
+ if (ngram == null ) {
266
+ ngram = StringBuilder ()
267
+ }
261
268
ngram.append(word[j++ ])
262
269
while (j < word.length && charMatches(word[j])) {
263
270
ngram.append(word[j++ ])
@@ -297,7 +304,7 @@ class Dictionary(private val args: Args) {
297
304
val t = args.t
298
305
for (i in 0 until size) {
299
306
val f = wordList[i].count * 1.0f / ntokens
300
- pdiscard[i] = (Math .sqrt(t / f) + t/ f).toFloat()
307
+ pdiscard[i] = (Math .sqrt(t / f) + t / f).toFloat()
301
308
}
302
309
}
303
310
@@ -307,8 +314,8 @@ class Dictionary(private val args: Args) {
307
314
.toMutableList()
308
315
(wordList as java.util.ArrayList <Entry >).trimToSize()
309
316
310
- size= 0
311
- nwords= 0
317
+ size = 0
318
+ nwords = 0
312
319
nlabels = 0
313
320
314
321
word_hash_2_id.fill(- 1 )
@@ -318,7 +325,7 @@ class Dictionary(private val args: Args) {
318
325
word_hash_2_id[h] = size++
319
326
if (it.type == EntryType .word) {
320
327
nwords++
321
- }else if (it.type == EntryType .label) {
328
+ } else if (it.type == EntryType .label) {
322
329
nlabels++
323
330
}
324
331
}
@@ -355,7 +362,6 @@ class Dictionary(private val args: Args) {
355
362
val h = stringHash(token)
356
363
val wid = getId(token, h)
357
364
val type = if (wid < 0 ) getType(token) else getType(wid)
358
-
359
365
ntokens++
360
366
361
367
if (type == EntryType .word) {
@@ -371,21 +377,60 @@ class Dictionary(private val args: Args) {
371
377
return ntokens
372
378
}
373
379
380
+ companion object {
381
+ val coeff = UnsignedLong .valueOf(116049371L )
382
+ val U64_START = UnsignedLong .valueOf(" 18446744069414584320" )
383
+ }
374
384
375
385
private fun addWordNgrams (line : IntArrayList ,
376
386
hashes : LongArrayList ,
377
387
n : Int ) {
378
- for (i in 0 until hashes.size()) {
379
- var h = hashes.get(i)
388
+ // read word^ hash 3675003649 int32 -619963647 uint64 18446744073089587969 wid 1
389
+ // for (i in 0 until hashes.size()) {
390
+ // var h = hashes.get(i)
391
+ // var j = i + 1
392
+ // while (j < hashes.size() && j < i + n) {
393
+ // h = (h * 116049371) + hashes.get(j)
394
+ // pushHash(line, (h % bucket).toInt())
395
+ // j++
396
+ // }
397
+ // }
398
+ // AddWordNgramsHelper.addWordNGrams(line,hashes,n,bucket.toLong(),{x->
399
+ // pushHash(h)
400
+ // })
401
+ val hashSize = hashes.size()
402
+
403
+ for (i in 0 until hashSize) {
404
+ var h = toUnsignedLong64(hashes.get(i))
380
405
var j = i + 1
381
- while (j < hashes.size() && j < i + n) {
382
- h = h * 116049371 + hashes.get(j)
383
- pushHash(line, (h % bucket).toInt())
406
+ while (j < hashSize && j < i + n) {
407
+ // val h2 = hashes.get(j)
408
+ val h2 = hashes.get(j).toInt().toLong()
409
+
410
+ if (h2 >= 0 ) {
411
+ h = h.times(coeff).plus(UnsignedLong .valueOf(h2))
412
+ } else {
413
+ h = h.times(coeff).minus(UnsignedLong .valueOf(- h2))
414
+ }
415
+ var id = h.mod(UnsignedLong .valueOf(bucketLong)).toInt()
416
+
417
+ pushHash(line ,id)
384
418
j++
385
419
}
386
420
}
387
421
}
388
422
423
+ // from https://github.com/linkfluence/fastText4j/blob/b018438e84bebd20f89a701c35f022139418930c/src/main/java/fasttext/BaseDictionary.java
424
+
425
+
426
+ private fun toUnsignedLong64 (l : Long ): UnsignedLong {
427
+ return if (l > Integer .MAX_VALUE ) {
428
+ U64_START .plus(UnsignedLong .valueOf(l))
429
+ } else {
430
+ UnsignedLong .valueOf(l)
431
+ }
432
+ }
433
+
389
434
private fun addSubwords (line : IntArrayList ,
390
435
token : String ,
391
436
wid : Int ) {
@@ -482,30 +527,30 @@ class Dictionary(private val args: Args) {
482
527
channel.writeLong(ntokens)
483
528
channel.writeLong(pruneidxSize)
484
529
485
- val buffer = ByteBuffer .allocate(1024 * 1024 )
486
- val em = buffer.capacity()* 0.25f
530
+ val buffer = ByteBuffer .allocate(1024 * 1024 )
531
+ val em = buffer.capacity() * 0.25f
487
532
for (entry in wordList) {
488
533
buffer.writeUTF(entry.word)
489
534
buffer.putLong(entry.count)
490
535
buffer.put(entry.type.value.toByte())
491
536
492
- if (buffer.remaining() < em){
493
- buffer.flip()
494
- while (buffer.hasRemaining()) {
495
- channel.write(buffer)
496
- }
497
- buffer.clear()
498
- }
537
+ if (buffer.remaining() < em) {
538
+ buffer.flip()
539
+ while (buffer.hasRemaining()) {
540
+ channel.write(buffer)
541
+ }
542
+ buffer.clear()
543
+ }
499
544
}
500
545
501
546
buffer.flip()
502
547
while (buffer.hasRemaining()) {
503
548
channel.write(buffer)
504
549
}
505
550
506
- val buffer2 = ByteBuffer .allocate(pruneidx.size()* 4 )
551
+ val buffer2 = ByteBuffer .allocate(pruneidx.size() * 4 )
507
552
pruneidx.forEach {
508
- buffer2.putInt(it.key,it.value)
553
+ buffer2.putInt(it.key, it.value)
509
554
}
510
555
buffer2.flip()
511
556
channel.write(buffer2)
@@ -514,7 +559,7 @@ class Dictionary(private val args: Args) {
514
559
515
560
516
561
@Throws(IOException ::class )
517
- fun load (buffer : AutoDataInput ) : Dictionary {
562
+ fun load (buffer : AutoDataInput ): Dictionary {
518
563
// wordList.clear();
519
564
// word2int_.clear();
520
565
@@ -530,7 +575,7 @@ class Dictionary(private val args: Args) {
530
575
// size 189997 18万的词汇
531
576
// val byteArray = ByteArray(1024)
532
577
for (i in 0 until size) {
533
- val e = Entry (buffer.readUTF(),buffer.readLong(),EntryType .fromValue(buffer.readUnsignedByte().toInt()))
578
+ val e = Entry (buffer.readUTF(), buffer.readLong(), EntryType .fromValue(buffer.readUnsignedByte().toInt()))
534
579
wordList.add(e)
535
580
word_hash_2_id[find(e.word)] = i
536
581
}
@@ -558,6 +603,7 @@ class Dictionary(private val args: Args) {
558
603
*/
559
604
560
605
val Empty_IntArrayList = IntArrayList (0 )
606
+
561
607
data class Entry (
562
608
val word : String ,
563
609
var count : Long ,
0 commit comments