4
4
from __future__ import unicode_literals
5
5
6
6
import numpy as np
7
- import time
8
7
import unittest
9
8
10
9
# Must happen before importing caffe2.python.*
11
10
import caffe2 .python .fakelowp .init_shared_libs # noqa
12
11
13
- from hypothesis import given , settings
12
+ from hypothesis import given
14
13
from hypothesis import strategies as st
15
14
from caffe2 .proto import caffe2_pb2
16
- from caffe2 .python import core , workspace , dyndep
15
+ from caffe2 .python import core , workspace
17
16
from caffe2 .python .onnx .onnxifi import onnxifi_caffe2_net
18
- from caffe2 .python .onnx .tests .test_utils import TestCase
19
17
from caffe2 .python .fakelowp .test_utils import print_test_debug_info
20
18
21
19
workspace .GlobalInit (["caffe2" , "--glow_global_fp16=1" ,
22
20
"--glow_global_fused_scale_offset_fp16=1" ,
23
21
"--glow_global_force_sls_fp16_accum=1" ])
24
22
25
23
26
- class SparseLengthsSumTest (unittest .TestCase ):
24
+ class SparseLengthsSum4BitFakeNNPIFp16Test (unittest .TestCase ):
27
25
@given (seed = st .integers (0 , 65535 ))
28
26
def test_slws_fused_4bit_rowwise_all_same (self , seed ):
29
27
np .random .seed (seed )
30
28
workspace .ResetWorkspace ()
31
29
n = 1
32
30
m = 2
33
31
data = np .ones ((n , m )).astype (np .float32 ) * 0.2 - 0.1
34
-
35
32
max_segments = 5
36
33
max_segment_length = 100
37
34
num_lengths = np .random .randint (1 , max_segments + 1 )
@@ -43,7 +40,6 @@ def test_slws_fused_4bit_rowwise_all_same(self, seed):
43
40
weights = np .random .uniform (low = - 0.5 , high = 0.5 ,
44
41
size = [len (indices )]).astype (np .float32 )
45
42
weights = np .ones (len (indices )).astype (np .float32 )
46
-
47
43
pred_net = caffe2_pb2 .NetDef ()
48
44
pred_net .name = "pred"
49
45
pred_net .external_input .extend (
@@ -56,7 +52,6 @@ def test_slws_fused_4bit_rowwise_all_same(self, seed):
56
52
["Y" ],
57
53
)
58
54
)
59
-
60
55
ref_net = caffe2_pb2 .NetDef ()
61
56
ref_net .name = "ref"
62
57
ref_net .external_input .extend (
@@ -69,7 +64,6 @@ def test_slws_fused_4bit_rowwise_all_same(self, seed):
69
64
["Y" ],
70
65
)
71
66
)
72
-
73
67
workspace .FeedBlob ("data" , data )
74
68
workspace .RunOperatorOnce (
75
69
core .CreateOperator (
@@ -78,7 +72,6 @@ def test_slws_fused_4bit_rowwise_all_same(self, seed):
78
72
['quantized_data' ]
79
73
)
80
74
)
81
-
82
75
print ("quantized" , workspace .FetchBlob ("quantized_data" ))
83
76
pred_net_onnxified = onnxifi_caffe2_net (
84
77
pred_net ,
@@ -89,24 +82,18 @@ def test_slws_fused_4bit_rowwise_all_same(self, seed):
89
82
adjust_batch = True ,
90
83
use_onnx = False
91
84
)
92
-
93
85
num_onnxified_ops = sum (
94
86
1 if o .type == "Onnxifi" else 0 for o in pred_net_onnxified .op )
95
87
np .testing .assert_equal (num_onnxified_ops , 1 )
96
-
97
88
workspace .FeedBlob ("indices" , indices )
98
89
workspace .FeedBlob ("lengths" , lengths )
99
90
workspace .FeedBlob ("weights" , weights )
100
-
101
91
workspace .CreateNet (pred_net_onnxified )
102
92
workspace .CreateNet (ref_net )
103
-
104
93
workspace .RunNet (pred_net_onnxified .name )
105
94
Y_glow = workspace .FetchBlob ('Y' )
106
-
107
95
workspace .RunNet (ref_net .name )
108
96
Y_c2 = workspace .FetchBlob ('Y' )
109
-
110
97
if not np .allclose (Y_c2 , Y_glow ):
111
98
print_test_debug_info (
112
99
"slws_fused_4bit_rowwise" ,
@@ -121,33 +108,35 @@ def test_slws_fused_4bit_rowwise_all_same(self, seed):
121
108
"rowwise_diff" : (Y_glow - Y_c2 )[:, 0 ]})
122
109
assert (0 )
123
110
124
- @given (seed = st .integers (0 , 65535 ))
125
- def test_slws_fused_4bit_rowwise (self , seed ):
126
- np .random .seed (seed )
111
+
112
+ @given (
113
+ seed = st .integers (0 , 65535 ),
114
+ num_rows = st .integers (2 , 20 ),
115
+ embedding_dim = st .sampled_from ([8 , 12 , 16 , 24 , 32 , 54 , 64 , 128 ]),
116
+ batch_size = st .integers (1 , 5 ),
117
+ max_weight = st .integers (0 , 100 ),
118
+ )
119
+ def test_slws_fused_4bit_rowwise (self , seed , num_rows , embedding_dim , batch_size , max_weight ):
127
120
workspace .ResetWorkspace ()
121
+ np .random .seed (seed )
122
+ data = np .random .rand (num_rows , embedding_dim ).astype (np .float32 )
123
+ lengths = np .random .choice (np .arange (1 , num_rows ), batch_size ).astype (np .int32 )
128
124
129
- n = 20000
130
- DIM = 6
131
- data = (4 * np .random .random_sample ((n , DIM )) + 1 ).astype (np .float32 )
125
+ indices = []
126
+ for length in lengths :
127
+ indices .extend (np .random .choice (np .arange (1 , num_rows ), length ))
128
+ indices = np .asarray (indices ).astype (np .int64 )
132
129
133
- max_segments = 200
134
- max_segment_length = 200
135
- num_lengths = np .random .randint (0 , max_segments + 1 )
136
- # number of segments to run
137
- lengths = np .random .randint (2 , max_segment_length + 1 , size = num_lengths ).astype (
138
- np .int32
139
- )
140
- num_indices = np .sum (lengths )
141
- indices = np .random .randint (low = 0 , high = n , size = num_indices , dtype = np .int64 )
142
- weights = np .random .uniform (low = 0.01 , high = 0.5 , size = [len (indices )]).astype (
143
- np .float32
144
- )
130
+ weights = np .random .uniform (
131
+ low = 0 ,
132
+ high = max_weight ,
133
+ size = [len (indices )]
134
+ ).astype (np .float32 )
145
135
146
136
pred_net = caffe2_pb2 .NetDef ()
147
137
pred_net .name = "pred"
148
138
pred_net .external_input .extend (
149
- ["quantized_data" , "weights" , "indices" , "lengths" ]
150
- )
139
+ ["quantized_data" , "weights" , "indices" , "lengths" ])
151
140
pred_net .external_output .append ("Y" )
152
141
pred_net .op .add ().CopyFrom (
153
142
core .CreateOperator (
@@ -160,8 +149,7 @@ def test_slws_fused_4bit_rowwise(self, seed):
160
149
ref_net = caffe2_pb2 .NetDef ()
161
150
ref_net .name = "ref"
162
151
ref_net .external_input .extend (
163
- ["quantized_data" , "weights" , "indices" , "lengths" ]
164
- )
152
+ ["quantized_data" , "weights" , "indices" , "lengths" ])
165
153
ref_net .external_output .append ("Y" )
166
154
ref_net .op .add ().CopyFrom (
167
155
core .CreateOperator (
@@ -174,49 +162,52 @@ def test_slws_fused_4bit_rowwise(self, seed):
174
162
workspace .FeedBlob ("data" , data )
175
163
workspace .RunOperatorOnce (
176
164
core .CreateOperator (
177
- "FloatToFused4BitRowwiseQuantized" , ["data" ], ["quantized_data" ]
165
+ "FloatToFused4BitRowwiseQuantized" ,
166
+ ["data" ],
167
+ ["quantized_data" ]
178
168
)
179
169
)
180
- onnxified_net = onnxifi_caffe2_net (
170
+
171
+ pred_net_onnxified = onnxifi_caffe2_net (
181
172
pred_net ,
182
173
{},
183
- max_batch_size = max_segments ,
184
- max_seq_size = max_segments * max_segment_length ,
174
+ max_batch_size = batch_size ,
175
+ max_seq_size = batch_size * np . max ( lengths ) ,
185
176
debug = True ,
186
177
adjust_batch = True ,
187
- use_onnx = False ,
178
+ use_onnx = False
188
179
)
180
+
181
+ num_onnxified_ops = sum (
182
+ 1 if o .type == "Onnxifi" else 0 for o in pred_net_onnxified .op )
183
+ np .testing .assert_equal (num_onnxified_ops , 1 )
184
+
189
185
workspace .FeedBlob ("indices" , indices )
190
186
workspace .FeedBlob ("lengths" , lengths )
191
187
workspace .FeedBlob ("weights" , weights )
192
188
193
- workspace .CreateNet (onnxified_net )
189
+ workspace .CreateNet (pred_net_onnxified )
194
190
workspace .CreateNet (ref_net )
195
191
196
- workspace .RunNet (onnxified_net .name )
197
- Y_glow = workspace .FetchBlob ("Y" )
192
+ workspace .RunNet (pred_net_onnxified .name )
193
+ Y_glow = workspace .FetchBlob ('Y' )
198
194
199
195
workspace .RunNet (ref_net .name )
200
- Y_ref = workspace .FetchBlob ("Y" )
196
+ Y_c2 = workspace .FetchBlob ('Y' )
201
197
202
- diff = np .abs ((Y_ref - Y_glow ) / (Y_ref + 1e-8 ))
203
- max_err = np .max (diff , axis = 1 )
204
- num_offenders = (max_err > 0 ).sum ()
205
- if num_offenders > 0 :
198
+ if not np .allclose (Y_c2 , Y_glow ):
206
199
print_test_debug_info (
207
- "slws_fused_4bit" ,
208
- {
209
- "indices" : indices ,
210
- "data" : data .shape ,
211
- "lengths" : lengths ,
212
- "weights" : weights ,
213
- "Y_glow" : Y_glow ,
214
- "Y_ref" : Y_ref ,
215
- "diff" : diff ,
216
- "rowwise_diff" : np .max (diff , axis = 1 ),
217
- },
218
- )
219
- assert 0
200
+ "slws_fused_4bit_rowwise" ,
201
+ {"seed" : seed ,
202
+ "indices" : indices ,
203
+ "data" : data ,
204
+ "lengths" : lengths ,
205
+ "weights" : weights ,
206
+ "Y_c2" : Y_c2 ,
207
+ "Y_glow" : Y_glow ,
208
+ "diff" : Y_glow - Y_c2 ,
209
+ "rowwise_diff" : (Y_glow - Y_c2 )[:, 0 ]})
210
+ assert (0 )
220
211
221
212
if __name__ == '__main__' :
222
213
unittest .main ()
0 commit comments