使用微调后的大模型
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel
from fastapi.responses import StreamingResponse, JSONResponse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import uvicorn
import re
import asyncio
from contextlib import asynccontextmanager
import time
from typing import List, Optional, Dict, Any
import json
# 使用 lifespan 替代 on_event
@asynccontextmanager
async def lifespan(app: FastAPI):
"""管理应用生命周期"""
# 启动时加载模型
print("正在加载微调模型...")
await load_model()
print("模型加载完成")
yield
# 关闭时清理
print("正在关闭服务...")
app = FastAPI(
title="分析API (OpenAI兼容)",
version="3.0.0",
lifespan=lifespan
)
# 全局模型变量
model = None
tokenizer = None
api_model = "./models/Qwen3-0.6B"
# OpenAI 兼容的请求模型
class OpenAIMessage(BaseModel):
role: str # "system", "user", "assistant"
content: str
class OpenAIRequest(BaseModel):
model: str = api_model
messages: List[OpenAIMessage]
max_tokens: Optional[int] = 300
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.9
stream: Optional[bool] = False
n: Optional[int] = 1
class OpenAIChoice(BaseModel):
index: int
message: Dict[str, str]
finish_reason: str
class OpenAIUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class OpenAIResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[OpenAIChoice]
usage: OpenAIUsage
class EmbeddingRequest(BaseModel):
input: str | List[str]
model: str = "qwen3-0.6B-embeddings"
class EmbeddingResponse(BaseModel):
object: str = "list"
data: List[Dict[str, Any]]
model: str
usage: Dict[str, int]
async def load_model():
"""异步加载模型"""
global model, tokenizer
BASE_MODEL = api_model
LORA_PATH = "out/qwen3-0.6B-finetuned-final"
try:
tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
padding_side="left"
)
# 设置pad_token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
dtype=torch.float32,
low_cpu_mem_usage=True
)
model = PeftModel.from_pretrained(base_model, LORA_PATH)
model.eval()
print(f"模型加载成功,模型名称: {api_model}")
except Exception as e:
print(f"模型加载失败: {e}")
raise
def build_qwen_prompt(messages: List[OpenAIMessage]) -> str:
"""将OpenAI格式的消息转换为Qwen格式的提示词"""
prompt = ""
for msg in messages:
if msg.role == "system":
prompt += f"<|im_start|>system\n{msg.content}<|im_end|>\n"
elif msg.role == "user":
prompt += f"<|im_start|>user\n{msg.content}<|im_end|>\n"
elif msg.role == "assistant":
prompt += f"<|im_start|>assistant\n{msg.content}<|im_end|>\n"
# 添加最后的assistant开始标记
prompt += "<|im_start|>assistant\n"
return prompt
@app.get("/")
async def root():
return {
"service": "分析API (OpenAI兼容)",
"model": api_model,
"endpoints": {
"openai_chat": "/v1/chat/completions (POST)",
"openai_embeddings": "/v1/embeddings (POST)",
"health": "/health",
"models": "/v1/models"
},
"compatible_with": ["OpenAI Python SDK", "LangChain", "LlamaIndex"]
}
@app.get("/health")
async def health_check():
return {
"status": "healthy" if model else "unhealthy",
"model_loaded": model is not None,
"model_name": api_model
}
@app.get("/v1/models")
async def list_models():
"""返回可用的模型列表(OpenAI兼容)"""
return JSONResponse(content={
"object": "list",
"data": [
{
"id": api_model,
"object": "model",
"created": int(time.time()),
"owned_by": "custom"
},
{
"id": "qwen3-0.6B-embeddings",
"object": "model",
"created": int(time.time()),
"owned_by": "custom"
}
]
})
@app.post("/v1/chat/completions")
async def chat_completions(request: OpenAIRequest):
"""OpenAI兼容的聊天补全接口"""
if not model or not tokenizer:
raise HTTPException(status_code=503, detail="模型未加载")
start_time = time.time()
# 构建提示
prompt = build_qwen_prompt(request.messages)
# 流式响应
if request.stream:
async def generate_stream():
"""生成流式响应(OpenAI兼容格式)"""
stream_id = f"chatcmpl-{int(time.time())}"
# 发送开始事件
yield f"data: {json.dumps({'id': stream_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': request.model, 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n"
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
generated_ids = inputs['input_ids'].clone()
accumulated_text = ""
for i in range(request.max_tokens or 300):
with torch.no_grad():
outputs = model.generate(
generated_ids,
max_new_tokens=1,
temperature=request.temperature or 0.7,
do_sample=True,
top_p=request.top_p or 0.9,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
new_token_id = outputs[0, -1].item()
# 检查结束
if new_token_id == tokenizer.eos_token_id:
yield f"data: {json.dumps({'id': stream_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': request.model, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n"
yield "data: [DONE]\n\n"
break
# 解码token
new_token = tokenizer.decode([new_token_id], skip_special_tokens=False)
clean_token = new_token.replace("<|im_end|>", "").strip()
if clean_token:
accumulated_text += clean_token
yield f"data: {json.dumps({'id': stream_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': request.model, 'choices': [{'index': 0, 'delta': {'content': clean_token}, 'finish_reason': None}]})}\n\n"
await asyncio.sleep(0.01)
# 更新生成序列
generated_ids = torch.cat([generated_ids, outputs[:, -1:]], dim=-1)
yield f"data: {json.dumps({'id': stream_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': request.model, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/event-stream"
)
# 非流式响应
else:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=request.max_tokens or 300,
temperature=request.temperature or 0.7,
do_sample=True,
top_p=request.top_p or 0.9,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
# 提取assistant的回复
if "<|im_start|>assistant" in full_response:
assistant_response = full_response.split("<|im_start|>assistant")[-1]
answer = assistant_response.replace("<|im_end|>", "").strip()
else:
answer = full_response
# 清理标签和多余空格
answer = re.sub(r'<[^>]+>', '', answer) # 移除所有HTML标签
answer = re.sub(r'\n+', '\n', answer).strip()
response_time = time.time() - start_time
# 计算token数量(简化版)
prompt_tokens = len(tokenizer.encode(prompt))
completion_tokens = len(tokenizer.encode(answer))
response_id = f"chatcmpl-{int(time.time())}"
return JSONResponse(content={
"id": response_id,
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": answer
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
}
})
@app.post("/v1/embeddings")
async def create_embeddings(request: EmbeddingRequest):
"""OpenAI兼容的嵌入接口(简化版)"""
# 注意:Qwen 0.6B 没有专门的嵌入模型,这里返回模拟嵌入
# 实际使用中建议使用专门的嵌入模型
inputs = request.input if isinstance(request.input, list) else [request.input]
embeddings = []
for text in inputs:
# 使用tokenizer编码文本作为模拟嵌入
token_ids = tokenizer.encode(text, max_length=512, truncation=True)
# 创建模拟嵌入向量(384维)
import numpy as np
np.random.seed(hash(text) % 10000)
embedding = np.random.randn(384).tolist()
embeddings.append({
"object": "embedding",
"embedding": embedding,
"index": len(embeddings)
})
return JSONResponse(content={
"object": "list",
"data": embeddings,
"model": request.model,
"usage": {
"prompt_tokens": sum(len(tokenizer.encode(text)) for text in inputs),
"total_tokens": sum(len(tokenizer.encode(text)) for text in inputs)
}
})
if __name__ == "__main__":
print("启动OpenAI兼容API服务 v3.0...")
print("访问 http://localhost:8001/docs 查看API文档")
print("\n📋 OpenAI兼容接口:")
print("POST /v1/chat/completions")
print("POST /v1/embeddings")
print("GET /v1/models")
print("\n💡 使用方式:")
print("1. 使用OpenAI Python SDK:")
print(" client = OpenAI(base_url='http://localhost:8001', api_key='fake-key')")
print("2. 使用curl:")
print(' curl http://localhost:8001/v1/chat/completions \\')
print(' -H "Content-Type: application/json" \\')
print(' -d \'{"model": "qwen3-0.6B-human-resources", "messages": [{"role": "user", "content": "你好"}]}\'')
uvicorn.run(
app,
host="0.0.0.0",
port=8001,
log_level="info"
)
- THE END -
最后修改:2026年5月19日
非特殊说明,本博所有文章均为博主原创。