11
11
from sqlalchemy .ext .asyncio .session import AsyncSession
12
12
13
13
import settings
14
- from utils import delete_expire_files , storage , get_code
14
+ from utils import delete_expire_files , storage , get_code , error_ip_limit , upload_ip_limit
15
15
from database import get_session , Codes , init_models
16
- from depends import admin_required , IPRateLimit
16
+ from depends import admin_required
17
17
18
+ # 实例化FastAPI
18
19
app = FastAPI (debug = settings .DEBUG )
19
20
21
+ # 数据存储文件夹
20
22
DATA_ROOT = Path (settings .DATA_ROOT )
21
23
if not DATA_ROOT .exists ():
22
24
DATA_ROOT .mkdir (parents = True )
23
-
25
+ # 静态文件夹
24
26
app .mount (settings .STATIC_URL , StaticFiles (directory = DATA_ROOT ), name = "static" )
25
27
26
28
27
29
@app .on_event ('startup' )
28
30
async def startup ():
31
+ # 初始化数据库
29
32
await init_models ()
33
+ # 启动后台任务,不定时删除过期文件
30
34
asyncio .create_task (delete_expire_files ())
31
35
32
36
37
+ # 首页页面
33
38
index_html = open ('templates/index.html' , 'r' , encoding = 'utf-8' ).read () \
34
39
.replace ('{{title}}' , settings .TITLE ) \
35
40
.replace ('{{description}}' , settings .DESCRIPTION ) \
36
41
.replace ('{{keywords}}' , settings .KEYWORDS )
42
+ # 管理页面
37
43
admin_html = open ('templates/admin.html' , 'r' , encoding = 'utf-8' ).read () \
38
44
.replace ('{{title}}' , settings .TITLE ) \
39
45
.replace ('{{description}}' , settings .DESCRIPTION ) \
40
46
.replace ('{{keywords}}' , settings .KEYWORDS )
41
47
42
- ip_limit = IPRateLimit ()
43
-
44
48
45
- @app .get (f'/{ settings .ADMIN_ADDRESS } ' )
49
+ @app .get (f'/{ settings .ADMIN_ADDRESS } ' , description = '管理页面' , response_class = HTMLResponse )
46
50
async def admin ():
47
51
return HTMLResponse (admin_html )
48
52
49
53
50
- @app .post (f'/{ settings .ADMIN_ADDRESS } ' , dependencies = [Depends (admin_required )])
54
+ @app .post (f'/{ settings .ADMIN_ADDRESS } ' , dependencies = [Depends (admin_required )], description = '查询数据库列表' )
51
55
async def admin_post (s : AsyncSession = Depends (get_session )):
52
- query = select ( Codes )
53
- codes = (await s .execute (query )).scalars ().all ()
56
+ # 查询数据库列表
57
+ codes = (await s .execute (select ( Codes ) )).scalars ().all ()
54
58
return {'detail' : '查询成功' , 'data' : codes }
55
59
56
60
57
- @app .delete (f'/{ settings .ADMIN_ADDRESS } ' , dependencies = [Depends (admin_required )])
61
+ @app .delete (f'/{ settings .ADMIN_ADDRESS } ' , dependencies = [Depends (admin_required )], description = '删除数据库记录' )
58
62
async def admin_delete (code : str , s : AsyncSession = Depends (get_session )):
63
+ # 找到相应记录
59
64
query = select (Codes ).where (Codes .code == code )
65
+ # 找到第一条记录
60
66
file = (await s .execute (query )).scalars ().first ()
61
- if file :
62
- if file .type != 'text' :
63
- await storage .delete_file (file .text )
64
- await s .delete (file )
65
- await s .commit ()
67
+ # 如果记录存在,并且不是文本
68
+ if file and file .type != 'text' :
69
+ # 删除文件
70
+ await storage .delete_file (file .text )
71
+ # 删除数据库记录
72
+ await s .delete (file )
73
+ await s .commit ()
66
74
return {'detail' : '删除成功' }
67
75
68
76
@@ -72,25 +80,30 @@ async def index():
72
80
73
81
74
82
@app .get ('/select' )
75
- async def get_file (code : str , s : AsyncSession = Depends (get_session )):
83
+ async def get_file (code : str , ip : str = Depends (error_ip_limit ), s : AsyncSession = Depends (get_session )):
84
+ # 查出数据库记录
76
85
query = select (Codes ).where (Codes .code == code )
77
86
info = (await s .execute (query )).scalars ().first ()
87
+ # 如果记录不存在,IP错误次数+1
78
88
if not info :
79
- raise HTTPException (status_code = 404 , detail = "口令不存在" )
89
+ error_ip_limit .add_ip (ip )
90
+ raise HTTPException (status_code = 404 , detail = "口令不存在,次数过多将被禁止访问" )
91
+ # 如果是文本,直接返回
80
92
if info .type == 'text' :
81
93
return {'detail' : '查询成功' , 'data' : info .text }
94
+ # 如果是文件,返回文件
82
95
else :
83
96
filepath = await storage .get_filepath (info .text )
84
97
return FileResponse (filepath , filename = info .name )
85
98
86
99
87
100
@app .post ('/' )
88
- async def index (code : str , ip : str = Depends (ip_limit ), s : AsyncSession = Depends (get_session )):
101
+ async def index (code : str , ip : str = Depends (error_ip_limit ), s : AsyncSession = Depends (get_session )):
89
102
query = select (Codes ).where (Codes .code == code )
90
103
info = (await s .execute (query )).scalars ().first ()
91
104
if not info :
92
- error_count = settings .ERROR_COUNT - ip_limit .add_ip (ip )
93
- raise HTTPException (status_code = 404 , detail = f"取件码错误,错误 { error_count } 次将被禁止10分钟 " )
105
+ error_count = settings .ERROR_COUNT - error_ip_limit .add_ip (ip )
106
+ raise HTTPException (status_code = 404 , detail = f"取件码错误,{ error_count } 次后将被禁止 { settings . ERROR_MINUTE } 分钟 " )
94
107
if info .exp_time < datetime .datetime .now () or info .count == 0 :
95
108
if info .type != "text" :
96
109
await storage .delete_file (info .text )
@@ -109,7 +122,7 @@ async def index(code: str, ip: str = Depends(ip_limit), s: AsyncSession = Depend
109
122
110
123
@app .post ('/share' )
111
124
async def share (background_tasks : BackgroundTasks , text : str = Form (default = None ), style : str = Form (default = '2' ),
112
- value : int = Form (default = 1 ), file : UploadFile = File (default = None ),
125
+ value : int = Form (default = 1 ), file : UploadFile = File (default = None ), ip : str = Depends ( upload_ip_limit ),
113
126
s : AsyncSession = Depends (get_session )):
114
127
code = await get_code (s )
115
128
if style == '2' :
@@ -137,6 +150,7 @@ async def share(background_tasks: BackgroundTasks, text: str = Form(default=None
137
150
info = Codes (code = code , text = _text , size = size , type = _type , name = name , count = exp_count , exp_time = exp_time , key = key )
138
151
s .add (info )
139
152
await s .commit ()
153
+ upload_ip_limit .add_ip (ip )
140
154
return {
141
155
'detail' : '分享成功,请点击取件码按钮查看上传列表' ,
142
156
'data' : {'code' : code , 'key' : key , 'name' : name , 'text' : _text }
0 commit comments