Our method can improve the embedding performance while retaining the generative performance.
If you extend or use this work, please cite the paper where it was introduced:
@misc{sun2025gracegenerativerepresentationlearning,
title={GRACE: Generative Representation Learning via Contrastive Policy Optimization},
author={Jiashuo Sun and Shixuan Liu and Zhaochen Su and Xianrui Zhong and Pengcheng Jiang and Bowen Jin and Peiran Li and Weijia Shi and Jiawei Han},
year={2025},
eprint={2510.04506},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2510.04506},
}
- [2025-10-02]: 🚀 Our paper is avaliable at https://arxiv.org/pdf/2510.04506.
- [2025-10-02]: 🚀 We release the code for training and evaluation.
- GRACE Overview
- Project Visualizations
- 📌 Data Processing Pipeline
- 🎯 Reinforcement Learning for Supervised Training
- 🤖 Reinforcement Learning for Unsupervised Training
- 🔍 Inference and Evaluation
- 📄 Licensing and Claims
We present GRACE (Generative Representation Learning via Contrastive
Policy Optimization), a framework that turns LLMs into interpretable representation learners using policy-gradient optimization. The model first produces explicit rationales
(1) We present the first empirical evidence that rewards derived from contrastive learning can be leveraged to train policy models, resulting in improved representational capabilities.
(2) We propose a novel methodology that enables the transformation of existing LLMs into powerful representation models while preserving their general-purpose capabilities without performance degradation.
(3) This work represents a substantial advancement in text representation interpretability, as the model’s reasoning can be directly inspected through its textual outputs.
(4) Our method yields a significant performance gain of avg 11.5 % over baseline models when evaluated on the MTEB benchmark.
You should install the environment by pip install -r requirements.txt.
Moreover, we developed our algorithm and pipeline based on verl, version 0.4.0dev, so after installing dependencies please run pip install -e . at the project root.
- Install dependencies and editable package:
pip install -r requirements.txt
pip install -e .- Prepare data (both supervised and unsupervised are supported). The following script wraps the steps in this README:
bash process_data.sh- Train (edit
data.train_files,data.val_files, and+data.train_modeintrain.shif needed):
bash train.sh- Evaluate on MTEB with vLLM (ensure merged checkpoints are pointed by
MODEL_PATHineval.sh):
bash eval.shGRACE/
assets/ # Figures used in README
scripts/ # Utilities such as model merging
verl/ # Training framework (0.4.0dev-based)
process_data.py # Convert raw data to parquet for training
offline_filter_data.py # Pre-filter overlong samples
eval_mteb.py # vLLM-based MTEB evaluation entry
train.sh # Supervised training example (GRPO)
eval.sh # Inference + evaluation example
process_data.sh # End-to-end data prep helper
requirements.txt
pyproject.toml # Packaging for editable install
README.md
You can find the data from following link: Supervised data: Data from Repetition Improves Language Model Embeddings
Unsupervised data: Wiki1m_for_simcse from SimCSE: Simple Contrastive Learning of Sentence Embeddings
After downloading the data, you can directly run:
#!/bin/bash
set -e # Exit on any error
echo "Processing supervised data..."
python process_data.py \
--input_file echo_data_total.jsonl \
--local_dir data/supervised \
--mode supervised \
--test_ratio 0.01
sleep 3
python offline_filter_data.py \
--train_parquet data/supervised/train.parquet \
--val_parquet data/supervised/test.parquet \
--out_dir data/supervised_filtered_overlong/ \
--tokenizer_path Qwen/Qwen2.5-1.5B-Instruct \
--max_len 1024
# echo "Processing unsupervised data..."
python process_data.py \
--input_file wiki1m_for_simcse.txt \
--local_dir data/unsupervised \
--mode unsupervised \
--test_ratio 0.01
sleep 3
python offline_filter_data.py \
--train_parquet data/unsupervised/train.parquet \
--val_parquet data/unsupervised/test.parquet \
--out_dir data/unsupervised_filtered_overlong/ \
--tokenizer_path Qwen/Qwen2.5-1.5B-Instruct \
--max_len 1024
which can be found at process_data.sh
This operation first processes the data into a format that verl can recognize. In addition, because the amount of data is large, we will process the overlong data in advance to prevent this step from taking too long during training.
To train the model, you can directly using following command:
#!/bin/bash
mkdir -p logs
export HF_DATASETS_CACHE="huggingface_cache"
export HF_HOME="huggingface_cache"
export CUDA_VISIBLE_DEVICES=0,1,2,3
export VLLM_TORCH_COMPILE_LEVEL=0
export TORCH_COMPILE_DISABLE=1
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
algorithm.norm_adv_by_std_in_grpo=False \
data.train_files=data/supervised_filtered_overlong/train.parquet \
data.val_files=data/supervised_filtered_overlong/test.parquet \
data.train_batch_size=64 \
data.val_batch_size=16 \
data.max_prompt_length=1024 \
data.max_response_length=2048 \
data.filter_overlong_prompts=False \
data.truncation='right' \
+data.train_mode=supervised \
actor_rollout_ref.model.path=Qwen/Qwen2.5-1.5B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.actor.fsdp_config.offload_policy=True \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_mini_batch_size=16 \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=16384 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.rollout.enable_chunked_prefill=False \
actor_rollout_ref.rollout.enforce_eager=True \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.free_cache_engine=False \
actor_rollout_ref.rollout.n=8 \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \
algorithm.use_kl_in_reward=False \
reward_model.enable=False \
reward_model.reward_manager=hidden \
trainer.critic_warmup=0 \
+reward_model.reward_kwargs.temperature=0.1 \
+reward_model.reward_kwargs.with_scale=True \
+reward_model.reward_kwargs.clustering_weight=0.2 \
+reward_model.reward_kwargs.cross_group_weight=0.2 \
trainer.logger='["wandb"]' \
trainer.project_name='GRACE' \
trainer.experiment_name='test_exp' \
trainer.n_gpus_per_node=4 \
trainer.nnodes=1 \
trainer.save_freq=50 \
trainer.test_freq=-1 \
trainer.val_before_train=False \
trainer.total_epochs=2which can be found at train.sh. You need to modify data.train_files, data.val_files, and +data.train_mode before training.
🛠️ We used 4× H100-80G GPUs to train all models.
To train the unsupervised model, you can directly using following command:
#!/bin/bash
mkdir -p logs
export HF_DATASETS_CACHE="huggingface_cache"
export HF_HOME="huggingface_cache"
export CUDA_VISIBLE_DEVICES=0,1,2,3
export VLLM_TORCH_COMPILE_LEVEL=0
export TORCH_COMPILE_DISABLE=1
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
algorithm.norm_adv_by_std_in_grpo=False \
data.train_files=data/unsupervised_filtered_overlong/train.parquet \
data.val_files=data/unsupervised_filtered_overlong/test.parquet \
data.train_batch_size=64 \
data.val_batch_size=16 \
data.max_prompt_length=1024 \
data.max_response_length=2048 \
data.filter_overlong_prompts=False \
data.truncation='right' \
+data.train_mode=unsupervised \
actor_rollout_ref.model.path=Qwen/Qwen2.5-1.5B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.actor.fsdp_config.offload_policy=True \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_mini_batch_size=16 \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=16384 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.rollout.enable_chunked_prefill=False \
actor_rollout_ref.rollout.enforce_eager=True \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.free_cache_engine=False \
actor_rollout_ref.rollout.n=8 \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \
algorithm.use_kl_in_reward=False \
reward_model.enable=False \
reward_model.reward_manager=hidden \
trainer.critic_warmup=0 \
+reward_model.reward_kwargs.temperature=0.1 \
+reward_model.reward_kwargs.with_scale=True \
+reward_model.reward_kwargs.clustering_weight=0.2 \
+reward_model.reward_kwargs.cross_group_weight=0.2 \
trainer.logger='["wandb"]' \
trainer.project_name='GRACE' \
trainer.experiment_name='test_exp' \
trainer.n_gpus_per_node=4 \
trainer.nnodes=1 \
trainer.save_freq=50 \
trainer.test_freq=-1 \
trainer.val_before_train=False \
trainer.total_epochs=2Also remember to modify the data.train_files, data.val_files, data.train_mode before training.
We use vLLM for efficient inference, and use MTEB benchmark for evaluation.
#!/bin/bash
# Environment variables
export CUDA_VISIBLE_DEVICES=0,1,2,3
export VLLM_TORCH_COMPILE_LEVEL=0
export TORCH_COMPILE_DISABLE=1
export VLLM_USE_MODELSCOPE=false
export VLLM_WORKER_MULTIPROC_METHOD=spawn
export VLLM_ENGINE_ITERATION_TIMEOUT_S=1800
export TOKENIZERS_PARALLELISM=false
export INPUT_MAX_LENGTH=1024
export THINKING_MAX_LENGTH=2048
export POOLING_METHOD=mean
export VLLM_TENSOR_PARALLEL_SIZE=4
export VLLM_GPU_MEMORY_UTILIZATION=0.4
export HF_DATASETS_CACHE="huggingface_cache"
export HF_HOME="huggingface_cache"
export TRAIN_MODEL_NAME=checkpoints/
export TARGET_DIR=merge_models/
# Create target directory if it doesn't exist
mkdir -p "$TARGET_DIR"
echo "Starting model merge..."
# Merge model
python scripts/merge_model.py merge --backend fsdp \
--hf_model_path Qwen/Qwen2.5-1.5B-Instruct \
--local_dir "$TRAIN_MODEL_NAME" \
--target_dir "$TARGET_DIR"
# echo "Model merge completed. Starting evaluation..."
# Task configuration by category
# Retrieval (15 tasks)
RETRIEVAL_TASKS="ArguAna CQADupstackRetrieval FiQA2018 NFCorpus SCIDOCS SciFact ClimateFEVER DBPedia NQ FEVER HotpotQA MSMARCO QuoraRetrieval TRECCOVID Touche2020"
# Reranking (4 tasks)
RERANKING_TASKS="AskUbuntuDupQuestions MindSmallReranking SciDocsRR StackOverflowDupQuestions"
# Clustering (11 tasks)
CLUSTERING_TASKS="ArxivClusteringP2P ArxivClusteringS2S BiorxivClusteringP2P BiorxivClusteringS2S MedrxivClusteringP2P MedrxivClusteringS2S RedditClustering RedditClusteringP2P StackExchangeClustering TwentyNewsgroupsClustering"
# Pair Classification (3 tasks)
PAIR_CLASSIFICATION_TASKS="SprintDuplicateQuestions TwitterSemEval2015 TwitterURLCorpus"
# Classification (12 tasks)
CLASSIFICATION_TASKS="AmazonCounterfactualClassification AmazonPolarityClassification AmazonReviewsClassification Banking77Classification EmotionClassification ImdbClassification ToxicConversationsClassification MTOPDomainClassification MassiveIntentClassification MassiveScenarioClassification MTOPIntentClassification TweetSentimentClassification"
# STS - Semantic Textual Similarity (10 tasks)
STS_TASKS="BIOSSES SICK-R STS12 STS13 STS14 STS15 STS16 STS17 STSBenchmark STS22"
# Summarization (1 task)
SUMMARIZATION_TASKS="SummEval"
# All tasks combined (excluding retrieval for now)
ALL_TASKS="$STS_TASKS $SUMMARIZATION_TASKS $RETRIEVAL_TASKS $RERANKING_TASKS $CLUSTERING_TASKS $PAIR_CLASSIFICATION_TASKS $CLASSIFICATION_TASKS"
# Subset tasks
SUBSET_TASKS="Banking77Classification EmotionClassification MedrxivClusteringS2S TwitterSemEval2015 AskUbuntuDupQuestions BIOSSES STS17 STSBenchmark SummEval"
# Choose which tasks to evaluate (modify this line as needed)
EVAL_TASKS=$ALL_TASKS # or $SUBSET_TASKS
# Model configuration
MODEL_NAME="GRACE-1.5B"
MODEL_PATH="$TARGET_DIR"
# Initialize timing log file
TIMING_LOG="timing_results_$(date +%Y%m%d_%H%M%S).txt"
echo "Task Timing Report - $(date)" > $TIMING_LOG
echo "=================================" >> $TIMING_LOG
# Overall start time
OVERALL_START=$(date +%s)
echo "Starting evaluation..."
python eval_mteb.py \
--model_path "$MODEL_PATH" \
--tasks $EVAL_TASKS \
--output_dir results \
--batch_size 16 \
--input_max_length $INPUT_MAX_LENGTH \
--thinking_max_length $THINKING_MAX_LENGTH \
--vllm_tensor_parallel_size $VLLM_TENSOR_PARALLEL_SIZE \
--vllm_gpu_memory_utilization $VLLM_GPU_MEMORY_UTILIZATION \
--model_name "$MODEL_NAME" \
--pooling_method $POOLING_METHOD
# Overall end time
OVERALL_END=$(date +%s)
OVERALL_DURATION=$((OVERALL_END - OVERALL_START))
OVERALL_MINUTES=$(echo "scale=2; $OVERALL_DURATION / 60" | bc)
echo "=================================" >> $TIMING_LOG
echo "Total evaluation time: $OVERALL_MINUTES minutes" >> $TIMING_LOG
echo "All tasks: $EVAL_TASKS" >> $TIMING_LOG
echo "=================================" >> $TIMING_LOG
# Print timing report
echo ""
echo "====== TIMING REPORT ======"
cat $TIMING_LOG
echo "==========================="
echo "Evaluation completed! All tasks processed with single model load."
which can be found at eval.sh. The direct inference code will be coming soon.
Explore the key components and performance of GRACE through the following images:
- Pipeline of GRACE:
- Results for Supervised training:
- Results for Unsupervised training:
- Case Study:
- FlashAttention build errors: ensure CUDA/torch version compatibility and set
TORCH_CUDA_ARCH_LIST; refer toflash-attndocs if compilation fails. - vLLM OOM/timeout: lower
VLLM_GPU_MEMORY_UTILIZATION, reduceVLLM_TENSOR_PARALLEL_SIZE, or shortenTHINKING_MAX_LENGTH. - PyArrow/Parquet issues: this repo expects
pyarrow>=19.0.0. Reinstall withpip install -U pyarrowif parquet IO fails. - MTEB version mismatch: we pin
mteb==1.38.33. Align the version to reproduce scores.
If you find GRACE useful in your research, please cite this repository. A formal paper/bib will be added upon publication.
This project builds upon excellent open-source efforts, including vLLM and verl.
This project is licensed under the Apache 2.0 protocol. The project assumes no legal responsibility for any output generated by the models and will not be held liable for any damages resulting from the use of the provided resources and outputs.