Skip to content

Commit b0aaacd

Browse files
author
Tianyu Gao
committed
Merge branch 'develop' into main
2 parents 354784b + becfbf0 commit b0aaacd

24 files changed

+1701
-28
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Byte-compiled / optimized / DLL files
22
__pycache__/
3+
simcse/__pycache__/
34
*.py[cod]
45
*$py.class
56

@@ -128,3 +129,4 @@ dmypy.json
128129
# Pyre type checker
129130
.pyre/
130131
.DS_Store
132+
.vscode

README.md

Lines changed: 81 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,22 @@ Wait a minute! The authors are working day and night 💪, to make the code and
1111
We anticipate the code will be out * **in one week** *. -->
1212

1313
<!-- * 4/26: SimCSE is now on [Gradio Web Demo](https://gradio.app/g/AK391/SimCSE) (Thanks [@AK391](https://github.com/AK391)!). Try it out! -->
14+
* 5/10: We released our [sentence embedding tool](#getting-started) and [demo code](./demo).
1415
* 4/23: We released our [training code](#training).
1516
* 4/20: We released our [model checkpoints](#use-our-models-out-of-the-box) and [evaluation code](#evaluation).
1617
* 4/18: We released [our paper](https://arxiv.org/pdf/2104.08821.pdf). Check it out!
1718

1819

19-
## Quick links
20+
## Quick Links
2021

2122
- [Overview](#overview)
22-
- [Pre-trained sentence embeddings](#use-our-models-out-of-the-box)
23-
- [Requirements](#requirements)
24-
- [Evaluation](#evaluation)
25-
- [Training](#training)
23+
- [Getting Started](#getting-started)
24+
- [Model List](#model-list)
25+
- [Use SimCSE with Huggingface](#use-our-models-out-of-the-box)
26+
- [Train SimCSE](#train-simcse)
27+
- [Requirements](#requirements)
28+
- [Evaluation](#evaluation)
29+
- [Training](#training)
2630
- [Bugs or Questions?](#Bugs-or-questions)
2731
- [Citation](#citation)
2832
- [SimCSE Elsewhere](#simcse-elsewhere)
@@ -33,22 +37,71 @@ We propose a simple contrastive learning framework that works with both unlabele
3337

3438
![](figure/model.png)
3539

36-
## Use our models out of the box
37-
Our pre-trained models are now publicly available with [HuggingFace's Transformers](https://github.com/huggingface/transformers). Models and their performance are presented as follows:
40+
## Getting Started
41+
42+
We provide an easy-to-use sentence embedding tool based on our SimCSE model. To use the tool, first install the `simcse` package from pypi
43+
```bash
44+
pip install simcse
45+
```
46+
47+
Or directly install it from our code
48+
```bash
49+
python setup.py install
50+
```
51+
52+
Note that if you want to enable GPU encoding, you should install the correct version of PyTorch that supports CUDA. See [PyTorch official website](https://pytorch.org) for instructions.
53+
54+
After installing the package, you can load our model by just two lines of code
55+
```python
56+
from simcse import SimCSE
57+
model = SimCSE("princeton-nlp/sup-simcse-bert-base-uncased")
58+
```
59+
See [model list](#model-list) for a full list of available models.
60+
61+
Then you can use our model for **encoding sentences into embeddings**
62+
```python
63+
embeddings = model.encode("A woman is reading.")
64+
```
65+
66+
**Compute the cosine similarities** between two groups of sentences
67+
```python
68+
sentences_a = ['A woman is reading.', 'A man is playing a guitar.']
69+
sentences_b = ['He plays guitar.', 'A woman is making a photo.']
70+
similarities = model.similarity(sentences_a, sentences_b)
71+
```
72+
73+
Or build index for a group of sentences and **search** among them
74+
```python
75+
sentences = ['A woman is reading.', 'A man is playing a guitar.']
76+
model.build_index(sentences)
77+
results = model.search("He plays guitar.")
78+
```
79+
80+
We also support [faiss](https://github.com/facebookresearch/faiss), an efficient similarity search library. Just install the package following [instructions](https://github.com/facebookresearch/faiss/blob/master/INSTALL.md) here and `simcse` will automatically use `faiss` for efficient search.
81+
82+
**WARNING**: We have found that `faiss` did not well support Nvidia AMPERE GPUs (3090 and A100). In that case, you should change to other GPUs or install the CPU version of `faiss` package.
83+
84+
We also provide an easy-to-build [demo website](./demo) to show how SimCSE can be used in sentence retrieval.
85+
86+
## Model List
87+
88+
Our released models are listed as following. You can import these models by using the `simcse` package or using [HuggingFace's Transformers](https://github.com/huggingface/transformers).
3889
| Model | Avg. STS |
39-
|:-------------------------------:|:--------:|
40-
| [unsup-simcse-bert-base-uncased](https://huggingface.co/princeton-nlp/unsup-simcse-bert-base-uncased) | 74.54 |
41-
| [unsup-simcse-bert-large-uncased](https://huggingface.co/princeton-nlp/unsup-simcse-bert-large-uncased) | 76.05 |
42-
| [unsup-simcse-roberta-base](https://huggingface.co/princeton-nlp/unsup-simcse-roberta-base) | 76.50 |
43-
| [unsup-simcse-roberta-large](https://huggingface.co/princeton-nlp/unsup-simcse-roberta-large) | 77.47 |
44-
| [sup-simcse-bert-base-uncased](https://huggingface.co/princeton-nlp/sup-simcse-bert-base-uncased) | 81.57 |
45-
| [sup-simcse-bert-large-uncased](https://huggingface.co/princeton-nlp/sup-simcse-bert-large-uncased) | 82.21 |
46-
| [sup-simcse-roberta-base](https://huggingface.co/princeton-nlp/sup-simcse-roberta-base) | 82.52 |
47-
| [sup-simcse-roberta-large](https://huggingface.co/princeton-nlp/sup-simcse-roberta-large) | 83.76 |
90+
|:-------------------------------|:--------:|
91+
| [princeton-nlp/unsup-simcse-bert-base-uncased](https://huggingface.co/princeton-nlp/unsup-simcse-bert-base-uncased) | 74.54 |
92+
| [princeton-nlp/unsup-simcse-bert-large-uncased](https://huggingface.co/princeton-nlp/unsup-simcse-bert-large-uncased) | 76.05 |
93+
| [princeton-nlp/unsup-simcse-roberta-base](https://huggingface.co/princeton-nlp/unsup-simcse-roberta-base) | 76.50 |
94+
| [princeton-nlp/unsup-simcse-roberta-large](https://huggingface.co/princeton-nlp/unsup-simcse-roberta-large) | 77.47 |
95+
| [princeton-nlp/sup-simcse-bert-base-uncased](https://huggingface.co/princeton-nlp/sup-simcse-bert-base-uncased) | 81.57 |
96+
| [princeton-nlp/sup-simcse-bert-large-uncased](https://huggingface.co/princeton-nlp/sup-simcse-bert-large-uncased) | 82.21 |
97+
| [princeton-nlp/sup-simcse-roberta-base](https://huggingface.co/princeton-nlp/sup-simcse-roberta-base) | 82.52 |
98+
| [princeton-nlp/sup-simcse-roberta-large](https://huggingface.co/princeton-nlp/sup-simcse-roberta-large) | 83.76 |
4899

49100
**Naming rules**: `unsup` and `sup` represent "unsupervised" (trained on Wikipedia corpus) and "supervised" (trained on NLI datasets) respectively.
50101

51-
You can easily import our model in an out-of-the-box way with HuggingFace's API:
102+
## Use SimCSE with Huggingface
103+
104+
Besides using our provided sentence embedding tool, you can also easily import our models with HuggingFace's `transformers`:
52105
```python
53106
import torch
54107
from scipy.spatial.distance import cosine
@@ -81,9 +134,11 @@ print("Cosine similarity between \"%s\" and \"%s\" is: %.3f" % (texts[0], texts[
81134

82135
If you encounter any problem when directly loading the models by HuggingFace's API, you can also download the models manually from the above table and use `model = AutoModel.from_pretrained({PATH TO THE DOWNLOAD MODEL})`.
83136

84-
If you only want to use our models in an out-of-the-box way, just installing the latest version of `torch`, `transformers` and `scipy` is enough. If you want to use our training or evaluation code, see the requirement section below.
137+
## Train SimCSE
85138

86-
## Requirements
139+
In the following section, we describe how to train a SimCSE model by using our code.
140+
141+
### Requirements
87142

88143
First, install PyTorch by following the instructions from [the official website](https://pytorch.org). To faithfully reproduce our results, please use the correct `1.7.1` version corresponding to your platforms/CUDA versions. PyTorch version higher than `1.7.1` should also work. For example, if you use Linux and **CUDA11** ([how to check CUDA version](https://varhowto.com/check-cuda-version/)), install PyTorch by the following command,
89144

@@ -104,7 +159,7 @@ Then run the following script to install the remaining dependencies,
104159
pip install -r requirements.txt
105160
```
106161

107-
## Evaluation
162+
### Evaluation
108163
Our evaluation code for sentence embeddings is based on a modified version of [SentEval](https://github.com/facebookresearch/SentEval). It evaluates sentence embeddings on semantic textual similarity (STS) tasks and downstream transfer tasks. For STS tasks, our evaluation takes the "all" setting, and report Spearman's correlation. See [our paper](https://arxiv.org/pdf/2104.08821.pdf) (Appendix B) for evaluation details.
109164

110165
Before evaluation, please download the evaluation datasets by running
@@ -151,13 +206,13 @@ Arguments for the evaluation script are as follows,
151206
* `na`: Manually set tasks by `--tasks`.
152207
* `--tasks`: Specify which dataset(s) to evaluate on. Will be overridden if `--task_set` is not `na`. See the code for a full list of tasks.
153208

154-
## Training
209+
### Training
155210

156-
### Data
211+
#### Data
157212

158213
For unsupervised SimCSE, we sample 1 million sentences from English Wikipedia; for supervised SimCSE, we use the SNLI and MNLI datasets. You can run `data/download_wiki.sh` and `data/download_nli.sh` to download the two datasets.
159214

160-
### Training scripts
215+
#### Training scripts
161216

162217
We provide example training scripts for both unsupervised and supervised SimCSE. In `run_unsup_example.sh`, we provide a single-GPU (or CPU) example for the unsupervised version, and in `run_sup_example.sh` we give a **multiple-GPU** example for the supervised version. Both scripts call `train.py` for training. We explain the arguments in following:
163218
* `--train_file`: Training file path. We support "txt" files (one line for one sentence) and "csv" files (2-column: pair data with no hard negative; 3-column: pair data with one corresponding hard negative instance). You can use our provided Wikipedia or NLI data, or you can use your own data with the same format.
@@ -173,10 +228,12 @@ All the other arguments are standard Huggingface's `transformers` training argum
173228

174229
**REPRODUCTION**: For results in the paper, we use Nvidia 3090 GPUs with CUDA 11. Using different types of devices or different versions of CUDA/other softwares may lead to slightly different performance.
175230

176-
### Convert models
231+
#### Convert models
177232

178233
**IMPORTANT**: Our saved checkpoints are slightly different from Huggingface's pre-trained checkpoints. Run `python simcse_to_huggingface.py --path {PATH_TO_CHECKPOINT_FOLDER}` to convert it. After that, you can evaluate it by our [evaluation](#evaluation) code or directly use it [out of the box](#use-our-models-out-of-the-box).
179234

235+
236+
180237
## Bugs or questions?
181238

182239
If you have any questions related to the code or the paper, feel free to email Tianyu (`[email protected]`) and Xingcheng (`[email protected]`). If you encounter any problems when using the code, or want to report a bug, you can open an issue. Please try to specify the problem with details so we can help you better and quicker!

demo/README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
## Demo of SimCSE
2+
Several demos are available for people to play with our pre-trained SimCSE.
3+
4+
### Flask Demo
5+
<div align="center">
6+
<img src="../figure/demo.gif" width="750">
7+
</div>
8+
9+
We provide a simple Web demo based on [flask](https://github.com/pallets/flask) to show how SimCSE can be directly used for information retrieval. To run this flask demo locally, make sure the SimCSE inference interfaces are setup:
10+
```bash
11+
git clone https://github.com/princeton-nlp/SimCSE
12+
cd SimCSE
13+
python setup.py develop
14+
```
15+
Then you can use `run_demo_example.sh` to launch the demo. As a default setting, we build the index for 1000 sentences sampled from STS-B dataset. Feel free to build the index of your own corpora. You can also install [faiss](https://github.com/facebookresearch/faiss) to speed up the retrieval process.
16+
17+
### Gradio Demo
18+
[AK391](https://github.com/AK391) has provided a [Gradio Web Demo](https://gradio.app/g/AK391/SimCSE) of SimCSE to show how the pre-trained models can predict the semantic similarity between two sentences.

demo/flaskdemo.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import json
2+
import argparse
3+
import torch
4+
import os
5+
import random
6+
import numpy as np
7+
import requests
8+
import logging
9+
import math
10+
import copy
11+
import string
12+
13+
from tqdm import tqdm
14+
from time import time
15+
from flask import Flask, request, jsonify
16+
from flask_cors import CORS
17+
from tornado.wsgi import WSGIContainer
18+
from tornado.httpserver import HTTPServer
19+
from tornado.ioloop import IOLoop
20+
21+
from simcse import SimCSE
22+
23+
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
24+
level=logging.INFO)
25+
logger = logging.getLogger(__name__)
26+
27+
def run_simcse_demo(port, args):
28+
app = Flask(__name__, static_folder='./static')
29+
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False
30+
CORS(app)
31+
32+
sentence_path = os.path.join(args.sentences_dir, args.example_sentences)
33+
query_path = os.path.join(args.sentences_dir, args.example_query)
34+
embedder = SimCSE(args.model_name_or_path)
35+
embedder.build_index(sentence_path)
36+
@app.route('/')
37+
def index():
38+
return app.send_static_file('index.html')
39+
40+
@app.route('/api', methods=['GET'])
41+
def api():
42+
query = request.args['query']
43+
top_k = int(request.args['topk'])
44+
threshold = float(request.args['threshold'])
45+
start = time()
46+
results = embedder.search(query, top_k=top_k, threshold=threshold)
47+
ret = []
48+
out = {}
49+
for sentence, score in results:
50+
ret.append({"sentence": sentence, "score": score})
51+
span = time() - start
52+
out['ret'] = ret
53+
out['time'] = "{:.4f}".format(span)
54+
return jsonify(out)
55+
56+
@app.route('/files/<path:path>')
57+
def static_files(path):
58+
return app.send_static_file('files/' + path)
59+
60+
@app.route('/get_examples', methods=['GET'])
61+
def get_examples():
62+
with open(query_path, 'r') as fp:
63+
examples = [line.strip() for line in fp.readlines()]
64+
return jsonify(examples)
65+
66+
addr = args.ip + ":" + args.port
67+
logger.info(f'Starting Index server at {addr}')
68+
http_server = HTTPServer(WSGIContainer(app))
69+
http_server.listen(port)
70+
IOLoop.instance().start()
71+
72+
if __name__=="__main__":
73+
parser = argparse.ArgumentParser()
74+
parser.add_argument('--model_name_or_path', default=None, type=str)
75+
parser.add_argument('--device', default='cpu', type=str)
76+
parser.add_argument('--sentences_dir', default=None, type=str)
77+
parser.add_argument('--example_query', default=None, type=str)
78+
parser.add_argument('--example_sentences', default=None, type=str)
79+
parser.add_argument('--port', default='8888', type=str)
80+
parser.add_argument('--ip', default='http://127.0.0.1')
81+
parser.add_argument('--load_light', default=False, action='store_true')
82+
args = parser.parse_args()
83+
84+
run_simcse_demo(args.port, args)
File renamed without changes.

demo/run_demo_example.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#!/bin/bash
2+
3+
# This example shows how to run the flask demo of SimCSE
4+
5+
python flaskdemo.py \
6+
--model_name_or_path princeton-nlp/sup-simcse-bert-base-uncased \
7+
--sentences_dir ./static/ \
8+
--example_query example_query.txt \
9+
--example_sentences example_sentence.txt

demo/static/example_query.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
a man is playing music
2+
a woman is making a photo
3+
a woman is taking some food

0 commit comments

Comments
 (0)