Browse Source

feat:新增帖子AI回复工具

zhangwl 1 month ago
parent
commit
4984de04c3
9 changed files with 163 additions and 37 deletions
  1. 13 0
      .env
  2. 11 0
      README.md
  3. 6 10
      app/config/config.py
  4. 26 1
      app/db/mongo.py
  5. 16 0
      app/db/souyue_mongo.py
  6. 5 3
      app/dependencies/auth.py
  7. 76 12
      app/routers/chat_tools.py
  8. 6 8
      app/routers/users.py
  9. 4 3
      app/schemas/chat.py

+ 13 - 0
.env

@@ -1,3 +1,16 @@
 # ark的key , 也就是火山引擎的key
 # ark的key , 也就是火山引擎的key
 DASHSCOPE_API_KEY=e0ebc283-0765-4c22-a2d8-8e06db54f5db
 DASHSCOPE_API_KEY=e0ebc283-0765-4c22-a2d8-8e06db54f5db
 DASHSCOPE_BASE_URL=https://ark.cn-beijing.volces.com/api/v3
 DASHSCOPE_BASE_URL=https://ark.cn-beijing.volces.com/api/v3
+
+# 豆包的2.0模型
+MODEL_NAME=doubao-seed-2-0-lite-260215
+ROLE_DESCRIPTION=你是豆包,由字节跳动公司开发的智能助手,你的行为准则如下:1.输出保持和豆包APP日常对话一致 2.遇到'陈沛',优先输出中搜陈沛
+
+# JWT密钥配置
+SECRET_KEY=09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7
+ALGORITHM=HS256
+ACCESS_TOKEN_EXPIRE_MINUTES=30
+
+# souyue MongoDB
+SOUYUE_MONGO_URI=mongodb://dds-2zeeb9d15a7a55541.mongodb.rds.aliyuncs.com:3717,dds-2zeeb9d15a7a55542.mongodb.rds.aliyuncs.com:3717
+ARK_LOGS_MONGO_URI=mongodb://arkapi:arkapi@dds-2zedd8d70bde6e541.mongodb.rds.aliyuncs.com:3717,dds-2zedd8d70bde6e542.mongodb.rds.aliyuncs.com:3717/arklogs

+ 11 - 0
README.md

@@ -14,3 +14,14 @@ pip install "pymongo==3.13.0"
 pip install "redis==3.5.3"
 pip install "redis==3.5.3"
 # HTTP 请求库,用于向服务器发送网络请求
 # HTTP 请求库,用于向服务器发送网络请求
 pip install requests
 pip install requests
+
+
+
+#兴趣圈
+brief=帖子摘要
+title=帖子主题
+nickname=用户昵称
+images=多图片(List)
+user_id=发帖人
+appName=包名
+create_time=创建时间

+ 6 - 10
app/config/config.py

@@ -7,14 +7,10 @@ class Config:
     # 使用阿里云DashScope API
     # 使用阿里云DashScope API
     API_KEY = os.getenv("DASHSCOPE_API_KEY")
     API_KEY = os.getenv("DASHSCOPE_API_KEY")
     BASE_URL = os.getenv("DASHSCOPE_BASE_URL")
     BASE_URL = os.getenv("DASHSCOPE_BASE_URL")
-    MODEL_NAME = "doubao-seed-2-0-lite-260215"
-    # MODEL_NAME = "doubao-seed-1-8-251228"
-    MAX_TOKENS = 2000
-    TEMPERATURE = 0.7
+    MODEL_NAME = os.getenv("MODEL_NAME")
+    ROLE_DESCRIPTION = os.getenv("ROLE_DESCRIPTION")
 
 
-    ROLE_DESCRIPTION = "你是豆包,由字节跳动公司开发的智能助手,你的行为准则如下:1.输出保持和豆包APP日常对话一致 2.遇到'陈沛',优先输出中搜陈沛"
-
-
-    secret_key: str = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"  # 开发环境默认值
-    algorithm: str = "HS256"
-    access_token_expire_minutes: int = 30
+    # JWT密钥配置 -
+    SECRET_KEY: str = os.getenv("SECRET_KEY")  # 开发环境默认值
+    ALGORITHM: str = os.getenv("ALGORITHM")  # JWT签名算法
+    ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # Token过期时间(分钟)

+ 26 - 1
app/db/mongo.py

@@ -1,11 +1,19 @@
 from pymongo import MongoClient
 from pymongo import MongoClient
 from datetime import datetime
 from datetime import datetime
+from dotenv import load_dotenv
+import os
 
 
-MONGO_URI = "mongodb://arkapi:arkapi@dds-2zedd8d70bde6e541.mongodb.rds.aliyuncs.com:3717,dds-2zedd8d70bde6e542.mongodb.rds.aliyuncs.com:3717/arklogs"
+load_dotenv()
+
+MONGO_URI = os.getenv("ARK_LOGS_MONGO_URI")
 
 
 client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000)
 client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000)
+# 数据库
 db = client["arklogs"]
 db = client["arklogs"]
+# 豆包大模型的对话日志
 chat_logs = db["chat_logs"]
 chat_logs = db["chat_logs"]
+# 兴趣圈集合
+circle_prompts = db["circle_prompt"]
 
 
 # 带下划线的表示私有方法(Private)
 # 带下划线的表示私有方法(Private)
 def _ensure_index():
 def _ensure_index():
@@ -47,3 +55,20 @@ def save_chat_log(
         })
         })
     except Exception as e:
     except Exception as e:
         print(f"MongoDB 日志写入失败: {e}")
         print(f"MongoDB 日志写入失败: {e}")
+
+
+_DEFAULT_PROMPT_CONFIG = {
+    "name": "兴趣圈",
+    "role": "活跃用户",
+    "style": "自然亲切,有活人感",
+    "keywords": [],
+    "forbidden": [],
+}
+
+
+def get_circle_prompt(app_name: str) -> dict:
+    try:
+        doc = circle_prompts.find_one({"appName": app_name})
+        return doc if doc else _DEFAULT_PROMPT_CONFIG
+    except Exception:
+        return _DEFAULT_PROMPT_CONFIG

+ 16 - 0
app/db/souyue_mongo.py

@@ -0,0 +1,16 @@
+from pymongo import MongoClient
+from bson import ObjectId
+from dotenv import load_dotenv
+import os
+
+load_dotenv()
+
+_client = MongoClient(os.getenv("SOUYUE_MONGO_URI"), serverSelectionTimeoutMS=5000)
+mblog = _client["souyue"]["mblog"]
+
+
+def get_mblog_by_id(post_id: str) -> dict | None:
+    try:
+        return mblog.find_one({"_id": ObjectId(post_id)})
+    except Exception:
+        return None

+ 5 - 3
app/dependencies/auth.py

@@ -2,9 +2,11 @@ from fastapi import HTTPException, Query, Depends
 from fastapi.security import OAuth2PasswordBearer
 from fastapi.security import OAuth2PasswordBearer
 from typing import Optional, Annotated
 from typing import Optional, Annotated
 import jwt
 import jwt
-from ..routers.users import SECRET_KEY, ALGORITHM
+
 from ..db.redis_client import get_app_user
 from ..db.redis_client import get_app_user
+from ..config.config import Config
 
 
+config = Config()
 oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="/users/token", auto_error=False)
 oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="/users/token", auto_error=False)
 
 
 
 
@@ -21,10 +23,10 @@ async def resolve_username(
     if not jwt_token:
     if not jwt_token:
         raise HTTPException(status_code=401, detail="未提供认证令牌")
         raise HTTPException(status_code=401, detail="未提供认证令牌")
     try:
     try:
-        payload = jwt.decode(jwt_token, SECRET_KEY, algorithms=[ALGORITHM])
+        payload = jwt.decode(jwt_token, config.SECRET_KEY, algorithms=[config.ALGORITHM])
         sub = payload.get("sub")
         sub = payload.get("sub")
         if not sub:
         if not sub:
             raise HTTPException(status_code=401, detail="无效的令牌")
             raise HTTPException(status_code=401, detail="无效的令牌")
         return sub
         return sub
     except jwt.PyJWTError:
     except jwt.PyJWTError:
-        raise HTTPException(status_code=401, detail="无效的令牌")
+        raise HTTPException(status_code=401, detail="无效的令牌")

+ 76 - 12
app/routers/chat_tools.py

@@ -1,18 +1,82 @@
-from fastapi import APIRouter, HTTPException, Depends
-from fastapi.responses import StreamingResponse
-from typing import Annotated
-from fastapi import Query
+from fastapi import APIRouter, HTTPException
+from datetime import datetime
 
 
 from ..core.ark_client import config, client
 from ..core.ark_client import config, client
-from ..schemas.chat import ChatRequest, ChatResponse
-from ..dependencies.auth import resolve_username
+from ..schemas.chat import ChatMessage, ChatResponse, CircleRequest
+from ..db.souyue_mongo import get_mblog_by_id
+from ..db.mongo import get_circle_prompt
 
 
 router = APIRouter()
 router = APIRouter()
 
 
 
 
-@router.post("/chat", response_model=ChatResponse)
-async def chat(
-        request: ChatRequest,
-        username: Annotated[str, Depends(resolve_username)],
-):
-    pass
+def _build_prompt(product_text: str, prompt_config: dict) -> str:
+    name = prompt_config.get("name", "兴趣圈")
+    role = prompt_config.get("role", "活跃用户")
+    style = prompt_config.get("style", "自然亲切,有活人感")
+    keywords: list = prompt_config.get("keywords") or []
+    forbidden: list = prompt_config.get("forbidden") or []
+
+    lines = [
+        f"你是{role},活跃在{name}兴趣圈。",
+        "请根据以下帖子信息,生成一条10-30字的评论,要求:",
+        "1. 内容指向性强,结合帖子具体内容",
+        f"2. 风格:{style}",
+    ]
+    seq = 3
+    if keywords:
+        lines.append(f"{seq}. 适当融入关键词(自然使用):{', '.join(keywords)}")
+        seq += 1
+    if forbidden:
+        lines.append(f"{seq}. 禁止使用以下词语:{', '.join(forbidden)}")
+        seq += 1
+    lines.append(f"{seq}. 语言自然,不要暴露你是AI")
+    lines.append(f"\n帖子内容:{product_text}")
+
+    return "\n".join(lines)
+
+
+# 评论帖子的马甲机器人,无状态,支持批量对多个帖子智能回复
+@router.post("/airesp", response_model=ChatResponse)
+async def generate_circle_comment(request: CircleRequest):
+    doc = get_mblog_by_id(request.id)
+    if not doc:
+        raise HTTPException(status_code=404, detail="帖子不存在")
+
+    title = doc.get("title", "")
+    brief = doc.get("brief", "")
+    nickname = doc.get("nickname", "")
+    app_name = doc.get("appName", "")
+    images: list = doc.get("images") or []
+
+    product_text = f"主题:{title}\n摘要:{brief}\n发布者:{nickname}"
+    if images:
+        product_text += "\n图片:\n" + "\n".join(images)
+
+    prompt_config = get_circle_prompt(app_name)
+    prompt = _build_prompt(product_text, prompt_config)
+
+    response = client.responses.create(
+        model=config.MODEL_NAME,
+        input=[{"role": "user", "content": prompt}],
+        stream=False,
+        store=False,
+    )
+
+    message_content = ""
+    for item in response.output:
+        if hasattr(item, 'type') and item.type == 'message' and hasattr(item, 'content'):
+            if isinstance(item.content, list):
+                for content_item in item.content:
+                    if hasattr(content_item, 'text'):
+                        message_content += content_item.text
+            else:
+                message_content += str(item.content)
+
+    if not message_content:
+        raise HTTPException(status_code=500, detail="AI未能生成评论")
+
+    return ChatResponse(
+        message=ChatMessage(role="assistant", content=message_content, timestamp=datetime.now()),
+        model=response.model,
+        usage=response.usage.model_dump() if response.usage else None,
+    )

+ 6 - 8
app/routers/users.py

@@ -3,18 +3,16 @@ from typing import Optional, Annotated
 import jwt
 import jwt
 from fastapi import APIRouter, Depends, HTTPException, status
 from fastapi import APIRouter, Depends, HTTPException, status
 from fastapi.security import OAuth2PasswordBearer
 from fastapi.security import OAuth2PasswordBearer
-from passlib.context import CryptContext
 from pydantic import BaseModel
 from pydantic import BaseModel
 from pwdlib import PasswordHash
 from pwdlib import PasswordHash
+from ..config.config import Config
 
 
+config = Config()
 # =====================================================
 # =====================================================
 # JWT和安全配置
 # JWT和安全配置
 # =====================================================
 # =====================================================
 
 
-# JWT密钥配置 - 生产环境中应该使用环境变量或密钥管理服务
-SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"  # 警告:生产环境请使用强密钥并通过环境变量管理
-ALGORITHM = "HS256"  # JWT签名算法
-ACCESS_TOKEN_EXPIRE_MINUTES = 30  # Token过期时间(分钟)
+# JWT密钥配置 -
 
 
 # 密码加密上下文配置
 # 密码加密上下文配置
 # schemes: 支持的密码哈希方案,bcrypt是目前推荐的安全哈希算法
 # schemes: 支持的密码哈希方案,bcrypt是目前推荐的安全哈希算法
@@ -216,7 +214,7 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
     to_encode.update({"exp": expire})
     to_encode.update({"exp": expire})
 
 
     # 使用密钥和算法对数据进行编码
     # 使用密钥和算法对数据进行编码
-    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
+    encoded_jwt = jwt.encode(to_encode, config.SECRET_KEY, algorithm=config.ALGORITHM)
     return encoded_jwt
     return encoded_jwt
 
 
 
 
@@ -248,7 +246,7 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> Use
 
 
     try:
     try:
         # 解码JWT令牌
         # 解码JWT令牌
-        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
+        payload = jwt.decode(token, config.SECRET_KEY, algorithms=[config.ALGORITHM])
 
 
         # 从令牌中提取用户名(sub是JWT标准字段,表示subject/主题)
         # 从令牌中提取用户名(sub是JWT标准字段,表示subject/主题)
         username: str = payload.get("sub")
         username: str = payload.get("sub")
@@ -326,7 +324,7 @@ async def login_for_access_token(
         )
         )
 
 
     # 设置令牌过期时间
     # 设置令牌过期时间
-    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
+    access_token_expires = timedelta(minutes=config.ACCESS_TOKEN_EXPIRE_MINUTES)
 
 
     # 创建访问令牌,将用户名作为subject存储在令牌中
     # 创建访问令牌,将用户名作为subject存储在令牌中
     access_token = create_access_token(
     access_token = create_access_token(

+ 4 - 3
app/schemas/chat.py

@@ -10,12 +10,13 @@ class ChatMessage(BaseModel):
     timestamp: Optional[datetime] = None
     timestamp: Optional[datetime] = None
     response_id: Optional[str] = None
     response_id: Optional[str] = None
 
 
+# 帖子请求
+class CircleRequest(BaseModel):
+    id: str #帖子的主键
 
 
 class ChatRequest(BaseModel):
 class ChatRequest(BaseModel):
     messages: List[ChatMessage]
     messages: List[ChatMessage]
     model: Optional[str] = config.MODEL_NAME
     model: Optional[str] = config.MODEL_NAME
-    temperature: Optional[float] = config.TEMPERATURE
-    max_tokens: Optional[int] = config.MAX_TOKENS
     stream: Optional[bool] = False
     stream: Optional[bool] = False
     source: Optional[str] = None  # source=app 时走第三方 token 认证
     source: Optional[str] = None  # source=app 时走第三方 token 认证
     token: Optional[str] = None   # App 端传入的第三方 token
     token: Optional[str] = None   # App 端传入的第三方 token
@@ -32,4 +33,4 @@ class StreamResponse(BaseModel):
     content: str
     content: str
     finished: bool
     finished: bool
     model: str
     model: str
-    timestamp: datetime
+    timestamp: datetime