Browse Source

feat:创建新对话

zhangwl 1 month ago
parent
commit
3f86902d90
6 changed files with 61 additions and 45 deletions
  1. 15 14
      app/db/mongo.py
  2. 4 4
      app/dependencies/auth.py
  3. 33 23
      app/routers/chat.py
  4. 3 1
      app/routers/users.py
  5. 4 1
      app/schemas/chat.py
  6. 2 2
      app/utils/chat_utils.py

+ 15 - 14
app/db/mongo.py

@@ -17,17 +17,16 @@ chat_history_col = db["chat_history"]
 # 兴趣圈集合
 # 兴趣圈集合
 circle_prompts = db["circle_prompt"]
 circle_prompts = db["circle_prompt"]
 
 
-# 带下划线的表示私有方法(Private)
 def _ensure_index():
 def _ensure_index():
     try:
     try:
-        chat_logs.create_index([("username", 1), ("asked_at", -1)])
-        chat_history_col.create_index([("username", 1), ("timestamp", -1)])
+        chat_logs.create_index([("user_id", 1), ("asked_at", -1)])
+        chat_history_col.create_index([("user_id", 1), ("session_id", 1), ("timestamp", -1)])
     except Exception:
     except Exception:
         pass
         pass
 
 
 
 
 def save_chat_log(
 def save_chat_log(
-    username: str,
+    user_id: str,
     question: str,
     question: str,
     stream_mode: bool,
     stream_mode: bool,
     raw_response: str = None,
     raw_response: str = None,
@@ -38,7 +37,7 @@ def save_chat_log(
     保存聊天原始响应日志到 MongoDB
     保存聊天原始响应日志到 MongoDB
 
 
     Args:
     Args:
-        username: 提问人
+        user_id: 提问人
         question: 提问的问题
         question: 提问的问题
         stream_mode: 回答方式(流式或非流式)
         stream_mode: 回答方式(流式或非流式)
         raw_response: API 原始响应的 repr 字符串
         raw_response: API 原始响应的 repr 字符串
@@ -48,7 +47,7 @@ def save_chat_log(
     try:
     try:
         _ensure_index()
         _ensure_index()
         chat_logs.insert_one({
         chat_logs.insert_one({
-            "username": username,
+            "user_id": user_id,
             "question": question,
             "question": question,
             "stream_mode": stream_mode,
             "stream_mode": stream_mode,
             "raw_response": raw_response,
             "raw_response": raw_response,
@@ -61,7 +60,8 @@ def save_chat_log(
 
 
 
 
 def save_chat_history(
 def save_chat_history(
-    username: str,
+    user_id: str,
+    session_id: str,
     role: str,
     role: str,
     content: str,
     content: str,
     timestamp: datetime,
     timestamp: datetime,
@@ -72,7 +72,8 @@ def save_chat_history(
     try:
     try:
         _ensure_index()
         _ensure_index()
         chat_history_col.insert_one({
         chat_history_col.insert_one({
-            "username": username,
+            "user_id": user_id,
+            "session_id": session_id,
             "role": role,
             "role": role,
             "content": content,
             "content": content,
             "thinking": thinking,
             "thinking": thinking,
@@ -84,10 +85,10 @@ def save_chat_history(
         print(f"MongoDB 聊天历史写入失败: {e}")
         print(f"MongoDB 聊天历史写入失败: {e}")
 
 
 
 
-def get_chat_history(username: str) -> list:
+def get_chat_history(user_id: str, session_id: str) -> list:
     try:
     try:
         docs = chat_history_col.find(
         docs = chat_history_col.find(
-            {"username": username},
+            {"user_id": user_id, "session_id": session_id},
             {"_id": 0}
             {"_id": 0}
         ).sort("timestamp", 1)
         ).sort("timestamp", 1)
         return list(docs)
         return list(docs)
@@ -96,10 +97,10 @@ def get_chat_history(username: str) -> list:
         return []
         return []
 
 
 
 
-def get_last_response_id(username: str) -> str | None:
+def get_last_response_id(user_id: str, session_id: str) -> str | None:
     try:
     try:
         doc = chat_history_col.find_one(
         doc = chat_history_col.find_one(
-            {"username": username, "role": "assistant", "response_id": {"$ne": None}},
+            {"user_id": user_id, "session_id": session_id, "role": "assistant", "response_id": {"$ne": None}},
             {"response_id": 1, "_id": 0},
             {"response_id": 1, "_id": 0},
             sort=[("timestamp", -1)]
             sort=[("timestamp", -1)]
         )
         )
@@ -109,9 +110,9 @@ def get_last_response_id(username: str) -> str | None:
         return None
         return None
 
 
 
 
-def delete_chat_history(username: str) -> int:
+def delete_chat_history(user_id: str, session_id: str) -> int:
     try:
     try:
-        result = chat_history_col.delete_many({"username": username})
+        result = chat_history_col.delete_many({"user_id": user_id, "session_id": session_id})
         return result.deleted_count
         return result.deleted_count
     except Exception as e:
     except Exception as e:
         print(f"MongoDB 聊天历史删除失败: {e}")
         print(f"MongoDB 聊天历史删除失败: {e}")

+ 4 - 4
app/dependencies/auth.py

@@ -10,7 +10,7 @@ config = Config()
 oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="/users/token", auto_error=False)
 oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="/users/token", auto_error=False)
 
 
 
 
-async def resolve_username(
+async def resolve_user_id(
     jwt_token: Annotated[Optional[str], Depends(oauth2_scheme_optional)] = None,
     jwt_token: Annotated[Optional[str], Depends(oauth2_scheme_optional)] = None,
     source: Optional[str] = Query(default=None),
     source: Optional[str] = Query(default=None),
     token: Optional[str] = Query(default=None),
     token: Optional[str] = Query(default=None),
@@ -24,9 +24,9 @@ async def resolve_username(
         raise HTTPException(status_code=401, detail="未提供认证令牌")
         raise HTTPException(status_code=401, detail="未提供认证令牌")
     try:
     try:
         payload = jwt.decode(jwt_token, config.SECRET_KEY, algorithms=[config.ALGORITHM])
         payload = jwt.decode(jwt_token, config.SECRET_KEY, algorithms=[config.ALGORITHM])
-        sub = payload.get("sub")
-        if not sub:
+        user_id = payload.get("userId")
+        if not user_id:
             raise HTTPException(status_code=401, detail="无效的令牌")
             raise HTTPException(status_code=401, detail="无效的令牌")
-        return sub
+        return user_id
     except jwt.PyJWTError:
     except jwt.PyJWTError:
         raise HTTPException(status_code=401, detail="无效的令牌")
         raise HTTPException(status_code=401, detail="无效的令牌")

+ 33 - 23
app/routers/chat.py

@@ -1,6 +1,6 @@
-from fastapi import APIRouter, HTTPException, Depends
+from fastapi import APIRouter, HTTPException, Depends, Query
 from fastapi.responses import StreamingResponse
 from fastapi.responses import StreamingResponse
-from typing import List, Annotated
+from typing import List, Annotated, Optional
 from datetime import datetime
 from datetime import datetime
 import json
 import json
 import asyncio
 import asyncio
@@ -10,12 +10,13 @@ from ..schemas.chat import ChatMessage, ChatRequest, ChatResponse, StreamRespons
 from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools, get_web_search_tools
 from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools, get_web_search_tools
 from ..routers.users import get_current_active_user, User
 from ..routers.users import get_current_active_user, User
 from ..db.mongo import save_chat_log, save_chat_history, get_chat_history, delete_chat_history
 from ..db.mongo import save_chat_log, save_chat_history, get_chat_history, delete_chat_history
-from ..dependencies.auth import resolve_username
+from ..dependencies.auth import resolve_user_id
 
 
 router = APIRouter()
 router = APIRouter()
 
 
 
 
-async def generate_stream_response(request: ChatRequest, username: str):
+async def generate_stream_response(request: ChatRequest, user_id: str):
+    session_id = request.session_id
     latest_user_msg = None
     latest_user_msg = None
     try:
     try:
         latest_user_msg = get_latest_user_message(request.messages)
         latest_user_msg = get_latest_user_message(request.messages)
@@ -24,7 +25,8 @@ async def generate_stream_response(request: ChatRequest, username: str):
 
 
         # 保存用户消息到 DB
         # 保存用户消息到 DB
         save_chat_history(
         save_chat_history(
-            username=username,
+            user_id=user_id,
+            session_id=session_id,
             role=latest_user_msg.role,
             role=latest_user_msg.role,
             content=latest_user_msg.content,
             content=latest_user_msg.content,
             timestamp=datetime.now(),
             timestamp=datetime.now(),
@@ -77,7 +79,7 @@ async def generate_stream_response(request: ChatRequest, username: str):
 
 
         api_messages = [system_prompt, {"role": latest_user_msg.role, "content": latest_user_msg.content}]
         api_messages = [system_prompt, {"role": latest_user_msg.role, "content": latest_user_msg.content}]
         tools = get_web_search_tools()
         tools = get_web_search_tools()
-        previous_response_id = get_previous_response_id(username)
+        previous_response_id = get_previous_response_id(user_id, session_id)
 
 
         stream = client.responses.create(
         stream = client.responses.create(
             model=config.MODEL_NAME,
             model=config.MODEL_NAME,
@@ -135,7 +137,7 @@ async def generate_stream_response(request: ChatRequest, username: str):
                 delta_text = getattr(chunk, 'delta', '')
                 delta_text = getattr(chunk, 'delta', '')
                 if delta_text:
                 if delta_text:
                     if not answering_started:
                     if not answering_started:
-                        print(f"\n\n💬 AI回答 [{datetime.now().strftime('%H:%M:%S')}]:")
+                        # print(f"\n\n💬 AI回答 [{datetime.now().strftime('%H:%M:%S')}]:")
                         answering_started = True
                         answering_started = True
                     accumulated_content += delta_text
                     accumulated_content += delta_text
                     yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now()).model_dump_json()}\n\n"
                     yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now()).model_dump_json()}\n\n"
@@ -147,7 +149,7 @@ async def generate_stream_response(request: ChatRequest, username: str):
                 if response_obj and hasattr(response_obj, 'id'):
                 if response_obj and hasattr(response_obj, 'id'):
                     response_id = response_obj.id
                     response_id = response_obj.id
                 save_chat_log(
                 save_chat_log(
-                    username=username,
+                    user_id=user_id,
                     question=latest_user_msg.content,
                     question=latest_user_msg.content,
                     stream_mode=True,
                     stream_mode=True,
                     raw_response=repr(response_obj),
                     raw_response=repr(response_obj),
@@ -159,7 +161,8 @@ async def generate_stream_response(request: ChatRequest, username: str):
         if accumulated_content:
         if accumulated_content:
             # 保存助手消息到 DB(含 thinking / searching)
             # 保存助手消息到 DB(含 thinking / searching)
             save_chat_history(
             save_chat_history(
-                username=username,
+                user_id=user_id,
+                session_id=session_id,
                 role="assistant",
                 role="assistant",
                 content=accumulated_content,
                 content=accumulated_content,
                 timestamp=datetime.now(),
                 timestamp=datetime.now(),
@@ -176,7 +179,7 @@ async def generate_stream_response(request: ChatRequest, username: str):
             "timestamp": datetime.now().isoformat()
             "timestamp": datetime.now().isoformat()
         }
         }
         save_chat_log(
         save_chat_log(
-            username=username,
+            user_id=user_id,
             question=latest_user_msg.content if latest_user_msg else "",
             question=latest_user_msg.content if latest_user_msg else "",
             stream_mode=True,
             stream_mode=True,
             status="error",
             status="error",
@@ -188,12 +191,12 @@ async def generate_stream_response(request: ChatRequest, username: str):
 @router.post("/chat", response_model=ChatResponse)
 @router.post("/chat", response_model=ChatResponse)
 async def chat(
 async def chat(
     request: ChatRequest,
     request: ChatRequest,
-    username: Annotated[str, Depends(resolve_username)],
+    user_id: Annotated[str, Depends(resolve_user_id)],
 ):
 ):
     try:
     try:
         if request.stream:
         if request.stream:
             return StreamingResponse(
             return StreamingResponse(
-                generate_stream_response(request, username),
+                generate_stream_response(request, user_id),
                 media_type="text/plain",
                 media_type="text/plain",
                 headers={
                 headers={
                     "Cache-Control": "no-cache",
                     "Cache-Control": "no-cache",
@@ -203,12 +206,14 @@ async def chat(
             )
             )
 
 
         # 以下是非流式输出处理
         # 以下是非流式输出处理
+        session_id = request.session_id
         latest_user_msg = get_latest_user_message(request.messages)
         latest_user_msg = get_latest_user_message(request.messages)
         if not latest_user_msg:
         if not latest_user_msg:
             raise ValueError("请求中没有找到user角色的消息")
             raise ValueError("请求中没有找到user角色的消息")
 
 
         save_chat_history(
         save_chat_history(
-            username=username,
+            user_id=user_id,
+            session_id=session_id,
             role=latest_user_msg.role,
             role=latest_user_msg.role,
             content=latest_user_msg.content,
             content=latest_user_msg.content,
             timestamp=datetime.now(),
             timestamp=datetime.now(),
@@ -216,7 +221,7 @@ async def chat(
 
 
         api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}]
         api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}]
         tools = get_doubao_tools()
         tools = get_doubao_tools()
-        previous_response_id = get_previous_response_id(username)
+        previous_response_id = get_previous_response_id(user_id, session_id)
 
 
         response = client.responses.create(
         response = client.responses.create(
             model=config.MODEL_NAME,
             model=config.MODEL_NAME,
@@ -228,7 +233,7 @@ async def chat(
         )
         )
 
 
         save_chat_log(
         save_chat_log(
-            username=username,
+            user_id=user_id,
             question=latest_user_msg.content,
             question=latest_user_msg.content,
             stream_mode=False,
             stream_mode=False,
             raw_response=repr(response),
             raw_response=repr(response),
@@ -259,7 +264,8 @@ async def chat(
 
 
         now = datetime.now()
         now = datetime.now()
         save_chat_history(
         save_chat_history(
-            username=username,
+            user_id=user_id,
+            session_id=session_id,
             role="assistant",
             role="assistant",
             content=message_content,
             content=message_content,
             timestamp=now,
             timestamp=now,
@@ -284,7 +290,7 @@ async def chat(
     except Exception as e:
     except Exception as e:
         error_message = f"处理聊天请求时发生错误: {str(e)}"
         error_message = f"处理聊天请求时发生错误: {str(e)}"
         save_chat_log(
         save_chat_log(
-            username=username,
+            user_id=user_id,
             question=request.messages[-1].content if request.messages else "",
             question=request.messages[-1].content if request.messages else "",
             stream_mode=request.stream,
             stream_mode=request.stream,
             status="error",
             status="error",
@@ -313,9 +319,10 @@ async def get_models(current_user: Annotated[User, Depends(get_current_active_us
 
 
 @router.get("/history")
 @router.get("/history")
 async def get_user_history(
 async def get_user_history(
-    current_user: Annotated[User, Depends(get_current_active_user)]
+    current_user: Annotated[User, Depends(get_current_active_user)],
+    sessionId: str = Query(..., description="会话ID"),
 ) -> List[ChatMessage]:
 ) -> List[ChatMessage]:
-    docs = get_chat_history(current_user.username)
+    docs = get_chat_history(current_user.userId, sessionId)
     return [
     return [
         ChatMessage(
         ChatMessage(
             role=doc["role"],
             role=doc["role"],
@@ -330,11 +337,14 @@ async def get_user_history(
 
 
 
 
 @router.delete("/history")
 @router.delete("/history")
-async def clear_user_history(current_user: Annotated[User, Depends(get_current_active_user)]):
-    deleted_count = delete_chat_history(current_user.username)
+async def clear_user_history(
+    current_user: Annotated[User, Depends(get_current_active_user)],
+    sessionId: str = Query(..., description="会话ID"),
+):
+    deleted_count = delete_chat_history(current_user.userId, sessionId)
     if deleted_count > 0:
     if deleted_count > 0:
-        return {"message": "聊天历史已清空", "user": current_user.username, "deleted_messages": deleted_count, "timestamp": datetime.now()}
-    return {"message": "用户没有聊天历史", "user": current_user.username, "deleted_messages": 0, "timestamp": datetime.now()}
+        return {"message": "聊天历史已清空", "user": current_user.userId, "deleted_messages": deleted_count, "timestamp": datetime.now()}
+    return {"message": "用户没有聊天历史", "user": current_user.userId, "deleted_messages": 0, "timestamp": datetime.now()}
 
 
 
 
 @router.get("/health")
 @router.get("/health")

+ 3 - 1
app/routers/users.py

@@ -66,6 +66,7 @@ class User(BaseModel):
     用户基础信息模型
     用户基础信息模型
     定义用户的公开信息(不包含密码等敏感信息)
     定义用户的公开信息(不包含密码等敏感信息)
     """
     """
+    userId: str  # 用户ID(必需)
     username: str  # 用户名(必需)
     username: str  # 用户名(必需)
     email: Optional[str] = None  # 邮箱(可选)
     email: Optional[str] = None  # 邮箱(可选)
     full_name: Optional[str] = None  # 全名(可选)
     full_name: Optional[str] = None  # 全名(可选)
@@ -108,6 +109,7 @@ class UserUpdate(BaseModel):
 # 这里使用字典来模拟数据库存储,包含一个默认管理员账户
 # 这里使用字典来模拟数据库存储,包含一个默认管理员账户
 fake_users_db = {
 fake_users_db = {
     "root": {
     "root": {
+        "userId":"1",
         "username": "root",
         "username": "root",
         "full_name": "Administrator",
         "full_name": "Administrator",
         "email": "admin@example.com",
         "email": "admin@example.com",
@@ -328,7 +330,7 @@ async def login_for_access_token(
 
 
     # 创建访问令牌,将用户名作为subject存储在令牌中
     # 创建访问令牌,将用户名作为subject存储在令牌中
     access_token = create_access_token(
     access_token = create_access_token(
-        data={"sub": user.username},
+        data={"sub": user.username, "userId": user.userId},
         expires_delta=access_token_expires
         expires_delta=access_token_expires
     )
     )
 
 

+ 4 - 1
app/schemas/chat.py

@@ -1,4 +1,4 @@
-from pydantic import BaseModel
+from pydantic import BaseModel, Field, ConfigDict
 from typing import List, Optional, Dict, Any
 from typing import List, Optional, Dict, Any
 from datetime import datetime
 from datetime import datetime
 from ..core.ark_client import config
 from ..core.ark_client import config
@@ -29,11 +29,14 @@ class CirclePromptConfig(BaseModel):
 
 
 # Ai的请求对象
 # Ai的请求对象
 class ChatRequest(BaseModel):
 class ChatRequest(BaseModel):
+    model_config = ConfigDict(populate_by_name=True)
+
     messages: List[ChatMessage]
     messages: List[ChatMessage]
     model: Optional[str] = config.MODEL_NAME
     model: Optional[str] = config.MODEL_NAME
     stream: Optional[bool] = False
     stream: Optional[bool] = False
     source: Optional[str] = None  # source=app 时走第三方 token 认证
     source: Optional[str] = None  # source=app 时走第三方 token 认证
     token: Optional[str] = None   # App 端传入的第三方 token
     token: Optional[str] = None   # App 端传入的第三方 token
+    session_id: Optional[str] = Field(None, alias="sessionId")  # 会话ID,前端传 sessionId
 
 
 
 
 # Ai的返回对象
 # Ai的返回对象

+ 2 - 2
app/utils/chat_utils.py

@@ -15,8 +15,8 @@ def get_latest_user_message(messages: List[ChatMessage]) -> Optional[ChatMessage
     return None
     return None
 
 
 
 
-def get_previous_response_id(username: str) -> Optional[str]:
-    return get_last_response_id(username)
+def get_previous_response_id(user_id: str, session_id: str) -> Optional[str]:
+    return get_last_response_id(user_id, session_id)
 
 
 # 联网搜索工具
 # 联网搜索工具
 def get_web_search_tools() -> list:
 def get_web_search_tools() -> list: