1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @Time: 2023/11/7 22:45
5
+ @Author: zhidong
6
+ @File: reranker.py
7
+ @Desc:
8
+ """
9
+ import os
10
+ import numpy as np
11
+ import logging
12
+ import uvicorn
13
+ import datetime
14
+ from fastapi import FastAPI , Security , HTTPException
15
+ from fastapi .security import HTTPBearer , HTTPAuthorizationCredentials
16
+ from FlagEmbedding import FlagReranker
17
+ from pydantic import Field , BaseModel , validator
18
+ from typing import Optional , List
19
+
20
+ def response (code , msg , data = None ):
21
+ time = str (datetime .datetime .now ())
22
+ if data is None :
23
+ data = []
24
+ result = {
25
+ "code" : code ,
26
+ "message" : msg ,
27
+ "data" : data ,
28
+ "time" : time
29
+ }
30
+ return result
31
+
32
+ def success (data = None , msg = '' ):
33
+ return
34
+
35
+
36
+ class Inputs (BaseModel ):
37
+ id : str
38
+ text : Optional [str ]
39
+
40
+
41
+ class QADocs (BaseModel ):
42
+ query : Optional [str ]
43
+ inputs : Optional [List [Inputs ]]
44
+
45
+
46
+ class Singleton (type ):
47
+ def __call__ (cls , * args , ** kwargs ):
48
+ if not hasattr (cls , '_instance' ):
49
+ cls ._instance = super ().__call__ (* args , ** kwargs )
50
+ return cls ._instance
51
+
52
+
53
+ RERANK_MODEL_PATH = os .path .join (os .path .dirname (__file__ ), "bge-reranker-base" )
54
+
55
+ class Reranker (metaclass = Singleton ):
56
+ def __init__ (self , model_path ):
57
+ self .reranker = FlagReranker (model_path ,
58
+ use_fp16 = False )
59
+
60
+ def compute_score (self , pairs : List [List [str ]]):
61
+ if len (pairs ) > 0 :
62
+ result = self .reranker .compute_score (pairs )
63
+ if isinstance (result , float ):
64
+ result = [result ]
65
+ return result
66
+ else :
67
+ return None
68
+
69
+
70
+ class Chat (object ):
71
+ def __init__ (self , rerank_model_path : str = RERANK_MODEL_PATH ):
72
+ self .reranker = Reranker (rerank_model_path )
73
+
74
+ def fit_query_answer_rerank (self , query_docs : QADocs ) -> List :
75
+ if query_docs is None or len (query_docs .inputs ) == 0 :
76
+ return []
77
+ new_docs = []
78
+ pair = []
79
+ for answer in query_docs .inputs :
80
+ pair .append ([query_docs .query , answer .text ])
81
+ scores = self .reranker .compute_score (pair )
82
+ for index , score in enumerate (scores ):
83
+ new_docs .append ({"id" : query_docs .inputs [index ].id , "score" : 1 / (1 + np .exp (- score ))})
84
+ new_docs = list (sorted (new_docs , key = lambda x : x ["score" ], reverse = True ))
85
+ return new_docs
86
+
87
+ app = FastAPI ()
88
+ security = HTTPBearer ()
89
+ env_bearer_token = 'ACCESS_TOKEN'
90
+
91
+ @app .post ('/api/v1/rerank' )
92
+ async def handle_post_request (docs : QADocs , credentials : HTTPAuthorizationCredentials = Security (security )):
93
+ token = credentials .credentials
94
+ if env_bearer_token is not None and token != env_bearer_token :
95
+ raise HTTPException (status_code = 401 , detail = "Invalid token" )
96
+ chat = Chat ()
97
+ qa_docs_with_rerank = chat .fit_query_answer_rerank (docs )
98
+ return response (200 , msg = "重排成功" , data = qa_docs_with_rerank )
99
+
100
+ if __name__ == "__main__" :
101
+ token = os .getenv ("ACCESS_TOKEN" )
102
+ if token is not None :
103
+ env_bearer_token = token
104
+ try :
105
+ uvicorn .run (app , host = '0.0.0.0' , port = 6006 )
106
+ except Exception as e :
107
+ print (f"API启动失败!\n 报错:\n { e } " )
0 commit comments