Skip to content

Superior 3INST parameters #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

louiehelm
Copy link

@louiehelm louiehelm commented Apr 28, 2025

Summary: Swapped out original LCG params in 3INST for a better, computationally screened, optimal MCG multiplier

0.2-3.9% lower PPL
1.25% faster processing (addition removed)

Details:

3INST's default LCG params are knowably suboptimal:
x = 89226354 * x + 64248484

Property 89,226,354 (0x5517C72) 4,286,219,237 (0xFF7A83E5)
Odd
>sqrt(2^32)
>2^32/2 = 2147483648
Congruent 5 (mod 8)
High figure of merit
"Spectrally good" in higher dimensions
Maximum potency

See Computationally easy, spectrally good multipliers for congruential pseudorandom number generators for deeper discussion

Demanding LCGs have good spectral properties is not always necessary but for a 256-diminsion Trellis Decoder, spectral quality in higher dimensions is actually a key performance requirement.

Screened all 45458 spectrally good LCG multipliers and all 49507 spectrally good MCG multipliers uploaded to github by @vigna

Many LCG multipliers had strictly lower RMS in the decoder but none had universally superior PPL when actually quantizing and decoding real models.

But several MCG multipliers had both lower RMS in a vacuum AND universally better PPL on several models @ multiple bitrates. The best one appears to be 0xFF7A83E5.

0xff7a83e5_codebook

Performance Comparison of Multiplier 0xFF7A83E5 Across Models

Metric Mistral-7B-instruct Llama-3.2-1B-Instruct
Original Baseline (89226354)
Baseline [email protected] 6.897755 23.382405
Baseline [email protected] 5.962067 15.637272
Baseline [email protected] 5.821786 14.540853
With Multiplier 0xFF7A83E5
[email protected] 6.881336 22.464554
[email protected] 5.957563 15.410229
[email protected] 5.820617 14.468759
Improvement (%)
[email protected] 0.238% 3.93%
[email protected] 0.076% 1.45%
[email protected] 0.020% 0.50%
Overall Performance
Avg Improvement 0.111% 1.96%
Total Improvement 0.334% 5.87%

Also inspected how this impacts the CUDA (removing addition):

Original CUDA PTX:
mad.lo.s32 %r4, %r3, -1997118179, 64248484;
and.b32 %r5, %r4, -1879076865;
xor.b32 %r2, %r5, 996162400;

New CUDA PTX:
mul.lo.s32 %r4, %r3, -8748059;
and.b32 %r5, %r4, -1879076865;
xor.b32 %r2, %r5, 996162400;

Benchmarking reveals small but measurable 1.25% speed increase by eliminating addition.

Even modest 1.25% speedup and a small 0.2-3.9% lower PPL is hard to come by in other places, given how efficient EXL3 is shaping up to be.

Swapped out original LCG params for computationally screened optimal MCG param.

0.2-3.9% lower PPL
1.25% faster processing (addition removed)
@turboderp
Copy link
Member

This is very interesting. I will have to do a lot of tests, also on larger models, just to thoroughly verify this. Because to actually include it would require doubling the number of kernel instances and adding some metadata to quantized models to identify the new codebook (a bit too far along to just invalidate all existing EXL3 models.)

I'm surprised there's a performance difference, but it looks like the compiler does produce an extra MOV instruction in the first case since 64248484 is more than 20 bits, and the GEMM kernel is currently running into an ALU bottleneck due to the trellis unpacking overhead (on Ampere mostly).

Also, was the 1.25% speed increase just the kernel latency or inference overall, and what hardware did you see this on?

@louiehelm
Copy link
Author

More testing today with larger model:

Metric Llama 3.1 8B Instruct Llama 3.1 70B Instruct
Original Baseline
Baseline [email protected] 10.421164 6.533785
Baseline [email protected] 8.119616 4.450084
Baseline [email protected] 7.756961 3.822873
With Multiplier 0xff7a83e5
[email protected] 10.453533 6.543885
[email protected] 8.123410 4.422189
[email protected] 7.754818 3.835484
Improvement (%)
[email protected] -0.31% -0.15%
[email protected] -0.05% 0.63%
[email protected] 0.03% -0.33%
Overall Performance
Avg Improvement -0.11% 0.05%
Total Improvement -0.33% 0.15%

Few minor regressions. Mostly breakeven. Biggest gain on 3.0bpw which already appears to be the format's new sweet spot.

Improved parameter is most valuable if:

  1. Decoder improvement gives more measurable speedup on older hardware
  2. PPL improvements are more numerous and larger than minor regressions
  3. Other performance benchmarks bear out higher fidelity
  4. There's abstract benefit to having a principled parameter

Happy to run further evals. Which ones would be most helpful? Is there a KLD eval? Or another automated benchmark you find correlates well with actual model quality?

was the 1.25% speed increase just the kernel latency or inference overall

Just the decoder's isolated CUDA routines in a vacuum.

and what hardware did you see this on?

2 x 5090. Perhaps older hardware has larger speedup from improving the decoder?

@turboderp
Copy link
Member

Cutting out a MOV instruction is definitely worth it, even if it ends up breaking even on accuracy. I think if the difference is this small on perplexity, most likely it will be hard to detect any other way. I didn't write up a KLD test yet, but eval/model_diff.py would be a good place to start.

I'm a little tied up with Qwen3 at the moment, but I will get back to working out a nice way to incorporate this.

@louiehelm
Copy link
Author

More evidence for 0xFF7A83E5

Out of curiosity, I coded my own routines to reproduce the QTIP paper's distortion rate calculations.

Got slightly different results than what they reported:

My tests QTIP Difference
LLOYD-MAX 0.1172 0.118 0.7%
1MAD 0.0687 0.069 0.4%
DR(infinite) 0.0625 0.063 0.8%
3INST 0.0743 0.069 7.7%

First 3 are nitpicks... essentially just replicating their work with 1 more sig fig.

  • Although, I'll note others have also published 0.0117 as the distortion rate for 2-bit Lloyd-Max
  • DR (infinite) = 2^(-2 * bitrate) = 2^(-2 * 2) = 2^-4 = 0.0625. Rounding to 0.063 feels odd for a tight theoretical lowerbound that only required one more digit to express precisely.
  • 1MAD is tiny bit below 0.069 at ~0.0687. This algorithm seems to have inherently more sample variance so 0.069 may be the best estimate -- and perhaps why they chose to report 2 sig figs across the board.

However 3INST is more confusing:

  • 3INST: 0.0743 average distortion rate. Never collected a sample outside [0.072 - 0.076]

Their code shows they tested another decoder called 2MAD. My guess is 2MAD has a measured distortion rate of 0.069 too and they were initially planning to report both 1MAD and 2MAD results. But then they found 3INST and probably decided that since 2MAD was just slower without being measurably better, they would instead only report 1MAD and 3INST data. But then somehow the data for 3INST distortion rate never got updated in Table 1 of the paper.

In any case, I've sampled the richest chunk of the MCG multiplier space and using abstract MSE measurements (rather than full end-to-end model quantizing / perplexity measures), 0xFF7A83E5 once again ranks at the top with the lowest avg distortion rate.

3INST (MCG-0xFF7A83E5): 0.0726 (0.0017 lower distortion rate than original 3INST params)

Next 3 best 3INST MCG multipliers:

  • 0xEAB209FD 0.001425
  • 0xFC275E4D 0.001256
  • 0xF2BAA58D 0.001239

So far only tested ~10% of the spectrally good MCG multipliers from Vigna's work for MSE improvements. Should complete exhaustive rankings in next couple days.

But it's good seeing a totally different evaluation criteria that uses completely separate code also identify the same exact multiplier as the best one.

And relative to the original 3INST baseline, this multiplier produces 2.3% lower distortion which is roughly the average improvement I'm seeing in 2-bit model perplexity too. So this is a good sanity check that this particular parameter isn't just "getting lucky" on the specific models I've tested, but actually would be expected to be 2.3% better at 2-bits.

MSE distortion rate at k=1 bit, 3 bits, 4 bits, and 5 bits all improved as well at roughly 3.3%, 2.0%, 1.6%, and 1.2% vs baselines.

LOP3 not currently used?

QTIP paper mentions how 3INST is especially good precisely because mask and xor step can be combined into one low level lop3 instruction.

So theoretically nvcc should render these two lines:

x &= 0b10001111111111111000111111111111u;
x ^= 0b00111011011000000011101101100000u;

Down into something like this PTX instruction:

asm("lop3.b32 %0, %1, 0x8FFF8FFF, 0x3B603B60, 0x6A;" : "=r"(x) : "r"(x));

But nvcc never does this (at least on my system with sm120). Maybe it's not kosher to pass in the same register for an input and output with lop3? If so, changing CUDA code from "x ^= ...". into "y = x ^ ..." might be enough to fix this and help the compiler successfully find this optimization. I haven't found a painless way to reliably implement LOP3 yet but it's probably a trivial patch for a smart CUDA wizard. We have any of those around here? If LOP3 isn't getting implemented in any of the EXL3 decoder kernels and it's using 2 separate [and (...) and xor(...)] instruction every pass, that's another 8-10% speedup being left on the table. Maybe this will finally make Ampere perform better?

There might also be another clever way to shave off a mov in the final re-packing steps by ordering it slightly differently as well. But we can explore that later. It's the highest hanging fruit and can be optimized down any time.

Removing addition, using the right multiplier, and making sure LOP3 gets implemented are all more important for getting better speed and quality.

Let me know when KLD eval is ready so we can triple verify this change actually provides 1-3% avg higher fidelity via lower MSE, lower perplexity, and lower KLD at all bitrates (especially lower ones). The speedup may grow to +10% too once lop3 is used.

@louiehelm
Copy link
Author

Top 6 3INST MCG parameters found through exhaustive MSE sampling of all Vigna 32-bit multipliers (plotted with your visualization and RMS calculation):

0xE2A54B45

0xFF7A83E5

0xA2EC794D

0x00D0AA65

0xCAF6A435

0xB26DB26D

These plots are a fine tool for getting a general feel for each codebook by examining part of its distribution. But do keep in mind that this is only plotting raw input <--> output for dim=1 for 8 different bitrates [1-bit - 8-bit]. However, the QTIP Trellis codes operate in 256 dimensions. So these charts are only showing 1 of the 256 distributions that effect overall performance. Also, unlike most applications where dim=1 is by far the most meaningful distribution, there's nothing special about dim=1 within a 256 dimension Trellis Coder. So it should be expected that quality of the 255 higher dimensions (aka higher-order spectral quality) actually matter equally in all dimensions for this application.

Just pointing this out so we remember RMS=1.17 vs RMS=1.20 in these charts just means ~0.4% of all distributions are ~2.6% better. But that doesn't automatically mean the other 99.6% of the distribution quality will also be 2.6% better. Often higher order behavior is quite jagged and random and good distributions in one dimension only loosely correlate with distribution quality within other dimensions.

That's why I tested these six 3INST MCG values for actual MSE distortion using real 256 dimension Trellis Coding:

2-bit PLUS

3-bit PLUS

4-bit PLUS

Y-axis re-scaled to distortion relative to an optimal codebook [2^(-2*bitrate)]

Average distortion and variance was plotted by sampling each decoder with 145-195 different 8192 [16 bit] Gaussian data chunks. This effectively simulates the precise error that occurs when quantizing and decoding 100+ of the most difficult to quantize tensors that it's possible to construct. This is the standard model (worst-case) for distortion rate calculations. It's broadly assumed that codebooks which are strongest in this worst-case setting can "degrade gracefully" into also being the most robust even in less-challenging settings. One of those less-challenging settings is modern LLMs weights which empirically only have ~76%-79% entropy levels (lower-bound = Llama 2; upper-bound = Llama 3). Modern LLMs haven't reached 100% entropy yet (and possibly won't/can't within the current transformer paradigm). In any case, just pointing out in passing that these codebooks I'm locating are optimized to minimize distortion in a slightly more challenging domain than what we intend to use them for here.

As for the data, my hunch that 1MAD had higher variance was right. That said, it's range still looks better by most metrics (max distortion, avg distortion, etc) than even improved 3INST decoders at 2-bit and 3-bit. However, the variance in the distortion gets worse at higher bitwidths and may become intolerable.

Is there a reason 3INST was initially chosen over 1MAD? I know ik_llama.cpp also chose 3INST for their experimental IQx_KT quants so I assume there's a good reason? It's just weird because in theory I'd expect 1MAD to both run faster and provide better results on average (below 4-bits).

Is it just the inconsistency (higher variance) in 1MAD quantization quality that was the problem? If that's all that's wrong, I could try to find better params for 1MAD to reduce distortion and variance some more. Default 1MAD LCG params certainly aren't optimal either. Would that make it a more attractive option again?

Current 3INST decoder is only giving:

  • 2-bit quant performance equivalent to an optimal decoder with 1.87525bpw
  • 3-bit -- 2.86937
  • 4-bit -- 3.8528

SUMMARY:

  • Using default 3INST creates ~18% more distortion (above an optimal decoder) which manifests as ~0.14bpw of wasted coding overhead at all EXL3 bitwidths.
  • My original proposed 3INST MCG multiplier (0xFF7A83E5) would only have ~0.12bpw coding overhead.
  • It's possible one of the other top 3INST MCG multipliers (like 0xCAF6A435) could bring that down to ~0.11bpw coding overhead.
  • If 1MAD has better params with variance more like 3INST, it could potentially have only ~0.08bpw coding overhead. Should I look for these better 1MAD params?
  • There's another, more general way to analyze distortion rates that shows optimal decoders should remove a full 1.0bpw in coding overhead from current EXL3 models. I'm skeptical but also can't identify why it's not right. So while I'm confident that better decoder params like the ones above will easily shift the bpw vs perplexity plots over to the left by at least 0.02bpw, and think 0.06bpw could be possible (with optimized 1MAD params), I can't rule out a shift by as much as 0.50bpw (~50% of 1.0bpw).

@turboderp
Copy link
Member

Still held up by other stuff. Currently completely refactoring the kernels and probably the next thing I'll be occupied with will be kernel fusion, an alternative GEMV path and maybe Ampere optimizations. It's just a bad time right now to also have the complexity of a second quantization format to worry about, though there's no reason it couldn't be plugged in a little later. I'll write a quick KLD test in a little bit, at any rate.

As for compiling two steps into a single LOP3, the compiler may simply be choosing not to do it because it isn't efficient. MOV+LOP3 should have half the throughput of LOP3 on its own, but it may have lower latency overall, and scheduling the MOV instruction might be free if there's some other pipe that's stalled at that point anyway. It's very hard to know exactly what the compiler's deal is sometimes since SASS documentation is so sparse.

A point about the codebook worth noting, perhaps, is that the quantizer doesn't always end up using all of it. This may be due to insufficient regularization (128D Hadamard rotations for the sake of kernel fusion and the ability to split models for tensor-parallelism at load time,) but especially at higher bitrates where there is more coverage, it ends up being preferable to scale down the input by a factor of 10-40% to make better use of the denser middle part of the distribution (and possibly widen that "spike" that always appears in the middle?)

I was initially testing with 1MAD and did see promising results. The distribution was a perfectly smooth Gaussian too, albeit with some gaps and very bad correlations at 1 bpw. I'll see if I can dig up the results in a little bit, or recreate them. Paper shows this plot for 2 bpw, with narrow bands that end up not really mattering:

image

I definitely didn't explore this fully, since there was also a whole framework to build around the quantization algorithm. The overall idea with the project is to make these algorithms accessible in a usable format. And QTIP, despite being SOTA, still remains largely unavailable with only a handful of quantized models on HF (all made in relation to the paper it seems) that I still can't get working despite quite a bit of effort. So I had to make a decision eventually or be stuck endlessly obsessing over the details. That ultimately came down to 3INST achieving better perplexity on actual models in my tests, though not by much, and looking more effcient at a glance. Also the QTIP paper describe the two methods as roughly equivelent, with a slight edge to 3INST at 2 bpw:

image

Can't say for sure if this was the same reason the llama.cpp people chose 3INST for their experiments, but at least when it comes to performance I'm not convinced 1MAD would end up being faster. It's essentially:

1MAD:

  • LCG
  • sum unsigned bytes (PRMT)
  • convert int->float
  • constant add, mul (can be FMA)

3INST:

  • LCG
  • bitwise xor, and (single LOP3 in theory)
  • reinterpret (free)
  • half add

At least on the surface, 3INST seems more efficient. I believe PRMT executes on the SFU pipeline and LOP3 has 4x the throughput (?), and also the float conversion takes 4 cycles or something.

The real bottleneck is unpacking the trellis, though. You end up needing a lot of registers, you want to stick to 32-bit SMEM access which means reading any 16-bit field requires loading two 32-bit values, doing a funnel shift and then masking (would be interesting maybe to explore LCGs/MCGs that are indifferent to the high 16 bits of the input.)

All that said it's possible 1MAD (or some other function) could be or could be made faster at the end of the day. That's definitely worth exploring. And if indeed there's something on the order of 0.5 bpw to be gained, even in theory, that needs to be looked into. It feels unlikely, though, noting that EXL3 already matches (sometimes outperforms) methods that rely on finetuned codebooks.

xVJpdzty0lglAuQswmKyW

Wish I could get QTIP inference working in Transformers. Then I could just plot the finetuned hybrid code models onto the same graph and get a better idea. :) 🤷

@louiehelm
Copy link
Author

Like you, I also can't make actual QTIP code work. 😂 The requirements file is a lie. There's no combination of torch, cuda, numpy, fast-hadamard-transform, transformers, and qtip that can compile simultaneously. I gave up after trying every public fork of fast-hadamard-transform, creating my own fork that compiles in more than zero versions of torch+cuda, only to still not have it work.

I love QTIP (mathematically) but their research code is diabolical. The fact I can't make it run in the era of infinite LLM assistance is mind-blowing.

Their paper mentions using using vabsdiff4 to do the 4 adds -- which can be emulated with PRMT -- or achieved using other neat methods people found a few years ago. Of course, even this random post about VABSDIFF4 incidentally cites Vigna. We should really just ask Vigna what the best 256-dimensional static codebook function is. You know it's already on a napkin on his desk and it's only 2 instructions somehow.

Got some prettier looking 1MAD distributions even a few weeks ago when I first looked into this:

Figure_4216904717

especially at higher bitrates where there is more coverage, it ends up being preferable to scale down the input by a factor of 10-40% to make better use of the denser middle part of the distribution (and possibly widen that "spike" that always appears in the middle?)

That's VERY interesting! What file does the EXL3 code make these sorts of scaling decisions in? I assumed in practice that rotations would probably limit codebooks from ever reaching 100% usage, but the fact you're able to scale inputs going into a TC and book a "profit" in reduced distortion is a stronger indictment on the default codebook params than anything I've found so far. Maybe I could add some verbose output to the quantizer to track how often and how aggressively this sort of scaling hack wins out over the default full codebook? That's probably an excellent segregate measure of absolute cookbook quality.

I'll try to tighten up all these threads into more concrete answers later this week. It just takes time to fully characterize param space for these computed codebooks. Appreciate you need to stay on task with other components too so thanks for giving this attention when you have time. I'll keep chugging along and we'll sort this all out soon.

@turboderp
Copy link
Member

All the scaling and quant logic (aside from the Viterbi kernel and some other CUDA support functions) is in quantize.py. I tried to keep it neat but it always becomes a little messy when it needs to also work in less than an unlimited amount of VRAM. And of course every new model throws a few surprises at you. So 🤷

Main function is quantize_exl3, and regularize is the function that tries to make the input tensor as Gaussian as possible.

The input and output channel scales were added later on in development. They're not necessary for all models but in a few cases the Hadamard rotations just aren't enough to deal with extreme outlier channels. This might not be an issue if I was rotating full rows and columns, but doing so would both be less efficient and as mentioned result in tensors that can't be split for tensor parallelism at load time. You'd have to requantize the model for any given hardware configuration and I'm trying to avoid that.

I did a bunch of tests at a range of scales but couldn't find a way to predict the best scale for any given tensor at a given bitrate. So I ended up with a golden section search on a sample of the regularized input tensor, which should be solid under a couple of assumptions:

  • tiles along a wrapped diagonal over the entire tensor make a representative sample, and
  • LDLQ rounding isn't relevant for determining the scale, and
  • MSE is a unimodal function of scale

At least for the last assumption I did do some plots (blue is 2 bpw quantization and red is 5 bpw):

image

I think it's very likely scaling wouldn't improve anything if the regularized tensor was perfectly IID Gaussian, but it generally isn't. There may be outliers that simply can't be encoded and need to be clamped, for instance, so downscaling the tensor reduces the error from those outliers. And the higher the bitrate, the less penalty you incur from the mismatched distributions, so you end up with this tradeoff.

It's coincidental that 1.0 is the ideal scale for 2 bpw in this example. It all shifts around a bit depending on how amenable a given tensor is to the regularization process, or maybe how Gaussian it was to begin with.

@turboderp
Copy link
Member

I added a KLD test to model_diff.py. To run it:

python eval/model_diff.py -ma /mnt/models/test_model -mb /mnt/models/ref_model -r 5

This would do 5 rows of 2048 wiki2 tokens. You can do more rows of course, but you end up with two big (num_rows, 2048, vocab_size) tensors to manage and it's not very clever about that. Example output:

$ python eval/model_diff.py -ma /mnt/str/models/llama3.2-1b-instruct/exl3/4.0bpw/ -mb /mnt/str/models/llama3.2-1b-instruct/hf/ -r 10
 -- model.embed_tokens                         error: 0.000000   max_diff/norm: 0.000000
 -- model.layers.0                             error: 0.060575   max_diff/norm: 0.005043
 -- model.layers.1                             error: 0.034527   max_diff/norm: 0.025184
 -- model.layers.2                             error: 0.040355   max_diff/norm: 0.024767
 -- model.layers.3                             error: 0.046076   max_diff/norm: 0.024320
 -- model.layers.4                             error: 0.049348   max_diff/norm: 0.023913
 -- model.layers.5                             error: 0.053157   max_diff/norm: 0.023478
 -- model.layers.6                             error: 0.056017   max_diff/norm: 0.023225
 -- model.layers.7                             error: 0.058559   max_diff/norm: 0.022791
 -- model.layers.8                             error: 0.060759   max_diff/norm: 0.022081
 -- model.layers.9                             error: 0.062963   max_diff/norm: 0.020648
 -- model.layers.10                            error: 0.065641   max_diff/norm: 0.019616
 -- model.layers.11                            error: 0.069408   max_diff/norm: 0.017781
 -- model.layers.12                            error: 0.074267   max_diff/norm: 0.015942
 -- model.layers.13                            error: 0.081371   max_diff/norm: 0.014155
 -- model.layers.14                            error: 0.090546   max_diff/norm: 0.012164
 -- model.layers.15                            error: 0.097819   max_diff/norm: 0.007099
 -- model.norm                                 error: 0.113367   max_diff/norm: 0.001088
 -- lm_head                                    error: 0.084074   max_diff/norm: 0.000186
 -- A perplexity: 17.95771179
 -- B perplexity: 17.56365633
 -- A label in top-K:
      K = 1: 0.4377
      K = 2: 0.5479
      K = 3: 0.6058
      K = 4: 0.6494
      K = 5: 0.6766
 -- B label in top-K:
      K = 1: 0.4399
      K = 2: 0.5497
      K = 3: 0.6100
      K = 4: 0.6504
      K = 5: 0.6785
 -- Top-K agreement, A vs B:
      K = 1: 0.9245
      K = 2: 0.7808
      K = 3: 0.6125
      K = 4: 0.4508
      K = 5: 0.3154
 -- KL divergence (A, B):  0.01924682
 -- KL divergence (B, A):  0.01966666

@louiehelm
Copy link
Author

Fantastic! I'll try out KLD eval and collect data on various codebook params across different models.

Looks like your quantization function has lots of good instrumentation in it to gauge relative quality across different codebook functions too.

Thanks for sharing more details about how you constructed the regularization function. Very interesting! Trying hard to wrap my head around it. Forgive me if I'm not grasping all of the nuances right away. It just seems to violates so much of what I thought I knew about Hadamard rotations. My initial thoughts are something like:

  1. How are the Hadamard functions so ineffective at taming outliers? Is there something wrong with them? I would have expected them to completely eliminate outliers even in the most pathological distributions. So I'm perplexed to hear they can't handle much simpler distributions (like LLM weights).

  2. Even if somehow these rotation functions weren't capable enough to reliably eliminate outliers, wouldn't it still always be strictly better (from an overall MSE perspective) to regularize by re-scaling post-rotation, not pre-rotation?

  3. If I'm testing different codebook functions, I should probably comment out this line during testing to compare performance without regularization when deciding the best codebook since the 50 and the 15 values that are hardcoded into the current regularization logic are implicitly tuned to compensate specifically for the imperfections in the default 3INST codebook, right?

  4. You say "in a few cases the Hadamard rotation just aren't enough to deal with extreme outlier channels". Out of curiosity, which specific models are the worst for this? My immediate intuition is more ancient models like Llama 2 (and other models trained back in the native float32 era) would be most effected, right? Although even modern model developers still seem to be terrible at properly using their dynamic ranges so maybe lots of stuff is still fragile? Perhaps recent DeepSeek models are less plagued by bad outliers? Curious what your data shows.

  5. This may be dumb but what's stopping you from just rotating blocks of every tensor in 256 weight chunks? You're not going to be able to split blocks any smaller (at least efficiently) due to the underlying dimensionality of the Trellis code. So you'd still have maximal opportunity for tensor parallelism all within the same file. Yeah, there's always gonna be that one guy whose 7x GPU system can't evenly split Qwen 3 0.6B, but that guy is already out of luck with your current splits, right? Wouldn't it actually always be better in all tensor parallelism situations to use the min block size [256]? Is there a competing interest in rotating and coding larger blocks? My intuition is Hadamard rotations would always perform worse over larger chunks and so would the Trellis code (especially when using a sub-optimal codebook).

  6. Is some subtlety in the blocking, scaling, rotation, and regularization code accounting for the "missing" ~0.5bpw delta in overhead between what a 256-dim Trellis code with current codebook distortion levels should be getting vs what EXL3 is actually getting right now? Perhaps the entire bpw vs perplexity curve would shift left ~0.5bpw if every tensor was rotated and coded separately in minimal 256 weight chunks? But maybe that's undesirable for some other reason or imparts too much overhead somewhere else in the online decoding pipeline? It seems abstractly like it would be computationally break-even (or even better in some multi-gpu settings) but perhaps there's a hidden gotcha between naive theory and actual implementation?

@turboderp
Copy link
Member

  1. Hadamard rotations do eliminate outliers in a sense, but they do so by mixing the outlier channels into all the other channels involved in the rotation. The outliers still end up dominating and drowning out the smaller weights.

  2. The issue probably comes down to how rotation is performed in batches of 128 columns/rows. When you have a single outlier channel that's orders of magnitude larger than all the others, the corresponding batch becomes much larger as well after rotation, and then the distribution across all batches is no longer Gaussian, even if each span of row or column vectors might be, individually. Normalizing the input channels prior to rotation solves this for 99% of all models.

I did try to rescale channels after rotation but couldn't get good results from it. It's also less efficient since you'd still need the separate input/output sign flips, and this way you can combine them into a single operation. I think a better solution might be to scale each 128x128 block of the regularized tensor independently. But this would require some modifications to the GEMM kernel, followed by days and days of testing to make sure this new method still works across some large enough selection of models. And all the while people will be begging me to add support for this or that new architecture, make it faster, support ROCm, when is multimodal ready, etc. I just have to prioritize and build on what works instead of going back to the drawing board every other week.

  1. If you comment out the call to the regularization function you won't have Hadamard rotations at all, so that's not going to give you very representative results. The two hardcoded values are a little arbitrary, but mostly there to solve a particular degenerate case where LDLQ rounding blows up for reasons I haven't had time to fully investigate. The workaround is (don't ask me why) to identify when the inputs to a linear layer are dominated by a few large channels, and then disable the output channel scales in those cases. You can also override it with the --no_output_scales argument on the command line.

  2. The model I'm currently fighting with is Command-R-Plus. There is a tensor with a single (!) weight hundreds of times larger than the rest, and one with a single column that's also abnormally huge. Here's what that looks like after the rotations:

image

This might be some sort of overflow, I'm not sure yet.

  1. I think either your intuition is backwards or mine is :). As I understand it, the more channels get mixed, the more normal the output will look. This is why QTIP and QUIP# rotate all the input and output channels. There are diminishing returns, though, and 128 seemed like a good tradeoff: end-to-end results are still better than AQLM, empirically, and most models have a head dimension of either 128 or 256, which is the main concern for tensor parallelism.

  2. If you rotate each 16x16 matrix tile separately, apart from not ending up with very well-conditioned tiles most likely (?), you would also need to scale them independently. This would add some overhead both in storage and latency.

@louiehelm
Copy link
Author

  1. My mental model of how Hadamard functions work doesn't seem much like a mixing function that benefits from sheer scale. There's simpler structures that algorithm designers could reach for if that's all they were using them for. What sanity check eventually convinced you the Hadamard rotations were implemented correctly? A slippery aspect is that slightly incorrect constructions will often still work (partially). For example, the code in the utils folder looks fine at a glance. But why is PAGE_SIZE = 256? I haven't carefully traced the code paths so forgive me if I'm aliasing concepts but wouldn't this also need to be 128 to match the height and width of the block dimensions? A double-sized Hadamard could still "work" for reducing outliers (and work very well relative to nothing) but its performance would be significantly degraded (especially in a setting like this where a single outlier slipping through can doom the quality of the entire block). This could explain why your expectations for Hadamard rotations are so much lower than mine.

  2. Don't suppose I could do something cheap to experiment with these trade-offs myself like replacing the 5 instances of 128 in quantize.py with 256 to test out different block sizes? Will GEMMs go along with these shenanigans?

  3. Yes! I found --no_output_scales in convert.md documentation right after my comment so decided to use that for certain arms of my automated benchmarking. Although, just noticed codebook_scale = 1.24371088 in quantize.py. Is that RMS error for default 3INST params? Perhaps I should be modifying this to the true RMS for new params (~3.5-6% lower) to get better quants?

  4. Feels like another manifestation of the PrefixQuant phenomenon. Been meaning to try out the image dumps. They look nifty and quite helpful. That's on my agenda for Friday. :)

  5. Hadamard functions are tricky to reason about correctly. I published my thesis on novel quantum algorithms for distributed consensus but that was years ago. So I'm certainly not as sharp on the subject as I once was.

  6. My thought was that no scaling would be beneficial or necessary when encoding blocks equal to the dimensionality of the Trellis. So there wouldn't be additional scaling factors to calculate, store, or process since they wouldn't have a reason exist.

@turboderp
Copy link
Member

  1. PAGE_SIZE is unrelated. It's just how many tokens make up one page in the K/V cache. Which is a fixed value for flash-attn, but could change if I migrate to FlashInfer in the future.

I'm pretty sure the Hadamard rotations function correctly because I implemented them in several different ways and they're mutually compatible. The quantizer uses a standard 128x128 matrix (from Sylvester's method) and a Torch matmul, and the kernels use warp shuffling shenanigans. In all combinations H^T@H = I, etc.

And I'm not so much saying that outliers "slip by", but rather that they still end up dominating in a few extreme cases. So for an extreme example, if you have a 4D vector like (1, 1, 1, 1000), the rotation is going to be (501.5, -499.5, -499.5, 499.5). Quantizing to a grid after that you're still going to lose the weaker signal to rounding. Or, if you scale it for quantization with some Gaussian-ish codebook, all of the rotated values end up being in the tails of the distribution, because the single outlier is added to or subtracted from every other channel.

  1. The value of 128 is pretty much fixed unless you want to start rewriting the kernels. You should be able to quantize tensors with other values of had_k and had_n, but inference won't work and quantizing a model will fail after the first module (when it tries to do a forward pass with the quantized module.)

  2. If you disable the output scales it will still try try to estimate a global scale here. If you skip those three lines, it won't do the whole test quanting of the diagonal to tweak the scale, and the weights should end up with a RMS of whatever codebook_scale is set to. To disable the input channel scales, replace two lines right above that:

    in_channel_scales = block_rms(weight, dim = 1, keepdim = True)
    su = (su * in_channel_scales / (-codebook_scale) + 1e-10).float()  # mustn't be inplace
    weight /= su

->

    su = (su / (-codebook_scale) + 1e-10).float()
    weight /= su

TBH I've long since forgotten why it flips all the signs again there, and I can't imagine why it's needed. Adding the small eps seems pontless too. :D

  1. Well, in any case the codebook contains values from -4 to +4, and the weights are going to be on an arbitrary scale unless you normalize them one way or another. A single float16 scale per 16x16 tile would add 0.0625 bits per weight. So that's not catastrophic but something to keep in mind.

@turboderp
Copy link
Member

Yep, turns out the compiler just doesn't figure out how to combine the two ops into a LOP3. Free small performance increase with some inline PTX. Go figure. (:

@louiehelm
Copy link
Author

Initial KLD (forward pass) +PPL (10 rows; initial PPL data at start of thread is 100).

3inst_multiplier_kld_bar_charts3_half

These plots show performance with quantize.py codebook_scale = real std dev (RMS)

Switching to % KLD and % PPL to better see differences. Also removed 2 lowest performing multipliers:

llama32_1b_inst_multiplier_kld_percent_change10_half

Hard to pick clear winner from L3.2 1B alone. But encouraging to see KLD dropping 4-8% with most multipliers at all bitrates.

mistral_7b_inst_multiplier_kld_percent_change_half

0xCAF6A435 looks promising. If these KLD improvements are actually indicative of:

  • 2.0 bit = 6-12% higher fidelity
  • 3.0 bit = 0.5-4% higher fidelity
  • 4.0 bit = 4.5-5% higher fidelity

REMAINING TODO:

  • collect Llama 3 70B Instruct data
  • spot check 1.0bit and 5.0 quants
  • sweep through 1MAD params to see how best ones perform

@louiehelm
Copy link
Author

0x00D0AA65 testing showed less improvements + more regressions than the best multipliers.

It's the only good 24-bit multiplier so in theory a sicko could:

asm volitile("mul24.lo.u32 %0, %1, 0x00D0AA65;" ...

Not seeing speedup on Blackwell from this. Supposedly mul24 isn't faster anymore on modern cards. So unless you see big speed boosts on Ampere, there's no compelling reason to consider 0x00D0AA65 over the other options.

llama_31_70b_inst_multiplier_kld_percent_change_half

70B: 0xCAF6A435 looks best. Some merit in considering 0xFF7A83E5 which typically has better 2bit and 3bit performance.

0xCAF6A435 is more well-rounded at all bitrates though.

llama_32_1b_inst_multiplier_kld_percent_change_1bit_half

1.0bit and 1.5bit look much better with 0xCAF6A435. KLD regression at 5 bit is only 0.0001. On other hand, improvement at 1 bit was KLD: 4.82 --> 4.61 PPL: 1506 --> 1126.

Not implying sub-2bit quants will be good now, but for narrow use cases where 1.8bpw quants fits a specific GPU but 2.0bpw can't, 0xCAF6A435 degrades more gracefully than default 3INST.

Can you start testing 0xCAF6A435 to confirm it's a good candidate? I think either that or 0xFF7A83E5 would be best.

Remember to update codebook_scale = std dev in quantize.py for best performance.

0xCAF6A435 = 1.206441
0xFF7A83E5 = 1.199048

@turboderp
Copy link
Member

Currently everything's tied up testing block-sparse layers. I've got scripts queued for the next several days just quantizing and testing. Not looking forward to the electricity bill :|

As for codebook_scale, were you testing this with the global scale search enabled? Because in theory that should override the base scale anyway. And it definitely needs to be accounted for, since some models break badly without the full scaling applied (Llama happens to be very well-behaved and I'm not sure it's all that representative at the end of the day.)

At any rate, the main thing I'm worried about right now is compilation time (of all things). Currently there are 128 unique instances of the GEMM kernel, and I'm exploring ways to make it more flexible without doubling that number every time a new feature is added. But there's no reason in principle the MCG parameter couldn't be variable at runtime, I'll know more soon.

@louiehelm
Copy link
Author

QTIP paper overstated 2-bit RPTC distortion rate too.

My calc QTIP Difference
LLOYD-MAX 0.1172 0.118 0.7%
3INST 0.0743 0.069 7.7%
1MAD 0.0685 0.069 0.7%
RPTC 0.0659 0.068 3.1%
DR(infinite) 0.0625 0.063 0.8%

Real L=16 RPTC distortion rate from original 2010 RPTC citation was 0.06595 (or ~0.066) [calculated based off final table in paper]. This is 36% closer to the theoretical optimum so not a small difference!

If 0.068 distortion were correct, it shouldn't be possible to make computed codebooks much better than standard 1MAD. But since the 2-bit RPTC limit is actually below 0.066 there are plausibly 1MAD params a few % better.

And I've found a few! These don't use an additive factors either -- should they be called 1MUL now?

2-bit-1477-optimal-lines-quantization_mse_comparison
3-bit-1477-optimal-lines-quantization_mse_comparison
4-bit-1477-optimal-lines-quantization_mse_comparison

Average quality improvements with better 1MAD params are small but reduction in overall variability and worst-case distortion is a bigger deal. Makes it higher quality at most bitrates with less compromises and regressions than standard 1MAD.

Also removes addition step. Honestly I can't measure any speed difference between 1MAD vs 3INST. Perhaps my final multiplication/conversion step was more efficient?

If 1MAD and 3INST really are equally fast, it could make sense to primarily use 1MAD instead. Default 1MAD was more marginal but these alternate 1MAD constructions outperform 3INST in theoretical MSE reduction (at least up to 4 bit quants). At 5-8 bit 3INST technically has lower variance and could be better in worst case scenarios like trying to encode pure noise but for Hadamard-rotated, 80% entropy LLM weights, there's no daylight between their actual performance. Quick test runs with Llama 3.2 1B and Mistral 7B showed lower PPL at all bitrates. I'll test Llama 3.1 70B next.

Any chance the smoother Gaussian of 0xAD9A2EC5 1MAD is enough to reduce or eliminate the Command-R-Plus outlier you're currently struggling with using regular 3INST?

Could you dump that one 128x128 block of Command-R-Plus with that outlier to a file (post-rotation)? Right now I'm using a completely synthetic memoryless Gaussian source (read: pure noise) as the heart of my multiplier ranking pipeline. Perhaps instead I should rank decoder multipliers against the actual catastrophic distributions they need to be most resilient to -- outlier spikes (presumably from attention sinks?) washing out variability in certain blocks post-rotation.

1MAD 0xAD9A2EC5 codebook

BTW - this uses 0.0677 as final multiplier (implied divisor of 147.71 as opposed to the 148.80 in QTIP paper or the (implied) 118.84 in current codebook.cuh 2MAD code

@turboderp
Copy link
Member

Okay, so I'm working on making it switchable. So far I've added MCG mode as an option to the quantizer script, with an arbitrary multiplier. This way the framework handles both, and I can make some comparisons relatively easily. I'm not seeing a clear winner, though. Here's perplexity for L3-1B:

image

And KL divergence:

image

The latter is definitely the better measure to be focusing on, since it measures difference to the output of the unquantized model directly. Perplexity is simpler to measure but it can be misleading, e.g. by allowing for the quantized model to score better than the original (and I've seen this in a few cases with very large MoE models.)

Now, it's very hard to pick a clear winner from these. Even the original LCG wins at 2.25 bpw. That's produced as a mix of 2, 3 and 4 bpw tensors, all of which work better individually with the 0xCAF6A435 variant, which kinda says to me there's some chaotic dynamics going on and it's going to be very hard to predict the E2E error from the error on individual linear layers.

0xCAF6A435 does look like it has the best overall performance, at least for this model. The setting is stored per tensor, though, so there's room for using individual multipliers per tensor (must be the same across all parallel tensors in block-sparse layers though) so potentially some E2E optimization could produce an even better final model.

I'll do some tests with the addition-less 1MAD next. 1MUL is fine I guess, though I'd probably prefer MUL1 or something, to avoid labels starting with a numeric character.

@turboderp
Copy link
Member

Okay, so some preliminary results with the addition-free 1MAD variant, 0xAD9A2EC5 multiplier:

image

Very smooth distribution at least, despite the gaps, but that's a given when summing four more-or-less uniform random numbers. MCG indeed seems to be just as good as an LCG for this purpose, which I guess also makes sense since adding a constant isn't going to make any of those four components more uniform anyway.

At K=1, neighboring states are still highly correlated, though that probably doesn't mean much for L3-1B, where perplexity at 1 bpw is on the order of 1000 regardless. So it's not plotted here, either:

image

KLD looks a little nicer:

image

So it does fairly well at 2bpw, though the difference isn't as pronounced at higher bitrates. Presumably this is because the quantizer still scales the values down to use the denser middle portion of the codebook, since the weights don't end up being perfectly Gaussian.

I'll need to clean up the code a bit and split the three different modes into individual kernel instances since E2E performance seems to take a 5% hit if I have a single kernel branching off to three different matmul functions. Then I'll push the changes probably tomorrow, with some cleanup and warnings to make it clear models created with these settings may or may not be supported in future versions.

Next step I guess is to evaluate the performance of each CB function. The MUL1 version might end up being faster, I'm not sure. It does reduce to very few instructions with the vabsdiff4 trick:

template <int cb>
__device__ inline half decode_pcb(uint32_t x, uint32_t mult)
{
    if constexpr (cb == 0)  // 3INST LCG, original
    {
        x *= 89226354u;
        x += 64248484u;
        asm volatile ("lop3.b32 %0, %0, 0x8fff8fff, 0x3b603b60, 0x6a;" : "+r"(x));
        half2_uint32 xu(x);
        return __hadd(__low2half(xu.as_half2), __high2half(xu.as_half2));
    }
    if constexpr (cb == 1)  // 3INST MCG
    {
        x *= mult;
        asm volatile ("lop3.b32 %0, %0, 0x8fff8fff, 0x3b603b60, 0x6a;" : "+r"(x));
        half2_uint32 xu(x);
        return __hadd(__low2half(xu.as_half2), __high2half(xu.as_half2));
    }
    if constexpr (cb == 2)  // 1MAD MCG
    {
        x *= mult;
        uint32_t sum;
        asm volatile ("vabsdiff4.u32.u32.u32.add %0, %1, %2, %3;" : "=r"(sum) : "r"(x), "r"(0), "r"(0) : );
        const __half k_inv_h = __ushort_as_half(0x1eee);   //  0.00677 = 1/147.7
        const __half k_bias_h = __ushort_as_half(0xc2e7);  // -3.452 = -510.0 * k_inv_h
        return __hfma(__int2half_rn(sum), k_inv_h, k_bias_h);
    }
}

But I still haven't profiled yet or checked the SASS it actually compiles to in the end.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants