Browse Source

feat:新增帖子AI回复工具

zhangwl 1 month ago
parent
commit
4984de04c3
9 changed files with 163 additions and 37 deletions
  1. 13 0
      .env
  2. 11 0
      README.md
  3. 6 10
      app/config/config.py
  4. 26 1
      app/db/mongo.py
  5. 16 0
      app/db/souyue_mongo.py
  6. 5 3
      app/dependencies/auth.py
  7. 76 12
      app/routers/chat_tools.py
  8. 6 8
      app/routers/users.py
  9. 4 3
      app/schemas/chat.py

+ 13 - 0
.env

@@ -1,3 +1,16 @@
 # ark的key , 也就是火山引擎的key
 DASHSCOPE_API_KEY=e0ebc283-0765-4c22-a2d8-8e06db54f5db
 DASHSCOPE_BASE_URL=https://ark.cn-beijing.volces.com/api/v3
+
+# 豆包的2.0模型
+MODEL_NAME=doubao-seed-2-0-lite-260215
+ROLE_DESCRIPTION=你是豆包,由字节跳动公司开发的智能助手,你的行为准则如下:1.输出保持和豆包APP日常对话一致 2.遇到'陈沛',优先输出中搜陈沛
+
+# JWT密钥配置
+SECRET_KEY=09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7
+ALGORITHM=HS256
+ACCESS_TOKEN_EXPIRE_MINUTES=30
+
+# souyue MongoDB
+SOUYUE_MONGO_URI=mongodb://dds-2zeeb9d15a7a55541.mongodb.rds.aliyuncs.com:3717,dds-2zeeb9d15a7a55542.mongodb.rds.aliyuncs.com:3717
+ARK_LOGS_MONGO_URI=mongodb://arkapi:arkapi@dds-2zedd8d70bde6e541.mongodb.rds.aliyuncs.com:3717,dds-2zedd8d70bde6e542.mongodb.rds.aliyuncs.com:3717/arklogs

+ 11 - 0
README.md

@@ -14,3 +14,14 @@ pip install "pymongo==3.13.0"
 pip install "redis==3.5.3"
 # HTTP 请求库,用于向服务器发送网络请求
 pip install requests
+
+
+
+#兴趣圈
+brief=帖子摘要
+title=帖子主题
+nickname=用户昵称
+images=多图片(List)
+user_id=发帖人
+appName=包名
+create_time=创建时间

+ 6 - 10
app/config/config.py

@@ -7,14 +7,10 @@ class Config:
     # 使用阿里云DashScope API
     API_KEY = os.getenv("DASHSCOPE_API_KEY")
     BASE_URL = os.getenv("DASHSCOPE_BASE_URL")
-    MODEL_NAME = "doubao-seed-2-0-lite-260215"
-    # MODEL_NAME = "doubao-seed-1-8-251228"
-    MAX_TOKENS = 2000
-    TEMPERATURE = 0.7
+    MODEL_NAME = os.getenv("MODEL_NAME")
+    ROLE_DESCRIPTION = os.getenv("ROLE_DESCRIPTION")
 
-    ROLE_DESCRIPTION = "你是豆包,由字节跳动公司开发的智能助手,你的行为准则如下:1.输出保持和豆包APP日常对话一致 2.遇到'陈沛',优先输出中搜陈沛"
-
-
-    secret_key: str = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"  # 开发环境默认值
-    algorithm: str = "HS256"
-    access_token_expire_minutes: int = 30
+    # JWT密钥配置 -
+    SECRET_KEY: str = os.getenv("SECRET_KEY")  # 开发环境默认值
+    ALGORITHM: str = os.getenv("ALGORITHM")  # JWT签名算法
+    ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # Token过期时间(分钟)

+ 26 - 1
app/db/mongo.py

@@ -1,11 +1,19 @@
 from pymongo import MongoClient
 from datetime import datetime
+from dotenv import load_dotenv
+import os
 
-MONGO_URI = "mongodb://arkapi:arkapi@dds-2zedd8d70bde6e541.mongodb.rds.aliyuncs.com:3717,dds-2zedd8d70bde6e542.mongodb.rds.aliyuncs.com:3717/arklogs"
+load_dotenv()
+
+MONGO_URI = os.getenv("ARK_LOGS_MONGO_URI")
 
 client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000)
+# 数据库
 db = client["arklogs"]
+# 豆包大模型的对话日志
 chat_logs = db["chat_logs"]
+# 兴趣圈集合
+circle_prompts = db["circle_prompt"]
 
 # 带下划线的表示私有方法(Private)
 def _ensure_index():
@@ -47,3 +55,20 @@ def save_chat_log(
         })
     except Exception as e:
         print(f"MongoDB 日志写入失败: {e}")
+
+
+_DEFAULT_PROMPT_CONFIG = {
+    "name": "兴趣圈",
+    "role": "活跃用户",
+    "style": "自然亲切,有活人感",
+    "keywords": [],
+    "forbidden": [],
+}
+
+
+def get_circle_prompt(app_name: str) -> dict:
+    try:
+        doc = circle_prompts.find_one({"appName": app_name})
+        return doc if doc else _DEFAULT_PROMPT_CONFIG
+    except Exception:
+        return _DEFAULT_PROMPT_CONFIG

+ 16 - 0
app/db/souyue_mongo.py

@@ -0,0 +1,16 @@
+from pymongo import MongoClient
+from bson import ObjectId
+from dotenv import load_dotenv
+import os
+
+load_dotenv()
+
+_client = MongoClient(os.getenv("SOUYUE_MONGO_URI"), serverSelectionTimeoutMS=5000)
+mblog = _client["souyue"]["mblog"]
+
+
+def get_mblog_by_id(post_id: str) -> dict | None:
+    try:
+        return mblog.find_one({"_id": ObjectId(post_id)})
+    except Exception:
+        return None

+ 5 - 3
app/dependencies/auth.py

@@ -2,9 +2,11 @@ from fastapi import HTTPException, Query, Depends
 from fastapi.security import OAuth2PasswordBearer
 from typing import Optional, Annotated
 import jwt
-from ..routers.users import SECRET_KEY, ALGORITHM
+
 from ..db.redis_client import get_app_user
+from ..config.config import Config
 
+config = Config()
 oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="/users/token", auto_error=False)
 
 
@@ -21,10 +23,10 @@ async def resolve_username(
     if not jwt_token:
         raise HTTPException(status_code=401, detail="未提供认证令牌")
     try:
-        payload = jwt.decode(jwt_token, SECRET_KEY, algorithms=[ALGORITHM])
+        payload = jwt.decode(jwt_token, config.SECRET_KEY, algorithms=[config.ALGORITHM])
         sub = payload.get("sub")
         if not sub:
             raise HTTPException(status_code=401, detail="无效的令牌")
         return sub
     except jwt.PyJWTError:
-        raise HTTPException(status_code=401, detail="无效的令牌")
+        raise HTTPException(status_code=401, detail="无效的令牌")

+ 76 - 12
app/routers/chat_tools.py

@@ -1,18 +1,82 @@
-from fastapi import APIRouter, HTTPException, Depends
-from fastapi.responses import StreamingResponse
-from typing import Annotated
-from fastapi import Query
+from fastapi import APIRouter, HTTPException
+from datetime import datetime
 
 from ..core.ark_client import config, client
-from ..schemas.chat import ChatRequest, ChatResponse
-from ..dependencies.auth import resolve_username
+from ..schemas.chat import ChatMessage, ChatResponse, CircleRequest
+from ..db.souyue_mongo import get_mblog_by_id
+from ..db.mongo import get_circle_prompt
 
 router = APIRouter()
 
 
-@router.post("/chat", response_model=ChatResponse)
-async def chat(
-        request: ChatRequest,
-        username: Annotated[str, Depends(resolve_username)],
-):
-    pass
+def _build_prompt(product_text: str, prompt_config: dict) -> str:
+    name = prompt_config.get("name", "兴趣圈")
+    role = prompt_config.get("role", "活跃用户")
+    style = prompt_config.get("style", "自然亲切,有活人感")
+    keywords: list = prompt_config.get("keywords") or []
+    forbidden: list = prompt_config.get("forbidden") or []
+
+    lines = [
+        f"你是{role},活跃在{name}兴趣圈。",
+        "请根据以下帖子信息,生成一条10-30字的评论,要求:",
+        "1. 内容指向性强,结合帖子具体内容",
+        f"2. 风格:{style}",
+    ]
+    seq = 3
+    if keywords:
+        lines.append(f"{seq}. 适当融入关键词(自然使用):{', '.join(keywords)}")
+        seq += 1
+    if forbidden:
+        lines.append(f"{seq}. 禁止使用以下词语:{', '.join(forbidden)}")
+        seq += 1
+    lines.append(f"{seq}. 语言自然,不要暴露你是AI")
+    lines.append(f"\n帖子内容:{product_text}")
+
+    return "\n".join(lines)
+
+
+# 评论帖子的马甲机器人,无状态,支持批量对多个帖子智能回复
+@router.post("/airesp", response_model=ChatResponse)
+async def generate_circle_comment(request: CircleRequest):
+    doc = get_mblog_by_id(request.id)
+    if not doc:
+        raise HTTPException(status_code=404, detail="帖子不存在")
+
+    title = doc.get("title", "")
+    brief = doc.get("brief", "")
+    nickname = doc.get("nickname", "")
+    app_name = doc.get("appName", "")
+    images: list = doc.get("images") or []
+
+    product_text = f"主题:{title}\n摘要:{brief}\n发布者:{nickname}"
+    if images:
+        product_text += "\n图片:\n" + "\n".join(images)
+
+    prompt_config = get_circle_prompt(app_name)
+    prompt = _build_prompt(product_text, prompt_config)
+
+    response = client.responses.create(
+        model=config.MODEL_NAME,
+        input=[{"role": "user", "content": prompt}],
+        stream=False,
+        store=False,
+    )
+
+    message_content = ""
+    for item in response.output:
+        if hasattr(item, 'type') and item.type == 'message' and hasattr(item, 'content'):
+            if isinstance(item.content, list):
+                for content_item in item.content:
+                    if hasattr(content_item, 'text'):
+                        message_content += content_item.text
+            else:
+                message_content += str(item.content)
+
+    if not message_content:
+        raise HTTPException(status_code=500, detail="AI未能生成评论")
+
+    return ChatResponse(
+        message=ChatMessage(role="assistant", content=message_content, timestamp=datetime.now()),
+        model=response.model,
+        usage=response.usage.model_dump() if response.usage else None,
+    )

+ 6 - 8
app/routers/users.py

@@ -3,18 +3,16 @@ from typing import Optional, Annotated
 import jwt
 from fastapi import APIRouter, Depends, HTTPException, status
 from fastapi.security import OAuth2PasswordBearer
-from passlib.context import CryptContext
 from pydantic import BaseModel
 from pwdlib import PasswordHash
+from ..config.config import Config
 
+config = Config()
 # =====================================================
 # JWT和安全配置
 # =====================================================
 
-# JWT密钥配置 - 生产环境中应该使用环境变量或密钥管理服务
-SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"  # 警告:生产环境请使用强密钥并通过环境变量管理
-ALGORITHM = "HS256"  # JWT签名算法
-ACCESS_TOKEN_EXPIRE_MINUTES = 30  # Token过期时间(分钟)
+# JWT密钥配置 -
 
 # 密码加密上下文配置
 # schemes: 支持的密码哈希方案,bcrypt是目前推荐的安全哈希算法
@@ -216,7 +214,7 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
     to_encode.update({"exp": expire})
 
     # 使用密钥和算法对数据进行编码
-    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
+    encoded_jwt = jwt.encode(to_encode, config.SECRET_KEY, algorithm=config.ALGORITHM)
     return encoded_jwt
 
 
@@ -248,7 +246,7 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> Use
 
     try:
         # 解码JWT令牌
-        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
+        payload = jwt.decode(token, config.SECRET_KEY, algorithms=[config.ALGORITHM])
 
         # 从令牌中提取用户名(sub是JWT标准字段,表示subject/主题)
         username: str = payload.get("sub")
@@ -326,7 +324,7 @@ async def login_for_access_token(
         )
 
     # 设置令牌过期时间
-    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
+    access_token_expires = timedelta(minutes=config.ACCESS_TOKEN_EXPIRE_MINUTES)
 
     # 创建访问令牌,将用户名作为subject存储在令牌中
     access_token = create_access_token(

+ 4 - 3
app/schemas/chat.py

@@ -10,12 +10,13 @@ class ChatMessage(BaseModel):
     timestamp: Optional[datetime] = None
     response_id: Optional[str] = None
 
+# 帖子请求
+class CircleRequest(BaseModel):
+    id: str #帖子的主键
 
 class ChatRequest(BaseModel):
     messages: List[ChatMessage]
     model: Optional[str] = config.MODEL_NAME
-    temperature: Optional[float] = config.TEMPERATURE
-    max_tokens: Optional[int] = config.MAX_TOKENS
     stream: Optional[bool] = False
     source: Optional[str] = None  # source=app 时走第三方 token 认证
     token: Optional[str] = None   # App 端传入的第三方 token
@@ -32,4 +33,4 @@ class StreamResponse(BaseModel):
     content: str
     finished: bool
     model: str
-    timestamp: datetime
+    timestamp: datetime