Browse Source

feat: 以项目制,配置模型api-key和知识库实例id

zhangwl 1 hour ago
parent
commit
6dd9aaa206

+ 12 - 10
app/core/ark_client.py

@@ -1,13 +1,15 @@
-from volcenginesdkarkruntime import Ark
 from openai import OpenAI
 from openai import OpenAI
-from ..config.config import Config
+from ..db.ai_config import get_config_by_app_name
 
 
-config = Config()
-# 初始化客户端
-# client = Ark(api_key=config.API_KEY, base_url=config.BASE_URL,default_headers={"ark-beta-knowledge-search": "true"})  # 启用私域知识库搜索
 
 
-client = OpenAI(
-    base_url=config.BASE_URL,
-    api_key=config.API_KEY,
-    default_headers={"ark-beta-knowledge-search": "true"}  # 启用私域知识库搜索
-)
+def get_client(app_name: str = "com.yunxiangshengtai") -> OpenAI:
+    """根据appName动态初始化client"""
+    config = get_config_by_app_name(app_name)
+    if not config:
+        raise ValueError(f"未找到appName '{app_name}' 的配置")
+
+    return OpenAI(
+        base_url=config["baseUrl"],
+        api_key=config["apiKey"],
+        default_headers={"ark-beta-knowledge-search": "true"}
+    )

+ 95 - 0
app/db/ai_config.py

@@ -0,0 +1,95 @@
+from pymongo import MongoClient
+from bson import ObjectId
+from datetime import datetime
+from dotenv import load_dotenv
+import os
+
+load_dotenv()
+
+MONGO_URI = os.getenv("ARK_LOGS_MONGO_URI")
+
+client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000)
+db = client["arklogs"]
+ai_config_col = db["ai_config"]
+
+
+def _ensure_index():
+    try:
+        ai_config_col.create_index([("appName", 1)], unique=True)
+        ai_config_col.create_index([("status", 1), ("deleted", 1)])
+    except Exception:
+        pass
+
+
+def get_config_by_app_name(app_name: str) -> dict | None:
+    """根据appName查询配置,只返回启用且未删除的记录"""
+    try:
+        _ensure_index()
+        doc = ai_config_col.find_one({
+            "appName": app_name,
+            "status": 1,
+            "deleted": 0
+        })
+        if doc:
+            doc["_id"] = str(doc["_id"])
+        return doc
+    except Exception as e:
+        print(f"MongoDB 查询 ai_config 失败: {e}")
+        return None
+
+
+def create_config(data: dict) -> str | None:
+    """创建新的配置"""
+    try:
+        _ensure_index()
+        data["status"] = 1
+        data["deleted"] = 0
+        data["createdAt"] = datetime.now()
+        data["updatedAt"] = datetime.now()
+        result = ai_config_col.insert_one(data)
+        return str(result.inserted_id)
+    except Exception as e:
+        print(f"MongoDB 创建 ai_config 失败: {e}")
+        return None
+
+
+def update_config(app_name: str, data: dict) -> int:
+    """更新配置"""
+    try:
+        _ensure_index()
+        data["updatedAt"] = datetime.now()
+        result = ai_config_col.update_one(
+            {"appName": app_name},
+            {"$set": data}
+        )
+        return result.matched_count
+    except Exception as e:
+        print(f"MongoDB 更新 ai_config 失败: {e}")
+        return 0
+
+
+def delete_config(app_name: str) -> int:
+    """软删除配置"""
+    try:
+        result = ai_config_col.update_one(
+            {"appName": app_name},
+            {"$set": {"deleted": 1, "updatedAt": datetime.now()}}
+        )
+        return result.matched_count
+    except Exception as e:
+        print(f"MongoDB 删除 ai_config 失败: {e}")
+        return 0
+
+
+def get_all_configs() -> list:
+    """获取所有启用的配置"""
+    try:
+        _ensure_index()
+        docs = ai_config_col.find({
+            "status": 1,
+            "deleted": 0
+        })
+        return [{"_id": str(doc["_id"]), **{k: v for k, v in doc.items() if k != "_id"}} for doc in docs]
+    except Exception as e:
+        print(f"MongoDB 查询所有 ai_config 失败: {e}")
+        return []

+ 149 - 0
app/routers/ai_config.py

@@ -0,0 +1,149 @@
+from fastapi import APIRouter, HTTPException, status, Depends
+from pydantic import BaseModel
+from typing import Optional, Annotated
+from datetime import datetime
+from ..db.ai_config import (
+    get_config_by_app_name,
+    create_config,
+    update_config,
+    delete_config,
+    get_all_configs
+)
+from ..dependencies.auth import resolve_user_id
+
+router = APIRouter()
+
+
+class AIConfigCreate(BaseModel):
+    appName: str
+    knowledgeId: str
+    apiKey: str
+    baseUrl: str
+
+
+class AIConfigUpdate(BaseModel):
+    knowledgeId: Optional[str] = None
+    apiKey: Optional[str] = None
+    baseUrl: Optional[str] = None
+    status: Optional[int] = None
+
+
+class AIConfigResponse(BaseModel):
+    id: str
+    appName: str
+    knowledgeId: str
+    apiKey: str
+    baseUrl: str
+    status: int
+    deleted: int
+    createdAt: datetime
+    updatedAt: datetime
+
+
+@router.post("/", response_model=dict, summary="创建AI配置", description="创建新的AI应用配置映射")
+async def create_ai_config(
+    config: AIConfigCreate
+) -> dict:
+    """创建新的AI配置"""
+    existing = get_config_by_app_name(config.appName)
+    if existing:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=f"appName '{config.appName}' 已存在"
+        )
+
+    config_id = create_config(config.model_dump())
+    if not config_id:
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+            detail="创建配置失败"
+        )
+
+    return {
+        "message": "配置创建成功",
+        "id": config_id,
+        "appName": config.appName
+    }
+
+
+@router.get("/{app_name}", response_model=dict, summary="获取AI配置", description="根据appName获取AI配置")
+async def get_ai_config(
+    app_name: str
+) -> dict:
+    """获取指定appName的配置"""
+    config = get_config_by_app_name(app_name)
+    if not config:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail=f"appName '{app_name}' 的配置不存在"
+        )
+    return config
+
+
+@router.put("/{app_name}", response_model=dict, summary="更新AI配置", description="更新指定appName的AI配置")
+async def update_ai_config(
+    app_name: str,
+    config_update: AIConfigUpdate
+) -> dict:
+    """更新指定appName的配置"""
+    existing = get_config_by_app_name(app_name)
+    if not existing:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail=f"appName '{app_name}' 的配置不存在"
+        )
+
+    update_data = {k: v for k, v in config_update.model_dump().items() if v is not None}
+    if not update_data:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail="没有要更新的字段"
+        )
+
+    matched = update_config(app_name, update_data)
+    if matched == 0:
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+            detail="更新配置失败"
+        )
+
+    return {
+        "message": "配置更新成功",
+        "appName": app_name
+    }
+
+
+@router.delete("/{app_name}", response_model=dict, summary="删除AI配置", description="删除指定appName的AI配置")
+async def delete_ai_config(
+    app_name: str
+) -> dict:
+    """删除指定appName的配置"""
+    existing = get_config_by_app_name(app_name)
+    if not existing:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail=f"appName '{app_name}' 的配置不存在"
+        )
+
+    matched = delete_config(app_name)
+    if matched == 0:
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+            detail="删除配置失败"
+        )
+
+    return {
+        "message": "配置删除成功",
+        "appName": app_name
+    }
+
+
+@router.get("/", response_model=dict, summary="获取所有AI配置", description="获取所有启用的AI配置")
+async def list_ai_configs(
+) -> dict:
+    """获取所有启用的配置"""
+    configs = get_all_configs()
+    return {
+        "total": len(configs),
+        "configs": configs
+    }

+ 41 - 23
app/routers/chat.py

@@ -6,20 +6,31 @@ import json
 import asyncio
 import asyncio
 import threading
 import threading
 
 
-from ..core.ark_client import config, client
+from ..core.ark_client import get_client
+from ..config.config import Config
 from ..schemas.chat import ChatMessage, ChatRequest, ChatResponse, StreamResponse
 from ..schemas.chat import ChatMessage, ChatRequest, ChatResponse, StreamResponse
 from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools, get_web_search_tools, get_knowledge_search_tools
 from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools, get_web_search_tools, get_knowledge_search_tools
 from ..routers.users import get_current_active_user, User
 from ..routers.users import get_current_active_user, User
 from ..db.mongo import save_chat_log, save_chat_history, get_chat_history, delete_chat_history, get_sessions
 from ..db.mongo import save_chat_log, save_chat_history, get_chat_history, delete_chat_history, get_sessions
+from ..db.ai_config import get_config_by_app_name
 from ..dependencies.auth import resolve_user_id
 from ..dependencies.auth import resolve_user_id
 
 
+config = Config()
+
 router = APIRouter()
 router = APIRouter()
 
 
 
 
-async def generate_stream_response(request: ChatRequest, user_id: str):
+async def generate_stream_response(request: ChatRequest, user_id: str, app_name: str = "com.yunxiangshengtai"):
     session_id = request.session_id
     session_id = request.session_id
     latest_user_msg = None
     latest_user_msg = None
     try:
     try:
+        ai_config = get_config_by_app_name(app_name)
+        if not ai_config:
+            raise ValueError(f"未找到appName '{app_name}' 的配置")
+
+        client = get_client(app_name)
+        knowledge_id = ai_config["knowledgeId"]
+
         latest_user_msg = get_latest_user_message(request.messages)
         latest_user_msg = get_latest_user_message(request.messages)
         if not latest_user_msg:
         if not latest_user_msg:
             raise ValueError("请求中没有找到user角色的消息")
             raise ValueError("请求中没有找到user角色的消息")
@@ -152,8 +163,8 @@ async def generate_stream_response(request: ChatRequest, user_id: str):
             2. 不得把低可信、未验证、可能不实的信息写入答案;
             2. 不得把低可信、未验证、可能不实的信息写入答案;
             3. 不得编造事实、时间、数据、人物关系或产品能力;
             3. 不得编造事实、时间、数据、人物关系或产品能力;
             4. 不得向用户暴露知识库、检索、搜索策略、来源筛选过程或内部判断过程;
             4. 不得向用户暴露知识库、检索、搜索策略、来源筛选过程或内部判断过程;
-            5. 不得输出“思考过程”“搜索关键词”“为什么需要搜索”等内部推理内容;
-            6. 不得使用根据参考资料/根据知识库/根据搜索结果”等表述。
+            5. 不得输出”思考过程””搜索关键词””为什么需要搜索”等内部推理内容;
+            6. 不得使用根据参考资料/根据知识库/根据搜索结果”等表述。
 
 
             ## 八、最终目标
             ## 八、最终目标
             在保证回答自然、清晰、易懂的前提下:
             在保证回答自然、清晰、易懂的前提下:
@@ -166,7 +177,7 @@ async def generate_stream_response(request: ChatRequest, user_id: str):
         system_prompt = {"role": "system", "content": [{"type": "input_text", "text": system_prompt}]}
         system_prompt = {"role": "system", "content": [{"type": "input_text", "text": system_prompt}]}
 
 
         api_messages = [system_prompt, {"role": latest_user_msg.role, "content": latest_user_msg.content}]
         api_messages = [system_prompt, {"role": latest_user_msg.role, "content": latest_user_msg.content}]
-        tools = get_web_search_tools() + get_knowledge_search_tools()
+        tools = get_web_search_tools() + get_knowledge_search_tools(knowledge_id)
         previous_response_id = get_previous_response_id(user_id, session_id)
         previous_response_id = get_previous_response_id(user_id, session_id)
 
 
         stream = client.responses.create(
         stream = client.responses.create(
@@ -294,11 +305,12 @@ async def generate_stream_response(request: ChatRequest, user_id: str):
 async def chat(
 async def chat(
     request: ChatRequest,
     request: ChatRequest,
     user_id: Annotated[str, Depends(resolve_user_id)],
     user_id: Annotated[str, Depends(resolve_user_id)],
+    app_name: str = Query(default="com.yunxiangshengtai", description="应用包名"),
 ):
 ):
     try:
     try:
         if request.stream:
         if request.stream:
             return StreamingResponse(
             return StreamingResponse(
-                generate_stream_response(request, user_id),
+                generate_stream_response(request, user_id, app_name),
                 media_type="text/plain",
                 media_type="text/plain",
                 headers={
                 headers={
                     "Cache-Control": "no-cache",
                     "Cache-Control": "no-cache",
@@ -313,6 +325,12 @@ async def chat(
         if not latest_user_msg:
         if not latest_user_msg:
             raise ValueError("请求中没有找到user角色的消息")
             raise ValueError("请求中没有找到user角色的消息")
 
 
+        ai_config = get_config_by_app_name(app_name)
+        if not ai_config:
+            raise ValueError(f"未找到appName '{app_name}' 的配置")
+
+        client = get_client(app_name)
+
         save_chat_history(
         save_chat_history(
             user_id=user_id,
             user_id=user_id,
             session_id=session_id,
             session_id=session_id,
@@ -400,23 +418,23 @@ async def chat(
         )
         )
         raise HTTPException(status_code=500, detail=error_message)
         raise HTTPException(status_code=500, detail=error_message)
 
 
-
-@router.get("/models")
-async def get_models(current_user: Annotated[User, Depends(get_current_active_user)]):
-    try:
-        models = client.models.list()
-        return {
-            "models": [model.id for model in models.data],
-            "default_model": config.MODEL_NAME,
-            "user": current_user.username
-        }
-    except Exception:
-        return {
-            "models": [config.MODEL_NAME],
-            "default_model": config.MODEL_NAME,
-            "note": "使用默认模型配置",
-            "user": current_user.username
-        }
+#
+# @router.get("/models")
+# async def get_models(current_user: Annotated[User, Depends(get_current_active_user)]):
+#     try:
+#         models = client.models.list()
+#         return {
+#             "models": [model.id for model in models.data],
+#             "default_model": config.MODEL_NAME,
+#             "user": current_user.username
+#         }
+#     except Exception:
+#         return {
+#             "models": [config.MODEL_NAME],
+#             "default_model": config.MODEL_NAME,
+#             "note": "使用默认模型配置",
+#             "user": current_user.username
+#         }
 
 
 
 
 @router.get("/history")
 @router.get("/history")

+ 7 - 3
app/routers/chat_tools.py

@@ -1,11 +1,13 @@
 from fastapi import APIRouter, HTTPException
 from fastapi import APIRouter, HTTPException
 from datetime import datetime
 from datetime import datetime
 
 
-from ..core.ark_client import config, client
+from ..core.ark_client import get_client
+from ..config.config import Config
 from ..schemas.chat import ChatMessage, ChatResponse, CommentRequest, CirclePromptConfig, HistoricalFigure, RephraseRequest, FigureUpsert
 from ..schemas.chat import ChatMessage, ChatResponse, CommentRequest, CirclePromptConfig, HistoricalFigure, RephraseRequest, FigureUpsert
 from ..db.souyue_mongo import get_mblog_by_id
 from ..db.souyue_mongo import get_mblog_by_id
 from ..db.mongo import get_circle_prompt, upsert_circle_prompt, get_all_figures, get_figure_by_id, insert_figure, update_figure, delete_figure
 from ..db.mongo import get_circle_prompt, upsert_circle_prompt, get_all_figures, get_figure_by_id, insert_figure, update_figure, delete_figure
 
 
+
 router = APIRouter()
 router = APIRouter()
 
 
 
 
@@ -76,8 +78,9 @@ async def generate_post_comment(request: CommentRequest):
 
 
     content = file_list + [{"type": "input_text", "text": input_text}]
     content = file_list + [{"type": "input_text", "text": input_text}]
     print(f"concat text: {content}")
     print(f"concat text: {content}")
+    client = get_client(app_name)
     response = client.responses.create(
     response = client.responses.create(
-        model=config.MODEL_NAME,
+        model=Config.MODEL_NAME,
         input=[{"role": "user", "content": content}],
         input=[{"role": "user", "content": content}],
 
 
     )
     )
@@ -156,8 +159,9 @@ async def rephrase_as_figure(request: RephraseRequest):
         f"原文:{request.text}"
         f"原文:{request.text}"
     )
     )
 
 
+    client = get_client()
     response = client.responses.create(
     response = client.responses.create(
-        model=config.MODEL_NAME,
+        model=Config.MODEL_NAME,
         input=[{"role": "user", "content": prompt}],
         input=[{"role": "user", "content": prompt}],
         stream=False,
         stream=False,
         store=False,
         store=False,

+ 6 - 2
app/routers/users.py

@@ -42,6 +42,7 @@ class LoginRequest(BaseModel):
     """
     """
     username: str
     username: str
     password: str
     password: str
+    appName: Optional[str] = "com.yunxiangshengtai"
 
 
 
 
 class Token(BaseModel):
 class Token(BaseModel):
@@ -54,6 +55,7 @@ class Token(BaseModel):
     refresh_token: str  # JWT刷新令牌
     refresh_token: str  # JWT刷新令牌
     token_type: str  # 令牌类型,通常是"bearer"
     token_type: str  # 令牌类型,通常是"bearer"
     username: str
     username: str
+    appName: str
 
 
 
 
 class TokenData(BaseModel):
 class TokenData(BaseModel):
@@ -356,7 +358,8 @@ async def login_for_access_token(
         "access_token": access_token,
         "access_token": access_token,
         "refresh_token": refresh_token,
         "refresh_token": refresh_token,
         "token_type": "bearer",
         "token_type": "bearer",
-        "username": user.username
+        "username": user.username,
+        "appName": login_data.appName
     }
     }
 
 
 
 
@@ -483,7 +486,8 @@ async def refresh_access_token(refresh_token: str) -> Token:
         "access_token": new_access_token,
         "access_token": new_access_token,
         "refresh_token": refresh_token,
         "refresh_token": refresh_token,
         "token_type": "bearer",
         "token_type": "bearer",
-        "username": username
+        "username": username,
+        "appName": "com.yunxiangshengtai"
     }
     }
 
 
 
 

+ 2 - 2
app/schemas/chat.py

@@ -1,7 +1,7 @@
 from pydantic import BaseModel, Field, ConfigDict
 from pydantic import BaseModel, Field, ConfigDict
 from typing import List, Optional, Dict, Any
 from typing import List, Optional, Dict, Any
 from datetime import datetime
 from datetime import datetime
-from ..core.ark_client import config
+from ..config.config import Config
 
 
 
 
 class ChatMessage(BaseModel):
 class ChatMessage(BaseModel):
@@ -32,7 +32,7 @@ class ChatRequest(BaseModel):
     model_config = ConfigDict(populate_by_name=True)
     model_config = ConfigDict(populate_by_name=True)
 
 
     messages: List[ChatMessage]
     messages: List[ChatMessage]
-    model: Optional[str] = config.MODEL_NAME
+    model: Optional[str] = Config.MODEL_NAME
     stream: Optional[bool] = False
     stream: Optional[bool] = False
     source: Optional[str] = None  # source=app 时走第三方 token 认证
     source: Optional[str] = None  # source=app 时走第三方 token 认证
     token: Optional[str] = None   # App 端传入的第三方 token
     token: Optional[str] = None   # App 端传入的第三方 token

+ 6 - 6
app/utils/chat_utils.py

@@ -1,6 +1,6 @@
 from typing import List, Optional, Dict
 from typing import List, Optional, Dict
 from ..schemas.chat import ChatMessage
 from ..schemas.chat import ChatMessage
-from ..core.ark_client import config
+from ..config.config import Config
 from ..db.mongo import get_last_response_id
 from ..db.mongo import get_last_response_id
 
 
 
 
@@ -24,7 +24,7 @@ def get_web_search_tools() -> list:
         "type": "web_search",
         "type": "web_search",
         "max_keyword": 20,
         "max_keyword": 20,
         "limit": 20,
         "limit": 20,
-        "sources": ["douyin", "moji", "toutiao"],# 附加搜索来源(抖音百科、墨迹天气、头条图文等平台)
+        "sources": ["douyin", "toutiao"],# 附加搜索来源(抖音百科、墨迹天气、头条图文等平台)
         "user_location": {  # 用户地理位置(用于优化搜索结果)
         "user_location": {  # 用户地理位置(用于优化搜索结果)
             "type": "approximate",  # 大致位置
             "type": "approximate",  # 大致位置
             "country": "中国",
             "country": "中国",
@@ -33,10 +33,10 @@ def get_web_search_tools() -> list:
         }
         }
     }]
     }]
 # 私域知识库 --- 此处的知识库也需要映射
 # 私域知识库 --- 此处的知识库也需要映射
-def get_knowledge_search_tools() -> list:
+def get_knowledge_search_tools(knowledge_resource_id: str) -> list:
     return [{
     return [{
             "type": "knowledge_search",
             "type": "knowledge_search",
-            "knowledge_resource_id": "kb-25bb30f3d7463a76",  # 替换为实际知识库ID
+            "knowledge_resource_id": knowledge_resource_id,  # 替换为实际知识库ID
             "limit": 10,  # 最多返回10条搜索结果
             "limit": 10,  # 最多返回10条搜索结果
         }]
         }]
 
 
@@ -48,12 +48,12 @@ def get_doubao_tools() -> list:
             # 联网搜索功能
             # 联网搜索功能
             "ai_search": {
             "ai_search": {
                 "type": "disabled",
                 "type": "disabled",
-                "role_description": config.ROLE_DESCRIPTION
+                "role_description": Config.ROLE_DESCRIPTION
             },
             },
             # 边想边搜功能
             # 边想边搜功能
             "reasoning_search":{
             "reasoning_search":{
                 "type": "enabled",
                 "type": "enabled",
-                "role_description": config.ROLE_DESCRIPTION
+                "role_description": Config.ROLE_DESCRIPTION
             }
             }
         },
         },
         "user_location": {
         "user_location": {

+ 2 - 1
main.py

@@ -1,7 +1,7 @@
 from fastapi import FastAPI
 from fastapi import FastAPI
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 
 
-from app.routers import users, chat,chat_tools
+from app.routers import users, chat, chat_tools, ai_config
 
 
 # 创建FastAPI应用实例
 # 创建FastAPI应用实例
 app = FastAPI(title="聊天机器人", version="1.0.0", description="基于fastapi+VUE的聊天机器人")
 app = FastAPI(title="聊天机器人", version="1.0.0", description="基于fastapi+VUE的聊天机器人")
@@ -21,6 +21,7 @@ app.add_middleware(
 app.include_router(users.router, prefix="/main/users", tags=["用户管理"])
 app.include_router(users.router, prefix="/main/users", tags=["用户管理"])
 app.include_router(chat.router, prefix="/main/chat", tags=["聊天管理"])
 app.include_router(chat.router, prefix="/main/chat", tags=["聊天管理"])
 app.include_router(chat_tools.router, prefix="/main/chatTools", tags=["AI工具管理"])
 app.include_router(chat_tools.router, prefix="/main/chatTools", tags=["AI工具管理"])
+app.include_router(ai_config.router, prefix="/main/aiConfig", tags=["AI配置管理"])