@@ -50,11 +50,23 @@ void dequantize(float *code, unsigned char *A, float *out, int n)
5050 CUDA_CHECK_RETURN (cudaPeekAtLastError ());
5151}
5252
53- template <typename T, int STOCHASTIC> void quantizeBlockwise (float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n)
53+ template <typename T, int STOCHASTIC> void quantizeBlockwise (float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
5454{
55- int num_blocks = n/4096 ;
56- num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1 ;
57- kQuantizeBlockwise <T, 4096 , 4 , STOCHASTIC><<<num_blocks, 1024 >>> (code, A, absmax, out, rand, rand_offset, n);
55+ int num_blocks = n/blocksize;
56+ num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1 ;
57+ if (STOCHASTIC == 1 )
58+ assert (blocksize == 4096 );
59+
60+ if (blocksize == 4096 )
61+ kQuantizeBlockwise <T, 4096 , 4 , STOCHASTIC><<<num_blocks, 1024 >>> (code, A, absmax, out, rand, rand_offset, n);
62+ else if (blocksize == 2048 )
63+ kQuantizeBlockwise <T, 2048 , 4 , 0 ><<<num_blocks, 512 >>> (code, A, absmax, out, rand, rand_offset, n);
64+ else if (blocksize == 1024 )
65+ kQuantizeBlockwise <T, 1024 , 4 , 0 ><<<num_blocks, 256 >>> (code, A, absmax, out, rand, rand_offset, n);
66+ else if (blocksize == 512 )
67+ kQuantizeBlockwise <T, 512 , 2 , 0 ><<<num_blocks, 256 >>> (code, A, absmax, out, rand, rand_offset, n);
68+
69+
5870 CUDA_CHECK_RETURN (cudaPeekAtLastError ());
5971}
6072
@@ -66,6 +78,11 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
6678 kDequantizeBlockwise <T, 4096 , 1024 , 4 ><<<num_blocks, 4096 /4 >>> (code, A, absmax, out, n);
6779 else if (blocksize == 2048 )
6880 kDequantizeBlockwise <T, 2048 , 512 , 4 ><<<num_blocks, 2048 /4 >>> (code, A, absmax, out, n);
81+ else if (blocksize == 1024 )
82+ kDequantizeBlockwise <T, 1024 , 256 , 4 ><<<num_blocks, 1024 /4 >>> (code, A, absmax, out, n);
83+ else if (blocksize == 512 )
84+ kDequantizeBlockwise <T, 512 , 256 , 2 ><<<num_blocks, 512 /2 >>> (code, A, absmax, out, n);
85+
6986 CUDA_CHECK_RETURN (cudaPeekAtLastError ());
7087}
7188
@@ -659,10 +676,10 @@ template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows,
659676template void estimateQuantiles (half *A, float *code, float offset, int n);
660677template void estimateQuantiles (float *A, float *code, float offset, int n);
661678
662- template void quantizeBlockwise<half, 0 >(float * code, half *A, float *absmax, unsigned char *out, float * rand, int rand_offset, const int n);
663- template void quantizeBlockwise<float , 0 >(float * code, float *A, float *absmax, unsigned char *out, float * rand, int rand_offset, const int n);
664- template void quantizeBlockwise<half, 1 >(float * code, half *A, float *absmax, unsigned char *out, float * rand, int rand_offset, const int n);
665- template void quantizeBlockwise<float , 1 >(float * code, float *A, float *absmax, unsigned char *out, float * rand, int rand_offset, const int n);
679+ template void quantizeBlockwise<half, 0 >(float * code, half *A, float *absmax, unsigned char *out, float * rand, int rand_offset, int blocksize, const int n);
680+ template void quantizeBlockwise<float , 0 >(float * code, float *A, float *absmax, unsigned char *out, float * rand, int rand_offset, int blocksize, const int n);
681+ template void quantizeBlockwise<half, 1 >(float * code, half *A, float *absmax, unsigned char *out, float * rand, int rand_offset, int blocksize, const int n);
682+ template void quantizeBlockwise<float , 1 >(float * code, float *A, float *absmax, unsigned char *out, float * rand, int rand_offset, int blocksize, const int n);
666683template void dequantizeBlockwise<half>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
667684template void dequantizeBlockwise<float >(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
668685
0 commit comments