|
| 1 | +# fmt: off |
| 2 | + |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import torch |
| 6 | +import pytest |
| 7 | +import triton |
| 8 | +import triton.language as tl |
| 9 | + |
| 10 | +def matching_int(dtype): |
| 11 | + if dtype.primitive_bitwidth == 8: |
| 12 | + return torch.int8 |
| 13 | + elif dtype.primitive_bitwidth == 16: |
| 14 | + return torch.int16 |
| 15 | + elif dtype.primitive_bitwidth == 32: |
| 16 | + return torch.int32 |
| 17 | + elif dtype.primitive_bitwidth == 64: |
| 18 | + return torch.int64 |
| 19 | + else: |
| 20 | + raise ValueError('unsupported number of bits') |
| 21 | + |
| 22 | +@triton.jit |
| 23 | +def type_convert_triton(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr): |
| 24 | + |
| 25 | + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 26 | + |
| 27 | + x = tl.load(src + idxs) |
| 28 | + y = x.to(dst.dtype.element_ty, fp_downcast_rounding=rounding) |
| 29 | + tl.store(dst + idxs, y) |
| 30 | + |
| 31 | + |
| 32 | +def launch_type_convert_triton(src, src_dtype, dst_dtype, rounding=None, BLOCK_SIZE=4096): |
| 33 | + |
| 34 | + dst = torch.empty(src.shape, dtype=matching_int(dst_dtype), device='cuda') |
| 35 | + type_convert_triton[(src.shape[0] // BLOCK_SIZE,)](triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), rounding, BLOCK_SIZE) |
| 36 | + return dst |
| 37 | + |
| 38 | + |
| 39 | +@triton.jit |
| 40 | +def exhaustive_populate(dst, offset, BLOCK_SIZE : tl.constexpr, force_odd : tl.constexpr, output_bits : tl.constexpr, max_repr : tl.constexpr): |
| 41 | + |
| 42 | + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 43 | + vals = (idxs + offset).to(tl.uint32) |
| 44 | + |
| 45 | + # pseudorandom permutation: |
| 46 | + multiplier = vals << 1 |
| 47 | + multiplier += 3511 |
| 48 | + vals *= multiplier |
| 49 | + |
| 50 | + if force_odd: |
| 51 | + vals *= 2 |
| 52 | + vals += 1 |
| 53 | + |
| 54 | + if (output_bits == 8): |
| 55 | + vals &= 0xff |
| 56 | + avals = vals & 0x7f |
| 57 | + elif (output_bits == 16): |
| 58 | + vals &= 0xffff |
| 59 | + avals = vals & 0x7fff |
| 60 | + elif (output_bits == 32): |
| 61 | + avals = vals & 0x7fffffff |
| 62 | + |
| 63 | + vals = tl.where(avals <= max_repr, vals, 0) |
| 64 | + |
| 65 | + if (output_bits == 8): |
| 66 | + vals = vals.to(tl.uint8) |
| 67 | + elif (output_bits == 16): |
| 68 | + vals = vals.to(tl.uint16) |
| 69 | + |
| 70 | + vals = vals.to(dst.dtype.element_ty, bitcast=True) |
| 71 | + tl.store(dst + idxs, vals) |
| 72 | + |
| 73 | + |
| 74 | +def launch_exhaustive_populate(dst_dtype, offset, numel, force_odd, output_bits, max_repr, BLOCK_SIZE=4096): |
| 75 | + |
| 76 | + assert(numel % BLOCK_SIZE == 0) |
| 77 | + dst = torch.empty((numel,), dtype=matching_int(dst_dtype), device='cuda') |
| 78 | + exhaustive_populate[(numel // BLOCK_SIZE,)](triton.reinterpret(dst, dst_dtype), offset, BLOCK_SIZE, force_odd, output_bits, max_repr) |
| 79 | + return dst |
| 80 | + |
| 81 | + |
| 82 | +@triton.jit |
| 83 | +def arbitrary_fp32_downcast(x, rounding : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): |
| 84 | + |
| 85 | + tl.static_assert(x.dtype == tl.float32, "input must be float32") |
| 86 | + numbits_dst : tl.constexpr = 1 + exponent_bits + mantissa_bits |
| 87 | + tl.static_assert((numbits_dst == 8) or (numbits_dst == 16), "numbits_dst must be 8 or 16") |
| 88 | + |
| 89 | + x = x.to(tl.uint32, bitcast=True) |
| 90 | + |
| 91 | + mantissa = (x & 0x7fffff) |
| 92 | + exponent = ((x >> 23) & 0xff).to(tl.int32) |
| 93 | + mantissa = tl.where(exponent == 0, mantissa, mantissa + 0x800000).to(tl.int32) |
| 94 | + exponent = tl.where(exponent == 0, exponent, exponent - 1) |
| 95 | + |
| 96 | + sign = (x >> 31) |
| 97 | + |
| 98 | + exponent = exponent + exponent_bias - 127 |
| 99 | + adjustment : tl.constexpr = 0.5 ** (23 - mantissa_bits) |
| 100 | + mantissa = mantissa.to(tl.float32) * adjustment |
| 101 | + |
| 102 | + # make exponent nonnegative: |
| 103 | + mantissa = tl.where(exponent > -16, mantissa, 0.0) # destination has fewer than 16 mantissa bits, so safe |
| 104 | + exponent = tl.where(exponent > -16, exponent, 0) |
| 105 | + mantissa = tl.where(exponent > -8, mantissa, mantissa * 0.00390625) |
| 106 | + exponent = tl.where(exponent > -8, exponent, exponent + 8) |
| 107 | + mantissa = tl.where(exponent > -4, mantissa, mantissa * 0.0625) |
| 108 | + exponent = tl.where(exponent > -4, exponent, exponent + 4) |
| 109 | + mantissa = tl.where(exponent > -2, mantissa, mantissa * 0.25) |
| 110 | + exponent = tl.where(exponent > -2, exponent, exponent + 2) |
| 111 | + mantissa = tl.where(exponent > -1, mantissa, mantissa * 0.5) |
| 112 | + exponent = tl.where(exponent > -1, exponent, exponent + 1) |
| 113 | + |
| 114 | + if rounding == 'rtne': |
| 115 | + mantissa = tl.inline_asm_elementwise("""{ |
| 116 | + cvt.rni.s32.f32 $0, $1; |
| 117 | +}""", "=r,r", [mantissa,], dtype=tl.int32, is_pure=True, pack=1).to(tl.uint32) |
| 118 | + elif rounding == 'rtz': |
| 119 | + mantissa = tl.inline_asm_elementwise("""{ |
| 120 | + cvt.rzi.s32.f32 $0, $1; |
| 121 | +}""", "=r,r", [mantissa,], dtype=tl.int32, is_pure=True, pack=1).to(tl.uint32) |
| 122 | + else: |
| 123 | + raise ValueError('unrecognized rounding mode') |
| 124 | + |
| 125 | + # Reassemble output floating-point representation: |
| 126 | + exponent = exponent.to(tl.uint32) |
| 127 | + y = (sign << (exponent_bits + mantissa_bits)) + (exponent << mantissa_bits) + mantissa |
| 128 | + if numbits_dst == 8: |
| 129 | + y = y.to(tl.uint8) |
| 130 | + elif numbits_dst == 16: |
| 131 | + y = y.to(tl.uint16) |
| 132 | + return y |
| 133 | + |
| 134 | + |
| 135 | +@triton.jit |
| 136 | +def downcast_emulated(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): |
| 137 | + |
| 138 | + tl.static_assert(src.dtype.element_ty == tl.float32, "src dtype must be float32") |
| 139 | + |
| 140 | + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 141 | + x = tl.load(src + idxs) |
| 142 | + y = arbitrary_fp32_downcast(x, rounding, exponent_bits, mantissa_bits, exponent_bias) |
| 143 | + y = y.to(dst.dtype.element_ty, bitcast=True) |
| 144 | + tl.store(dst + idxs, y) |
| 145 | + |
| 146 | + |
| 147 | +def launch_downcast_emulated(src, src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, BLOCK_SIZE=4096): |
| 148 | + |
| 149 | + dst = torch.empty(src.shape, dtype=matching_int(dst_dtype), device='cuda') |
| 150 | + downcast_emulated[(src.shape[0] // BLOCK_SIZE,)]( |
| 151 | + triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), rounding, BLOCK_SIZE, exponent_bits, mantissa_bits, exponent_bias) |
| 152 | + return dst |
| 153 | + |
| 154 | + |
| 155 | +@triton.jit |
| 156 | +def upcast_emulated(src, dst, BLOCK_SIZE : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): |
| 157 | + |
| 158 | + exponent_compensator : tl.constexpr = 2.0 ** (127 - exponent_bias) |
| 159 | + |
| 160 | + numbits_src : tl.constexpr = 1 + exponent_bits + mantissa_bits |
| 161 | + tl.static_assert((numbits_src == 8) or (numbits_src == 16), "numbits_src must be 8 or 16") |
| 162 | + tl.static_assert(dst.dtype.element_ty == tl.float32, "dst dtype must be float32") |
| 163 | + |
| 164 | + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 165 | + |
| 166 | + x = tl.load(src + idxs) |
| 167 | + |
| 168 | + if numbits_src == 8: |
| 169 | + x = x.to(tl.uint8, bitcast=True) |
| 170 | + elif numbits_src == 16: |
| 171 | + x = x.to(tl.uint16, bitcast=True) |
| 172 | + |
| 173 | + x = x.to(tl.uint32) |
| 174 | + |
| 175 | + mantissa_mask : tl.constexpr = (1 << mantissa_bits) - 1 |
| 176 | + exponent_mask : tl.constexpr = (1 << exponent_bits) - 1 |
| 177 | + |
| 178 | + mantissa = x & mantissa_mask |
| 179 | + exponent = (x >> mantissa_bits) & exponent_mask |
| 180 | + sign = (x >> (numbits_src - 1)) |
| 181 | + |
| 182 | + y = (sign << 31) | (exponent << 23) | (mantissa << (23 - mantissa_bits)) |
| 183 | + y = y.to(tl.float32, bitcast=True) |
| 184 | + y = y * exponent_compensator |
| 185 | + |
| 186 | + tl.store(dst + idxs, y) |
| 187 | + |
| 188 | + |
| 189 | +def launch_upcast_emulated(src, exponent_bits, mantissa_bits, exponent_bias, BLOCK_SIZE=4096): |
| 190 | + |
| 191 | + dst = torch.empty(src.shape, dtype=torch.int32, device='cuda') |
| 192 | + upcast_emulated[(src.shape[0] // BLOCK_SIZE,)](src, triton.reinterpret(dst, tl.float32), BLOCK_SIZE, exponent_bits, mantissa_bits, exponent_bias) |
| 193 | + return dst |
| 194 | + |
| 195 | + |
| 196 | +def downcast_test(src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, max_repr, offset): |
| 197 | + |
| 198 | + src = launch_exhaustive_populate(src_dtype, offset << 24, 2**24, False, src_dtype.primitive_bitwidth, max_repr) |
| 199 | + dst = launch_type_convert_triton(src, src_dtype, dst_dtype, rounding) |
| 200 | + src = launch_type_convert_triton(src, src_dtype, tl.float32) |
| 201 | + |
| 202 | + dst2 = launch_downcast_emulated(src, tl.float32, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias) |
| 203 | + |
| 204 | + dst = launch_upcast_emulated(dst, exponent_bits, mantissa_bits, exponent_bias) |
| 205 | + dst2 = launch_upcast_emulated(dst2, exponent_bits, mantissa_bits, exponent_bias) |
| 206 | + |
| 207 | + if not (torch.equal(dst, dst2)): |
| 208 | + |
| 209 | + print('Error!!!') |
| 210 | + |
| 211 | + dst = dst.cpu().detach().numpy() |
| 212 | + dst2 = dst2.cpu().detach().numpy() |
| 213 | + src = src.cpu().detach().numpy() |
| 214 | + |
| 215 | + print(src[dst != dst2][0]) |
| 216 | + print(dst[dst != dst2][0]) |
| 217 | + print(dst2[dst != dst2][0]) |
| 218 | + print(hex(src.view(np.uint32)[dst != dst2][0])) |
| 219 | + print(hex(dst.view(np.uint32)[dst != dst2][0])) |
| 220 | + print(hex(dst2.view(np.uint32)[dst != dst2][0])) |
| 221 | + print('') |
| 222 | + raise ValueError('%d elements mismatch' % (dst != dst2).sum()) |
| 223 | + |
| 224 | + |
| 225 | +def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bias, max_repr): |
| 226 | + |
| 227 | + numbits_src = exponent_bits + mantissa_bits + 1 |
| 228 | + |
| 229 | + src = launch_exhaustive_populate(src_dtype, 0, 65536, False, numbits_src, max_repr) |
| 230 | + |
| 231 | + dst = launch_type_convert_triton(src, src_dtype, dst_dtype) |
| 232 | + dst = launch_type_convert_triton(dst, dst_dtype, tl.float32) |
| 233 | + |
| 234 | + dst2 = launch_upcast_emulated(src, exponent_bits, mantissa_bits, exponent_bias) |
| 235 | + |
| 236 | + assert(torch.equal(dst, dst2)) |
| 237 | + |
| 238 | + |
| 239 | +@pytest.mark.parametrize("src_dtype, dst_dtype", [ |
| 240 | + ('float16', 'float32'), |
| 241 | + ('bfloat16', 'float32'), |
| 242 | +
|
| 243 | + ('float8e5', 'float16'), |
| 244 | + ('float8e5', 'bfloat16'), |
| 245 | + ('float8e5', 'float32'), |
| 246 | +
|
| 247 | + ('float8e4b15', 'float16'), |
| 248 | + # ('float8e4b15', 'bfloat16'), # Unsupported conversion from f8E4M3B11FNUZ to bf16 |
| 249 | + ('float8e4b15', 'float32'), |
| 250 | +
|
| 251 | + ('float8e4nv', 'float16'), |
| 252 | + ('float8e4nv', 'bfloat16'), |
| 253 | + ('float8e4nv', 'float32'), |
| 254 | +]) |
| 255 | +def test_typeconvert_upcast(src_dtype, dst_dtype): |
| 256 | + |
| 257 | + if src_dtype == 'float8e4nv' and torch.cuda.get_device_capability(0) < (9, 0): |
| 258 | + pytest.skip("float8e4nv upcast tests only supported on compute capability 9.0+") |
| 259 | + |
| 260 | + # dtype : (exponent_bits, mantissa_bits, exponent_bias, max_repr) |
| 261 | + stuff = { |
| 262 | + 'float8e4b15': (4, 3, 15, 0x7e), |
| 263 | + 'float8e4nv': (4, 3, 7, 0x7e), |
| 264 | + 'float8e5': (5, 2, 15, 0x7b), |
| 265 | + 'float16': (5, 10, 15, 0x7bff), |
| 266 | + 'bfloat16': (8, 7, 127, 0x7f7f), |
| 267 | + }[src_dtype] |
| 268 | + |
| 269 | + upcast_test(getattr(tl, src_dtype), getattr(tl, dst_dtype), *stuff) |
| 270 | + |
| 271 | +@pytest.mark.parametrize("src_dtype, dst_dtype, rounding, max_repr", [ |
| 272 | + ('float32', 'float16', 'rtne', 0x477fe000), |
| 273 | + ('float32', 'float16', 'rtz', 0x477fe000), |
| 274 | + ('float32', 'bfloat16', 'rtne', 0x7f7f0000), |
| 275 | + ('float32', 'bfloat16', 'rtz', 0x7f7f0000), |
| 276 | + ('float32', 'float8e5', 'rtne', 0x47600000), |
| 277 | + ('float32', 'float8e5', 'rtz', 0x47600000), |
| 278 | + ('float32', 'float8e4nv', 'rtne', 0x43e00000), |
| 279 | + # ('float32', 'float8e4b15', 'rtne', 0x3fe00000), # Skip, no HW rtne conversion from f32 to f8e4b15 |
| 280 | +
|
| 281 | + ('bfloat16', 'float8e5', 'rtne', 0x4760), |
| 282 | + ('bfloat16', 'float8e4nv', 'rtne', 0x43e0), |
| 283 | +
|
| 284 | + ('float16', 'float8e5', 'rtne', 0x7b00), |
| 285 | + ('float16', 'float8e4nv', 'rtne', 0x5f00), |
| 286 | +]) |
| 287 | +def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr): |
| 288 | + |
| 289 | + if src_dtype != 'float32' and torch.cuda.get_device_capability(0) < (9, 0): |
| 290 | + pytest.skip("non-float32 downcast tests only supported on compute capability 9.0+") |
| 291 | + |
| 292 | + if dst_dtype.startswith('float8') and rounding == 'rtne' and torch.cuda.get_device_capability(0) < (9, 0): |
| 293 | + pytest.skip("float8 downcast with RTNE rounding tests only supported on compute capability 9.0+") |
| 294 | + |
| 295 | + # dtype : (exponent_bits, mantissa_bits, exponent_bias) |
| 296 | + stuff = { |
| 297 | + 'float16': (5, 10, 15), |
| 298 | + 'bfloat16': (8, 7, 127), |
| 299 | + 'float8e5': (5, 2, 15), |
| 300 | + 'float8e4b15': (4, 3, 15), |
| 301 | + 'float8e4nv': (4, 3, 7), |
| 302 | + }[dst_dtype] |
| 303 | + |
| 304 | + for i in range(256): |
| 305 | + downcast_test(getattr(tl, src_dtype), getattr(tl, dst_dtype), rounding, *stuff, max_repr, i) |
0 commit comments