Skip to content

Commit d5b24ec

Browse files
authored
add reranker (labring#679)
1 parent 2e75851 commit d5b24ec

File tree

4 files changed

+171
-0
lines changed

4 files changed

+171
-0
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
2+
3+
# please download the model from https://huggingface.co/BAAI/bge-reranker-base and put it in the same directory as Dockerfile
4+
COPY ./bge-reranker-base ./bge-reranker-base
5+
6+
COPY app.py Dockerfile requirement.txt .
7+
8+
RUN python3 -m pip install -r requirement.txt
9+
10+
ENTRYPOINT python3 app.py
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
2+
## 推荐配置
3+
4+
推荐配置如下:
5+
6+
{{< table "table-hover table-striped-columns" >}}
7+
| 类型 | 内存 | 显存 | 硬盘空间 | 启动命令 |
8+
|------|---------|---------|----------|--------------------------|
9+
| base | >=4GB | >=3GB | >=8GB | python app.py |
10+
{{< /table >}}
11+
12+
## 部署
13+
14+
### 环境要求
15+
16+
- Python 3.10.11
17+
- CUDA 11.7
18+
- 科学上网环境
19+
20+
### 源码部署
21+
22+
1. 根据上面的环境配置配置好环境,具体教程自行 GPT;
23+
2. 下载 [python 文件](app.py)
24+
3. 在命令行输入命令 `pip install -r requirments.txt`
25+
4. 按照[https://huggingface.co/BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)下载模型仓库到app.py同级目录
26+
5. 添加环境变量 `export ACCESS_TOKEN=XXXXXX` 配置 token,这里的 token 只是加一层验证,防止接口被人盗用,默认值为 `ACCESS_TOKEN`
27+
6. 执行命令 `python app.py`
28+
29+
然后等待模型下载,直到模型加载完毕为止。如果出现报错先问 GPT。
30+
31+
启动成功后应该会显示如下地址:
32+
33+
![](/imgs/chatglm2.png)
34+
35+
> 这里的 `http://0.0.0.0:6006` 就是连接地址。
36+
37+
### docker 部署
38+
39+
**镜像和端口**
40+
41+
+ 镜像名: `luanshaotong/reranker:v0.1`
42+
+ 端口号: 6006
43+
44+
```
45+
# 设置安全凭证(即oneapi中的渠道密钥)
46+
通过环境变量ACCESS_TOKEN引入,默认值:ACCESS_TOKEN。
47+
有关docker环境变量引入的方法请自寻教程,此处不再赘述。
48+
```
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
"""
4+
@Time: 2023/11/7 22:45
5+
@Author: zhidong
6+
@File: reranker.py
7+
@Desc:
8+
"""
9+
import os
10+
import numpy as np
11+
import logging
12+
import uvicorn
13+
import datetime
14+
from fastapi import FastAPI, Security, HTTPException
15+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
16+
from FlagEmbedding import FlagReranker
17+
from pydantic import Field, BaseModel, validator
18+
from typing import Optional, List
19+
20+
def response(code, msg, data=None):
21+
time = str(datetime.datetime.now())
22+
if data is None:
23+
data = []
24+
result = {
25+
"code": code,
26+
"message": msg,
27+
"data": data,
28+
"time": time
29+
}
30+
return result
31+
32+
def success(data=None, msg=''):
33+
return
34+
35+
36+
class Inputs(BaseModel):
37+
id: str
38+
text: Optional[str]
39+
40+
41+
class QADocs(BaseModel):
42+
query: Optional[str]
43+
inputs: Optional[List[Inputs]]
44+
45+
46+
class Singleton(type):
47+
def __call__(cls, *args, **kwargs):
48+
if not hasattr(cls, '_instance'):
49+
cls._instance = super().__call__(*args, **kwargs)
50+
return cls._instance
51+
52+
53+
RERANK_MODEL_PATH = os.path.join(os.path.dirname(__file__), "bge-reranker-base")
54+
55+
class Reranker(metaclass=Singleton):
56+
def __init__(self, model_path):
57+
self.reranker = FlagReranker(model_path,
58+
use_fp16=False)
59+
60+
def compute_score(self, pairs: List[List[str]]):
61+
if len(pairs) > 0:
62+
result = self.reranker.compute_score(pairs)
63+
if isinstance(result, float):
64+
result = [result]
65+
return result
66+
else:
67+
return None
68+
69+
70+
class Chat(object):
71+
def __init__(self, rerank_model_path: str = RERANK_MODEL_PATH):
72+
self.reranker = Reranker(rerank_model_path)
73+
74+
def fit_query_answer_rerank(self, query_docs: QADocs) -> List:
75+
if query_docs is None or len(query_docs.inputs) == 0:
76+
return []
77+
new_docs = []
78+
pair = []
79+
for answer in query_docs.inputs:
80+
pair.append([query_docs.query, answer.text])
81+
scores = self.reranker.compute_score(pair)
82+
for index, score in enumerate(scores):
83+
new_docs.append({"id": query_docs.inputs[index].id, "score": 1 / (1 + np.exp(-score))})
84+
new_docs = list(sorted(new_docs, key=lambda x: x["score"], reverse=True))
85+
return new_docs
86+
87+
app = FastAPI()
88+
security = HTTPBearer()
89+
env_bearer_token = 'ACCESS_TOKEN'
90+
91+
@app.post('/api/v1/rerank')
92+
async def handle_post_request(docs: QADocs, credentials: HTTPAuthorizationCredentials = Security(security)):
93+
token = credentials.credentials
94+
if env_bearer_token is not None and token != env_bearer_token:
95+
raise HTTPException(status_code=401, detail="Invalid token")
96+
chat = Chat()
97+
qa_docs_with_rerank = chat.fit_query_answer_rerank(docs)
98+
return response(200, msg="重排成功", data=qa_docs_with_rerank)
99+
100+
if __name__ == "__main__":
101+
token = os.getenv("ACCESS_TOKEN")
102+
if token is not None:
103+
env_bearer_token = token
104+
try:
105+
uvicorn.run(app, host='0.0.0.0', port=6006)
106+
except Exception as e:
107+
print(f"API启动失败!\n报错:\n{e}")
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
fastapi==0.104.1
2+
FlagEmbedding==1.1.5
3+
pydantic==1.10.13
4+
uvicorn==0.17.6
5+
itsdangerous
6+
protobuf

0 commit comments

Comments
 (0)