|
17 | 17 | from typing import Optional |
18 | 18 | from unittest.mock import MagicMock |
19 | 19 |
|
| 20 | +import pydantic |
| 21 | +import pytest |
| 22 | + |
20 | 23 | from google.adk.agents.invocation_context import InvocationContext |
21 | 24 | from google.adk.sessions.session import Session |
22 | 25 | from google.adk.tools.function_tool import FunctionTool |
23 | 26 | from google.adk.tools.tool_context import ToolContext |
24 | | -import pydantic |
25 | | -import pytest |
26 | 27 |
|
27 | 28 |
|
28 | 29 | class UserModel(pydantic.BaseModel): |
@@ -280,5 +281,121 @@ async def test_run_async_with_optional_pydantic_models(): |
280 | 281 | assert result["theme"] == "dark" |
281 | 282 | assert result["notifications"] is True |
282 | 283 | assert result["preferences_type"] == "PreferencesModel" |
283 | | - assert result["preferences_type"] == "PreferencesModel" |
284 | | - assert result["preferences_type"] == "PreferencesModel" |
| 284 | + |
| 285 | + |
| 286 | +def function_with_list_of_pydantic_models(users: list[UserModel]) -> dict: |
| 287 | + """Function that takes a list of Pydantic models.""" |
| 288 | + return { |
| 289 | + "count": len(users), |
| 290 | + "names": [user.name for user in users], |
| 291 | + "ages": [user.age for user in users], |
| 292 | + "types": [type(user).__name__ for user in users], |
| 293 | + } |
| 294 | + |
| 295 | + |
| 296 | +def function_with_optional_list_of_pydantic_models( |
| 297 | + users: Optional[list[UserModel]] = None, |
| 298 | +) -> dict: |
| 299 | + """Function that takes an optional list of Pydantic models.""" |
| 300 | + if users is None: |
| 301 | + return {"count": 0, "names": []} |
| 302 | + return { |
| 303 | + "count": len(users), |
| 304 | + "names": [user.name for user in users], |
| 305 | + } |
| 306 | + |
| 307 | + |
| 308 | +def test_preprocess_args_with_list_of_dicts_to_pydantic_models(): |
| 309 | + """Test _preprocess_args converts list of dicts to list of Pydantic models.""" |
| 310 | + tool = FunctionTool(function_with_list_of_pydantic_models) |
| 311 | + |
| 312 | + input_args = { |
| 313 | + "users": [ |
| 314 | + {"name": "Alice", "age": 30, "email": "alice@example.com"}, |
| 315 | + {"name": "Bob", "age": 25}, |
| 316 | + {"name": "Charlie", "age": 35, "email": "charlie@example.com"}, |
| 317 | + ] |
| 318 | + } |
| 319 | + |
| 320 | + processed_args = tool._preprocess_args(input_args) |
| 321 | + |
| 322 | + # Check that the list of dicts was converted to a list of Pydantic models |
| 323 | + assert "users" in processed_args |
| 324 | + users = processed_args["users"] |
| 325 | + assert isinstance(users, list) |
| 326 | + assert len(users) == 3 |
| 327 | + |
| 328 | + # Check each element is a Pydantic model with correct data |
| 329 | + assert isinstance(users[0], UserModel) |
| 330 | + assert users[0].name == "Alice" |
| 331 | + assert users[0].age == 30 |
| 332 | + assert users[0].email == "alice@example.com" |
| 333 | + |
| 334 | + assert isinstance(users[1], UserModel) |
| 335 | + assert users[1].name == "Bob" |
| 336 | + assert users[1].age == 25 |
| 337 | + assert users[1].email is None |
| 338 | + |
| 339 | + assert isinstance(users[2], UserModel) |
| 340 | + assert users[2].name == "Charlie" |
| 341 | + assert users[2].age == 35 |
| 342 | + assert users[2].email == "charlie@example.com" |
| 343 | + |
| 344 | + |
| 345 | +def test_preprocess_args_with_optional_list_of_pydantic_models_none(): |
| 346 | + """Test _preprocess_args handles None for optional list parameter.""" |
| 347 | + tool = FunctionTool(function_with_optional_list_of_pydantic_models) |
| 348 | + |
| 349 | + input_args = {"users": None} |
| 350 | + |
| 351 | + processed_args = tool._preprocess_args(input_args) |
| 352 | + |
| 353 | + # Check that None is preserved |
| 354 | + assert "users" in processed_args |
| 355 | + assert processed_args["users"] is None |
| 356 | + |
| 357 | + |
| 358 | +def test_preprocess_args_with_optional_list_of_pydantic_models_with_data(): |
| 359 | + """Test _preprocess_args converts list for optional list parameter.""" |
| 360 | + tool = FunctionTool(function_with_optional_list_of_pydantic_models) |
| 361 | + |
| 362 | + input_args = { |
| 363 | + "users": [ |
| 364 | + {"name": "Alice", "age": 30}, |
| 365 | + {"name": "Bob", "age": 25}, |
| 366 | + ] |
| 367 | + } |
| 368 | + |
| 369 | + processed_args = tool._preprocess_args(input_args) |
| 370 | + |
| 371 | + # Check conversion |
| 372 | + assert "users" in processed_args |
| 373 | + users = processed_args["users"] |
| 374 | + assert len(users) == 2 |
| 375 | + assert all(isinstance(user, UserModel) for user in users) |
| 376 | + assert users[0].name == "Alice" |
| 377 | + assert users[1].name == "Bob" |
| 378 | + |
| 379 | + |
| 380 | +def test_preprocess_args_with_list_skips_invalid_items(): |
| 381 | + """Test _preprocess_args skips items that fail validation.""" |
| 382 | + tool = FunctionTool(function_with_list_of_pydantic_models) |
| 383 | + |
| 384 | + input_args = { |
| 385 | + "users": [ |
| 386 | + {"name": "Alice", "age": 30}, |
| 387 | + {"name": "Invalid"}, # Missing required 'age' field |
| 388 | + {"name": "Bob", "age": 25}, |
| 389 | + ] |
| 390 | + } |
| 391 | + |
| 392 | + processed_args = tool._preprocess_args(input_args) |
| 393 | + |
| 394 | + # Check that invalid item was skipped |
| 395 | + assert "users" in processed_args |
| 396 | + users = processed_args["users"] |
| 397 | + assert len(users) == 2 # Only 2 valid items |
| 398 | + assert users[0].name == "Alice" |
| 399 | + assert users[0].age == 30 |
| 400 | + assert users[1].name == "Bob" |
| 401 | + assert users[1].age == 25 |
0 commit comments