Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion backend/app/api/v1/endpoints/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ensure_unique_storage_name,
normalize_original_filename,
sanitize_storage_filename,
build_storage_name_with_email,
)

settings = get_settings()
Expand Down Expand Up @@ -290,6 +291,7 @@ async def ensure_uploaded_paper_local(
content = await _download_pdf_from_url(candidate_url)
stored_filename, file_url, file_size, file_hash = await _save_pdf_bytes(
current_user.id,
current_user.email,
content,
preferred_name=record.original_filename,
)
Expand Down Expand Up @@ -375,12 +377,14 @@ async def _download_pdf_from_url(url: str) -> bytes:

async def _save_pdf_bytes(
user_id: int,
user_email: str,
content: bytes,
*,
preferred_name: str | None = None,
) -> tuple[str, str, int, str]:
display_name = normalize_original_filename(preferred_name or f"user_{user_id}.pdf")
storage_candidate = sanitize_storage_filename(display_name)
storage_candidate = build_storage_name_with_email(display_name, user_email)
storage_candidate = sanitize_storage_filename(storage_candidate)
stored_filename, destination = ensure_unique_storage_name(UPLOAD_DIR, storage_candidate)
await asyncio.to_thread(destination.write_bytes, content)
file_url = f"/media/uploads/{stored_filename}"
Expand Down
128 changes: 103 additions & 25 deletions backend/app/api/v1/endpoints/papers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ParsedPaperCacheRepository,
UploadedPaperRepository,
)
from app.db.note_repository import NoteRepository
from app.db.conversation_repository import ConversationRepository
from app.db.session import get_db
from app.dependencies.auth import get_current_user
Expand All @@ -46,6 +47,7 @@
ensure_unique_storage_name,
normalize_original_filename,
sanitize_storage_filename,
build_storage_name_with_email,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -151,6 +153,8 @@ async def list_uploaded_papers(
async def upload_paper(
file: UploadFile = File(..., description="需要上传的 PDF 文件"),
folder_id: int | None = Form(None, description="文件夹 ID,不填则保存在未分类"),
conflict_resolution: str | None = Form(None, description="冲突处理方式:overwrite 或 rename"),
new_filename: str | None = Form(None, description="当 conflict_resolution=rename 时的新文件名"),
current_user=Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> upload_schema.UploadedPaperRead:
Expand All @@ -170,42 +174,108 @@ async def upload_paper(
)

cleaned_bytes = raw_bytes[:MAX_UPLOAD_BYTES]
file_hash = _calculate_file_hash(cleaned_bytes)

original_display_name = normalize_original_filename(file.filename)
storage_candidate = sanitize_storage_filename(original_display_name)
stored_filename, destination = ensure_unique_storage_name(UPLOAD_DIR, storage_candidate)

await asyncio.to_thread(destination.write_bytes, cleaned_bytes)

relative_url = f"/media/uploads/{stored_filename}"

folder_repo = LibraryFolderRepository(db)
repo = UploadedPaperRepository(db)
folder_repo = LibraryFolderRepository(db)
resolved_folder_id = await _ensure_folder_access(
folder_repo,
folder_id=folder_id if folder_id and folder_id > 0 else None,
user_id=current_user.id,
)

target_display_name = normalize_original_filename(file.filename)
existing = await repo.get_by_original_name(current_user.id, target_display_name)

# 解析冲突处理策略
resolution = (conflict_resolution or "").strip().lower() or None
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable resolution is normalized to lowercase and then compared against None and string literals. However, if conflict_resolution is an empty string after stripping, the expression resolves to None due to "or None". This logic could be clearer. Consider: resolution = conflict_resolution.strip().lower() if conflict_resolution and conflict_resolution.strip() else None

Suggested change
resolution = (conflict_resolution or "").strip().lower() or None
resolution = (
conflict_resolution.strip().lower()
if conflict_resolution and conflict_resolution.strip()
else None
)

Copilot uses AI. Check for mistakes.
if existing and not resolution:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"message": "当前用户已存在同名文件",
"conflict": True,
"filename": target_display_name,
"options": ["overwrite", "rename"],
},
)

if resolution == "rename":
if not new_filename:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="重命名上传时必须提供新文件名")
target_display_name = normalize_original_filename(new_filename)
existing = await repo.get_by_original_name(current_user.id, target_display_name)
if existing:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"message": "新的文件名仍然存在冲突,请更换名称",
"conflict": True,
"filename": target_display_name,
"options": ["overwrite", "rename"],
},
)
elif resolution not in {None, "overwrite"}:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="无效的冲突处理方式")

if existing and resolution == "overwrite":
stored_filename = existing.stored_filename
destination = (UPLOAD_DIR / stored_filename).resolve()
relative_url = existing.file_url or f"/media/uploads/{stored_filename}"
# 先删除旧文件,再写入新内容,保持物理名不变
try:
if destination.exists():
await asyncio.to_thread(destination.unlink)
except Exception:
logger.warning("Failed to remove existing file before overwrite: %s", destination)
else:
storage_candidate = build_storage_name_with_email(target_display_name, current_user.email)
storage_candidate = sanitize_storage_filename(storage_candidate)
stored_filename, destination = ensure_unique_storage_name(UPLOAD_DIR, storage_candidate)
relative_url = f"/media/uploads/{stored_filename}"

await asyncio.to_thread(destination.write_bytes, cleaned_bytes)
file_hash = _calculate_file_hash(cleaned_bytes)

metadata_json: dict | None = None
try:
metadata_json = await extract_pdf_metadata_async(destination, original_display_name)
metadata_json = await extract_pdf_metadata_async(destination, target_display_name)
except Exception as exc: # pragma: no cover - best effort only
logger.warning("Failed to extract metadata for uploaded PDF: %s", exc)
metadata_json = None

try:
record = await repo.create(
user_id=current_user.id,
stored_filename=stored_filename,
original_filename=original_display_name,
content_type=file.content_type or "application/pdf",
file_size=len(cleaned_bytes),
file_url=relative_url,
file_hash=file_hash,
folder_id=resolved_folder_id,
metadata_json=metadata_json,
)
if existing and resolution == "overwrite":
await repo.purge_cached_artifacts(existing)

conv_repo = ConversationRepository(db)
note_repo = NoteRepository(db)
await conv_repo.delete_conversations_for_paper(current_user.id, existing.id)
await note_repo.detach_uploaded_paper(current_user.id, existing.id)

# 更新记录为新的文件
await repo.update_file_fields(
existing,
stored_filename=stored_filename,
file_url=relative_url,
file_size=len(cleaned_bytes),
file_hash=file_hash,
content_type=file.content_type or "application/pdf",
)
await repo.update_metadata(existing, metadata_json)
record = existing
else:
# 新建记录(无冲突或重命名)
record = await repo.create(
user_id=current_user.id,
stored_filename=stored_filename,
original_filename=target_display_name,
content_type=file.content_type or "application/pdf",
file_size=len(cleaned_bytes),
file_url=relative_url,
file_hash=file_hash,
folder_id=resolved_folder_id,
metadata_json=metadata_json,
)
await db.commit()
except Exception:
await db.rollback()
Expand Down Expand Up @@ -834,21 +904,29 @@ async def _handle_paper_qa(
conversation_id = request.conversation_id

if conversation_id:
# 验证对话存在且属于当前用户,并属于智能阅读
# 验证对话存在且属于当前用户,并属于智能阅读,且绑定到该文档
conversation = await conv_repo.get_conversation(conversation_id, current_user.id)
if not conversation or conversation.category != "reading":
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Conversation not found or access denied"
)
if conversation.paper_id not in (None, request.paper_id):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Conversation not found for this paper",
)
if conversation.paper_id is None:
conversation.paper_id = request.paper_id
await db.flush()
# 获取历史消息
history_messages = await conv_repo.get_conversation_messages(conversation_id, current_user.id)
else:
# 创建新对话
# 创建新对话并绑定该文档
paper_title = parse_result.get("metadata", {}).get("title", "未命名文档")
conversation = await conv_repo.create_conversation(
current_user.id,
ConversationCreate(title=f"关于《{paper_title}》的讨论", category="reading")
ConversationCreate(title=f"关于《{paper_title}》的讨论", category="reading", paper_id=request.paper_id)
)
conversation_id = conversation.id
history_messages = []
Expand Down
44 changes: 40 additions & 4 deletions backend/app/api/v1/endpoints/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@
from __future__ import annotations

import asyncio
import shutil
from pathlib import Path
from typing import Final
from uuid import uuid4

from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession

from app.core.config import get_settings
from app.core.security import hash_password, verify_password
from app.db.repository import UserRepository
from app.models.uploaded_paper import UploadedPaper
from app.models.parsed_paper_cache import ParsedPaperCache
from app.db.session import get_db
from app.dependencies.auth import get_current_user
from app.schemas import user as user_schema
Expand Down Expand Up @@ -45,6 +49,16 @@ def _remove_avatar_file(avatar_url: str | None) -> None:
pass


def _remove_path_safely(target: Path) -> None:
try:
if target.is_file() or target.is_symlink():
target.unlink()
elif target.is_dir():
shutil.rmtree(target)
except OSError:
pass


@router.post(
"",
response_model=user_schema.UserRead,
Expand Down Expand Up @@ -179,13 +193,35 @@ async def delete_account(
current_user=Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Soft-delete the user account by marking it inactive."""
"""Hard delete user and all related data/files so the email can be reused."""

# Collect uploaded papers before DB deletion (to remove files/cache/parsed dirs)
result = await db.execute(
select(UploadedPaper.id, UploadedPaper.stored_filename, UploadedPaper.file_hash).where(
UploadedPaper.user_id == current_user.id
)
)
uploads = list(result.all())

# Remove physical files and parsed outputs
for paper_id, stored_filename, file_hash in uploads:
upload_path = settings.media_path / "uploads" / stored_filename
parse_dir = settings.media_path / "parsed" / f"paper_{paper_id}"
_remove_path_safely(upload_path)
_remove_path_safely(parse_dir)

if file_hash:
await db.execute(delete(ParsedPaperCache).where(ParsedPaperCache.file_hash == file_hash))

# Remove avatar if under media
_remove_avatar_file(getattr(current_user, "avatar_url", None))

repo = UserRepository(db)
await repo.update(current_user, {"is_active": False})
# 删除用户前先删除上传记录,避免 ORM 删除流程尝试将 user_id 置空导致约束错误
await db.execute(delete(UploadedPaper).where(UploadedPaper.user_id == current_user.id))

# Finally delete the user (FK cascades will clean remaining dependencies)
await db.delete(current_user)

await db.commit()

return {"message": "账户已成功注销"}
return {"message": "账户已彻底删除,可使用该邮箱重新注册"}
Comment on lines +196 to +227
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard-deleting user accounts and all related data could lead to accidental data loss. Consider implementing a confirmation mechanism (e.g., requiring the user to type their email or a confirmation phrase) before permanently deleting the account. Additionally, consider implementing a grace period or soft-delete first to allow recovery if the deletion was accidental.

Copilot uses AI. Check for mistakes.
27 changes: 26 additions & 1 deletion backend/app/db/conversation_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ async def create_conversation(self, user_id: int, data: ConversationCreate) -> C
user_id=user_id,
title=data.title,
category=data.category or "search",
paper_id=getattr(data, "paper_id", None),
)
self.db.add(conversation)
await self.db.commit()
await self.db.refresh(conversation)
return conversation

async def get_conversation(self, conversation_id: int, user_id: int) -> Optional[Conversation]:
async def get_conversation(self, conversation_id: int, user_id: int, *, paper_id: int | None = None) -> Optional[Conversation]:
"""获取特定对话(含消息)"""
stmt = (
select(Conversation)
Expand All @@ -43,6 +44,8 @@ async def get_conversation(self, conversation_id: int, user_id: int) -> Optional
)
.options(selectinload(Conversation.messages))
)
if paper_id is not None:
stmt = stmt.where(Conversation.paper_id == paper_id)
Comment on lines +47 to +48
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition paper_id is not None uses explicit None comparison, but on line 48 the comparison uses == paper_id. For consistency with Python best practices and the existing code style, use explicit None comparison throughout: if paper_id is not None:

Copilot uses AI. Check for mistakes.
result = await self.db.execute(stmt)
return result.scalar_one_or_none()

Expand Down Expand Up @@ -111,6 +114,28 @@ async def delete_conversation(self, conversation_id: int, user_id: int) -> bool:
await self.db.commit()
return True

async def delete_conversations_for_paper(self, user_id: int, paper_id: int) -> int:
"""软删除绑定到指定文档的阅读类对话,返回删除数量"""
stmt = (
select(Conversation)
.where(
Conversation.user_id == user_id,
Conversation.paper_id == paper_id,
Conversation.category == "reading",
Conversation.is_deleted == False,
)
.options(selectinload(Conversation.messages))
)
result = await self.db.execute(stmt)
conversations = result.scalars().all()
deleted = 0
for conv in conversations:
conv.is_deleted = True
deleted += 1
if deleted:
await self.db.commit()
return deleted

async def add_message(
self,
conversation_id: int,
Expand Down
15 changes: 14 additions & 1 deletion backend/app/db/note_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Mapping
from typing import Any

from sqlalchemy import func, select
from sqlalchemy import func, select, update
from sqlalchemy.ext.asyncio import AsyncSession

from app.models.note import Note
Expand Down Expand Up @@ -85,3 +85,16 @@ async def update(self, note: Note, updates: Mapping[str, Any]) -> Note:
async def delete(self, note: Note) -> None:
await self._session.delete(note)
await self._session.flush()

async def detach_uploaded_paper(self, user_id: int, paper_id: int) -> int:
"""Set uploaded_paper_id to NULL for notes linked to the given paper, returns affected rows."""

stmt = (
update(Note)
.where(Note.user_id == user_id, Note.uploaded_paper_id == paper_id)
.values(uploaded_paper_id=None)
.execution_options(synchronize_session="fetch")
)
result = await self._session.execute(stmt)
await self._session.flush()
return result.rowcount or 0
Loading
Loading