connections.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. """Lightweight connection handling for MCP servers."""
  2. from abc import ABC, abstractmethod
  3. from contextlib import AsyncExitStack
  4. from typing import Any
  5. from mcp import ClientSession, StdioServerParameters
  6. from mcp.client.sse import sse_client
  7. from mcp.client.stdio import stdio_client
  8. from mcp.client.streamable_http import streamablehttp_client
  9. class MCPConnection(ABC):
  10. """Base class for MCP server connections."""
  11. def __init__(self):
  12. self.session = None
  13. self._stack = None
  14. @abstractmethod
  15. def _create_context(self):
  16. """Create the connection context based on connection type."""
  17. async def __aenter__(self):
  18. """Initialize MCP server connection."""
  19. self._stack = AsyncExitStack()
  20. await self._stack.__aenter__()
  21. try:
  22. ctx = self._create_context()
  23. result = await self._stack.enter_async_context(ctx)
  24. if len(result) == 2:
  25. read, write = result
  26. elif len(result) == 3:
  27. read, write, _ = result
  28. else:
  29. raise ValueError(f"Unexpected context result: {result}")
  30. session_ctx = ClientSession(read, write)
  31. self.session = await self._stack.enter_async_context(session_ctx)
  32. await self.session.initialize()
  33. return self
  34. except BaseException:
  35. await self._stack.__aexit__(None, None, None)
  36. raise
  37. async def __aexit__(self, exc_type, exc_val, exc_tb):
  38. """Clean up MCP server connection resources."""
  39. if self._stack:
  40. await self._stack.__aexit__(exc_type, exc_val, exc_tb)
  41. self.session = None
  42. self._stack = None
  43. async def list_tools(self) -> list[dict[str, Any]]:
  44. """Retrieve available tools from the MCP server."""
  45. response = await self.session.list_tools()
  46. return [
  47. {
  48. "name": tool.name,
  49. "description": tool.description,
  50. "input_schema": tool.inputSchema,
  51. }
  52. for tool in response.tools
  53. ]
  54. async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
  55. """Call a tool on the MCP server with provided arguments."""
  56. result = await self.session.call_tool(tool_name, arguments=arguments)
  57. return result.content
  58. class MCPConnectionStdio(MCPConnection):
  59. """MCP connection using standard input/output."""
  60. def __init__(self, command: str, args: list[str] = None, env: dict[str, str] = None):
  61. super().__init__()
  62. self.command = command
  63. self.args = args or []
  64. self.env = env
  65. def _create_context(self):
  66. return stdio_client(
  67. StdioServerParameters(command=self.command, args=self.args, env=self.env)
  68. )
  69. class MCPConnectionSSE(MCPConnection):
  70. """MCP connection using Server-Sent Events."""
  71. def __init__(self, url: str, headers: dict[str, str] = None):
  72. super().__init__()
  73. self.url = url
  74. self.headers = headers or {}
  75. def _create_context(self):
  76. return sse_client(url=self.url, headers=self.headers)
  77. class MCPConnectionHTTP(MCPConnection):
  78. """MCP connection using Streamable HTTP."""
  79. def __init__(self, url: str, headers: dict[str, str] = None):
  80. super().__init__()
  81. self.url = url
  82. self.headers = headers or {}
  83. def _create_context(self):
  84. return streamablehttp_client(url=self.url, headers=self.headers)
  85. def create_connection(
  86. transport: str,
  87. command: str = None,
  88. args: list[str] = None,
  89. env: dict[str, str] = None,
  90. url: str = None,
  91. headers: dict[str, str] = None,
  92. ) -> MCPConnection:
  93. """Factory function to create the appropriate MCP connection.
  94. Args:
  95. transport: Connection type ("stdio", "sse", or "http")
  96. command: Command to run (stdio only)
  97. args: Command arguments (stdio only)
  98. env: Environment variables (stdio only)
  99. url: Server URL (sse and http only)
  100. headers: HTTP headers (sse and http only)
  101. Returns:
  102. MCPConnection instance
  103. """
  104. transport = transport.lower()
  105. if transport == "stdio":
  106. if not command:
  107. raise ValueError("Command is required for stdio transport")
  108. return MCPConnectionStdio(command=command, args=args, env=env)
  109. elif transport == "sse":
  110. if not url:
  111. raise ValueError("URL is required for sse transport")
  112. return MCPConnectionSSE(url=url, headers=headers)
  113. elif transport in ["http", "streamable_http", "streamable-http"]:
  114. if not url:
  115. raise ValueError("URL is required for http transport")
  116. return MCPConnectionHTTP(url=url, headers=headers)
  117. else:
  118. raise ValueError(f"Unsupported transport type: {transport}. Use 'stdio', 'sse', or 'http'")