langgraph同时使用tool和mcp(二)

TwoAdmin 2025-9-26 62 9/26

cometLanggraph.py

import asyncio
import logging
import json
from typing import AsyncIterable, Dict, Any
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.graph import StateGraph, END
from tavily import TavilyClient

from sdk.langgraph.langgraph import AgentState
from tools.job_info_tools import async_job_info_node

from app.config import comet_config
from service.langchainRedisManager import LangchainRedisManager



# 配置常量
class AgentConfig:
    MAX_ITERATIONS = 3
    CHUNK_SIZE = 3
    STREAM_DELAY = 0.03
    RECURSION_LIMIT = 15



# --- 将同步节点转换为异步节点 ---
async def async_web_search_node(state: AgentState):
    """异步网页搜索节点"""
    print("--- 进入节点: Web Searching ---")
    query = state.get("web_search_query")
    if not query:
        print("⚠️ 没有搜索查询词,跳过搜索")
        return state

    try:
        # 如果 web_info_func 是同步的,在线程池中执行
        result =await web_info_func(query)
        state['web_search_result'] = result
        print("✅ 网页搜索完成")
    except Exception as e:
        print(f"❌ 网页搜索失败: {e}")
        state['error'] = f"网页搜索失败: {str(e)}"

    return state

async def web_info_func(query: str) -> str:
    """根据查询词进行联网搜索"""
    tavily_client = TavilyClient(api_key=comet_config['tavily_key'])
    response = tavily_client.search(query=query, country="china")
    print(response)
    return response





async def async_classifier_node(state: AgentState):
    """异步分类器节点"""
    print("--- 进入节点: Classifier ---")

    iteration = state.get('iteration_count', 0) + 1
    state['iteration_count'] = iteration

    max_iterations = state.get('max_iterations', AgentConfig.MAX_ITERATIONS)
    if iteration > max_iterations:
        state['next_step'] = 'generate_response'
        return state

    print(f"🔄 迭代计数: {iteration}/{max_iterations}")
    # 检查是否已经有工具结果需要处理
    has_tool_results = any([
        state.get('user_info'),
        state.get('job_info'),
        state.get('web_search_result')
    ])

    print(f"🔍 has_tool_results: {has_tool_results}")

    if has_tool_results and iteration > 1:
        print("🛠️ 检测到工具执行结果,直接生成响应")
        state['next_step'] = 'generate_response'
        return state
    # 获取最后一条用户消息
    last_user_message = None
    for msg in reversed(state['messages']):
        if isinstance(msg, HumanMessage):
            last_user_message = msg.content
            break

    # 扩展的提示词,包含 MCP 工具
    prompt = f"""你是一个智能路由决策器。根据用户问题决定下一步操作。

用户问题: "{last_user_message}"

可用操作:
- 'fetch_job_info': 工作、职位、招聘相关
- 'web_search': 需要天气查询
- 'generate_response': 直接回答

重要规则: 如果已经有相关工具的执行结果,请选择 generate_response!

决策规则:
1. 工作信息 → fetch_job_info
2. 天气查询 → web_search
3. 其他情况 → generate_response

返回JSON格式:
{{
    "action": "操作名称",
    "reason": "决策理由",
    "query": "额外参数(如文件路径、计算表达式等)"
}}
"""

    try:
        # 使用异步方式调用 LLM
        response = await asyncio.get_event_loop().run_in_executor(
            None, llm.invoke, prompt
        )
        decision = json.loads(response.content)
        print(f"🤖 AI决策: {decision}")

        state['next_step'] = decision.get("action")
        # 根据不同的 MCP 工具设置相应参数
        if  state['next_step'] == 'web_search':
            # 将搜索查询词也暂存到状态中,以便web_search_node使用
            state['web_search_query'] = decision.get("query")

    except Exception as e:
        print(f"❌ 分类器错误: {e}")
        state['next_step'] = 'generate_response'

    return state


async def async_generate_response_node(state: AgentState):
    """异步生成响应节点"""
    print("--- 进入节点: Generating Response ---")

    # 准备上下文(包含 MCP 结果)
    context = {
        "user_info": state.get('user_info', '暂无'),
        "job_info": state.get('job_info', '暂无'),
        "web_search_result": state.get('web_search_result', '暂无')
    }

    prompt = f"""你是一个专业的、友好的AI职业顾问。
        请根据下面提供的上下文信息和完整的对话历史,回答用户的最后一个问题。

        **上下文信息:**
        {context}

        **对话历史:**
        {state['messages']}

        """



    # 获取最后一条用户问题
    last_user_message = ""
    for msg in reversed(state['messages']):
        if isinstance(msg, HumanMessage):
            last_user_message = msg.content
            break

    print(f"<UNK>上下文信息 <UNK>: {prompt}")
    try:
        # 使用异步方式调用 LLM
        response = await asyncio.get_event_loop().run_in_executor(
            None, llm.invoke, prompt
        )
        state['messages'].append(AIMessage(content=response.content))
        print("✅ 响应生成完成")
    except Exception as e:
        print(f"❌ 响应生成失败: {e}")
        state['error'] = f"响应生成失败: {str(e)}"

    return state


async def async_error_handler_node(state: AgentState):
    """异步错误处理节点"""
    error_msg = state.get("error", "未知错误")
    user_message = "抱歉,处理您的请求时遇到了问题。请稍后重试。"

    logging.error(f"智能体错误 - 用户ID: {state.get('userID')}, 错误: {error_msg}")
    state['messages'].append(AIMessage(content=user_message))
    return state


# --- 初始化 LLM ---
llm = ChatOpenAI(
    model=comet_config['api_model'],
    api_key=comet_config['api_key'],
    base_url=comet_config['api_uri'],
    extra_body={
        "chat_template_kwargs": {
            "enable_thinking": False,
            "enable_search": True,
        }
    },
    streaming=True
)


# --- 构建图 (集成 MCP 节点) ---
def decide_next_node(state: AgentState):
    """决策路由 - 扩展支持 MCP"""
    iteration = state.get('iteration_count', 0)
    max_iterations = state.get('max_iterations', AgentConfig.MAX_ITERATIONS)

    print(f"🔀 决策路由检查 - 迭代: {iteration}/{max_iterations}")

    if iteration >= max_iterations:
        return "generate_response_node"

    if state.get("error"):
        return "error_handler"

    action = state.get("next_step")

    # 扩展的动作映射
    action_map = {
        "fetch_job_info": "job_info_node",
        "web_search": "web_search_node",
        "generate_response": "generate_response_node"
    }

    next_node = action_map.get(action, "generate_response_node")
    print(f"➡️ 路由到: {next_node}")
    return next_node


# --- 创建工作流 ---
workflow = StateGraph(AgentState)

# 添加所有节点(全部使用异步版本)
workflow.add_node("classifier", async_classifier_node)
workflow.add_node("job_info_node", async_job_info_node)
workflow.add_node("web_search_node", async_web_search_node)
workflow.add_node("generate_response_node", async_generate_response_node)
workflow.add_node("error_handler", async_error_handler_node)

workflow.set_entry_point("classifier")
workflow.add_edge("error_handler", END)

# 条件边
workflow.add_conditional_edges(
    "classifier",
    decide_next_node,
    {
        "job_info_node": "job_info_node",
        "web_search_node": "web_search_node",
        "generate_response_node": "generate_response_node",
        "error_handler": "error_handler",
    }
)

# 普通边(工具节点执行后回到分类器)
workflow.add_edge("job_info_node", "classifier")
workflow.add_edge("web_search_node", "classifier")
workflow.add_edge("generate_response_node", END)

# 编译图
app = workflow.compile()


# --- 异步调用接口 ---
async def async_app_invoke(input_state: Dict) -> Dict:
    """异步调用应用"""
    return await app.ainvoke(input_state)


async def async_app_stream(input_state: Dict):
    """异步流式调用应用"""
    async for chunk in app.astream(input_state):
        yield chunk


# --- 使用示例 ---
async def test_mcp_integration():
    """测试 MCP 集成"""
    test_cases = [
        "计算 1 * 38 等于多少",
        "今天北京的天气怎么样",
        "帮我找一份 中级软件测试 开发工作"
    ]

    for query in test_cases:
        print(f"\n🧪 测试: {query}")
        initial_state = {
            "userID": "222",
            "messages": [HumanMessage(content=query)],
            "iteration_count": 0,
            "max_iterations": 3
        }

        # 使用异步调用
        #result = await async_app_invoke(initial_state)
        #print(f"📝 响应: {result['messages'][-1].content}")
        async for chunk in chat_stream(222, query):
            print("Received:", chunk)


# Redis管理器
redis_manager = LangchainRedisManager(1, "cometOptimizeChat:history:")


async def chat_stream(userID: int, query: str) -> AsyncIterable[Dict[str, Any]]:
    """流式聊天接口 - 异步版本"""
    logging.info(f"用户[{userID}]提问: {query}")

    try:
        # 获取对话历史
        current_messages = await redis_manager.load_messages(userID)
        current_messages.append(HumanMessage(content=query))

        # 准备初始状态
        initial_state = {
            "userID": str(userID),
            "messages": current_messages,
            "iteration_count": 0,
            "max_iterations": AgentConfig.MAX_ITERATIONS,
        }

        # 异步执行图
        final_state = await async_app_invoke(initial_state)

        # 获取AI回复
        ai_message = None
        for msg in reversed(final_state['messages']):
            if isinstance(msg, AIMessage):
                ai_message = msg
                break

        if not ai_message:
            yield {"type": "error", "error": "未生成AI回复", "is_complete": True}
            return

        full_response = ai_message.content

        # 流式输出
        chunks = [full_response[i:i + AgentConfig.CHUNK_SIZE]
                  for i in range(0, len(full_response), AgentConfig.CHUNK_SIZE)]

        for chunk_id, chunk in enumerate(chunks, 1):
            await asyncio.sleep(AgentConfig.STREAM_DELAY)
            yield {
                "type": "content_chunk",
                "content": chunk,
                "chunk_id": chunk_id,
                "is_complete": False
            }

        # 保存历史
        await redis_manager.save_messages(userID, final_state['messages'])
        await redis_manager.set_user_questions_num(userID)

        yield {
            "type": "complete",
            "full_response": full_response,
            "total_chunks": len(chunks),
            "is_complete": True
        }

    except Exception as e:
        logging.error(f"流式聊天错误: {e}")
        yield {"type": "error", "error": f"系统错误: {str(e)}", "is_complete": True}


if __name__ == "__main__":
    # 运行测试
    asyncio.run(test_mcp_integration())

 

- THE END -
Tag:

TwoAdmin

10月21日14:44

最后修改:2025年10月21日
0

非特殊说明,本博所有文章均为博主原创。