Skip to content

Commit 7e401c0

Browse files
dskhudiafacebook-github-bot
authored andcommitted
part2: Move embedding quantization kernels to fbgemm for better sharing between C2/PT (#425)
Summary: Pull Request resolved: #425 8bit with float scale and bias. Test and benchmark added. ``` With scale and bias as float bit_rate, rows, cols, elems_per_usec, GB/Sec 8, 100, 16, 556.20, 2.22 8, 100, 64, 1022.51, 4.09 8, 100, 128, 1121.43, 4.49 8, 100, 256, 1292.61, 5.17 8, 100, 512, 1526.69, 6.11 8, 100, 1024, 1407.09, 5.63 8, 100, 2048, 1620.34, 6.48 8, 120, 16, 562.60, 2.25 8, 120, 64, 1058.52, 4.23 8, 120, 128, 1082.74, 4.33 8, 120, 256, 1382.87, 5.53 8, 120, 512, 1513.15, 6.05 8, 120, 1024, 1441.19, 5.76 8, 120, 2048, 1634.99, 6.54 8, 1000, 16, 598.05, 2.39 8, 1000, 64, 1151.16, 4.60 8, 1000, 128, 1071.58, 4.29 8, 1000, 256, 1278.66, 5.11 8, 1000, 512, 1441.13, 5.76 8, 1000, 1024, 1605.48, 6.42 8, 1000, 2048, 1764.24, 7.06 ``` Reviewed By: supriyar Differential Revision: D23455486 fbshipit-source-id: e0dea307c42d614747302544a7179fa40194dad6
1 parent 1289a1f commit 7e401c0

File tree

6 files changed

+264
-10
lines changed

6 files changed

+264
-10
lines changed

bench/EmbeddingQuantizeBenchmark.cc

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <initializer_list>
99
#include <iomanip>
1010
#include <iostream>
11+
#include <vector>
1112

1213
#ifdef _OPENMP
1314
#include <omp.h>
@@ -20,33 +21,57 @@
2021
using namespace std;
2122
using namespace fbgemm;
2223

24+
// T is the type of scale and bias
25+
template <typename T>
2326
void performance_test() {
2427
constexpr int NWARMUP = 4;
2528
constexpr int NITER = 256;
2629

30+
if (is_same<T, float16>::value) {
31+
cout << "With scale and bias as float16" << endl;
32+
} else {
33+
cout << "With scale and bias as float" << endl;
34+
}
2735
cout << setw(8) << "bit_rate"
2836
<< ", " << setw(6) << "rows"
2937
<< "," << setw(6) << "cols"
3038
<< "," << setw(16) << "elems_per_usec"
3139
<< "," << setw(10) << "GB/Sec" << endl;
32-
for (int bit_rate : {2, 4, 8}) {
40+
std::vector<int> bit_rates;
41+
if (is_same<T, float16>::value) {
42+
bit_rates = {2, 4, 8};
43+
} else {
44+
// float
45+
bit_rates = {8};
46+
}
47+
for (int bit_rate : bit_rates) {
3348
for (int rowSize : {100, 120, 1000}) {
3449
for (int colSize : {16, 64, 128, 256, 512, 1024, 2048}) {
3550
aligned_vector<float> inpVec(rowSize * colSize);
3651
randFill<float>(inpVec, -10.0f, 10.0f);
3752

38-
int elements_per_byte = 8 / bit_rate;
39-
int out_emb_cols =
40-
(colSize + elements_per_byte - 1) / elements_per_byte;
53+
int out_emb_cols = colSize;
54+
55+
if (is_same<T, float16>::value) {
56+
int elements_per_byte = 8 / bit_rate;
57+
out_emb_cols = (colSize + elements_per_byte - 1) / elements_per_byte;
58+
}
4159
int outVecSize = rowSize * (out_emb_cols + 2 * sizeof(float16));
4260
aligned_vector<uint8_t> outVec(outVecSize);
4361

4462
double duration = 0.0f;
4563

4664
duration = measureWithWarmup(
4765
[&]() {
48-
FloatToFusedNBitRowwiseQuantizedSBHalf(
49-
bit_rate, inpVec.data(), rowSize, colSize, outVec.data());
66+
is_same<T, float16>::value
67+
? FloatToFusedNBitRowwiseQuantizedSBHalf(
68+
bit_rate,
69+
inpVec.data(),
70+
rowSize,
71+
colSize,
72+
outVec.data())
73+
: FloatToFused8BitRowwiseQuantizedSBFloat(
74+
inpVec.data(), rowSize, colSize, outVec.data());
5075
},
5176
NWARMUP,
5277
NITER,
@@ -63,8 +88,10 @@ void performance_test() {
6388

6489
cout << setw(8) << bit_rate << "," << setw(6) << rowSize << ", "
6590
<< setw(6) << colSize << ",";
66-
cout << setw(16) << std::fixed << std::setprecision(2) << elements_per_usec << ", ";
67-
cout << setw(10) << std::fixed << std::setprecision(2) << gigabyes_per_sec << endl;
91+
cout << setw(16) << std::fixed << std::setprecision(2)
92+
<< elements_per_usec << ", ";
93+
cout << setw(10) << std::fixed << std::setprecision(2)
94+
<< gigabyes_per_sec << endl;
6895
} // for each cols
6996
} // for each rows
7097
} // for each bit_rate
@@ -78,6 +105,7 @@ int main() {
78105
omp_set_num_threads(1);
79106
}
80107
#endif
81-
performance_test();
108+
performance_test<float16>();
109+
performance_test<float>();
82110
return 0;
83111
}

include/fbgemm/QuantUtils.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,4 +274,31 @@ FBGEMM_API void FloatToFusedNBitRowwiseQuantizedSBHalfRef(
274274
int input_rows,
275275
int input_columns,
276276
std::uint8_t* output);
277+
278+
/**
279+
* Convert float inputs to rowwise quantized (8-bit) outputs.
280+
* Scale and Bias are in float. Each row's Scale and Bias are stored in
281+
* the row itself (fused) at the end.
282+
*
283+
* This version intentionally supports only 8-bit because we want to discourage
284+
* the usage of float scale and bias with 2 and 4 bit cases as that diminishes
285+
* the overall memory savings.
286+
*
287+
*/
288+
FBGEMM_API void FloatToFused8BitRowwiseQuantizedSBFloat(
289+
const float* input,
290+
int input_rows,
291+
int input_columns,
292+
std::uint8_t* output);
293+
294+
/**
295+
* Same as FloatToFused8BitRowwiseQuantizedSBFloat but unoptimized.
296+
* This should not be called directly except in testing.
297+
*/
298+
FBGEMM_API void FloatToFused8BitRowwiseQuantizedSBFloatRef(
299+
const float* input,
300+
int input_rows,
301+
int input_columns,
302+
std::uint8_t* output);
303+
277304
} // namespace fbgemm

include/fbgemm/QuantUtilsAvx2.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,4 +126,10 @@ void FloatToFusedNBitRowwiseQuantizedSBHalfAvx2(
126126
int input_columns,
127127
std::uint8_t* output);
128128

129+
void FloatToFused8BitRowwiseQuantizedSBFloatAvx2(
130+
const float* input,
131+
int input_rows,
132+
int input_columns,
133+
std::uint8_t* output);
134+
129135
} // namespace fbgemm

src/QuantUtils.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,4 +555,48 @@ void FloatToFusedNBitRowwiseQuantizedSBHalf(
555555
}
556556
}
557557

558+
void FloatToFused8BitRowwiseQuantizedSBFloatRef(
559+
const float* input,
560+
int input_rows,
561+
int input_columns,
562+
std::uint8_t* output) {
563+
constexpr float kEpsilon = 1e-8f;
564+
565+
int output_columns = input_columns + 2 * sizeof(float);
566+
for (std::size_t row = 0; row < input_rows; ++row) {
567+
const float* input_row = input + row * input_columns;
568+
std::uint8_t* output_row = output + row * output_columns;
569+
float* output_row_scale_bias =
570+
reinterpret_cast<float*>(output_row + input_columns);
571+
572+
float minimum_element =
573+
*std::min_element(input_row, input_row + input_columns);
574+
float maximum_element =
575+
*std::max_element(input_row, input_row + input_columns);
576+
float range = maximum_element - minimum_element;
577+
578+
output_row_scale_bias[0] = range / 255.0f;
579+
output_row_scale_bias[1] = minimum_element;
580+
const auto inverse_scale = 255.0f / (range + kEpsilon);
581+
for (std::size_t col = 0; col < input_columns; ++col) {
582+
output_row[col] =
583+
std::lrintf((input_row[col] - minimum_element) * inverse_scale);
584+
}
585+
}
586+
}
587+
588+
void FloatToFused8BitRowwiseQuantizedSBFloat(
589+
const float* input,
590+
int input_rows,
591+
int input_columns,
592+
std::uint8_t* output) {
593+
if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
594+
FloatToFused8BitRowwiseQuantizedSBFloatAvx2(
595+
input, input_rows, input_columns, output);
596+
} else {
597+
FloatToFused8BitRowwiseQuantizedSBFloatRef(
598+
input, input_rows, input_columns, output);
599+
}
600+
}
601+
558602
} // namespace fbgemm

src/QuantUtilsAvx2.cc

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1622,4 +1622,105 @@ template void FloatToFusedNBitRowwiseQuantizedSBHalfAvx2<8>(
16221622
int input_columns,
16231623
std::uint8_t* output);
16241624

1625+
void FloatToFused8BitRowwiseQuantizedSBFloatAvx2(
1626+
const float* input,
1627+
int input_rows,
1628+
int input_columns,
1629+
std::uint8_t* output) {
1630+
constexpr int VLEN = 8;
1631+
constexpr float kEpsilon = 1e-8f;
1632+
1633+
__m256i permute_mask1_v =
1634+
_mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
1635+
// clang-format off
1636+
__m256i shuffle_mask_v = _mm256_set_epi8(
1637+
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
1638+
0xff, 0xff, 0xff, 0xff, 0x0c, 0x08, 0x04, 0x00,
1639+
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
1640+
0xff, 0xff, 0xff, 0xff, 0x0c, 0x08, 0x04, 0x00);
1641+
// clang-format on
1642+
1643+
__m256i permute_mask2_v =
1644+
_mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
1645+
1646+
int output_columns = input_columns + 2 * sizeof(float);
1647+
for (std::size_t row = 0; row < input_rows; ++row) {
1648+
const float* input_row = input + row * input_columns;
1649+
std::uint8_t* output_row = output + row * output_columns;
1650+
float* output_row_scale_bias =
1651+
reinterpret_cast<float*>(output_row + input_columns);
1652+
1653+
float minimum_element = FLT_MAX;
1654+
float maximum_element = -FLT_MAX;
1655+
__m256 min_v = _mm256_set1_ps(minimum_element);
1656+
__m256 max_v = _mm256_set1_ps(maximum_element);
1657+
std::size_t col;
1658+
for (col = 0; col < input_columns / VLEN * VLEN; col += VLEN) {
1659+
__m256 in_v = _mm256_loadu_ps(input_row + col);
1660+
min_v = _mm256_min_ps(min_v, in_v);
1661+
max_v = _mm256_max_ps(max_v, in_v);
1662+
}
1663+
alignas(64) float min_buf[VLEN], max_buf[VLEN];
1664+
_mm256_store_ps(min_buf, min_v);
1665+
_mm256_store_ps(max_buf, max_v);
1666+
for (int i = 0; i < VLEN; ++i) {
1667+
minimum_element = std::min(minimum_element, min_buf[i]);
1668+
maximum_element = std::max(maximum_element, max_buf[i]);
1669+
}
1670+
for (; col < input_columns; ++col) {
1671+
minimum_element = std::min(minimum_element, input_row[col]);
1672+
maximum_element = std::max(maximum_element, input_row[col]);
1673+
}
1674+
1675+
float range = maximum_element - minimum_element;
1676+
1677+
output_row_scale_bias[0] = range / 255.0f;
1678+
output_row_scale_bias[1] = minimum_element;
1679+
const auto inverse_scale = 255.0f / (range + kEpsilon);
1680+
min_v = _mm256_set1_ps(minimum_element);
1681+
__m256 inverse_scale_v = _mm256_set1_ps(inverse_scale);
1682+
1683+
for (col = 0; col < input_columns / (4 * VLEN) * (4 * VLEN);
1684+
col += 4 * VLEN) {
1685+
__m256i x_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
1686+
_mm256_sub_ps(_mm256_loadu_ps(input_row + col), min_v),
1687+
inverse_scale_v));
1688+
__m256i y_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
1689+
_mm256_sub_ps(_mm256_loadu_ps(input_row + col + VLEN), min_v),
1690+
inverse_scale_v));
1691+
__m256i z_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
1692+
_mm256_sub_ps(_mm256_loadu_ps(input_row + col + 2 * VLEN), min_v),
1693+
inverse_scale_v));
1694+
__m256i w_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
1695+
_mm256_sub_ps(_mm256_loadu_ps(input_row + col + 3 * VLEN), min_v),
1696+
inverse_scale_v));
1697+
1698+
// An instruction sequence to save 32 32-bit integers as 8-bit integers
1699+
__m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v);
1700+
__m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v);
1701+
__m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
1702+
xyzw_packed_v =
1703+
_mm256_permutevar8x32_epi32(xyzw_packed_v, permute_mask1_v);
1704+
_mm256_storeu_si256(
1705+
reinterpret_cast<__m256i*>(output_row + col), xyzw_packed_v);
1706+
}
1707+
for (; col < input_columns / VLEN * VLEN; col += VLEN) {
1708+
__m256i rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
1709+
_mm256_sub_ps(_mm256_loadu_ps(input_row + col), min_v),
1710+
inverse_scale_v));
1711+
1712+
// An instruction sequence to save 8 32-bit integers as 8-bit integers
1713+
rounded_v = _mm256_shuffle_epi8(rounded_v, shuffle_mask_v);
1714+
rounded_v = _mm256_permutevar8x32_epi32(rounded_v, permute_mask2_v);
1715+
_mm_storel_epi64(
1716+
reinterpret_cast<__m128i*>(output_row + col),
1717+
_mm256_castsi256_si128(rounded_v));
1718+
}
1719+
for (; col < input_columns; ++col) {
1720+
output_row[col] =
1721+
std::lrintf((input_row[col] - minimum_element) * inverse_scale);
1722+
}
1723+
}
1724+
}
1725+
16251726
} // namespace fbgemm

test/QuantUtilsTest.cc

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ class FusedQuantizeDequantizeTest : public testing::TestWithParam<int> {};
3434
class EmbeddingQuantizeTest
3535
: public testing::TestWithParam<tuple<int, int, int>> {};
3636

37+
// Parameter are input rows and input columns
38+
// Scale and Bias are of type float (SBFloat)
39+
class EmbeddingQuantizeSBFloatTest
40+
: public testing::TestWithParam<tuple<int, int>> {};
41+
3742
INSTANTIATE_TEST_CASE_P(
3843
InstantiationName,
3944
QuantizeGroupwiseTest,
@@ -57,6 +62,13 @@ INSTANTIATE_TEST_CASE_P(
5762
::testing::ValuesIn({1, 2, 3}),
5863
::testing::ValuesIn({1, 2, 5, 8, 9, 16, 20, 28, 32, 33, 64, 65})));
5964

65+
INSTANTIATE_TEST_CASE_P(
66+
InstantiationName,
67+
EmbeddingQuantizeSBFloatTest,
68+
::testing::Combine(
69+
::testing::ValuesIn({1, 2, 3}),
70+
::testing::ValuesIn({1, 2, 5, 8, 9, 16, 20, 28, 32, 33, 64, 65})));
71+
6072
template <typename T, layout_t LT>
6173
void ref_impl(
6274
const vector<float>& src,
@@ -185,7 +197,14 @@ ::testing::AssertionResult isQEmbeddingClose(
185197
res_ref.data() + i * ld + out_emb_cols)[1]);
186198
} else {
187199
// float scale and bias
188-
// TODO:
200+
scaleTest = reinterpret_cast<const float*>(
201+
res.data() + i * ld + out_emb_cols)[0];
202+
biasTest = reinterpret_cast<const float*>(
203+
res.data() + i * ld + out_emb_cols)[1];
204+
scaleRef = reinterpret_cast<const float*>(
205+
res_ref.data() + i * ld + out_emb_cols)[0];
206+
biasRef = reinterpret_cast<const float*>(
207+
res_ref.data() + i * ld + out_emb_cols)[1];
189208
}
190209
if (fabs(scaleTest - scaleRef) > std::numeric_limits<float>::epsilon()) {
191210
ss << " scale mismatch for row:" << i;
@@ -548,3 +567,32 @@ TEST_P(EmbeddingQuantizeTest, embeddingHalfTest) {
548567

549568
EXPECT_TRUE(isQEmbeddingClose<float16>(outVecTest, outVecRef, rows, out_emb_cols));
550569
}
570+
571+
TEST_P(EmbeddingQuantizeSBFloatTest, embeddingFloatTest) {
572+
int rows, cols;
573+
tie(rows, cols) = GetParam();
574+
575+
random_device rd;
576+
mt19937 gen(rd());
577+
578+
uniform_real_distribution<float> disFP(-10.0f, 10.0f);
579+
580+
vector<float> inpVec(rows * cols);
581+
582+
generate(inpVec.begin(), inpVec.end(), [&, disFP]() mutable {
583+
return disFP(gen);
584+
});
585+
586+
int outVecSize = rows * (cols + 2 * sizeof(float));
587+
588+
vector<uint8_t> outVecRef(outVecSize);
589+
vector<uint8_t> outVecTest(outVecSize);
590+
591+
FloatToFused8BitRowwiseQuantizedSBFloatRef(
592+
inpVec.data(), rows, cols, outVecRef.data());
593+
FloatToFused8BitRowwiseQuantizedSBFloat(
594+
inpVec.data(), rows, cols, outVecTest.data());
595+
596+
// The number of input columns is the same as the number of output columns
597+
EXPECT_TRUE(isQEmbeddingClose<float>(outVecTest, outVecRef, rows, cols));
598+
}

0 commit comments

Comments
 (0)