diff --git a/docs/extensions/litestar/api.rst b/docs/extensions/litestar/api.rst index 49f41ee45..d2c157ede 100644 --- a/docs/extensions/litestar/api.rst +++ b/docs/extensions/litestar/api.rst @@ -34,7 +34,8 @@ Configure the plugin via ``extension_config`` in database configuration: "commit_mode": "autocommit", "extra_commit_statuses": {201, 204}, "extra_rollback_statuses": {409}, - "enable_correlation_middleware": True + "enable_correlation_middleware": True, + "correlation_header": "x-correlation-id", } } ) @@ -74,10 +75,22 @@ Configuration Options - ``set[int]`` - ``None`` - Additional HTTP status codes that trigger rollbacks - * - ``enable_correlation_middleware`` + * - ``enable_correlation_middleware`` - ``bool`` - ``True`` - Enable request correlation tracking + * - ``correlation_header`` + - ``str`` + - ``"X-Request-ID"`` + - HTTP header to read when populating the correlation ID middleware + * - ``correlation_headers`` + - ``list[str]`` + - ``[]`` + - Additional headers to consider (auto-detected headers are appended unless disabled) + * - ``auto_trace_headers`` + - ``bool`` + - ``True`` + - Toggle automatic detection of standard tracing headers (`Traceparent`, `X-Cloud-Trace-Context`, etc.) Session Stores ============== diff --git a/docs/guides/architecture/architecture.md b/docs/guides/architecture/architecture.md index 2caa07e36..1dc3bc088 100644 --- a/docs/guides/architecture/architecture.md +++ b/docs/guides/architecture/architecture.md @@ -14,6 +14,7 @@ orphan: true 4. [Driver Implementation](#driver-implementation) 5. [Parameter Handling](#parameter-handling) 6. [Testing & Development](#testing--development) +7. [Observability Runtime](#observability-runtime) --- @@ -308,3 +309,12 @@ make install # Standard development installation 2. Implement the `config.py` and `driver.py` files. 3. Add integration tests for the new adapter. 4. Document any special cases or configurations. + +## Observability Runtime + +The observability subsystem (lifecycle dispatcher, statement observers, span manager, diagnostics) now sits alongside the driver architecture. Refer to the dedicated [Observability Runtime guide](./observability.md) for: + +- configuration sources (`ObservabilityConfig`, adapter overrides, and `driver_features` compatibility), +- the full list of lifecycle events emitted by SQLSpec, +- guidance on statement observers, redaction, and OpenTelemetry spans, +- the Phase 4/5 roadmap for spans + diagnostics. diff --git a/docs/guides/architecture/observability.md b/docs/guides/architecture/observability.md new file mode 100644 index 000000000..168078a01 --- /dev/null +++ b/docs/guides/architecture/observability.md @@ -0,0 +1,136 @@ +# SQLSpec Observability Runtime + +This guide explains how the consolidated observability stack works after the Lifecycle Dispatcher + Statement Observer integration. Use it as the single source of truth when wiring new adapters, features, or docs. + +## Goals + +1. **Unified Hooks** – every pool, connection, session, and query event is emitted through one dispatcher with zero work when no listeners exist. +2. **Structured Statement Events** – observers receive normalized payloads (`StatementEvent`) for printing, logging, or exporting to tracing systems. +3. **Optional OpenTelemetry Spans** – span creation is lazy and never imports `opentelemetry` unless spans are enabled. +4. **Diagnostics** – storage bridge + serializer metrics + lifecycle counters roll up under `SQLSpec.telemetry_snapshot()` (Phase 5). +5. **Loader & Migration Telemetry** – SQL file loader, caching, and migration runners emit metrics/spans without additional plumbing (Phase 7). + +## Configuration Sources + +There are three ways to enable observability today: + +1. **Registry-Level** – pass `observability_config=ObservabilityConfig(...)` to `SQLSpec()`. +2. **Adapter Override** – each config constructor accepts `observability_config=` for adapter-specific knobs. +3. **`driver_features` Compatibility** – existing keys such as `"on_connection_create"`, `"on_pool_destroy"`, and `"on_session_start"` are automatically promoted into lifecycle observers, so user-facing APIs do **not** change. + +```python +from sqlspec import SQLSpec +from sqlspec.adapters.duckdb import DuckDBConfig + +def ensure_extensions(connection): + connection.execute("INSTALL http_client; LOAD http_client;") + +config = DuckDBConfig( + pool_config={"database": ":memory:"}, + driver_features={ + "extensions": [{"name": "http_client"}], + "on_connection_create": ensure_extensions, # promoted to observability runtime + }, +) + +sql = SQLSpec(observability_config=ObservabilityConfig(print_sql=True)) +sql.add_config(config) +``` + +> **Implementation note:** During config initialization we inspect `driver_features` for known hook keys and wrap them into `ObservabilityConfig` callbacks. Hooks that accepted a raw resource (e.g., connection) continue to do so without additional adapter plumbing. + +## Lifecycle Events + +The dispatcher exposes the following events (all opt-in and guard-checked): + +| Event | Context contents | +| --- | --- | +| `on_pool_create` / `on_pool_destroy` | `pool`, `config`, `bind_key`, `correlation_id` | +| `on_connection_create` / `on_connection_destroy` | `connection`, plus base context | +| `on_session_start` / `on_session_end` | `session` / driver instance | +| `on_query_start` / `on_query_complete` | SQL text, parameters, metadata | +| `on_error` | `exception` plus last query context | + +`SQLSpec.provide_connection()` and `SQLSpec.provide_session()` now emit these events automatically, regardless of whether the caller uses registry helpers or adapter helpers directly. + +## Statement Observers & Print SQL + +Statement observers receive `StatementEvent` objects. Typical uses: + +* enable `print_sql=True` to attach the built-in logger. +* add custom redaction rules via `RedactionConfig` (mask parameters, mask literals, allow-list names). +* forward events to bespoke loggers or telemetry exporters. + +```python +def log_statement(event: StatementEvent) -> None: + logger.info("%s (%s) -> %ss", event.operation, event.driver, event.duration_s) + +ObservabilityConfig( + print_sql=False, + statement_observers=(log_statement,), + redaction=RedactionConfig(mask_parameters=True, parameter_allow_list=("tenant_id",)), +) +``` + +### Optional Exporters (OpenTelemetry & Prometheus) + +Two helper modules wire optional dependencies into the runtime without forcing unconditional imports: + +* `sqlspec.extensions.otel.enable_tracing()` ensures `opentelemetry-api` is installed, then returns an `ObservabilityConfig` whose `TelemetryConfig` enables spans and (optionally) injects a tracer provider factory. +* `sqlspec.extensions.prometheus.enable_metrics()` ensures `prometheus-client` is installed and appends a `PrometheusStatementObserver` that emits counters and histograms for every `StatementEvent`. + +Both helpers rely on the conditional stubs defined in `sqlspec/typing.py`, so they remain safe to import even when the extras are absent. + +```python +from sqlspec.extensions import otel, prometheus + +config = otel.enable_tracing(resource_attributes={"service.name": "orders-api"}) +config = prometheus.enable_metrics(base_config=config, label_names=("driver", "operation", "adapter")) +sql = SQLSpec(observability_config=config) +``` + +You can also opt in per adapter by passing `extension_config["otel"]` or `extension_config["prometheus"]` when constructing a config; the helpers above are invoked automatically during initialization. + +## Loader & Migration Telemetry + +`SQLSpec` instantiates a dedicated `ObservabilityRuntime` for the SQL file loader and shares it with every migration command/runner. Instrumentation highlights: + +- Loader metrics such as `SQLFileLoader.loader.load.invocations`, `.cache.hit`, `.files.loaded`, `.statements.loaded`, and `.directories.scanned` fire automatically when queries are loaded or cache state is inspected. +- Migration runners publish cache stats (`{Config}.migrations.listing.cache_hit`, `.cache_miss`, `.metadata.cache_hit`), command metrics (`{Config}.migrations.command.upgrade.invocations`, `.downgrade.errors`), and per-migration execution metrics (`{Config}.migrations.upgrade.duration_ms`, `.downgrade.applied`). +- Command and migration spans (`sqlspec.migration.command.upgrade`, `sqlspec.migration.upgrade`) include version numbers, bind keys, and correlation IDs; they end with duration attributes even when exceptions occur. + +All metrics surface through `SQLSpec.telemetry_snapshot()` under the adapter key, so exporters observe a flat counter space regardless of which subsystem produced the events. + +## Span Manager & Diagnostics + +* **Span Manager:** Query spans ship today, lifecycle events emit `sqlspec.lifecycle.*` spans, storage bridge helpers wrap reads/writes with `sqlspec.storage.*` spans, and migration runners create `sqlspec.migration.*` spans for both commands and individual revisions. Mocked span tests live in `tests/unit/test_observability.py`. +* **Diagnostics:** `TelemetryDiagnostics` aggregates lifecycle counters, loader/migration metrics, storage bridge telemetry, and serializer cache stats. Storage telemetry carries backend IDs, bind key, and correlation IDs so snapshots/spans inherit the same context, and `SQLSpec.telemetry_snapshot()` exposes that data via flat counters plus a `storage_bridge.recent_jobs` list detailing the last 25 operations. + +Example snapshot payload: + +``` +{ + "storage_bridge.bytes_written": 2048, + "storage_bridge.recent_jobs": [ + { + "destination": "alias://warehouse/users.parquet", + "backend": "s3", + "bytes_processed": 2048, + "rows_processed": 16, + "config": "AsyncpgConfig", + "bind_key": "analytics", + "correlation_id": "8f64c0f6", + "format": "parquet" + } + ], + "serializer.hits": 12, + "serializer.misses": 2, + "AsyncpgConfig.lifecycle.query_start": 4 +} +``` + +## Next Steps (2025 Q4) + +1. **Exporter Validation:** Exercise the OpenTelemetry/Prometheus helpers against the new loader + migration metrics and document recommended dashboards. +2. **Adapter Audit:** Confirm every adapter’s migration tracker benefits from the instrumentation (especially Oracle/BigQuery fixtures) and extend coverage where needed. +3. **Performance Budgets:** Add guard-path benchmarks/tests to ensure disabled observability remains near-zero overhead now that migration/loader events emit metrics by default. diff --git a/docs/guides/extensions/litestar.md b/docs/guides/extensions/litestar.md index da141cb96..5556e297c 100644 --- a/docs/guides/extensions/litestar.md +++ b/docs/guides/extensions/litestar.md @@ -13,7 +13,7 @@ Explains how to wire SQLSpec into Litestar using the official plugin, covering d - Commit strategies: `manual`, `autocommit`, and `autocommit_include_redirect`, configured via `extension_config["litestar"]["commit_mode"]`. - Session storage uses adapter-specific stores built on `BaseSQLSpecStore` (e.g., `AsyncpgStore`, `AiosqliteStore`). - CLI support registers `litestar db ...` commands by including `database_group` in the Litestar CLI app. -- Correlation middleware emits request IDs in query logs (`enable_correlation_middleware=True` by default). +- Correlation middleware emits request IDs in query logs (`enable_correlation_middleware=True` by default). It auto-detects standard tracing headers (`X-Request-ID`, `Traceparent`, `X-Cloud-Trace-Context`, `X-Amzn-Trace-Id`, etc.) unless you override the set via `correlation_header` / `correlation_headers`. ## Installation @@ -95,6 +95,10 @@ config = AsyncpgConfig( } }, ) + +## Correlation IDs + +Enable request-level correlation tracking (on by default) to thread Litestar requests into SQLSpec's observability runtime. The plugin inspects `X-Request-ID`, `Traceparent`, `X-Cloud-Trace-Context`, `X-Amzn-Trace-Id`, `grpc-trace-bin`, and `X-Correlation-ID` automatically, then falls back to generating a UUID if none are present. Override the primary header with `correlation_header`, append more via `correlation_headers`, or set `auto_trace_headers=False` to opt out of the auto-detection list entirely. Observers (print SQL, custom hooks, OpenTelemetry spans) automatically attach the current `correlation_id` to their payloads. Disable the middleware with `enable_correlation_middleware=False` when another piece of infrastructure manages IDs. ``` ## Transaction Management @@ -162,7 +166,7 @@ Commands include `db migrate`, `db upgrade`, `db downgrade`, and `db status`. Th ## Middleware and Observability -- Correlation middleware annotates query logs with request-scoped IDs. Disable by setting `enable_correlation_middleware=False`. +- Correlation middleware annotates query logs with request-scoped IDs. Disable by setting `enable_correlation_middleware=False`, override the primary header via `correlation_header`, add more with `correlation_headers`, or disable auto-detection using `auto_trace_headers=False`. - The plugin enforces graceful shutdown by closing pools during Litestar’s lifespan events. - Combine with Litestar’s `TelemetryConfig` to emit tracing spans around database calls. diff --git a/docs/usage/configuration.rst b/docs/usage/configuration.rst index 5ede2c5d2..a53af5aa8 100644 --- a/docs/usage/configuration.rst +++ b/docs/usage/configuration.rst @@ -497,10 +497,25 @@ Litestar Plugin Configuration "pool_key": "db_pool", "commit_mode": "autocommit", "enable_correlation_middleware": True, + "correlation_header": "x-correlation-id", + "correlation_headers": ["x-custom-trace"], + "auto_trace_headers": True, # Detect Traceparent, X-Cloud-Trace-Context, etc. } } ) +Telemetry Snapshot +~~~~~~~~~~~~~~~~~~ + +Call ``SQLSpec.telemetry_snapshot()`` to inspect lifecycle counters, serializer metrics, and recent storage jobs: + +.. code-block:: python + + snapshot = spec.telemetry_snapshot() + print(snapshot["storage_bridge.bytes_written"]) + for job in snapshot.get("storage_bridge.recent_jobs", []): + print(job["destination"], job.get("correlation_id")) + Environment-Based Configuration ------------------------------- diff --git a/docs/usage/framework_integrations.rst b/docs/usage/framework_integrations.rst index 419268f1e..5860fd6f6 100644 --- a/docs/usage/framework_integrations.rst +++ b/docs/usage/framework_integrations.rst @@ -339,13 +339,16 @@ Enable request correlation tracking via ``extension_config``: pool_config={"dsn": "postgresql://..."}, extension_config={ "litestar": { - "enable_correlation_middleware": True # Default: True - } - } - ) - ) - - # Queries will include correlation IDs in logs + "enable_correlation_middleware": True, # Default: True + "correlation_header": "x-request-id", + "correlation_headers": ["x-client-trace"], + "auto_trace_headers": True, + } + } + ) + ) + + # Queries will include correlation IDs in logs (header or generated UUID) # Format: [correlation_id=abc123] SELECT * FROM users FastAPI Integration diff --git a/pyproject.toml b/pyproject.toml index 9ef0d316b..d7e5d4ddc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -179,6 +179,14 @@ include = [ "sqlspec/utils/fixtures.py", # File fixture loading "sqlspec/utils/data_transformation.py", # Data transformation utilities + # === OBSERVABILITY === + "sqlspec/observability/_config.py", + "sqlspec/observability/_diagnostics.py", + "sqlspec/observability/_dispatcher.py", + "sqlspec/observability/_observer.py", + "sqlspec/observability/_runtime.py", + "sqlspec/observability/_spans.py", + # === STORAGE LAYER === "sqlspec/storage/_utils.py", "sqlspec/storage/registry.py", diff --git a/sqlspec/_typing.py b/sqlspec/_typing.py index 23c628820..30ac292de 100644 --- a/sqlspec/_typing.py +++ b/sqlspec/_typing.py @@ -568,6 +568,9 @@ def get_tracer( ) -> Tracer: return Tracer() # type: ignore[abstract] # pragma: no cover + def get_tracer_provider(self) -> Any: # pragma: no cover + return None + TracerProvider = type(None) # Shim for TracerProvider if needed elsewhere StatusCode = type(None) # Shim for StatusCode Status = type(None) # Shim for Status @@ -600,6 +603,8 @@ def __init__( unit: str = "", registry: Any = None, ejemplar_fn: Any = None, + buckets: Any = None, + **_: Any, ) -> None: return None diff --git a/sqlspec/adapters/adbc/config.py b/sqlspec/adapters/adbc/config.py index 22a9dd4f4..5d503d6d8 100644 --- a/sqlspec/adapters/adbc/config.py +++ b/sqlspec/adapters/adbc/config.py @@ -9,7 +9,7 @@ from sqlspec.adapters.adbc._types import AdbcConnection from sqlspec.adapters.adbc.driver import AdbcCursor, AdbcDriver, AdbcExceptionHandler, get_adbc_statement_config -from sqlspec.config import ADKConfig, FastAPIConfig, FlaskConfig, LitestarConfig, NoPoolSyncConfig, StarletteConfig +from sqlspec.config import ExtensionConfigs, NoPoolSyncConfig from sqlspec.core import StatementConfig from sqlspec.exceptions import ImproperConfigurationError from sqlspec.utils.module_loader import import_string @@ -21,6 +21,8 @@ from sqlglot.dialects.dialect import DialectType + from sqlspec.observability import ObservabilityConfig + logger = logging.getLogger("sqlspec.adapters.adbc") @@ -116,7 +118,8 @@ def __init__( statement_config: StatementConfig | None = None, driver_features: "AdbcDriverFeatures | dict[str, Any] | None" = None, bind_key: str | None = None, - extension_config: "dict[str, dict[str, Any]] | LitestarConfig | FastAPIConfig | StarletteConfig | FlaskConfig | ADKConfig | None" = None, + extension_config: "ExtensionConfigs | None" = None, + observability_config: "ObservabilityConfig | None" = None, ) -> None: """Initialize configuration. @@ -127,6 +130,7 @@ def __init__( driver_features: Driver feature configuration (AdbcDriverFeatures) bind_key: Optional unique identifier for this configuration extension_config: Extension-specific configuration (e.g., Litestar plugin settings) + observability_config: Adapter-level observability overrides for lifecycle hooks and observers """ if connection_config is None: connection_config = {} @@ -168,6 +172,7 @@ def __init__( driver_features=processed_driver_features, bind_key=bind_key, extension_config=extension_config, + observability_config=observability_config, ) def _resolve_driver_name(self) -> str: @@ -366,9 +371,10 @@ def session_manager() -> "Generator[AdbcDriver, None, None]": or self.statement_config or get_adbc_statement_config(str(self._get_dialect() or "sqlite")) ) - yield self.driver_type( + driver = self.driver_type( connection=connection, statement_config=final_statement_config, driver_features=self.driver_features ) + yield self._prepare_driver(driver) return session_manager() diff --git a/sqlspec/adapters/adbc/driver.py b/sqlspec/adapters/adbc/driver.py index f32f2acdf..2da43c467 100644 --- a/sqlspec/adapters/adbc/driver.py +++ b/sqlspec/adapters/adbc/driver.py @@ -706,8 +706,8 @@ def select_to_storage( self._require_capability("arrow_export_enabled") arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) sync_pipeline: SyncStoragePipeline = cast("SyncStoragePipeline", self._storage_pipeline()) - telemetry_payload = arrow_result.write_to_storage_sync( - destination, format_hint=format_hint, pipeline=sync_pipeline + telemetry_payload = self._write_result_to_storage_sync( + arrow_result, destination, format_hint=format_hint, pipeline=sync_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry) diff --git a/sqlspec/adapters/aiosqlite/config.py b/sqlspec/adapters/aiosqlite/config.py index 8bc4aa2db..68befc31e 100644 --- a/sqlspec/adapters/aiosqlite/config.py +++ b/sqlspec/adapters/aiosqlite/config.py @@ -20,13 +20,14 @@ AiosqlitePoolConnection, ) from sqlspec.adapters.sqlite._type_handlers import register_type_handlers -from sqlspec.config import ADKConfig, AsyncDatabaseConfig, FastAPIConfig, FlaskConfig, LitestarConfig, StarletteConfig +from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs from sqlspec.utils.serializers import from_json, to_json if TYPE_CHECKING: from collections.abc import AsyncGenerator, Callable from sqlspec.core import StatementConfig + from sqlspec.observability import ObservabilityConfig __all__ = ("AiosqliteConfig", "AiosqliteConnectionParams", "AiosqliteDriverFeatures", "AiosqlitePoolParams") @@ -94,7 +95,8 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "AiosqliteDriverFeatures | dict[str, Any] | None" = None, bind_key: "str | None" = None, - extension_config: "dict[str, dict[str, Any]] | LitestarConfig | FastAPIConfig | StarletteConfig | FlaskConfig | ADKConfig | None" = None, + extension_config: "ExtensionConfigs | None" = None, + observability_config: "ObservabilityConfig | None" = None, ) -> None: """Initialize AioSQLite configuration. @@ -106,6 +108,7 @@ def __init__( driver_features: Optional driver feature configuration. bind_key: Optional unique identifier for this configuration. extension_config: Extension-specific configuration (e.g., Litestar plugin settings) + observability_config: Adapter-level observability overrides for lifecycle hooks and observers """ config_dict = dict(pool_config) if pool_config else {} @@ -142,6 +145,7 @@ def __init__( driver_features=processed_driver_features, bind_key=bind_key, extension_config=extension_config, + observability_config=observability_config, ) def _get_pool_config_dict(self) -> "dict[str, Any]": @@ -206,11 +210,12 @@ async def provide_session( An AiosqliteDriver instance. """ async with self.provide_connection(*_args, **_kwargs) as connection: - yield self.driver_type( + driver = self.driver_type( connection=connection, statement_config=statement_config or self.statement_config, driver_features=self.driver_features, ) + yield self._prepare_driver(driver) async def _create_pool(self) -> AiosqliteConnectionPool: """Create the connection pool instance. diff --git a/sqlspec/adapters/aiosqlite/driver.py b/sqlspec/adapters/aiosqlite/driver.py index 8e118dc9f..3f99807e6 100644 --- a/sqlspec/adapters/aiosqlite/driver.py +++ b/sqlspec/adapters/aiosqlite/driver.py @@ -298,8 +298,8 @@ async def select_to_storage( self._require_capability("arrow_export_enabled") arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) async_pipeline: AsyncStoragePipeline = cast("AsyncStoragePipeline", self._storage_pipeline()) - telemetry_payload = await arrow_result.write_to_storage_async( - destination, format_hint=format_hint, pipeline=async_pipeline + telemetry_payload = await self._write_result_to_storage_async( + arrow_result, destination, format_hint=format_hint, pipeline=async_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry) diff --git a/sqlspec/adapters/asyncmy/config.py b/sqlspec/adapters/asyncmy/config.py index cadbc79d2..c4c51cf06 100644 --- a/sqlspec/adapters/asyncmy/config.py +++ b/sqlspec/adapters/asyncmy/config.py @@ -18,7 +18,7 @@ asyncmy_statement_config, build_asyncmy_statement_config, ) -from sqlspec.config import ADKConfig, AsyncDatabaseConfig, FastAPIConfig, FlaskConfig, LitestarConfig, StarletteConfig +from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs from sqlspec.utils.serializers import from_json, to_json if TYPE_CHECKING: @@ -28,6 +28,7 @@ from asyncmy.pool import Pool # pyright: ignore from sqlspec.core import StatementConfig + from sqlspec.observability import ObservabilityConfig __all__ = ("AsyncmyConfig", "AsyncmyConnectionParams", "AsyncmyDriverFeatures", "AsyncmyPoolParams") @@ -104,7 +105,8 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "AsyncmyDriverFeatures | dict[str, Any] | None" = None, bind_key: "str | None" = None, - extension_config: "dict[str, dict[str, Any]] | LitestarConfig | FastAPIConfig | StarletteConfig | FlaskConfig | ADKConfig | None" = None, + extension_config: "ExtensionConfigs | None" = None, + observability_config: "ObservabilityConfig | None" = None, ) -> None: """Initialize Asyncmy configuration. @@ -116,6 +118,7 @@ def __init__( driver_features: Driver feature configuration (TypedDict or dict) bind_key: Optional unique identifier for this configuration extension_config: Extension-specific configuration (e.g., Litestar plugin settings) + observability_config: Adapter-level observability overrides for lifecycle hooks and observers """ processed_pool_config: dict[str, Any] = dict(pool_config) if pool_config else {} if "extra" in processed_pool_config: @@ -141,6 +144,7 @@ def __init__( driver_features=processed_driver_features, bind_key=bind_key, extension_config=extension_config, + observability_config=observability_config, ) async def _create_pool(self) -> "AsyncmyPool": # pyright: ignore @@ -206,9 +210,10 @@ async def provide_session( """ async with self.provide_connection(*args, **kwargs) as connection: final_statement_config = statement_config or self.statement_config or asyncmy_statement_config - yield self.driver_type( + driver = self.driver_type( connection=connection, statement_config=final_statement_config, driver_features=self.driver_features ) + yield self._prepare_driver(driver) async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool": # pyright: ignore """Provide async pool instance. diff --git a/sqlspec/adapters/asyncmy/driver.py b/sqlspec/adapters/asyncmy/driver.py index aaa3d4461..d3f229255 100644 --- a/sqlspec/adapters/asyncmy/driver.py +++ b/sqlspec/adapters/asyncmy/driver.py @@ -457,8 +457,8 @@ async def select_to_storage( self._require_capability("arrow_export_enabled") arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) async_pipeline: AsyncStoragePipeline = cast("AsyncStoragePipeline", self._storage_pipeline()) - telemetry_payload = await arrow_result.write_to_storage_async( - destination, format_hint=format_hint, pipeline=async_pipeline + telemetry_payload = await self._write_result_to_storage_async( + arrow_result, destination, format_hint=format_hint, pipeline=async_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry) diff --git a/sqlspec/adapters/asyncpg/config.py b/sqlspec/adapters/asyncpg/config.py index 8f784fe58..a9af9dd39 100644 --- a/sqlspec/adapters/asyncpg/config.py +++ b/sqlspec/adapters/asyncpg/config.py @@ -20,7 +20,7 @@ asyncpg_statement_config, build_asyncpg_statement_config, ) -from sqlspec.config import ADKConfig, AsyncDatabaseConfig, FastAPIConfig, FlaskConfig, LitestarConfig, StarletteConfig +from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs from sqlspec.exceptions import ImproperConfigurationError from sqlspec.typing import ALLOYDB_CONNECTOR_INSTALLED, CLOUD_SQL_CONNECTOR_INSTALLED, PGVECTOR_INSTALLED from sqlspec.utils.serializers import from_json, to_json @@ -30,6 +30,7 @@ from collections.abc import AsyncGenerator, Awaitable from sqlspec.core import StatementConfig + from sqlspec.observability import ObservabilityConfig __all__ = ("AsyncpgConfig", "AsyncpgConnectionConfig", "AsyncpgDriverFeatures", "AsyncpgPoolConfig") @@ -154,7 +155,8 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "AsyncpgDriverFeatures | dict[str, Any] | None" = None, bind_key: "str | None" = None, - extension_config: "dict[str, dict[str, Any]] | LitestarConfig | FastAPIConfig | StarletteConfig | FlaskConfig | ADKConfig | None" = None, + extension_config: "ExtensionConfigs | None" = None, + observability_config: "ObservabilityConfig | None" = None, ) -> None: """Initialize AsyncPG configuration. @@ -166,6 +168,7 @@ def __init__( driver_features: Driver features configuration (TypedDict or dict) bind_key: Optional unique identifier for this configuration extension_config: Extension-specific configuration (e.g., Litestar plugin settings) + observability_config: Adapter-level observability overrides for lifecycle hooks and observers """ features_dict: dict[str, Any] = dict(driver_features) if driver_features else {} @@ -188,6 +191,7 @@ def __init__( driver_features=features_dict, bind_key=bind_key, extension_config=extension_config, + observability_config=observability_config, ) self._cloud_sql_connector: Any | None = None @@ -415,9 +419,10 @@ async def provide_session( """ async with self.provide_connection(*args, **kwargs) as connection: final_statement_config = statement_config or self.statement_config or asyncpg_statement_config - yield self.driver_type( + driver = self.driver_type( connection=connection, statement_config=final_statement_config, driver_features=self.driver_features ) + yield self._prepare_driver(driver) async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool[Record]": """Provide async pool instance. diff --git a/sqlspec/adapters/asyncpg/driver.py b/sqlspec/adapters/asyncpg/driver.py index fd662ba02..263c5755d 100644 --- a/sqlspec/adapters/asyncpg/driver.py +++ b/sqlspec/adapters/asyncpg/driver.py @@ -359,8 +359,8 @@ async def select_to_storage( self._require_capability("arrow_export_enabled") arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) async_pipeline: AsyncStoragePipeline = cast("AsyncStoragePipeline", self._storage_pipeline()) - telemetry_payload = await arrow_result.write_to_storage_async( - destination, format_hint=format_hint, pipeline=async_pipeline + telemetry_payload = await self._write_result_to_storage_async( + arrow_result, destination, format_hint=format_hint, pipeline=async_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry) diff --git a/sqlspec/adapters/bigquery/config.py b/sqlspec/adapters/bigquery/config.py index 1eb67c0bb..0dc0fa45c 100644 --- a/sqlspec/adapters/bigquery/config.py +++ b/sqlspec/adapters/bigquery/config.py @@ -14,8 +14,9 @@ BigQueryExceptionHandler, build_bigquery_statement_config, ) -from sqlspec.config import ADKConfig, FastAPIConfig, FlaskConfig, LitestarConfig, NoPoolSyncConfig, StarletteConfig +from sqlspec.config import ExtensionConfigs, NoPoolSyncConfig from sqlspec.exceptions import ImproperConfigurationError +from sqlspec.observability import ObservabilityConfig from sqlspec.typing import Empty from sqlspec.utils.serializers import to_json @@ -120,7 +121,8 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "BigQueryDriverFeatures | dict[str, Any] | None" = None, bind_key: "str | None" = None, - extension_config: "dict[str, dict[str, Any]] | LitestarConfig | FastAPIConfig | StarletteConfig | FlaskConfig | ADKConfig | None" = None, + extension_config: "ExtensionConfigs | None" = None, + observability_config: "ObservabilityConfig | None" = None, ) -> None: """Initialize BigQuery configuration. @@ -131,6 +133,7 @@ def __init__( driver_features: BigQuery-specific driver features bind_key: Optional unique identifier for this configuration extension_config: Extension-specific configuration (e.g., Litestar plugin settings) + observability_config: Adapter-level observability overrides for lifecycle hooks and observers """ self.connection_config: dict[str, Any] = dict(connection_config) if connection_config else {} @@ -138,26 +141,42 @@ def __init__( extras = self.connection_config.pop("extra") self.connection_config.update(extras) - self.driver_features: dict[str, Any] = dict(driver_features) if driver_features else {} - self.driver_features.setdefault("enable_uuid_conversion", True) - serializer = self.driver_features.setdefault("json_serializer", to_json) + processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {} + user_connection_hook = processed_driver_features.pop("on_connection_create", None) + processed_driver_features.setdefault("enable_uuid_conversion", True) + serializer = processed_driver_features.setdefault("json_serializer", to_json) - self._connection_instance: BigQueryConnection | None = self.driver_features.get("connection_instance") + self._connection_instance: BigQueryConnection | None = processed_driver_features.get("connection_instance") if "default_query_job_config" not in self.connection_config: self._setup_default_job_config() base_statement_config = statement_config or build_bigquery_statement_config(json_serializer=serializer) + local_observability = observability_config + if user_connection_hook is not None: + + def _wrap_hook(context: dict[str, Any]) -> None: + connection = context.get("connection") + if connection is None: + return + user_connection_hook(connection) + + lifecycle_override = ObservabilityConfig(lifecycle={"on_connection_create": [_wrap_hook]}) + local_observability = ObservabilityConfig.merge(local_observability, lifecycle_override) + super().__init__( connection_config=self.connection_config, migration_config=migration_config, statement_config=base_statement_config, - driver_features=self.driver_features, + driver_features=processed_driver_features, bind_key=bind_key, extension_config=extension_config, + observability_config=local_observability, ) + self.driver_features = processed_driver_features + def _setup_default_job_config(self) -> None: """Set up default job configuration.""" @@ -217,10 +236,6 @@ def create_connection(self) -> BigQueryConnection: if default_load_job_config is not None: self.driver_features["default_load_job_config"] = default_load_job_config - on_connection_create = self.driver_features.get("on_connection_create") - if on_connection_create: - on_connection_create(connection) - self._connection_instance = connection except Exception as e: project = self.connection_config.get("project", "Unknown") @@ -263,7 +278,7 @@ def provide_session( driver = self.driver_type( connection=connection, statement_config=final_statement_config, driver_features=self.driver_features ) - yield driver + yield self._prepare_driver(driver) def get_signature_namespace(self) -> "dict[str, Any]": """Get the signature namespace for BigQuery types. diff --git a/sqlspec/adapters/bigquery/driver.py b/sqlspec/adapters/bigquery/driver.py index 26d84b4f5..465f551ce 100644 --- a/sqlspec/adapters/bigquery/driver.py +++ b/sqlspec/adapters/bigquery/driver.py @@ -812,8 +812,8 @@ def select_to_storage( self._require_capability("arrow_export_enabled") arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) sync_pipeline: SyncStoragePipeline = cast("SyncStoragePipeline", self._storage_pipeline()) - telemetry_payload = arrow_result.write_to_storage_sync( - destination, format_hint=format_hint, pipeline=sync_pipeline + telemetry_payload = self._write_result_to_storage_sync( + arrow_result, destination, format_hint=format_hint, pipeline=sync_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry) diff --git a/sqlspec/adapters/duckdb/config.py b/sqlspec/adapters/duckdb/config.py index 111a4057c..d5472ddef 100644 --- a/sqlspec/adapters/duckdb/config.py +++ b/sqlspec/adapters/duckdb/config.py @@ -1,6 +1,6 @@ """DuckDB database configuration with connection pooling.""" -from collections.abc import Sequence +from collections.abc import Callable, Sequence from contextlib import contextmanager from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast @@ -14,7 +14,8 @@ build_duckdb_statement_config, ) from sqlspec.adapters.duckdb.pool import DuckDBConnectionPool -from sqlspec.config import ADKConfig, FastAPIConfig, FlaskConfig, LitestarConfig, StarletteConfig, SyncDatabaseConfig +from sqlspec.config import ExtensionConfigs, SyncDatabaseConfig +from sqlspec.observability import ObservabilityConfig from sqlspec.utils.serializers import to_json if TYPE_CHECKING: @@ -198,7 +199,8 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "DuckDBDriverFeatures | dict[str, Any] | None" = None, bind_key: "str | None" = None, - extension_config: "dict[str, dict[str, Any]] | LitestarConfig | FastAPIConfig | StarletteConfig | FlaskConfig | ADKConfig | None" = None, + extension_config: "ExtensionConfigs | None" = None, + observability_config: "ObservabilityConfig | None" = None, ) -> None: """Initialize DuckDB configuration. @@ -211,6 +213,7 @@ def __init__( and enable_uuid_conversion options bind_key: Optional unique identifier for this configuration extension_config: Extension-specific configuration (e.g., Litestar plugin settings) + observability_config: Adapter-level observability overrides for lifecycle hooks and observers """ if pool_config is None: pool_config = {} @@ -220,9 +223,24 @@ def __init__( pool_config["database"] = ":memory:shared_db" processed_features = dict(driver_features) if driver_features else {} + user_connection_hook = cast( + "Callable[[Any], None] | None", processed_features.pop("on_connection_create", None) + ) processed_features.setdefault("enable_uuid_conversion", True) serializer = processed_features.setdefault("json_serializer", to_json) + local_observability = observability_config + if user_connection_hook is not None: + + def _wrap_lifecycle_hook(context: dict[str, Any]) -> None: + connection = context.get("connection") + if connection is None: + return + user_connection_hook(connection) + + lifecycle_override = ObservabilityConfig(lifecycle={"on_connection_create": [_wrap_lifecycle_hook]}) + local_observability = ObservabilityConfig.merge(local_observability, lifecycle_override) + base_statement_config = statement_config or build_duckdb_statement_config( json_serializer=cast("Callable[[Any], str]", serializer) ) @@ -235,6 +253,7 @@ def __init__( statement_config=base_statement_config, driver_features=processed_features, extension_config=extension_config, + observability_config=local_observability, ) def _get_connection_config_dict(self) -> "dict[str, Any]": @@ -252,25 +271,11 @@ def _create_pool(self) -> DuckDBConnectionPool: extensions = self.driver_features.get("extensions", None) secrets = self.driver_features.get("secrets", None) - on_connection_create = self.driver_features.get("on_connection_create", None) - extensions_dicts = [dict(ext) for ext in extensions] if extensions else None secrets_dicts = [dict(secret) for secret in secrets] if secrets else None - pool_callback = None - if on_connection_create: - - def wrapped_callback(conn: DuckDBConnection) -> None: - on_connection_create(conn) - - pool_callback = wrapped_callback - return DuckDBConnectionPool( - connection_config=connection_config, - extensions=extensions_dicts, - secrets=secrets_dicts, - on_connection_create=pool_callback, - **self.pool_config, + connection_config=connection_config, extensions=extensions_dicts, secrets=secrets_dicts, **self.pool_config ) def _close_pool(self) -> None: @@ -333,7 +338,7 @@ def provide_session( statement_config=statement_config or self.statement_config, driver_features=self.driver_features, ) - yield driver + yield self._prepare_driver(driver) def get_signature_namespace(self) -> "dict[str, Any]": """Get the signature namespace for DuckDB types. diff --git a/sqlspec/adapters/duckdb/driver.py b/sqlspec/adapters/duckdb/driver.py index f93d22c76..ff13fc0e0 100644 --- a/sqlspec/adapters/duckdb/driver.py +++ b/sqlspec/adapters/duckdb/driver.py @@ -504,8 +504,8 @@ def select_to_storage( self._require_capability("arrow_export_enabled") arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) sync_pipeline: SyncStoragePipeline = cast("SyncStoragePipeline", self._storage_pipeline()) - telemetry_payload = arrow_result.write_to_storage_sync( - destination, format_hint=format_hint, pipeline=sync_pipeline + telemetry_payload = self._write_result_to_storage_sync( + arrow_result, destination, format_hint=format_hint, pipeline=sync_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry) diff --git a/sqlspec/adapters/oracledb/config.py b/sqlspec/adapters/oracledb/config.py index 29092507b..3e0a5d1bb 100644 --- a/sqlspec/adapters/oracledb/config.py +++ b/sqlspec/adapters/oracledb/config.py @@ -26,15 +26,7 @@ oracledb_statement_config, ) from sqlspec.adapters.oracledb.migrations import OracleAsyncMigrationTracker, OracleSyncMigrationTracker -from sqlspec.config import ( - ADKConfig, - AsyncDatabaseConfig, - FastAPIConfig, - FlaskConfig, - LitestarConfig, - StarletteConfig, - SyncDatabaseConfig, -) +from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs, SyncDatabaseConfig from sqlspec.typing import NUMPY_INSTALLED if TYPE_CHECKING: @@ -143,7 +135,7 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "OracleDriverFeatures | dict[str, Any] | None" = None, bind_key: "str | None" = None, - extension_config: "dict[str, dict[str, Any]] | LitestarConfig | FastAPIConfig | StarletteConfig | FlaskConfig | ADKConfig | None" = None, + extension_config: "ExtensionConfigs | None" = None, ) -> None: """Initialize Oracle synchronous configuration. @@ -252,11 +244,12 @@ def provide_session( """ _ = (args, kwargs) # Mark as intentionally unused with self.provide_connection() as conn: - yield self.driver_type( + driver = self.driver_type( connection=conn, statement_config=statement_config or self.statement_config, driver_features=self.driver_features, ) + yield self._prepare_driver(driver) def provide_pool(self) -> "OracleSyncConnectionPool": """Provide pool instance. @@ -319,7 +312,7 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "OracleDriverFeatures | dict[str, Any] | None" = None, bind_key: "str | None" = None, - extension_config: "dict[str, dict[str, Any]] | LitestarConfig | FastAPIConfig | StarletteConfig | FlaskConfig | ADKConfig | None" = None, + extension_config: "ExtensionConfigs | None" = None, ) -> None: """Initialize Oracle asynchronous configuration. @@ -431,11 +424,12 @@ async def provide_session( """ _ = (args, kwargs) # Mark as intentionally unused async with self.provide_connection() as conn: - yield self.driver_type( + driver = self.driver_type( connection=conn, statement_config=statement_config or self.statement_config, driver_features=self.driver_features, ) + yield self._prepare_driver(driver) async def provide_pool(self) -> "OracleAsyncConnectionPool": """Provide async pool instance. diff --git a/sqlspec/adapters/oracledb/data_dictionary.py b/sqlspec/adapters/oracledb/data_dictionary.py index e8b3734bd..55193081d 100644 --- a/sqlspec/adapters/oracledb/data_dictionary.py +++ b/sqlspec/adapters/oracledb/data_dictionary.py @@ -2,7 +2,6 @@ # cspell:ignore pdbs import re -from contextlib import suppress from typing import TYPE_CHECKING, Any, cast from sqlspec.driver import ( @@ -30,6 +29,14 @@ # Compiled regex patterns ORACLE_VERSION_PATTERN = re.compile(r"Oracle Database (\d+)c?.* Release (\d+)\.(\d+)\.(\d+)") +COMPONENT_VERSION_SQL = ( + "SELECT product || ' Release ' || version AS \"banner\" " + "FROM product_component_version WHERE product LIKE 'Oracle%' " + "ORDER BY version DESC FETCH FIRST 1 ROWS ONLY" +) + +AUTONOMOUS_SERVICE_SQL = "SELECT sys_context('USERENV','CLOUD_SERVICE') AS \"service\" FROM dual" + __all__ = ("OracleAsyncDataDictionary", "OracleSyncDataDictionary", "OracleVersionInfo") @@ -130,6 +137,13 @@ def _get_columns_sql(self, table: str, schema: "str | None" = None) -> str: ORDER BY column_id """ + def _select_version_banner(self, driver: "OracleSyncDriver") -> str: + return str(driver.select_value(COMPONENT_VERSION_SQL)) + + async def _select_version_banner_async(self, driver: "OracleAsyncDriver") -> str: + result = await driver.select_value(COMPONENT_VERSION_SQL) + return str(result) + def _get_oracle_version(self, driver: "OracleAsyncDriver | OracleSyncDriver") -> "OracleVersionInfo | None": """Get Oracle database version information. @@ -139,7 +153,7 @@ def _get_oracle_version(self, driver: "OracleAsyncDriver | OracleSyncDriver") -> Returns: Oracle version information or None if detection fails """ - banner = driver.select_value("SELECT banner AS \"banner\" FROM v$version WHERE banner LIKE 'Oracle%'") + banner = self._select_version_banner(cast("OracleSyncDriver", driver)) # Parse version from banner like "Oracle Database 21c Enterprise Edition Release 21.0.0.0.0 - Production" # or "Oracle Database 19c Standard Edition 2 Release 19.0.0.0.0 - Production" @@ -220,8 +234,17 @@ def _is_oracle_autonomous(self, driver: "OracleSyncDriver") -> bool: Returns: True if this is an Autonomous Database, False otherwise """ - result = driver.select_value_or_none('SELECT COUNT(1) AS "cnt" FROM v$pdbs WHERE cloud_identity IS NOT NULL') - return bool(result and int(result) > 0) + try: + service = driver.select_value_or_none(AUTONOMOUS_SERVICE_SQL) + except Exception: + logger.debug("Unable to detect Oracle cloud service via sys_context") + return False + if service is None: + return False + normalized = str(service).strip().upper() + if not normalized: + return False + return "AUTONOMOUS" in normalized or normalized.startswith(("ATP", "ADW")) def get_version(self, driver: SyncDriverAdapterBase) -> "OracleVersionInfo | None": """Get Oracle database version information. @@ -348,9 +371,8 @@ async def get_version(self, driver: AsyncDriverAdapterBase) -> "OracleVersionInf Returns: Oracle version information or None if detection fails """ - banner = await cast("OracleAsyncDriver", driver).select_value( - "SELECT banner AS \"banner\" FROM v$version WHERE banner LIKE 'Oracle%'" - ) + oracle_driver = cast("OracleAsyncDriver", driver) + banner = await self._select_version_banner_async(oracle_driver) version_match = ORACLE_VERSION_PATTERN.search(str(banner)) @@ -369,7 +391,6 @@ async def get_version(self, driver: AsyncDriverAdapterBase) -> "OracleVersionInf version_info = OracleVersionInfo(release_major, minor, patch) # Enhance with additional information - oracle_driver = cast("OracleAsyncDriver", driver) compatible = await self._get_oracle_compatible_async(oracle_driver) is_autonomous = await self._is_oracle_autonomous_async(oracle_driver) @@ -407,15 +428,19 @@ async def _is_oracle_autonomous_async(self, driver: "OracleAsyncDriver") -> bool Returns: True if this is an Autonomous Database, False otherwise """ - # Check for cloud_identity in v$pdbs (most reliable for Autonomous) - with suppress(Exception): - result = await driver.execute('SELECT COUNT(1) AS "cnt" FROM v$pdbs WHERE cloud_identity IS NOT NULL') - if result.data: - count = result.data[0]["cnt"] if isinstance(result.data[0], dict) else result.data[0][0] - if int(count) > 0: - logger.debug("Detected Oracle Autonomous Database via v$pdbs") - return True - + try: + service = await driver.select_value_or_none(AUTONOMOUS_SERVICE_SQL) + except Exception: + logger.debug("Unable to detect Oracle cloud service via sys_context (async)") + return False + if service is None: + return False + normalized = str(service).strip().upper() + if not normalized: + return False + if "AUTONOMOUS" in normalized or normalized.startswith(("ATP", "ADW")): + logger.debug("Detected Oracle Autonomous Database via USERENV context") + return True logger.debug("Oracle Autonomous Database not detected") return False diff --git a/sqlspec/adapters/oracledb/driver.py b/sqlspec/adapters/oracledb/driver.py index 5de971fd7..aa639fa6d 100644 --- a/sqlspec/adapters/oracledb/driver.py +++ b/sqlspec/adapters/oracledb/driver.py @@ -506,8 +506,8 @@ def select_to_storage( self._require_capability("arrow_export_enabled") arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) sync_pipeline: SyncStoragePipeline = cast("SyncStoragePipeline", self._storage_pipeline()) - telemetry_payload = arrow_result.write_to_storage_sync( - destination, format_hint=format_hint, pipeline=sync_pipeline + telemetry_payload = self._write_result_to_storage_sync( + arrow_result, destination, format_hint=format_hint, pipeline=sync_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry) @@ -859,8 +859,8 @@ async def select_to_storage( self._require_capability("arrow_export_enabled") arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) async_pipeline: AsyncStoragePipeline = cast("AsyncStoragePipeline", self._storage_pipeline()) - telemetry_payload = await arrow_result.write_to_storage_async( - destination, format_hint=format_hint, pipeline=async_pipeline + telemetry_payload = await self._write_result_to_storage_async( + arrow_result, destination, format_hint=format_hint, pipeline=async_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry) diff --git a/sqlspec/adapters/psqlpy/config.py b/sqlspec/adapters/psqlpy/config.py index 629671f4f..790534b58 100644 --- a/sqlspec/adapters/psqlpy/config.py +++ b/sqlspec/adapters/psqlpy/config.py @@ -15,7 +15,7 @@ PsqlpyExceptionHandler, build_psqlpy_statement_config, ) -from sqlspec.config import ADKConfig, AsyncDatabaseConfig, FastAPIConfig, FlaskConfig, LitestarConfig, StarletteConfig +from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs from sqlspec.core import StatementConfig from sqlspec.typing import PGVECTOR_INSTALLED from sqlspec.utils.serializers import to_json @@ -119,7 +119,7 @@ def __init__( statement_config: StatementConfig | None = None, driver_features: "PsqlpyDriverFeatures | dict[str, Any] | None" = None, bind_key: str | None = None, - extension_config: "dict[str, dict[str, Any]] | LitestarConfig | FastAPIConfig | StarletteConfig | FlaskConfig | ADKConfig | None" = None, + extension_config: "ExtensionConfigs | None" = None, ) -> None: """Initialize Psqlpy configuration. @@ -236,11 +236,12 @@ async def provide_session( A PsqlpyDriver instance. """ async with self.provide_connection(*args, **kwargs) as conn: - yield self.driver_type( + driver = self.driver_type( connection=conn, statement_config=statement_config or self.statement_config, driver_features=self.driver_features, ) + yield self._prepare_driver(driver) async def provide_pool(self, *args: Any, **kwargs: Any) -> ConnectionPool: """Provide async pool instance. diff --git a/sqlspec/adapters/psqlpy/driver.py b/sqlspec/adapters/psqlpy/driver.py index 1254a6f02..c72520651 100644 --- a/sqlspec/adapters/psqlpy/driver.py +++ b/sqlspec/adapters/psqlpy/driver.py @@ -513,8 +513,8 @@ async def select_to_storage( self._require_capability("arrow_export_enabled") arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) async_pipeline: AsyncStoragePipeline = cast("AsyncStoragePipeline", self._storage_pipeline()) - telemetry_payload = await arrow_result.write_to_storage_async( - destination, format_hint=format_hint, pipeline=async_pipeline + telemetry_payload = await self._write_result_to_storage_async( + arrow_result, destination, format_hint=format_hint, pipeline=async_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry) diff --git a/sqlspec/adapters/psycopg/config.py b/sqlspec/adapters/psycopg/config.py index 73419f063..d7e8d376c 100644 --- a/sqlspec/adapters/psycopg/config.py +++ b/sqlspec/adapters/psycopg/config.py @@ -21,15 +21,7 @@ build_psycopg_statement_config, psycopg_statement_config, ) -from sqlspec.config import ( - ADKConfig, - AsyncDatabaseConfig, - FastAPIConfig, - FlaskConfig, - LitestarConfig, - StarletteConfig, - SyncDatabaseConfig, -) +from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs, SyncDatabaseConfig from sqlspec.typing import PGVECTOR_INSTALLED from sqlspec.utils.serializers import to_json @@ -127,7 +119,7 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, bind_key: "str | None" = None, - extension_config: "dict[str, dict[str, Any]] | LitestarConfig | FastAPIConfig | StarletteConfig | FlaskConfig | ADKConfig | None" = None, + extension_config: "ExtensionConfigs | None" = None, ) -> None: """Initialize Psycopg synchronous configuration. @@ -272,9 +264,10 @@ def provide_session( """ with self.provide_connection(*args, **kwargs) as conn: final_statement_config = statement_config or self.statement_config - yield self.driver_type( + driver = self.driver_type( connection=conn, statement_config=final_statement_config, driver_features=self.driver_features ) + yield self._prepare_driver(driver) def provide_pool(self, *args: Any, **kwargs: Any) -> "ConnectionPool": """Provide pool instance. @@ -327,7 +320,7 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, bind_key: "str | None" = None, - extension_config: "dict[str, dict[str, Any]] | LitestarConfig | FastAPIConfig | StarletteConfig | FlaskConfig | ADKConfig | None" = None, + extension_config: "ExtensionConfigs | None" = None, ) -> None: """Initialize Psycopg asynchronous configuration. @@ -462,9 +455,10 @@ async def provide_session( """ async with self.provide_connection(*args, **kwargs) as conn: final_statement_config = statement_config or psycopg_statement_config - yield self.driver_type( + driver = self.driver_type( connection=conn, statement_config=final_statement_config, driver_features=self.driver_features ) + yield self._prepare_driver(driver) async def provide_pool(self, *args: Any, **kwargs: Any) -> "AsyncConnectionPool": """Provide async pool instance. diff --git a/sqlspec/adapters/psycopg/driver.py b/sqlspec/adapters/psycopg/driver.py index ee867409d..6d21f35fe 100644 --- a/sqlspec/adapters/psycopg/driver.py +++ b/sqlspec/adapters/psycopg/driver.py @@ -483,8 +483,8 @@ def select_to_storage( self._require_capability("arrow_export_enabled") arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) sync_pipeline: SyncStoragePipeline = cast("SyncStoragePipeline", self._storage_pipeline()) - telemetry_payload = arrow_result.write_to_storage_sync( - destination, format_hint=format_hint, pipeline=sync_pipeline + telemetry_payload = self._write_result_to_storage_sync( + arrow_result, destination, format_hint=format_hint, pipeline=sync_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry) @@ -925,8 +925,8 @@ async def select_to_storage( self._require_capability("arrow_export_enabled") arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) async_pipeline: AsyncStoragePipeline = cast("AsyncStoragePipeline", self._storage_pipeline()) - telemetry_payload = await arrow_result.write_to_storage_async( - destination, format_hint=format_hint, pipeline=async_pipeline + telemetry_payload = await self._write_result_to_storage_async( + arrow_result, destination, format_hint=format_hint, pipeline=async_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry) diff --git a/sqlspec/adapters/sqlite/config.py b/sqlspec/adapters/sqlite/config.py index 43afd6552..f48d3969f 100644 --- a/sqlspec/adapters/sqlite/config.py +++ b/sqlspec/adapters/sqlite/config.py @@ -11,7 +11,7 @@ from sqlspec.adapters.sqlite._types import SqliteConnection from sqlspec.adapters.sqlite.driver import SqliteCursor, SqliteDriver, SqliteExceptionHandler, sqlite_statement_config from sqlspec.adapters.sqlite.pool import SqliteConnectionPool -from sqlspec.config import ADKConfig, FastAPIConfig, FlaskConfig, LitestarConfig, StarletteConfig, SyncDatabaseConfig +from sqlspec.config import ExtensionConfigs, SyncDatabaseConfig from sqlspec.utils.serializers import from_json, to_json logger = logging.getLogger(__name__) @@ -20,6 +20,7 @@ from collections.abc import Callable, Generator from sqlspec.core import StatementConfig + from sqlspec.observability import ObservabilityConfig class SqliteConnectionParams(TypedDict): @@ -77,7 +78,8 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "SqliteDriverFeatures | dict[str, Any] | None" = None, bind_key: "str | None" = None, - extension_config: "dict[str, dict[str, Any]] | LitestarConfig | FastAPIConfig | StarletteConfig | FlaskConfig | ADKConfig | None" = None, + extension_config: "ExtensionConfigs | None" = None, + observability_config: "ObservabilityConfig | None" = None, ) -> None: """Initialize SQLite configuration. @@ -89,6 +91,7 @@ def __init__( driver_features: Optional driver feature configuration bind_key: Optional bind key for the configuration extension_config: Extension-specific configuration (e.g., Litestar plugin settings) + observability_config: Adapter-level observability overrides for lifecycle hooks and observers """ if pool_config is None: pool_config = {} @@ -125,6 +128,7 @@ def __init__( statement_config=base_statement_config, driver_features=processed_driver_features, extension_config=extension_config, + observability_config=observability_config, ) def _get_connection_config_dict(self) -> "dict[str, Any]": @@ -191,11 +195,12 @@ def provide_session( SqliteDriver: A driver instance with thread-local connection """ with self.provide_connection(*args, **kwargs) as connection: - yield self.driver_type( + driver = self.driver_type( connection=connection, statement_config=statement_config or self.statement_config, driver_features=self.driver_features, ) + yield self._prepare_driver(driver) def get_signature_namespace(self) -> "dict[str, Any]": """Get the signature namespace for SQLite types. diff --git a/sqlspec/adapters/sqlite/driver.py b/sqlspec/adapters/sqlite/driver.py index 31216e895..ed34adb8d 100644 --- a/sqlspec/adapters/sqlite/driver.py +++ b/sqlspec/adapters/sqlite/driver.py @@ -365,8 +365,8 @@ def select_to_storage( self._require_capability("arrow_export_enabled") arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) sync_pipeline: SyncStoragePipeline = cast("SyncStoragePipeline", self._storage_pipeline()) - telemetry_payload = arrow_result.write_to_storage_sync( - destination, format_hint=format_hint, pipeline=sync_pipeline + telemetry_payload = self._write_result_to_storage_sync( + arrow_result, destination, format_hint=format_hint, pipeline=sync_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry) diff --git a/sqlspec/base.py b/sqlspec/base.py index 8b726b9ce..fbb91c2d9 100644 --- a/sqlspec/base.py +++ b/sqlspec/base.py @@ -1,7 +1,8 @@ import asyncio import atexit -from collections.abc import Awaitable, Coroutine -from typing import TYPE_CHECKING, Any, Union, cast, overload +from collections.abc import AsyncIterator, Awaitable, Coroutine, Iterator +from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager, contextmanager +from typing import TYPE_CHECKING, Any, TypeGuard, Union, cast, overload from sqlspec.config import ( AsyncConfigT, @@ -21,15 +22,16 @@ reset_cache_stats, update_cache_config, ) +from sqlspec.loader import SQLFileLoader +from sqlspec.observability import ObservabilityConfig, ObservabilityRuntime, TelemetryDiagnostics +from sqlspec.typing import ConnectionT from sqlspec.utils.logging import get_logger if TYPE_CHECKING: - from contextlib import AbstractAsyncContextManager, AbstractContextManager from pathlib import Path from sqlspec.core import SQL - from sqlspec.loader import SQLFileLoader - from sqlspec.typing import ConnectionT, PoolT + from sqlspec.typing import PoolT __all__ = ("SQLSpec",) @@ -37,16 +39,30 @@ logger = get_logger() +def _is_async_context_manager(obj: Any) -> TypeGuard[AbstractAsyncContextManager[Any]]: + return hasattr(obj, "__aenter__") + + +def _is_sync_context_manager(obj: Any) -> TypeGuard[AbstractContextManager[Any]]: + return hasattr(obj, "__enter__") + + class SQLSpec: """Configuration manager and registry for database connections and pools.""" - __slots__ = ("_configs", "_instance_cache_config", "_sql_loader") + __slots__ = ("_configs", "_instance_cache_config", "_loader_runtime", "_observability_config", "_sql_loader") - def __init__(self, *, loader: "SQLFileLoader | None" = None) -> None: + def __init__( + self, *, loader: "SQLFileLoader | None" = None, observability_config: "ObservabilityConfig | None" = None + ) -> None: self._configs: dict[Any, DatabaseConfigProtocol[Any, Any, Any]] = {} atexit.register(self._cleanup_sync_pools) self._instance_cache_config: CacheConfig | None = None self._sql_loader: SQLFileLoader | None = loader + self._observability_config = observability_config + self._loader_runtime = ObservabilityRuntime(observability_config, config_name="SQLFileLoader") + if self._sql_loader is not None: + self._sql_loader.set_observability_runtime(self._loader_runtime) @staticmethod def _get_config_name(obj: Any) -> str: @@ -129,6 +145,8 @@ def add_config(self, config: "SyncConfigT | AsyncConfigT") -> "type[SyncConfigT config_type = type(config) if config_type in self._configs: logger.debug("Configuration for %s already exists. Overwriting.", config_type.__name__) + if hasattr(config, "attach_observability"): + config.attach_observability(self._observability_config) self._configs[config_type] = config return config_type @@ -170,6 +188,30 @@ def configs(self) -> "dict[type, DatabaseConfigProtocol[Any, Any, Any]]": """ return self._configs + def telemetry_snapshot(self) -> "dict[str, Any]": + """Return aggregated diagnostics across all registered configurations.""" + + diagnostics = TelemetryDiagnostics() + loader_metrics = self._loader_runtime.metrics_snapshot() + if loader_metrics: + diagnostics.add_metric_snapshot(loader_metrics) + for config in self._configs.values(): + runtime = config.get_observability_runtime() + diagnostics.add_lifecycle_snapshot(runtime.diagnostics_key, runtime.lifecycle_snapshot()) + metrics_snapshot = runtime.metrics_snapshot() + if metrics_snapshot: + diagnostics.add_metric_snapshot(metrics_snapshot) + return diagnostics.snapshot() + + def _ensure_sql_loader(self) -> SQLFileLoader: + """Return a SQLFileLoader instance configured with observability runtime.""" + + if self._sql_loader is None: + self._sql_loader = SQLFileLoader(runtime=self._loader_runtime) + else: + self._sql_loader.set_observability_runtime(self._loader_runtime) + return self._sql_loader + @overload def get_connection( self, @@ -281,25 +323,19 @@ def get_session( async def _create_driver_async() -> "DriverT": resolved_connection = await connection_obj # pyright: ignore - return cast( # pyright: ignore - "DriverT", - config.driver_type( - connection=resolved_connection, - statement_config=config.statement_config, - driver_features=config.driver_features, - ), + driver = config.driver_type( # pyright: ignore + connection=resolved_connection, + statement_config=config.statement_config, + driver_features=config.driver_features, ) + return config._prepare_driver(driver) # pyright: ignore return _create_driver_async() - return cast( # pyright: ignore - "DriverT", - config.driver_type( - connection=connection_obj, - statement_config=config.statement_config, - driver_features=config.driver_features, - ), + driver = config.driver_type( # pyright: ignore + connection=connection_obj, statement_config=config.statement_config, driver_features=config.driver_features ) + return config._prepare_driver(driver) # pyright: ignore @overload def provide_connection( @@ -360,7 +396,41 @@ def provide_connection( config_name = self._get_config_name(name) logger.debug("Providing connection context for config: %s", config_name, extra={"config_type": config_name}) - return config.provide_connection(*args, **kwargs) + connection_context = config.provide_connection(*args, **kwargs) + runtime = config.get_observability_runtime() + + if _is_async_context_manager(connection_context): + async_context = cast("AbstractAsyncContextManager[ConnectionT]", connection_context) + + @asynccontextmanager + async def _async_wrapper() -> AsyncIterator[ConnectionT]: + connection: ConnectionT | None = None + try: + async with async_context as conn: + connection = conn + runtime.emit_connection_create(conn) + yield conn + finally: + if connection is not None: + runtime.emit_connection_destroy(connection) + + return _async_wrapper() + + sync_context = cast("AbstractContextManager[ConnectionT]", connection_context) + + @contextmanager + def _sync_wrapper() -> Iterator[ConnectionT]: + connection: ConnectionT | None = None + try: + with sync_context as conn: + connection = conn + runtime.emit_connection_create(conn) + yield conn + finally: + if connection is not None: + runtime.emit_connection_destroy(connection) + + return _sync_wrapper() @overload def provide_session( @@ -421,7 +491,53 @@ def provide_session( config_name = self._get_config_name(name) logger.debug("Providing session context for config: %s", config_name, extra={"config_type": config_name}) - return config.provide_session(*args, **kwargs) + session_context = config.provide_session(*args, **kwargs) + runtime = config.get_observability_runtime() + + if _is_async_context_manager(session_context): + async_session = cast("AbstractAsyncContextManager[DriverT]", session_context) + + @asynccontextmanager + async def _async_session_wrapper() -> AsyncIterator[DriverT]: + driver: DriverT | None = None + try: + async with async_session as session: + driver = config._prepare_driver(session) # pyright: ignore + connection = getattr(driver, "connection", None) + if connection is not None: + runtime.emit_connection_create(connection) + runtime.emit_session_start(driver) + yield driver + finally: + if driver is not None: + runtime.emit_session_end(driver) + connection = getattr(driver, "connection", None) + if connection is not None: + runtime.emit_connection_destroy(connection) + + return _async_session_wrapper() + + sync_session = cast("AbstractContextManager[DriverT]", session_context) + + @contextmanager + def _sync_session_wrapper() -> Iterator[DriverT]: + driver: DriverT | None = None + try: + with sync_session as session: + driver = config._prepare_driver(session) # pyright: ignore + connection = getattr(driver, "connection", None) + if connection is not None: + runtime.emit_connection_create(connection) + runtime.emit_session_start(driver) + yield driver + finally: + if driver is not None: + runtime.emit_session_end(driver) + connection = getattr(driver, "connection", None) + if connection is not None: + runtime.emit_connection_destroy(connection) + + return _sync_session_wrapper() @overload def get_pool( @@ -616,12 +732,8 @@ def load_sql_files(self, *paths: "str | Path") -> None: Args: *paths: One or more file paths or directory paths to load. """ - if self._sql_loader is None: - from sqlspec.loader import SQLFileLoader - - self._sql_loader = SQLFileLoader() - - self._sql_loader.load_sql(*paths) + loader = self._ensure_sql_loader() + loader.load_sql(*paths) logger.debug("Loaded SQL files: %s", paths) def add_named_sql(self, name: str, sql: str, dialect: "str | None" = None) -> None: @@ -632,12 +744,8 @@ def add_named_sql(self, name: str, sql: str, dialect: "str | None" = None) -> No sql: Raw SQL content. dialect: Optional dialect for the SQL statement. """ - if self._sql_loader is None: - from sqlspec.loader import SQLFileLoader - - self._sql_loader = SQLFileLoader() - - self._sql_loader.add_named_sql(name, sql, dialect) + loader = self._ensure_sql_loader() + loader.add_named_sql(name, sql, dialect) logger.debug("Added named SQL: %s", name) def get_sql(self, name: str) -> "SQL": @@ -650,12 +758,8 @@ def get_sql(self, name: str) -> "SQL": Returns: SQL object ready for execution. """ - if self._sql_loader is None: - from sqlspec.loader import SQLFileLoader - - self._sql_loader = SQLFileLoader() - - return self._sql_loader.get_sql(name) + loader = self._ensure_sql_loader() + return loader.get_sql(name) def list_sql_queries(self) -> "list[str]": """List all available query names. diff --git a/sqlspec/config.py b/sqlspec/config.py index b5728c814..f7d2447ec 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -1,13 +1,15 @@ from abc import ABC, abstractmethod from collections.abc import Callable +from inspect import Signature, signature from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar, cast +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeAlias, TypeVar, cast from typing_extensions import NotRequired, TypedDict from sqlspec.core import ParameterStyle, ParameterStyleConfig, StatementConfig from sqlspec.exceptions import MissingDependencyError from sqlspec.migrations.tracker import AsyncMigrationTracker, SyncMigrationTracker +from sqlspec.observability import ObservabilityConfig from sqlspec.utils.logging import get_logger from sqlspec.utils.module_loader import ensure_pyarrow @@ -18,6 +20,7 @@ from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase from sqlspec.loader import SQLFileLoader from sqlspec.migrations.commands import AsyncMigrationCommands, SyncMigrationCommands + from sqlspec.observability import ObservabilityRuntime from sqlspec.storage import StorageCapabilities @@ -28,6 +31,7 @@ "ConfigT", "DatabaseConfigProtocol", "DriverT", + "ExtensionConfigs", "FastAPIConfig", "FlaskConfig", "LifecycleConfig", @@ -35,6 +39,8 @@ "MigrationConfig", "NoPoolAsyncConfig", "NoPoolSyncConfig", + "OpenTelemetryConfig", + "PrometheusConfig", "StarletteConfig", "SyncConfigT", "SyncDatabaseConfig", @@ -53,6 +59,15 @@ logger = get_logger("config") +DRIVER_FEATURE_LIFECYCLE_HOOKS: dict[str, str | None] = { + "on_connection_create": "connection", + "on_connection_destroy": "connection", + "on_pool_create": "pool", + "on_pool_destroy": "pool", + "on_session_start": "session", + "on_session_end": "session", +} + class LifecycleConfig(TypedDict): """Lifecycle hooks for database adapters. @@ -173,6 +188,9 @@ class LitestarConfig(TypedDict): enable_correlation_middleware: NotRequired[bool] """Enable request correlation ID middleware. Default: True""" + correlation_header: NotRequired[str] + """HTTP header to read the request correlation ID from when middleware is enabled. Default: ``X-Request-ID``""" + extra_commit_statuses: NotRequired[set[int]] """Additional HTTP status codes that trigger commit. Default: set()""" @@ -386,16 +404,79 @@ class ADKConfig(TypedDict): """ +class OpenTelemetryConfig(TypedDict): + """Configuration options for OpenTelemetry integration. + + Use in ``extension_config["otel"]``. + """ + + enabled: NotRequired[bool] + """Enable the extension. Default: True.""" + + enable_spans: NotRequired[bool] + """Enable span emission (set False to disable while keeping other settings).""" + + resource_attributes: NotRequired[dict[str, Any]] + """Additional resource attributes passed to the tracer provider factory.""" + + tracer_provider: NotRequired[Any] + """Tracer provider instance to reuse. Mutually exclusive with ``tracer_provider_factory``.""" + + tracer_provider_factory: NotRequired[Callable[[], Any]] + """Factory returning a tracer provider. Invoked lazily when spans are needed.""" + + +class PrometheusConfig(TypedDict): + """Configuration options for Prometheus metrics. + + Use in ``extension_config["prometheus"]``. + """ + + enabled: NotRequired[bool] + """Enable the extension. Default: True.""" + + namespace: NotRequired[str] + """Prometheus metric namespace. Default: ``"sqlspec"``.""" + + subsystem: NotRequired[str] + """Prometheus metric subsystem. Default: ``"driver"``.""" + + registry: NotRequired[Any] + """Custom Prometheus registry (defaults to the global registry).""" + + label_names: NotRequired[tuple[str, ...]] + """Labels applied to metrics. Default: ("driver", "operation").""" + + duration_buckets: NotRequired[tuple[float, ...]] + """Histogram buckets for query duration (seconds).""" + + +ExtensionConfigs: TypeAlias = dict[ + str, + dict[str, Any] + | LitestarConfig + | FastAPIConfig + | StarletteConfig + | FlaskConfig + | ADKConfig + | OpenTelemetryConfig + | PrometheusConfig, +] + + class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]): """Protocol defining the interface for database configurations.""" __slots__ = ( "_migration_commands", "_migration_loader", + "_observability_runtime", "_storage_capabilities", "bind_key", "driver_features", + "extension_config", "migration_config", + "observability_config", "pool_instance", "statement_config", ) @@ -419,8 +500,11 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]): statement_config: "StatementConfig" pool_instance: "PoolT | None" migration_config: "dict[str, Any] | MigrationConfig" + extension_config: "ExtensionConfigs" driver_features: "dict[str, Any]" _storage_capabilities: "StorageCapabilities | None" + observability_config: "ObservabilityConfig | None" + _observability_runtime: "ObservabilityRuntime | None" def __hash__(self) -> int: return id(self) @@ -466,6 +550,126 @@ def _build_storage_capabilities(self) -> "StorageCapabilities": capabilities["default_storage_profile"] = self.default_storage_profile return capabilities + def _init_observability(self, observability_config: "ObservabilityConfig | None" = None) -> None: + """Initialize observability attributes for the configuration.""" + + self.observability_config = observability_config + self._observability_runtime = None + + def _configure_observability_extensions(self) -> None: + """Apply extension_config hooks (otel/prometheus) to ObservabilityConfig.""" + + config_map = cast("dict[str, Any]", self.extension_config) + if not config_map: + return + updated = self.observability_config + + otel_config = cast("OpenTelemetryConfig | None", config_map.get("otel")) + if otel_config and otel_config.get("enabled", True): + from sqlspec.extensions import otel as otel_extension + + updated = otel_extension.enable_tracing( + base_config=updated, + resource_attributes=otel_config.get("resource_attributes"), + tracer_provider=otel_config.get("tracer_provider"), + tracer_provider_factory=otel_config.get("tracer_provider_factory"), + enable_spans=otel_config.get("enable_spans", True), + ) + + prom_config = cast("PrometheusConfig | None", config_map.get("prometheus")) + if prom_config and prom_config.get("enabled", True): + from sqlspec.extensions import prometheus as prometheus_extension + + label_names = tuple(prom_config.get("label_names", ("driver", "operation"))) + duration_buckets = prom_config.get("duration_buckets") + if duration_buckets is not None: + duration_buckets = tuple(duration_buckets) + + updated = prometheus_extension.enable_metrics( + base_config=updated, + namespace=prom_config.get("namespace", "sqlspec"), + subsystem=prom_config.get("subsystem", "driver"), + registry=prom_config.get("registry"), + label_names=label_names, + duration_buckets=duration_buckets, + ) + + if updated is not self.observability_config: + self.observability_config = updated + + def _promote_driver_feature_hooks(self) -> None: + lifecycle_hooks: dict[str, list[Callable[[dict[str, Any]], None]]] = {} + + for hook_name, context_key in DRIVER_FEATURE_LIFECYCLE_HOOKS.items(): + callback = self.driver_features.pop(hook_name, None) + if callback is None: + continue + callbacks = callback if isinstance(callback, (list, tuple)) else (callback,) + wrapped_callbacks = [self._wrap_driver_feature_hook(cb, context_key) for cb in callbacks] + lifecycle_hooks.setdefault(hook_name, []).extend(wrapped_callbacks) + + if not lifecycle_hooks: + return + + lifecycle_config = cast("LifecycleConfig", lifecycle_hooks) + override = ObservabilityConfig(lifecycle=lifecycle_config) + if self.observability_config is None: + self.observability_config = override + else: + self.observability_config = ObservabilityConfig.merge(self.observability_config, override) + + @staticmethod + def _wrap_driver_feature_hook( + callback: Callable[..., Any], context_key: str | None + ) -> Callable[[dict[str, Any]], None]: + try: + hook_signature: Signature = signature(callback) + except (TypeError, ValueError): # pragma: no cover - builtins without signatures + hook_signature = Signature() + + positional_params = [ + param + for param in hook_signature.parameters.values() + if param.kind in {param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD} and param.default is param.empty + ] + expects_argument = bool(positional_params) + + def handler(context: dict[str, Any]) -> None: + if not expects_argument: + callback() + return + if context_key is None: + callback(context) + return + callback(context.get(context_key)) + + return handler + + def attach_observability(self, registry_config: "ObservabilityConfig | None") -> None: + """Attach merged observability runtime composed from registry and adapter overrides.""" + + from sqlspec.observability import ObservabilityConfig as ObservabilityConfigImpl + from sqlspec.observability import ObservabilityRuntime + + merged = ObservabilityConfigImpl.merge(registry_config, self.observability_config) + self._observability_runtime = ObservabilityRuntime( + merged, bind_key=self.bind_key, config_name=type(self).__name__ + ) + + def get_observability_runtime(self) -> "ObservabilityRuntime": + """Return the attached runtime, creating a disabled instance when missing.""" + + if self._observability_runtime is None: + self.attach_observability(None) + assert self._observability_runtime is not None + return self._observability_runtime + + def _prepare_driver(self, driver: DriverT) -> DriverT: + """Attach observability runtime to driver instances before returning them.""" + + driver.attach_observability(self.get_observability_runtime()) + return driver + @staticmethod def _dependency_available(checker: "Callable[[], None]") -> bool: try: @@ -532,7 +736,8 @@ def _initialize_migration_components(self) -> None: from sqlspec.loader import SQLFileLoader from sqlspec.migrations import create_migration_commands - self._migration_loader = SQLFileLoader() + runtime = self.get_observability_runtime() + self._migration_loader = SQLFileLoader(runtime=runtime) self._migration_commands = create_migration_commands(self) # pyright: ignore def _ensure_migration_loader(self) -> "SQLFileLoader": @@ -680,7 +885,7 @@ def fix_migrations( class NoPoolSyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]): """Base class for sync database configurations that do not implement a pool.""" - __slots__ = ("connection_config", "extension_config") + __slots__ = ("connection_config",) is_async: "ClassVar[bool]" = False supports_connection_pooling: "ClassVar[bool]" = False migration_tracker_type: "ClassVar[type[Any]]" = SyncMigrationTracker @@ -693,13 +898,15 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, bind_key: "str | None" = None, - extension_config: "dict[str, dict[str, Any]] | LitestarConfig | FastAPIConfig | StarletteConfig | FlaskConfig | ADKConfig | None" = None, + extension_config: "ExtensionConfigs | None" = None, + observability_config: "ObservabilityConfig | None" = None, ) -> None: self.bind_key = bind_key self.pool_instance = None self.connection_config = connection_config or {} - self.extension_config: dict[str, dict[str, Any]] = cast("dict[str, Any]", extension_config or {}) + self.extension_config = extension_config or {} self.migration_config: dict[str, Any] | MigrationConfig = migration_config or {} + self._init_observability(observability_config) self._initialize_migration_components() if statement_config is None: @@ -712,6 +919,8 @@ def __init__( self.driver_features = driver_features or {} self._storage_capabilities = None self.driver_features.setdefault("storage_capabilities", self.storage_capabilities()) + self._promote_driver_feature_hooks() + self._configure_observability_extensions() def create_connection(self) -> ConnectionT: """Create a database connection.""" @@ -821,7 +1030,7 @@ def fix_migrations(self, dry_run: bool = False, update_database: bool = True, ye class NoPoolAsyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]): """Base class for async database configurations that do not implement a pool.""" - __slots__ = ("connection_config", "extension_config") + __slots__ = ("connection_config",) is_async: "ClassVar[bool]" = True supports_connection_pooling: "ClassVar[bool]" = False migration_tracker_type: "ClassVar[type[Any]]" = AsyncMigrationTracker @@ -834,13 +1043,15 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, bind_key: "str | None" = None, - extension_config: "dict[str, dict[str, Any]] | LitestarConfig | FastAPIConfig | StarletteConfig | FlaskConfig | ADKConfig | None" = None, + extension_config: "ExtensionConfigs | None" = None, + observability_config: "ObservabilityConfig | None" = None, ) -> None: self.bind_key = bind_key self.pool_instance = None self.connection_config = connection_config or {} - self.extension_config: dict[str, dict[str, Any]] = cast("dict[str, Any]", extension_config or {}) + self.extension_config = extension_config or {} self.migration_config: dict[str, Any] | MigrationConfig = migration_config or {} + self._init_observability(observability_config) self._initialize_migration_components() if statement_config is None: @@ -851,6 +1062,8 @@ def __init__( else: self.statement_config = statement_config self.driver_features = driver_features or {} + self._promote_driver_feature_hooks() + self._configure_observability_extensions() async def create_connection(self) -> ConnectionT: """Create a database connection.""" @@ -960,7 +1173,7 @@ async def fix_migrations(self, dry_run: bool = False, update_database: bool = Tr class SyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]): """Base class for sync database configurations with connection pooling.""" - __slots__ = ("extension_config", "pool_config") + __slots__ = ("pool_config",) is_async: "ClassVar[bool]" = False supports_connection_pooling: "ClassVar[bool]" = True migration_tracker_type: "ClassVar[type[Any]]" = SyncMigrationTracker @@ -974,13 +1187,15 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, bind_key: "str | None" = None, - extension_config: "dict[str, dict[str, Any]] | LitestarConfig | FastAPIConfig | StarletteConfig | FlaskConfig | ADKConfig | None" = None, + extension_config: "ExtensionConfigs | None" = None, + observability_config: "ObservabilityConfig | None" = None, ) -> None: self.bind_key = bind_key self.pool_instance = pool_instance self.pool_config = pool_config or {} - self.extension_config: dict[str, dict[str, Any]] = cast("dict[str, dict[str, Any]]", extension_config or {}) + self.extension_config = extension_config or {} self.migration_config: dict[str, Any] | MigrationConfig = migration_config or {} + self._init_observability(observability_config) self._initialize_migration_components() if statement_config is None: @@ -1003,11 +1218,16 @@ def create_pool(self) -> PoolT: if self.pool_instance is not None: return self.pool_instance self.pool_instance = self._create_pool() + self.get_observability_runtime().emit_pool_create(self.pool_instance) return self.pool_instance def close_pool(self) -> None: """Close the connection pool.""" + pool = self.pool_instance self._close_pool() + if pool is not None: + self.get_observability_runtime().emit_pool_destroy(pool) + self.pool_instance = None def provide_pool(self, *args: Any, **kwargs: Any) -> PoolT: """Provide pool instance.""" @@ -1124,7 +1344,7 @@ def fix_migrations(self, dry_run: bool = False, update_database: bool = True, ye class AsyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]): """Base class for async database configurations with connection pooling.""" - __slots__ = ("extension_config", "pool_config") + __slots__ = ("pool_config",) is_async: "ClassVar[bool]" = True supports_connection_pooling: "ClassVar[bool]" = True migration_tracker_type: "ClassVar[type[Any]]" = AsyncMigrationTracker @@ -1138,13 +1358,15 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, bind_key: "str | None" = None, - extension_config: "dict[str, dict[str, Any]] | LitestarConfig | FastAPIConfig | StarletteConfig | FlaskConfig | ADKConfig | None" = None, + extension_config: "ExtensionConfigs | None" = None, + observability_config: "ObservabilityConfig | None" = None, ) -> None: self.bind_key = bind_key self.pool_instance = pool_instance self.pool_config = pool_config or {} - self.extension_config: dict[str, dict[str, Any]] = cast("dict[str, dict[str, Any]]", extension_config or {}) + self.extension_config = extension_config or {} self.migration_config: dict[str, Any] | MigrationConfig = migration_config or {} + self._init_observability(observability_config) self._initialize_migration_components() if statement_config is None: @@ -1169,11 +1391,16 @@ async def create_pool(self) -> PoolT: if self.pool_instance is not None: return self.pool_instance self.pool_instance = await self._create_pool() + self.get_observability_runtime().emit_pool_create(self.pool_instance) return self.pool_instance async def close_pool(self) -> None: """Close the connection pool.""" + pool = self.pool_instance await self._close_pool() + if pool is not None: + self.get_observability_runtime().emit_pool_destroy(pool) + self.pool_instance = None async def provide_pool(self, *args: Any, **kwargs: Any) -> PoolT: """Provide pool instance.""" diff --git a/sqlspec/core/parameters/_alignment.py b/sqlspec/core/parameters/_alignment.py index e4ca41647..bdcc1dd41 100644 --- a/sqlspec/core/parameters/_alignment.py +++ b/sqlspec/core/parameters/_alignment.py @@ -160,6 +160,37 @@ def _format_identifiers(identifiers: "set[tuple[str, int | str]]") -> str: return "[" + ", ".join(formatted) + "]" +def _normalize_index_identifiers(expected: "set[tuple[str, int | str]]", actual: "set[tuple[str, int | str]]") -> bool: + """Allow positional payloads to satisfy generated param_N identifiers.""" + + if not expected or not actual: + return False + + expected_named = {value for kind, value in expected if kind == "named"} + actual_indexes = {value for kind, value in actual if kind == "index"} + + if not expected_named or not actual_indexes: + return False + + normalized_expected: set[int] = set() + for name in expected_named: + if not isinstance(name, str) or not name.startswith("param_"): + return False + suffix = name[6:] + if not suffix.isdigit(): + return False + normalized_expected.add(int(suffix)) + + if not normalized_expected: + return False + + if not all(isinstance(index, int) for index in actual_indexes): + return False + + normalized_actual = {int(index) for index in actual_indexes} + return normalized_actual == normalized_expected + + def _validate_single_parameter_set( parameter_profile: "ParameterProfile", parameters: Any, batch_index: "int | None" = None ) -> None: @@ -186,7 +217,11 @@ def _validate_single_parameter_set( msg = f"{prefix}: {actual_count} parameters provided but {expected_count} placeholders detected." raise sqlspec.exceptions.SQLSpecError(msg) - if expected_identifiers != actual_identifiers: + identifiers_match = expected_identifiers == actual_identifiers or _normalize_index_identifiers( + expected_identifiers, actual_identifiers + ) + + if not identifiers_match: msg = ( f"{prefix}: expected identifiers {_format_identifiers(expected_identifiers)}, " f"received {_format_identifiers(actual_identifiers)}." diff --git a/sqlspec/driver/_async.py b/sqlspec/driver/_async.py index 708c719dc..d6103c203 100644 --- a/sqlspec/driver/_async.py +++ b/sqlspec/driver/_async.py @@ -1,6 +1,7 @@ """Asynchronous driver protocol implementation.""" from abc import abstractmethod +from time import perf_counter from typing import TYPE_CHECKING, Any, Final, TypeVar, overload from sqlspec.core import SQL, Statement, create_arrow_result @@ -61,19 +62,58 @@ async def dispatch_statement_execution(self, statement: "SQL", connection: "Any" Returns: The result of the SQL execution """ - async with self.handle_database_exceptions(), self.with_cursor(connection) as cursor: - special_result = await self._try_special_handling(cursor, statement) - if special_result is not None: - return special_result - - if statement.is_script: - execution_result = await self._execute_script(cursor, statement) - elif statement.is_many: - execution_result = await self._execute_many(cursor, statement) - else: - execution_result = await self._execute_statement(cursor, statement) - - return self.build_statement_result(statement, execution_result) + runtime = self.observability + compiled_sql, execution_parameters = statement.compile() + processed_state = statement.get_processed_state() + operation = getattr(processed_state, "operation_type", statement.operation_type) + query_context = { + "sql": compiled_sql, + "parameters": execution_parameters, + "driver": type(self).__name__, + "operation": operation, + "is_many": statement.is_many, + "is_script": statement.is_script, + } + runtime.emit_query_start(**query_context) + span = runtime.start_query_span(compiled_sql, operation, type(self).__name__) + started = perf_counter() + + try: + async with self.handle_database_exceptions(), self.with_cursor(connection) as cursor: + special_result = await self._try_special_handling(cursor, statement) + if special_result is not None: + result = special_result + elif statement.is_script: + execution_result = await self._execute_script(cursor, statement) + result = self.build_statement_result(statement, execution_result) + elif statement.is_many: + execution_result = await self._execute_many(cursor, statement) + result = self.build_statement_result(statement, execution_result) + else: + execution_result = await self._execute_statement(cursor, statement) + result = self.build_statement_result(statement, execution_result) + except Exception as exc: # pragma: no cover + runtime.span_manager.end_span(span, error=exc) + runtime.emit_error(exc, **query_context) + raise + + runtime.span_manager.end_span(span) + duration = perf_counter() - started + runtime.emit_query_complete(**{**query_context, "rows_affected": result.rows_affected}) + runtime.emit_statement_event( + sql=compiled_sql, + parameters=execution_parameters, + driver=type(self).__name__, + operation=operation, + execution_mode=self.statement_config.execution_mode, + is_many=statement.is_many, + is_script=statement.is_script, + rows_affected=result.rows_affected, + duration_s=duration, + storage_backend=(result.metadata or {}).get("storage_backend") if hasattr(result, "metadata") else None, + started_at=started, + ) + return result @abstractmethod def with_cursor(self, connection: Any) -> Any: diff --git a/sqlspec/driver/_common.py b/sqlspec/driver/_common.py index ac7668d25..3aa16e508 100644 --- a/sqlspec/driver/_common.py +++ b/sqlspec/driver/_common.py @@ -28,6 +28,7 @@ from collections.abc import Sequence from sqlspec.core import FilterTypeT, StatementFilter + from sqlspec.observability import ObservabilityRuntime from sqlspec.typing import StatementParameters @@ -287,13 +288,17 @@ class ExecutionResult(NamedTuple): class CommonDriverAttributesMixin: """Common attributes and methods for driver adapters.""" - __slots__ = ("connection", "driver_features", "statement_config") + __slots__ = ("_observability", "connection", "driver_features", "statement_config") connection: "Any" statement_config: "StatementConfig" driver_features: "dict[str, Any]" def __init__( - self, connection: "Any", statement_config: "StatementConfig", driver_features: "dict[str, Any] | None" = None + self, + connection: "Any", + statement_config: "StatementConfig", + driver_features: "dict[str, Any] | None" = None, + observability: "ObservabilityRuntime | None" = None, ) -> None: """Initialize driver adapter with connection and configuration. @@ -301,10 +306,27 @@ def __init__( connection: Database connection instance statement_config: Statement configuration for the driver driver_features: Driver-specific features like extensions, secrets, and connection callbacks + observability: Optional runtime handling lifecycle hooks, observers, and spans """ self.connection = connection self.statement_config = statement_config self.driver_features = driver_features or {} + self._observability = observability + + def attach_observability(self, runtime: "ObservabilityRuntime") -> None: + """Attach or replace the observability runtime.""" + + self._observability = runtime + + @property + def observability(self) -> "ObservabilityRuntime": + """Return the observability runtime, creating a disabled instance when absent.""" + + if self._observability is None: + from sqlspec.observability import ObservabilityRuntime + + self._observability = ObservabilityRuntime(config_name=type(self).__name__) + return self._observability def create_execution_result( self, diff --git a/sqlspec/driver/_sync.py b/sqlspec/driver/_sync.py index aed4efa75..b349aeece 100644 --- a/sqlspec/driver/_sync.py +++ b/sqlspec/driver/_sync.py @@ -1,6 +1,7 @@ """Synchronous driver protocol implementation.""" from abc import abstractmethod +from time import perf_counter from typing import TYPE_CHECKING, Any, Final, TypeVar, overload from sqlspec.core import SQL, create_arrow_result @@ -61,19 +62,58 @@ def dispatch_statement_execution(self, statement: "SQL", connection: "Any") -> " Returns: The result of the SQL execution """ - with self.handle_database_exceptions(), self.with_cursor(connection) as cursor: - special_result = self._try_special_handling(cursor, statement) - if special_result is not None: - return special_result - - if statement.is_script: - execution_result = self._execute_script(cursor, statement) - elif statement.is_many: - execution_result = self._execute_many(cursor, statement) - else: - execution_result = self._execute_statement(cursor, statement) - - return self.build_statement_result(statement, execution_result) + runtime = self.observability + compiled_sql, execution_parameters = statement.compile() + processed_state = statement.get_processed_state() + operation = getattr(processed_state, "operation_type", statement.operation_type) + query_context = { + "sql": compiled_sql, + "parameters": execution_parameters, + "driver": type(self).__name__, + "operation": operation, + "is_many": statement.is_many, + "is_script": statement.is_script, + } + runtime.emit_query_start(**query_context) + span = runtime.start_query_span(compiled_sql, operation, type(self).__name__) + started = perf_counter() + + try: + with self.handle_database_exceptions(), self.with_cursor(connection) as cursor: + special_result = self._try_special_handling(cursor, statement) + if special_result is not None: + result = special_result + elif statement.is_script: + execution_result = self._execute_script(cursor, statement) + result = self.build_statement_result(statement, execution_result) + elif statement.is_many: + execution_result = self._execute_many(cursor, statement) + result = self.build_statement_result(statement, execution_result) + else: + execution_result = self._execute_statement(cursor, statement) + result = self.build_statement_result(statement, execution_result) + except Exception as exc: # pragma: no cover - instrumentation path + runtime.span_manager.end_span(span, error=exc) + runtime.emit_error(exc, **query_context) + raise + + runtime.span_manager.end_span(span) + duration = perf_counter() - started + runtime.emit_query_complete(**{**query_context, "rows_affected": result.rows_affected}) + runtime.emit_statement_event( + sql=compiled_sql, + parameters=execution_parameters, + driver=type(self).__name__, + operation=operation, + execution_mode=self.statement_config.execution_mode, + is_many=statement.is_many, + is_script=statement.is_script, + rows_affected=result.rows_affected, + duration_s=duration, + storage_backend=(result.metadata or {}).get("storage_backend") if hasattr(result, "metadata") else None, + started_at=started, + ) + return result @abstractmethod def with_cursor(self, connection: Any) -> Any: diff --git a/sqlspec/driver/mixins/storage.py b/sqlspec/driver/mixins/storage.py index 7e5926dc7..99e644a11 100644 --- a/sqlspec/driver/mixins/storage.py +++ b/sqlspec/driver/mixins/storage.py @@ -1,6 +1,7 @@ """Storage bridge mixin shared by sync and async drivers.""" from collections.abc import Iterable +from pathlib import Path from typing import TYPE_CHECKING, Any, cast from mypy_extensions import trait @@ -24,6 +25,7 @@ from sqlspec.core import StatementConfig, StatementFilter from sqlspec.core.result import ArrowResult from sqlspec.core.statement import SQL + from sqlspec.observability import ObservabilityRuntime from sqlspec.typing import ArrowTable, StatementParameters __all__ = ("StorageDriverMixin",) @@ -45,6 +47,11 @@ class StorageDriverMixin: storage_pipeline_factory: "type[SyncStoragePipeline | AsyncStoragePipeline] | None" = None driver_features: dict[str, Any] + if TYPE_CHECKING: + + @property + def observability(self) -> "ObservabilityRuntime": ... + def storage_capabilities(self) -> StorageCapabilities: """Return cached storage capabilities for the active driver.""" @@ -170,17 +177,89 @@ def _create_storage_job( merged["extra"] = extra return create_storage_bridge_job(status, merged) + def _write_result_to_storage_sync( + self, + result: "ArrowResult", + destination: StorageDestination, + *, + format_hint: StorageFormat | None = None, + storage_options: "dict[str, Any] | None" = None, + pipeline: "SyncStoragePipeline | None" = None, + ) -> StorageTelemetry: + runtime = self.observability + span = runtime.start_storage_span( + "write", destination=self._stringify_storage_target(destination), format_label=format_hint + ) + try: + telemetry = result.write_to_storage_sync( + destination, format_hint=format_hint, storage_options=storage_options, pipeline=pipeline + ) + except Exception as exc: # pragma: no cover - passthrough + runtime.end_storage_span(span, error=exc) + raise + telemetry = runtime.annotate_storage_telemetry(telemetry) + runtime.end_storage_span(span, telemetry=telemetry) + return telemetry + + async def _write_result_to_storage_async( + self, + result: "ArrowResult", + destination: StorageDestination, + *, + format_hint: StorageFormat | None = None, + storage_options: "dict[str, Any] | None" = None, + pipeline: "AsyncStoragePipeline | None" = None, + ) -> StorageTelemetry: + runtime = self.observability + span = runtime.start_storage_span( + "write", destination=self._stringify_storage_target(destination), format_label=format_hint + ) + try: + telemetry = await result.write_to_storage_async( + destination, format_hint=format_hint, storage_options=storage_options, pipeline=pipeline + ) + except Exception as exc: # pragma: no cover - passthrough + runtime.end_storage_span(span, error=exc) + raise + telemetry = runtime.annotate_storage_telemetry(telemetry) + runtime.end_storage_span(span, telemetry=telemetry) + return telemetry + def _read_arrow_from_storage_sync( self, source: StorageDestination, *, file_format: StorageFormat, storage_options: "dict[str, Any] | None" = None ) -> "tuple[ArrowTable, StorageTelemetry]": + runtime = self.observability + span = runtime.start_storage_span( + "read", destination=self._stringify_storage_target(source), format_label=file_format + ) pipeline = cast("SyncStoragePipeline", self._storage_pipeline()) - return pipeline.read_arrow(source, file_format=file_format, storage_options=storage_options) + try: + table, telemetry = pipeline.read_arrow(source, file_format=file_format, storage_options=storage_options) + except Exception as exc: # pragma: no cover - passthrough + runtime.end_storage_span(span, error=exc) + raise + telemetry = runtime.annotate_storage_telemetry(telemetry) + runtime.end_storage_span(span, telemetry=telemetry) + return table, telemetry async def _read_arrow_from_storage_async( self, source: StorageDestination, *, file_format: StorageFormat, storage_options: "dict[str, Any] | None" = None ) -> "tuple[ArrowTable, StorageTelemetry]": + runtime = self.observability + span = runtime.start_storage_span( + "read", destination=self._stringify_storage_target(source), format_label=file_format + ) pipeline = cast("AsyncStoragePipeline", self._storage_pipeline()) - return await pipeline.read_arrow_async(source, file_format=file_format, storage_options=storage_options) + try: + table, telemetry = await pipeline.read_arrow_async( + source, file_format=file_format, storage_options=storage_options + ) + except Exception as exc: # pragma: no cover - passthrough + runtime.end_storage_span(span, error=exc) + raise + telemetry = runtime.annotate_storage_telemetry(telemetry) + runtime.end_storage_span(span, telemetry=telemetry) + return table, telemetry @staticmethod def _build_ingest_telemetry(table: "ArrowTable", *, format_label: str = "arrow") -> StorageTelemetry: @@ -207,6 +286,14 @@ def _coerce_arrow_table(self, source: "ArrowResult | Any") -> "ArrowTable": msg = f"Unsupported Arrow source type: {type(source).__name__}" raise TypeError(msg) + @staticmethod + def _stringify_storage_target(target: StorageDestination | None) -> str | None: + if target is None: + return None + if isinstance(target, Path): + return target.as_posix() + return str(target) + @staticmethod def _arrow_table_to_rows( table: "ArrowTable", columns: "list[str] | None" = None diff --git a/sqlspec/extensions/litestar/plugin.py b/sqlspec/extensions/litestar/plugin.py index 57d7e6fe1..e6f107c5c 100644 --- a/sqlspec/extensions/litestar/plugin.py +++ b/sqlspec/extensions/litestar/plugin.py @@ -1,7 +1,10 @@ +from collections.abc import Iterable +from contextlib import suppress from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Literal, NoReturn, cast, overload from litestar.di import Provide +from litestar.middleware import DefineMiddleware from litestar.plugins import CLIPlugin, InitPluginProtocol from sqlspec.base import SQLSpec @@ -16,7 +19,11 @@ SyncDatabaseConfig, ) from sqlspec.exceptions import ImproperConfigurationError -from sqlspec.extensions.litestar._utils import get_sqlspec_scope_state, set_sqlspec_scope_state +from sqlspec.extensions.litestar._utils import ( + delete_sqlspec_scope_state, + get_sqlspec_scope_state, + set_sqlspec_scope_state, +) from sqlspec.extensions.litestar.handlers import ( autocommit_handler_maker, connection_provider_maker, @@ -26,6 +33,7 @@ session_provider_maker, ) from sqlspec.typing import NUMPY_INSTALLED, ConnectionT, PoolT, SchemaT +from sqlspec.utils.correlation import CorrelationContext from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import numpy_array_dec_hook, numpy_array_enc_hook, numpy_array_predicate @@ -36,7 +44,7 @@ from litestar import Litestar from litestar.config.app import AppConfig from litestar.datastructures.state import State - from litestar.types import BeforeMessageSendHookHandler, Scope + from litestar.types import ASGIApp, BeforeMessageSendHookHandler, Receive, Scope, Send from rich_click import Group from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase @@ -49,6 +57,18 @@ DEFAULT_CONNECTION_KEY = "db_connection" DEFAULT_POOL_KEY = "db_pool" DEFAULT_SESSION_KEY = "db_session" +DEFAULT_CORRELATION_HEADER = "x-request-id" +TRACE_CONTEXT_FALLBACK_HEADERS: tuple[str, ...] = ( + DEFAULT_CORRELATION_HEADER, + "x-correlation-id", + "traceparent", + "x-cloud-trace-context", + "grpc-trace-bin", + "x-amzn-trace-id", + "x-b3-traceid", + "x-client-trace-id", +) +CORRELATION_STATE_KEY = "sqlspec_correlation_id" __all__ = ( "DEFAULT_COMMIT_MODE", @@ -60,6 +80,78 @@ ) +def _normalize_header_list(headers: Any) -> list[str]: + if headers is None: + return [] + if isinstance(headers, str): + return [headers.lower()] + if isinstance(headers, Iterable): + normalized: list[str] = [] + for header in headers: + if not isinstance(header, str): + msg = "litestar correlation headers must be strings" + raise ImproperConfigurationError(msg) + normalized.append(header.lower()) + return normalized + msg = "litestar correlation_headers must be a string or iterable of strings" + raise ImproperConfigurationError(msg) + + +def _dedupe_headers(headers: Iterable[str]) -> list[str]: + seen: set[str] = set() + ordered: list[str] = [] + for header in headers: + lowered = header.lower() + if lowered in seen or not lowered: + continue + seen.add(lowered) + ordered.append(lowered) + return ordered + + +def _build_correlation_headers(*, primary: str, configured: list[str], auto_trace_headers: bool) -> tuple[str, ...]: + header_order: list[str] = [primary.lower()] + header_order.extend(configured) + if auto_trace_headers: + header_order.extend(TRACE_CONTEXT_FALLBACK_HEADERS) + return tuple(_dedupe_headers(header_order)) + + +class _CorrelationMiddleware: + __slots__ = ("_app", "_headers") + + def __init__(self, app: "ASGIApp", *, headers: tuple[str, ...]) -> None: + self._app = app + self._headers = headers + + async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None: + scope_type = scope.get("type") + if str(scope_type) != "http" or not self._headers: + await self._app(scope, receive, send) + return + + header_value: str | None = None + raw_headers = scope.get("headers") or [] + for header in self._headers: + for name, value in raw_headers: + if name.decode().lower() == header: + header_value = value.decode() + break + if header_value: + break + if not header_value: + header_value = CorrelationContext.generate() + + CorrelationContext.set(header_value) + set_sqlspec_scope_state(scope, CORRELATION_STATE_KEY, header_value) + try: + await self._app(scope, receive, send) + finally: + with suppress(KeyError): + delete_sqlspec_scope_state(scope, CORRELATION_STATE_KEY) + CorrelationContext.clear() + + @dataclass class _PluginConfigState: """Internal state for each database configuration.""" @@ -72,6 +164,8 @@ class _PluginConfigState: extra_commit_statuses: "set[int] | None" extra_rollback_statuses: "set[int] | None" enable_correlation_middleware: bool + correlation_header: str + correlation_headers: tuple[str, ...] = field(init=False) disable_di: bool connection_provider: "Callable[[State, Scope], AsyncGenerator[Any, None]]" = field(init=False) pool_provider: "Callable[[State, Scope], Any]" = field(init=False) @@ -114,7 +208,7 @@ class SQLSpecPlugin(InitPluginProtocol, CLIPlugin): prevent version conflicts with application migrations. """ - __slots__ = ("_plugin_configs", "_sqlspec") + __slots__ = ("_correlation_headers", "_plugin_configs", "_sqlspec") def __init__(self, sqlspec: SQLSpec, *, loader: "SQLFileLoader | None" = None) -> None: """Initialize SQLSpec plugin. @@ -135,6 +229,15 @@ def __init__(self, sqlspec: SQLSpec, *, loader: "SQLFileLoader | None" = None) - state = self._create_config_state(config_union, settings) self._plugin_configs.append(state) + correlation_headers: list[str] = [] + for state in self._plugin_configs: + if not state.enable_correlation_middleware: + continue + for header in state.correlation_headers: + if header not in correlation_headers: + correlation_headers.append(header) + self._correlation_headers = tuple(correlation_headers) + def _extract_litestar_settings( self, config: "SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]", @@ -150,6 +253,10 @@ def _extract_litestar_settings( if not config.supports_connection_pooling and pool_key == DEFAULT_POOL_KEY: pool_key = f"_{DEFAULT_POOL_KEY}_{id(config)}" + correlation_header = str(litestar_config.get("correlation_header", DEFAULT_CORRELATION_HEADER)).lower() + configured_headers = _normalize_header_list(litestar_config.get("correlation_headers")) + auto_trace_headers = bool(litestar_config.get("auto_trace_headers", True)) + return { "connection_key": connection_key, "pool_key": pool_key, @@ -158,6 +265,10 @@ def _extract_litestar_settings( "extra_commit_statuses": litestar_config.get("extra_commit_statuses"), "extra_rollback_statuses": litestar_config.get("extra_rollback_statuses"), "enable_correlation_middleware": litestar_config.get("enable_correlation_middleware", True), + "correlation_header": correlation_header, + "correlation_headers": _build_correlation_headers( + primary=correlation_header, configured=configured_headers, auto_trace_headers=auto_trace_headers + ), "disable_di": litestar_config.get("disable_di", False), } @@ -176,8 +287,10 @@ def _create_config_state( extra_commit_statuses=settings.get("extra_commit_statuses"), extra_rollback_statuses=settings.get("extra_rollback_statuses"), enable_correlation_middleware=settings["enable_correlation_middleware"], + correlation_header=settings["correlation_header"], disable_di=settings["disable_di"], ) + state.correlation_headers = tuple(settings["correlation_headers"]) if not state.disable_di: self._setup_handlers(state) @@ -289,6 +402,12 @@ def store_sqlspec_in_state() -> None: decoders_list.append((numpy_array_predicate, numpy_array_dec_hook)) # type: ignore[arg-type] app_config.type_decoders = decoders_list + if self._correlation_headers: + middleware = DefineMiddleware(_CorrelationMiddleware, headers=self._correlation_headers) + existing_middleware = list(app_config.middleware or []) + existing_middleware.append(middleware) + app_config.middleware = existing_middleware + return app_config def get_annotations( diff --git a/sqlspec/extensions/otel/__init__.py b/sqlspec/extensions/otel/__init__.py new file mode 100644 index 000000000..61e3b95b3 --- /dev/null +++ b/sqlspec/extensions/otel/__init__.py @@ -0,0 +1,58 @@ +"""Optional helpers for enabling OpenTelemetry spans via ObservabilityConfig.""" + +from collections.abc import Callable +from typing import Any + +from sqlspec.observability import ObservabilityConfig, TelemetryConfig +from sqlspec.typing import trace +from sqlspec.utils.module_loader import ensure_opentelemetry + +__all__ = ("enable_tracing",) + + +def _wrap_provider(provider: Any | None) -> Callable[[], Any] | None: + if provider is None: + return None + + def _factory() -> Any: + return provider + + return _factory + + +def enable_tracing( + *, + base_config: ObservabilityConfig | None = None, + tracer_provider: Any | None = None, + tracer_provider_factory: Callable[[], Any] | None = None, + resource_attributes: dict[str, Any] | None = None, + enable_spans: bool = True, +) -> ObservabilityConfig: + """Return an ObservabilityConfig with OpenTelemetry spans enabled. + + Args: + base_config: Existing observability config to extend. When omitted a new instance is created. + tracer_provider: Optional provider instance to reuse. Mutually exclusive with tracer_provider_factory. + tracer_provider_factory: Callable that returns a tracer provider when spans are first used. + resource_attributes: Additional attributes to attach to every span. + enable_spans: Allow disabling spans while keeping the rest of the config. + + Returns: + ObservabilityConfig with telemetry options configured for OpenTelemetry. + """ + + ensure_opentelemetry() + + if tracer_provider is not None and tracer_provider_factory is not None: + msg = "Provide either tracer_provider or tracer_provider_factory, not both" + raise ValueError(msg) + + telemetry = TelemetryConfig( + enable_spans=enable_spans, + provider_factory=tracer_provider_factory or _wrap_provider(tracer_provider) or trace.get_tracer_provider, + resource_attributes=resource_attributes, + ) + + config = base_config.copy() if base_config else ObservabilityConfig() + config.telemetry = telemetry + return config diff --git a/sqlspec/extensions/prometheus/__init__.py b/sqlspec/extensions/prometheus/__init__.py new file mode 100644 index 000000000..62c48bca3 --- /dev/null +++ b/sqlspec/extensions/prometheus/__init__.py @@ -0,0 +1,107 @@ +"""Prometheus metrics helpers that integrate with the observability statement observers.""" + +from collections.abc import Iterable +from typing import Any + +from sqlspec.observability import ObservabilityConfig +from sqlspec.observability._observer import StatementEvent, StatementObserver +from sqlspec.typing import Counter, Histogram +from sqlspec.utils.module_loader import ensure_prometheus + +__all__ = ("PrometheusStatementObserver", "enable_metrics") + + +class PrometheusStatementObserver: + """Statement observer that records Prometheus metrics.""" + + __slots__ = ("_counters", "_duration", "_label_names", "_rows") + + def __init__( + self, + *, + namespace: str = "sqlspec", + subsystem: str = "driver", + registry: Any | None = None, + label_names: Iterable[str] = ("driver", "operation"), + duration_buckets: tuple[float, ...] | None = None, + ) -> None: + self._label_names = tuple(label_names) + self._counters = Counter( + "query_total", + "Total SQL statements executed", + labelnames=self._label_names, + namespace=namespace, + subsystem=subsystem, + registry=registry, + ) + histogram_kwargs: dict[str, Any] = {} + if duration_buckets is not None: + histogram_kwargs["buckets"] = duration_buckets + + self._duration = Histogram( + "query_duration_seconds", + "SQL execution time in seconds", + labelnames=self._label_names, + namespace=namespace, + subsystem=subsystem, + registry=registry, + **histogram_kwargs, + ) + self._rows = Histogram( + "query_rows", + "Rows affected per statement", + labelnames=self._label_names, + namespace=namespace, + subsystem=subsystem, + registry=registry, + ) + + def __call__(self, event: StatementEvent) -> None: + label_values = self._label_values(event) + self._counters.labels(*label_values).inc() + self._duration.labels(*label_values).observe(max(event.duration_s, 0.0)) + if event.rows_affected is not None: + self._rows.labels(*label_values).observe(float(event.rows_affected)) + + def _label_values(self, event: StatementEvent) -> tuple[str, ...]: + values: list[str] = [] + for name in self._label_names: + if name == "driver": + values.append(event.driver) + elif name == "operation": + values.append(event.operation or "EXECUTE") + elif name == "adapter": + values.append(event.adapter) + elif name == "bind_key": + values.append(event.bind_key or "default") + else: + values.append(getattr(event, name, "")) + return tuple(values) + + +def enable_metrics( + *, + base_config: ObservabilityConfig | None = None, + namespace: str = "sqlspec", + subsystem: str = "driver", + registry: Any | None = None, + label_names: Iterable[str] = ("driver", "operation"), + duration_buckets: tuple[float, ...] | None = None, +) -> ObservabilityConfig: + """Attach a Prometheus-backed statement observer to the provided config.""" + + ensure_prometheus() + + observer = PrometheusStatementObserver( + namespace=namespace, + subsystem=subsystem, + registry=registry, + label_names=label_names, + duration_buckets=duration_buckets, + ) + + config = base_config.copy() if base_config else ObservabilityConfig() + existing: list[StatementObserver] = list(config.statement_observers or ()) + existing.append(observer) + config.statement_observers = tuple(existing) + return config diff --git a/sqlspec/loader.py b/sqlspec/loader.py index 2d679f2a8..07dbcac42 100644 --- a/sqlspec/loader.py +++ b/sqlspec/loader.py @@ -25,6 +25,7 @@ from sqlspec.utils.text import slugify if TYPE_CHECKING: + from sqlspec.observability import ObservabilityRuntime from sqlspec.storage.registry import StorageRegistry __all__ = ("CachedSQLFile", "NamedStatement", "SQLFile", "SQLFileLoader") @@ -161,14 +162,21 @@ class SQLFileLoader: and retrieves them by name. """ - __slots__ = ("_files", "_queries", "_query_to_file", "encoding", "storage_registry") + __slots__ = ("_files", "_queries", "_query_to_file", "_runtime", "encoding", "storage_registry") - def __init__(self, *, encoding: str = "utf-8", storage_registry: "StorageRegistry | None" = None) -> None: + def __init__( + self, + *, + encoding: str = "utf-8", + storage_registry: "StorageRegistry | None" = None, + runtime: "ObservabilityRuntime | None" = None, + ) -> None: """Initialize the SQL file loader. Args: encoding: Text encoding for reading SQL files. storage_registry: Storage registry for handling file URIs. + runtime: Observability runtime for instrumentation. """ self.encoding = encoding @@ -176,6 +184,16 @@ def __init__(self, *, encoding: str = "utf-8", storage_registry: "StorageRegistr self._queries: dict[str, NamedStatement] = {} self._files: dict[str, SQLFile] = {} self._query_to_file: dict[str, str] = {} + self._runtime = runtime + + def set_observability_runtime(self, runtime: "ObservabilityRuntime | None") -> None: + """Attach an observability runtime used for instrumentation.""" + + self._runtime = runtime + + def _metric(self, name: str, amount: float = 1.0) -> None: + if self._runtime is not None: + self._runtime.increment_metric(name, amount) def _raise_file_not_found(self, path: str) -> None: """Raise SQLFileNotFoundError for nonexistent file. @@ -360,8 +378,20 @@ def load_sql(self, *paths: str | Path) -> None: Args: *paths: One or more file paths or directory paths to load. """ - correlation_id = CorrelationContext.get() + runtime = self._runtime + span = None + error: Exception | None = None start_time = time.perf_counter() + path_count = len(paths) + if runtime is not None: + runtime.increment_metric("loader.load.invocations") + runtime.increment_metric("loader.paths.requested", path_count) + span = runtime.start_span( + "sqlspec.loader.load", + attributes={"sqlspec.loader.path_count": path_count, "sqlspec.loader.encoding": self.encoding}, + ) + + correlation_id = CorrelationContext.get() try: for path in paths: @@ -377,18 +407,28 @@ def load_sql(self, *paths: str | Path) -> None: elif path_obj.suffix: self._raise_file_not_found(str(path)) - except Exception as e: + except Exception as exc: + error = exc duration = time.perf_counter() - start_time logger.exception( "Failed to load SQL files after %.3fms", duration * 1000, extra={ - "error_type": type(e).__name__, + "error_type": type(exc).__name__, "duration_ms": duration * 1000, "correlation_id": correlation_id, }, ) + if runtime is not None: + runtime.increment_metric("loader.load.errors") raise + finally: + duration_ms = (time.perf_counter() - start_time) * 1000 + if runtime is not None: + runtime.record_metric("loader.last_load_ms", duration_ms) + runtime.increment_metric("loader.load.duration_ms", duration_ms) + runtime.end_span(span, error=error) + CorrelationContext.clear() def _load_directory(self, dir_path: Path) -> None: """Load all SQL files from a directory. @@ -396,6 +436,10 @@ def _load_directory(self, dir_path: Path) -> None: Args: dir_path: Directory path to load SQL files from. """ + runtime = self._runtime + if runtime is not None: + runtime.increment_metric("loader.directories.scanned") + sql_files = list(dir_path.rglob("*.sql")) if not sql_files: return @@ -416,13 +460,20 @@ def _load_single_file(self, file_path: str | Path, namespace: str | None) -> boo True if file was newly loaded, False if already cached. """ path_str = str(file_path) + runtime = self._runtime + if runtime is not None: + runtime.increment_metric("loader.files.considered") if path_str in self._files: + if runtime is not None: + runtime.increment_metric("loader.cache.hit") return False cache_config = get_cache_config() if not cache_config.compiled_cache_enabled: self._load_file_without_cache(file_path, namespace) + if runtime is not None: + runtime.increment_metric("loader.cache.miss") return True cache_key_str = self._generate_file_cache_key(file_path) @@ -447,6 +498,8 @@ def _load_single_file(self, file_path: str | Path, namespace: str | None) -> boo ) self._queries[namespaced_name] = statement self._query_to_file[namespaced_name] = path_str + if runtime is not None: + runtime.increment_metric("loader.cache.hit") return True self._load_file_without_cache(file_path, namespace) @@ -463,6 +516,10 @@ def _load_single_file(self, file_path: str | Path, namespace: str | None) -> boo cached_file_data = CachedSQLFile(sql_file=sql_file, parsed_statements=file_statements) cache.put("file", cache_key_str, cached_file_data) + if runtime is not None: + runtime.increment_metric("loader.cache.miss") + runtime.increment_metric("loader.files.loaded") + runtime.increment_metric("loader.statements.loaded", len(file_statements)) return True @@ -474,7 +531,7 @@ def _load_file_without_cache(self, file_path: str | Path, namespace: str | None) namespace: Optional namespace prefix for queries. """ path_str = str(file_path) - + runtime = self._runtime content = self._read_file_content(file_path) statements = self._parse_sql_content(content, path_str) @@ -501,6 +558,9 @@ def _load_file_without_cache(self, file_path: str | Path, namespace: str | None) ) self._queries[namespaced_name] = statement self._query_to_file[namespaced_name] = path_str + if runtime is not None: + runtime.increment_metric("loader.files.loaded") + runtime.increment_metric("loader.statements.loaded", len(statements)) def add_named_sql(self, name: str, sql: str, dialect: "str | None" = None) -> None: """Add a named SQL query directly without loading from a file. diff --git a/sqlspec/migrations/base.py b/sqlspec/migrations/base.py index 0d214c331..da8724955 100644 --- a/sqlspec/migrations/base.py +++ b/sqlspec/migrations/base.py @@ -6,7 +6,7 @@ import hashlib from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Generic, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast from sqlspec.builder import Delete, Insert, Select, Update, sql from sqlspec.builder._ddl import CreateTable @@ -17,14 +17,17 @@ from sqlspec.utils.sync_tools import await_ from sqlspec.utils.version import parse_version +if TYPE_CHECKING: + from sqlspec.config import DatabaseConfigProtocol + from sqlspec.observability import ObservabilityRuntime + __all__ = ("BaseMigrationCommands", "BaseMigrationRunner", "BaseMigrationTracker") +DriverT = TypeVar("DriverT") +ConfigT = TypeVar("ConfigT", bound="DatabaseConfigProtocol[Any, Any, Any]") logger = get_logger("migrations.base") -DriverT = TypeVar("DriverT") -ConfigT = TypeVar("ConfigT") - class BaseMigrationTracker(ABC, Generic[DriverT]): """Base class for migration version tracking.""" @@ -488,6 +491,9 @@ def __init__(self, config: ConfigT) -> None: self.project_root = Path(migration_config["project_root"]) if "project_root" in migration_config else None self.include_extensions = migration_config.get("include_extensions", []) self.extension_configs = self._parse_extension_configs() + self._runtime: ObservabilityRuntime | None = self.config.get_observability_runtime() + self._last_command_error: Exception | None = None + self._last_command_metrics: dict[str, float] | None = None def _parse_extension_configs(self) -> "dict[str, dict[str, Any]]": """Parse extension configurations from include_extensions. @@ -635,6 +641,13 @@ def init_directory(self, directory: str, package: bool = True) -> None: console.print(f"[green]Initialized migrations in {directory}[/]") + def _record_command_metric(self, name: str, value: float) -> None: + """Accumulate per-command metrics for decorator flushing.""" + + if self._last_command_metrics is None: + self._last_command_metrics = {} + self._last_command_metrics[name] = self._last_command_metrics.get(name, 0.0) + value + @abstractmethod def init(self, directory: str, package: bool = True) -> Any: """Initialize migration directory structure.""" diff --git a/sqlspec/migrations/commands.py b/sqlspec/migrations/commands.py index dd60cec09..53a7f1a7d 100644 --- a/sqlspec/migrations/commands.py +++ b/sqlspec/migrations/commands.py @@ -3,7 +3,11 @@ This module provides the main command interface for database migrations. """ -from typing import TYPE_CHECKING, Any, cast +import functools +import inspect +import time +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast from rich.console import Console from rich.table import Table @@ -27,6 +31,127 @@ logger = get_logger("migrations.commands") console = Console() +P = ParamSpec("P") +R = TypeVar("R") + + +MetadataBuilder = Callable[[dict[str, Any]], tuple[str | None, dict[str, Any]]] + + +def _bind_arguments(signature: inspect.Signature, args: tuple[Any, ...], kwargs: dict[str, Any]) -> dict[str, Any]: + bound = signature.bind_partial(*args, **kwargs) + arguments = dict(bound.arguments) + arguments.pop("self", None) + return arguments + + +def _with_command_span( + event: str, metadata_fn: "MetadataBuilder | None" = None, *, dry_run_param: str | None = "dry_run" +) -> Callable[[Callable[P, R]], Callable[P, R]]: + """Attach span lifecycle and command metric management to command methods.""" + + metric_prefix = f"migrations.command.{event}" + + def decorator(func: Callable[P, R]) -> Callable[P, R]: + signature = inspect.signature(func) + + def _prepare(self: Any, args: tuple[Any, ...], kwargs: dict[str, Any]) -> tuple[Any, bool, Any]: + runtime = getattr(self, "_runtime", None) + metadata_args = _bind_arguments(signature, args, kwargs) + dry_run = False + if dry_run_param is not None: + dry_run = bool(metadata_args.get(dry_run_param, False)) + metadata: dict[str, Any] | None = None + version: str | None = None + span = None + if runtime is not None: + runtime.increment_metric(f"{metric_prefix}.invocations") + if dry_run_param is not None and dry_run: + runtime.increment_metric(f"{metric_prefix}.dry_run") + if metadata_fn is not None: + version, metadata = metadata_fn(metadata_args) + span = runtime.start_migration_span(f"command.{event}", version=version, metadata=metadata) + return runtime, dry_run, span + + def _finalize( + self: Any, + runtime: Any, + span: Any, + start: float, + error: "Exception | None", + recorded_error: bool, + dry_run: bool, + ) -> None: + command_error = getattr(self, "_last_command_error", None) + setattr(self, "_last_command_error", None) + command_metrics = getattr(self, "_last_command_metrics", None) + setattr(self, "_last_command_metrics", None) + if runtime is None: + return + if command_error is not None and not recorded_error: + runtime.increment_metric(f"{metric_prefix}.errors") + if not dry_run and command_metrics: + for metric, value in command_metrics.items(): + runtime.increment_metric(f"{metric_prefix}.{metric}", value) + duration_ms = int((time.perf_counter() - start) * 1000) + runtime.end_migration_span(span, duration_ms=duration_ms, error=error or command_error) + + if inspect.iscoroutinefunction(func): + + @functools.wraps(func) + async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + self = args[0] + runtime, dry_run, span = _prepare(self, args, kwargs) + start = time.perf_counter() + error: Exception | None = None + error_recorded = False + try: + async_func = cast("Callable[P, Awaitable[R]]", func) + return await async_func(*args, **kwargs) + except Exception as exc: # pragma: no cover - passthrough + error = exc + if runtime is not None: + runtime.increment_metric(f"{metric_prefix}.errors") + error_recorded = True + raise + finally: + _finalize(self, runtime, span, start, error, error_recorded, dry_run) + + return cast("Callable[P, R]", async_wrapper) + + @functools.wraps(func) + def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + self = args[0] + runtime, dry_run, span = _prepare(self, args, kwargs) + start = time.perf_counter() + error: Exception | None = None + error_recorded = False + try: + return func(*args, **kwargs) + except Exception as exc: # pragma: no cover - passthrough + error = exc + if runtime is not None: + runtime.increment_metric(f"{metric_prefix}.errors") + error_recorded = True + raise + finally: + _finalize(self, runtime, span, start, error, error_recorded, dry_run) + + return cast("Callable[P, R]", sync_wrapper) + + return decorator + + +def _upgrade_metadata(args: dict[str, Any]) -> tuple[str | None, dict[str, Any]]: + revision = cast("str | None", args.get("revision")) + metadata = {"dry_run": str(args.get("dry_run", False)).lower()} + return revision, metadata + + +def _downgrade_metadata(args: dict[str, Any]) -> tuple[str | None, dict[str, Any]]: + revision = cast("str | None", args.get("revision")) + metadata = {"dry_run": str(args.get("dry_run", False)).lower()} + return revision, metadata class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]): @@ -46,7 +171,11 @@ def __init__(self, config: "SyncConfigT") -> None: context.extension_config = self.extension_configs self.runner = SyncMigrationRunner( - self.migrations_path, self._discover_extension_migrations(), context, self.extension_configs + self.migrations_path, + self._discover_extension_migrations(), + context, + self.extension_configs, + runtime=self._runtime, ) def init(self, directory: str, package: bool = True) -> None: @@ -193,6 +322,7 @@ def _synchronize_version_records(self, driver: Any) -> int: return updated_count + @_with_command_span("upgrade", metadata_fn=_upgrade_metadata) def upgrade( self, revision: str = "head", allow_missing: bool = False, auto_sync: bool = True, dry_run: bool = False ) -> None: @@ -210,6 +340,9 @@ def upgrade( Defaults to True. Can be disabled via --no-auto-sync flag. dry_run: If True, show what would be done without making changes. """ + runtime = self._runtime + applied_count = 0 + if dry_run: console.print("[bold yellow]DRY RUN MODE:[/] No database changes will be applied\n") @@ -217,8 +350,7 @@ def upgrade( self.tracker.ensure_tracking_table(driver) if auto_sync: - migration_config = getattr(self.config, "migration_config", {}) or {} - config_auto_sync = migration_config.get("auto_sync", True) + config_auto_sync = self.config.migration_config.get("auto_sync", True) if config_auto_sync: self._synchronize_version_records(driver) @@ -227,6 +359,9 @@ def upgrade( applied_set = set(applied_versions) all_migrations = self.runner.get_migration_files() + if runtime is not None: + runtime.increment_metric("migrations.command.upgrade.available", float(len(all_migrations))) + pending = [] for version, file_path in all_migrations: if version not in applied_set: @@ -240,6 +375,9 @@ def upgrade( if parsed_version <= parsed_revision: pending.append((version, file_path)) + if runtime is not None: + runtime.increment_metric("migrations.command.upgrade.pending", float(len(pending))) + if not pending: if not all_migrations: console.print( @@ -275,17 +413,22 @@ def record_version(exec_time: int, migration: "dict[str, Any]" = migration) -> N ) _, execution_time = self.runner.execute_upgrade(driver, migration, on_success=record_version) + applied_count += 1 console.print(f"[green]✓ Applied in {execution_time}ms[/]") - except Exception as e: + except Exception as exc: use_txn = self.runner.should_use_transaction(migration, self.config) rollback_msg = " (transaction rolled back)" if use_txn else "" - console.print(f"[red]✗ Failed{rollback_msg}: {e}[/]") + console.print(f"[red]✗ Failed{rollback_msg}: {exc}[/]") + self._last_command_error = exc return - if dry_run: - console.print("\n[bold yellow]Dry run complete.[/] No changes were made to the database.") + if dry_run: + console.print("\n[bold yellow]Dry run complete.[/] No changes were made to the database.") + elif applied_count: + self._record_command_metric("applied", float(applied_count)) + @_with_command_span("downgrade", metadata_fn=_downgrade_metadata) def downgrade(self, revision: str = "-1", *, dry_run: bool = False) -> None: """Downgrade to a target revision. @@ -293,15 +436,21 @@ def downgrade(self, revision: str = "-1", *, dry_run: bool = False) -> None: revision: Target revision or "-1" for one step back. dry_run: If True, show what would be done without making changes. """ + runtime = self._runtime + reverted_count = 0 + if dry_run: console.print("[bold yellow]DRY RUN MODE:[/] No database changes will be applied\n") with self.config.provide_session() as driver: self.tracker.ensure_tracking_table(driver) applied = self.tracker.get_applied_migrations(driver) + if runtime is not None: + runtime.increment_metric("migrations.command.downgrade.available", float(len(applied))) if not applied: console.print("[yellow]No migrations to downgrade[/]") return + to_revert = [] if revision == "-1": to_revert = [applied[-1]] @@ -316,6 +465,9 @@ def downgrade(self, revision: str = "-1", *, dry_run: bool = False) -> None: if parsed_migration_version > parsed_revision: to_revert.append(migration) + if runtime is not None: + runtime.increment_metric("migrations.command.downgrade.pending", float(len(to_revert))) + if not to_revert: console.print("[yellow]Nothing to downgrade[/]") return @@ -326,6 +478,8 @@ def downgrade(self, revision: str = "-1", *, dry_run: bool = False) -> None: version = migration_record["version_num"] if version not in all_files: console.print(f"[red]Migration file not found for {version}[/]") + if runtime is not None: + runtime.increment_metric("migrations.command.downgrade.missing_files") continue migration = self.runner.load_migration(all_files[version], version) @@ -342,15 +496,19 @@ def remove_version(exec_time: int, version: str = version) -> None: self.tracker.remove_migration(driver, version) _, execution_time = self.runner.execute_downgrade(driver, migration, on_success=remove_version) + reverted_count += 1 console.print(f"[green]✓ Reverted in {execution_time}ms[/]") - except Exception as e: + except Exception as exc: use_txn = self.runner.should_use_transaction(migration, self.config) rollback_msg = " (transaction rolled back)" if use_txn else "" - console.print(f"[red]✗ Failed{rollback_msg}: {e}[/]") + console.print(f"[red]✗ Failed{rollback_msg}: {exc}[/]") + self._last_command_error = exc return - if dry_run: - console.print("\n[bold yellow]Dry run complete.[/] No changes were made to the database.") + if dry_run: + console.print("\n[bold yellow]Dry run complete.[/] No changes were made to the database.") + elif reverted_count: + self._record_command_metric("applied", float(reverted_count)) def stamp(self, revision: str) -> None: """Mark database as being at a specific revision without running migrations. @@ -487,7 +645,11 @@ def __init__(self, config: "AsyncConfigT") -> None: context.extension_config = self.extension_configs self.runner = AsyncMigrationRunner( - self.migrations_path, self._discover_extension_migrations(), context, self.extension_configs + self.migrations_path, + self._discover_extension_migrations(), + context, + self.extension_configs, + runtime=self._runtime, ) async def init(self, directory: str, package: bool = True) -> None: @@ -634,6 +796,7 @@ async def _synchronize_version_records(self, driver: Any) -> int: return updated_count + @_with_command_span("upgrade", metadata_fn=_upgrade_metadata) async def upgrade( self, revision: str = "head", allow_missing: bool = False, auto_sync: bool = True, dry_run: bool = False ) -> None: @@ -651,6 +814,9 @@ async def upgrade( Defaults to True. Can be disabled via --no-auto-sync flag. dry_run: If True, show what would be done without making changes. """ + runtime = self._runtime + applied_count = 0 + if dry_run: console.print("[bold yellow]DRY RUN MODE:[/] No database changes will be applied\n") @@ -668,6 +834,9 @@ async def upgrade( applied_set = set(applied_versions) all_migrations = await self.runner.get_migration_files() + if runtime is not None: + runtime.increment_metric("migrations.command.upgrade.available", float(len(all_migrations))) + pending = [] for version, file_path in all_migrations: if version not in applied_set: @@ -680,6 +849,10 @@ async def upgrade( parsed_revision = parse_version(revision) if parsed_version <= parsed_revision: pending.append((version, file_path)) + + if runtime is not None: + runtime.increment_metric("migrations.command.upgrade.pending", float(len(pending))) + if not pending: if not all_migrations: console.print( @@ -714,16 +887,21 @@ async def record_version(exec_time: int, migration: "dict[str, Any]" = migration ) _, execution_time = await self.runner.execute_upgrade(driver, migration, on_success=record_version) + applied_count += 1 console.print(f"[green]✓ Applied in {execution_time}ms[/]") - except Exception as e: + except Exception as exc: use_txn = self.runner.should_use_transaction(migration, self.config) rollback_msg = " (transaction rolled back)" if use_txn else "" - console.print(f"[red]✗ Failed{rollback_msg}: {e}[/]") + console.print(f"[red]✗ Failed{rollback_msg}: {exc}[/]") + self._last_command_error = exc return - if dry_run: - console.print("\n[bold yellow]Dry run complete.[/] No changes were made to the database.") + if dry_run: + console.print("\n[bold yellow]Dry run complete.[/] No changes were made to the database.") + elif applied_count: + self._record_command_metric("applied", float(applied_count)) + @_with_command_span("downgrade", metadata_fn=_downgrade_metadata) async def downgrade(self, revision: str = "-1", *, dry_run: bool = False) -> None: """Downgrade to a target revision. @@ -731,6 +909,9 @@ async def downgrade(self, revision: str = "-1", *, dry_run: bool = False) -> Non revision: Target revision or "-1" for one step back. dry_run: If True, show what would be done without making changes. """ + runtime = self._runtime + reverted_count = 0 + if dry_run: console.print("[bold yellow]DRY RUN MODE:[/] No database changes will be applied\n") @@ -738,6 +919,8 @@ async def downgrade(self, revision: str = "-1", *, dry_run: bool = False) -> Non await self.tracker.ensure_tracking_table(driver) applied = await self.tracker.get_applied_migrations(driver) + if runtime is not None: + runtime.increment_metric("migrations.command.downgrade.available", float(len(applied))) if not applied: console.print("[yellow]No migrations to downgrade[/]") return @@ -754,6 +937,10 @@ async def downgrade(self, revision: str = "-1", *, dry_run: bool = False) -> Non parsed_migration_version = parse_version(migration["version_num"]) if parsed_migration_version > parsed_revision: to_revert.append(migration) + + if runtime is not None: + runtime.increment_metric("migrations.command.downgrade.pending", float(len(to_revert))) + if not to_revert: console.print("[yellow]Nothing to downgrade[/]") return @@ -764,6 +951,8 @@ async def downgrade(self, revision: str = "-1", *, dry_run: bool = False) -> Non version = migration_record["version_num"] if version not in all_files: console.print(f"[red]Migration file not found for {version}[/]") + if runtime is not None: + runtime.increment_metric("migrations.command.downgrade.missing_files") continue migration = await self.runner.load_migration(all_files[version], version) @@ -783,15 +972,19 @@ async def remove_version(exec_time: int, version: str = version) -> None: _, execution_time = await self.runner.execute_downgrade( driver, migration, on_success=remove_version ) + reverted_count += 1 console.print(f"[green]✓ Reverted in {execution_time}ms[/]") - except Exception as e: + except Exception as exc: use_txn = self.runner.should_use_transaction(migration, self.config) rollback_msg = " (transaction rolled back)" if use_txn else "" - console.print(f"[red]✗ Failed{rollback_msg}: {e}[/]") + console.print(f"[red]✗ Failed{rollback_msg}: {exc}[/]") + self._last_command_error = exc return - if dry_run: - console.print("\n[bold yellow]Dry run complete.[/] No changes were made to the database.") + if dry_run: + console.print("\n[bold yellow]Dry run complete.[/] No changes were made to the database.") + elif reverted_count: + self._record_command_metric("applied", float(reverted_count)) async def stamp(self, revision: str) -> None: """Mark database as being at a specific revision without running migrations. diff --git a/sqlspec/migrations/runner.py b/sqlspec/migrations/runner.py index 0d5ef2d7e..d7b726a6d 100644 --- a/sqlspec/migrations/runner.py +++ b/sqlspec/migrations/runner.py @@ -23,6 +23,7 @@ from collections.abc import Awaitable, Callable, Coroutine from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase + from sqlspec.observability import ObservabilityRuntime __all__ = ("AsyncMigrationRunner", "SyncMigrationRunner", "create_migration_runner") @@ -62,6 +63,7 @@ def __init__( extension_migrations: "dict[str, Path] | None" = None, context: "MigrationContext | None" = None, extension_configs: "dict[str, dict[str, Any]] | None" = None, + runtime: "ObservabilityRuntime | None" = None, ) -> None: """Initialize the migration runner. @@ -70,12 +72,14 @@ def __init__( extension_migrations: Optional mapping of extension names to their migration paths. context: Optional migration context for Python migrations. extension_configs: Optional mapping of extension names to their configurations. + runtime: Observability runtime shared with command/context consumers. """ self.migrations_path = migrations_path self.extension_migrations = extension_migrations or {} from sqlspec.loader import SQLFileLoader - self.loader = SQLFileLoader() + self.runtime = runtime + self.loader = SQLFileLoader(runtime=runtime) self.project_root: Path | None = None self.context = context self.extension_configs = extension_configs or {} @@ -84,6 +88,11 @@ def __init__( self._listing_signatures: dict[str, tuple[int, int]] = {} self._metadata_cache: dict[str, _CachedMigrationMetadata] = {} + def _metric(self, name: str, amount: float = 1.0) -> None: + if self.runtime is None: + return + self.runtime.increment_metric(name, amount) + def _iter_directory_entries(self, base_path: Path, extension_name: "str | None") -> "list[_MigrationFileEntry]": """Collect migration files discovered under a base path.""" @@ -173,6 +182,13 @@ def _log_listing_invalidation( len(removed), len(modified), ) + self._metric("migrations.listing.cache_invalidations") + if added: + self._metric("migrations.listing.added", float(len(added))) + if removed: + self._metric("migrations.listing.removed", float(len(removed))) + if modified: + self._metric("migrations.listing.modified", float(len(modified))) def _extract_version(self, filename: str) -> "str | None": """Extract version from filename. @@ -238,6 +254,8 @@ def _load_migration_listing(self) -> "list[tuple[str, Path]]": cached_listing = self._listing_cache if cached_listing is not None and self._listing_digest == digest: + self._metric("migrations.listing.cache_hit") + self._metric("migrations.listing.files_cached", float(len(cached_listing))) logger.debug("Migration listing cache hit (%d files)", len(cached_listing)) return cached_listing @@ -245,6 +263,9 @@ def _load_migration_listing(self) -> "list[tuple[str, Path]]": previous_digest = self._listing_digest previous_signatures = self._listing_signatures + self._metric("migrations.listing.cache_miss") + self._metric("migrations.listing.files_scanned", float(len(files))) + self._listing_cache = files self._listing_signatures = signatures self._listing_digest = digest @@ -278,11 +299,15 @@ def _load_migration_metadata_common(self, file_path: Path, version: "str | None" and cached_metadata.mtime_ns == stat_result.st_mtime_ns and cached_metadata.size == stat_result.st_size ): + self._metric("migrations.metadata.cache_hit") logger.debug("Migration metadata cache hit: %s", cache_key) metadata = cached_metadata.clone() metadata["file_path"] = file_path return metadata + self._metric("migrations.metadata.cache_miss") + self._metric("migrations.metadata.bytes", float(stat_result.st_size)) + content = file_path.read_text(encoding="utf-8") checksum = self._calculate_checksum(content) if version is None: @@ -436,34 +461,53 @@ def execute_upgrade( """ upgrade_sql_list = self._get_migration_sql(migration, "up") if upgrade_sql_list is None: + self._metric("migrations.upgrade.skipped") return None, 0 if use_transaction is None: config = self.context.config if self.context else None use_transaction = self.should_use_transaction(migration, config) if config else False - start_time = time.time() + runtime = self.runtime + span = None + if runtime is not None: + version = cast("str | None", migration.get("version")) + span = runtime.start_migration_span("upgrade", version=version) + runtime.increment_metric("migrations.upgrade.invocations") - if use_transaction: - try: + start_time = time.perf_counter() + execution_time = 0 + + try: + if use_transaction: driver.begin() for sql_statement in upgrade_sql_list: if sql_statement.strip(): driver.execute_script(sql_statement) - execution_time = int((time.time() - start_time) * 1000) + execution_time = int((time.perf_counter() - start_time) * 1000) if on_success: on_success(execution_time) driver.commit() - except Exception: + else: + for sql_statement in upgrade_sql_list: + if sql_statement.strip(): + driver.execute_script(sql_statement) + execution_time = int((time.perf_counter() - start_time) * 1000) + if on_success: + on_success(execution_time) + except Exception as exc: + if use_transaction: driver.rollback() - raise - else: - for sql_statement in upgrade_sql_list: - if sql_statement.strip(): - driver.execute_script(sql_statement) - execution_time = int((time.time() - start_time) * 1000) - if on_success: - on_success(execution_time) + if runtime is not None: + duration_ms = int((time.perf_counter() - start_time) * 1000) + runtime.increment_metric("migrations.upgrade.errors") + runtime.end_migration_span(span, duration_ms=duration_ms, error=exc) + raise + + if runtime is not None: + runtime.increment_metric("migrations.upgrade.applied") + runtime.increment_metric("migrations.upgrade.duration_ms", float(execution_time)) + runtime.end_migration_span(span, duration_ms=execution_time) return None, execution_time @@ -488,34 +532,53 @@ def execute_downgrade( """ downgrade_sql_list = self._get_migration_sql(migration, "down") if downgrade_sql_list is None: + self._metric("migrations.downgrade.skipped") return None, 0 if use_transaction is None: config = self.context.config if self.context else None use_transaction = self.should_use_transaction(migration, config) if config else False - start_time = time.time() + runtime = self.runtime + span = None + if runtime is not None: + version = cast("str | None", migration.get("version")) + span = runtime.start_migration_span("downgrade", version=version) + runtime.increment_metric("migrations.downgrade.invocations") - if use_transaction: - try: + start_time = time.perf_counter() + execution_time = 0 + + try: + if use_transaction: driver.begin() for sql_statement in downgrade_sql_list: if sql_statement.strip(): driver.execute_script(sql_statement) - execution_time = int((time.time() - start_time) * 1000) + execution_time = int((time.perf_counter() - start_time) * 1000) if on_success: on_success(execution_time) driver.commit() - except Exception: + else: + for sql_statement in downgrade_sql_list: + if sql_statement.strip(): + driver.execute_script(sql_statement) + execution_time = int((time.perf_counter() - start_time) * 1000) + if on_success: + on_success(execution_time) + except Exception as exc: + if use_transaction: driver.rollback() - raise - else: - for sql_statement in downgrade_sql_list: - if sql_statement.strip(): - driver.execute_script(sql_statement) - execution_time = int((time.time() - start_time) * 1000) - if on_success: - on_success(execution_time) + if runtime is not None: + duration_ms = int((time.perf_counter() - start_time) * 1000) + runtime.increment_metric("migrations.downgrade.errors") + runtime.end_migration_span(span, duration_ms=duration_ms, error=exc) + raise + + if runtime is not None: + runtime.increment_metric("migrations.downgrade.applied") + runtime.increment_metric("migrations.downgrade.duration_ms", float(execution_time)) + runtime.end_migration_span(span, duration_ms=execution_time) return None, execution_time @@ -658,34 +721,53 @@ async def execute_upgrade( """ upgrade_sql_list = await self._get_migration_sql_async(migration, "up") if upgrade_sql_list is None: + self._metric("migrations.upgrade.skipped") return None, 0 if use_transaction is None: config = self.context.config if self.context else None use_transaction = self.should_use_transaction(migration, config) if config else False - start_time = time.time() + runtime = self.runtime + span = None + if runtime is not None: + version = cast("str | None", migration.get("version")) + span = runtime.start_migration_span("upgrade", version=version) + runtime.increment_metric("migrations.upgrade.invocations") - if use_transaction: - try: + start_time = time.perf_counter() + execution_time = 0 + + try: + if use_transaction: await driver.begin() for sql_statement in upgrade_sql_list: if sql_statement.strip(): await driver.execute_script(sql_statement) - execution_time = int((time.time() - start_time) * 1000) + execution_time = int((time.perf_counter() - start_time) * 1000) if on_success: await on_success(execution_time) await driver.commit() - except Exception: + else: + for sql_statement in upgrade_sql_list: + if sql_statement.strip(): + await driver.execute_script(sql_statement) + execution_time = int((time.perf_counter() - start_time) * 1000) + if on_success: + await on_success(execution_time) + except Exception as exc: + if use_transaction: await driver.rollback() - raise - else: - for sql_statement in upgrade_sql_list: - if sql_statement.strip(): - await driver.execute_script(sql_statement) - execution_time = int((time.time() - start_time) * 1000) - if on_success: - await on_success(execution_time) + if runtime is not None: + duration_ms = int((time.perf_counter() - start_time) * 1000) + runtime.increment_metric("migrations.upgrade.errors") + runtime.end_migration_span(span, duration_ms=duration_ms, error=exc) + raise + + if runtime is not None: + runtime.increment_metric("migrations.upgrade.applied") + runtime.increment_metric("migrations.upgrade.duration_ms", float(execution_time)) + runtime.end_migration_span(span, duration_ms=execution_time) return None, execution_time @@ -710,34 +792,53 @@ async def execute_downgrade( """ downgrade_sql_list = await self._get_migration_sql_async(migration, "down") if downgrade_sql_list is None: + self._metric("migrations.downgrade.skipped") return None, 0 if use_transaction is None: config = self.context.config if self.context else None use_transaction = self.should_use_transaction(migration, config) if config else False - start_time = time.time() + runtime = self.runtime + span = None + if runtime is not None: + version = cast("str | None", migration.get("version")) + span = runtime.start_migration_span("downgrade", version=version) + runtime.increment_metric("migrations.downgrade.invocations") - if use_transaction: - try: + start_time = time.perf_counter() + execution_time = 0 + + try: + if use_transaction: await driver.begin() for sql_statement in downgrade_sql_list: if sql_statement.strip(): await driver.execute_script(sql_statement) - execution_time = int((time.time() - start_time) * 1000) + execution_time = int((time.perf_counter() - start_time) * 1000) if on_success: await on_success(execution_time) await driver.commit() - except Exception: + else: + for sql_statement in downgrade_sql_list: + if sql_statement.strip(): + await driver.execute_script(sql_statement) + execution_time = int((time.perf_counter() - start_time) * 1000) + if on_success: + await on_success(execution_time) + except Exception as exc: + if use_transaction: await driver.rollback() - raise - else: - for sql_statement in downgrade_sql_list: - if sql_statement.strip(): - await driver.execute_script(sql_statement) - execution_time = int((time.time() - start_time) * 1000) - if on_success: - await on_success(execution_time) + if runtime is not None: + duration_ms = int((time.perf_counter() - start_time) * 1000) + runtime.increment_metric("migrations.downgrade.errors") + runtime.end_migration_span(span, duration_ms=duration_ms, error=exc) + raise + + if runtime is not None: + runtime.increment_metric("migrations.downgrade.applied") + runtime.increment_metric("migrations.downgrade.duration_ms", float(execution_time)) + runtime.end_migration_span(span, duration_ms=execution_time) return None, execution_time @@ -818,6 +919,7 @@ def create_migration_runner( context: "MigrationContext | None", extension_configs: "dict[str, Any]", is_async: "Literal[False]" = False, + runtime: "ObservabilityRuntime | None" = None, ) -> SyncMigrationRunner: ... @@ -828,6 +930,7 @@ def create_migration_runner( context: "MigrationContext | None", extension_configs: "dict[str, Any]", is_async: "Literal[True]", + runtime: "ObservabilityRuntime | None" = None, ) -> AsyncMigrationRunner: ... @@ -837,6 +940,7 @@ def create_migration_runner( context: "MigrationContext | None", extension_configs: "dict[str, Any]", is_async: bool = False, + runtime: "ObservabilityRuntime | None" = None, ) -> "SyncMigrationRunner | AsyncMigrationRunner": """Factory function to create the appropriate migration runner. @@ -846,10 +950,11 @@ def create_migration_runner( context: Migration context. extension_configs: Extension configurations. is_async: Whether to create async or sync runner. + runtime: Observability runtime shared with loaders and execution steps. Returns: Appropriate migration runner instance. """ if is_async: - return AsyncMigrationRunner(migrations_path, extension_migrations, context, extension_configs) - return SyncMigrationRunner(migrations_path, extension_migrations, context, extension_configs) + return AsyncMigrationRunner(migrations_path, extension_migrations, context, extension_configs, runtime=runtime) + return SyncMigrationRunner(migrations_path, extension_migrations, context, extension_configs, runtime=runtime) diff --git a/sqlspec/observability/__init__.py b/sqlspec/observability/__init__.py new file mode 100644 index 000000000..32a82e79a --- /dev/null +++ b/sqlspec/observability/__init__.py @@ -0,0 +1,22 @@ +"""Public observability exports.""" + +from sqlspec.observability._config import ObservabilityConfig, RedactionConfig, StatementObserver, TelemetryConfig +from sqlspec.observability._diagnostics import TelemetryDiagnostics +from sqlspec.observability._dispatcher import LifecycleDispatcher +from sqlspec.observability._observer import StatementEvent, default_statement_observer, format_statement_event +from sqlspec.observability._runtime import ObservabilityRuntime +from sqlspec.observability._spans import SpanManager + +__all__ = ( + "LifecycleDispatcher", + "ObservabilityConfig", + "ObservabilityRuntime", + "RedactionConfig", + "SpanManager", + "StatementEvent", + "StatementObserver", + "TelemetryConfig", + "TelemetryDiagnostics", + "default_statement_observer", + "format_statement_event", +) diff --git a/sqlspec/observability/_config.py b/sqlspec/observability/_config.py new file mode 100644 index 000000000..c57547cee --- /dev/null +++ b/sqlspec/observability/_config.py @@ -0,0 +1,228 @@ +"""Configuration objects for the observability suite.""" + +from collections.abc import Callable, Iterable +from typing import TYPE_CHECKING, Any, cast + +if TYPE_CHECKING: # pragma: no cover - import cycle guard + from sqlspec.config import LifecycleConfig + from sqlspec.observability._observer import StatementEvent + + +StatementObserver = Callable[["StatementEvent"], None] + + +class RedactionConfig: + """Controls SQL and parameter redaction before observers run.""" + + __slots__ = ("mask_literals", "mask_parameters", "parameter_allow_list") + + def __init__( + self, + *, + mask_parameters: bool | None = None, + mask_literals: bool | None = None, + parameter_allow_list: tuple[str, ...] | Iterable[str] | None = None, + ) -> None: + self.mask_parameters = mask_parameters + self.mask_literals = mask_literals + self.parameter_allow_list = tuple(parameter_allow_list) if parameter_allow_list is not None else None + + def __hash__(self) -> int: # pragma: no cover - explicit to mirror dataclass behavior + msg = "RedactionConfig objects are mutable and unhashable" + raise TypeError(msg) + + def copy(self) -> "RedactionConfig": + """Return a copy to avoid sharing mutable state.""" + + allow_list = tuple(self.parameter_allow_list) if self.parameter_allow_list else None + return RedactionConfig( + mask_parameters=self.mask_parameters, mask_literals=self.mask_literals, parameter_allow_list=allow_list + ) + + def __repr__(self) -> str: + return f"RedactionConfig(mask_parameters={self.mask_parameters!r}, mask_literals={self.mask_literals!r}, parameter_allow_list={self.parameter_allow_list!r})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, RedactionConfig): + return NotImplemented + return ( + self.mask_parameters == other.mask_parameters + and self.mask_literals == other.mask_literals + and self.parameter_allow_list == other.parameter_allow_list + ) + + +class TelemetryConfig: + """Span emission and tracer provider settings.""" + + __slots__ = ("enable_spans", "provider_factory", "resource_attributes") + + def __init__( + self, + *, + enable_spans: bool = False, + provider_factory: Callable[[], Any] | None = None, + resource_attributes: dict[str, Any] | None = None, + ) -> None: + self.enable_spans = enable_spans + self.provider_factory = provider_factory + self.resource_attributes = dict(resource_attributes) if resource_attributes else None + + def __hash__(self) -> int: # pragma: no cover - explicit to mirror dataclass behavior + msg = "TelemetryConfig objects are mutable and unhashable" + raise TypeError(msg) + + def copy(self) -> "TelemetryConfig": + """Return a shallow copy preserving optional dictionaries.""" + + attributes = dict(self.resource_attributes) if self.resource_attributes else None + return TelemetryConfig( + enable_spans=self.enable_spans, provider_factory=self.provider_factory, resource_attributes=attributes + ) + + def __repr__(self) -> str: + return f"TelemetryConfig(enable_spans={self.enable_spans!r}, provider_factory={self.provider_factory!r}, resource_attributes={self.resource_attributes!r})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TelemetryConfig): + return NotImplemented + return ( + self.enable_spans == other.enable_spans + and self.provider_factory == other.provider_factory + and self.resource_attributes == other.resource_attributes + ) + + +class ObservabilityConfig: + """Aggregates lifecycle hooks, observers, and telemetry toggles.""" + + __slots__ = ("lifecycle", "print_sql", "redaction", "statement_observers", "telemetry") + + def __init__( + self, + *, + lifecycle: "LifecycleConfig | None" = None, + print_sql: bool | None = None, + statement_observers: tuple[StatementObserver, ...] | Iterable[StatementObserver] | None = None, + telemetry: "TelemetryConfig | None" = None, + redaction: "RedactionConfig | None" = None, + ) -> None: + self.lifecycle = lifecycle + self.print_sql = print_sql + self.statement_observers = tuple(statement_observers) if statement_observers is not None else None + self.telemetry = telemetry + self.redaction = redaction + + def __hash__(self) -> int: # pragma: no cover - explicit to mirror dataclass behavior + msg = "ObservabilityConfig objects are mutable and unhashable" + raise TypeError(msg) + + def copy(self) -> "ObservabilityConfig": + """Return a deep copy of the configuration.""" + + lifecycle_copy = _normalize_lifecycle(self.lifecycle) + observers = tuple(self.statement_observers) if self.statement_observers else None + telemetry_copy = self.telemetry.copy() if self.telemetry else None + redaction_copy = self.redaction.copy() if self.redaction else None + return ObservabilityConfig( + lifecycle=lifecycle_copy, + print_sql=self.print_sql, + statement_observers=observers, + telemetry=telemetry_copy, + redaction=redaction_copy, + ) + + @classmethod + def merge( + cls, base_config: "ObservabilityConfig | None", override_config: "ObservabilityConfig | None" + ) -> "ObservabilityConfig": + """Merge registry-level and adapter-level configuration objects.""" + + if base_config is None and override_config is None: + return cls() + + base = base_config.copy() if base_config else cls() + override = override_config + if override is None: + return base + + lifecycle = _merge_lifecycle(base.lifecycle, override.lifecycle) + observers: tuple[StatementObserver, ...] | None + if base.statement_observers and override.statement_observers: + observers = base.statement_observers + tuple(override.statement_observers) + elif override.statement_observers: + observers = tuple(override.statement_observers) + else: + observers = base.statement_observers + + print_sql = base.print_sql + if override.print_sql is not None: + print_sql = override.print_sql + + telemetry = override.telemetry.copy() if override.telemetry else base.telemetry + redaction = _merge_redaction(base.redaction, override.redaction) + + return ObservabilityConfig( + lifecycle=lifecycle, + print_sql=print_sql, + statement_observers=observers, + telemetry=telemetry, + redaction=redaction, + ) + + def __repr__(self) -> str: + return f"ObservabilityConfig(lifecycle={self.lifecycle!r}, print_sql={self.print_sql!r}, statement_observers={self.statement_observers!r}, telemetry={self.telemetry!r}, redaction={self.redaction!r})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ObservabilityConfig): + return NotImplemented + return ( + _normalize_lifecycle(self.lifecycle) == _normalize_lifecycle(other.lifecycle) + and self.print_sql == other.print_sql + and self.statement_observers == other.statement_observers + and self.telemetry == other.telemetry + and self.redaction == other.redaction + ) + + +def _merge_redaction(base: "RedactionConfig | None", override: "RedactionConfig | None") -> "RedactionConfig | None": + if base is None and override is None: + return None + if override is None: + return base.copy() if base else None + if base is None: + return override.copy() + merged = base.copy() + if override.mask_parameters is not None: + merged.mask_parameters = override.mask_parameters + if override.mask_literals is not None: + merged.mask_literals = override.mask_literals + if override.parameter_allow_list is not None: + merged.parameter_allow_list = tuple(override.parameter_allow_list) + return merged + + +def _normalize_lifecycle(config: "LifecycleConfig | None") -> "LifecycleConfig | None": + if config is None: + return None + normalized: dict[str, list[Any]] = {} + for event, hooks in config.items(): + normalized[event] = list(cast("Iterable[Any]", hooks)) + return cast("LifecycleConfig", normalized) + + +def _merge_lifecycle(base: "LifecycleConfig | None", override: "LifecycleConfig | None") -> "LifecycleConfig | None": + if base is None and override is None: + return None + if base is None: + return _normalize_lifecycle(override) + if override is None: + return _normalize_lifecycle(base) + merged_dict: dict[str, list[Any]] = cast("dict[str, list[Any]]", _normalize_lifecycle(base)) or {} + for event, hooks in override.items(): + merged_dict.setdefault(event, []) + merged_dict[event].extend(cast("Iterable[Any]", hooks)) + return cast("LifecycleConfig", merged_dict) + + +__all__ = ("ObservabilityConfig", "RedactionConfig", "StatementObserver", "TelemetryConfig") diff --git a/sqlspec/observability/_diagnostics.py b/sqlspec/observability/_diagnostics.py new file mode 100644 index 000000000..e2fdbd528 --- /dev/null +++ b/sqlspec/observability/_diagnostics.py @@ -0,0 +1,66 @@ +"""Diagnostics aggregation utilities for observability exports.""" + +from collections import defaultdict +from collections.abc import Iterable +from typing import Any + +from sqlspec.storage.pipeline import StorageDiagnostics, get_recent_storage_events, get_storage_bridge_diagnostics + + +class TelemetryDiagnostics: + """Aggregates lifecycle counters, custom metrics, and storage telemetry.""" + + __slots__ = ("_lifecycle_sections", "_metrics") + + def __init__(self) -> None: + self._lifecycle_sections: list[tuple[str, dict[str, int]]] = [] + self._metrics: StorageDiagnostics = {} + + def add_lifecycle_snapshot(self, config_key: str, counters: dict[str, int]) -> None: + """Store lifecycle counters for later snapshot generation.""" + + if not counters: + return + self._lifecycle_sections.append((config_key, counters)) + + def add_metric_snapshot(self, metrics: StorageDiagnostics) -> None: + """Store custom metric snapshots.""" + + for key, value in metrics.items(): + if key in self._metrics: + self._metrics[key] += value + else: + self._metrics[key] = value + + def snapshot(self) -> "dict[str, Any]": + """Return aggregated diagnostics payload.""" + + def _zero() -> float: + return 0.0 + + numeric_payload: defaultdict[str, float] = defaultdict(_zero) + for key, value in get_storage_bridge_diagnostics().items(): + numeric_payload[key] = float(value) + for _prefix, counters in self._lifecycle_sections: + for metric, value in counters.items(): + numeric_payload[metric] += float(value) + for metric, value in self._metrics.items(): + numeric_payload[metric] += float(value) + + payload: dict[str, Any] = dict(numeric_payload) + recent_jobs = get_recent_storage_events() + if recent_jobs: + payload["storage_bridge.recent_jobs"] = recent_jobs + return payload + + +def collect_diagnostics(sections: Iterable[tuple[str, dict[str, int]]]) -> dict[str, Any]: + """Convenience helper for aggregating sections without constructing a class.""" + + diag = TelemetryDiagnostics() + for prefix, counters in sections: + diag.add_lifecycle_snapshot(prefix, counters) + return diag.snapshot() + + +__all__ = ("TelemetryDiagnostics", "collect_diagnostics") diff --git a/sqlspec/observability/_dispatcher.py b/sqlspec/observability/_dispatcher.py new file mode 100644 index 000000000..2583167d4 --- /dev/null +++ b/sqlspec/observability/_dispatcher.py @@ -0,0 +1,129 @@ +"""Lifecycle dispatcher used by drivers and registry hooks.""" + +from typing import TYPE_CHECKING, Any, Literal + +from sqlspec.utils.logging import get_logger + +if TYPE_CHECKING: + from collections.abc import Iterable + +logger = get_logger("sqlspec.observability.lifecycle") + +LifecycleEvent = Literal[ + "on_pool_create", + "on_pool_destroy", + "on_connection_create", + "on_connection_destroy", + "on_session_start", + "on_session_end", + "on_query_start", + "on_query_complete", + "on_error", +] +EVENT_ATTRS: tuple[LifecycleEvent, ...] = ( + "on_pool_create", + "on_pool_destroy", + "on_connection_create", + "on_connection_destroy", + "on_session_start", + "on_session_end", + "on_query_start", + "on_query_complete", + "on_error", +) +GUARD_ATTRS = tuple(f"has_{name[3:]}" for name in EVENT_ATTRS) + + +class LifecycleDispatcher: + """Dispatches lifecycle hooks with guard flags and diagnostics counters.""" + + __slots__ = ("_hooks", "_counters", *GUARD_ATTRS) + + def __init__(self, hooks: "dict[str, Iterable[Any]] | None" = None) -> None: + normalized: dict[LifecycleEvent, tuple[Any, ...]] = {} + for event_name, guard_attr in zip(EVENT_ATTRS, GUARD_ATTRS, strict=False): + callables = hooks.get(event_name) if hooks else None + normalized[event_name] = tuple(callables) if callables else () + setattr(self, guard_attr, bool(normalized[event_name])) + self._hooks: dict[LifecycleEvent, tuple[Any, ...]] = normalized + self._counters: dict[LifecycleEvent, int] = dict.fromkeys(EVENT_ATTRS, 0) + + @property + def is_enabled(self) -> bool: + """Return True when at least one hook is registered.""" + + return any(self._hooks[name] for name in EVENT_ATTRS) + + def emit_pool_create(self, context: "dict[str, Any]") -> None: + """Fire pool creation hooks.""" + + self._emit("on_pool_create", context) + + def emit_pool_destroy(self, context: "dict[str, Any]") -> None: + """Fire pool destruction hooks.""" + + self._emit("on_pool_destroy", context) + + def emit_connection_create(self, context: "dict[str, Any]") -> None: + """Fire connection creation hooks.""" + + self._emit("on_connection_create", context) + + def emit_connection_destroy(self, context: "dict[str, Any]") -> None: + """Fire connection teardown hooks.""" + + self._emit("on_connection_destroy", context) + + def emit_session_start(self, context: "dict[str, Any]") -> None: + """Fire session start hooks.""" + + self._emit("on_session_start", context) + + def emit_session_end(self, context: "dict[str, Any]") -> None: + """Fire session end hooks.""" + + self._emit("on_session_end", context) + + def emit_query_start(self, context: "dict[str, Any]") -> None: + """Fire query start hooks.""" + + self._emit("on_query_start", context) + + def emit_query_complete(self, context: "dict[str, Any]") -> None: + """Fire query completion hooks.""" + + self._emit("on_query_complete", context) + + def emit_error(self, context: "dict[str, Any]") -> None: + """Fire error hooks with failure context.""" + + self._emit("on_error", context) + + def snapshot(self, *, prefix: str | None = None) -> "dict[str, int]": + """Return counter snapshot keyed for diagnostics export.""" + + metrics: dict[str, int] = {} + for event_name, count in self._counters.items(): + key = event_name.replace("on_", "lifecycle.") + if prefix: + key = f"{prefix}.{key}" + metrics[key] = count + return metrics + + def _emit(self, event: LifecycleEvent, context: "dict[str, Any]") -> None: + callbacks = self._hooks.get(event) + if not callbacks: + return + self._counters[event] += 1 + for callback in callbacks: + self._invoke_callback(callback, context, event) + + @staticmethod + def _invoke_callback(callback: Any, context: "dict[str, Any]", event: LifecycleEvent) -> None: + try: + callback(context) + except Exception as exc: # pragma: no cover - defensive logging + logger.warning("Lifecycle hook failed: event=%s error=%s", event, exc) + + +__all__ = ("LifecycleDispatcher",) diff --git a/sqlspec/observability/_observer.py b/sqlspec/observability/_observer.py new file mode 100644 index 000000000..08b9f5eb8 --- /dev/null +++ b/sqlspec/observability/_observer.py @@ -0,0 +1,180 @@ +"""Statement observer primitives for SQL execution events.""" + +from collections.abc import Callable +from time import time +from typing import Any + +from sqlspec.utils.logging import get_logger + +__all__ = ("StatementEvent", "create_event", "default_statement_observer", "format_statement_event") + + +logger = get_logger("sqlspec.observability") + + +StatementObserver = Callable[["StatementEvent"], None] + + +class StatementEvent: + """Structured payload describing a SQL execution.""" + + __slots__ = ( + "adapter", + "bind_key", + "correlation_id", + "driver", + "duration_s", + "execution_mode", + "is_many", + "is_script", + "operation", + "parameters", + "rows_affected", + "sql", + "started_at", + "storage_backend", + ) + + def __init__( + self, + *, + sql: str, + parameters: Any, + driver: str, + adapter: str, + bind_key: "str | None", + operation: str, + execution_mode: "str | None", + is_many: bool, + is_script: bool, + rows_affected: "int | None", + duration_s: float, + started_at: float, + correlation_id: "str | None", + storage_backend: "str | None", + ) -> None: + self.sql = sql + self.parameters = parameters + self.driver = driver + self.adapter = adapter + self.bind_key = bind_key + self.operation = operation + self.execution_mode = execution_mode + self.is_many = is_many + self.is_script = is_script + self.rows_affected = rows_affected + self.duration_s = duration_s + self.started_at = started_at + self.correlation_id = correlation_id + self.storage_backend = storage_backend + + def __hash__(self) -> int: # pragma: no cover - explicit to mirror dataclass behavior + msg = "StatementEvent objects are mutable and unhashable" + raise TypeError(msg) + + def as_dict(self) -> "dict[str, Any]": + """Return event payload as a dictionary.""" + + return { + "sql": self.sql, + "parameters": self.parameters, + "driver": self.driver, + "adapter": self.adapter, + "bind_key": self.bind_key, + "operation": self.operation, + "execution_mode": self.execution_mode, + "is_many": self.is_many, + "is_script": self.is_script, + "rows_affected": self.rows_affected, + "duration_s": self.duration_s, + "started_at": self.started_at, + "correlation_id": self.correlation_id, + "storage_backend": self.storage_backend, + } + + def __repr__(self) -> str: + return ( + f"StatementEvent(sql={self.sql!r}, parameters={self.parameters!r}, driver={self.driver!r}, adapter={self.adapter!r}, bind_key={self.bind_key!r}, " + f"operation={self.operation!r}, execution_mode={self.execution_mode!r}, is_many={self.is_many!r}, is_script={self.is_script!r}, rows_affected={self.rows_affected!r}, " + f"duration_s={self.duration_s!r}, started_at={self.started_at!r}, correlation_id={self.correlation_id!r}, storage_backend={self.storage_backend!r})" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, StatementEvent): + return NotImplemented + return ( + self.sql == other.sql + and self.parameters == other.parameters + and self.driver == other.driver + and self.adapter == other.adapter + and self.bind_key == other.bind_key + and self.operation == other.operation + and self.execution_mode == other.execution_mode + and self.is_many == other.is_many + and self.is_script == other.is_script + and self.rows_affected == other.rows_affected + and self.duration_s == other.duration_s + and self.started_at == other.started_at + and self.correlation_id == other.correlation_id + and self.storage_backend == other.storage_backend + ) + + +def format_statement_event(event: StatementEvent) -> str: + """Create a concise human-readable representation of a statement event.""" + + classification = [] + if event.is_script: + classification.append("script") + if event.is_many: + classification.append("many") + mode_label = ",".join(classification) if classification else "single" + rows_label = "rows=%s" % (event.rows_affected if event.rows_affected is not None else "unknown") + duration_label = f"{event.duration_s:.6f}s" + return ( + f"[{event.driver}] {event.operation} ({mode_label}, {rows_label}, duration={duration_label})\n" + f"SQL: {event.sql}\nParameters: {event.parameters}" + ) + + +def default_statement_observer(event: StatementEvent) -> None: + """Log statement execution payload when no custom observer is supplied.""" + + logger.info(format_statement_event(event), extra={"correlation_id": event.correlation_id}) + + +def create_event( + *, + sql: str, + parameters: Any, + driver: str, + adapter: str, + bind_key: "str | None", + operation: str, + execution_mode: "str | None", + is_many: bool, + is_script: bool, + rows_affected: "int | None", + duration_s: float, + correlation_id: "str | None", + storage_backend: "str | None" = None, + started_at: float | None = None, +) -> StatementEvent: + """Factory helper used by runtime to build statement events.""" + + return StatementEvent( + sql=sql, + parameters=parameters, + driver=driver, + adapter=adapter, + bind_key=bind_key, + operation=operation, + execution_mode=execution_mode, + is_many=is_many, + is_script=is_script, + rows_affected=rows_affected, + duration_s=duration_s, + started_at=started_at if started_at is not None else time(), + correlation_id=correlation_id, + storage_backend=storage_backend, + ) diff --git a/sqlspec/observability/_runtime.py b/sqlspec/observability/_runtime.py new file mode 100644 index 000000000..e14bfd569 --- /dev/null +++ b/sqlspec/observability/_runtime.py @@ -0,0 +1,381 @@ +"""Runtime helpers that bundle lifecycle, observer, and span orchestration.""" + +import re +from typing import TYPE_CHECKING, Any, cast + +from sqlspec.observability._config import ObservabilityConfig +from sqlspec.observability._dispatcher import LifecycleDispatcher +from sqlspec.observability._observer import StatementObserver, create_event, default_statement_observer +from sqlspec.observability._spans import SpanManager +from sqlspec.utils.correlation import CorrelationContext + +_LITERAL_PATTERN = re.compile(r"'(?:''|[^'])*'") + +if TYPE_CHECKING: + from collections.abc import Iterable + + from sqlspec.storage import StorageTelemetry + + +class ObservabilityRuntime: + """Aggregates dispatchers, observers, spans, and custom metrics.""" + + __slots__ = ( + "_metrics", + "_redaction", + "_statement_observers", + "bind_key", + "config", + "config_name", + "lifecycle", + "span_manager", + ) + + def __init__( + self, config: ObservabilityConfig | None = None, *, bind_key: str | None = None, config_name: str | None = None + ) -> None: + config = config.copy() if config else ObservabilityConfig() + self.config = config + self.bind_key = bind_key + self.config_name = config_name or "SQLSpecConfig" + lifecycle_config = cast("dict[str, Iterable[Any]] | None", config.lifecycle) + self.lifecycle = LifecycleDispatcher(lifecycle_config) + self.span_manager = SpanManager(config.telemetry) + observers: list[StatementObserver] = [] + if config.statement_observers: + observers.extend(config.statement_observers) + if config.print_sql: + observers.append(default_statement_observer) + self._statement_observers = tuple(observers) + self._redaction = config.redaction.copy() if config.redaction else None + self._metrics: dict[str, float] = {} + + @property + def has_statement_observers(self) -> bool: + """Return True when any observers are registered.""" + + return bool(self._statement_observers) + + @property + def diagnostics_key(self) -> str: + """Derive diagnostics key from bind key or configuration name.""" + + if self.bind_key: + return self.bind_key + return self.config_name + + def base_context(self) -> dict[str, Any]: + """Return the base payload for lifecycle events.""" + + context = {"config": self.config_name} + if self.bind_key: + context["bind_key"] = self.bind_key + correlation_id = CorrelationContext.get() + if correlation_id: + context["correlation_id"] = correlation_id + return context + + def _build_context(self, **extras: Any) -> dict[str, Any]: + context = self.base_context() + context.update({key: value for key, value in extras.items() if value is not None}) + return context + + def lifecycle_snapshot(self) -> dict[str, int]: + """Return lifecycle counters keyed under the diagnostics prefix.""" + + return self.lifecycle.snapshot(prefix=self.diagnostics_key) + + def metrics_snapshot(self) -> dict[str, float]: + """Return accumulated custom metrics with diagnostics prefix.""" + + if not self._metrics: + return {} + prefix = self.diagnostics_key + return {f"{prefix}.{name}": value for name, value in self._metrics.items()} + + def increment_metric(self, name: str, amount: float = 1.0) -> None: + """Increment a custom metric counter.""" + + self._metrics[name] = self._metrics.get(name, 0.0) + amount + + def record_metric(self, name: str, value: float) -> None: + """Set a custom metric to an explicit value.""" + + self._metrics[name] = value + + def start_migration_span( + self, event: str, *, version: "str | None" = None, metadata: "dict[str, Any] | None" = None + ) -> Any: + """Start a migration span when telemetry is enabled.""" + + if not getattr(self.span_manager, "is_enabled", False): + return None + attributes: dict[str, Any] = {"sqlspec.migration.event": event, "sqlspec.config": self.config_name} + if self.bind_key: + attributes["sqlspec.bind_key"] = self.bind_key + correlation_id = CorrelationContext.get() + if correlation_id: + attributes["sqlspec.correlation_id"] = correlation_id + if version: + attributes["sqlspec.migration.version"] = version + if metadata: + for key, value in metadata.items(): + if value is not None: + attributes[f"sqlspec.migration.{key}"] = value + return self.span_manager.start_span(f"sqlspec.migration.{event}", attributes) + + def end_migration_span( + self, span: Any, *, duration_ms: "int | None" = None, error: "Exception | None" = None + ) -> None: + """Finish a migration span, attaching optional duration metadata.""" + + if span is None: + return + setter = getattr(span, "set_attribute", None) + if setter is not None and duration_ms is not None: + setter("sqlspec.migration.duration_ms", duration_ms) + self.span_manager.end_span(span, error=error) + + def emit_pool_create(self, pool: Any) -> None: + span = self._start_lifecycle_span("pool.create", subject=pool) + try: + if getattr(self.lifecycle, "has_pool_create", False): + self.lifecycle.emit_pool_create(self._build_context(pool=pool)) + finally: + self.span_manager.end_span(span) + + def emit_pool_destroy(self, pool: Any) -> None: + span = self._start_lifecycle_span("pool.destroy", subject=pool) + try: + if getattr(self.lifecycle, "has_pool_destroy", False): + self.lifecycle.emit_pool_destroy(self._build_context(pool=pool)) + finally: + self.span_manager.end_span(span) + + def emit_connection_create(self, connection: Any) -> None: + span = self._start_lifecycle_span("connection.create", subject=connection) + try: + if getattr(self.lifecycle, "has_connection_create", False): + self.lifecycle.emit_connection_create(self._build_context(connection=connection)) + finally: + self.span_manager.end_span(span) + + def emit_connection_destroy(self, connection: Any) -> None: + span = self._start_lifecycle_span("connection.destroy", subject=connection) + try: + if getattr(self.lifecycle, "has_connection_destroy", False): + self.lifecycle.emit_connection_destroy(self._build_context(connection=connection)) + finally: + self.span_manager.end_span(span) + + def emit_session_start(self, session: Any) -> None: + span = self._start_lifecycle_span("session.start", subject=session) + try: + if getattr(self.lifecycle, "has_session_start", False): + self.lifecycle.emit_session_start(self._build_context(session=session)) + finally: + self.span_manager.end_span(span) + + def emit_session_end(self, session: Any) -> None: + span = self._start_lifecycle_span("session.end", subject=session) + try: + if getattr(self.lifecycle, "has_session_end", False): + self.lifecycle.emit_session_end(self._build_context(session=session)) + finally: + self.span_manager.end_span(span) + + def emit_query_start(self, **extras: Any) -> None: + if getattr(self.lifecycle, "has_query_start", False): + self.lifecycle.emit_query_start(self._build_context(**extras)) + + def emit_query_complete(self, **extras: Any) -> None: + if getattr(self.lifecycle, "has_query_complete", False): + self.lifecycle.emit_query_complete(self._build_context(**extras)) + + def emit_error(self, exception: Exception, **extras: Any) -> None: + if getattr(self.lifecycle, "has_error", False): + payload = self._build_context(exception=exception) + payload.update({key: value for key, value in extras.items() if value is not None}) + self.lifecycle.emit_error(payload) + self.increment_metric("errors", 1.0) + + def emit_statement_event( + self, + *, + sql: str, + parameters: Any, + driver: str, + operation: str, + execution_mode: str | None, + is_many: bool, + is_script: bool, + rows_affected: int | None, + duration_s: float, + storage_backend: str | None, + started_at: float | None = None, + ) -> None: + """Emit a statement event to all registered observers.""" + + if not self._statement_observers: + return + sanitized_sql = self._redact_sql(sql) + sanitized_params = self._redact_parameters(parameters) + correlation_id = CorrelationContext.get() + event = create_event( + sql=sanitized_sql, + parameters=sanitized_params, + driver=driver, + adapter=self.config_name, + bind_key=self.bind_key, + operation=operation, + execution_mode=execution_mode, + is_many=is_many, + is_script=is_script, + rows_affected=rows_affected, + duration_s=duration_s, + correlation_id=correlation_id, + storage_backend=storage_backend, + started_at=started_at, + ) + for observer in self._statement_observers: + observer(event) + + def start_query_span(self, sql: str, operation: str, driver: str) -> Any: + """Start a query span with runtime metadata.""" + + correlation_id = CorrelationContext.get() + return self.span_manager.start_query_span( + driver=driver, + adapter=self.config_name, + bind_key=self.bind_key, + sql=sql, + operation=operation, + correlation_id=correlation_id, + ) + + def start_storage_span( + self, operation: str, *, destination: str | None = None, format_label: str | None = None + ) -> Any: + """Start a storage bridge span for read/write operations.""" + + if not getattr(self.span_manager, "is_enabled", False): + return None + attributes: dict[str, Any] = {"sqlspec.storage.operation": operation, "sqlspec.config": self.config_name} + if self.bind_key: + attributes["sqlspec.bind_key"] = self.bind_key + correlation_id = CorrelationContext.get() + if correlation_id: + attributes["sqlspec.correlation_id"] = correlation_id + if destination: + attributes["sqlspec.storage.destination"] = destination + if format_label: + attributes["sqlspec.storage.format"] = format_label + return self.span_manager.start_span(f"sqlspec.storage.{operation}", attributes) + + def start_span(self, name: str, *, attributes: dict[str, Any] | None = None) -> Any: + """Start a custom span enriched with configuration context.""" + + if not getattr(self.span_manager, "is_enabled", False): + return None + merged: dict[str, Any] = attributes.copy() if attributes else {} + merged.setdefault("sqlspec.config", self.config_name) + if self.bind_key: + merged.setdefault("sqlspec.bind_key", self.bind_key) + correlation_id = CorrelationContext.get() + if correlation_id: + merged.setdefault("sqlspec.correlation_id", correlation_id) + return self.span_manager.start_span(name, merged) + + def end_span(self, span: Any, *, error: Exception | None = None) -> None: + """Finish a custom span.""" + + self.span_manager.end_span(span, error=error) + + def end_storage_span( + self, span: Any, *, telemetry: "StorageTelemetry | None" = None, error: Exception | None = None + ) -> None: + """Finish a storage span, attaching telemetry metadata when available.""" + + if span is None: + return + if telemetry: + telemetry = self.annotate_storage_telemetry(telemetry) + self._attach_storage_telemetry(span, telemetry) + self.span_manager.end_span(span, error=error) + + def annotate_storage_telemetry(self, telemetry: "StorageTelemetry") -> "StorageTelemetry": + """Add bind key / config / correlation metadata to telemetry payloads.""" + + annotated = telemetry + base = self.base_context() + correlation_id = base.get("correlation_id") + if correlation_id and not annotated.get("correlation_id"): + annotated["correlation_id"] = correlation_id + annotated.setdefault("config", self.config_name) + if self.bind_key and not annotated.get("bind_key"): + annotated["bind_key"] = self.bind_key + return annotated + + def _start_lifecycle_span(self, event: str, subject: Any | None = None) -> Any: + if not getattr(self.span_manager, "is_enabled", False): + return None + attributes: dict[str, Any] = {"sqlspec.lifecycle.event": event, "sqlspec.config": self.config_name} + if self.bind_key: + attributes["sqlspec.bind_key"] = self.bind_key + correlation_id = CorrelationContext.get() + if correlation_id: + attributes["sqlspec.correlation_id"] = correlation_id + if subject is not None: + attributes["sqlspec.lifecycle.subject_type"] = type(subject).__name__ + return self.span_manager.start_span(f"sqlspec.lifecycle.{event}", attributes) + + def _attach_storage_telemetry(self, span: Any, telemetry: "StorageTelemetry") -> None: + setter = getattr(span, "set_attribute", None) + if setter is None: + return + if "backend" in telemetry and telemetry["backend"] is not None: + setter("sqlspec.storage.backend", telemetry["backend"]) + if "bytes_processed" in telemetry and telemetry["bytes_processed"] is not None: + setter("sqlspec.storage.bytes_processed", telemetry["bytes_processed"]) + if "rows_processed" in telemetry and telemetry["rows_processed"] is not None: + setter("sqlspec.storage.rows_processed", telemetry["rows_processed"]) + if "destination" in telemetry and telemetry["destination"] is not None: + setter("sqlspec.storage.destination", telemetry["destination"]) + if "format" in telemetry and telemetry["format"] is not None: + setter("sqlspec.storage.format", telemetry["format"]) + if "duration_s" in telemetry and telemetry["duration_s"] is not None: + setter("sqlspec.storage.duration_s", telemetry["duration_s"]) + if "correlation_id" in telemetry and telemetry["correlation_id"] is not None: + setter("sqlspec.correlation_id", telemetry["correlation_id"]) + + def _redact_sql(self, sql: str) -> str: + config = self._redaction + if config is None or not config.mask_literals: + return sql + return _LITERAL_PATTERN.sub("'***'", sql) + + def _redact_parameters(self, parameters: Any) -> Any: + config = self._redaction + if config is None or not config.mask_parameters: + return parameters + allow_list = set(config.parameter_allow_list or ()) + return _mask_parameters(parameters, allow_list) + + +def _mask_parameters(value: Any, allow_list: set[str]) -> Any: + if isinstance(value, dict): + masked: dict[str, Any] = {} + for key, item in value.items(): + if allow_list and key in allow_list: + masked[key] = _mask_parameters(item, allow_list) + else: + masked[key] = "***" + return masked + if isinstance(value, list): + return [_mask_parameters(item, allow_list) for item in value] + if isinstance(value, tuple): + return tuple(_mask_parameters(item, allow_list) for item in value) + return "***" + + +__all__ = ("ObservabilityRuntime",) diff --git a/sqlspec/observability/_spans.py b/sqlspec/observability/_spans.py new file mode 100644 index 000000000..b54ba0874 --- /dev/null +++ b/sqlspec/observability/_spans.py @@ -0,0 +1,148 @@ +"""Optional OpenTelemetry span helpers.""" + +from importlib import import_module +from typing import Any + +from sqlspec.observability._config import TelemetryConfig +from sqlspec.utils.logging import get_logger + +logger = get_logger("sqlspec.observability.spans") + + +class SpanManager: + """Lazy OpenTelemetry span manager with graceful degradation.""" + + __slots__ = ( + "_enabled", + "_provider_factory", + "_resource_attributes", + "_span_kind", + "_status_cls", + "_status_code_cls", + "_trace_api", + "_tracer", + ) + + def __init__(self, telemetry: TelemetryConfig | None = None) -> None: + telemetry = telemetry or TelemetryConfig() + self._enabled = bool(telemetry.enable_spans) + self._provider_factory = telemetry.provider_factory + self._resource_attributes = dict(telemetry.resource_attributes or {}) + self._trace_api: Any | None = None + self._status_cls: Any | None = None + self._status_code_cls: Any | None = None + self._span_kind: Any | None = None + self._tracer: Any | None = None + if self._enabled: + self._resolve_api() + + @property + def is_enabled(self) -> bool: + """Return True once OpenTelemetry spans are available.""" + + return bool(self._enabled and self._tracer) + + def start_query_span( + self, + *, + driver: str, + adapter: str, + bind_key: str | None, + sql: str, + operation: str, + connection_info: dict[str, Any] | None = None, + storage_backend: str | None = None, + correlation_id: str | None = None, + ) -> Any: + """Start a query span with SQLSpec semantic attributes.""" + + if not self._enabled: + return None + attributes: dict[str, Any] = { + "db.system": adapter.lower(), + "db.operation": operation, + "db.statement": sql, + "sqlspec.driver": driver, + } + if bind_key: + attributes["sqlspec.bind_key"] = bind_key + if storage_backend: + attributes["sqlspec.storage_backend"] = storage_backend + if correlation_id: + attributes["sqlspec.correlation_id"] = correlation_id + if connection_info: + attributes.update(connection_info) + attributes.update(self._resource_attributes) + return self._start_span("sqlspec.query", attributes) + + def start_span(self, name: str, attributes: dict[str, Any] | None = None) -> Any: + """Start a generic span when instrumentation needs a custom name.""" + + if not self._enabled: + return None + merged = dict(self._resource_attributes) + if attributes: + merged.update(attributes) + return self._start_span(name, merged) + + def end_span(self, span: Any, error: Exception | None = None) -> None: + """Close a span and record errors when provided.""" + + if span is None: + return + try: + if error and self._status_cls and self._status_code_cls: + span.record_exception(error) + status = self._status_cls(self._status_code_cls.ERROR, str(error)) + span.set_status(status) + span.end() + except Exception as exc: # pragma: no cover - defensive logging + logger.debug("Failed to finish span: %s", exc) + + def _start_span(self, name: str, attributes: dict[str, Any]) -> Any: + tracer = self._get_tracer() + if tracer is None: + return None + span_kind = self._span_kind + if span_kind is None: + return tracer.start_span(name=name, attributes=attributes) + return tracer.start_span(name=name, attributes=attributes, kind=span_kind) + + def _get_tracer(self) -> Any: + if not self._enabled: + return None + if self._tracer is None: + self._resolve_api() + return self._tracer + + def _resolve_api(self) -> None: + try: + trace = import_module("opentelemetry.trace") + status_module = import_module("opentelemetry.trace.status") + except ImportError: + logger.debug("OpenTelemetry dependency missing - disabling spans") + self._enabled = False + self._tracer = None + return + + span_kind_cls = trace.SpanKind + status_cls = status_module.Status + status_code_cls = status_module.StatusCode + + provider = None + if self._provider_factory is not None: + try: + provider = self._provider_factory() + except Exception as exc: # pragma: no cover - defensive logging + logger.debug("Tracer provider factory failed: %s", exc) + if provider and hasattr(provider, "get_tracer"): + self._tracer = provider.get_tracer("sqlspec.observability") + else: + self._tracer = trace.get_tracer("sqlspec.observability") + self._trace_api = trace + self._status_cls = status_cls + self._status_code_cls = status_code_cls + self._span_kind = span_kind_cls.CLIENT + + +__all__ = ("SpanManager",) diff --git a/sqlspec/storage/pipeline.py b/sqlspec/storage/pipeline.py index 98b179651..9d21559f6 100644 --- a/sqlspec/storage/pipeline.py +++ b/sqlspec/storage/pipeline.py @@ -1,5 +1,6 @@ """Storage pipeline scaffolding for driver-aware storage bridge.""" +from collections import deque from functools import partial from pathlib import Path from time import perf_counter, time @@ -28,18 +29,23 @@ "StorageBridgeJob", "StorageCapabilities", "StorageDestination", + "StorageDiagnostics", "StorageFormat", "StorageLoadRequest", "StorageTelemetry", "SyncStoragePipeline", "create_storage_bridge_job", + "get_recent_storage_events", "get_storage_bridge_diagnostics", "get_storage_bridge_metrics", + "record_storage_diagnostic_event", + "reset_storage_bridge_events", "reset_storage_bridge_metrics", ) StorageFormat = Literal["jsonl", "json", "parquet", "arrow-ipc"] StorageDestination: TypeAlias = str | Path +StorageDiagnostics: TypeAlias = dict[str, float] class StorageCapabilities(TypedDict): @@ -96,6 +102,9 @@ class StorageTelemetry(TypedDict, total=False): format: str extra: "dict[str, Any]" backend: str + correlation_id: str + config: str + bind_key: str class StorageBridgeJob(NamedTuple): @@ -131,6 +140,7 @@ def reset(self) -> None: _METRICS = _StorageBridgeMetrics() +_RECENT_STORAGE_EVENTS: "deque[StorageTelemetry]" = deque(maxlen=25) def get_storage_bridge_metrics() -> "dict[str, int]": @@ -145,19 +155,39 @@ def reset_storage_bridge_metrics() -> None: _METRICS.reset() +def record_storage_diagnostic_event(telemetry: StorageTelemetry) -> None: + """Record telemetry for inclusion in diagnostics snapshots.""" + + _RECENT_STORAGE_EVENTS.append(cast("StorageTelemetry", dict(telemetry))) + + +def get_recent_storage_events() -> "list[StorageTelemetry]": + """Return recent storage telemetry events (most recent first).""" + + return [cast("StorageTelemetry", dict(entry)) for entry in _RECENT_STORAGE_EVENTS] + + +def reset_storage_bridge_events() -> None: + """Clear recorded storage telemetry events.""" + + _RECENT_STORAGE_EVENTS.clear() + + def create_storage_bridge_job(status: str, telemetry: StorageTelemetry) -> StorageBridgeJob: """Create a storage bridge job handle with a unique identifier.""" - return StorageBridgeJob(job_id=str(uuid4()), status=status, telemetry=telemetry) + job = StorageBridgeJob(job_id=str(uuid4()), status=status, telemetry=telemetry) + record_storage_diagnostic_event(job.telemetry) + return job -def get_storage_bridge_diagnostics() -> "dict[str, int]": +def get_storage_bridge_diagnostics() -> "StorageDiagnostics": """Return aggregated storage bridge + serializer cache metrics.""" - diagnostics = dict(get_storage_bridge_metrics()) + diagnostics: dict[str, float] = {key: float(value) for key, value in get_storage_bridge_metrics().items()} serializer_metrics = get_serializer_metrics() for key, value in serializer_metrics.items(): - diagnostics[f"serializer.{key}"] = value + diagnostics[f"serializer.{key}"] = float(value) return diagnostics diff --git a/tests/integration/test_adapters/test_aiosqlite/test_connection.py b/tests/integration/test_adapters/test_aiosqlite/test_connection.py index c1605df04..86bfdca38 100644 --- a/tests/integration/test_adapters/test_aiosqlite/test_connection.py +++ b/tests/integration/test_adapters/test_aiosqlite/test_connection.py @@ -4,12 +4,16 @@ from __future__ import annotations from pathlib import Path +from typing import Any, cast from uuid import uuid4 import pytest +from sqlspec import SQLSpec from sqlspec.adapters.aiosqlite import AiosqliteConfig, AiosqliteDriver +from sqlspec.config import LifecycleConfig from sqlspec.core import SQLResult +from sqlspec.observability import ObservabilityConfig pytestmark = pytest.mark.xdist_group("sqlite") @@ -180,6 +184,44 @@ async def test_config_with_kwargs_override(tmp_path: Path) -> None: await config.close_pool() +async def test_aiosqlite_disabled_observability_has_zero_counts() -> None: + """Lifecycle counters remain zero when observability hooks are disabled.""" + + spec = SQLSpec() + config = AiosqliteConfig() + spec.add_config(config) + + async with spec.provide_session(config) as driver: + await driver.execute("SELECT 1") + + runtime = config.get_observability_runtime() + assert all(value == 0 for value in runtime.lifecycle_snapshot().values()) + await config.close_pool() + + +async def test_aiosqlite_observability_hook_tracks_queries() -> None: + """Lifecycle hooks should record query counts in async drivers.""" + + captured: list[dict[str, Any]] = [] + + def hook(ctx: dict[str, Any]) -> None: + captured.append(ctx) + + spec = SQLSpec() + config = AiosqliteConfig( + observability_config=ObservabilityConfig(lifecycle=cast("LifecycleConfig", {"on_query_start": [hook]})) + ) + spec.add_config(config) + + async with spec.provide_session(config) as driver: + await driver.execute("SELECT 1") + + runtime = config.get_observability_runtime() + assert runtime.lifecycle_snapshot()["AiosqliteConfig.lifecycle.query_start"] == 1 + assert captured + await config.close_pool() + + async def test_config_memory_database_conversion() -> None: """Test that :memory: databases are converted to shared memory.""" diff --git a/tests/integration/test_adapters/test_asyncpg/test_extensions/test_adk/test_owner_id_column.py b/tests/integration/test_adapters/test_asyncpg/test_extensions/test_adk/test_owner_id_column.py index cfe06b2fd..d2c884e9f 100644 --- a/tests/integration/test_adapters/test_asyncpg/test_extensions/test_adk/test_owner_id_column.py +++ b/tests/integration/test_adapters/test_asyncpg/test_extensions/test_adk/test_owner_id_column.py @@ -1,13 +1,14 @@ """Tests for AsyncPG ADK store owner_id_column support.""" from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, cast import asyncpg import pytest from sqlspec.adapters.asyncpg import AsyncpgConfig from sqlspec.adapters.asyncpg.adk import AsyncpgADKStore +from sqlspec.config import ADKConfig, ExtensionConfigs pytestmark = [pytest.mark.xdist_group("postgres"), pytest.mark.asyncpg, pytest.mark.integration] @@ -19,11 +20,10 @@ def _make_config_with_owner_id( events_table: str = "adk_events", ) -> AsyncpgConfig: """Helper to create config with ADK extension config.""" - extension_config: dict[str, dict[str, Any]] = { - "adk": {"session_table": session_table, "events_table": events_table} - } + extension_config = cast("ExtensionConfigs", {"adk": {"session_table": session_table, "events_table": events_table}}) + adk_settings = cast("ADKConfig", extension_config["adk"]) if owner_id_column is not None: - extension_config["adk"]["owner_id_column"] = owner_id_column + adk_settings["owner_id_column"] = owner_id_column return AsyncpgConfig( pool_config={ diff --git a/tests/integration/test_adapters/test_duckdb/test_connection.py b/tests/integration/test_adapters/test_duckdb/test_connection.py index 5a12464a4..ff12364fb 100644 --- a/tests/integration/test_adapters/test_duckdb/test_connection.py +++ b/tests/integration/test_adapters/test_duckdb/test_connection.py @@ -5,13 +5,16 @@ import tempfile import time from pathlib import Path -from typing import Any +from typing import Any, cast from uuid import uuid4 import pytest +from sqlspec import SQLSpec from sqlspec.adapters.duckdb import DuckDBConfig, DuckDBConnection +from sqlspec.config import LifecycleConfig from sqlspec.core import SQLResult +from sqlspec.observability import ObservabilityConfig pytestmark = pytest.mark.xdist_group("duckdb") @@ -129,16 +132,19 @@ def test_connection_with_hook() -> None: """Test DuckDB connection with connection creation hook.""" hook_executed = False - def connection_hook(conn: DuckDBConnection) -> None: + def connection_hook(connection: DuckDBConnection) -> None: nonlocal hook_executed hook_executed = True - conn.execute("SET threads = 1") + connection.execute("SET threads = 1") config = DuckDBConfig( pool_config={"database": ":memory:"}, driver_features={"on_connection_create": connection_hook} ) - with config.provide_session() as session: + registry = SQLSpec() + registry.add_config(config) + + with registry.provide_session(config) as session: assert hook_executed is True result = session.execute("SELECT current_setting('threads')") @@ -196,6 +202,42 @@ def test_connection_with_logging_settings() -> None: assert result.data[0]["message"] == "logging_test" +def test_duckdb_disabled_observability_has_zero_lifecycle_counts() -> None: + """Ensure lifecycle counters stay zero when no hooks are registered.""" + + registry = SQLSpec() + config = create_permissive_config() + registry.add_config(config) + + with registry.provide_session(config) as session: + session.execute("SELECT 1") + + runtime = config.get_observability_runtime() + assert all(value == 0 for value in runtime.lifecycle_snapshot().values()) + + +def test_duckdb_observability_hook_records_query_counts() -> None: + """Lifecycle hooks should increment counters when configured.""" + + queries: list[dict[str, Any]] = [] + + def hook(context: dict[str, Any]) -> None: + queries.append(context) + + registry = SQLSpec() + config = create_permissive_config( + observability_config=ObservabilityConfig(lifecycle=cast(LifecycleConfig, {"on_query_start": [hook]})) + ) + registry.add_config(config) + + with registry.provide_session(config) as session: + session.execute("SELECT 1") + + runtime = config.get_observability_runtime() + assert runtime.lifecycle_snapshot()["DuckDBConfig.lifecycle.query_start"] == 1 + assert queries, "Lifecycle hook should capture context" + + def test_connection_with_extension_settings() -> None: """Test DuckDB connection with extension-related settings.""" config = create_permissive_config( diff --git a/tests/integration/test_extensions/test_litestar/test_correlation_middleware.py b/tests/integration/test_extensions/test_litestar/test_correlation_middleware.py new file mode 100644 index 000000000..c5c7b59b4 --- /dev/null +++ b/tests/integration/test_extensions/test_litestar/test_correlation_middleware.py @@ -0,0 +1,91 @@ +from typing import Any, cast + +from litestar import Litestar, get +from litestar.testing import TestClient + +from sqlspec import SQLSpec +from sqlspec.adapters.sqlite import SqliteConfig +from sqlspec.config import ExtensionConfigs +from sqlspec.extensions.litestar import SQLSpecPlugin +from sqlspec.utils.correlation import CorrelationContext + + +@get("/correlation") +async def correlation_handler() -> dict[str, str | None]: + return {"correlation_id": CorrelationContext.get()} + + +def _build_app( + *, + enable: bool = True, + header: str | None = None, + headers: list[str] | None = None, + auto_trace_headers: bool | None = None, +) -> Litestar: + extension_config = cast("ExtensionConfigs", {"litestar": {"enable_correlation_middleware": enable}}) + litestar_settings = cast("dict[str, Any]", extension_config["litestar"]) + if header is not None: + litestar_settings["correlation_header"] = header + if headers is not None: + litestar_settings["correlation_headers"] = headers + if auto_trace_headers is not None: + litestar_settings["auto_trace_headers"] = auto_trace_headers + + spec = SQLSpec() + spec.add_config(SqliteConfig(pool_config={"database": ":memory:"}, extension_config=extension_config)) + + return Litestar(route_handlers=[correlation_handler], plugins=[SQLSpecPlugin(sqlspec=spec)]) + + +def test_correlation_middleware_uses_default_header() -> None: + app = _build_app() + + with TestClient(app) as client: + response = client.get("/correlation", headers={"X-Request-ID": "abc-123"}) + assert response.json()["correlation_id"] == "abc-123" + + +def test_correlation_middleware_custom_header() -> None: + app = _build_app(header="x-correlation-id") + + with TestClient(app) as client: + response = client.get("/correlation", headers={"X-Correlation-ID": "custom-id"}) + assert response.json()["correlation_id"] == "custom-id" + + +def test_correlation_middleware_can_be_disabled() -> None: + app = _build_app(enable=False) + + with TestClient(app) as client: + response = client.get("/correlation", headers={"X-Request-ID": "should-not-stick"}) + assert response.json()["correlation_id"] is None + + +def test_correlation_middleware_detects_traceparent_header() -> None: + app = _build_app() + traceparent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01" + + with TestClient(app) as client: + response = client.get("/correlation", headers={"traceparent": traceparent}) + assert response.json()["correlation_id"] == traceparent + + +def test_correlation_middleware_detects_x_cloud_trace_context_header() -> None: + app = _build_app() + header_value = "105445aa7843bc8bf206b120001000/1;o=1" + + with TestClient(app) as client: + response = client.get("/correlation", headers={"X-Cloud-Trace-Context": header_value}) + assert response.json()["correlation_id"] == header_value + + +def test_correlation_middleware_auto_detection_can_be_disabled() -> None: + app = _build_app(header="x-custom-id", auto_trace_headers=False) + traceparent = "00-11111111111111111111111111111111-2222222222222222-01" + + with TestClient(app) as client: + response = client.get("/correlation", headers={"traceparent": traceparent}) + assert response.json()["correlation_id"] != traceparent + + response = client.get("/correlation", headers={"X-Custom-ID": "custom-value"}) + assert response.json()["correlation_id"] == "custom-value" diff --git a/tests/unit/test_adapters/test_extension_config.py b/tests/unit/test_adapters/test_extension_config.py index f032c7f7a..36bca448c 100644 --- a/tests/unit/test_adapters/test_extension_config.py +++ b/tests/unit/test_adapters/test_extension_config.py @@ -1,6 +1,6 @@ """Test extension_config parameter support across all adapters.""" -from typing import Any +from typing import Any, cast import pytest @@ -14,31 +14,38 @@ from sqlspec.adapters.psqlpy import PsqlpyConfig from sqlspec.adapters.psycopg import PsycopgAsyncConfig, PsycopgSyncConfig from sqlspec.adapters.sqlite import SqliteConfig +from sqlspec.config import ExtensionConfigs def test_sqlite_extension_config() -> None: """Test SqliteConfig accepts and stores extension_config.""" - extension_config = {"litestar": {"session_key": "custom_session", "commit_mode": "manual"}} + extension_config = cast( + "ExtensionConfigs", {"litestar": {"session_key": "custom_session", "commit_mode": "manual"}} + ) config = SqliteConfig(pool_config={"database": ":memory:"}, extension_config=extension_config) assert config.extension_config == extension_config - assert config.extension_config["litestar"]["session_key"] == "custom_session" + litestar_settings = cast("dict[str, Any]", config.extension_config["litestar"]) + assert litestar_settings["session_key"] == "custom_session" def test_aiosqlite_extension_config() -> None: """Test AiosqliteConfig accepts and stores extension_config.""" - extension_config = {"litestar": {"pool_key": "db_pool", "enable_correlation_middleware": False}} + extension_config = cast( + ExtensionConfigs, {"litestar": {"pool_key": "db_pool", "enable_correlation_middleware": False}} + ) config = AiosqliteConfig(pool_config={"database": ":memory:"}, extension_config=extension_config) assert config.extension_config == extension_config - assert config.extension_config["litestar"]["pool_key"] == "db_pool" + litestar_settings = cast("dict[str, Any]", config.extension_config["litestar"]) + assert litestar_settings["pool_key"] == "db_pool" def test_duckdb_extension_config() -> None: """Test DuckDBConfig accepts and stores extension_config.""" - extension_config = {"litestar": {"connection_key": "duckdb_conn"}} + extension_config = cast("ExtensionConfigs", {"litestar": {"connection_key": "duckdb_conn"}}) config = DuckDBConfig(pool_config={"database": ":memory:"}, extension_config=extension_config) @@ -47,7 +54,7 @@ def test_duckdb_extension_config() -> None: def test_asyncpg_extension_config() -> None: """Test AsyncpgConfig accepts and stores extension_config.""" - extension_config = {"litestar": {"commit_mode": "autocommit"}} + extension_config = cast("ExtensionConfigs", {"litestar": {"commit_mode": "autocommit"}}) config = AsyncpgConfig(pool_config={"host": "localhost", "database": "test"}, extension_config=extension_config) @@ -56,7 +63,7 @@ def test_asyncpg_extension_config() -> None: def test_psycopg_sync_extension_config() -> None: """Test PsycopgSyncConfig accepts and stores extension_config.""" - extension_config = {"litestar": {"session_key": "psycopg_session"}} + extension_config = cast("ExtensionConfigs", {"litestar": {"session_key": "psycopg_session"}}) config = PsycopgSyncConfig(pool_config={"host": "localhost", "dbname": "test"}, extension_config=extension_config) @@ -65,7 +72,7 @@ def test_psycopg_sync_extension_config() -> None: def test_psycopg_async_extension_config() -> None: """Test PsycopgAsyncConfig accepts and stores extension_config.""" - extension_config = {"litestar": {"extra_commit_statuses": {201, 202}}} + extension_config = cast("ExtensionConfigs", {"litestar": {"extra_commit_statuses": {201, 202}}}) config = PsycopgAsyncConfig(pool_config={"host": "localhost", "dbname": "test"}, extension_config=extension_config) @@ -74,7 +81,7 @@ def test_psycopg_async_extension_config() -> None: def test_asyncmy_extension_config() -> None: """Test AsyncmyConfig accepts and stores extension_config.""" - extension_config = {"litestar": {"commit_mode": "autocommit_include_redirect"}} + extension_config = cast("ExtensionConfigs", {"litestar": {"commit_mode": "autocommit_include_redirect"}}) config = AsyncmyConfig(pool_config={"host": "localhost", "database": "test"}, extension_config=extension_config) @@ -83,7 +90,7 @@ def test_asyncmy_extension_config() -> None: def test_psqlpy_extension_config() -> None: """Test PsqlpyConfig accepts and stores extension_config.""" - extension_config = {"litestar": {"extra_rollback_statuses": {400, 500}}} + extension_config = cast("ExtensionConfigs", {"litestar": {"extra_rollback_statuses": {400, 500}}}) config = PsqlpyConfig(pool_config={"host": "localhost", "db_name": "test"}, extension_config=extension_config) @@ -92,7 +99,7 @@ def test_psqlpy_extension_config() -> None: def test_oracle_sync_extension_config() -> None: """Test OracleSyncConfig accepts and stores extension_config.""" - extension_config = {"litestar": {"enable_correlation_middleware": True}} + extension_config = cast("ExtensionConfigs", {"litestar": {"enable_correlation_middleware": True}}) config = OracleSyncConfig(pool_config={"user": "test", "password": "test"}, extension_config=extension_config) @@ -101,7 +108,7 @@ def test_oracle_sync_extension_config() -> None: def test_oracle_async_extension_config() -> None: """Test OracleAsyncConfig accepts and stores extension_config.""" - extension_config = {"litestar": {"connection_key": "oracle_async"}} + extension_config = cast("ExtensionConfigs", {"litestar": {"connection_key": "oracle_async"}}) config = OracleAsyncConfig(pool_config={"user": "test", "password": "test"}, extension_config=extension_config) @@ -110,7 +117,7 @@ def test_oracle_async_extension_config() -> None: def test_adbc_extension_config() -> None: """Test AdbcConfig accepts and stores extension_config.""" - extension_config = {"litestar": {"session_key": "adbc_session"}} + extension_config = cast("ExtensionConfigs", {"litestar": {"session_key": "adbc_session"}}) config = AdbcConfig( connection_config={"driver_name": "sqlite", "uri": "sqlite://:memory:"}, extension_config=extension_config @@ -121,7 +128,7 @@ def test_adbc_extension_config() -> None: def test_bigquery_extension_config() -> None: """Test BigQueryConfig accepts and stores extension_config.""" - extension_config = {"litestar": {"pool_key": "bigquery_pool"}} + extension_config = cast("ExtensionConfigs", {"litestar": {"pool_key": "bigquery_pool"}}) config = BigQueryConfig(connection_config={"project": "test-project"}, extension_config=extension_config) @@ -152,11 +159,14 @@ def test_extension_config_defaults_to_empty_dict() -> None: def test_extension_config_with_multiple_extensions() -> None: """Test extension_config can hold multiple extension configurations.""" - extension_config: dict[str, dict[str, Any]] = { - "litestar": {"session_key": "db_session", "commit_mode": "manual"}, - "custom_extension": {"setting1": "value1", "setting2": 42}, - "another_ext": {"enabled": True}, - } + extension_config = cast( + ExtensionConfigs, + { + "litestar": {"session_key": "db_session", "commit_mode": "manual"}, + "custom_extension": {"setting1": "value1", "setting2": 42}, + "another_ext": {"enabled": True}, + }, + ) config = SqliteConfig(pool_config={"database": ":memory:"}, extension_config=extension_config) @@ -186,7 +196,7 @@ def test_extension_config_with_multiple_extensions() -> None: ) def test_all_adapters_accept_extension_config(config_class: type, init_kwargs: dict) -> None: """Parameterized test ensuring all adapters accept extension_config.""" - extension_config = {"test_extension": {"test_key": "test_value"}} + extension_config = cast("ExtensionConfigs", {"test_extension": {"test_key": "test_value"}}) config = config_class(**init_kwargs, extension_config=extension_config) diff --git a/tests/unit/test_config/test_observability_extensions.py b/tests/unit/test_config/test_observability_extensions.py new file mode 100644 index 000000000..8faa2764d --- /dev/null +++ b/tests/unit/test_config/test_observability_extensions.py @@ -0,0 +1,49 @@ +"""Tests for extension_config-driven observability hooks.""" + +from typing import Any + +from sqlspec.adapters.asyncpg import AsyncpgDriver +from sqlspec.config import NoPoolSyncConfig + + +class _DummySyncConfig(NoPoolSyncConfig[Any, AsyncpgDriver]): + driver_type = AsyncpgDriver + connection_type = object + + def create_connection(self) -> Any: + raise NotImplementedError + + def provide_connection(self, *args: Any, **kwargs: Any): # type: ignore[override] + raise NotImplementedError + + def provide_session(self, *args: Any, **kwargs: Any): # type: ignore[override] + raise NotImplementedError + + +def test_otel_extension_config_enables_spans(monkeypatch): + monkeypatch.setattr("sqlspec.utils.module_loader.OPENTELEMETRY_INSTALLED", True, raising=False) + + config = _DummySyncConfig(extension_config={"otel": {"resource_attributes": {"service.name": "api"}}}) + + assert config.observability_config is not None + telemetry = config.observability_config.telemetry + assert telemetry is not None + assert telemetry.resource_attributes == {"service.name": "api"} + + +def test_prometheus_extension_registers_observer(monkeypatch): + monkeypatch.setattr("sqlspec.utils.module_loader.PROMETHEUS_INSTALLED", True, raising=False) + + config = _DummySyncConfig( + extension_config={"prometheus": {"namespace": "custom", "label_names": ("driver", "operation", "adapter")}} + ) + + assert config.observability_config is not None + observers = config.observability_config.statement_observers + assert observers is not None and observers, "expected prometheus observer to be registered" + + +def test_disabled_extensions_are_ignored(monkeypatch): + monkeypatch.setattr("sqlspec.utils.module_loader.OPENTELEMETRY_INSTALLED", True, raising=False) + config = _DummySyncConfig(extension_config={"otel": {"enabled": False}}) + assert config.observability_config is None diff --git a/tests/unit/test_extensions/test_observability_integrations.py b/tests/unit/test_extensions/test_observability_integrations.py new file mode 100644 index 000000000..329dcceac --- /dev/null +++ b/tests/unit/test_extensions/test_observability_integrations.py @@ -0,0 +1,42 @@ +"""Unit tests for observability helper extensions.""" + +from sqlspec.observability._observer import create_event + + +def test_enable_tracing_sets_telemetry(monkeypatch): + monkeypatch.setattr("sqlspec.utils.module_loader.OPENTELEMETRY_INSTALLED", True, raising=False) + + from sqlspec.extensions import otel + + config = otel.enable_tracing() + assert config.telemetry is not None + assert config.telemetry.enable_spans is True + provider = config.telemetry.provider_factory() if config.telemetry.provider_factory else None + assert provider is not None + + +def test_enable_metrics_registers_observer(monkeypatch): + monkeypatch.setattr("sqlspec.utils.module_loader.PROMETHEUS_INSTALLED", True, raising=False) + + from sqlspec.extensions import prometheus + + config = prometheus.enable_metrics() + assert config.statement_observers is not None + observer = config.statement_observers[-1] + + event = create_event( + sql="SELECT 1", + parameters=(), + driver="TestDriver", + adapter="test", + bind_key=None, + operation="SELECT", + execution_mode="sync", + is_many=False, + is_script=False, + rows_affected=1, + duration_s=0.05, + correlation_id=None, + ) + + observer(event) diff --git a/tests/unit/test_observability.py b/tests/unit/test_observability.py new file mode 100644 index 000000000..5511ca090 --- /dev/null +++ b/tests/unit/test_observability.py @@ -0,0 +1,461 @@ +"""Unit tests for observability helpers.""" + +from collections.abc import Iterable +from contextlib import contextmanager, nullcontext +from pathlib import Path +from typing import Any, cast + +from sqlspec import SQLSpec +from sqlspec.adapters.sqlite import SqliteConfig +from sqlspec.config import LifecycleConfig +from sqlspec.core import SQL, ArrowResult, StatementConfig +from sqlspec.driver._sync import SyncDataDictionaryBase, SyncDriverAdapterBase +from sqlspec.observability import ( + LifecycleDispatcher, + ObservabilityConfig, + ObservabilityRuntime, + RedactionConfig, + StatementObserver, +) +from sqlspec.storage import StorageTelemetry +from sqlspec.storage.pipeline import ( + record_storage_diagnostic_event, + reset_storage_bridge_events, + reset_storage_bridge_metrics, +) +from sqlspec.utils.correlation import CorrelationContext + + +def _lifecycle_config(hooks: dict[str, list[Any]]) -> "LifecycleConfig": + return cast("LifecycleConfig", hooks) + + +class _FakeSpan: + def __init__(self, name: str, attributes: dict[str, Any]) -> None: + self.name = name + self.attributes = attributes + self.closed = False + self.exception: Exception | None = None + + def end(self) -> None: + self.closed = True + + def record_exception(self, error: Exception) -> None: + self.exception = error + + def set_attribute(self, name: str, value: Any) -> None: + self.attributes[name] = value + + +class _FakeSpanManager: + def __init__(self) -> None: + self.is_enabled = True + self.started: list[_FakeSpan] = [] + self.finished: list[_FakeSpan] = [] + + def start_span(self, name: str, attributes: dict[str, Any]) -> _FakeSpan: + correlation = attributes.get("correlation_id") + if correlation is not None: + attributes.setdefault("sqlspec.correlation_id", correlation) + span = _FakeSpan(name, dict(attributes)) + self.started.append(span) + return span + + def start_query_span(self, **attributes: Any) -> _FakeSpan: + return self.start_span("sqlspec.query", attributes) + + def end_span(self, span: _FakeSpan | None, error: Exception | None = None) -> None: + if span is None: + return + if error is not None: + span.record_exception(error) + span.end() + self.finished.append(span) + + +class _ArrowResultStub: + def __init__(self) -> None: + self.calls: list[tuple[Any, Any, Any, Any]] = [] + + def write_to_storage_sync( + self, destination: Any, *, format_hint: Any = None, storage_options: Any = None, pipeline: Any = None + ) -> dict[str, Any]: + self.calls.append((destination, format_hint, storage_options, pipeline)) + return { + "destination": str(destination), + "backend": "local", + "bytes_processed": 1, + "rows_processed": 1, + "format": format_hint or "jsonl", + } + + +class _FakeSyncPipeline: + def __init__(self) -> None: + self.calls: list[tuple[Any, Any]] = [] + + def read_arrow(self, source: Any, *, file_format: str, storage_options: Any = None) -> tuple[str, dict[str, Any]]: + _ = storage_options + self.calls.append((source, file_format)) + return ( + "table", + { + "destination": str(source), + "backend": "s3", + "bytes_processed": 10, + "rows_processed": 5, + "format": file_format, + }, + ) + + +class _DummyDictionary(SyncDataDictionaryBase): + def get_version(self, driver: SyncDriverAdapterBase) -> None: + _ = driver + + def get_feature_flag(self, driver: SyncDriverAdapterBase, feature: str) -> bool: + _ = driver, feature + return False + + def get_optimal_type(self, driver: SyncDriverAdapterBase, type_category: str) -> str: + _ = driver, type_category + return "TEXT" + + +class _DummyCursor: + rowcount = 1 + + +class _DummyDriver(SyncDriverAdapterBase): + def __init__(self, *args: Any, **kwargs: Any) -> None: + self._dictionary = _DummyDictionary() + super().__init__(*args, **kwargs) + + @property + def data_dictionary(self) -> SyncDataDictionaryBase: + return self._dictionary + + def with_cursor(self, connection: Any): + @contextmanager + def _cursor() -> Any: + yield _DummyCursor() + + return _cursor() + + def handle_database_exceptions(self): + return nullcontext() + + def begin(self) -> None: # pragma: no cover - unused in tests + return None + + def rollback(self) -> None: # pragma: no cover - unused in tests + return None + + def commit(self) -> None: # pragma: no cover - unused in tests + return None + + def _try_special_handling(self, cursor: Any, statement: SQL): + _ = cursor, statement + + def _execute_statement(self, cursor: Any, statement: SQL): + _ = cursor, statement + return self.create_execution_result( + cursor_result=None, + rowcount_override=1, + special_data={}, + selected_data=None, + column_names=None, + data_row_count=None, + statement_count=None, + successful_statements=None, + is_script_result=False, + is_select_result=False, + is_many_result=False, + ) + + def _execute_many(self, cursor: Any, statement: SQL): + _ = cursor, statement + return self.create_execution_result( + cursor_result=None, + rowcount_override=1, + special_data={}, + selected_data=None, + column_names=None, + data_row_count=None, + statement_count=None, + successful_statements=None, + is_script_result=False, + is_select_result=False, + is_many_result=True, + ) + + +def test_observability_config_merge_combines_hooks_and_observers() -> None: + """Merged configs should merge lifecycle hooks and observers.""" + + base = ObservabilityConfig(lifecycle=_lifecycle_config({"on_query_start": [lambda ctx: ctx]}), print_sql=False) + observer_called = [] + + def observer(_event: Any) -> None: + observer_called.append(True) + + override = ObservabilityConfig( + lifecycle=_lifecycle_config({"on_query_start": [lambda ctx: ctx]}), + print_sql=True, + statement_observers=(observer,), + ) + + merged = ObservabilityConfig.merge(base, override) + + assert merged.print_sql is True + assert merged.statement_observers is not None + assert len(merged.statement_observers) == 1 + dispatcher = LifecycleDispatcher(cast("dict[str, Iterable[Any]]", merged.lifecycle)) + assert getattr(dispatcher, "has_query_start") is True + dispatcher.emit_query_start({"foo": "bar"}) + assert observer_called == [] # observers run via runtime, dispatcher unaffected + + +def test_lifecycle_dispatcher_counts_events() -> None: + """Lifecycle dispatcher should count emitted events for diagnostics.""" + + dispatcher = LifecycleDispatcher( + cast("dict[str, Iterable[Any]]", {"on_query_start": [lambda ctx: ctx], "on_query_complete": [lambda ctx: ctx]}) + ) + dispatcher.emit_query_start({}) + dispatcher.emit_query_complete({}) + dispatcher.emit_query_complete({}) + snapshot = dispatcher.snapshot(prefix="test-config") + assert snapshot["test-config.lifecycle.query_start"] == 1 + assert snapshot["test-config.lifecycle.query_complete"] == 2 + + +def test_runtime_statement_event_redaction() -> None: + """Runtime should redact SQL and parameters before notifying observers.""" + + observed: list[dict[str, Any]] = [] + + def observer(event: Any) -> None: + observed.append(event.as_dict()) + + config = ObservabilityConfig( + redaction=RedactionConfig(mask_literals=True, mask_parameters=True), + statement_observers=(cast(StatementObserver, observer),), + ) + runtime = ObservabilityRuntime(config, bind_key="primary", config_name="TestConfig") + + runtime.emit_statement_event( + sql="select * from users where email='secret'", + parameters={"email": "secret@example.com"}, + driver="DummyDriver", + operation="SELECT", + execution_mode="single", + is_many=False, + is_script=False, + rows_affected=1, + duration_s=0.01, + storage_backend=None, + ) + + assert observed, "Observer should capture at least one event" + event = observed[0] + assert "'***'" in event["sql"] + assert event["parameters"] == {"email": "***"} + + +def test_runtime_emits_pool_events_with_context() -> None: + """Emit helpers should forward base context to lifecycle hooks.""" + + captured: list[dict[str, Any]] = [] + + def hook(context: dict[str, Any]) -> None: + captured.append(context) + + runtime = ObservabilityRuntime( + ObservabilityConfig(lifecycle=_lifecycle_config({"on_pool_create": [hook], "on_pool_destroy": [hook]})), + bind_key="primary", + config_name="TestConfig", + ) + + runtime.emit_pool_create("pool-obj") + runtime.emit_pool_destroy("pool-obj") + + assert len(captured) == 2 + assert captured[0]["config"] == "TestConfig" + assert captured[0]["bind_key"] == "primary" + assert captured[1]["config"] == "TestConfig" + + +def test_lifecycle_spans_emit_even_without_hooks() -> None: + """Lifecycle emissions should still create spans when no hooks exist.""" + + runtime = ObservabilityRuntime(ObservabilityConfig(), bind_key="primary", config_name="DummyAdapter") + fake_manager = _FakeSpanManager() + runtime.span_manager = cast(Any, fake_manager) + + runtime.emit_connection_create(object()) + runtime.emit_connection_destroy(object()) + + span_names = [span.name for span in fake_manager.finished] + assert "sqlspec.lifecycle.connection.create" in span_names + assert "sqlspec.lifecycle.connection.destroy" in span_names + + +def test_driver_dispatch_records_query_span() -> None: + """Driver dispatch should start and finish query spans.""" + + span_manager = _FakeSpanManager() + runtime = ObservabilityRuntime(ObservabilityConfig(), config_name="DummyAdapter") + runtime.span_manager = cast(Any, span_manager) + + statement_config = StatementConfig() + driver = _DummyDriver(connection=object(), statement_config=statement_config, observability=runtime) + statement = SQL("SELECT 1", statement_config=statement_config) + + with CorrelationContext.context("query-correlation"): + driver.dispatch_statement_execution(statement, driver.connection) + + assert span_manager.started, "Query span should start" + assert span_manager.finished, "Query span should finish" + assert span_manager.started[0].name == "sqlspec.query" + assert span_manager.started[0].attributes["adapter"] == "DummyAdapter" + assert span_manager.started[0].attributes["sqlspec.correlation_id"] == "query-correlation" + assert span_manager.finished[0].closed is True + + +def test_storage_span_records_telemetry_attributes() -> None: + """Storage spans should capture telemetry attributes when ending.""" + + runtime = ObservabilityRuntime(ObservabilityConfig(), config_name="TestConfig") + span_manager = _FakeSpanManager() + runtime.span_manager = cast(Any, span_manager) + span = runtime.start_storage_span("write", destination="alias://foo", format_label="parquet") + telemetry: StorageTelemetry = { + "destination": "alias://foo", + "backend": "s3", + "bytes_processed": 1024, + "rows_processed": 8, + "format": "parquet", + "duration_s": 0.5, + } + runtime.end_storage_span(span, telemetry=telemetry) + + assert span_manager.finished, "Storage span should finish" + assert span_manager.finished[0].attributes["sqlspec.storage.backend"] == "s3" + + +def test_write_storage_helper_emits_span() -> None: + """Storage driver helper should wrap sync writes with spans.""" + + runtime = ObservabilityRuntime(ObservabilityConfig(), config_name="DummyAdapter") + span_manager = _FakeSpanManager() + runtime.span_manager = cast(Any, span_manager) + statement_config = StatementConfig() + driver = _DummyDriver(connection=object(), statement_config=statement_config, observability=runtime) + result_stub = _ArrowResultStub() + + with CorrelationContext.context("test-correlation"): + telemetry = driver._write_result_to_storage_sync( # pyright: ignore[reportPrivateUsage] + cast(ArrowResult, result_stub), "alias://bucket/object" + ) + + assert telemetry["backend"] == "local" + assert telemetry["correlation_id"] == "test-correlation" + assert any(span.name == "sqlspec.storage.write" for span in span_manager.finished) + + +def test_read_storage_helper_emits_span() -> None: + """Reading from storage via helper should emit spans and return telemetry.""" + + runtime = ObservabilityRuntime(ObservabilityConfig(), config_name="DummyAdapter") + span_manager = _FakeSpanManager() + runtime.span_manager = cast(Any, span_manager) + statement_config = StatementConfig() + driver = _DummyDriver(connection=object(), statement_config=statement_config, observability=runtime) + pipeline = _FakeSyncPipeline() + driver.storage_pipeline_factory = lambda: pipeline # type: ignore[assignment] + + with CorrelationContext.context("read-correlation"): + _table, telemetry = driver._read_arrow_from_storage_sync( # pyright: ignore[reportPrivateUsage] + "alias://bucket/data", file_format="parquet" + ) + + assert telemetry["backend"] == "s3" + assert telemetry["correlation_id"] == "read-correlation" + assert pipeline.calls, "Pipeline should be invoked" + assert any(span.name == "sqlspec.storage.read" for span in span_manager.finished) + + +def test_telemetry_snapshot_includes_recent_storage_jobs() -> None: + """Telemetry snapshot should surface recent storage jobs with correlation metadata.""" + + reset_storage_bridge_metrics() + reset_storage_bridge_events() + + spec = SQLSpec() + spec.add_config(SqliteConfig(pool_config={"database": ":memory:"})) + + record_storage_diagnostic_event({ + "destination": "alias://bucket/path", + "backend": "s3", + "bytes_processed": 512, + "rows_processed": 8, + "config": "SqliteConfig", + "bind_key": "default", + "correlation_id": "diag-test", + }) + + snapshot = spec.telemetry_snapshot() + recent_jobs = snapshot.get("storage_bridge.recent_jobs") + assert recent_jobs, "Recent storage jobs should be included in diagnostics" + assert recent_jobs[0]["correlation_id"] == "diag-test" + + +def test_telemetry_snapshot_includes_loader_metrics(tmp_path: "Path") -> None: + """Telemetry snapshot should expose loader metric counters after a load.""" + + sql_path = tmp_path / "queries.sql" + sql_path.write_text("-- name: example\nSELECT 1;\n", encoding="utf-8") + + spec = SQLSpec() + spec.load_sql_files(sql_path) + + snapshot = spec.telemetry_snapshot() + assert "SQLFileLoader.loader.load.invocations" in snapshot + assert snapshot["SQLFileLoader.loader.files.loaded"] >= 1 + + +def test_disabled_runtime_avoids_lifecycle_counters() -> None: + """Drivers should skip lifecycle hooks entirely when none are registered.""" + + runtime = ObservabilityRuntime() + statement_config = StatementConfig() + driver = _DummyDriver(connection=object(), statement_config=statement_config, observability=runtime) + statement = SQL("SELECT 1", statement_config=statement_config) + + driver.dispatch_statement_execution(statement, driver.connection) + + snapshot = runtime.lifecycle_snapshot() + assert all(value == 0 for value in snapshot.values()) + + +def test_runtime_with_lifecycle_hooks_records_counters() -> None: + """Lifecycle counters should increment when hooks are configured.""" + + captured: list[dict[str, Any]] = [] + + def hook(ctx: dict[str, Any]) -> None: + captured.append(ctx) + + runtime = ObservabilityRuntime( + ObservabilityConfig(lifecycle=_lifecycle_config({"on_query_start": [hook]})), config_name="DummyConfig" + ) + statement_config = StatementConfig() + driver = _DummyDriver(connection=object(), statement_config=statement_config, observability=runtime) + statement = SQL("SELECT 1", statement_config=statement_config) + + driver.dispatch_statement_execution(statement, driver.connection) + + snapshot = runtime.lifecycle_snapshot() + assert snapshot["DummyConfig.lifecycle.query_start"] == 1 + assert captured, "Hook should have been invoked" diff --git a/uv.lock b/uv.lock index c8fd924e6..92335cb5b 100644 --- a/uv.lock +++ b/uv.lock @@ -1689,7 +1689,7 @@ grpc = [ [[package]] name = "google-api-python-client" -version = "2.186.0" +version = "2.187.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core" }, @@ -1698,23 +1698,23 @@ dependencies = [ { name = "httplib2" }, { name = "uritemplate" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/47/cf/d167fec8be9e65768133be83a8d182350195840e14d1c203565383834614/google_api_python_client-2.186.0.tar.gz", hash = "sha256:01b8ff446adbc10f495188400a9f7c3e88e5e75741663a25822f41e788475333", size = 13937230, upload-time = "2025-10-30T22:13:20.971Z" } +sdist = { url = "https://files.pythonhosted.org/packages/75/83/60cdacf139d768dd7f0fcbe8d95b418299810068093fdf8228c6af89bb70/google_api_python_client-2.187.0.tar.gz", hash = "sha256:e98e8e8f49e1b5048c2f8276473d6485febc76c9c47892a8b4d1afa2c9ec8278", size = 14068154, upload-time = "2025-11-06T01:48:53.274Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/21/5a/b00b944eb9cd0f2e39daf3bcce006cb503a89532f507e87e038e04bbea8c/google_api_python_client-2.186.0-py3-none-any.whl", hash = "sha256:2ea4beba93e193d3a632c7bf865b6ccace42b0017269a964566e39b7e1f3cf79", size = 14507868, upload-time = "2025-10-30T22:13:18.426Z" }, + { url = "https://files.pythonhosted.org/packages/96/58/c1e716be1b055b504d80db2c8413f6c6a890a6ae218a65f178b63bc30356/google_api_python_client-2.187.0-py3-none-any.whl", hash = "sha256:d8d0f6d85d7d1d10bdab32e642312ed572bdc98919f72f831b44b9a9cebba32f", size = 14641434, upload-time = "2025-11-06T01:48:50.763Z" }, ] [[package]] name = "google-auth" -version = "2.42.1" +version = "2.43.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cachetools" }, { name = "pyasn1-modules" }, { name = "rsa" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/25/6b/22a77135757c3a7854c9f008ffed6bf4e8851616d77faf13147e9ab5aae6/google_auth-2.42.1.tar.gz", hash = "sha256:30178b7a21aa50bffbdc1ffcb34ff770a2f65c712170ecd5446c4bef4dc2b94e", size = 295541, upload-time = "2025-10-30T16:42:19.381Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ff/ef/66d14cf0e01b08d2d51ffc3c20410c4e134a1548fc246a6081eae585a4fe/google_auth-2.43.0.tar.gz", hash = "sha256:88228eee5fc21b62a1b5fe773ca15e67778cb07dc8363adcb4a8827b52d81483", size = 296359, upload-time = "2025-11-06T00:13:36.587Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/92/05/adeb6c495aec4f9d93f9e2fc29eeef6e14d452bba11d15bdb874ce1d5b10/google_auth-2.42.1-py2.py3-none-any.whl", hash = "sha256:eb73d71c91fc95dbd221a2eb87477c278a355e7367a35c0d84e6b0e5f9b4ad11", size = 222550, upload-time = "2025-10-30T16:42:17.878Z" }, + { url = "https://files.pythonhosted.org/packages/6f/d1/385110a9ae86d91cc14c5282c61fe9f4dc41c0b9f7d423c6ad77038c4448/google_auth-2.43.0-py2.py3-none-any.whl", hash = "sha256:af628ba6fa493f75c7e9dbe9373d148ca9f4399b5ea29976519e0a3848eddd16", size = 223114, upload-time = "2025-11-06T00:13:35.209Z" }, ] [[package]] @@ -1732,7 +1732,7 @@ wheels = [ [[package]] name = "google-cloud-aiplatform" -version = "1.125.0" +version = "1.126.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "docstring-parser" }, @@ -1749,9 +1749,9 @@ dependencies = [ { name = "shapely" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0f/d7/5c2df60dbcc68f292a4a7e7b0e17e17b4808788be0fdbdb150d439bc70e6/google_cloud_aiplatform-1.125.0.tar.gz", hash = "sha256:2cafa6222c78155c209893458706942dffa16d5647496257e916405a19e75a63", size = 9772810, upload-time = "2025-11-05T00:32:00.927Z" } +sdist = { url = "https://files.pythonhosted.org/packages/44/37/4f963ad4c2ea5f4ab68e5bf83de80b8c7622bb3add81189b217095963d06/google_cloud_aiplatform-1.126.0.tar.gz", hash = "sha256:032d3551acd1f51cbafb096c13a940df761d01f19cbd92a6dfa800aa222c9517", size = 9777797, upload-time = "2025-11-05T23:16:48.616Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7f/fa/72c8d14b8abb119a15e31864fc61ab2ff5565f1e1166a137d62968935e00/google_cloud_aiplatform-1.125.0-py2.py3-none-any.whl", hash = "sha256:956058c138ba668f7e1365489dbbfca8606e4d35e9186fa429434950817a78d5", size = 8122978, upload-time = "2025-11-05T00:31:57.74Z" }, + { url = "https://files.pythonhosted.org/packages/d1/44/84c470e23c66af14f1c2bba88f902b2264c6e554da2c88366683de81d2bd/google_cloud_aiplatform-1.126.0-py2.py3-none-any.whl", hash = "sha256:6008e134f0a93c1d310ea6628051653dde35415cfc0c5fe60cf822bdcccf4d49", size = 8123670, upload-time = "2025-11-05T23:16:45.565Z" }, ] [package.optional-dependencies] @@ -1759,6 +1759,7 @@ agent-engines = [ { name = "cloudpickle" }, { name = "google-cloud-logging" }, { name = "google-cloud-trace" }, + { name = "opentelemetry-exporter-gcp-logging" }, { name = "opentelemetry-exporter-gcp-trace" }, { name = "opentelemetry-exporter-otlp-proto-http" }, { name = "opentelemetry-sdk" },