Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions airbyte_cdk/sources/declarative/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
from dataclasses import InitVar, dataclass, field
from datetime import datetime, timedelta
from typing import Any, List, Mapping, Optional, Union
from typing import Any, List, Mapping, Optional, Tuple, Union

from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator
from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean
Expand Down Expand Up @@ -46,6 +46,9 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
refresh_request_headers (Optional[Mapping[str, Any]]): The request headers to send in the refresh request
grant_type: The grant_type to request for access_token. If set to refresh_token, the refresh_token parameter has to be provided
message_repository (MessageRepository): the message repository used to emit logs on HTTP requests
refresh_token_error_status_codes (Tuple[int, ...]): Status codes to identify refresh token errors in response
refresh_token_error_key (str): Key to identify refresh token error in response
refresh_token_error_values (Tuple[str, ...]): List of values to check for exception during token refresh process
"""

config: Mapping[str, Any]
Expand All @@ -72,9 +75,16 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
message_repository: MessageRepository = NoopMessageRepository()
profile_assertion: Optional[DeclarativeAuthenticator] = None
use_profile_assertion: Optional[Union[InterpolatedBoolean, str, bool]] = False
refresh_token_error_status_codes: Tuple[int, ...] = ()
refresh_token_error_key: str = ""
refresh_token_error_values: Tuple[str, ...] = ()

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
super().__init__()
super().__init__(
refresh_token_error_status_codes=self.refresh_token_error_status_codes,
refresh_token_error_key=self.refresh_token_error_key,
refresh_token_error_values=self.refresh_token_error_values,
)
if self.token_refresh_endpoint is not None:
self._token_refresh_endpoint: Optional[InterpolatedString] = InterpolatedString.create(
self.token_refresh_endpoint, parameters=parameters
Expand Down
28 changes: 25 additions & 3 deletions airbyte_cdk/sources/declarative/declarative_component_schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,28 @@ definitions:
type: string
examples:
- "%Y-%m-%d %H:%M:%S.%f+00:00"
refresh_token_error_status_codes:
title: Refresh Token Error Status Codes
description: Status Codes to Identify refresh token error in response (Refresh Token Error Key and Refresh Token Error Values should be also specified). Responses with one of the error status code and containing an error value will be flagged as a config error
type: array
items:
type: integer
examples:
- [400, 500]
refresh_token_error_key:
title: Refresh Token Error Key
description: Key to Identify refresh token error in response (Refresh Token Error Status Codes and Refresh Token Error Values should be also specified).
type: string
examples:
- "error"
refresh_token_error_values:
title: Refresh Token Error Values
description: 'List of values to check for exception during token refresh process. Used to check if the error found in the response matches the key from the Refresh Token Error Key field (e.g. response={"error": "invalid_grant"}). Only responses with one of the error status code and containing an error value will be flagged as a config error'
type: array
items:
type: string
examples:
- ["invalid_grant", "invalid_permissions"]
refresh_token_updater:
title: Refresh Token Updater
description: When the refresh token updater is defined, new refresh tokens, access tokens and the access token expiry date are written back from the authentication response to the config object. This is important if the refresh token can only used once.
Expand Down Expand Up @@ -1468,7 +1490,7 @@ definitions:
examples:
- ["credentials", "token_expiry_date"]
refresh_token_error_status_codes:
title: Refresh Token Error Status Codes
title: (Deprecated - Use the same field on the OAuthAuthenticator level) Refresh Token Error Status Codes
description: Status Codes to Identify refresh token error in response (Refresh Token Error Key and Refresh Token Error Values should be also specified). Responses with one of the error status code and containing an error value will be flagged as a config error
type: array
items:
Expand All @@ -1477,14 +1499,14 @@ definitions:
examples:
- [400, 500]
refresh_token_error_key:
title: Refresh Token Error Key
title: (Deprecated - Use the same field on the OAuthAuthenticator level) Refresh Token Error Key
description: Key to Identify refresh token error in response (Refresh Token Error Status Codes and Refresh Token Error Values should be also specified).
type: string
default: ""
examples:
- "error"
refresh_token_error_values:
title: Refresh Token Error Values
title: (Deprecated - Use the same field on the OAuthAuthenticator level) Refresh Token Error Values
description: 'List of values to check for exception during token refresh process. Used to check if the error found in the response matches the key from the Refresh Token Error Key field (e.g. response={"error": "invalid_grant"}). Only responses with one of the error status code and containing an error value will be flagged as a config error'
type: array
items:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# Copyright (c) 2025 Airbyte, Inc., all rights reserved.

# generated by datamodel-codegen:
# filename: declarative_component_schema.yaml

Expand Down Expand Up @@ -426,19 +424,19 @@ class RefreshTokenUpdater(BaseModel):
[],
description="Status Codes to Identify refresh token error in response (Refresh Token Error Key and Refresh Token Error Values should be also specified). Responses with one of the error status code and containing an error value will be flagged as a config error",
examples=[[400, 500]],
title="Refresh Token Error Status Codes",
title="(Deprecated - Use the same field on the OAuthAuthenticator level) Refresh Token Error Status Codes",
)
refresh_token_error_key: Optional[str] = Field(
"",
description="Key to Identify refresh token error in response (Refresh Token Error Status Codes and Refresh Token Error Values should be also specified).",
examples=["error"],
title="Refresh Token Error Key",
title="(Deprecated - Use the same field on the OAuthAuthenticator level) Refresh Token Error Key",
)
refresh_token_error_values: Optional[List[str]] = Field(
[],
description='List of values to check for exception during token refresh process. Used to check if the error found in the response matches the key from the Refresh Token Error Key field (e.g. response={"error": "invalid_grant"}). Only responses with one of the error status code and containing an error value will be flagged as a config error',
examples=[["invalid_grant", "invalid_permissions"]],
title="Refresh Token Error Values",
title="(Deprecated - Use the same field on the OAuthAuthenticator level) Refresh Token Error Values",
)


Expand Down Expand Up @@ -1900,6 +1898,24 @@ class OAuthAuthenticator(BaseModel):
examples=["%Y-%m-%d %H:%M:%S.%f+00:00"],
title="Token Expiry Date Format",
)
refresh_token_error_status_codes: Optional[List[int]] = Field(
None,
description="Status Codes to Identify refresh token error in response (Refresh Token Error Key and Refresh Token Error Values should be also specified). Responses with one of the error status code and containing an error value will be flagged as a config error",
examples=[[400, 500]],
title="Refresh Token Error Status Codes",
)
refresh_token_error_key: Optional[str] = Field(
None,
description="Key to Identify refresh token error in response (Refresh Token Error Status Codes and Refresh Token Error Values should be also specified).",
examples=["error"],
title="Refresh Token Error Key",
)
refresh_token_error_values: Optional[List[str]] = Field(
None,
description='List of values to check for exception during token refresh process. Used to check if the error found in the response matches the key from the Refresh Token Error Key field (e.g. response={"error": "invalid_grant"}). Only responses with one of the error status code and containing an error value will be flagged as a config error',
examples=[["invalid_grant", "invalid_permissions"]],
title="Refresh Token Error Values",
)
refresh_token_updater: Optional[RefreshTokenUpdater] = Field(
None,
description="When the refresh token updater is defined, new refresh tokens, access tokens and the access token expiry date are written back from the authentication response to the config object. This is important if the refresh token can only used once.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Mapping,
MutableMapping,
Optional,
Tuple,
Type,
Union,
cast,
Expand Down Expand Up @@ -400,6 +401,9 @@
from airbyte_cdk.sources.declarative.models.declarative_component_schema import (
RecordSelector as RecordSelectorModel,
)
from airbyte_cdk.sources.declarative.models.declarative_component_schema import (
RefreshTokenUpdater as RefreshTokenUpdaterModel,
)
from airbyte_cdk.sources.declarative.models.declarative_component_schema import (
RemoveFields as RemoveFieldsModel,
)
Expand Down Expand Up @@ -2789,6 +2793,9 @@ def create_oauth_authenticator(
else None
)

refresh_token_error_status_codes, refresh_token_error_key, refresh_token_error_values = (
self._get_refresh_token_error_information(model)
)
if model.refresh_token_updater:
# ignore type error because fixing it would have a lot of dependencies, revisit later
return DeclarativeSingleUseRefreshTokenOauth2Authenticator( # type: ignore
Expand Down Expand Up @@ -2839,9 +2846,9 @@ def create_oauth_authenticator(
token_expiry_date_format=model.token_expiry_date_format,
token_expiry_is_time_of_expiration=bool(model.token_expiry_date_format),
message_repository=self._message_repository,
refresh_token_error_status_codes=model.refresh_token_updater.refresh_token_error_status_codes,
refresh_token_error_key=model.refresh_token_updater.refresh_token_error_key,
refresh_token_error_values=model.refresh_token_updater.refresh_token_error_values,
refresh_token_error_status_codes=refresh_token_error_status_codes,
refresh_token_error_key=refresh_token_error_key,
refresh_token_error_values=refresh_token_error_values,
)
# ignore type error because fixing it would have a lot of dependencies, revisit later
return DeclarativeOauth2Authenticator( # type: ignore
Expand All @@ -2868,8 +2875,59 @@ def create_oauth_authenticator(
message_repository=self._message_repository,
profile_assertion=profile_assertion,
use_profile_assertion=model.use_profile_assertion,
refresh_token_error_status_codes=refresh_token_error_status_codes,
refresh_token_error_key=refresh_token_error_key,
refresh_token_error_values=refresh_token_error_values,
)

@staticmethod
def _get_refresh_token_error_information(
model: OAuthAuthenticatorModel,
) -> Tuple[Tuple[int, ...], str, Tuple[str, ...]]:
"""
In a previous version of the CDK, the auth error as config_error was only done if a refresh token updater was
defined. As a transition, we added those fields on the OAuthAuthenticatorModel. This method ensures that the
information is defined only once and return the right fields.
"""
refresh_token_updater = model.refresh_token_updater
is_defined_on_refresh_token_updated = refresh_token_updater and (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I struggle a bit with is_defined_on_refresh_token_updated and is_defined_on_oauth_authenticator because if someone wanted to disable the error handling to flag these as system_errors, they would have to set impossible values instead of defining them with none values. That being said, it seems like a weird case to handle and there is an escape path so I'm not too worried

refresh_token_updater.refresh_token_error_status_codes
or refresh_token_updater.refresh_token_error_key
or refresh_token_updater.refresh_token_error_values
)
is_defined_on_oauth_authenticator = (
model.refresh_token_error_status_codes
or model.refresh_token_error_key
or model.refresh_token_error_values
)
if is_defined_on_refresh_token_updated and is_defined_on_oauth_authenticator:
raise ValueError(
"refresh_token_error should either be defined on the OAuthAuthenticatorModel or the RefreshTokenUpdaterModel, not both"
)

if is_defined_on_refresh_token_updated:
not_optional_refresh_token_updater: RefreshTokenUpdaterModel = refresh_token_updater # type: ignore # we know from the condition that this is not None
return (
tuple(not_optional_refresh_token_updater.refresh_token_error_status_codes)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this will have the same behavior as the previous code but is a bit different. The previous code would pass refresh_token_updater.refresh_token_error_status_codes directly which felt very dangerous because the typing was different (Tuple[int, ...] vs Optional[List[int]]). I fear that the previous code could lead to None exception here if refresh_token_error_status_codes was not defined. I don't have signals to believe this actually happen in prod though...

if not_optional_refresh_token_updater.refresh_token_error_status_codes
else (),
not_optional_refresh_token_updater.refresh_token_error_key or "",
tuple(not_optional_refresh_token_updater.refresh_token_error_values)
if not_optional_refresh_token_updater.refresh_token_error_values
else (),
)
elif is_defined_on_oauth_authenticator:
return (
tuple(model.refresh_token_error_status_codes)
if model.refresh_token_error_status_codes
else (),
model.refresh_token_error_key or "",
tuple(model.refresh_token_error_values) if model.refresh_token_error_values else (),
)

# returning default values we think cover most cases
return (400,), "error", ("invalid_grant", "invalid_permissions")

def create_offset_increment(
self,
model: OffsetIncrementModel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -636,9 +636,9 @@ def test_single_use_oauth_branch():
# default values
assert authenticator._access_token_config_path == ["credentials", "access_token"]
assert authenticator._token_expiry_date_config_path == ["credentials", "token_expiry_date"]
assert authenticator._refresh_token_error_status_codes == [400]
assert authenticator._refresh_token_error_status_codes == (400,)
assert authenticator._refresh_token_error_key == "error"
assert authenticator._refresh_token_error_values == ["invalid_grant"]
assert authenticator._refresh_token_error_values == ("invalid_grant",)


def test_list_based_stream_slicer_with_values_refd():
Expand Down
Loading