Browse Source

feat: 以项目制,配置模型api-key和知识库实例id

zhangwl 43 minutes ago
parent
commit
6dd9aaa206

+ 12 - 10
app/core/ark_client.py

@@ -1,13 +1,15 @@
-from volcenginesdkarkruntime import Ark
 from openai import OpenAI
-from ..config.config import Config
+from ..db.ai_config import get_config_by_app_name
 
-config = Config()
-# 初始化客户端
-# client = Ark(api_key=config.API_KEY, base_url=config.BASE_URL,default_headers={"ark-beta-knowledge-search": "true"})  # 启用私域知识库搜索
 
-client = OpenAI(
-    base_url=config.BASE_URL,
-    api_key=config.API_KEY,
-    default_headers={"ark-beta-knowledge-search": "true"}  # 启用私域知识库搜索
-)
+def get_client(app_name: str = "com.yunxiangshengtai") -> OpenAI:
+    """根据appName动态初始化client"""
+    config = get_config_by_app_name(app_name)
+    if not config:
+        raise ValueError(f"未找到appName '{app_name}' 的配置")
+
+    return OpenAI(
+        base_url=config["baseUrl"],
+        api_key=config["apiKey"],
+        default_headers={"ark-beta-knowledge-search": "true"}
+    )

+ 95 - 0
app/db/ai_config.py

@@ -0,0 +1,95 @@
+from pymongo import MongoClient
+from bson import ObjectId
+from datetime import datetime
+from dotenv import load_dotenv
+import os
+
+load_dotenv()
+
+MONGO_URI = os.getenv("ARK_LOGS_MONGO_URI")
+
+client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000)
+db = client["arklogs"]
+ai_config_col = db["ai_config"]
+
+
+def _ensure_index():
+    try:
+        ai_config_col.create_index([("appName", 1)], unique=True)
+        ai_config_col.create_index([("status", 1), ("deleted", 1)])
+    except Exception:
+        pass
+
+
+def get_config_by_app_name(app_name: str) -> dict | None:
+    """根据appName查询配置,只返回启用且未删除的记录"""
+    try:
+        _ensure_index()
+        doc = ai_config_col.find_one({
+            "appName": app_name,
+            "status": 1,
+            "deleted": 0
+        })
+        if doc:
+            doc["_id"] = str(doc["_id"])
+        return doc
+    except Exception as e:
+        print(f"MongoDB 查询 ai_config 失败: {e}")
+        return None
+
+
+def create_config(data: dict) -> str | None:
+    """创建新的配置"""
+    try:
+        _ensure_index()
+        data["status"] = 1
+        data["deleted"] = 0
+        data["createdAt"] = datetime.now()
+        data["updatedAt"] = datetime.now()
+        result = ai_config_col.insert_one(data)
+        return str(result.inserted_id)
+    except Exception as e:
+        print(f"MongoDB 创建 ai_config 失败: {e}")
+        return None
+
+
+def update_config(app_name: str, data: dict) -> int:
+    """更新配置"""
+    try:
+        _ensure_index()
+        data["updatedAt"] = datetime.now()
+        result = ai_config_col.update_one(
+            {"appName": app_name},
+            {"$set": data}
+        )
+        return result.matched_count
+    except Exception as e:
+        print(f"MongoDB 更新 ai_config 失败: {e}")
+        return 0
+
+
+def delete_config(app_name: str) -> int:
+    """软删除配置"""
+    try:
+        result = ai_config_col.update_one(
+            {"appName": app_name},
+            {"$set": {"deleted": 1, "updatedAt": datetime.now()}}
+        )
+        return result.matched_count
+    except Exception as e:
+        print(f"MongoDB 删除 ai_config 失败: {e}")
+        return 0
+
+
+def get_all_configs() -> list:
+    """获取所有启用的配置"""
+    try:
+        _ensure_index()
+        docs = ai_config_col.find({
+            "status": 1,
+            "deleted": 0
+        })
+        return [{"_id": str(doc["_id"]), **{k: v for k, v in doc.items() if k != "_id"}} for doc in docs]
+    except Exception as e:
+        print(f"MongoDB 查询所有 ai_config 失败: {e}")
+        return []

+ 149 - 0
app/routers/ai_config.py

@@ -0,0 +1,149 @@
+from fastapi import APIRouter, HTTPException, status, Depends
+from pydantic import BaseModel
+from typing import Optional, Annotated
+from datetime import datetime
+from ..db.ai_config import (
+    get_config_by_app_name,
+    create_config,
+    update_config,
+    delete_config,
+    get_all_configs
+)
+from ..dependencies.auth import resolve_user_id
+
+router = APIRouter()
+
+
+class AIConfigCreate(BaseModel):
+    appName: str
+    knowledgeId: str
+    apiKey: str
+    baseUrl: str
+
+
+class AIConfigUpdate(BaseModel):
+    knowledgeId: Optional[str] = None
+    apiKey: Optional[str] = None
+    baseUrl: Optional[str] = None
+    status: Optional[int] = None
+
+
+class AIConfigResponse(BaseModel):
+    id: str
+    appName: str
+    knowledgeId: str
+    apiKey: str
+    baseUrl: str
+    status: int
+    deleted: int
+    createdAt: datetime
+    updatedAt: datetime
+
+
+@router.post("/", response_model=dict, summary="创建AI配置", description="创建新的AI应用配置映射")
+async def create_ai_config(
+    config: AIConfigCreate
+) -> dict:
+    """创建新的AI配置"""
+    existing = get_config_by_app_name(config.appName)
+    if existing:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=f"appName '{config.appName}' 已存在"
+        )
+
+    config_id = create_config(config.model_dump())
+    if not config_id:
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+            detail="创建配置失败"
+        )
+
+    return {
+        "message": "配置创建成功",
+        "id": config_id,
+        "appName": config.appName
+    }
+
+
+@router.get("/{app_name}", response_model=dict, summary="获取AI配置", description="根据appName获取AI配置")
+async def get_ai_config(
+    app_name: str
+) -> dict:
+    """获取指定appName的配置"""
+    config = get_config_by_app_name(app_name)
+    if not config:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail=f"appName '{app_name}' 的配置不存在"
+        )
+    return config
+
+
+@router.put("/{app_name}", response_model=dict, summary="更新AI配置", description="更新指定appName的AI配置")
+async def update_ai_config(
+    app_name: str,
+    config_update: AIConfigUpdate
+) -> dict:
+    """更新指定appName的配置"""
+    existing = get_config_by_app_name(app_name)
+    if not existing:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail=f"appName '{app_name}' 的配置不存在"
+        )
+
+    update_data = {k: v for k, v in config_update.model_dump().items() if v is not None}
+    if not update_data:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail="没有要更新的字段"
+        )
+
+    matched = update_config(app_name, update_data)
+    if matched == 0:
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+            detail="更新配置失败"
+        )
+
+    return {
+        "message": "配置更新成功",
+        "appName": app_name
+    }
+
+
+@router.delete("/{app_name}", response_model=dict, summary="删除AI配置", description="删除指定appName的AI配置")
+async def delete_ai_config(
+    app_name: str
+) -> dict:
+    """删除指定appName的配置"""
+    existing = get_config_by_app_name(app_name)
+    if not existing:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail=f"appName '{app_name}' 的配置不存在"
+        )
+
+    matched = delete_config(app_name)
+    if matched == 0:
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+            detail="删除配置失败"
+        )
+
+    return {
+        "message": "配置删除成功",
+        "appName": app_name
+    }
+
+
+@router.get("/", response_model=dict, summary="获取所有AI配置", description="获取所有启用的AI配置")
+async def list_ai_configs(
+) -> dict:
+    """获取所有启用的配置"""
+    configs = get_all_configs()
+    return {
+        "total": len(configs),
+        "configs": configs
+    }

+ 41 - 23
app/routers/chat.py

@@ -6,20 +6,31 @@ import json
 import asyncio
 import threading
 
-from ..core.ark_client import config, client
+from ..core.ark_client import get_client
+from ..config.config import Config
 from ..schemas.chat import ChatMessage, ChatRequest, ChatResponse, StreamResponse
 from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools, get_web_search_tools, get_knowledge_search_tools
 from ..routers.users import get_current_active_user, User
 from ..db.mongo import save_chat_log, save_chat_history, get_chat_history, delete_chat_history, get_sessions
+from ..db.ai_config import get_config_by_app_name
 from ..dependencies.auth import resolve_user_id
 
+config = Config()
+
 router = APIRouter()
 
 
-async def generate_stream_response(request: ChatRequest, user_id: str):
+async def generate_stream_response(request: ChatRequest, user_id: str, app_name: str = "com.yunxiangshengtai"):
     session_id = request.session_id
     latest_user_msg = None
     try:
+        ai_config = get_config_by_app_name(app_name)
+        if not ai_config:
+            raise ValueError(f"未找到appName '{app_name}' 的配置")
+
+        client = get_client(app_name)
+        knowledge_id = ai_config["knowledgeId"]
+
         latest_user_msg = get_latest_user_message(request.messages)
         if not latest_user_msg:
             raise ValueError("请求中没有找到user角色的消息")
@@ -152,8 +163,8 @@ async def generate_stream_response(request: ChatRequest, user_id: str):
             2. 不得把低可信、未验证、可能不实的信息写入答案;
             3. 不得编造事实、时间、数据、人物关系或产品能力;
             4. 不得向用户暴露知识库、检索、搜索策略、来源筛选过程或内部判断过程;
-            5. 不得输出“思考过程”“搜索关键词”“为什么需要搜索”等内部推理内容;
-            6. 不得使用根据参考资料/根据知识库/根据搜索结果”等表述。
+            5. 不得输出”思考过程””搜索关键词””为什么需要搜索”等内部推理内容;
+            6. 不得使用根据参考资料/根据知识库/根据搜索结果”等表述。
 
             ## 八、最终目标
             在保证回答自然、清晰、易懂的前提下:
@@ -166,7 +177,7 @@ async def generate_stream_response(request: ChatRequest, user_id: str):
         system_prompt = {"role": "system", "content": [{"type": "input_text", "text": system_prompt}]}
 
         api_messages = [system_prompt, {"role": latest_user_msg.role, "content": latest_user_msg.content}]
-        tools = get_web_search_tools() + get_knowledge_search_tools()
+        tools = get_web_search_tools() + get_knowledge_search_tools(knowledge_id)
         previous_response_id = get_previous_response_id(user_id, session_id)
 
         stream = client.responses.create(
@@ -294,11 +305,12 @@ async def generate_stream_response(request: ChatRequest, user_id: str):
 async def chat(
     request: ChatRequest,
     user_id: Annotated[str, Depends(resolve_user_id)],
+    app_name: str = Query(default="com.yunxiangshengtai", description="应用包名"),
 ):
     try:
         if request.stream:
             return StreamingResponse(
-                generate_stream_response(request, user_id),
+                generate_stream_response(request, user_id, app_name),
                 media_type="text/plain",
                 headers={
                     "Cache-Control": "no-cache",
@@ -313,6 +325,12 @@ async def chat(
         if not latest_user_msg:
             raise ValueError("请求中没有找到user角色的消息")
 
+        ai_config = get_config_by_app_name(app_name)
+        if not ai_config:
+            raise ValueError(f"未找到appName '{app_name}' 的配置")
+
+        client = get_client(app_name)
+
         save_chat_history(
             user_id=user_id,
             session_id=session_id,
@@ -400,23 +418,23 @@ async def chat(
         )
         raise HTTPException(status_code=500, detail=error_message)
 
-
-@router.get("/models")
-async def get_models(current_user: Annotated[User, Depends(get_current_active_user)]):
-    try:
-        models = client.models.list()
-        return {
-            "models": [model.id for model in models.data],
-            "default_model": config.MODEL_NAME,
-            "user": current_user.username
-        }
-    except Exception:
-        return {
-            "models": [config.MODEL_NAME],
-            "default_model": config.MODEL_NAME,
-            "note": "使用默认模型配置",
-            "user": current_user.username
-        }
+#
+# @router.get("/models")
+# async def get_models(current_user: Annotated[User, Depends(get_current_active_user)]):
+#     try:
+#         models = client.models.list()
+#         return {
+#             "models": [model.id for model in models.data],
+#             "default_model": config.MODEL_NAME,
+#             "user": current_user.username
+#         }
+#     except Exception:
+#         return {
+#             "models": [config.MODEL_NAME],
+#             "default_model": config.MODEL_NAME,
+#             "note": "使用默认模型配置",
+#             "user": current_user.username
+#         }
 
 
 @router.get("/history")

+ 7 - 3
app/routers/chat_tools.py

@@ -1,11 +1,13 @@
 from fastapi import APIRouter, HTTPException
 from datetime import datetime
 
-from ..core.ark_client import config, client
+from ..core.ark_client import get_client
+from ..config.config import Config
 from ..schemas.chat import ChatMessage, ChatResponse, CommentRequest, CirclePromptConfig, HistoricalFigure, RephraseRequest, FigureUpsert
 from ..db.souyue_mongo import get_mblog_by_id
 from ..db.mongo import get_circle_prompt, upsert_circle_prompt, get_all_figures, get_figure_by_id, insert_figure, update_figure, delete_figure
 
+
 router = APIRouter()
 
 
@@ -76,8 +78,9 @@ async def generate_post_comment(request: CommentRequest):
 
     content = file_list + [{"type": "input_text", "text": input_text}]
     print(f"concat text: {content}")
+    client = get_client(app_name)
     response = client.responses.create(
-        model=config.MODEL_NAME,
+        model=Config.MODEL_NAME,
         input=[{"role": "user", "content": content}],
 
     )
@@ -156,8 +159,9 @@ async def rephrase_as_figure(request: RephraseRequest):
         f"原文:{request.text}"
     )
 
+    client = get_client()
     response = client.responses.create(
-        model=config.MODEL_NAME,
+        model=Config.MODEL_NAME,
         input=[{"role": "user", "content": prompt}],
         stream=False,
         store=False,

+ 6 - 2
app/routers/users.py

@@ -42,6 +42,7 @@ class LoginRequest(BaseModel):
     """
     username: str
     password: str
+    appName: Optional[str] = "com.yunxiangshengtai"
 
 
 class Token(BaseModel):
@@ -54,6 +55,7 @@ class Token(BaseModel):
     refresh_token: str  # JWT刷新令牌
     token_type: str  # 令牌类型,通常是"bearer"
     username: str
+    appName: str
 
 
 class TokenData(BaseModel):
@@ -356,7 +358,8 @@ async def login_for_access_token(
         "access_token": access_token,
         "refresh_token": refresh_token,
         "token_type": "bearer",
-        "username": user.username
+        "username": user.username,
+        "appName": login_data.appName
     }
 
 
@@ -483,7 +486,8 @@ async def refresh_access_token(refresh_token: str) -> Token:
         "access_token": new_access_token,
         "refresh_token": refresh_token,
         "token_type": "bearer",
-        "username": username
+        "username": username,
+        "appName": "com.yunxiangshengtai"
     }
 
 

+ 2 - 2
app/schemas/chat.py

@@ -1,7 +1,7 @@
 from pydantic import BaseModel, Field, ConfigDict
 from typing import List, Optional, Dict, Any
 from datetime import datetime
-from ..core.ark_client import config
+from ..config.config import Config
 
 
 class ChatMessage(BaseModel):
@@ -32,7 +32,7 @@ class ChatRequest(BaseModel):
     model_config = ConfigDict(populate_by_name=True)
 
     messages: List[ChatMessage]
-    model: Optional[str] = config.MODEL_NAME
+    model: Optional[str] = Config.MODEL_NAME
     stream: Optional[bool] = False
     source: Optional[str] = None  # source=app 时走第三方 token 认证
     token: Optional[str] = None   # App 端传入的第三方 token

+ 6 - 6
app/utils/chat_utils.py

@@ -1,6 +1,6 @@
 from typing import List, Optional, Dict
 from ..schemas.chat import ChatMessage
-from ..core.ark_client import config
+from ..config.config import Config
 from ..db.mongo import get_last_response_id
 
 
@@ -24,7 +24,7 @@ def get_web_search_tools() -> list:
         "type": "web_search",
         "max_keyword": 20,
         "limit": 20,
-        "sources": ["douyin", "moji", "toutiao"],# 附加搜索来源(抖音百科、墨迹天气、头条图文等平台)
+        "sources": ["douyin", "toutiao"],# 附加搜索来源(抖音百科、墨迹天气、头条图文等平台)
         "user_location": {  # 用户地理位置(用于优化搜索结果)
             "type": "approximate",  # 大致位置
             "country": "中国",
@@ -33,10 +33,10 @@ def get_web_search_tools() -> list:
         }
     }]
 # 私域知识库 --- 此处的知识库也需要映射
-def get_knowledge_search_tools() -> list:
+def get_knowledge_search_tools(knowledge_resource_id: str) -> list:
     return [{
             "type": "knowledge_search",
-            "knowledge_resource_id": "kb-25bb30f3d7463a76",  # 替换为实际知识库ID
+            "knowledge_resource_id": knowledge_resource_id,  # 替换为实际知识库ID
             "limit": 10,  # 最多返回10条搜索结果
         }]
 
@@ -48,12 +48,12 @@ def get_doubao_tools() -> list:
             # 联网搜索功能
             "ai_search": {
                 "type": "disabled",
-                "role_description": config.ROLE_DESCRIPTION
+                "role_description": Config.ROLE_DESCRIPTION
             },
             # 边想边搜功能
             "reasoning_search":{
                 "type": "enabled",
-                "role_description": config.ROLE_DESCRIPTION
+                "role_description": Config.ROLE_DESCRIPTION
             }
         },
         "user_location": {

+ 2 - 1
main.py

@@ -1,7 +1,7 @@
 from fastapi import FastAPI
 from fastapi.middleware.cors import CORSMiddleware
 
-from app.routers import users, chat,chat_tools
+from app.routers import users, chat, chat_tools, ai_config
 
 # 创建FastAPI应用实例
 app = FastAPI(title="聊天机器人", version="1.0.0", description="基于fastapi+VUE的聊天机器人")
@@ -21,6 +21,7 @@ app.add_middleware(
 app.include_router(users.router, prefix="/main/users", tags=["用户管理"])
 app.include_router(chat.router, prefix="/main/chat", tags=["聊天管理"])
 app.include_router(chat_tools.router, prefix="/main/chatTools", tags=["AI工具管理"])
+app.include_router(ai_config.router, prefix="/main/aiConfig", tags=["AI配置管理"])