| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151 |
- """Lightweight connection handling for MCP servers."""
- from abc import ABC, abstractmethod
- from contextlib import AsyncExitStack
- from typing import Any
- from mcp import ClientSession, StdioServerParameters
- from mcp.client.sse import sse_client
- from mcp.client.stdio import stdio_client
- from mcp.client.streamable_http import streamablehttp_client
- class MCPConnection(ABC):
- """Base class for MCP server connections."""
- def __init__(self):
- self.session = None
- self._stack = None
- @abstractmethod
- def _create_context(self):
- """Create the connection context based on connection type."""
- async def __aenter__(self):
- """Initialize MCP server connection."""
- self._stack = AsyncExitStack()
- await self._stack.__aenter__()
- try:
- ctx = self._create_context()
- result = await self._stack.enter_async_context(ctx)
- if len(result) == 2:
- read, write = result
- elif len(result) == 3:
- read, write, _ = result
- else:
- raise ValueError(f"Unexpected context result: {result}")
- session_ctx = ClientSession(read, write)
- self.session = await self._stack.enter_async_context(session_ctx)
- await self.session.initialize()
- return self
- except BaseException:
- await self._stack.__aexit__(None, None, None)
- raise
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- """Clean up MCP server connection resources."""
- if self._stack:
- await self._stack.__aexit__(exc_type, exc_val, exc_tb)
- self.session = None
- self._stack = None
- async def list_tools(self) -> list[dict[str, Any]]:
- """Retrieve available tools from the MCP server."""
- response = await self.session.list_tools()
- return [
- {
- "name": tool.name,
- "description": tool.description,
- "input_schema": tool.inputSchema,
- }
- for tool in response.tools
- ]
- async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
- """Call a tool on the MCP server with provided arguments."""
- result = await self.session.call_tool(tool_name, arguments=arguments)
- return result.content
- class MCPConnectionStdio(MCPConnection):
- """MCP connection using standard input/output."""
- def __init__(self, command: str, args: list[str] = None, env: dict[str, str] = None):
- super().__init__()
- self.command = command
- self.args = args or []
- self.env = env
- def _create_context(self):
- return stdio_client(
- StdioServerParameters(command=self.command, args=self.args, env=self.env)
- )
- class MCPConnectionSSE(MCPConnection):
- """MCP connection using Server-Sent Events."""
- def __init__(self, url: str, headers: dict[str, str] = None):
- super().__init__()
- self.url = url
- self.headers = headers or {}
- def _create_context(self):
- return sse_client(url=self.url, headers=self.headers)
- class MCPConnectionHTTP(MCPConnection):
- """MCP connection using Streamable HTTP."""
- def __init__(self, url: str, headers: dict[str, str] = None):
- super().__init__()
- self.url = url
- self.headers = headers or {}
- def _create_context(self):
- return streamablehttp_client(url=self.url, headers=self.headers)
- def create_connection(
- transport: str,
- command: str = None,
- args: list[str] = None,
- env: dict[str, str] = None,
- url: str = None,
- headers: dict[str, str] = None,
- ) -> MCPConnection:
- """Factory function to create the appropriate MCP connection.
- Args:
- transport: Connection type ("stdio", "sse", or "http")
- command: Command to run (stdio only)
- args: Command arguments (stdio only)
- env: Environment variables (stdio only)
- url: Server URL (sse and http only)
- headers: HTTP headers (sse and http only)
- Returns:
- MCPConnection instance
- """
- transport = transport.lower()
- if transport == "stdio":
- if not command:
- raise ValueError("Command is required for stdio transport")
- return MCPConnectionStdio(command=command, args=args, env=env)
- elif transport == "sse":
- if not url:
- raise ValueError("URL is required for sse transport")
- return MCPConnectionSSE(url=url, headers=headers)
- elif transport in ["http", "streamable_http", "streamable-http"]:
- if not url:
- raise ValueError("URL is required for http transport")
- return MCPConnectionHTTP(url=url, headers=headers)
- else:
- raise ValueError(f"Unsupported transport type: {transport}. Use 'stdio', 'sse', or 'http'")
|