Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/openai/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import inspect
from typing import TYPE_CHECKING, Any, Type, Tuple, Union, Generic, TypeVar, Callable, Optional, cast
from weakref import WeakKeyDictionary
from datetime import date, datetime
from typing_extensions import (
List,
Expand Down Expand Up @@ -77,6 +78,8 @@

ReprArgs = Sequence[Tuple[Optional[str], Any]]

_DISCRIMINATOR_CACHE: "WeakKeyDictionary[type, DiscriminatorDetails]" = WeakKeyDictionary()


@runtime_checkable
class _ConfigProtocol(Protocol):
Expand Down Expand Up @@ -593,11 +596,6 @@ def construct_type(*, value: object, type_: object, metadata: Optional[List[Any]
return value


@runtime_checkable
class CachedDiscriminatorType(Protocol):
__discriminator__: DiscriminatorDetails


class DiscriminatorDetails:
field_name: str
"""The name of the discriminator field in the variant class, e.g.
Expand Down Expand Up @@ -640,8 +638,9 @@ def __init__(


def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
if isinstance(union, CachedDiscriminatorType):
return union.__discriminator__
cached_discriminator = _DISCRIMINATOR_CACHE.get(union)
if cached_discriminator is not None:
return cached_discriminator

discriminator_field_name: str | None = None

Expand Down Expand Up @@ -694,7 +693,7 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
discriminator_field=discriminator_field_name,
discriminator_alias=discriminator_alias,
)
cast(CachedDiscriminatorType, union).__discriminator__ = details
_DISCRIMINATOR_CACHE[union] = details
return details


Expand Down
9 changes: 5 additions & 4 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from openai._utils import PropertyInfo
from openai._compat import PYDANTIC_V1, parse_obj, model_dump, model_json
from openai._models import BaseModel, construct_type
from openai._models import BaseModel, construct_type, _DISCRIMINATOR_CACHE


class BasicModel(BaseModel):
Expand Down Expand Up @@ -809,7 +809,7 @@ class B(BaseModel):

UnionType = cast(Any, Union[A, B])

assert not hasattr(UnionType, "__discriminator__")
assert _DISCRIMINATOR_CACHE.get(UnionType) is None

m = construct_type(
value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")])
Expand All @@ -818,7 +818,8 @@ class B(BaseModel):
assert m.type == "b"
assert m.data == "foo" # type: ignore[comparison-overlap]

discriminator = UnionType.__discriminator__
discriminator = _DISCRIMINATOR_CACHE.get(UnionType)

assert discriminator is not None

m = construct_type(
Expand All @@ -830,7 +831,7 @@ class B(BaseModel):

# if the discriminator details object stays the same between invocations then
# we hit the cache
assert UnionType.__discriminator__ is discriminator
assert _DISCRIMINATOR_CACHE.get(UnionType) is discriminator


@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1")
Expand Down