|
3 | 3 | import uuid
|
4 | 4 | import threading
|
5 | 5 | import random
|
6 |
| - |
7 | 6 | from fastapi import FastAPI, Depends, UploadFile, Form, File
|
8 | 7 | from starlette.requests import Request
|
9 | 8 | from starlette.responses import HTMLResponse, FileResponse
|
10 |
| -import random |
11 | 9 | from starlette.staticfiles import StaticFiles
|
12 | 10 |
|
13 |
| -from sqlalchemy import or_, select, update, delete, create_engine |
14 |
| -from sqlalchemy import select, update, delete |
| 11 | +from sqlalchemy import or_, select, update, delete |
15 | 12 | from sqlalchemy.ext.asyncio.session import AsyncSession
|
16 | 13 |
|
17 |
| -from database import engine, get_session, Base, Codes |
18 |
| - |
19 |
| - |
20 |
| -engine = create_engine('sqlite:///database.db', connect_args={"check_same_thread": False}) |
21 |
| -Base.metadata.create_all(bind=engine) |
| 14 | +from database import get_session, Codes |
22 | 15 |
|
23 | 16 | app = FastAPI()
|
24 | 17 | if not os.path.exists('./static'):
|
@@ -137,13 +130,14 @@ def ip_error(ip):
|
137 | 130 |
|
138 | 131 |
|
139 | 132 | @app.get('/select')
|
140 |
| -async def get_file(code: str, db: Session = Depends(get_db)): |
141 |
| - file = db.query(database.Codes).filter(database.Codes.code == code).first() |
142 |
| - if file: |
143 |
| - if file.type == 'text': |
144 |
| - return {'code': code, 'msg': '查询成功', 'data': file.text} |
| 133 | +async def get_file(code: str, s: AsyncSession = Depends(get_session)): |
| 134 | + query = select(Codes).where(Codes.code == code) |
| 135 | + info = (await s.execute(query)).scalars().first() |
| 136 | + if info: |
| 137 | + if info.type == 'text': |
| 138 | + return {'code': code, 'msg': '查询成功', 'data': info.text} |
145 | 139 | else:
|
146 |
| - return FileResponse('.' + file.text, filename=file.name) |
| 140 | + return FileResponse('.' + info.text, filename=info.name) |
147 | 141 | else:
|
148 | 142 | return {'code': 404, 'msg': '口令不存在'}
|
149 | 143 |
|
@@ -182,7 +176,7 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'),
|
182 | 176 | query = select(Codes).where(or_(Codes.exp_time < datetime.datetime.now(), Codes.count == 0))
|
183 | 177 | exps = (await s.execute(query)).scalars().all()
|
184 | 178 | threading.Thread(target=delete_file, args=([[{'type': old.type, 'text': old.text}] for old in exps],)).start()
|
185 |
| - |
| 179 | + |
186 | 180 | exps_ids = [exp.id for exp in exps]
|
187 | 181 | query = delete(Codes).where(Codes.id.in_(exps_ids))
|
188 | 182 | await s.execute(query)
|
|
0 commit comments