Skip to content

Commit cd20496

Browse files
Conversion tests implementation (triton-lang#2808)
Integrating @apgoucher's work for testing floating point conversions in triton. Small additions from me include support for different rounding modes. Second attempt, using triton types this time. Co-authored-by: apgoucher [[email protected]](mailto:[email protected])
1 parent bd73596 commit cd20496

File tree

2 files changed

+305
-0
lines changed

2 files changed

+305
-0
lines changed

python/test/unit/conversion_tests.py

Whitespace-only changes.
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
# fmt: off
2+
3+
4+
import numpy as np
5+
import torch
6+
import pytest
7+
import triton
8+
import triton.language as tl
9+
10+
def matching_int(dtype):
11+
if dtype.primitive_bitwidth == 8:
12+
return torch.int8
13+
elif dtype.primitive_bitwidth == 16:
14+
return torch.int16
15+
elif dtype.primitive_bitwidth == 32:
16+
return torch.int32
17+
elif dtype.primitive_bitwidth == 64:
18+
return torch.int64
19+
else:
20+
raise ValueError('unsupported number of bits')
21+
22+
@triton.jit
23+
def type_convert_triton(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr):
24+
25+
idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
26+
27+
x = tl.load(src + idxs)
28+
y = x.to(dst.dtype.element_ty, fp_downcast_rounding=rounding)
29+
tl.store(dst + idxs, y)
30+
31+
32+
def launch_type_convert_triton(src, src_dtype, dst_dtype, rounding=None, BLOCK_SIZE=4096):
33+
34+
dst = torch.empty(src.shape, dtype=matching_int(dst_dtype), device='cuda')
35+
type_convert_triton[(src.shape[0] // BLOCK_SIZE,)](triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), rounding, BLOCK_SIZE)
36+
return dst
37+
38+
39+
@triton.jit
40+
def exhaustive_populate(dst, offset, BLOCK_SIZE : tl.constexpr, force_odd : tl.constexpr, output_bits : tl.constexpr, max_repr : tl.constexpr):
41+
42+
idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
43+
vals = (idxs + offset).to(tl.uint32)
44+
45+
# pseudorandom permutation:
46+
multiplier = vals << 1
47+
multiplier += 3511
48+
vals *= multiplier
49+
50+
if force_odd:
51+
vals *= 2
52+
vals += 1
53+
54+
if (output_bits == 8):
55+
vals &= 0xff
56+
avals = vals & 0x7f
57+
elif (output_bits == 16):
58+
vals &= 0xffff
59+
avals = vals & 0x7fff
60+
elif (output_bits == 32):
61+
avals = vals & 0x7fffffff
62+
63+
vals = tl.where(avals <= max_repr, vals, 0)
64+
65+
if (output_bits == 8):
66+
vals = vals.to(tl.uint8)
67+
elif (output_bits == 16):
68+
vals = vals.to(tl.uint16)
69+
70+
vals = vals.to(dst.dtype.element_ty, bitcast=True)
71+
tl.store(dst + idxs, vals)
72+
73+
74+
def launch_exhaustive_populate(dst_dtype, offset, numel, force_odd, output_bits, max_repr, BLOCK_SIZE=4096):
75+
76+
assert(numel % BLOCK_SIZE == 0)
77+
dst = torch.empty((numel,), dtype=matching_int(dst_dtype), device='cuda')
78+
exhaustive_populate[(numel // BLOCK_SIZE,)](triton.reinterpret(dst, dst_dtype), offset, BLOCK_SIZE, force_odd, output_bits, max_repr)
79+
return dst
80+
81+
82+
@triton.jit
83+
def arbitrary_fp32_downcast(x, rounding : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr):
84+
85+
tl.static_assert(x.dtype == tl.float32, "input must be float32")
86+
numbits_dst : tl.constexpr = 1 + exponent_bits + mantissa_bits
87+
tl.static_assert((numbits_dst == 8) or (numbits_dst == 16), "numbits_dst must be 8 or 16")
88+
89+
x = x.to(tl.uint32, bitcast=True)
90+
91+
mantissa = (x & 0x7fffff)
92+
exponent = ((x >> 23) & 0xff).to(tl.int32)
93+
mantissa = tl.where(exponent == 0, mantissa, mantissa + 0x800000).to(tl.int32)
94+
exponent = tl.where(exponent == 0, exponent, exponent - 1)
95+
96+
sign = (x >> 31)
97+
98+
exponent = exponent + exponent_bias - 127
99+
adjustment : tl.constexpr = 0.5 ** (23 - mantissa_bits)
100+
mantissa = mantissa.to(tl.float32) * adjustment
101+
102+
# make exponent nonnegative:
103+
mantissa = tl.where(exponent > -16, mantissa, 0.0) # destination has fewer than 16 mantissa bits, so safe
104+
exponent = tl.where(exponent > -16, exponent, 0)
105+
mantissa = tl.where(exponent > -8, mantissa, mantissa * 0.00390625)
106+
exponent = tl.where(exponent > -8, exponent, exponent + 8)
107+
mantissa = tl.where(exponent > -4, mantissa, mantissa * 0.0625)
108+
exponent = tl.where(exponent > -4, exponent, exponent + 4)
109+
mantissa = tl.where(exponent > -2, mantissa, mantissa * 0.25)
110+
exponent = tl.where(exponent > -2, exponent, exponent + 2)
111+
mantissa = tl.where(exponent > -1, mantissa, mantissa * 0.5)
112+
exponent = tl.where(exponent > -1, exponent, exponent + 1)
113+
114+
if rounding == 'rtne':
115+
mantissa = tl.inline_asm_elementwise("""{
116+
cvt.rni.s32.f32 $0, $1;
117+
}""", "=r,r", [mantissa,], dtype=tl.int32, is_pure=True, pack=1).to(tl.uint32)
118+
elif rounding == 'rtz':
119+
mantissa = tl.inline_asm_elementwise("""{
120+
cvt.rzi.s32.f32 $0, $1;
121+
}""", "=r,r", [mantissa,], dtype=tl.int32, is_pure=True, pack=1).to(tl.uint32)
122+
else:
123+
raise ValueError('unrecognized rounding mode')
124+
125+
# Reassemble output floating-point representation:
126+
exponent = exponent.to(tl.uint32)
127+
y = (sign << (exponent_bits + mantissa_bits)) + (exponent << mantissa_bits) + mantissa
128+
if numbits_dst == 8:
129+
y = y.to(tl.uint8)
130+
elif numbits_dst == 16:
131+
y = y.to(tl.uint16)
132+
return y
133+
134+
135+
@triton.jit
136+
def downcast_emulated(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr):
137+
138+
tl.static_assert(src.dtype.element_ty == tl.float32, "src dtype must be float32")
139+
140+
idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
141+
x = tl.load(src + idxs)
142+
y = arbitrary_fp32_downcast(x, rounding, exponent_bits, mantissa_bits, exponent_bias)
143+
y = y.to(dst.dtype.element_ty, bitcast=True)
144+
tl.store(dst + idxs, y)
145+
146+
147+
def launch_downcast_emulated(src, src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, BLOCK_SIZE=4096):
148+
149+
dst = torch.empty(src.shape, dtype=matching_int(dst_dtype), device='cuda')
150+
downcast_emulated[(src.shape[0] // BLOCK_SIZE,)](
151+
triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), rounding, BLOCK_SIZE, exponent_bits, mantissa_bits, exponent_bias)
152+
return dst
153+
154+
155+
@triton.jit
156+
def upcast_emulated(src, dst, BLOCK_SIZE : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr):
157+
158+
exponent_compensator : tl.constexpr = 2.0 ** (127 - exponent_bias)
159+
160+
numbits_src : tl.constexpr = 1 + exponent_bits + mantissa_bits
161+
tl.static_assert((numbits_src == 8) or (numbits_src == 16), "numbits_src must be 8 or 16")
162+
tl.static_assert(dst.dtype.element_ty == tl.float32, "dst dtype must be float32")
163+
164+
idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
165+
166+
x = tl.load(src + idxs)
167+
168+
if numbits_src == 8:
169+
x = x.to(tl.uint8, bitcast=True)
170+
elif numbits_src == 16:
171+
x = x.to(tl.uint16, bitcast=True)
172+
173+
x = x.to(tl.uint32)
174+
175+
mantissa_mask : tl.constexpr = (1 << mantissa_bits) - 1
176+
exponent_mask : tl.constexpr = (1 << exponent_bits) - 1
177+
178+
mantissa = x & mantissa_mask
179+
exponent = (x >> mantissa_bits) & exponent_mask
180+
sign = (x >> (numbits_src - 1))
181+
182+
y = (sign << 31) | (exponent << 23) | (mantissa << (23 - mantissa_bits))
183+
y = y.to(tl.float32, bitcast=True)
184+
y = y * exponent_compensator
185+
186+
tl.store(dst + idxs, y)
187+
188+
189+
def launch_upcast_emulated(src, exponent_bits, mantissa_bits, exponent_bias, BLOCK_SIZE=4096):
190+
191+
dst = torch.empty(src.shape, dtype=torch.int32, device='cuda')
192+
upcast_emulated[(src.shape[0] // BLOCK_SIZE,)](src, triton.reinterpret(dst, tl.float32), BLOCK_SIZE, exponent_bits, mantissa_bits, exponent_bias)
193+
return dst
194+
195+
196+
def downcast_test(src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, max_repr, offset):
197+
198+
src = launch_exhaustive_populate(src_dtype, offset << 24, 2**24, False, src_dtype.primitive_bitwidth, max_repr)
199+
dst = launch_type_convert_triton(src, src_dtype, dst_dtype, rounding)
200+
src = launch_type_convert_triton(src, src_dtype, tl.float32)
201+
202+
dst2 = launch_downcast_emulated(src, tl.float32, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias)
203+
204+
dst = launch_upcast_emulated(dst, exponent_bits, mantissa_bits, exponent_bias)
205+
dst2 = launch_upcast_emulated(dst2, exponent_bits, mantissa_bits, exponent_bias)
206+
207+
if not (torch.equal(dst, dst2)):
208+
209+
print('Error!!!')
210+
211+
dst = dst.cpu().detach().numpy()
212+
dst2 = dst2.cpu().detach().numpy()
213+
src = src.cpu().detach().numpy()
214+
215+
print(src[dst != dst2][0])
216+
print(dst[dst != dst2][0])
217+
print(dst2[dst != dst2][0])
218+
print(hex(src.view(np.uint32)[dst != dst2][0]))
219+
print(hex(dst.view(np.uint32)[dst != dst2][0]))
220+
print(hex(dst2.view(np.uint32)[dst != dst2][0]))
221+
print('')
222+
raise ValueError('%d elements mismatch' % (dst != dst2).sum())
223+
224+
225+
def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bias, max_repr):
226+
227+
numbits_src = exponent_bits + mantissa_bits + 1
228+
229+
src = launch_exhaustive_populate(src_dtype, 0, 65536, False, numbits_src, max_repr)
230+
231+
dst = launch_type_convert_triton(src, src_dtype, dst_dtype)
232+
dst = launch_type_convert_triton(dst, dst_dtype, tl.float32)
233+
234+
dst2 = launch_upcast_emulated(src, exponent_bits, mantissa_bits, exponent_bias)
235+
236+
assert(torch.equal(dst, dst2))
237+
238+
239+
@pytest.mark.parametrize("src_dtype, dst_dtype", [
240+
('float16', 'float32'),
241+
('bfloat16', 'float32'),
242+
243+
('float8e5', 'float16'),
244+
('float8e5', 'bfloat16'),
245+
('float8e5', 'float32'),
246+
247+
('float8e4b15', 'float16'),
248+
# ('float8e4b15', 'bfloat16'), # Unsupported conversion from f8E4M3B11FNUZ to bf16
249+
('float8e4b15', 'float32'),
250+
251+
('float8e4nv', 'float16'),
252+
('float8e4nv', 'bfloat16'),
253+
('float8e4nv', 'float32'),
254+
])
255+
def test_typeconvert_upcast(src_dtype, dst_dtype):
256+
257+
if src_dtype == 'float8e4nv' and torch.cuda.get_device_capability(0) < (9, 0):
258+
pytest.skip("float8e4nv upcast tests only supported on compute capability 9.0+")
259+
260+
# dtype : (exponent_bits, mantissa_bits, exponent_bias, max_repr)
261+
stuff = {
262+
'float8e4b15': (4, 3, 15, 0x7e),
263+
'float8e4nv': (4, 3, 7, 0x7e),
264+
'float8e5': (5, 2, 15, 0x7b),
265+
'float16': (5, 10, 15, 0x7bff),
266+
'bfloat16': (8, 7, 127, 0x7f7f),
267+
}[src_dtype]
268+
269+
upcast_test(getattr(tl, src_dtype), getattr(tl, dst_dtype), *stuff)
270+
271+
@pytest.mark.parametrize("src_dtype, dst_dtype, rounding, max_repr", [
272+
('float32', 'float16', 'rtne', 0x477fe000),
273+
('float32', 'float16', 'rtz', 0x477fe000),
274+
('float32', 'bfloat16', 'rtne', 0x7f7f0000),
275+
('float32', 'bfloat16', 'rtz', 0x7f7f0000),
276+
('float32', 'float8e5', 'rtne', 0x47600000),
277+
('float32', 'float8e5', 'rtz', 0x47600000),
278+
('float32', 'float8e4nv', 'rtne', 0x43e00000),
279+
# ('float32', 'float8e4b15', 'rtne', 0x3fe00000), # Skip, no HW rtne conversion from f32 to f8e4b15
280+
281+
('bfloat16', 'float8e5', 'rtne', 0x4760),
282+
('bfloat16', 'float8e4nv', 'rtne', 0x43e0),
283+
284+
('float16', 'float8e5', 'rtne', 0x7b00),
285+
('float16', 'float8e4nv', 'rtne', 0x5f00),
286+
])
287+
def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr):
288+
289+
if src_dtype != 'float32' and torch.cuda.get_device_capability(0) < (9, 0):
290+
pytest.skip("non-float32 downcast tests only supported on compute capability 9.0+")
291+
292+
if dst_dtype.startswith('float8') and rounding == 'rtne' and torch.cuda.get_device_capability(0) < (9, 0):
293+
pytest.skip("float8 downcast with RTNE rounding tests only supported on compute capability 9.0+")
294+
295+
# dtype : (exponent_bits, mantissa_bits, exponent_bias)
296+
stuff = {
297+
'float16': (5, 10, 15),
298+
'bfloat16': (8, 7, 127),
299+
'float8e5': (5, 2, 15),
300+
'float8e4b15': (4, 3, 15),
301+
'float8e4nv': (4, 3, 7),
302+
}[dst_dtype]
303+
304+
for i in range(256):
305+
downcast_test(getattr(tl, src_dtype), getattr(tl, dst_dtype), rounding, *stuff, max_repr, i)

0 commit comments

Comments
 (0)