Skip to content

Commit 2d73312

Browse files
committed
update:新增限制时间内上传次数,移除IP黑名单列表过期IP,防止炸内存
1 parent 11fb809 commit 2d73312

File tree

4 files changed

+79
-41
lines changed

4 files changed

+79
-41
lines changed

depends.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,36 @@ async def admin_required(pwd: Union[str, None] = Header(default=None)):
1313

1414
class IPRateLimit:
1515
ips = {}
16+
count = 0
17+
minutes = 0
18+
19+
def __init__(self, count, minutes):
20+
self.count = count
21+
self.minutes = minutes
1622

1723
def check_ip(self, ip):
1824
# 检查ip是否被禁止
1925
if ip in self.ips:
20-
if self.ips[ip]['count'] >= settings.ERROR_COUNT:
21-
if self.ips[ip]['time'] + timedelta(minutes=settings.ERROR_MINUTE) > datetime.now():
26+
if self.ips[ip]['count'] >= self.count:
27+
if self.ips[ip]['time'] + timedelta(minutes=self.minutes) > datetime.now():
2228
return False
2329
else:
2430
self.ips.pop(ip)
2531
return True
26-
32+
2733
def add_ip(cls, ip):
2834
ip_info = cls.ips.get(ip, {'count': 0, 'time': datetime.now()})
2935
ip_info['count'] += 1
3036
cls.ips[ip] = ip_info
3137
return ip_info['count']
32-
38+
39+
async def remove_expired_ip(self):
40+
for ip in list(self.ips.keys()):
41+
if self.ips[ip]['time'] + timedelta(minutes=self.minutes) < datetime.now():
42+
self.ips.pop(ip)
43+
3344
def __call__(self, request: Request):
3445
ip = request.client.host
3546
if not self.check_ip(ip):
36-
raise HTTPException(status_code=400, detail="错误次数过多,请稍后再试")
47+
raise HTTPException(status_code=400, detail=f"请求次数过多,请稍后再试")
3748
return ip

main.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,58 +11,66 @@
1111
from sqlalchemy.ext.asyncio.session import AsyncSession
1212

1313
import settings
14-
from utils import delete_expire_files, storage, get_code
14+
from utils import delete_expire_files, storage, get_code, error_ip_limit, upload_ip_limit
1515
from database import get_session, Codes, init_models
16-
from depends import admin_required, IPRateLimit
16+
from depends import admin_required
1717

18+
# 实例化FastAPI
1819
app = FastAPI(debug=settings.DEBUG)
1920

21+
# 数据存储文件夹
2022
DATA_ROOT = Path(settings.DATA_ROOT)
2123
if not DATA_ROOT.exists():
2224
DATA_ROOT.mkdir(parents=True)
23-
25+
# 静态文件夹
2426
app.mount(settings.STATIC_URL, StaticFiles(directory=DATA_ROOT), name="static")
2527

2628

2729
@app.on_event('startup')
2830
async def startup():
31+
# 初始化数据库
2932
await init_models()
33+
# 启动后台任务,不定时删除过期文件
3034
asyncio.create_task(delete_expire_files())
3135

3236

37+
# 首页页面
3338
index_html = open('templates/index.html', 'r', encoding='utf-8').read() \
3439
.replace('{{title}}', settings.TITLE) \
3540
.replace('{{description}}', settings.DESCRIPTION) \
3641
.replace('{{keywords}}', settings.KEYWORDS)
42+
# 管理页面
3743
admin_html = open('templates/admin.html', 'r', encoding='utf-8').read() \
3844
.replace('{{title}}', settings.TITLE) \
3945
.replace('{{description}}', settings.DESCRIPTION) \
4046
.replace('{{keywords}}', settings.KEYWORDS)
4147

42-
ip_limit = IPRateLimit()
43-
4448

45-
@app.get(f'/{settings.ADMIN_ADDRESS}')
49+
@app.get(f'/{settings.ADMIN_ADDRESS}', description='管理页面', response_class=HTMLResponse)
4650
async def admin():
4751
return HTMLResponse(admin_html)
4852

4953

50-
@app.post(f'/{settings.ADMIN_ADDRESS}', dependencies=[Depends(admin_required)])
54+
@app.post(f'/{settings.ADMIN_ADDRESS}', dependencies=[Depends(admin_required)], description='查询数据库列表')
5155
async def admin_post(s: AsyncSession = Depends(get_session)):
52-
query = select(Codes)
53-
codes = (await s.execute(query)).scalars().all()
56+
# 查询数据库列表
57+
codes = (await s.execute(select(Codes))).scalars().all()
5458
return {'detail': '查询成功', 'data': codes}
5559

5660

57-
@app.delete(f'/{settings.ADMIN_ADDRESS}', dependencies=[Depends(admin_required)])
61+
@app.delete(f'/{settings.ADMIN_ADDRESS}', dependencies=[Depends(admin_required)], description='删除数据库记录')
5862
async def admin_delete(code: str, s: AsyncSession = Depends(get_session)):
63+
# 找到相应记录
5964
query = select(Codes).where(Codes.code == code)
65+
# 找到第一条记录
6066
file = (await s.execute(query)).scalars().first()
61-
if file:
62-
if file.type != 'text':
63-
await storage.delete_file(file.text)
64-
await s.delete(file)
65-
await s.commit()
67+
# 如果记录存在,并且不是文本
68+
if file and file.type != 'text':
69+
# 删除文件
70+
await storage.delete_file(file.text)
71+
# 删除数据库记录
72+
await s.delete(file)
73+
await s.commit()
6674
return {'detail': '删除成功'}
6775

6876

@@ -72,25 +80,30 @@ async def index():
7280

7381

7482
@app.get('/select')
75-
async def get_file(code: str, s: AsyncSession = Depends(get_session)):
83+
async def get_file(code: str, ip: str = Depends(error_ip_limit), s: AsyncSession = Depends(get_session)):
84+
# 查出数据库记录
7685
query = select(Codes).where(Codes.code == code)
7786
info = (await s.execute(query)).scalars().first()
87+
# 如果记录不存在,IP错误次数+1
7888
if not info:
79-
raise HTTPException(status_code=404, detail="口令不存在")
89+
error_ip_limit.add_ip(ip)
90+
raise HTTPException(status_code=404, detail="口令不存在,次数过多将被禁止访问")
91+
# 如果是文本,直接返回
8092
if info.type == 'text':
8193
return {'detail': '查询成功', 'data': info.text}
94+
# 如果是文件,返回文件
8295
else:
8396
filepath = await storage.get_filepath(info.text)
8497
return FileResponse(filepath, filename=info.name)
8598

8699

87100
@app.post('/')
88-
async def index(code: str, ip: str = Depends(ip_limit), s: AsyncSession = Depends(get_session)):
101+
async def index(code: str, ip: str = Depends(error_ip_limit), s: AsyncSession = Depends(get_session)):
89102
query = select(Codes).where(Codes.code == code)
90103
info = (await s.execute(query)).scalars().first()
91104
if not info:
92-
error_count = settings.ERROR_COUNT - ip_limit.add_ip(ip)
93-
raise HTTPException(status_code=404, detail=f"取件码错误,错误{error_count}次将被禁止10分钟")
105+
error_count = settings.ERROR_COUNT - error_ip_limit.add_ip(ip)
106+
raise HTTPException(status_code=404, detail=f"取件码错误,{error_count}次后将被禁止{settings.ERROR_MINUTE}分钟")
94107
if info.exp_time < datetime.datetime.now() or info.count == 0:
95108
if info.type != "text":
96109
await storage.delete_file(info.text)
@@ -109,7 +122,7 @@ async def index(code: str, ip: str = Depends(ip_limit), s: AsyncSession = Depend
109122

110123
@app.post('/share')
111124
async def share(background_tasks: BackgroundTasks, text: str = Form(default=None), style: str = Form(default='2'),
112-
value: int = Form(default=1), file: UploadFile = File(default=None),
125+
value: int = Form(default=1), file: UploadFile = File(default=None), ip: str = Depends(upload_ip_limit),
113126
s: AsyncSession = Depends(get_session)):
114127
code = await get_code(s)
115128
if style == '2':
@@ -137,6 +150,7 @@ async def share(background_tasks: BackgroundTasks, text: str = Form(default=None
137150
info = Codes(code=code, text=_text, size=size, type=_type, name=name, count=exp_count, exp_time=exp_time, key=key)
138151
s.add(info)
139152
await s.commit()
153+
upload_ip_limit.add_ip(ip)
140154
return {
141155
'detail': '分享成功,请点击取件码按钮查看上传列表',
142156
'data': {'code': code, 'key': key, 'name': name, 'text': _text}

settings.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,36 @@
11
from starlette.config import Config
22

3+
# 配置文件.env
34
config = Config(".env")
4-
5+
# 是否开启DEBUG模式
56
DEBUG = config('DEBUG', cast=bool, default=False)
6-
7+
# 端口
78
PORT = config('PORT', cast=int, default=12345)
8-
9+
# Sqlite数据库文件
910
DATABASE_URL = config('DATABASE_URL', cast=str, default="sqlite+aiosqlite:///database.db")
10-
11+
# 静态文件夹
1112
DATA_ROOT = config('DATA_ROOT', cast=str, default="./static")
12-
13+
# 静态文件夹URL
1314
STATIC_URL = config('STATIC_URL', cast=str, default="/static")
14-
15+
# 错误次数
1516
ERROR_COUNT = config('ERROR_COUNT', cast=int, default=5)
16-
17+
# 错误限制分钟数
1718
ERROR_MINUTE = config('ERROR_MINUTE', cast=int, default=10)
18-
19+
# 上传次数
20+
UPLOAD_COUNT = config('UPLOAD_COUNT', cast=int, default=60)
21+
# 上传限制分钟数
22+
UPLOAD_MINUTE = config('UPLOAD_MINUTE', cast=int, default=1)
23+
# 管理地址
1924
ADMIN_ADDRESS = config('ADMIN_ADDRESS', cast=str, default="admin")
20-
25+
# 管理密码
2126
ADMIN_PASSWORD = config('ADMIN_PASSWORD', cast=str, default="admin")
22-
27+
# 文件大小限制,默认10MB
2328
FILE_SIZE_LIMIT = config('FILE_SIZE_LIMIT', cast=int, default=10) * 1024 * 1024
24-
29+
# 网站标题
2530
TITLE = config('TITLE', cast=str, default="文件快递柜")
26-
31+
# 网站描述
2732
DESCRIPTION = config('DESCRIPTION', cast=str, default="FileCodeBox,文件快递柜,口令传送箱,匿名口令分享文本,文件等文件")
28-
33+
# 网站关键词
2934
KEYWORDS = config('KEYWORDS', cast=str, default="FileCodeBox,文件快递柜,口令传送箱,匿名口令分享文本,文件等文件")
30-
35+
# 存储引擎
3136
STORAGE_ENGINE = config('STORAGE_ENGINE', cast=str, default="filesystem")

utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,22 @@
55
from sqlalchemy.ext.asyncio.session import AsyncSession
66
import settings
77
from database import Codes, engine
8+
from depends import IPRateLimit
89
from storage import STORAGE_ENGINE
910

1011
storage = STORAGE_ENGINE[settings.STORAGE_ENGINE]()
1112

13+
# 错误IP限制器
14+
error_ip_limit = IPRateLimit(settings.ERROR_COUNT, settings.ERROR_MINUTE)
15+
# 上传文件限制器
16+
upload_ip_limit = IPRateLimit(settings.UPLOAD_COUNT, settings.UPLOAD_MINUTE)
17+
1218

1319
async def delete_expire_files():
1420
while True:
1521
async with AsyncSession(engine, expire_on_commit=False) as s:
22+
await error_ip_limit.remove_expired_ip()
23+
await upload_ip_limit.remove_expired_ip()
1624
query = select(Codes).where(or_(Codes.exp_time < datetime.datetime.now(), Codes.count == 0))
1725
exps = (await s.execute(query)).scalars().all()
1826
files = []
@@ -25,7 +33,7 @@ async def delete_expire_files():
2533
query = delete(Codes).where(Codes.id.in_(exps_ids))
2634
await s.execute(query)
2735
await s.commit()
28-
await asyncio.sleep(random.randint(60, 300))
36+
await asyncio.sleep(random.randint(2, 2))
2937

3038

3139
async def get_code(s: AsyncSession):

0 commit comments

Comments
 (0)