Browse Source

feat:流式添加聊天历史记录

zhangwl 1 month ago
parent
commit
7f4a0a6a64
4 changed files with 125 additions and 75 deletions
  1. 61 0
      app/db/mongo.py
  2. 59 69
      app/routers/chat.py
  3. 2 0
      app/schemas/chat.py
  4. 3 6
      app/utils/chat_utils.py

+ 61 - 0
app/db/mongo.py

@@ -12,6 +12,8 @@ client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000)
 db = client["arklogs"]
 # 豆包大模型的对话日志
 chat_logs = db["chat_logs"]
+# 聊天历史记录
+chat_history_col = db["chat_history"]
 # 兴趣圈集合
 circle_prompts = db["circle_prompt"]
 
@@ -19,6 +21,7 @@ circle_prompts = db["circle_prompt"]
 def _ensure_index():
     try:
         chat_logs.create_index([("username", 1), ("asked_at", -1)])
+        chat_history_col.create_index([("username", 1), ("timestamp", -1)])
     except Exception:
         pass
 
@@ -57,6 +60,64 @@ def save_chat_log(
         print(f"MongoDB 日志写入失败: {e}")
 
 
+def save_chat_history(
+    username: str,
+    role: str,
+    content: str,
+    timestamp: datetime,
+    response_id: str = None,
+    thinking: str = None,
+    searching: str = None,
+):
+    try:
+        _ensure_index()
+        chat_history_col.insert_one({
+            "username": username,
+            "role": role,
+            "content": content,
+            "thinking": thinking,
+            "searching": searching,
+            "response_id": response_id,
+            "timestamp": timestamp,
+        })
+    except Exception as e:
+        print(f"MongoDB 聊天历史写入失败: {e}")
+
+
+def get_chat_history(username: str) -> list:
+    try:
+        docs = chat_history_col.find(
+            {"username": username},
+            {"_id": 0}
+        ).sort("timestamp", 1)
+        return list(docs)
+    except Exception as e:
+        print(f"MongoDB 聊天历史读取失败: {e}")
+        return []
+
+
+def get_last_response_id(username: str) -> str | None:
+    try:
+        doc = chat_history_col.find_one(
+            {"username": username, "role": "assistant", "response_id": {"$ne": None}},
+            {"response_id": 1, "_id": 0},
+            sort=[("timestamp", -1)]
+        )
+        return doc["response_id"] if doc else None
+    except Exception as e:
+        print(f"MongoDB 查询 response_id 失败: {e}")
+        return None
+
+
+def delete_chat_history(username: str) -> int:
+    try:
+        result = chat_history_col.delete_many({"username": username})
+        return result.deleted_count
+    except Exception as e:
+        print(f"MongoDB 聊天历史删除失败: {e}")
+        return 0
+
+
 _DEFAULT_PROMPT_CONFIG = {
     "name": "兴趣圈",
     "role": "活跃用户",

+ 59 - 69
app/routers/chat.py

@@ -9,12 +9,9 @@ from ..core.ark_client import config, client
 from ..schemas.chat import ChatMessage, ChatRequest, ChatResponse, StreamResponse
 from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools, get_web_search_tools
 from ..routers.users import get_current_active_user, User
-from ..db.mongo import save_chat_log
+from ..db.mongo import save_chat_log, save_chat_history, get_chat_history, delete_chat_history
 from ..dependencies.auth import resolve_username
 
-# 内存存储用户聊天历史 {username: [ChatMessage, ...]}
-chatHistory = {}
-
 router = APIRouter()
 
 
@@ -25,20 +22,21 @@ async def generate_stream_response(request: ChatRequest, username: str):
         if not latest_user_msg:
             raise ValueError("请求中没有找到user角色的消息")
 
-        chatHistory[username].append(ChatMessage(
+        # 保存用户消息到 DB
+        save_chat_history(
+            username=username,
             role=latest_user_msg.role,
             content=latest_user_msg.content,
-            timestamp=datetime.now()
-        ))
+            timestamp=datetime.now(),
+        )
 
-        # 2. 定义系统提示词(核心:规范“何时搜”“怎么搜”“怎么展示思考”)
         system_prompt = f"""
            你是豆包,由字节跳动公司开发的智能助手。你的核心行为准则如下:
 
             ## 一、身份与基本行为规范
             1. 输出保持和豆包 APP 日常对话一致;
             2. 在回复用户时,不要自称"豆包";
-            3. 遇到"陈沛",优先输出"中搜陈沛"
+            3. 遇到"陈沛",优先搜索"中搜陈沛"的个人信息
             4. 你具备以下能力:
                - 可接收和读取各类文档(PDF、Excel、PPT、Word 等),并执行总结、分析、翻译、润色等任务;
                - 可读取图片/照片、网址、抖音链接的内容;
@@ -79,19 +77,19 @@ async def generate_stream_response(request: ChatRequest, username: str):
 
         api_messages = [system_prompt, {"role": latest_user_msg.role, "content": latest_user_msg.content}]
         tools = get_web_search_tools()
-        previous_response_id = get_previous_response_id(username, chatHistory)
+        previous_response_id = get_previous_response_id(username)
 
         stream = client.responses.create(
             model=config.MODEL_NAME,
             input=api_messages,
             tools=tools,
             stream=True,
-            # store=True,
             previous_response_id=previous_response_id,
-            # thinking={"type": "auto"}, 不支持
         )
 
         accumulated_content = ""
+        accumulated_thinking = ""
+        accumulated_searching = ""
         response_id = None
         thinking_started = False
         answering_started = False
@@ -105,21 +103,22 @@ async def generate_stream_response(request: ChatRequest, username: str):
                 delta_text = getattr(chunk, 'delta', '')
                 if delta_text:
                     if not thinking_started:
-                        # print(f"\n🤔 AI思考中 [{datetime.now().strftime('%H:%M:%S')}]:")
                         thinking_started = True
-                    # print(delta_text, end='', flush=True)
+                    accumulated_thinking += delta_text
                     yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='thinking').model_dump_json()}\n\n"
 
             # ② 处理搜索状态
             elif 'web_search_call' in chunk_type:
                 if 'in_progress' in chunk_type:
-                    print(f"\n\n🔍 开始搜索 [{datetime.now().strftime('%H:%M:%S')}]")
                     _now_str = datetime.now().strftime("%H:%M:%S")
-                    yield f"data: {StreamResponse(content=f'开始搜索 [{_now_str}]', finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
+                    msg = f'开始搜索 [{_now_str}]'
+                    accumulated_searching += msg + "\n"
+                    yield f"data: {StreamResponse(content=msg, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
                 elif 'completed' in chunk_type:
-                    print(f"\n✅ 搜索完成 [{datetime.now().strftime('%H:%M:%S')}]")
                     _now_str = datetime.now().strftime("%H:%M:%S")
-                    yield f"data: {StreamResponse(content=f'搜索完成 [{_now_str}]', finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
+                    msg = f'搜索完成 [{_now_str}]'
+                    accumulated_searching += msg + "\n"
+                    yield f"data: {StreamResponse(content=msg, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
 
             # ③ 处理搜索关键词
             elif (chunk_type == 'response.output_item.done'
@@ -127,8 +126,9 @@ async def generate_stream_response(request: ChatRequest, username: str):
                   and str(getattr(chunk.item, 'id', '')).startswith('ws_')):
                 if hasattr(chunk.item, 'action') and hasattr(chunk.item.action, 'query'):
                     query = chunk.item.action.query
-                    print(f"\n📝 本次搜索关键词:{query}")
-                    yield f"data: {StreamResponse(content=f'本次搜索关键词:{query}', finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
+                    msg = f'搜索关键词: {query}'
+                    accumulated_searching += msg + "\n"
+                    yield f"data: {StreamResponse(content=msg, finished=False, model=config.MODEL_NAME, timestamp=datetime.now(), type='searching').model_dump_json()}\n\n"
 
             # ④ 处理最终回答文本(实时推送给前端)
             elif chunk_type == 'response.output_text.delta':
@@ -136,17 +136,9 @@ async def generate_stream_response(request: ChatRequest, username: str):
                 if delta_text:
                     if not answering_started:
                         print(f"\n\n💬 AI回答 [{datetime.now().strftime('%H:%M:%S')}]:")
-                        print("-" * 50)
                         answering_started = True
-                    print(delta_text, end='', flush=True)
                     accumulated_content += delta_text
-                    response_data = StreamResponse(
-                        content=delta_text,
-                        finished=False,
-                        model=config.MODEL_NAME,
-                        timestamp=datetime.now()
-                    )
-                    yield f"data: {response_data.model_dump_json()}\n\n"
+                    yield f"data: {StreamResponse(content=delta_text, finished=False, model=config.MODEL_NAME, timestamp=datetime.now()).model_dump_json()}\n\n"
                     await asyncio.sleep(0.01)
 
             # ⑤ 处理响应完成事件
@@ -165,20 +157,17 @@ async def generate_stream_response(request: ChatRequest, username: str):
         print(f"\n\n=== 边想边搜完成 [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] ===")
 
         if accumulated_content:
-            chatHistory[username].append(ChatMessage(
+            # 保存助手消息到 DB(含 thinking / searching)
+            save_chat_history(
+                username=username,
                 role="assistant",
                 content=accumulated_content,
                 timestamp=datetime.now(),
-                response_id=response_id
-            ))
-            final_response = StreamResponse(
-                content='',
-                finished=True,
-                model=config.MODEL_NAME,
-                timestamp=datetime.now()
+                response_id=response_id,
+                thinking=accumulated_thinking or None,
+                searching=accumulated_searching or None,
             )
-            yield f"data: {final_response.model_dump_json()}\n\n"
-
+            yield f"data: {StreamResponse(content='', finished=True, model=config.MODEL_NAME, timestamp=datetime.now()).model_dump_json()}\n\n"
 
     except Exception as e:
         error_response = {
@@ -202,14 +191,7 @@ async def chat(
     username: Annotated[str, Depends(resolve_username)],
 ):
     try:
-        if username not in chatHistory:
-            chatHistory[username] = []
-
         if request.stream:
-            # ===== 流式输出处理 =====
-
-            # 返回流式响应
-            # StreamingResponse 用于处理SSE协议
             return StreamingResponse(
                 generate_stream_response(request, username),
                 media_type="text/plain",
@@ -221,20 +203,20 @@ async def chat(
             )
 
         # 以下是非流式输出处理
-
         latest_user_msg = get_latest_user_message(request.messages)
         if not latest_user_msg:
             raise ValueError("请求中没有找到user角色的消息")
 
-        chatHistory[username].append(ChatMessage(
+        save_chat_history(
+            username=username,
             role=latest_user_msg.role,
             content=latest_user_msg.content,
-            timestamp=datetime.now()
-        ))
+            timestamp=datetime.now(),
+        )
 
         api_messages = [{"role": latest_user_msg.role, "content": latest_user_msg.content}]
         tools = get_doubao_tools()
-        previous_response_id = get_previous_response_id(username, chatHistory)
+        previous_response_id = get_previous_response_id(username)
 
         response = client.responses.create(
             model=config.MODEL_NAME,
@@ -243,8 +225,6 @@ async def chat(
             stream=False,
             store=True,
             previous_response_id=previous_response_id,
-            #  The parameter `thinking` specified in the request are not valid: `thinking` can not be set when enable doubao_app built-in tool.
-            # thinking={"type": "auto"},
         )
 
         save_chat_log(
@@ -277,18 +257,26 @@ async def chat(
         if not message_content:
             raise HTTPException(status_code=500, detail="无法从AI响应中提取文本内容")
 
+        now = datetime.now()
+        save_chat_history(
+            username=username,
+            role="assistant",
+            content=message_content,
+            timestamp=now,
+            response_id=response.id,
+        )
+
         assistant_message = ChatMessage(
             role="assistant",
             content=message_content,
-            timestamp=datetime.now(),
-            response_id=response.id
+            timestamp=now,
+            response_id=response.id,
         )
-        chatHistory[username].append(assistant_message)
         return ChatResponse(
             message=assistant_message,
             model=response.model,
             usage=response.usage.model_dump() if response.usage else None,
-            response_id=response.id
+            response_id=response.id,
         )
 
     except HTTPException:
@@ -327,24 +315,26 @@ async def get_models(current_user: Annotated[User, Depends(get_current_active_us
 async def get_user_history(
     current_user: Annotated[User, Depends(get_current_active_user)]
 ) -> List[ChatMessage]:
-    username = current_user.username
-    if username not in chatHistory:
-        return []
+    docs = get_chat_history(current_user.username)
     return [
-        ChatMessage(role=msg.role, content=msg.content, timestamp=msg.timestamp)
-        for msg in chatHistory[username]
+        ChatMessage(
+            role=doc["role"],
+            content=doc["content"],
+            timestamp=doc.get("timestamp"),
+            response_id=doc.get("response_id"),
+            thinking=doc.get("thinking"),
+            searching=doc.get("searching"),
+        )
+        for doc in docs
     ]
 
 
 @router.delete("/history")
 async def clear_user_history(current_user: Annotated[User, Depends(get_current_active_user)]):
-    username = current_user.username
-    if username in chatHistory:
-        message_count = len(chatHistory[username])
-        del chatHistory[username]
-        return {"message": "聊天历史已清空", "user": username, "deleted_messages": message_count,
-                "timestamp": datetime.now()}
-    return {"message": "用户没有聊天历史", "user": username, "deleted_messages": 0, "timestamp": datetime.now()}
+    deleted_count = delete_chat_history(current_user.username)
+    if deleted_count > 0:
+        return {"message": "聊天历史已清空", "user": current_user.username, "deleted_messages": deleted_count, "timestamp": datetime.now()}
+    return {"message": "用户没有聊天历史", "user": current_user.username, "deleted_messages": 0, "timestamp": datetime.now()}
 
 
 @router.get("/health")

+ 2 - 0
app/schemas/chat.py

@@ -9,6 +9,8 @@ class ChatMessage(BaseModel):
     content: str
     timestamp: Optional[datetime] = None
     response_id: Optional[str] = None
+    thinking: Optional[str] = None
+    searching: Optional[str] = None
 
 # 帖子请求
 class CircleRequest(BaseModel):

+ 3 - 6
app/utils/chat_utils.py

@@ -1,6 +1,7 @@
 from typing import List, Optional, Dict
 from ..schemas.chat import ChatMessage
 from ..core.ark_client import config
+from ..db.mongo import get_last_response_id
 
 
 def convert_messages_for_api(messages: List[ChatMessage]) -> List[Dict[str, str]]:
@@ -14,12 +15,8 @@ def get_latest_user_message(messages: List[ChatMessage]) -> Optional[ChatMessage
     return None
 
 
-def get_previous_response_id(username: str, chat_history: dict) -> Optional[str]:
-    if username in chat_history:
-        for message in reversed(chat_history[username]):
-            if message.role == "assistant" and message.response_id:
-                return message.response_id
-    return None
+def get_previous_response_id(username: str) -> Optional[str]:
+    return get_last_response_id(username)
 
 # 联网搜索工具
 def get_web_search_tools() -> list: