From f89d8bcef578c5f8f4d5e46b90237a0b619a4e09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20C=2E=20Andersen?= Date: Sat, 18 Oct 2025 11:39:17 +0200 Subject: [PATCH 1/8] Add support for multiple database connections with descriptions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add DATABASE_URI_* env vars for multiple connections (e.g., DATABASE_URI_APP → "app") - Add DATABASE_DESC_* env vars for connection descriptions visible to AI - Update all tools to accept required conn_name parameter - Add ConnectionRegistry to manage multiple connection pools - Display available connections in server context - Update tests for new connection architecture --- README.md | 53 ++++ src/postgres_mcp/server.py | 98 +++++--- src/postgres_mcp/sql/__init__.py | 2 + src/postgres_mcp/sql/sql_driver.py | 166 +++++++++++++ tests/unit/explain/test_server.py | 8 +- tests/unit/explain/test_server_integration.py | 10 +- tests/unit/sql/test_readonly_enforcement.py | 166 ++++++------- tests/unit/test_access_mode.py | 228 +++++++++--------- 8 files changed, 495 insertions(+), 236 deletions(-) diff --git a/README.md b/README.md index f2303b9..b2b200f 100644 --- a/README.md +++ b/README.md @@ -199,6 +199,59 @@ The Postgres MCP Pro Docker image will automatically remap the hostname `localho Replace `postgresql://...` with your [Postgres database connection URI](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING-URIS). +##### Multiple Database Connections + +Postgres MCP Pro supports connecting to multiple databases simultaneously. This is useful when you need to work across different databases (e.g., application database, ETL database, analytics database). + +To configure multiple connections, define additional environment variables with the pattern `DATABASE_URI_`: + +```json +{ + "mcpServers": { + "postgres": { + "command": "docker", + "args": [ + "run", + "-i", + "--rm", + "-e", "DATABASE_URI_APP", + "-e", "DATABASE_URI_ETL", + "-e", "DATABASE_URI_ANALYTICS", + "-e", "DATABASE_DESC_APP", + "-e", "DATABASE_DESC_ETL", + "-e", "DATABASE_DESC_ANALYTICS", + "crystaldba/postgres-mcp", + "--access-mode=unrestricted" + ], + "env": { + "DATABASE_URI_APP": "postgresql://user:pass@localhost:5432/app_db", + "DATABASE_URI_ETL": "postgresql://user:pass@localhost:5432/etl_db", + "DATABASE_URI_ANALYTICS": "postgresql://user:pass@localhost:5432/analytics_db", + "DATABASE_DESC_APP": "Main application database with user data and transactions", + "DATABASE_DESC_ETL": "ETL staging database for data processing pipelines", + "DATABASE_DESC_ANALYTICS": "Read-only analytics database with aggregated metrics" + } + } + } +} +``` + +Each connection is identified by its name (the part after `DATABASE_URI_`, converted to lowercase): +- `DATABASE_URI_APP` → connection name: `"app"` +- `DATABASE_URI_ETL` → connection name: `"etl"` +- `DATABASE_URI_ANALYTICS` → connection name: `"analytics"` + +**Connection Descriptions**: You can optionally provide descriptions for each connection using `DATABASE_DESC_` environment variables. These descriptions help the AI assistant understand which database to use for different tasks. The descriptions are: +- Automatically displayed in the server context (visible to the AI without requiring a tool call) +- Useful for guiding the AI to select the appropriate database + +When using tools, you'll specify which connection to use via the `conn_name` parameter: +- `list_schemas(conn_name="app")` - Lists schemas in the app database +- `explain_query(conn_name="etl", sql="SELECT ...")` - Explains query in the ETL database + +For backward compatibility, `DATABASE_URI` (without a suffix) maps to the connection name `"default"`. + + ##### Access Mode Postgres MCP Pro supports multiple *access modes* to give you control over the operations that the AI agent can perform on the database: diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index af5669a..37cccf6 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -26,6 +26,7 @@ from .index.index_opt_base import MAX_NUM_INDEX_TUNING_QUERIES from .index.llm_opt import LLMOptimizerTool from .index.presentation import TextPresentation +from .sql import ConnectionRegistry from .sql import DbConnPool from .sql import SafeSqlDriver from .sql import SqlDriver @@ -34,6 +35,7 @@ from .top_queries import TopQueriesCalc # Initialize FastMCP with default settings +# Note: Server instructions will be updated after database connections are discovered mcp = FastMCP("postgres-mcp") # Constants @@ -53,20 +55,32 @@ class AccessMode(str, Enum): # Global variables -db_connection = DbConnPool() +connection_registry = ConnectionRegistry() current_access_mode = AccessMode.UNRESTRICTED shutdown_in_progress = False -async def get_sql_driver() -> Union[SqlDriver, SafeSqlDriver]: - """Get the appropriate SQL driver based on the current access mode.""" +async def get_sql_driver(conn_name: str) -> Union[SqlDriver, SafeSqlDriver]: + """ + Get the appropriate SQL driver based on the current access mode. + + Args: + conn_name: Connection name (e.g., "default", "app", "etl") + + Returns: + SqlDriver or SafeSqlDriver instance + + Raises: + ValueError: If connection name doesn't exist + """ + db_connection = connection_registry.get_connection(conn_name) base_driver = SqlDriver(conn=db_connection) if current_access_mode == AccessMode.RESTRICTED: - logger.debug("Using SafeSqlDriver with restrictions (RESTRICTED mode)") + logger.debug(f"Using SafeSqlDriver with restrictions for '{conn_name}' (RESTRICTED mode)") return SafeSqlDriver(sql_driver=base_driver, timeout=30) # 30 second timeout else: - logger.debug("Using unrestricted SqlDriver (UNRESTRICTED mode)") + logger.debug(f"Using unrestricted SqlDriver for '{conn_name}' (UNRESTRICTED mode)") return base_driver @@ -81,10 +95,12 @@ def format_error_response(error: str) -> ResponseType: @mcp.tool(description="List all schemas in the database") -async def list_schemas() -> ResponseType: +async def list_schemas( + conn_name: str = Field(description="Connection name (e.g., 'default', 'app', 'etl')"), +) -> ResponseType: """List all schemas in the database.""" try: - sql_driver = await get_sql_driver() + sql_driver = await get_sql_driver(conn_name) rows = await sql_driver.execute_query( """ SELECT @@ -108,12 +124,13 @@ async def list_schemas() -> ResponseType: @mcp.tool(description="List objects in a schema") async def list_objects( + conn_name: str = Field(description="Connection name (e.g., 'default', 'app', 'etl')"), schema_name: str = Field(description="Schema name"), object_type: str = Field(description="Object type: 'table', 'view', 'sequence', or 'extension'", default="table"), ) -> ResponseType: """List objects of a given type in a schema.""" try: - sql_driver = await get_sql_driver() + sql_driver = await get_sql_driver(conn_name) if object_type in ("table", "view"): table_type = "BASE TABLE" if object_type == "table" else "VIEW" @@ -176,13 +193,14 @@ async def list_objects( @mcp.tool(description="Show detailed information about a database object") async def get_object_details( + conn_name: str = Field(description="Connection name (e.g., 'default', 'app', 'etl')"), schema_name: str = Field(description="Schema name"), object_name: str = Field(description="Object name"), object_type: str = Field(description="Object type: 'table', 'view', 'sequence', or 'extension'", default="table"), ) -> ResponseType: """Get detailed information about a database object.""" try: - sql_driver = await get_sql_driver() + sql_driver = await get_sql_driver(conn_name) if object_type in ("table", "view"): # Get columns @@ -309,6 +327,7 @@ async def get_object_details( @mcp.tool(description="Explains the execution plan for a SQL query, showing how the database will execute it and provides detailed cost estimates.") async def explain_query( + conn_name: str = Field(description="Connection name (e.g., 'default', 'app', 'etl')"), sql: str = Field(description="SQL query to explain"), analyze: bool = Field( description="When True, actually runs the query to show real execution statistics instead of estimates. " @@ -333,12 +352,13 @@ async def explain_query( Explains the execution plan for a SQL query. Args: + conn_name: Connection name to use sql: The SQL query to explain analyze: When True, actually runs the query for real statistics hypothetical_indexes: Optional list of indexes to simulate """ try: - sql_driver = await get_sql_driver() + sql_driver = await get_sql_driver(conn_name) explain_tool = ExplainPlanTool(sql_driver=sql_driver) result: ExplainPlanArtifact | ErrorResult | None = None @@ -388,11 +408,12 @@ async def explain_query( # Query function declaration without the decorator - we'll add it dynamically based on access mode async def execute_sql( + conn_name: str = Field(description="Connection name (e.g., 'default', 'app', 'etl')"), sql: str = Field(description="SQL to run", default="all"), ) -> ResponseType: """Executes a SQL query against the database.""" try: - sql_driver = await get_sql_driver() + sql_driver = await get_sql_driver(conn_name) rows = await sql_driver.execute_query(sql) # type: ignore if rows is None: return format_text_response("No results") @@ -405,12 +426,13 @@ async def execute_sql( @mcp.tool(description="Analyze frequently executed queries in the database and recommend optimal indexes") @validate_call async def analyze_workload_indexes( + conn_name: str = Field(description="Connection name (e.g., 'default', 'app', 'etl')"), max_index_size_mb: int = Field(description="Max index size in MB", default=10000), method: Literal["dta", "llm"] = Field(description="Method to use for analysis", default="dta"), ) -> ResponseType: """Analyze frequently executed queries in the database and recommend optimal indexes.""" try: - sql_driver = await get_sql_driver() + sql_driver = await get_sql_driver(conn_name) if method == "dta": index_tuning = DatabaseTuningAdvisor(sql_driver) else: @@ -426,6 +448,7 @@ async def analyze_workload_indexes( @mcp.tool(description="Analyze a list of (up to 10) SQL queries and recommend optimal indexes") @validate_call async def analyze_query_indexes( + conn_name: str = Field(description="Connection name (e.g., 'default', 'app', 'etl')"), queries: list[str] = Field(description="List of Query strings to analyze"), max_index_size_mb: int = Field(description="Max index size in MB", default=10000), method: Literal["dta", "llm"] = Field(description="Method to use for analysis", default="dta"), @@ -437,7 +460,7 @@ async def analyze_query_indexes( return format_error_response(f"Please provide a list of up to {MAX_NUM_INDEX_TUNING_QUERIES} queries to analyze.") try: - sql_driver = await get_sql_driver() + sql_driver = await get_sql_driver(conn_name) if method == "dta": index_tuning = DatabaseTuningAdvisor(sql_driver) else: @@ -463,6 +486,7 @@ async def analyze_query_indexes( "You can optionally specify a single health check or a comma-separated list of health checks. The default is 'all' checks." ) async def analyze_db_health( + conn_name: str = Field(description="Connection name (e.g., 'default', 'app', 'etl')"), health_type: str = Field( description=f"Optional. Valid values are: {', '.join(sorted([t.value for t in HealthType]))}.", default="all", @@ -471,10 +495,11 @@ async def analyze_db_health( """Analyze database health for specified components. Args: + conn_name: Connection name to use health_type: Comma-separated list of health check types to perform. Valid values: index, connection, vacuum, sequence, replication, buffer, constraint, all """ - health_tool = DatabaseHealthTool(await get_sql_driver()) + health_tool = DatabaseHealthTool(await get_sql_driver(conn_name)) result = await health_tool.health(health_type=health_type) return format_text_response(result) @@ -484,6 +509,7 @@ async def analyze_db_health( description=f"Reports the slowest or most resource-intensive queries using data from the '{PG_STAT_STATEMENTS}' extension.", ) async def get_top_queries( + conn_name: str = Field(description="Connection name (e.g., 'default', 'app', 'etl')"), sort_by: str = Field( description="Ranking criteria: 'total_time' for total execution time or 'mean_time' for mean execution time per call, or 'resources' " "for resource-intensive queries", @@ -492,7 +518,7 @@ async def get_top_queries( limit: int = Field(description="Number of queries to return when ranking based on mean_time or total_time", default=10), ) -> ResponseType: try: - sql_driver = await get_sql_driver() + sql_driver = await get_sql_driver(conn_name) top_queries_tool = TopQueriesCalc(sql_driver=sql_driver) if sort_by == "resources": @@ -554,24 +580,36 @@ async def main(): logger.info(f"Starting PostgreSQL MCP Server in {current_access_mode.upper()} mode") - # Get database URL from environment variable or command line - database_url = os.environ.get("DATABASE_URI", args.database_url) - - if not database_url: - raise ValueError( - "Error: No database URL provided. Please specify via 'DATABASE_URI' environment variable or command-line argument.", - ) + # Initialize database connection registry + # For backwards compatibility, support command-line database_url argument + if args.database_url and "DATABASE_URI" not in os.environ: + os.environ["DATABASE_URI"] = args.database_url + logger.info("Using command-line database URL as DATABASE_URI") - # Initialize database connection pool try: - await db_connection.pool_connect(database_url) - logger.info("Successfully connected to database and initialized connection pool") + await connection_registry.discover_and_connect() + conn_names = connection_registry.get_connection_names() + logger.info(f"Successfully initialized {len(conn_names)} connection(s): {', '.join(conn_names)}") + + # Update server context with available connections + conn_info = connection_registry.get_connection_info() + if conn_info: + instructions = ["Available database connections:"] + for info in conn_info: + if "description" in info: + instructions.append(f"- {info['name']}: {info['description']}") + else: + instructions.append(f"- {info['name']}") + + # Set the server instructions to include connection information + mcp._instructions = "\n".join(instructions) + logger.info(f"Updated server context with {len(conn_info)} connection(s)") except Exception as e: logger.warning( - f"Could not connect to database: {obfuscate_password(str(e))}", + f"Could not initialize database connections: {obfuscate_password(str(e))}", ) logger.warning( - "The MCP server will start but database operations will fail until a valid connection is established.", + "The MCP server will start but database operations will fail until valid connections are established.", ) # Set up proper shutdown handling @@ -609,10 +647,10 @@ async def shutdown(sig=None): if sig: logger.info(f"Received exit signal {sig.name}") - # Close database connections + # Close all database connections try: - await db_connection.close() - logger.info("Closed database connections") + await connection_registry.close_all() + logger.info("Closed all database connections") except Exception as e: logger.error(f"Error closing database connections: {e}") diff --git a/src/postgres_mcp/sql/__init__.py b/src/postgres_mcp/sql/__init__.py index 1fded3b..921d0c3 100644 --- a/src/postgres_mcp/sql/__init__.py +++ b/src/postgres_mcp/sql/__init__.py @@ -10,12 +10,14 @@ from .extension_utils import reset_postgres_version_cache from .index import IndexDefinition from .safe_sql import SafeSqlDriver +from .sql_driver import ConnectionRegistry from .sql_driver import DbConnPool from .sql_driver import SqlDriver from .sql_driver import obfuscate_password __all__ = [ "ColumnCollector", + "ConnectionRegistry", "DbConnPool", "IndexDefinition", "SafeSqlDriver", diff --git a/src/postgres_mcp/sql/sql_driver.py b/src/postgres_mcp/sql/sql_driver.py index 5beacb0..2e61253 100644 --- a/src/postgres_mcp/sql/sql_driver.py +++ b/src/postgres_mcp/sql/sql_driver.py @@ -1,6 +1,8 @@ """SQL driver adapter for PostgreSQL connections.""" +import asyncio import logging +import os import re from dataclasses import dataclass from typing import Any @@ -136,6 +138,170 @@ def last_error(self) -> Optional[str]: return self._last_error +class ConnectionRegistry: + """Registry for managing multiple database connections.""" + + def __init__(self): + self.connections: Dict[str, DbConnPool] = {} + self._connection_urls: Dict[str, str] = {} + self._connection_descriptions: Dict[str, str] = {} + + def discover_connections(self) -> Dict[str, str]: + """ + Discover all DATABASE_URI_* environment variables. + + Returns: + Dict mapping connection names to connection URLs + - DATABASE_URI -> "default" + - DATABASE_URI_APP -> "app" + - DATABASE_URI_ETL -> "etl" + """ + discovered = {} + + for env_var, url in os.environ.items(): + if env_var == "DATABASE_URI": + discovered["default"] = url + elif env_var.startswith("DATABASE_URI_"): + # Extract postfix and lowercase it + postfix = env_var[len("DATABASE_URI_"):] + conn_name = postfix.lower() + discovered[conn_name] = url + + return discovered + + def discover_descriptions(self) -> Dict[str, str]: + """ + Discover all DATABASE_DESC_* environment variables. + + Returns: + Dict mapping connection names to descriptions + - DATABASE_DESC -> "default" + - DATABASE_DESC_APP -> "app" + - DATABASE_DESC_ETL -> "etl" + """ + descriptions = {} + + for env_var, desc in os.environ.items(): + if env_var == "DATABASE_DESC": + descriptions["default"] = desc + elif env_var.startswith("DATABASE_DESC_"): + # Extract postfix and lowercase it + postfix = env_var[len("DATABASE_DESC_"):] + conn_name = postfix.lower() + descriptions[conn_name] = desc + + return descriptions + + async def discover_and_connect(self) -> None: + """ + Discover all DATABASE_URI_* environment variables and connect to them. + Connections are initialized in parallel for efficiency. + """ + discovered = self.discover_connections() + + if not discovered: + raise ValueError( + "No database connections found. Please set DATABASE_URI or DATABASE_URI_* environment variables." + ) + + logger.info(f"Discovered {len(discovered)} database connection(s): {', '.join(discovered.keys())}") + + # Store URLs and descriptions for reference + self._connection_urls = discovered.copy() + self._connection_descriptions = self.discover_descriptions() + + # Create connection pools + for conn_name, url in discovered.items(): + self.connections[conn_name] = DbConnPool(url) + + # Connect to all databases in parallel + async def connect_single(conn_name: str, pool: DbConnPool) -> tuple[str, bool, Optional[str]]: + """Connect to a single database and return status.""" + try: + await pool.pool_connect() + return (conn_name, True, None) + except Exception as e: + error_msg = obfuscate_password(str(e)) + logger.warning(f"Failed to connect to '{conn_name}': {error_msg}") + return (conn_name, False, error_msg) + + # Execute all connections in parallel + results = await asyncio.gather( + *[connect_single(name, pool) for name, pool in self.connections.items()], + return_exceptions=False + ) + + # Log results + for conn_name, success, error in results: + if success: + logger.info(f"Successfully connected to '{conn_name}'") + else: + logger.warning(f"Connection '{conn_name}' failed: {error}") + + def get_connection(self, conn_name: str) -> DbConnPool: + """ + Get a connection pool by name. + + Args: + conn_name: Connection name (e.g., "default", "app", "etl") + + Returns: + DbConnPool instance + + Raises: + ValueError: If connection name doesn't exist + """ + if conn_name not in self.connections: + available = ", ".join(f"'{name}'" for name in sorted(self.connections.keys())) + raise ValueError( + f"Connection '{conn_name}' not found. Available connections: {available}" + ) + + pool = self.connections[conn_name] + + # Check if connection is valid + if not pool.is_valid: + error_msg = pool.last_error or "Unknown error" + raise ValueError( + f"Connection '{conn_name}' is not available: {obfuscate_password(error_msg)}" + ) + + return pool + + async def close_all(self) -> None: + """Close all database connections.""" + close_tasks = [] + for conn_name, pool in self.connections.items(): + logger.info(f"Closing connection '{conn_name}'...") + close_tasks.append(pool.close()) + + # Close all connections in parallel + await asyncio.gather(*close_tasks, return_exceptions=True) + + self.connections.clear() + self._connection_urls.clear() + self._connection_descriptions.clear() + + def get_connection_names(self) -> List[str]: + """Get list of all connection names.""" + return list(self.connections.keys()) + + def get_connection_info(self) -> List[Dict[str, str]]: + """ + Get information about all configured connections. + + Returns: + List of dicts with 'name' and optional 'description' for each connection + """ + info = [] + for conn_name in sorted(self.connections.keys()): + conn_info = {"name": conn_name} + if conn_name in self._connection_descriptions: + conn_info["description"] = self._connection_descriptions[conn_name] + info.append(conn_info) + return info + + class SqlDriver: """Adapter class that wraps a PostgreSQL connection with the interface expected by DTA.""" diff --git a/tests/unit/explain/test_server.py b/tests/unit/explain/test_server.py index 0fea0e0..5b4602f 100644 --- a/tests/unit/explain/test_server.py +++ b/tests/unit/explain/test_server.py @@ -46,7 +46,7 @@ async def test_explain_query_basic(): # Use patch to replace the actual explain_query function with our own mock with patch.object(server, "explain_query", return_value=[mock_response]): # Call the patched function - result = await server.explain_query("SELECT * FROM users") + result = await server.explain_query(conn_name="default", sql="SELECT * FROM users") # Verify we get the expected result assert isinstance(result, list) @@ -74,7 +74,7 @@ async def test_explain_query_analyze(): # Use patch to replace the actual explain_query function with our own mock with patch.object(server, "explain_query", return_value=[mock_response]): # Call the patched function with analyze=True - result = await server.explain_query("SELECT * FROM users", analyze=True) + result = await server.explain_query(conn_name="default", sql="SELECT * FROM users", analyze=True) # Verify we get the expected result assert isinstance(result, list) @@ -104,7 +104,7 @@ async def test_explain_query_hypothetical_indexes(): # Use patch to replace the actual explain_query function with our own mock with patch.object(server, "explain_query", return_value=[mock_response]): # Call the patched function with hypothetical_indexes - result = await server.explain_query(test_sql, hypothetical_indexes=test_indexes) + result = await server.explain_query(conn_name="default", sql=test_sql, hypothetical_indexes=test_indexes) # Verify we get the expected result assert isinstance(result, list) @@ -123,7 +123,7 @@ async def test_explain_query_error_handling(): # Use patch to replace the actual function with our mock that returns an error with patch.object(server, "explain_query", return_value=[mock_response]): # Call the patched function - result = await server.explain_query("INVALID SQL") + result = await server.explain_query(conn_name="default", sql="INVALID SQL") # Verify error is formatted correctly assert isinstance(result, list) diff --git a/tests/unit/explain/test_server_integration.py b/tests/unit/explain/test_server_integration.py index aa8d704..43c39ea 100644 --- a/tests/unit/explain/test_server_integration.py +++ b/tests/unit/explain/test_server_integration.py @@ -45,7 +45,7 @@ async def test_explain_query_integration(): with patch("postgres_mcp.server.get_sql_driver"): # Patch the ExplainPlanTool with patch("postgres_mcp.server.ExplainPlanTool"): - result = await explain_query("SELECT * FROM users", hypothetical_indexes=None) + result = await explain_query(conn_name="default", sql="SELECT * FROM users", hypothetical_indexes=None) # Verify result matches our expected plan data assert isinstance(result, list) @@ -67,7 +67,7 @@ async def test_explain_query_with_analyze_integration(): with patch("postgres_mcp.server.get_sql_driver"): # Patch the ExplainPlanTool with patch("postgres_mcp.server.ExplainPlanTool"): - result = await explain_query("SELECT * FROM users", analyze=True, hypothetical_indexes=None) + result = await explain_query(conn_name="default", sql="SELECT * FROM users", analyze=True, hypothetical_indexes=None) # Verify result matches our expected plan data assert isinstance(result, list) @@ -98,7 +98,7 @@ async def test_explain_query_with_hypothetical_indexes_integration(): with patch("postgres_mcp.server.get_sql_driver", return_value=mock_safe_driver): # Patch the ExplainPlanTool with patch("postgres_mcp.server.ExplainPlanTool"): - result = await explain_query(test_sql, hypothetical_indexes=test_indexes) + result = await explain_query(conn_name="default", sql=test_sql, hypothetical_indexes=test_indexes) # Verify result matches our expected plan data assert isinstance(result, list) @@ -129,7 +129,7 @@ async def test_explain_query_missing_hypopg_integration(): with patch("postgres_mcp.server.get_sql_driver", return_value=mock_safe_driver): # Patch the ExplainPlanTool with patch("postgres_mcp.server.ExplainPlanTool"): - result = await explain_query(test_sql, hypothetical_indexes=test_indexes) + result = await explain_query(conn_name="default", sql=test_sql, hypothetical_indexes=test_indexes) # Verify result assert isinstance(result, list) @@ -152,7 +152,7 @@ async def test_explain_query_error_handling_integration(): "postgres_mcp.server.get_sql_driver", side_effect=Exception(error_message), ): - result = await explain_query("INVALID SQL") + result = await explain_query(conn_name="default", sql="INVALID SQL") # Verify error is correctly formatted assert isinstance(result, list) diff --git a/tests/unit/sql/test_readonly_enforcement.py b/tests/unit/sql/test_readonly_enforcement.py index e079c02..334d5f0 100644 --- a/tests/unit/sql/test_readonly_enforcement.py +++ b/tests/unit/sql/test_readonly_enforcement.py @@ -1,83 +1,83 @@ -from unittest.mock import AsyncMock -from unittest.mock import MagicMock -from unittest.mock import patch - -import pytest - -from postgres_mcp.server import AccessMode -from postgres_mcp.server import get_sql_driver -from postgres_mcp.sql import SafeSqlDriver -from postgres_mcp.sql import SqlDriver - - -@pytest.mark.asyncio -async def test_force_readonly_enforcement(): - """ - Test that force_readonly is properly enforced based on access mode: - - In RESTRICTED mode: force_readonly is always True regardless of what's passed - - In UNRESTRICTED mode: force_readonly respects the passed value (default False) - """ - # Create mock for connection pool - mock_conn_pool = MagicMock() - mock_conn_pool._is_valid = True - - # Create a mock for the base SqlDriver._execute_with_connection - mock_execute = AsyncMock() - mock_execute.return_value = [SqlDriver.RowResult(cells={"test": "value"})] - - # Test UNRESTRICTED mode - with patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), patch( - "postgres_mcp.server.db_connection", mock_conn_pool - ), patch.object(SqlDriver, "_execute_with_connection", mock_execute): - driver = await get_sql_driver() - assert isinstance(driver, SqlDriver) - assert not isinstance(driver, SafeSqlDriver) - - # Test default behavior (should be False) - mock_execute.reset_mock() - await driver.execute_query("SELECT 1") - assert mock_execute.call_count == 1 - # Check that force_readonly is False by default - assert mock_execute.call_args[1]["force_readonly"] is False - - # Test explicit True - mock_execute.reset_mock() - await driver.execute_query("SELECT 1", force_readonly=True) - assert mock_execute.call_count == 1 - # Check that force_readonly=True is respected - assert mock_execute.call_args[1]["force_readonly"] is True - - # Test explicit False - mock_execute.reset_mock() - await driver.execute_query("SELECT 1", force_readonly=False) - assert mock_execute.call_count == 1 - # Check that force_readonly=False is respected - assert mock_execute.call_args[1]["force_readonly"] is False - - # Test RESTRICTED mode - with patch("postgres_mcp.server.current_access_mode", AccessMode.RESTRICTED), patch( - "postgres_mcp.server.db_connection", mock_conn_pool - ), patch.object(SqlDriver, "_execute_with_connection", mock_execute): - driver = await get_sql_driver() - assert isinstance(driver, SafeSqlDriver) - - # Test default behavior - mock_execute.reset_mock() - await driver.execute_query("SELECT 1") - assert mock_execute.call_count == 1 - # Check that force_readonly is always True - assert mock_execute.call_args[1]["force_readonly"] is True - - # Test explicit False (should still be True) - mock_execute.reset_mock() - await driver.execute_query("SELECT 1", force_readonly=False) - assert mock_execute.call_count == 1 - # Check that force_readonly is True despite passing False - assert mock_execute.call_args[1]["force_readonly"] is True - - # Test explicit True - mock_execute.reset_mock() - await driver.execute_query("SELECT 1", force_readonly=True) - assert mock_execute.call_count == 1 - # Check that force_readonly remains True - assert mock_execute.call_args[1]["force_readonly"] is True +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +from postgres_mcp.server import AccessMode +from postgres_mcp.server import get_sql_driver +from postgres_mcp.sql import SafeSqlDriver +from postgres_mcp.sql import SqlDriver + + +@pytest.mark.asyncio +async def test_force_readonly_enforcement(): + """ + Test that force_readonly is properly enforced based on access mode: + - In RESTRICTED mode: force_readonly is always True regardless of what's passed + - In UNRESTRICTED mode: force_readonly respects the passed value (default False) + """ + # Create mock for connection pool + mock_conn_pool = MagicMock() + mock_conn_pool._is_valid = True + + # Create a mock for the base SqlDriver._execute_with_connection + mock_execute = AsyncMock() + mock_execute.return_value = [SqlDriver.RowResult(cells={"test": "value"})] + + # Test UNRESTRICTED mode + with patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), patch( + "postgres_mcp.server.connection_registry.get_connection", return_value=mock_conn_pool + ), patch.object(SqlDriver, "_execute_with_connection", mock_execute): + driver = await get_sql_driver(conn_name="default") + assert isinstance(driver, SqlDriver) + assert not isinstance(driver, SafeSqlDriver) + + # Test default behavior (should be False) + mock_execute.reset_mock() + await driver.execute_query("SELECT 1") + assert mock_execute.call_count == 1 + # Check that force_readonly is False by default + assert mock_execute.call_args[1]["force_readonly"] is False + + # Test explicit True + mock_execute.reset_mock() + await driver.execute_query("SELECT 1", force_readonly=True) + assert mock_execute.call_count == 1 + # Check that force_readonly=True is respected + assert mock_execute.call_args[1]["force_readonly"] is True + + # Test explicit False + mock_execute.reset_mock() + await driver.execute_query("SELECT 1", force_readonly=False) + assert mock_execute.call_count == 1 + # Check that force_readonly=False is respected + assert mock_execute.call_args[1]["force_readonly"] is False + + # Test RESTRICTED mode + with patch("postgres_mcp.server.current_access_mode", AccessMode.RESTRICTED), patch( + "postgres_mcp.server.connection_registry.get_connection", return_value=mock_conn_pool + ), patch.object(SqlDriver, "_execute_with_connection", mock_execute): + driver = await get_sql_driver(conn_name="default") + assert isinstance(driver, SafeSqlDriver) + + # Test default behavior + mock_execute.reset_mock() + await driver.execute_query("SELECT 1") + assert mock_execute.call_count == 1 + # Check that force_readonly is always True + assert mock_execute.call_args[1]["force_readonly"] is True + + # Test explicit False (should still be True) + mock_execute.reset_mock() + await driver.execute_query("SELECT 1", force_readonly=False) + assert mock_execute.call_count == 1 + # Check that force_readonly is True despite passing False + assert mock_execute.call_args[1]["force_readonly"] is True + + # Test explicit True + mock_execute.reset_mock() + await driver.execute_query("SELECT 1", force_readonly=True) + assert mock_execute.call_count == 1 + # Check that force_readonly remains True + assert mock_execute.call_args[1]["force_readonly"] is True diff --git a/tests/unit/test_access_mode.py b/tests/unit/test_access_mode.py index f7d3b80..84c2b99 100644 --- a/tests/unit/test_access_mode.py +++ b/tests/unit/test_access_mode.py @@ -1,114 +1,114 @@ -import asyncio -from unittest.mock import AsyncMock -from unittest.mock import MagicMock -from unittest.mock import patch - -import pytest - -from postgres_mcp.server import AccessMode -from postgres_mcp.server import get_sql_driver -from postgres_mcp.sql.safe_sql import SafeSqlDriver -from postgres_mcp.sql.sql_driver import DbConnPool -from postgres_mcp.sql.sql_driver import SqlDriver - - -@pytest.fixture -def mock_db_connection(): - """Mock database connection pool.""" - conn = MagicMock(spec=DbConnPool) - conn.is_valid = True - return conn - - -@pytest.mark.parametrize( - "access_mode,expected_driver_type", - [ - (AccessMode.UNRESTRICTED, SqlDriver), - (AccessMode.RESTRICTED, SafeSqlDriver), - ], -) -@pytest.mark.asyncio -async def test_get_sql_driver_returns_correct_driver(access_mode, expected_driver_type, mock_db_connection): - """Test that get_sql_driver returns the correct driver type based on access mode.""" - with ( - patch("postgres_mcp.server.current_access_mode", access_mode), - patch("postgres_mcp.server.db_connection", mock_db_connection), - ): - driver = await get_sql_driver() - assert isinstance(driver, expected_driver_type) - - # When in RESTRICTED mode, verify timeout is set - if access_mode == AccessMode.RESTRICTED: - assert isinstance(driver, SafeSqlDriver) - assert driver.timeout == 30 - - -@pytest.mark.asyncio -async def test_get_sql_driver_sets_timeout_in_restricted_mode(mock_db_connection): - """Test that get_sql_driver sets the timeout in restricted mode.""" - with ( - patch("postgres_mcp.server.current_access_mode", AccessMode.RESTRICTED), - patch("postgres_mcp.server.db_connection", mock_db_connection), - ): - driver = await get_sql_driver() - assert isinstance(driver, SafeSqlDriver) - assert driver.timeout == 30 - assert hasattr(driver, "sql_driver") - - -@pytest.mark.asyncio -async def test_get_sql_driver_in_unrestricted_mode_no_timeout(mock_db_connection): - """Test that get_sql_driver in unrestricted mode is a regular SqlDriver.""" - with ( - patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), - patch("postgres_mcp.server.db_connection", mock_db_connection), - ): - driver = await get_sql_driver() - assert isinstance(driver, SqlDriver) - assert not hasattr(driver, "timeout") - - -@pytest.mark.asyncio -async def test_command_line_parsing(): - """Test that command-line arguments correctly set the access mode.""" - import sys - - from postgres_mcp.server import main - - # Mock sys.argv and asyncio.run - original_argv = sys.argv - original_run = asyncio.run - - try: - # Test with --access-mode=restricted - sys.argv = [ - "postgres_mcp", - "postgresql://user:password@localhost/db", - "--access-mode=restricted", - ] - asyncio.run = AsyncMock() - - with ( - patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), - patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), - patch("postgres_mcp.server.mcp.run_stdio_async", AsyncMock()), - patch("postgres_mcp.server.shutdown", AsyncMock()), - ): - # Reset the current_access_mode to UNRESTRICTED - import postgres_mcp.server - - postgres_mcp.server.current_access_mode = AccessMode.UNRESTRICTED - - # Run main (partially mocked to avoid actual connection) - try: - await main() - except Exception: - pass - - # Verify the mode was changed to RESTRICTED - assert postgres_mcp.server.current_access_mode == AccessMode.RESTRICTED - - finally: - # Restore original values - sys.argv = original_argv - asyncio.run = original_run +import asyncio +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +from postgres_mcp.server import AccessMode +from postgres_mcp.server import get_sql_driver +from postgres_mcp.sql.safe_sql import SafeSqlDriver +from postgres_mcp.sql.sql_driver import DbConnPool +from postgres_mcp.sql.sql_driver import SqlDriver + + +@pytest.fixture +def mock_db_connection(): + """Mock database connection pool.""" + conn = MagicMock(spec=DbConnPool) + conn.is_valid = True + return conn + + +@pytest.mark.parametrize( + "access_mode,expected_driver_type", + [ + (AccessMode.UNRESTRICTED, SqlDriver), + (AccessMode.RESTRICTED, SafeSqlDriver), + ], +) +@pytest.mark.asyncio +async def test_get_sql_driver_returns_correct_driver(access_mode, expected_driver_type, mock_db_connection): + """Test that get_sql_driver returns the correct driver type based on access mode.""" + with ( + patch("postgres_mcp.server.current_access_mode", access_mode), + patch("postgres_mcp.server.connection_registry.get_connection", return_value=mock_db_connection), + ): + driver = await get_sql_driver(conn_name="default") + assert isinstance(driver, expected_driver_type) + + # When in RESTRICTED mode, verify timeout is set + if access_mode == AccessMode.RESTRICTED: + assert isinstance(driver, SafeSqlDriver) + assert driver.timeout == 30 + + +@pytest.mark.asyncio +async def test_get_sql_driver_sets_timeout_in_restricted_mode(mock_db_connection): + """Test that get_sql_driver sets the timeout in restricted mode.""" + with ( + patch("postgres_mcp.server.current_access_mode", AccessMode.RESTRICTED), + patch("postgres_mcp.server.connection_registry.get_connection", return_value=mock_db_connection), + ): + driver = await get_sql_driver(conn_name="default") + assert isinstance(driver, SafeSqlDriver) + assert driver.timeout == 30 + assert hasattr(driver, "sql_driver") + + +@pytest.mark.asyncio +async def test_get_sql_driver_in_unrestricted_mode_no_timeout(mock_db_connection): + """Test that get_sql_driver in unrestricted mode is a regular SqlDriver.""" + with ( + patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), + patch("postgres_mcp.server.connection_registry.get_connection", return_value=mock_db_connection), + ): + driver = await get_sql_driver(conn_name="default") + assert isinstance(driver, SqlDriver) + assert not hasattr(driver, "timeout") + + +@pytest.mark.asyncio +async def test_command_line_parsing(): + """Test that command-line arguments correctly set the access mode.""" + import sys + + from postgres_mcp.server import main + + # Mock sys.argv and asyncio.run + original_argv = sys.argv + original_run = asyncio.run + + try: + # Test with --access-mode=restricted + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + "--access-mode=restricted", + ] + asyncio.run = AsyncMock() + + with ( + patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), + patch("postgres_mcp.server.connection_registry.discover_and_connect", AsyncMock()), + patch("postgres_mcp.server.mcp.run_stdio_async", AsyncMock()), + patch("postgres_mcp.server.shutdown", AsyncMock()), + ): + # Reset the current_access_mode to UNRESTRICTED + import postgres_mcp.server + + postgres_mcp.server.current_access_mode = AccessMode.UNRESTRICTED + + # Run main (partially mocked to avoid actual connection) + try: + await main() + except Exception: + pass + + # Verify the mode was changed to RESTRICTED + assert postgres_mcp.server.current_access_mode == AccessMode.RESTRICTED + + finally: + # Restore original values + sys.argv = original_argv + asyncio.run = original_run From 0b9fdca8223eeca9ffec4430c6767e42cc12bc9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20C=2E=20Andersen?= Date: Sat, 18 Oct 2025 12:27:30 +0200 Subject: [PATCH 2/8] Fixing line endings. --- tests/unit/sql/test_readonly_enforcement.py | 166 +++++++------- tests/unit/test_access_mode.py | 228 ++++++++++---------- 2 files changed, 197 insertions(+), 197 deletions(-) diff --git a/tests/unit/sql/test_readonly_enforcement.py b/tests/unit/sql/test_readonly_enforcement.py index 334d5f0..4f2ba97 100644 --- a/tests/unit/sql/test_readonly_enforcement.py +++ b/tests/unit/sql/test_readonly_enforcement.py @@ -1,83 +1,83 @@ -from unittest.mock import AsyncMock -from unittest.mock import MagicMock -from unittest.mock import patch - -import pytest - -from postgres_mcp.server import AccessMode -from postgres_mcp.server import get_sql_driver -from postgres_mcp.sql import SafeSqlDriver -from postgres_mcp.sql import SqlDriver - - -@pytest.mark.asyncio -async def test_force_readonly_enforcement(): - """ - Test that force_readonly is properly enforced based on access mode: - - In RESTRICTED mode: force_readonly is always True regardless of what's passed - - In UNRESTRICTED mode: force_readonly respects the passed value (default False) - """ - # Create mock for connection pool - mock_conn_pool = MagicMock() - mock_conn_pool._is_valid = True - - # Create a mock for the base SqlDriver._execute_with_connection - mock_execute = AsyncMock() - mock_execute.return_value = [SqlDriver.RowResult(cells={"test": "value"})] - - # Test UNRESTRICTED mode - with patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), patch( - "postgres_mcp.server.connection_registry.get_connection", return_value=mock_conn_pool - ), patch.object(SqlDriver, "_execute_with_connection", mock_execute): - driver = await get_sql_driver(conn_name="default") - assert isinstance(driver, SqlDriver) - assert not isinstance(driver, SafeSqlDriver) - - # Test default behavior (should be False) - mock_execute.reset_mock() - await driver.execute_query("SELECT 1") - assert mock_execute.call_count == 1 - # Check that force_readonly is False by default - assert mock_execute.call_args[1]["force_readonly"] is False - - # Test explicit True - mock_execute.reset_mock() - await driver.execute_query("SELECT 1", force_readonly=True) - assert mock_execute.call_count == 1 - # Check that force_readonly=True is respected - assert mock_execute.call_args[1]["force_readonly"] is True - - # Test explicit False - mock_execute.reset_mock() - await driver.execute_query("SELECT 1", force_readonly=False) - assert mock_execute.call_count == 1 - # Check that force_readonly=False is respected - assert mock_execute.call_args[1]["force_readonly"] is False - - # Test RESTRICTED mode - with patch("postgres_mcp.server.current_access_mode", AccessMode.RESTRICTED), patch( - "postgres_mcp.server.connection_registry.get_connection", return_value=mock_conn_pool - ), patch.object(SqlDriver, "_execute_with_connection", mock_execute): - driver = await get_sql_driver(conn_name="default") - assert isinstance(driver, SafeSqlDriver) - - # Test default behavior - mock_execute.reset_mock() - await driver.execute_query("SELECT 1") - assert mock_execute.call_count == 1 - # Check that force_readonly is always True - assert mock_execute.call_args[1]["force_readonly"] is True - - # Test explicit False (should still be True) - mock_execute.reset_mock() - await driver.execute_query("SELECT 1", force_readonly=False) - assert mock_execute.call_count == 1 - # Check that force_readonly is True despite passing False - assert mock_execute.call_args[1]["force_readonly"] is True - - # Test explicit True - mock_execute.reset_mock() - await driver.execute_query("SELECT 1", force_readonly=True) - assert mock_execute.call_count == 1 - # Check that force_readonly remains True - assert mock_execute.call_args[1]["force_readonly"] is True +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +from postgres_mcp.server import AccessMode +from postgres_mcp.server import get_sql_driver +from postgres_mcp.sql import SafeSqlDriver +from postgres_mcp.sql import SqlDriver + + +@pytest.mark.asyncio +async def test_force_readonly_enforcement(): + """ + Test that force_readonly is properly enforced based on access mode: + - In RESTRICTED mode: force_readonly is always True regardless of what's passed + - In UNRESTRICTED mode: force_readonly respects the passed value (default False) + """ + # Create mock for connection pool + mock_conn_pool = MagicMock() + mock_conn_pool._is_valid = True + + # Create a mock for the base SqlDriver._execute_with_connection + mock_execute = AsyncMock() + mock_execute.return_value = [SqlDriver.RowResult(cells={"test": "value"})] + + # Test UNRESTRICTED mode + with patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), patch( + "postgres_mcp.server.connection_registry.get_connection", return_value=mock_conn_pool + ), patch.object(SqlDriver, "_execute_with_connection", mock_execute): + driver = await get_sql_driver(conn_name="default") + assert isinstance(driver, SqlDriver) + assert not isinstance(driver, SafeSqlDriver) + + # Test default behavior (should be False) + mock_execute.reset_mock() + await driver.execute_query("SELECT 1") + assert mock_execute.call_count == 1 + # Check that force_readonly is False by default + assert mock_execute.call_args[1]["force_readonly"] is False + + # Test explicit True + mock_execute.reset_mock() + await driver.execute_query("SELECT 1", force_readonly=True) + assert mock_execute.call_count == 1 + # Check that force_readonly=True is respected + assert mock_execute.call_args[1]["force_readonly"] is True + + # Test explicit False + mock_execute.reset_mock() + await driver.execute_query("SELECT 1", force_readonly=False) + assert mock_execute.call_count == 1 + # Check that force_readonly=False is respected + assert mock_execute.call_args[1]["force_readonly"] is False + + # Test RESTRICTED mode + with patch("postgres_mcp.server.current_access_mode", AccessMode.RESTRICTED), patch( + "postgres_mcp.server.connection_registry.get_connection", return_value=mock_conn_pool + ), patch.object(SqlDriver, "_execute_with_connection", mock_execute): + driver = await get_sql_driver(conn_name="default") + assert isinstance(driver, SafeSqlDriver) + + # Test default behavior + mock_execute.reset_mock() + await driver.execute_query("SELECT 1") + assert mock_execute.call_count == 1 + # Check that force_readonly is always True + assert mock_execute.call_args[1]["force_readonly"] is True + + # Test explicit False (should still be True) + mock_execute.reset_mock() + await driver.execute_query("SELECT 1", force_readonly=False) + assert mock_execute.call_count == 1 + # Check that force_readonly is True despite passing False + assert mock_execute.call_args[1]["force_readonly"] is True + + # Test explicit True + mock_execute.reset_mock() + await driver.execute_query("SELECT 1", force_readonly=True) + assert mock_execute.call_count == 1 + # Check that force_readonly remains True + assert mock_execute.call_args[1]["force_readonly"] is True diff --git a/tests/unit/test_access_mode.py b/tests/unit/test_access_mode.py index 84c2b99..b772e1d 100644 --- a/tests/unit/test_access_mode.py +++ b/tests/unit/test_access_mode.py @@ -1,114 +1,114 @@ -import asyncio -from unittest.mock import AsyncMock -from unittest.mock import MagicMock -from unittest.mock import patch - -import pytest - -from postgres_mcp.server import AccessMode -from postgres_mcp.server import get_sql_driver -from postgres_mcp.sql.safe_sql import SafeSqlDriver -from postgres_mcp.sql.sql_driver import DbConnPool -from postgres_mcp.sql.sql_driver import SqlDriver - - -@pytest.fixture -def mock_db_connection(): - """Mock database connection pool.""" - conn = MagicMock(spec=DbConnPool) - conn.is_valid = True - return conn - - -@pytest.mark.parametrize( - "access_mode,expected_driver_type", - [ - (AccessMode.UNRESTRICTED, SqlDriver), - (AccessMode.RESTRICTED, SafeSqlDriver), - ], -) -@pytest.mark.asyncio -async def test_get_sql_driver_returns_correct_driver(access_mode, expected_driver_type, mock_db_connection): - """Test that get_sql_driver returns the correct driver type based on access mode.""" - with ( - patch("postgres_mcp.server.current_access_mode", access_mode), - patch("postgres_mcp.server.connection_registry.get_connection", return_value=mock_db_connection), - ): - driver = await get_sql_driver(conn_name="default") - assert isinstance(driver, expected_driver_type) - - # When in RESTRICTED mode, verify timeout is set - if access_mode == AccessMode.RESTRICTED: - assert isinstance(driver, SafeSqlDriver) - assert driver.timeout == 30 - - -@pytest.mark.asyncio -async def test_get_sql_driver_sets_timeout_in_restricted_mode(mock_db_connection): - """Test that get_sql_driver sets the timeout in restricted mode.""" - with ( - patch("postgres_mcp.server.current_access_mode", AccessMode.RESTRICTED), - patch("postgres_mcp.server.connection_registry.get_connection", return_value=mock_db_connection), - ): - driver = await get_sql_driver(conn_name="default") - assert isinstance(driver, SafeSqlDriver) - assert driver.timeout == 30 - assert hasattr(driver, "sql_driver") - - -@pytest.mark.asyncio -async def test_get_sql_driver_in_unrestricted_mode_no_timeout(mock_db_connection): - """Test that get_sql_driver in unrestricted mode is a regular SqlDriver.""" - with ( - patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), - patch("postgres_mcp.server.connection_registry.get_connection", return_value=mock_db_connection), - ): - driver = await get_sql_driver(conn_name="default") - assert isinstance(driver, SqlDriver) - assert not hasattr(driver, "timeout") - - -@pytest.mark.asyncio -async def test_command_line_parsing(): - """Test that command-line arguments correctly set the access mode.""" - import sys - - from postgres_mcp.server import main - - # Mock sys.argv and asyncio.run - original_argv = sys.argv - original_run = asyncio.run - - try: - # Test with --access-mode=restricted - sys.argv = [ - "postgres_mcp", - "postgresql://user:password@localhost/db", - "--access-mode=restricted", - ] - asyncio.run = AsyncMock() - - with ( - patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), - patch("postgres_mcp.server.connection_registry.discover_and_connect", AsyncMock()), - patch("postgres_mcp.server.mcp.run_stdio_async", AsyncMock()), - patch("postgres_mcp.server.shutdown", AsyncMock()), - ): - # Reset the current_access_mode to UNRESTRICTED - import postgres_mcp.server - - postgres_mcp.server.current_access_mode = AccessMode.UNRESTRICTED - - # Run main (partially mocked to avoid actual connection) - try: - await main() - except Exception: - pass - - # Verify the mode was changed to RESTRICTED - assert postgres_mcp.server.current_access_mode == AccessMode.RESTRICTED - - finally: - # Restore original values - sys.argv = original_argv - asyncio.run = original_run +import asyncio +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +from postgres_mcp.server import AccessMode +from postgres_mcp.server import get_sql_driver +from postgres_mcp.sql.safe_sql import SafeSqlDriver +from postgres_mcp.sql.sql_driver import DbConnPool +from postgres_mcp.sql.sql_driver import SqlDriver + + +@pytest.fixture +def mock_db_connection(): + """Mock database connection pool.""" + conn = MagicMock(spec=DbConnPool) + conn.is_valid = True + return conn + + +@pytest.mark.parametrize( + "access_mode,expected_driver_type", + [ + (AccessMode.UNRESTRICTED, SqlDriver), + (AccessMode.RESTRICTED, SafeSqlDriver), + ], +) +@pytest.mark.asyncio +async def test_get_sql_driver_returns_correct_driver(access_mode, expected_driver_type, mock_db_connection): + """Test that get_sql_driver returns the correct driver type based on access mode.""" + with ( + patch("postgres_mcp.server.current_access_mode", access_mode), + patch("postgres_mcp.server.connection_registry.get_connection", return_value=mock_db_connection), + ): + driver = await get_sql_driver(conn_name="default") + assert isinstance(driver, expected_driver_type) + + # When in RESTRICTED mode, verify timeout is set + if access_mode == AccessMode.RESTRICTED: + assert isinstance(driver, SafeSqlDriver) + assert driver.timeout == 30 + + +@pytest.mark.asyncio +async def test_get_sql_driver_sets_timeout_in_restricted_mode(mock_db_connection): + """Test that get_sql_driver sets the timeout in restricted mode.""" + with ( + patch("postgres_mcp.server.current_access_mode", AccessMode.RESTRICTED), + patch("postgres_mcp.server.connection_registry.get_connection", return_value=mock_db_connection), + ): + driver = await get_sql_driver(conn_name="default") + assert isinstance(driver, SafeSqlDriver) + assert driver.timeout == 30 + assert hasattr(driver, "sql_driver") + + +@pytest.mark.asyncio +async def test_get_sql_driver_in_unrestricted_mode_no_timeout(mock_db_connection): + """Test that get_sql_driver in unrestricted mode is a regular SqlDriver.""" + with ( + patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), + patch("postgres_mcp.server.connection_registry.get_connection", return_value=mock_db_connection), + ): + driver = await get_sql_driver(conn_name="default") + assert isinstance(driver, SqlDriver) + assert not hasattr(driver, "timeout") + + +@pytest.mark.asyncio +async def test_command_line_parsing(): + """Test that command-line arguments correctly set the access mode.""" + import sys + + from postgres_mcp.server import main + + # Mock sys.argv and asyncio.run + original_argv = sys.argv + original_run = asyncio.run + + try: + # Test with --access-mode=restricted + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + "--access-mode=restricted", + ] + asyncio.run = AsyncMock() + + with ( + patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), + patch("postgres_mcp.server.connection_registry.discover_and_connect", AsyncMock()), + patch("postgres_mcp.server.mcp.run_stdio_async", AsyncMock()), + patch("postgres_mcp.server.shutdown", AsyncMock()), + ): + # Reset the current_access_mode to UNRESTRICTED + import postgres_mcp.server + + postgres_mcp.server.current_access_mode = AccessMode.UNRESTRICTED + + # Run main (partially mocked to avoid actual connection) + try: + await main() + except Exception: + pass + + # Verify the mode was changed to RESTRICTED + assert postgres_mcp.server.current_access_mode == AccessMode.RESTRICTED + + finally: + # Restore original values + sys.argv = original_argv + asyncio.run = original_run From a826b78bda2060ae45421d80d6026258c7b26c8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20C=2E=20Andersen?= Date: Sat, 18 Oct 2025 13:09:23 +0200 Subject: [PATCH 3/8] Remove unnecessary connection pool invalidation in execute_query() The psycopg AsyncConnectionPool automatically handles connection recovery by discarding broken connections and creating new ones. Manual invalidation during query execution is redundant and can cause misleading error messages for SQL errors (missing columns, tables, etc.). The _is_valid flag should only be managed during pool lifecycle events (creation in pool_connect() and closure in close()), not during normal query execution. --- src/postgres_mcp/sql/sql_driver.py | 36 +++++++++++------------------- 1 file changed, 13 insertions(+), 23 deletions(-) diff --git a/src/postgres_mcp/sql/sql_driver.py b/src/postgres_mcp/sql/sql_driver.py index 2e61253..f15a69b 100644 --- a/src/postgres_mcp/sql/sql_driver.py +++ b/src/postgres_mcp/sql/sql_driver.py @@ -362,30 +362,20 @@ async def execute_query( Returns: List of RowResult objects or None on error """ - try: + if self.conn is None: + self.connect() if self.conn is None: - self.connect() - if self.conn is None: - raise ValueError("Connection not established") - - # Handle connection pool vs direct connection - if self.is_pool: - # For pools, get a connection from the pool - pool = await self.conn.pool_connect() - async with pool.connection() as connection: - return await self._execute_with_connection(connection, query, params, force_readonly=force_readonly) - else: - # Direct connection approach - return await self._execute_with_connection(self.conn, query, params, force_readonly=force_readonly) - except Exception as e: - # Mark pool as invalid if there was a connection issue - if self.conn and self.is_pool: - self.conn._is_valid = False # type: ignore - self.conn._last_error = str(e) # type: ignore - elif self.conn and not self.is_pool: - self.conn = None - - raise e + raise ValueError("Connection not established") + + # Handle connection pool vs direct connection + if self.is_pool: + # For pools, get a connection from the pool + pool = await self.conn.pool_connect() + async with pool.connection() as connection: + return await self._execute_with_connection(connection, query, params, force_readonly=force_readonly) + else: + # Direct connection approach + return await self._execute_with_connection(self.conn, query, params, force_readonly=force_readonly) async def _execute_with_connection(self, connection, query, params, force_readonly) -> Optional[List[RowResult]]: """Execute query with the given connection.""" From a1269788f5c2b8cf4821be291d0425013f31bcd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20C=2E=20Andersen?= Date: Sat, 18 Oct 2025 13:23:39 +0200 Subject: [PATCH 4/8] Minor adjustment to README. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b2b200f..e5e30dd 100644 --- a/README.md +++ b/README.md @@ -245,7 +245,7 @@ Each connection is identified by its name (the part after `DATABASE_URI_`, conve - Automatically displayed in the server context (visible to the AI without requiring a tool call) - Useful for guiding the AI to select the appropriate database -When using tools, you'll specify which connection to use via the `conn_name` parameter: +When using tools, the LLM will specify which connection to use via the `conn_name` parameter: - `list_schemas(conn_name="app")` - Lists schemas in the app database - `explain_query(conn_name="etl", sql="SELECT ...")` - Explains query in the ETL database From e6cc147d94e72e864e0ce0cdc56ab6968b139b24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20C=2E=20Andersen?= Date: Mon, 20 Oct 2025 15:20:20 +0200 Subject: [PATCH 5/8] Fetching configured connection info early. This avoids sneaking the instructions via private variables. (It was even done incorrectly, in my earlier implementation.) --- src/postgres_mcp/env_utils.py | 51 ++++++++++++++++++++++++++++++ src/postgres_mcp/server.py | 37 ++++++++++++++++++++-- src/postgres_mcp/sql/sql_driver.py | 29 +++-------------- 3 files changed, 90 insertions(+), 27 deletions(-) create mode 100644 src/postgres_mcp/env_utils.py diff --git a/src/postgres_mcp/env_utils.py b/src/postgres_mcp/env_utils.py new file mode 100644 index 0000000..88e73c3 --- /dev/null +++ b/src/postgres_mcp/env_utils.py @@ -0,0 +1,51 @@ +"""Utility functions for environment variable handling.""" + +import os + + +def discover_database_connections() -> dict[str, str]: + """ + Discover all DATABASE_URI_* environment variables. + + Returns: + Dict mapping connection names to connection URLs + - DATABASE_URI -> "default" + - DATABASE_URI_APP -> "app" + - DATABASE_URI_ETL -> "etl" + """ + discovered = {} + + for env_var, url in os.environ.items(): + if env_var == "DATABASE_URI": + discovered["default"] = url + elif env_var.startswith("DATABASE_URI_"): + # Extract postfix and lowercase it + postfix = env_var[len("DATABASE_URI_") :] + conn_name = postfix.lower() + discovered[conn_name] = url + + return discovered + + +def discover_database_descriptions() -> dict[str, str]: + """ + Discover all DATABASE_DESC_* environment variables. + + Returns: + Dict mapping connection names to descriptions + - DATABASE_DESC -> "default" + - DATABASE_DESC_APP -> "app" + - DATABASE_DESC_ETL -> "etl" + """ + descriptions = {} + + for env_var, desc in os.environ.items(): + if env_var == "DATABASE_DESC": + descriptions["default"] = desc + elif env_var.startswith("DATABASE_DESC_"): + # Extract postfix and lowercase it + postfix = env_var[len("DATABASE_DESC_") :] + conn_name = postfix.lower() + descriptions[conn_name] = desc + + return descriptions diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index 37cccf6..170bc14 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -33,10 +33,41 @@ from .sql import check_hypopg_installation_status from .sql import obfuscate_password from .top_queries import TopQueriesCalc +from .env_utils import discover_database_connections +from .env_utils import discover_database_descriptions -# Initialize FastMCP with default settings -# Note: Server instructions will be updated after database connections are discovered -mcp = FastMCP("postgres-mcp") + +INSTRUCTIONS_TEMPLATE = """\ +This PostgreSQL MCP Lite server gives (un)restricted DB access via one or more connection strings. + +Available database connections: +{conn_list} +""" + + +def build_instructions() -> str: + """Build server instructions including available connections.""" + # Discover connections from environment variables + conn_urls = discover_database_connections() + conn_descs = discover_database_descriptions() + + # Build connection list + if not conn_urls: + conn_list = "- No connections configured (set DATABASE_URI environment variable)" + else: + conn_items = [] + for name in sorted(conn_urls.keys()): + desc = conn_descs.get(name, "") + if desc: + conn_items.append(f"- '{name}': {desc}") + else: + conn_items.append(f"- '{name}'") + conn_list = "\n".join(conn_items) + + instructions = INSTRUCTIONS_TEMPLATE.format(conn_list=conn_list) + return instructions + +mcp = FastMCP("postgres-mcp", instructions=build_instructions()) # Constants PG_STAT_STATEMENTS = "pg_stat_statements" diff --git a/src/postgres_mcp/sql/sql_driver.py b/src/postgres_mcp/sql/sql_driver.py index f15a69b..61bb6b2 100644 --- a/src/postgres_mcp/sql/sql_driver.py +++ b/src/postgres_mcp/sql/sql_driver.py @@ -16,6 +16,9 @@ from psycopg_pool import AsyncConnectionPool from typing_extensions import LiteralString +from ..env_utils import discover_database_connections +from ..env_utils import discover_database_descriptions + logger = logging.getLogger(__name__) @@ -156,18 +159,7 @@ def discover_connections(self) -> Dict[str, str]: - DATABASE_URI_APP -> "app" - DATABASE_URI_ETL -> "etl" """ - discovered = {} - - for env_var, url in os.environ.items(): - if env_var == "DATABASE_URI": - discovered["default"] = url - elif env_var.startswith("DATABASE_URI_"): - # Extract postfix and lowercase it - postfix = env_var[len("DATABASE_URI_"):] - conn_name = postfix.lower() - discovered[conn_name] = url - - return discovered + return discover_database_connections() def discover_descriptions(self) -> Dict[str, str]: """ @@ -179,18 +171,7 @@ def discover_descriptions(self) -> Dict[str, str]: - DATABASE_DESC_APP -> "app" - DATABASE_DESC_ETL -> "etl" """ - descriptions = {} - - for env_var, desc in os.environ.items(): - if env_var == "DATABASE_DESC": - descriptions["default"] = desc - elif env_var.startswith("DATABASE_DESC_"): - # Extract postfix and lowercase it - postfix = env_var[len("DATABASE_DESC_"):] - conn_name = postfix.lower() - descriptions[conn_name] = desc - - return descriptions + return discover_database_descriptions() async def discover_and_connect(self) -> None: """ From f35856260caf2cf1c7c77bcf5fbdb73609b96ca3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20C=2E=20Andersen?= Date: Mon, 20 Oct 2025 15:47:32 +0200 Subject: [PATCH 6/8] Fixing ruff issues. --- src/postgres_mcp/server.py | 7 +++---- src/postgres_mcp/sql/sql_driver.py | 18 ++++-------------- 2 files changed, 7 insertions(+), 18 deletions(-) diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index 170bc14..1d87c4d 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -22,20 +22,18 @@ from .artifacts import ExplainPlanArtifact from .database_health import DatabaseHealthTool from .database_health import HealthType +from .env_utils import discover_database_connections +from .env_utils import discover_database_descriptions from .explain import ExplainPlanTool from .index.index_opt_base import MAX_NUM_INDEX_TUNING_QUERIES from .index.llm_opt import LLMOptimizerTool from .index.presentation import TextPresentation from .sql import ConnectionRegistry -from .sql import DbConnPool from .sql import SafeSqlDriver from .sql import SqlDriver from .sql import check_hypopg_installation_status from .sql import obfuscate_password from .top_queries import TopQueriesCalc -from .env_utils import discover_database_connections -from .env_utils import discover_database_descriptions - INSTRUCTIONS_TEMPLATE = """\ This PostgreSQL MCP Lite server gives (un)restricted DB access via one or more connection strings. @@ -67,6 +65,7 @@ def build_instructions() -> str: instructions = INSTRUCTIONS_TEMPLATE.format(conn_list=conn_list) return instructions + mcp = FastMCP("postgres-mcp", instructions=build_instructions()) # Constants diff --git a/src/postgres_mcp/sql/sql_driver.py b/src/postgres_mcp/sql/sql_driver.py index 61bb6b2..2d0683a 100644 --- a/src/postgres_mcp/sql/sql_driver.py +++ b/src/postgres_mcp/sql/sql_driver.py @@ -2,7 +2,6 @@ import asyncio import logging -import os import re from dataclasses import dataclass from typing import Any @@ -181,9 +180,7 @@ async def discover_and_connect(self) -> None: discovered = self.discover_connections() if not discovered: - raise ValueError( - "No database connections found. Please set DATABASE_URI or DATABASE_URI_* environment variables." - ) + raise ValueError("No database connections found. Please set DATABASE_URI or DATABASE_URI_* environment variables.") logger.info(f"Discovered {len(discovered)} database connection(s): {', '.join(discovered.keys())}") @@ -207,10 +204,7 @@ async def connect_single(conn_name: str, pool: DbConnPool) -> tuple[str, bool, O return (conn_name, False, error_msg) # Execute all connections in parallel - results = await asyncio.gather( - *[connect_single(name, pool) for name, pool in self.connections.items()], - return_exceptions=False - ) + results = await asyncio.gather(*[connect_single(name, pool) for name, pool in self.connections.items()], return_exceptions=False) # Log results for conn_name, success, error in results: @@ -234,18 +228,14 @@ def get_connection(self, conn_name: str) -> DbConnPool: """ if conn_name not in self.connections: available = ", ".join(f"'{name}'" for name in sorted(self.connections.keys())) - raise ValueError( - f"Connection '{conn_name}' not found. Available connections: {available}" - ) + raise ValueError(f"Connection '{conn_name}' not found. Available connections: {available}") pool = self.connections[conn_name] # Check if connection is valid if not pool.is_valid: error_msg = pool.last_error or "Unknown error" - raise ValueError( - f"Connection '{conn_name}' is not available: {obfuscate_password(error_msg)}" - ) + raise ValueError(f"Connection '{conn_name}' is not available: {obfuscate_password(error_msg)}") return pool From 045b0e4d70909beafdc2761439fde66d361681e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20C=2E=20Andersen?= Date: Mon, 20 Oct 2025 16:00:42 +0200 Subject: [PATCH 7/8] Removing old instruction overwriting. --- src/postgres_mcp/server.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index 1d87c4d..646c0a0 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -620,20 +620,6 @@ async def main(): await connection_registry.discover_and_connect() conn_names = connection_registry.get_connection_names() logger.info(f"Successfully initialized {len(conn_names)} connection(s): {', '.join(conn_names)}") - - # Update server context with available connections - conn_info = connection_registry.get_connection_info() - if conn_info: - instructions = ["Available database connections:"] - for info in conn_info: - if "description" in info: - instructions.append(f"- {info['name']}: {info['description']}") - else: - instructions.append(f"- {info['name']}") - - # Set the server instructions to include connection information - mcp._instructions = "\n".join(instructions) - logger.info(f"Updated server context with {len(conn_info)} connection(s)") except Exception as e: logger.warning( f"Could not initialize database connections: {obfuscate_password(str(e))}", From e641e18d73239bf055c84e90959df6b2d66328af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20C=2E=20Andersen?= Date: Mon, 20 Oct 2025 19:38:48 +0200 Subject: [PATCH 8/8] Adds multiple database connection support Enables the application to connect to multiple databases using the `--db` command-line argument. Validates the database connection names and URLs. Supports setting database connections via environment variables, with environment variables taking precedence over command-line arguments. For backwards compatibility, continues to support a default `DATABASE_URI` from a positional argument. --- src/postgres_mcp/server.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index 646c0a0..27fa572 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -3,6 +3,7 @@ import asyncio import logging import os +import re import signal import sys from enum import Enum @@ -595,6 +596,12 @@ async def main(): default=8000, help="Port for SSE server (default: 8000)", ) + parser.add_argument( + "--db", + action="append", + metavar="NAME=URL", + help="Database connection (can be repeated): --db prod=postgresql://... --db staging=postgresql://...", + ) args = parser.parse_args() @@ -614,7 +621,30 @@ async def main(): # For backwards compatibility, support command-line database_url argument if args.database_url and "DATABASE_URI" not in os.environ: os.environ["DATABASE_URI"] = args.database_url - logger.info("Using command-line database URL as DATABASE_URI") + logger.info("Set default database connection from positional argument") + + # Process --db arguments with validation + if args.db: + for db_spec in args.db: + if "=" not in db_spec: + logger.error(f"Invalid --db format: '{db_spec}'. Expected NAME=URL") + sys.exit(1) + + name, url = db_spec.split("=", 1) + name = name.strip().upper() + + # Validate name contains only alphanumeric and underscore + if not re.match(r"^[A-Z0-9_]+$", name): + logger.error(f"Invalid connection name '{name}'. Only alphanumeric characters and underscores allowed.") + sys.exit(1) + + # Check if already set in environment (env vars take precedence) + env_var = f"DATABASE_URI_{name}" if name != "DEFAULT" else "DATABASE_URI" + if env_var in os.environ: + logger.info(f"Skipping --db {name}=... (already set via {env_var} environment variable)") + else: + os.environ[env_var] = url + logger.info(f"Set database connection '{name.lower()}' from command-line argument") try: await connection_registry.discover_and_connect()