|
@@ -6,20 +6,31 @@ import json
|
|
|
import asyncio
|
|
import asyncio
|
|
|
import threading
|
|
import threading
|
|
|
|
|
|
|
|
-from ..core.ark_client import config, client
|
|
|
|
|
|
|
+from ..core.ark_client import get_client
|
|
|
|
|
+from ..config.config import Config
|
|
|
from ..schemas.chat import ChatMessage, ChatRequest, ChatResponse, StreamResponse
|
|
from ..schemas.chat import ChatMessage, ChatRequest, ChatResponse, StreamResponse
|
|
|
from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools, get_web_search_tools, get_knowledge_search_tools
|
|
from ..utils.chat_utils import get_latest_user_message, get_previous_response_id, get_doubao_tools, get_web_search_tools, get_knowledge_search_tools
|
|
|
from ..routers.users import get_current_active_user, User
|
|
from ..routers.users import get_current_active_user, User
|
|
|
from ..db.mongo import save_chat_log, save_chat_history, get_chat_history, delete_chat_history, get_sessions
|
|
from ..db.mongo import save_chat_log, save_chat_history, get_chat_history, delete_chat_history, get_sessions
|
|
|
|
|
+from ..db.ai_config import get_config_by_app_name
|
|
|
from ..dependencies.auth import resolve_user_id
|
|
from ..dependencies.auth import resolve_user_id
|
|
|
|
|
|
|
|
|
|
+config = Config()
|
|
|
|
|
+
|
|
|
router = APIRouter()
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
|
|
|
|
|
-async def generate_stream_response(request: ChatRequest, user_id: str):
|
|
|
|
|
|
|
+async def generate_stream_response(request: ChatRequest, user_id: str, app_name: str = "com.yunxiangshengtai"):
|
|
|
session_id = request.session_id
|
|
session_id = request.session_id
|
|
|
latest_user_msg = None
|
|
latest_user_msg = None
|
|
|
try:
|
|
try:
|
|
|
|
|
+ ai_config = get_config_by_app_name(app_name)
|
|
|
|
|
+ if not ai_config:
|
|
|
|
|
+ raise ValueError(f"未找到appName '{app_name}' 的配置")
|
|
|
|
|
+
|
|
|
|
|
+ client = get_client(app_name)
|
|
|
|
|
+ knowledge_id = ai_config["knowledgeId"]
|
|
|
|
|
+
|
|
|
latest_user_msg = get_latest_user_message(request.messages)
|
|
latest_user_msg = get_latest_user_message(request.messages)
|
|
|
if not latest_user_msg:
|
|
if not latest_user_msg:
|
|
|
raise ValueError("请求中没有找到user角色的消息")
|
|
raise ValueError("请求中没有找到user角色的消息")
|
|
@@ -152,8 +163,8 @@ async def generate_stream_response(request: ChatRequest, user_id: str):
|
|
|
2. 不得把低可信、未验证、可能不实的信息写入答案;
|
|
2. 不得把低可信、未验证、可能不实的信息写入答案;
|
|
|
3. 不得编造事实、时间、数据、人物关系或产品能力;
|
|
3. 不得编造事实、时间、数据、人物关系或产品能力;
|
|
|
4. 不得向用户暴露知识库、检索、搜索策略、来源筛选过程或内部判断过程;
|
|
4. 不得向用户暴露知识库、检索、搜索策略、来源筛选过程或内部判断过程;
|
|
|
- 5. 不得输出“思考过程”“搜索关键词”“为什么需要搜索”等内部推理内容;
|
|
|
|
|
- 6. 不得使用“根据参考资料/根据知识库/根据搜索结果”等表述。
|
|
|
|
|
|
|
+ 5. 不得输出”思考过程””搜索关键词””为什么需要搜索”等内部推理内容;
|
|
|
|
|
+ 6. 不得使用”根据参考资料/根据知识库/根据搜索结果”等表述。
|
|
|
|
|
|
|
|
## 八、最终目标
|
|
## 八、最终目标
|
|
|
在保证回答自然、清晰、易懂的前提下:
|
|
在保证回答自然、清晰、易懂的前提下:
|
|
@@ -166,7 +177,7 @@ async def generate_stream_response(request: ChatRequest, user_id: str):
|
|
|
system_prompt = {"role": "system", "content": [{"type": "input_text", "text": system_prompt}]}
|
|
system_prompt = {"role": "system", "content": [{"type": "input_text", "text": system_prompt}]}
|
|
|
|
|
|
|
|
api_messages = [system_prompt, {"role": latest_user_msg.role, "content": latest_user_msg.content}]
|
|
api_messages = [system_prompt, {"role": latest_user_msg.role, "content": latest_user_msg.content}]
|
|
|
- tools = get_web_search_tools() + get_knowledge_search_tools()
|
|
|
|
|
|
|
+ tools = get_web_search_tools() + get_knowledge_search_tools(knowledge_id)
|
|
|
previous_response_id = get_previous_response_id(user_id, session_id)
|
|
previous_response_id = get_previous_response_id(user_id, session_id)
|
|
|
|
|
|
|
|
stream = client.responses.create(
|
|
stream = client.responses.create(
|
|
@@ -294,11 +305,12 @@ async def generate_stream_response(request: ChatRequest, user_id: str):
|
|
|
async def chat(
|
|
async def chat(
|
|
|
request: ChatRequest,
|
|
request: ChatRequest,
|
|
|
user_id: Annotated[str, Depends(resolve_user_id)],
|
|
user_id: Annotated[str, Depends(resolve_user_id)],
|
|
|
|
|
+ app_name: str = Query(default="com.yunxiangshengtai", description="应用包名"),
|
|
|
):
|
|
):
|
|
|
try:
|
|
try:
|
|
|
if request.stream:
|
|
if request.stream:
|
|
|
return StreamingResponse(
|
|
return StreamingResponse(
|
|
|
- generate_stream_response(request, user_id),
|
|
|
|
|
|
|
+ generate_stream_response(request, user_id, app_name),
|
|
|
media_type="text/plain",
|
|
media_type="text/plain",
|
|
|
headers={
|
|
headers={
|
|
|
"Cache-Control": "no-cache",
|
|
"Cache-Control": "no-cache",
|
|
@@ -313,6 +325,12 @@ async def chat(
|
|
|
if not latest_user_msg:
|
|
if not latest_user_msg:
|
|
|
raise ValueError("请求中没有找到user角色的消息")
|
|
raise ValueError("请求中没有找到user角色的消息")
|
|
|
|
|
|
|
|
|
|
+ ai_config = get_config_by_app_name(app_name)
|
|
|
|
|
+ if not ai_config:
|
|
|
|
|
+ raise ValueError(f"未找到appName '{app_name}' 的配置")
|
|
|
|
|
+
|
|
|
|
|
+ client = get_client(app_name)
|
|
|
|
|
+
|
|
|
save_chat_history(
|
|
save_chat_history(
|
|
|
user_id=user_id,
|
|
user_id=user_id,
|
|
|
session_id=session_id,
|
|
session_id=session_id,
|
|
@@ -400,23 +418,23 @@ async def chat(
|
|
|
)
|
|
)
|
|
|
raise HTTPException(status_code=500, detail=error_message)
|
|
raise HTTPException(status_code=500, detail=error_message)
|
|
|
|
|
|
|
|
-
|
|
|
|
|
-@router.get("/models")
|
|
|
|
|
-async def get_models(current_user: Annotated[User, Depends(get_current_active_user)]):
|
|
|
|
|
- try:
|
|
|
|
|
- models = client.models.list()
|
|
|
|
|
- return {
|
|
|
|
|
- "models": [model.id for model in models.data],
|
|
|
|
|
- "default_model": config.MODEL_NAME,
|
|
|
|
|
- "user": current_user.username
|
|
|
|
|
- }
|
|
|
|
|
- except Exception:
|
|
|
|
|
- return {
|
|
|
|
|
- "models": [config.MODEL_NAME],
|
|
|
|
|
- "default_model": config.MODEL_NAME,
|
|
|
|
|
- "note": "使用默认模型配置",
|
|
|
|
|
- "user": current_user.username
|
|
|
|
|
- }
|
|
|
|
|
|
|
+#
|
|
|
|
|
+# @router.get("/models")
|
|
|
|
|
+# async def get_models(current_user: Annotated[User, Depends(get_current_active_user)]):
|
|
|
|
|
+# try:
|
|
|
|
|
+# models = client.models.list()
|
|
|
|
|
+# return {
|
|
|
|
|
+# "models": [model.id for model in models.data],
|
|
|
|
|
+# "default_model": config.MODEL_NAME,
|
|
|
|
|
+# "user": current_user.username
|
|
|
|
|
+# }
|
|
|
|
|
+# except Exception:
|
|
|
|
|
+# return {
|
|
|
|
|
+# "models": [config.MODEL_NAME],
|
|
|
|
|
+# "default_model": config.MODEL_NAME,
|
|
|
|
|
+# "note": "使用默认模型配置",
|
|
|
|
|
+# "user": current_user.username
|
|
|
|
|
+# }
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/history")
|
|
@router.get("/history")
|