|
| 1 | +#!/usr/bin/env python |
| 2 | +# -*- coding: UTF-8 -*- |
| 3 | +""" |
| 4 | +@Author sineom |
| 5 | +@Date 2024/7/23-09:20 |
| 6 | + |
| 7 | +@description sqlite操作 |
| 8 | +@Copyright (c) 2022 by sineom, All Rights Reserved. |
| 9 | +""" |
| 10 | +import os |
| 11 | +import sqlite3 |
| 12 | + |
| 13 | +from common.log import logger |
| 14 | + |
| 15 | + |
| 16 | +class Db: |
| 17 | + def __init__(self): |
| 18 | + curdir = os.path.dirname(__file__) |
| 19 | + db_path = os.path.join(curdir, "chat.db") |
| 20 | + self.conn = sqlite3.connect(db_path, check_same_thread=False) |
| 21 | + c = self.conn.cursor() |
| 22 | + c.execute('''CREATE TABLE IF NOT EXISTS chat_records |
| 23 | + (sessionid TEXT, msgid INTEGER, user TEXT, content TEXT, type TEXT, timestamp INTEGER, is_triggered INTEGER, |
| 24 | + PRIMARY KEY (sessionid, msgid))''') |
| 25 | + |
| 26 | + # 创建一个总结时间表,记录合适开始了总结的时间 |
| 27 | + c.execute('''CREATE TABLE IF NOT EXISTS summary_time |
| 28 | + (sessionid TEXT, summary_time INTEGER, PRIMARY KEY (sessionid))''') |
| 29 | + |
| 30 | + # 创建一个关闭保存聊天记录的表 |
| 31 | + c.execute('''CREATE TABLE IF NOT EXISTS summary_stop |
| 32 | + (sessionid TEXT, PRIMARY KEY (sessionid))''') |
| 33 | + |
| 34 | + # 后期增加了is_triggered字段,这里做个过渡,这段代码某天会删除 |
| 35 | + c = c.execute("PRAGMA table_info(chat_records);") |
| 36 | + column_exists = False |
| 37 | + for column in c.fetchall(): |
| 38 | + logger.debug("[Summary] column: {}".format(column)) |
| 39 | + if column[1] == 'is_triggered': |
| 40 | + column_exists = True |
| 41 | + break |
| 42 | + if not column_exists: |
| 43 | + self.conn.execute("ALTER TABLE chat_records ADD COLUMN is_triggered INTEGER DEFAULT 0;") |
| 44 | + self.conn.execute("UPDATE chat_records SET is_triggered = 0;") |
| 45 | + |
| 46 | + self.conn.commit() |
| 47 | + # 禁用的群聊 |
| 48 | + self.disable_group = self._get_summary_stop() |
| 49 | + |
| 50 | + def insert_record(self, session_id, msg_id, user, content, msg_type, timestamp, is_triggered=0): |
| 51 | + c = self.conn.cursor() |
| 52 | + logger.debug("[Summary] insert record: {} {} {} {} {} {} {}".format(session_id, msg_id, user, content, msg_type, |
| 53 | + timestamp, is_triggered)) |
| 54 | + c.execute("INSERT OR REPLACE INTO chat_records VALUES (?,?,?,?,?,?,?)", |
| 55 | + (session_id, msg_id, user, content, msg_type, timestamp, is_triggered)) |
| 56 | + self.conn.commit() |
| 57 | + |
| 58 | + # 根据时间删除记录 |
| 59 | + def delete_records(self, start_timestamp): |
| 60 | + try: |
| 61 | + c = self.conn.cursor() |
| 62 | + c.execute(''' |
| 63 | + DELETE FROM chat_records |
| 64 | + WHERE timestamp < ? |
| 65 | + ''', start_timestamp,) |
| 66 | + self.conn.commit() |
| 67 | + logger.info("Records older have been cleaned.") |
| 68 | + except Exception as e: |
| 69 | + logger.error(e) |
| 70 | + |
| 71 | + # 保存总结时间,如果表中不存在则插入,如果存在则更新 |
| 72 | + def save_summary_time(self, session_id, summary_time): |
| 73 | + if self.get_summary_time(session_id) is None: |
| 74 | + self._insert_summary_time(session_id, summary_time) |
| 75 | + else: |
| 76 | + self._update_summary_time(session_id, summary_time) |
| 77 | + |
| 78 | + # 插入总结时间 |
| 79 | + def _insert_summary_time(self, session_id, summary_time): |
| 80 | + c = self.conn.cursor() |
| 81 | + logger.debug("[Summary] insert summary time: {} {}".format(session_id, summary_time)) |
| 82 | + c.execute("INSERT OR REPLACE INTO summary_time VALUES (?,?)", |
| 83 | + (session_id, summary_time)) |
| 84 | + self.conn.commit() |
| 85 | + |
| 86 | + # 更新总结时间 |
| 87 | + def _update_summary_time(self, session_id, summary_time): |
| 88 | + c = self.conn.cursor() |
| 89 | + logger.debug("[Summary] update summary time: {} {}".format(session_id, summary_time)) |
| 90 | + c.execute("UPDATE summary_time SET summary_time = ? WHERE sessionid = ?", |
| 91 | + (summary_time, session_id)) |
| 92 | + self.conn.commit() |
| 93 | + |
| 94 | + # 获取总结时间,如果不存在返回None |
| 95 | + def get_summary_time(self, session_id): |
| 96 | + c = self.conn.cursor() |
| 97 | + c.execute("SELECT summary_time FROM summary_time WHERE sessionid=?", (session_id,)) |
| 98 | + row = c.fetchone() |
| 99 | + if row is None: |
| 100 | + return None |
| 101 | + return row[0] |
| 102 | + |
| 103 | + def get_records(self, session_id, start_timestamp=0, limit=9999) -> list: |
| 104 | + c = self.conn.cursor() |
| 105 | + c.execute("SELECT * FROM chat_records WHERE sessionid=? and timestamp>? ORDER BY timestamp DESC LIMIT ?", |
| 106 | + (session_id, start_timestamp, limit)) |
| 107 | + return c.fetchall() |
| 108 | + |
| 109 | + # 删除禁用的群聊 |
| 110 | + def delete_summary_stop(self, session_id): |
| 111 | + try: |
| 112 | + c = self.conn.cursor() |
| 113 | + c.execute("DELETE FROM summary_stop WHERE sessionid=?", (session_id,)) |
| 114 | + self.conn.commit() |
| 115 | + if session_id in self.disable_group: |
| 116 | + self.disable_group.remove(session_id) |
| 117 | + except Exception as e: |
| 118 | + logger.error(e) |
| 119 | + |
| 120 | + # 保存禁用的群聊 |
| 121 | + def save_summary_stop(self, session_id): |
| 122 | + try: |
| 123 | + c = self.conn.cursor() |
| 124 | + c.execute("INSERT OR REPLACE INTO summary_stop VALUES (?)", |
| 125 | + (session_id,)) |
| 126 | + self.conn.commit() |
| 127 | + self.disable_group.add(session_id) |
| 128 | + except Exception as e: |
| 129 | + logger.error(e) |
| 130 | + |
| 131 | + # 获取所有禁用的群聊 |
| 132 | + def _get_summary_stop(self): |
| 133 | + c = self.conn.cursor() |
| 134 | + c.execute("SELECT sessionid FROM summary_stop") |
| 135 | + return set(c.fetchall()) |
0 commit comments