TenantDrive/utils/detebase.py

469 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sqlite3
import json
import uuid
from typing import Dict, Any, Optional
class CloudDriveDatabase:
def __init__(self, db_path: str = "cloud_drive.db"):
"""初始化数据库连接"""
self.db_path = db_path
self.conn = sqlite3.connect(db_path)
self.conn.row_factory = sqlite3.Row
self.cursor = self.conn.cursor()
self._create_tables()
def _create_tables(self):
"""创建所需的数据库表并确保表结构是最新的"""
# 1. 网盘驱动表
self.cursor.execute('''
CREATE TABLE IF NOT EXISTS drive_providers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
provider_name TEXT UNIQUE NOT NULL,
config_vars TEXT NOT NULL,
remarks TEXT
)
''')
# 2. 用户网盘表
self.cursor.execute('''
CREATE TABLE IF NOT EXISTS user_drives (
id INTEGER PRIMARY KEY AUTOINCREMENT,
provider_name TEXT NOT NULL,
login_config TEXT NOT NULL,
remarks TEXT,
FOREIGN KEY (provider_name) REFERENCES drive_providers (provider_name)
)
''')
# 3. 外链表
self.cursor.execute('''
CREATE TABLE IF NOT EXISTS external_links (
id INTEGER PRIMARY KEY AUTOINCREMENT,
drive_id INTEGER NOT NULL,
total_quota REAL NOT NULL,
used_quota REAL NOT NULL DEFAULT 0,
link_uuid TEXT UNIQUE NOT NULL,
remarks TEXT,
FOREIGN KEY (drive_id) REFERENCES user_drives (id)
)
''')
# 检查并添加expiry_time列到external_links表
self._add_column_if_not_exists('external_links', 'expiry_time', 'TEXT')
self.conn.commit()
def _add_column_if_not_exists(self, table_name: str, column_name: str, column_type: str):
"""检查表是否存在指定列,如果不存在则添加"""
self.cursor.execute(f"PRAGMA table_info({table_name})")
columns = [column[1] for column in self.cursor.fetchall()]
if column_name not in columns:
try:
self.cursor.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}")
self.conn.commit()
print(f"已成功添加列 '{column_name}' 到表 '{table_name}'")
except sqlite3.OperationalError as e:
print(f"添加列 '{column_name}' 到表 '{table_name}' 时出错: {e}")
# 网盘驱动表操作
def add_drive_provider(self, provider_name: str, config_vars: Dict[str, Any], remarks: Optional[str] = None) -> bool:
"""添加网盘服务商"""
try:
self.cursor.execute(
"INSERT INTO drive_providers (provider_name, config_vars, remarks) VALUES (?, ?, ?)",
(provider_name, json.dumps(config_vars, ensure_ascii=False), remarks)
)
self.conn.commit()
return True
except sqlite3.IntegrityError:
# 服务商名称已存在
return False
def get_drive_provider(self, provider_name: str) -> Optional[Dict[str, Any]]:
"""获取网盘服务商信息"""
self.cursor.execute("SELECT * FROM drive_providers WHERE provider_name = ?", (provider_name,))
result = self.cursor.fetchone()
if result:
result_dict = dict(result)
result_dict['config_vars'] = json.loads(result_dict['config_vars'])
return result_dict
return None
def get_all_drive_providers(self) -> list:
"""获取所有网盘服务商"""
self.cursor.execute("SELECT * FROM drive_providers")
results = self.cursor.fetchall()
providers = []
for row in results:
provider = dict(row)
provider['config_vars'] = json.loads(provider['config_vars'])
providers.append(provider)
return providers
def update_drive_provider(self, provider_name: str, config_vars: Dict[str, Any] = None, remarks: str = None) -> bool:
"""更新网盘服务商信息"""
try:
current = self.get_drive_provider(provider_name)
if not current:
return False
if config_vars is not None:
config_vars_json = json.dumps(config_vars, ensure_ascii=False)
else:
config_vars_json = json.dumps(current['config_vars'], ensure_ascii=False)
if remarks is None:
remarks = current['remarks']
self.cursor.execute(
"UPDATE drive_providers SET config_vars = ?, remarks = ? WHERE provider_name = ?",
(config_vars_json, remarks, provider_name)
)
self.conn.commit()
return True
except Exception:
return False
def delete_drive_provider(self, provider_name: str) -> bool:
"""删除网盘服务商"""
try:
self.cursor.execute("DELETE FROM drive_providers WHERE provider_name = ?", (provider_name,))
self.conn.commit()
return self.cursor.rowcount > 0
except Exception:
return False
# 用户网盘表操作
def add_user_drive(self, provider_name: str, login_config: Dict[str, Any], remarks: Optional[str] = None) -> Optional[int]:
"""添加用户网盘"""
try:
# 检查服务商是否存在
if not self.get_drive_provider(provider_name):
return None
self.cursor.execute(
"INSERT INTO user_drives (provider_name, login_config, remarks) VALUES (?, ?, ?)",
(provider_name, json.dumps(login_config, ensure_ascii=False), remarks)
)
self.conn.commit()
return self.cursor.lastrowid
except Exception:
return None
def get_user_drive(self, drive_id: int) -> Optional[Dict[str, Any]]:
"""获取用户网盘信息"""
self.cursor.execute("SELECT * FROM user_drives WHERE id = ?", (drive_id,))
result = self.cursor.fetchone()
if result:
result_dict = dict(result)
result_dict['login_config'] = json.loads(result_dict['login_config'])
return result_dict
return None
def get_user_drives_by_provider(self, provider_name: str) -> list:
"""获取指定服务商的所有用户网盘"""
self.cursor.execute("SELECT * FROM user_drives WHERE provider_name = ?", (provider_name,))
results = self.cursor.fetchall()
drives = []
for row in results:
drive = dict(row)
drive['login_config'] = json.loads(drive['login_config'])
drives.append(drive)
return drives
def get_all_user_drives(self) -> list:
"""获取所有用户网盘"""
self.cursor.execute("SELECT * FROM user_drives")
results = self.cursor.fetchall()
drives = []
for row in results:
drive = dict(row)
drive['login_config'] = json.loads(drive['login_config'])
drives.append(drive)
return drives
def update_user_drive(self, drive_id: int, login_config: Dict[str, Any] = None, remarks: str = None) -> bool:
"""更新用户网盘信息"""
try:
current = self.get_user_drive(drive_id)
if not current:
return False
if login_config is not None:
login_config_json = json.dumps(login_config, ensure_ascii=False)
else:
login_config_json = json.dumps(current['login_config'], ensure_ascii=False)
if remarks is None:
remarks = current['remarks']
self.cursor.execute(
"UPDATE user_drives SET login_config = ?, remarks = ? WHERE id = ?",
(login_config_json, remarks, drive_id)
)
self.conn.commit()
return True
except Exception:
return False
def delete_user_drive(self, drive_id: int) -> bool:
"""删除用户网盘"""
try:
self.cursor.execute("DELETE FROM user_drives WHERE id = ?", (drive_id,))
self.conn.commit()
return self.cursor.rowcount > 0
except Exception:
return False
# 外链表操作
def create_external_link(self, drive_id: int, total_quota: float, remarks: Optional[str] = None, expiry_time: str = None) -> Optional[str]:
"""创建外链"""
try:
# 检查用户网盘是否存在
if not self.get_user_drive(drive_id):
return None
# 生成不重复的UUID
link_uuid = str(uuid.uuid4())
while self.get_external_link_by_uuid(link_uuid):
link_uuid = str(uuid.uuid4())
# 如果没有指定到期时间默认为24小时后
if not expiry_time:
from datetime import datetime, timedelta
expiry_time = (datetime.now() + timedelta(hours=24)).strftime('%Y-%m-%d %H:%M:%S')
self.cursor.execute(
"INSERT INTO external_links (drive_id, total_quota, used_quota, link_uuid, remarks, expiry_time) VALUES (?, ?, 0, ?, ?, ?)",
(drive_id, total_quota, link_uuid, remarks, expiry_time)
)
self.conn.commit()
return link_uuid
except Exception as e:
print(f"创建外链错误: {e}")
return None
def get_external_link(self, link_id: int) -> Optional[Dict[str, Any]]:
"""获取外链信息"""
self.cursor.execute("SELECT * FROM external_links WHERE id = ?", (link_id,))
result = self.cursor.fetchone()
if result:
return dict(result)
return None
def get_external_link_by_uuid(self, link_uuid: str) -> Optional[Dict[str, Any]]:
"""通过UUID获取外链信息"""
self.cursor.execute("SELECT * FROM external_links WHERE link_uuid = ?", (link_uuid,))
result = self.cursor.fetchone()
if result:
return dict(result)
return None
def get_external_links_by_drive(self, drive_id: int) -> list:
"""获取指定用户网盘的所有外链"""
self.cursor.execute("SELECT * FROM external_links WHERE drive_id = ?", (drive_id,))
return [dict(row) for row in self.cursor.fetchall()]
def update_external_link_quota(self, link_uuid: str, used_quota: float) -> bool:
"""更新外链已使用配额"""
try:
link = self.get_external_link_by_uuid(link_uuid)
if not link:
return False
# 确保不超过总配额
if used_quota > link['total_quota']:
return False
self.cursor.execute(
"UPDATE external_links SET used_quota = ? WHERE link_uuid = ?",
(used_quota, link_uuid)
)
self.conn.commit()
return True
except Exception:
return False
def update_external_link(self, link_uuid: str, total_quota: float = None, remarks: str = None) -> bool:
"""更新外链信息"""
try:
link = self.get_external_link_by_uuid(link_uuid)
if not link:
return False
if total_quota is None:
total_quota = link['total_quota']
if remarks is None:
remarks = link['remarks']
# 确保新的总配额不小于已使用配额
if total_quota < link['used_quota']:
return False
self.cursor.execute(
"UPDATE external_links SET total_quota = ?, remarks = ? WHERE link_uuid = ?",
(total_quota, remarks, link_uuid)
)
self.conn.commit()
return True
except Exception:
return False
def delete_external_link(self, link_uuid: str) -> bool:
"""删除外链"""
try:
self.cursor.execute("DELETE FROM external_links WHERE link_uuid = ?", (link_uuid,))
self.conn.commit()
return self.cursor.rowcount > 0
except Exception:
return False
# 新增统计方法
def get_total_user_drives_count(self) -> int:
"""获取用户网盘总数"""
try:
self.cursor.execute("SELECT COUNT(*) FROM user_drives")
result = self.cursor.fetchone()
return result[0] if result else 0
except Exception as e:
print(f"获取用户网盘总数错误: {e}")
return 0
def get_active_external_links_count(self) -> int:
"""获取活跃外链数量 (未过期且有剩余次数)"""
try:
from datetime import datetime, timezone
now_utc_iso = datetime.now(timezone.utc).isoformat()
# 注意SQLite 不直接支持 ISO 8601 比较,此查询可能需要调整或在 Python 中过滤
# 简单起见,我们先只检查次数和时间是否存在
# 更精确的查询可能需要 DATETIME 函数,或在 Python 中处理
self.cursor.execute(
"""
SELECT COUNT(*) FROM external_links
WHERE (used_quota < total_quota)
AND (expiry_time IS NOT NULL AND expiry_time > ?)
""",
(now_utc_iso,) # 这个比较可能不适用于所有 SQLite 版本/配置,后续可能需要调整
)
# 备选(更兼容但效率低):获取所有链接在 Python 中过滤
# self.cursor.execute("SELECT link_uuid, expiry_time, used_quota, total_quota FROM external_links")
# links = self.cursor.fetchall()
# count = 0
# for link in links:
# is_active = False
# if link['used_quota'] < link['total_quota']:
# if link['expiry_time']:
# try:
# expiry_dt = datetime.fromisoformat(link['expiry_time'].replace('Z', '+00:00'))
# if datetime.now(timezone.utc) < expiry_dt:
# is_active = True
# except: pass # Ignore parsing errors
# else: # No expiry time means active if quota is available
# is_active = True
# if is_active:
# count += 1
# return count
result = self.cursor.fetchone()
return result[0] if result else 0
except Exception as e:
print(f"获取活跃外链数量错误: {e}")
return 0 # 返回 0 或其他错误指示
def get_total_external_links_count(self) -> int:
"""获取外链总数"""
try:
self.cursor.execute("SELECT COUNT(*) FROM external_links")
result = self.cursor.fetchone()
return result[0] if result else 0
except Exception as e:
print(f"获取外链总数错误: {e}")
return 0
def get_user_drives_count_by_provider(self) -> Dict[str, int]:
"""按提供商统计用户网盘数量"""
try:
self.cursor.execute("SELECT provider_name, COUNT(*) as count FROM user_drives GROUP BY provider_name")
results = self.cursor.fetchall()
return {row['provider_name']: row['count'] for row in results}
except Exception as e:
print(f"按提供商统计用户网盘数量错误: {e}")
return {}
def close(self):
"""关闭数据库连接"""
if self.conn:
self.conn.close()
# 使用示例
if __name__ == "__main__":
# 创建数据库实例
db = CloudDriveDatabase("cloud_drive.db")
# 添加一个网盘服务商
db.add_drive_provider(
"夸克网盘",
{
"sign_wg": "",
"kps_wg": "",
"redirect_uri": "https://uop.quark.cn/cas/ajax/loginWithKpsAndQrcodeToken",
"data":{
'client_id': '532',
'v': '1.2',
'request_id': "",
'sign_wg': "",
'kps_wg': "",
'vcode': "",
'token': ""
}
},
"夸克网盘API配置"
)
# 添加用户网盘
drive_id = db.add_user_drive(
"夸克网盘",
{
"sign_wg": "AAQHaE4ww2nnIPvofH2SfMv3N6OplcPRjxlgScTZozm/ZCMfQP74bsMLyKW883hZCGY=",
"kps_wg": "AARWcp9UM71t5VzV9i5pBJ4dLXjJ7EZL5a9qz2QVVQtkkmcqS4wQGYtk38CRzW6HH4+5c7qsB9/EtUgkWcd8x/k7h9+PmAHUDvxKHUWnX7iL3h2fH86XJ4cEqwvUnQ77QGs=",
"redirect_uri": "https://uop.quark.cn/cas/ajax/loginWithKpsAndQrcodeToken",
"data":{
'client_id': '532',
'v': '1.2',
'request_id': "",
'sign_wg': "",
'kps_wg': "",
'vcode': "",
'token': ""
}
},
"张三的百度网盘"
)
# 创建外链
if drive_id:
link_uuid = db.create_external_link(
drive_id,
3,
"测试外链"
)
print(f"创建的外链UUID: {link_uuid}")
# 获取外链信息
link_info = db.get_external_link_by_uuid(link_uuid)
print(f"外链信息: {link_info}")
# 更新已使用配额
db.update_external_link_quota(link_uuid, 512.0) # 使用了512MB
# 重新获取外链信息
link_info = db.get_external_link_by_uuid(link_uuid)
print(f"更新后的外链信息: {link_info}")
# 关闭数据库连接
db.close()