Skip to content

failure converting pretrained litgpt checkpoints to HF format: a reproducible example #1871

Open
@2533245542

Description

@2533245542

Bug description

there will be an OS File Error if following this example

## 1. pretrain the model in litgpt format
mkdir -p custom_texts
curl https://www.gutenberg.org/cache/epub/24440/pg24440.txt --output custom_texts/book1.txt
curl https://www.gutenberg.org/cache/epub/26393/pg26393.txt --output custom_texts/book2.txt

litgpt download EleutherAI/pythia-14m --tokenizer_only True

litgpt pretrain EleutherAI/pythia-14m \
  --tokenizer_dir EleutherAI/pythia-14m \
  --data TextFiles \
  --data.train_data_path "custom_texts/" \
  --train.max_tokens 10_000_000 \
  --out_dir out/custom-model

## 2. convert to hf format
litgpt convert_from_litgpt out/custom-model/final converted_dir
cp out/custom-model/final/config.json converted_dir


## 3. load the model in python
import torch
from transformers import AutoModel

model_pth_path = 'converted_dir/model.pth'
model_dir_path = 'converted_dir/'
state_dict = torch.load(model_pth_path)
model = AutoModel.from_pretrained(model_dir_path, state_dict=state_dict, local_files_only=True)

OSError: Error no file named pytorch_model.bin, tf_model.h5, model.ckpt.index or flax_model.msgpack found in directory converted_dir/.

Edit:

There two approaches though work.

import torch
from transformers import AutoModel

model_pth_path = 'converted_dir/model.pth'
model_dir_path = 'converted_dir/'

state_dict = torch.load(model_pth_path)
model = AutoModel.from_pretrained('EleutherAI/pythia-14m', state_dict=state_dict)
model.embed_in._parameters['weight'][0][:2]
state_dict['gpt_neox.embed_in.weight'][0][:2]
AutoModel.from_pretrained('EleutherAI/pythia-14m').embed_in._parameters['weight'][0][:2]  # compare with huggingface's version to ensure it is actually loaded
import torch
from transformers import AutoConfig, AutoModelForCausalLM

model_pth_path = 'converted_dir/model.pth'
model_dir_path = 'converted_dir/'

state_dict = torch.load(model_pth_path)
config = AutoConfig.from_pretrained(model_dir_path)
model = AutoModelForCausalLM.from_config(config)
model.load_state_dict(state_dict)

What operating system are you using?

macOS

LitGPT Version

Version: 0.5.3

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions