44
55from fastapi_users .db .base import BaseUserDatabase
66from fastapi_users .models import ID , OAP , UP
7- from sqlalchemy import Boolean , Column , ForeignKey , Integer , String , func , select
7+ from sqlalchemy import Boolean , ForeignKey , Integer , String , func , select
88from sqlalchemy .ext .asyncio import AsyncSession
9- from sqlalchemy .orm import declarative_mixin , declared_attr
9+ from sqlalchemy .orm import Mapped , declared_attr , mapped_column
1010from sqlalchemy .sql import Select
1111
1212from fastapi_users_db_sqlalchemy .generics import GUID
1616UUID_ID = uuid .UUID
1717
1818
19- @declarative_mixin
2019class SQLAlchemyBaseUserTable (Generic [ID ]):
2120 """Base SQLAlchemy users table definition."""
2221
@@ -30,22 +29,28 @@ class SQLAlchemyBaseUserTable(Generic[ID]):
3029 is_superuser : bool
3130 is_verified : bool
3231 else :
33- email : str = Column (String (length = 320 ), unique = True , index = True , nullable = False )
34- hashed_password : str = Column (String (length = 1024 ), nullable = False )
35- is_active : bool = Column (Boolean , default = True , nullable = False )
36- is_superuser : bool = Column (Boolean , default = False , nullable = False )
37- is_verified : bool = Column (Boolean , default = False , nullable = False )
32+ email : Mapped [str ] = mapped_column (
33+ String (length = 320 ), unique = True , index = True , nullable = False
34+ )
35+ hashed_password : Mapped [str ] = mapped_column (
36+ String (length = 1024 ), nullable = False
37+ )
38+ is_active : Mapped [bool ] = mapped_column (Boolean , default = True , nullable = False )
39+ is_superuser : Mapped [bool ] = mapped_column (
40+ Boolean , default = False , nullable = False
41+ )
42+ is_verified : Mapped [bool ] = mapped_column (
43+ Boolean , default = False , nullable = False
44+ )
3845
3946
40- @declarative_mixin
4147class SQLAlchemyBaseUserTableUUID (SQLAlchemyBaseUserTable [UUID_ID ]):
4248 if TYPE_CHECKING : # pragma: no cover
4349 id : UUID_ID
4450 else :
45- id : UUID_ID = Column (GUID , primary_key = True , default = uuid .uuid4 )
51+ id : Mapped [ UUID_ID ] = mapped_column (GUID , primary_key = True , default = uuid .uuid4 )
4652
4753
48- @declarative_mixin
4954class SQLAlchemyBaseOAuthAccountTable (Generic [ID ]):
5055 """Base SQLAlchemy OAuth account table definition."""
5156
@@ -60,24 +65,32 @@ class SQLAlchemyBaseOAuthAccountTable(Generic[ID]):
6065 account_id : str
6166 account_email : str
6267 else :
63- oauth_name : str = Column (String (length = 100 ), index = True , nullable = False )
64- access_token : str = Column (String (length = 1024 ), nullable = False )
65- expires_at : Optional [int ] = Column (Integer , nullable = True )
66- refresh_token : Optional [str ] = Column (String (length = 1024 ), nullable = True )
67- account_id : str = Column (String (length = 320 ), index = True , nullable = False )
68- account_email : str = Column (String (length = 320 ), nullable = False )
68+ oauth_name : Mapped [str ] = mapped_column (
69+ String (length = 100 ), index = True , nullable = False
70+ )
71+ access_token : Mapped [str ] = mapped_column (String (length = 1024 ), nullable = False )
72+ expires_at : Mapped [Optional [int ]] = mapped_column (Integer , nullable = True )
73+ refresh_token : Mapped [Optional [str ]] = mapped_column (
74+ String (length = 1024 ), nullable = True
75+ )
76+ account_id : Mapped [str ] = mapped_column (
77+ String (length = 320 ), index = True , nullable = False
78+ )
79+ account_email : Mapped [str ] = mapped_column (String (length = 320 ), nullable = False )
6980
7081
71- @declarative_mixin
7282class SQLAlchemyBaseOAuthAccountTableUUID (SQLAlchemyBaseOAuthAccountTable [UUID_ID ]):
7383 if TYPE_CHECKING : # pragma: no cover
7484 id : UUID_ID
85+ user_id : UUID_ID
7586 else :
76- id : UUID_ID = Column (GUID , primary_key = True , default = uuid .uuid4 )
87+ id : Mapped [ UUID_ID ] = mapped_column (GUID , primary_key = True , default = uuid .uuid4 )
7788
78- @declared_attr
79- def user_id (cls ) -> Column [GUID ]:
80- return Column (GUID , ForeignKey ("user.id" , ondelete = "cascade" ), nullable = False )
89+ @declared_attr
90+ def user_id (cls ) -> Mapped [GUID ]:
91+ return mapped_column (
92+ GUID , ForeignKey ("user.id" , ondelete = "cascade" ), nullable = False
93+ )
8194
8295
8396class SQLAlchemyUserDatabase (Generic [UP , ID ], BaseUserDatabase [UP , ID ]):
@@ -120,24 +133,22 @@ async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UP
120133 statement = (
121134 select (self .user_table )
122135 .join (self .oauth_account_table )
123- .where (self .oauth_account_table .oauth_name == oauth )
124- .where (self .oauth_account_table .account_id == account_id )
136+ .where (self .oauth_account_table .oauth_name == oauth ) # type: ignore
137+ .where (self .oauth_account_table .account_id == account_id ) # type: ignore
125138 )
126139 return await self ._get_user (statement )
127140
128141 async def create (self , create_dict : Dict [str , Any ]) -> UP :
129142 user = self .user_table (** create_dict )
130143 self .session .add (user )
131144 await self .session .commit ()
132- await self .session .refresh (user )
133145 return user
134146
135147 async def update (self , user : UP , update_dict : Dict [str , Any ]) -> UP :
136148 for key , value in update_dict .items ():
137149 setattr (user , key , value )
138150 self .session .add (user )
139151 await self .session .commit ()
140- await self .session .refresh (user )
141152 return user
142153
143154 async def delete (self , user : UP ) -> None :
@@ -148,6 +159,7 @@ async def add_oauth_account(self, user: UP, create_dict: Dict[str, Any]) -> UP:
148159 if self .oauth_account_table is None :
149160 raise NotImplementedError ()
150161
162+ await self .session .refresh (user )
151163 oauth_account = self .oauth_account_table (** create_dict )
152164 self .session .add (oauth_account )
153165 user .oauth_accounts .append (oauth_account ) # type: ignore
@@ -172,8 +184,4 @@ async def update_oauth_account(
172184
173185 async def _get_user (self , statement : Select ) -> Optional [UP ]:
174186 results = await self .session .execute (statement )
175- user = results .first ()
176- if user is None :
177- return None
178-
179- return user [0 ]
187+ return results .unique ().scalar_one_or_none ()
0 commit comments