Skip to content

Conversation

@levkropp
Copy link

When trying to convert google/umt5-xl to GGUF, I encountered "Model UMT5Model is not supported".

The model repository doesn't have safetensors files, so I downloaded it using AutoModel.from_pretrained(). When AutoModel loads a model, it uses the base model class (UMT5Model) rather than the task-specific variant (UMT5ForConditionalGeneration), and this is what gets written to config.json when calling save_pretrained().

While UMT5ForConditionalGeneration was already registered, the base UMT5Model class was not.

Adding @ModelBase.register("UMT5Model") allows the conversion to work for models downloaded with AutoModel, similar to how other base model classes like BloomModel, BertModel, and RobertaModel are also registered alongside their task-specific variants.

I tested the resulting GGUF models (F32, F16, and Q8_0) and verified they produce identical embeddings and encoder weights compared to the original PyTorch model (max difference: 0.0, mean difference: 0.0).

(Sorry for the duplicate PRs, I thought this was a misconfiguration on my part so I deleted my initial fork)

test_umt5_encoding.py
#!/usr/bin/env python3
"""
Test script to verify that the GGUF conversion produces the same encodings
as the original PyTorch model.
"""

import sys
from pathlib import Path

# Add gguf-py to path
sys.path.insert(0, str(Path(__file__).parent / 'gguf-py'))

import torch
import numpy as np
from transformers import AutoModel, AutoTokenizer
import gguf

def load_pytorch_model(model_path):
    """Load the original PyTorch model."""
    print(f"Loading PyTorch model from {model_path}...")
    model = AutoModel.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model.eval()
    return model, tokenizer

def load_gguf_model(gguf_path):
    """Load the GGUF model and extract encoder weights."""
    print(f"Loading GGUF model from {gguf_path}...")
    reader = gguf.GGUFReader(gguf_path)
    
    # Extract tensors
    tensors = {}
    for tensor in reader.tensors:
        tensors[tensor.name] = tensor.data
    
    return reader, tensors

def encode_text_pytorch(model, tokenizer, text):
    """Encode text using PyTorch model."""
    print(f"\nEncoding with PyTorch: '{text}'")
    inputs = tokenizer(text, return_tensors="pt", padding=True)
    
    with torch.no_grad():
        # Get encoder outputs
        encoder_outputs = model.encoder(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask']
        )
        last_hidden_state = encoder_outputs.last_hidden_state
    
    print(f"Output shape: {last_hidden_state.shape}")
    print(f"Output mean: {last_hidden_state.mean().item():.6f}")
    print(f"Output std: {last_hidden_state.std().item():.6f}")
    print(f"First 10 values: {last_hidden_state[0, 0, :10].numpy()}")
    
    return last_hidden_state.numpy(), inputs['input_ids']

def compare_embeddings(pytorch_model, gguf_tensors, token_ids):
    """Compare embedding lookup between PyTorch and GGUF."""
    print("\n=== Comparing Token Embeddings ===")
    
    # Get PyTorch embeddings
    pt_embed_weight = pytorch_model.shared.weight.detach().numpy()
    print(f"PyTorch embedding shape: {pt_embed_weight.shape}")
    
    # Get GGUF embeddings
    gguf_embed = gguf_tensors.get('token_embd.weight')
    if gguf_embed is None:
        print("ERROR: Could not find token_embd.weight in GGUF")
        return False
    
    print(f"GGUF embedding shape: {gguf_embed.shape}")
    
    # Compare a few token embeddings
    token_id = token_ids[0, 0].item()
    print(f"\nComparing embedding for token {token_id}:")
    
    pt_vec = pt_embed_weight[token_id]
    gguf_vec = gguf_embed[:, token_id] if gguf_embed.shape[0] < gguf_embed.shape[1] else gguf_embed[token_id]
    
    print(f"PyTorch first 10: {pt_vec[:10]}")
    print(f"GGUF first 10: {gguf_vec[:10]}")
    
    # Calculate difference
    diff = np.abs(pt_vec - gguf_vec)
    max_diff = np.max(diff)
    mean_diff = np.mean(diff)
    
    print(f"\nMax difference: {max_diff:.6e}")
    print(f"Mean difference: {mean_diff:.6e}")
    
    # Check if they're close (allowing for floating point precision)
    if max_diff < 1e-5:
        print("✓ Embeddings match!")
        return True
    else:
        print("✗ Embeddings differ!")
        return False

def compare_encoder_weights(pytorch_model, gguf_tensors):
    """Compare encoder layer weights."""
    print("\n=== Comparing Encoder Weights ===")
    
    # Compare first layer attention query weights
    layer_idx = 0
    pt_q_weight = pytorch_model.encoder.block[layer_idx].layer[0].SelfAttention.q.weight.detach().numpy()
    gguf_q_weight = gguf_tensors.get(f'enc.blk.{layer_idx}.attn_q.weight')
    
    if gguf_q_weight is None:
        print(f"ERROR: Could not find enc.blk.{layer_idx}.attn_q.weight in GGUF")
        return False
    
    print(f"PyTorch Q weight shape: {pt_q_weight.shape}")
    print(f"GGUF Q weight shape: {gguf_q_weight.shape}")
    
    # GGUF may transpose the weights
    if pt_q_weight.shape != gguf_q_weight.shape:
        if pt_q_weight.shape == gguf_q_weight.T.shape:
            print("Transposing GGUF weight to match PyTorch...")
            gguf_q_weight = gguf_q_weight.T
        else:
            print("ERROR: Shape mismatch!")
            return False
    
    print(f"PyTorch first 5x5:\n{pt_q_weight[:5, :5]}")
    print(f"GGUF first 5x5:\n{gguf_q_weight[:5, :5]}")
    
    diff = np.abs(pt_q_weight - gguf_q_weight)
    max_diff = np.max(diff)
    mean_diff = np.mean(diff)
    
    print(f"\nMax difference: {max_diff:.6e}")
    print(f"Mean difference: {mean_diff:.6e}")
    
    if max_diff < 1e-5:
        print("✓ Weights match!")
        return True
    else:
        print("✗ Weights differ!")
        return False

def main():
    model_path = "./models/umt5-xl"
    gguf_path = "./google-umt5-xl-f32.gguf"
    test_text = "Hello, world! This is a test."
    
    print("="*70)
    print("UMT5-XL GGUF Conversion Verification Test")
    print("="*70)
    
    # Load models
    pytorch_model, tokenizer = load_pytorch_model(model_path)
    reader, gguf_tensors = load_gguf_model(gguf_path)
    
    # Encode text with PyTorch
    pt_output, token_ids = encode_text_pytorch(pytorch_model, tokenizer, test_text)
    
    print(f"\nToken IDs: {token_ids}")
    
    # Compare embeddings
    embeddings_match = compare_embeddings(pytorch_model, gguf_tensors, token_ids)
    
    # Compare encoder weights
    weights_match = compare_encoder_weights(pytorch_model, gguf_tensors)
    
    # Summary
    print("\n" + "="*70)
    print("SUMMARY")
    print("="*70)
    print(f"Embeddings match: {'✓ YES' if embeddings_match else '✗ NO'}")
    print(f"Encoder weights match: {'✓ YES' if weights_match else '✗ NO'}")
    
    if embeddings_match and weights_match:
        print("\n✓ SUCCESS: GGUF conversion verified!")
        return 0
    else:
        print("\n✗ FAILURE: GGUF conversion has issues!")
        return 1

if __name__ == "__main__":
    exit(main())

Register UMT5Model as a supported architecture variant for T5 model
conversion. This allows converting models like google/umt5-xl that use
the UMT5Model architecture class instead of UMT5ForConditionalGeneration.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant