From 3b07703455e51b7c97c4519a9d0a088b645b1d7b Mon Sep 17 00:00:00 2001 From: amalcaraz Date: Mon, 27 Oct 2025 11:17:33 +0100 Subject: [PATCH 1/6] feat: new endpoint for calculating consumed credits by item_hash --- src/aleph/db/accessors/balances.py | 27 ++++++++++++++++++++++++++ src/aleph/schemas/api/accounts.py | 5 +++++ src/aleph/web/controllers/accounts.py | 28 +++++++++++++++++++++++++++ src/aleph/web/controllers/routes.py | 4 ++++ 4 files changed, 64 insertions(+) diff --git a/src/aleph/db/accessors/balances.py b/src/aleph/db/accessors/balances.py index 70964a111..c48315b75 100644 --- a/src/aleph/db/accessors/balances.py +++ b/src/aleph/db/accessors/balances.py @@ -646,3 +646,30 @@ def count_address_credit_history( ) return session.execute(query).scalar_one() + + +def get_resource_consumed_credits( + session: DbSession, + item_hash: str, +) -> int: + """ + Calculate the total credits consumed by a specific resource. + + Aggregates all credit_history entries where: + - payment_method = 'credit_expense' + - origin = item_hash (the resource identifier) + + Args: + session: Database session + item_hash: The item hash of the resource (message hash) + + Returns: + Total credits consumed by the resource + """ + query = select(func.sum(func.abs(AlephCreditHistoryDb.amount))).where( + (AlephCreditHistoryDb.payment_method == "credit_expense") + & (AlephCreditHistoryDb.origin == item_hash) + ) + + result = session.execute(query).scalar() + return result or 0 diff --git a/src/aleph/schemas/api/accounts.py b/src/aleph/schemas/api/accounts.py index 7fb26a9b4..4f56af8ec 100644 --- a/src/aleph/schemas/api/accounts.py +++ b/src/aleph/schemas/api/accounts.py @@ -149,3 +149,8 @@ class GetAccountCreditHistoryResponse(BaseModel): pagination_page: int pagination_total: int pagination_per_page: int + + +class GetResourceConsumedCreditsResponse(BaseModel): + item_hash: str + consumed_credits: int diff --git a/src/aleph/web/controllers/accounts.py b/src/aleph/web/controllers/accounts.py index 1491514c1..069d9d322 100644 --- a/src/aleph/web/controllers/accounts.py +++ b/src/aleph/web/controllers/accounts.py @@ -15,6 +15,7 @@ get_balances_by_chain, get_credit_balance, get_credit_balances, + get_resource_consumed_credits, get_total_detailed_balance, ) from aleph.db.accessors.cost import get_total_cost_for_address @@ -33,6 +34,7 @@ GetAccountQueryParams, GetBalancesChainsQueryParams, GetCreditBalancesQueryParams, + GetResourceConsumedCreditsResponse, ) from aleph.types.db_session import DbSessionFactory from aleph.web.controllers.app_state_getters import get_session_factory_from_request @@ -84,6 +86,13 @@ def _get_chain_from_request(request: web.Request) -> str: return chain +def _get_item_hash_from_request(request: web.Request) -> str: + item_hash = request.match_info.get("item_hash") + if item_hash is None: + raise web.HTTPUnprocessableEntity(text="Item hash must be specified.") + return item_hash + + async def get_account_balance(request: web.Request): address = _get_address_from_request(request) @@ -268,3 +277,22 @@ async def get_account_credit_history(request: web.Request) -> web.Response: ) return web.json_response(text=response.model_dump_json()) + + +async def get_resource_consumed_credits(request: web.Request) -> web.Response: + """Returns the total credits consumed by a specific resource (item_hash).""" + item_hash = _get_item_hash_from_request(request) + + session_factory: DbSessionFactory = get_session_factory_from_request(request) + + with session_factory() as session: + consumed_credits = get_resource_consumed_credits( + session=session, item_hash=item_hash + ) + + response = GetResourceConsumedCreditsResponse( + item_hash=item_hash, + consumed_credits=consumed_credits, + ) + + return web.json_response(text=response.model_dump_json()) diff --git a/src/aleph/web/controllers/routes.py b/src/aleph/web/controllers/routes.py index 06951ab59..a8c76be13 100644 --- a/src/aleph/web/controllers/routes.py +++ b/src/aleph/web/controllers/routes.py @@ -83,6 +83,10 @@ def register_routes(app: web.Application): "/api/v0/addresses/{address}/credit_history", accounts.get_account_credit_history, ) + app.router.add_get( + "/api/v0/messages/{item_hash}/consumed_credits", + accounts.get_resource_consumed_credits, + ) app.router.add_post("/api/v0/ipfs/add_json", storage.add_ipfs_json_controller) app.router.add_post("/api/v0/storage/add_json", storage.add_storage_json_controller) From 987969c3a9026c2f9ea8e0cf3d05cd2e26d8e984 Mon Sep 17 00:00:00 2001 From: amalcaraz Date: Wed, 29 Oct 2025 12:35:18 +0000 Subject: [PATCH 2/6] fix: linting --- src/aleph/web/controllers/accounts.py | 4 +++- src/aleph/web/controllers/routes.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/aleph/web/controllers/accounts.py b/src/aleph/web/controllers/accounts.py index 069d9d322..c0793230b 100644 --- a/src/aleph/web/controllers/accounts.py +++ b/src/aleph/web/controllers/accounts.py @@ -279,7 +279,9 @@ async def get_account_credit_history(request: web.Request) -> web.Response: return web.json_response(text=response.model_dump_json()) -async def get_resource_consumed_credits(request: web.Request) -> web.Response: +async def get_resource_consumed_credits_controller( + request: web.Request, +) -> web.Response: """Returns the total credits consumed by a specific resource (item_hash).""" item_hash = _get_item_hash_from_request(request) diff --git a/src/aleph/web/controllers/routes.py b/src/aleph/web/controllers/routes.py index a8c76be13..399a0586d 100644 --- a/src/aleph/web/controllers/routes.py +++ b/src/aleph/web/controllers/routes.py @@ -85,7 +85,7 @@ def register_routes(app: web.Application): ) app.router.add_get( "/api/v0/messages/{item_hash}/consumed_credits", - accounts.get_resource_consumed_credits, + accounts.get_resource_consumed_credits_controller, ) app.router.add_post("/api/v0/ipfs/add_json", storage.add_ipfs_json_controller) From fd835a55af884a3d5dc3d8951d04d85be729ea20 Mon Sep 17 00:00:00 2001 From: amalcaraz Date: Thu, 30 Oct 2025 12:38:39 +0100 Subject: [PATCH 3/6] refactor: make reusable ItemHash parsing util function --- src/aleph/web/controllers/accounts.py | 10 ++------- src/aleph/web/controllers/messages.py | 32 +++++---------------------- src/aleph/web/controllers/prices.py | 25 ++++++++++++--------- src/aleph/web/controllers/utils.py | 31 ++++++++++++++++++++++++++ 4 files changed, 53 insertions(+), 45 deletions(-) diff --git a/src/aleph/web/controllers/accounts.py b/src/aleph/web/controllers/accounts.py index c0793230b..578d0affd 100644 --- a/src/aleph/web/controllers/accounts.py +++ b/src/aleph/web/controllers/accounts.py @@ -38,6 +38,7 @@ ) from aleph.types.db_session import DbSessionFactory from aleph.web.controllers.app_state_getters import get_session_factory_from_request +from aleph.web.controllers.utils import get_item_hash_str_from_request def make_stats_dict(stats) -> Dict[str, Any]: @@ -86,13 +87,6 @@ def _get_chain_from_request(request: web.Request) -> str: return chain -def _get_item_hash_from_request(request: web.Request) -> str: - item_hash = request.match_info.get("item_hash") - if item_hash is None: - raise web.HTTPUnprocessableEntity(text="Item hash must be specified.") - return item_hash - - async def get_account_balance(request: web.Request): address = _get_address_from_request(request) @@ -283,7 +277,7 @@ async def get_resource_consumed_credits_controller( request: web.Request, ) -> web.Response: """Returns the total credits consumed by a specific resource (item_hash).""" - item_hash = _get_item_hash_from_request(request) + item_hash = get_item_hash_str_from_request(request) session_factory: DbSessionFactory = get_session_factory_from_request(request) diff --git a/src/aleph/web/controllers/messages.py b/src/aleph/web/controllers/messages.py index fa7a60cc4..a41eb632b 100644 --- a/src/aleph/web/controllers/messages.py +++ b/src/aleph/web/controllers/messages.py @@ -58,6 +58,7 @@ DEFAULT_MESSAGES_PER_PAGE, DEFAULT_PAGE, LIST_FIELD_SEPARATOR, + get_item_hash_from_request, mq_make_aleph_message_topic_queue, ) @@ -611,14 +612,7 @@ def _get_message_with_status( async def view_message(request: web.Request): - item_hash_str = request.match_info.get("item_hash") - if not item_hash_str: - raise web.HTTPUnprocessableEntity(text=f"Invalid message hash: {item_hash_str}") - - try: - item_hash = ItemHash(item_hash_str) - except ValueError: - raise web.HTTPBadRequest(body=f"Invalid message hash: {item_hash_str}") + item_hash = get_item_hash_from_request(request) session_factory: DbSessionFactory = request.app["session_factory"] with session_factory() as session: @@ -633,14 +627,7 @@ async def view_message(request: web.Request): async def view_message_content(request: web.Request): - item_hash_str = request.match_info.get("item_hash") - if not item_hash_str: - raise web.HTTPUnprocessableEntity(text=f"Invalid message hash: {item_hash_str}") - - try: - item_hash = ItemHash(item_hash_str) - except ValueError: - raise web.HTTPBadRequest(body=f"Invalid message hash: {item_hash_str}") + item_hash = get_item_hash_from_request(request) session_factory: DbSessionFactory = request.app["session_factory"] with session_factory() as session: @@ -658,13 +645,13 @@ async def view_message_content(request: web.Request): or not isinstance(message_with_status.message, PostMessage) ): raise web.HTTPUnprocessableEntity( - text=f"Invalid message hash status {status} for hash {item_hash_str}" + text=f"Invalid message hash status {status} for hash {item_hash}" ) message_type = message_with_status.message.type if message_type != MessageType.post: raise web.HTTPUnprocessableEntity( - text=f"Invalid message hash type {message_type} for hash {item_hash_str}" + text=f"Invalid message hash type {message_type} for hash {item_hash}" ) content = message_with_status.message.content.content @@ -672,14 +659,7 @@ async def view_message_content(request: web.Request): async def view_message_status(request: web.Request): - item_hash_str = request.match_info.get("item_hash") - if not item_hash_str: - raise web.HTTPUnprocessableEntity(text=f"Invalid message hash: {item_hash_str}") - - try: - item_hash = ItemHash(item_hash_str) - except ValueError: - raise web.HTTPBadRequest(body=f"Invalid message hash: {item_hash_str}") + item_hash = get_item_hash_from_request(request) session_factory: DbSessionFactory = request.app["session_factory"] with session_factory() as session: diff --git a/src/aleph/web/controllers/prices.py b/src/aleph/web/controllers/prices.py index 4968fe58c..1e3c23cf6 100644 --- a/src/aleph/web/controllers/prices.py +++ b/src/aleph/web/controllers/prices.py @@ -36,6 +36,7 @@ get_session_factory_from_request, get_storage_service_from_request, ) +from aleph.web.controllers.utils import get_item_hash_from_request LOGGER = logging.getLogger(__name__) @@ -67,17 +68,11 @@ class MessagePrice(DataClassJsonMixin): required_tokens: Optional[Decimal] = None -async def get_executable_message(session: DbSession, item_hash_str: str) -> MessageDb: +async def get_executable_message(session: DbSession, item_hash: ItemHash) -> MessageDb: """Attempt to get an executable message from the database. Raises an HTTP exception if the message is not found, not processed or is not an executable message. """ - # Parse the item_hash_str into an ItemHash object - try: - item_hash = ItemHash(item_hash_str) - except ValueError: - raise web.HTTPBadRequest(body=f"Invalid message hash: {item_hash_str}") - # Get the message status from the database message_status_db = get_message_status(session=session, item_hash=item_hash) if not message_status_db: @@ -85,7 +80,7 @@ async def get_executable_message(session: DbSession, item_hash_str: str) -> Mess # Loop through the status_exceptions to find a match and raise the corresponding exception if message_status_db.status in MESSAGE_STATUS_EXCEPTIONS: exception, error_message = MESSAGE_STATUS_EXCEPTIONS[message_status_db.status] - raise exception(body=f"{error_message}: {item_hash_str}") + raise exception(body=f"{error_message}: {item_hash}") assert message_status_db.status == MessageStatus.PROCESSED # Get the message from the database @@ -98,7 +93,7 @@ async def get_executable_message(session: DbSession, item_hash_str: str) -> Mess MessageType.store, ): raise web.HTTPBadRequest( - body=f"Message is not an executable or store message: {item_hash_str}" + body=f"Message is not an executable or store message: {item_hash}" ) return message @@ -109,7 +104,7 @@ async def message_price(request: web.Request): session_factory = get_session_factory_from_request(request) with session_factory() as session: - item_hash = request.match_info["item_hash"] + item_hash = get_item_hash_from_request(request) message = await get_executable_message(session, item_hash) content: ExecutableContent = message.parsed_content @@ -201,7 +196,15 @@ async def recalculate_message_costs(request: web.Request): if item_hash_param: # Recalculate costs for a specific message try: - message = await get_executable_message(session, item_hash_param) + # Parse the item_hash_param into an ItemHash object + try: + item_hash = ItemHash(item_hash_param) + except ValueError: + raise web.HTTPBadRequest( + body=f"Invalid message hash: {item_hash_param}" + ) + + message = await get_executable_message(session, item_hash) messages_to_recalculate = [message] except HTTPException: raise diff --git a/src/aleph/web/controllers/utils.py b/src/aleph/web/controllers/utils.py index 7e0a4ac7b..5d3c5d7a5 100644 --- a/src/aleph/web/controllers/utils.py +++ b/src/aleph/web/controllers/utils.py @@ -11,6 +11,7 @@ import aiohttp_jinja2 from aiohttp import web from aiohttp.web_request import FileField +from aleph_message.models import ItemHash from aleph_p2p_client import AlephP2PServiceClient from configmanager import Config from pydantic import BaseModel @@ -414,3 +415,33 @@ def add_grace_period_for_file(session: DbSession, file_hash: str, hours: int): created=utc_now(), delete_by=delete_by, ) + + +def get_item_hash_str_from_request(request: web.Request) -> str: + """ + Extract and validate item_hash string from request path parameters. + Raises HTTPUnprocessableEntity if item_hash is missing. + """ + item_hash_str = request.match_info.get("item_hash") + if not item_hash_str: + raise web.HTTPUnprocessableEntity(text="Item hash must be specified.") + return item_hash_str + + +def get_item_hash_from_request(request: web.Request) -> ItemHash: + """ + Extract and validate item_hash from request path parameters. + Returns an ItemHash object. + Raises HTTPUnprocessableEntity if item_hash is missing. + Raises HTTPBadRequest if item_hash format is invalid. + """ + item_hash_str = request.match_info.get("item_hash") + if not item_hash_str: + raise web.HTTPUnprocessableEntity(text=f"Invalid message hash: {item_hash_str}") + + try: + item_hash = ItemHash(item_hash_str) + except ValueError: + raise web.HTTPBadRequest(body=f"Invalid message hash: {item_hash_str}") + + return item_hash From 8682c94f744cf12d612fed9633a4339c02d4cacb Mon Sep 17 00:00:00 2001 From: amalcaraz Date: Thu, 30 Oct 2025 13:23:25 +0100 Subject: [PATCH 4/6] chore: added tests --- tests/db/test_credit_balances.py | 276 ++++++++++++++++++ .../controllers/test_accounts_controllers.py | 117 ++++++++ 2 files changed, 393 insertions(+) create mode 100644 tests/web/controllers/test_accounts_controllers.py diff --git a/tests/db/test_credit_balances.py b/tests/db/test_credit_balances.py index 3e8d1820d..8c0d68ea0 100644 --- a/tests/db/test_credit_balances.py +++ b/tests/db/test_credit_balances.py @@ -872,3 +872,279 @@ def test_cache_invalidation_on_credit_expiration(session_factory: DbSessionFacto session.refresh(cached_balance) assert cached_balance.balance == 0 assert cached_balance.last_update > cache_time + + +def test_get_resource_consumed_credits_no_records(session_factory: DbSessionFactory): + """Test get_resource_consumed_credits returns 0 when no records exist.""" + from aleph.db.accessors.balances import get_resource_consumed_credits + + with session_factory() as session: + consumed_credits = get_resource_consumed_credits( + session=session, item_hash="nonexistent_hash" + ) + assert consumed_credits == 0 + + +def test_get_resource_consumed_credits_single_record(session_factory: DbSessionFactory): + """Test get_resource_consumed_credits with a single expense record.""" + from aleph.db.accessors.balances import get_resource_consumed_credits + + # Create a credit expense record + expense_credits = [ + { + "address": "0xtest_user", + "amount": 150, + "ref": "resource_123", + } + ] + + message_timestamp = dt.datetime(2023, 1, 1, 12, 0, 0, tzinfo=dt.timezone.utc) + + with session_factory() as session: + # Add the expense record with origin set to the resource hash + update_credit_balances_expense( + session=session, + credits_list=expense_credits, + message_hash="expense_msg_123", + message_timestamp=message_timestamp, + ) + + # Manually set the origin field to the item_hash we want to test + # Since update_credit_balances_expense doesn't set origin by default + from aleph.db.models import AlephCreditHistoryDb + from sqlalchemy import update as sql_update + + session.execute( + sql_update(AlephCreditHistoryDb) + .where(AlephCreditHistoryDb.credit_ref == "expense_msg_123") + .values(origin="resource_123") + ) + session.commit() + + consumed_credits = get_resource_consumed_credits( + session=session, item_hash="resource_123" + ) + assert consumed_credits == 150 + + +def test_get_resource_consumed_credits_multiple_records(session_factory: DbSessionFactory): + """Test get_resource_consumed_credits with multiple expense records for the same resource.""" + from aleph.db.accessors.balances import get_resource_consumed_credits + + message_timestamp = dt.datetime(2023, 1, 1, 12, 0, 0, tzinfo=dt.timezone.utc) + + with session_factory() as session: + # Create multiple expense records for the same resource + expense_batches = [ + { + "credits": [{"address": "0xuser1", "amount": 100, "ref": "resource_456"}], + "message_hash": "expense_msg_1", + }, + { + "credits": [{"address": "0xuser2", "amount": 250, "ref": "resource_456"}], + "message_hash": "expense_msg_2", + }, + { + "credits": [{"address": "0xuser3", "amount": 75, "ref": "resource_456"}], + "message_hash": "expense_msg_3", + }, + ] + + for batch in expense_batches: + update_credit_balances_expense( + session=session, + credits_list=batch["credits"], + message_hash=batch["message_hash"], + message_timestamp=message_timestamp, + ) + + # Set origin for all records + from aleph.db.models import AlephCreditHistoryDb + from sqlalchemy import update as sql_update + + for batch in expense_batches: + session.execute( + sql_update(AlephCreditHistoryDb) + .where(AlephCreditHistoryDb.credit_ref == batch["message_hash"]) + .values(origin="resource_456") + ) + session.commit() + + consumed_credits = get_resource_consumed_credits( + session=session, item_hash="resource_456" + ) + # Total: 100 + 250 + 75 = 425 + assert consumed_credits == 425 + + +def test_get_resource_consumed_credits_filters_by_payment_method(session_factory: DbSessionFactory): + """Test that get_resource_consumed_credits only counts credit_expense payments.""" + from aleph.db.accessors.balances import get_resource_consumed_credits + + message_timestamp = dt.datetime(2023, 1, 1, 12, 0, 0, tzinfo=dt.timezone.utc) + + with session_factory() as session: + # Add credit distribution (should be ignored) + distribution_credits = [ + { + "address": "0xuser1", + "amount": 500, + "ratio": "1.0", + "tx_hash": "0xdist", + "provider": "test_provider", + "expiration": 2000000000000, + } + ] + update_credit_balances_distribution( + session=session, + credits_list=distribution_credits, + token="TEST", + chain="ETH", + message_hash="distribution_msg", + message_timestamp=message_timestamp, + ) + + # Add credit transfer (should be ignored) + transfer_credits = [{"address": "0xuser2", "amount": 200}] + update_credit_balances_transfer( + session=session, + credits_list=transfer_credits, + sender_address="0xsender", + whitelisted_addresses=[], + message_hash="transfer_msg", + message_timestamp=message_timestamp, + ) + + # Add credit expense (should be counted) + expense_credits = [{"address": "0xuser3", "amount": 150, "ref": "resource_789"}] + update_credit_balances_expense( + session=session, + credits_list=expense_credits, + message_hash="expense_msg", + message_timestamp=message_timestamp, + ) + + # Set origin for all records to the same resource + from aleph.db.models import AlephCreditHistoryDb + from sqlalchemy import update as sql_update + + for msg_hash in ["distribution_msg", "transfer_msg", "expense_msg"]: + session.execute( + sql_update(AlephCreditHistoryDb) + .where(AlephCreditHistoryDb.credit_ref == msg_hash) + .values(origin="resource_789") + ) + session.commit() + + consumed_credits = get_resource_consumed_credits( + session=session, item_hash="resource_789" + ) + # Only the expense (150) should be counted, not distribution or transfer + assert consumed_credits == 150 + + +def test_get_resource_consumed_credits_filters_by_origin(session_factory: DbSessionFactory): + """Test that get_resource_consumed_credits only counts records with matching origin.""" + from aleph.db.accessors.balances import get_resource_consumed_credits + + message_timestamp = dt.datetime(2023, 1, 1, 12, 0, 0, tzinfo=dt.timezone.utc) + + with session_factory() as session: + # Create expense records for different resources + expenses = [ + { + "credits": [{"address": "0xuser1", "amount": 100}], + "message_hash": "expense_resource_a", + "origin": "resource_aaa", + }, + { + "credits": [{"address": "0xuser2", "amount": 200}], + "message_hash": "expense_resource_b", + "origin": "resource_bbb", + }, + { + "credits": [{"address": "0xuser3", "amount": 300}], + "message_hash": "expense_resource_a_2", + "origin": "resource_aaa", + }, + ] + + for expense in expenses: + update_credit_balances_expense( + session=session, + credits_list=expense["credits"], + message_hash=expense["message_hash"], + message_timestamp=message_timestamp, + ) + + # Set the origin for this expense + from aleph.db.models import AlephCreditHistoryDb + from sqlalchemy import update as sql_update + + session.execute( + sql_update(AlephCreditHistoryDb) + .where(AlephCreditHistoryDb.credit_ref == expense["message_hash"]) + .values(origin=expense["origin"]) + ) + + session.commit() + + # Test resource_aaa (should get 100 + 300 = 400) + consumed_credits_a = get_resource_consumed_credits( + session=session, item_hash="resource_aaa" + ) + assert consumed_credits_a == 400 + + # Test resource_bbb (should get 200) + consumed_credits_b = get_resource_consumed_credits( + session=session, item_hash="resource_bbb" + ) + assert consumed_credits_b == 200 + + # Test nonexistent resource (should get 0) + consumed_credits_none = get_resource_consumed_credits( + session=session, item_hash="resource_nonexistent" + ) + assert consumed_credits_none == 0 + + +def test_get_resource_consumed_credits_uses_absolute_values(session_factory: DbSessionFactory): + """Test that get_resource_consumed_credits uses absolute values of amounts.""" + from aleph.db.accessors.balances import get_resource_consumed_credits + + message_timestamp = dt.datetime(2023, 1, 1, 12, 0, 0, tzinfo=dt.timezone.utc) + + with session_factory() as session: + # Create expense record + expense_credits = [{"address": "0xuser", "amount": 250}] + update_credit_balances_expense( + session=session, + credits_list=expense_credits, + message_hash="expense_msg", + message_timestamp=message_timestamp, + ) + + # Set origin + from aleph.db.models import AlephCreditHistoryDb + from sqlalchemy import update as sql_update + + session.execute( + sql_update(AlephCreditHistoryDb) + .where(AlephCreditHistoryDb.credit_ref == "expense_msg") + .values(origin="resource_abs") + ) + session.commit() + + # Verify that the expense record has negative amount (as expected from expense) + expense_record = session.execute( + select(AlephCreditHistoryDb).where( + AlephCreditHistoryDb.credit_ref == "expense_msg" + ) + ).scalar_one() + assert expense_record.amount == -250 # Expenses are stored as negative + + # But get_resource_consumed_credits should return the absolute value + consumed_credits = get_resource_consumed_credits( + session=session, item_hash="resource_abs" + ) + assert consumed_credits == 250 diff --git a/tests/web/controllers/test_accounts_controllers.py b/tests/web/controllers/test_accounts_controllers.py new file mode 100644 index 000000000..3513bbc0d --- /dev/null +++ b/tests/web/controllers/test_accounts_controllers.py @@ -0,0 +1,117 @@ +import json +import pytest +import pytest_asyncio +from unittest.mock import MagicMock, patch + +from aleph.web.controllers.accounts import get_resource_consumed_credits_controller +from aleph.types.db_session import DbSessionFactory + + +@pytest.mark.asyncio +async def test_get_resource_consumed_credits_controller_success(): + """Test successful retrieval of consumed credits for a resource.""" + # Mock request object + mock_request = MagicMock() + mock_request.match_info = {"item_hash": "test_hash_123"} + + # Mock session factory and session + mock_session = MagicMock() + mock_session_factory = MagicMock() + mock_session_factory.return_value.__enter__.return_value = mock_session + mock_session_factory.return_value.__exit__.return_value = None + + # Mock consumed credits value + expected_consumed_credits = 42 + + with patch("aleph.web.controllers.accounts.get_item_hash_str_from_request") as mock_get_hash, \ + patch("aleph.web.controllers.accounts.get_session_factory_from_request") as mock_get_factory, \ + patch("aleph.web.controllers.accounts.get_resource_consumed_credits") as mock_get_credits: + + # Set up mocks + mock_get_hash.return_value = "test_hash_123" + mock_get_factory.return_value = mock_session_factory + mock_get_credits.return_value = expected_consumed_credits + + # Call the controller + response = await get_resource_consumed_credits_controller(mock_request) + + # Verify the response + assert response.status == 200 + response_data = json.loads(response.text) + assert response_data["item_hash"] == "test_hash_123" + assert response_data["consumed_credits"] == expected_consumed_credits + + # Verify mocks were called correctly + mock_get_hash.assert_called_once_with(mock_request) + mock_get_factory.assert_called_once_with(mock_request) + mock_get_credits.assert_called_once_with(session=mock_session, item_hash="test_hash_123") + + +@pytest.mark.asyncio +async def test_get_resource_consumed_credits_controller_zero_credits(): + """Test retrieval when resource has zero consumed credits.""" + # Mock request object + mock_request = MagicMock() + mock_request.match_info = {"item_hash": "empty_hash_456"} + + # Mock session factory and session + mock_session = MagicMock() + mock_session_factory = MagicMock() + mock_session_factory.return_value.__enter__.return_value = mock_session + mock_session_factory.return_value.__exit__.return_value = None + + # Mock zero consumed credits + expected_consumed_credits = 0 + + with patch("aleph.web.controllers.accounts.get_item_hash_str_from_request") as mock_get_hash, \ + patch("aleph.web.controllers.accounts.get_session_factory_from_request") as mock_get_factory, \ + patch("aleph.web.controllers.accounts.get_resource_consumed_credits") as mock_get_credits: + + # Set up mocks + mock_get_hash.return_value = "empty_hash_456" + mock_get_factory.return_value = mock_session_factory + mock_get_credits.return_value = expected_consumed_credits + + # Call the controller + response = await get_resource_consumed_credits_controller(mock_request) + + # Verify the response + assert response.status == 200 + response_data = json.loads(response.text) + assert response_data["item_hash"] == "empty_hash_456" + assert response_data["consumed_credits"] == 0 + + +@pytest.mark.asyncio +async def test_get_resource_consumed_credits_controller_large_credits(): + """Test retrieval with a large consumed credits value.""" + # Mock request object + mock_request = MagicMock() + mock_request.match_info = {"item_hash": "large_hash_789"} + + # Mock session factory and session + mock_session = MagicMock() + mock_session_factory = MagicMock() + mock_session_factory.return_value.__enter__.return_value = mock_session + mock_session_factory.return_value.__exit__.return_value = None + + # Mock large consumed credits value + expected_consumed_credits = 999999 + + with patch("aleph.web.controllers.accounts.get_item_hash_str_from_request") as mock_get_hash, \ + patch("aleph.web.controllers.accounts.get_session_factory_from_request") as mock_get_factory, \ + patch("aleph.web.controllers.accounts.get_resource_consumed_credits") as mock_get_credits: + + # Set up mocks + mock_get_hash.return_value = "large_hash_789" + mock_get_factory.return_value = mock_session_factory + mock_get_credits.return_value = expected_consumed_credits + + # Call the controller + response = await get_resource_consumed_credits_controller(mock_request) + + # Verify the response + assert response.status == 200 + response_data = json.loads(response.text) + assert response_data["item_hash"] == "large_hash_789" + assert response_data["consumed_credits"] == expected_consumed_credits \ No newline at end of file From 37d38f6b4f5615c02973adda096fcd1fd8c8efb4 Mon Sep 17 00:00:00 2001 From: amalcaraz Date: Thu, 30 Oct 2025 13:31:30 +0100 Subject: [PATCH 5/6] fix: linting issues --- tests/db/test_credit_balances.py | 76 +++++++++++++------ .../controllers/test_accounts_controllers.py | 53 +++++++++---- 2 files changed, 92 insertions(+), 37 deletions(-) diff --git a/tests/db/test_credit_balances.py b/tests/db/test_credit_balances.py index 8c0d68ea0..f61f78025 100644 --- a/tests/db/test_credit_balances.py +++ b/tests/db/test_credit_balances.py @@ -1,6 +1,7 @@ import datetime as dt import time from decimal import Decimal +from typing import Any, Dict, List from sqlalchemy import select @@ -911,9 +912,10 @@ def test_get_resource_consumed_credits_single_record(session_factory: DbSessionF # Manually set the origin field to the item_hash we want to test # Since update_credit_balances_expense doesn't set origin by default - from aleph.db.models import AlephCreditHistoryDb from sqlalchemy import update as sql_update + from aleph.db.models import AlephCreditHistoryDb + session.execute( sql_update(AlephCreditHistoryDb) .where(AlephCreditHistoryDb.credit_ref == "expense_msg_123") @@ -927,7 +929,9 @@ def test_get_resource_consumed_credits_single_record(session_factory: DbSessionF assert consumed_credits == 150 -def test_get_resource_consumed_credits_multiple_records(session_factory: DbSessionFactory): +def test_get_resource_consumed_credits_multiple_records( + session_factory: DbSessionFactory, +): """Test get_resource_consumed_credits with multiple expense records for the same resource.""" from aleph.db.accessors.balances import get_resource_consumed_credits @@ -935,37 +939,49 @@ def test_get_resource_consumed_credits_multiple_records(session_factory: DbSessi with session_factory() as session: # Create multiple expense records for the same resource - expense_batches = [ + expense_batches: List[Dict[str, Any]] = [ { - "credits": [{"address": "0xuser1", "amount": 100, "ref": "resource_456"}], + "credits": [ + {"address": "0xuser1", "amount": 100, "ref": "resource_456"} + ], "message_hash": "expense_msg_1", }, { - "credits": [{"address": "0xuser2", "amount": 250, "ref": "resource_456"}], + "credits": [ + {"address": "0xuser2", "amount": 250, "ref": "resource_456"} + ], "message_hash": "expense_msg_2", }, { - "credits": [{"address": "0xuser3", "amount": 75, "ref": "resource_456"}], + "credits": [ + {"address": "0xuser3", "amount": 75, "ref": "resource_456"} + ], "message_hash": "expense_msg_3", }, ] + # Import required modules + from sqlalchemy import update as sql_update + + from aleph.db.models import AlephCreditHistoryDb + for batch in expense_batches: + credits_list: List[Dict[str, Any]] = batch["credits"] + message_hash: str = batch["message_hash"] update_credit_balances_expense( session=session, - credits_list=batch["credits"], - message_hash=batch["message_hash"], + credits_list=credits_list, + message_hash=message_hash, message_timestamp=message_timestamp, ) # Set origin for all records - from aleph.db.models import AlephCreditHistoryDb - from sqlalchemy import update as sql_update for batch in expense_batches: + batch_message_hash = batch["message_hash"] session.execute( sql_update(AlephCreditHistoryDb) - .where(AlephCreditHistoryDb.credit_ref == batch["message_hash"]) + .where(AlephCreditHistoryDb.credit_ref == batch_message_hash) .values(origin="resource_456") ) session.commit() @@ -977,7 +993,9 @@ def test_get_resource_consumed_credits_multiple_records(session_factory: DbSessi assert consumed_credits == 425 -def test_get_resource_consumed_credits_filters_by_payment_method(session_factory: DbSessionFactory): +def test_get_resource_consumed_credits_filters_by_payment_method( + session_factory: DbSessionFactory, +): """Test that get_resource_consumed_credits only counts credit_expense payments.""" from aleph.db.accessors.balances import get_resource_consumed_credits @@ -1025,9 +1043,10 @@ def test_get_resource_consumed_credits_filters_by_payment_method(session_factory ) # Set origin for all records to the same resource - from aleph.db.models import AlephCreditHistoryDb from sqlalchemy import update as sql_update + from aleph.db.models import AlephCreditHistoryDb + for msg_hash in ["distribution_msg", "transfer_msg", "expense_msg"]: session.execute( sql_update(AlephCreditHistoryDb) @@ -1043,7 +1062,9 @@ def test_get_resource_consumed_credits_filters_by_payment_method(session_factory assert consumed_credits == 150 -def test_get_resource_consumed_credits_filters_by_origin(session_factory: DbSessionFactory): +def test_get_resource_consumed_credits_filters_by_origin( + session_factory: DbSessionFactory, +): """Test that get_resource_consumed_credits only counts records with matching origin.""" from aleph.db.accessors.balances import get_resource_consumed_credits @@ -1051,7 +1072,7 @@ def test_get_resource_consumed_credits_filters_by_origin(session_factory: DbSess with session_factory() as session: # Create expense records for different resources - expenses = [ + expenses: List[Dict[str, Any]] = [ { "credits": [{"address": "0xuser1", "amount": 100}], "message_hash": "expense_resource_a", @@ -1069,22 +1090,28 @@ def test_get_resource_consumed_credits_filters_by_origin(session_factory: DbSess }, ] + # Import required modules + from sqlalchemy import update as sql_update + + from aleph.db.models import AlephCreditHistoryDb + for expense in expenses: + credits_list: List[Dict[str, Any]] = expense["credits"] + message_hash: str = expense["message_hash"] + origin: str = expense["origin"] update_credit_balances_expense( session=session, - credits_list=expense["credits"], - message_hash=expense["message_hash"], + credits_list=credits_list, + message_hash=message_hash, message_timestamp=message_timestamp, ) # Set the origin for this expense - from aleph.db.models import AlephCreditHistoryDb - from sqlalchemy import update as sql_update session.execute( sql_update(AlephCreditHistoryDb) - .where(AlephCreditHistoryDb.credit_ref == expense["message_hash"]) - .values(origin=expense["origin"]) + .where(AlephCreditHistoryDb.credit_ref == message_hash) + .values(origin=origin) ) session.commit() @@ -1108,7 +1135,9 @@ def test_get_resource_consumed_credits_filters_by_origin(session_factory: DbSess assert consumed_credits_none == 0 -def test_get_resource_consumed_credits_uses_absolute_values(session_factory: DbSessionFactory): +def test_get_resource_consumed_credits_uses_absolute_values( + session_factory: DbSessionFactory, +): """Test that get_resource_consumed_credits uses absolute values of amounts.""" from aleph.db.accessors.balances import get_resource_consumed_credits @@ -1125,9 +1154,10 @@ def test_get_resource_consumed_credits_uses_absolute_values(session_factory: DbS ) # Set origin - from aleph.db.models import AlephCreditHistoryDb from sqlalchemy import update as sql_update + from aleph.db.models import AlephCreditHistoryDb + session.execute( sql_update(AlephCreditHistoryDb) .where(AlephCreditHistoryDb.credit_ref == "expense_msg") diff --git a/tests/web/controllers/test_accounts_controllers.py b/tests/web/controllers/test_accounts_controllers.py index 3513bbc0d..8be9bdc2b 100644 --- a/tests/web/controllers/test_accounts_controllers.py +++ b/tests/web/controllers/test_accounts_controllers.py @@ -1,10 +1,9 @@ import json -import pytest -import pytest_asyncio from unittest.mock import MagicMock, patch +import pytest + from aleph.web.controllers.accounts import get_resource_consumed_credits_controller -from aleph.types.db_session import DbSessionFactory @pytest.mark.asyncio @@ -23,9 +22,17 @@ async def test_get_resource_consumed_credits_controller_success(): # Mock consumed credits value expected_consumed_credits = 42 - with patch("aleph.web.controllers.accounts.get_item_hash_str_from_request") as mock_get_hash, \ - patch("aleph.web.controllers.accounts.get_session_factory_from_request") as mock_get_factory, \ - patch("aleph.web.controllers.accounts.get_resource_consumed_credits") as mock_get_credits: + with ( + patch( + "aleph.web.controllers.accounts.get_item_hash_str_from_request" + ) as mock_get_hash, + patch( + "aleph.web.controllers.accounts.get_session_factory_from_request" + ) as mock_get_factory, + patch( + "aleph.web.controllers.accounts.get_resource_consumed_credits" + ) as mock_get_credits, + ): # Set up mocks mock_get_hash.return_value = "test_hash_123" @@ -44,7 +51,9 @@ async def test_get_resource_consumed_credits_controller_success(): # Verify mocks were called correctly mock_get_hash.assert_called_once_with(mock_request) mock_get_factory.assert_called_once_with(mock_request) - mock_get_credits.assert_called_once_with(session=mock_session, item_hash="test_hash_123") + mock_get_credits.assert_called_once_with( + session=mock_session, item_hash="test_hash_123" + ) @pytest.mark.asyncio @@ -63,9 +72,17 @@ async def test_get_resource_consumed_credits_controller_zero_credits(): # Mock zero consumed credits expected_consumed_credits = 0 - with patch("aleph.web.controllers.accounts.get_item_hash_str_from_request") as mock_get_hash, \ - patch("aleph.web.controllers.accounts.get_session_factory_from_request") as mock_get_factory, \ - patch("aleph.web.controllers.accounts.get_resource_consumed_credits") as mock_get_credits: + with ( + patch( + "aleph.web.controllers.accounts.get_item_hash_str_from_request" + ) as mock_get_hash, + patch( + "aleph.web.controllers.accounts.get_session_factory_from_request" + ) as mock_get_factory, + patch( + "aleph.web.controllers.accounts.get_resource_consumed_credits" + ) as mock_get_credits, + ): # Set up mocks mock_get_hash.return_value = "empty_hash_456" @@ -98,9 +115,17 @@ async def test_get_resource_consumed_credits_controller_large_credits(): # Mock large consumed credits value expected_consumed_credits = 999999 - with patch("aleph.web.controllers.accounts.get_item_hash_str_from_request") as mock_get_hash, \ - patch("aleph.web.controllers.accounts.get_session_factory_from_request") as mock_get_factory, \ - patch("aleph.web.controllers.accounts.get_resource_consumed_credits") as mock_get_credits: + with ( + patch( + "aleph.web.controllers.accounts.get_item_hash_str_from_request" + ) as mock_get_hash, + patch( + "aleph.web.controllers.accounts.get_session_factory_from_request" + ) as mock_get_factory, + patch( + "aleph.web.controllers.accounts.get_resource_consumed_credits" + ) as mock_get_credits, + ): # Set up mocks mock_get_hash.return_value = "large_hash_789" @@ -114,4 +139,4 @@ async def test_get_resource_consumed_credits_controller_large_credits(): assert response.status == 200 response_data = json.loads(response.text) assert response_data["item_hash"] == "large_hash_789" - assert response_data["consumed_credits"] == expected_consumed_credits \ No newline at end of file + assert response_data["consumed_credits"] == expected_consumed_credits From bfb78dd69a7cb9e1bfd0a692560d808ebc5a22e4 Mon Sep 17 00:00:00 2001 From: amalcaraz Date: Thu, 30 Oct 2025 16:36:43 +0100 Subject: [PATCH 6/6] chore: move imports to the top of the file --- tests/db/test_credit_balances.py | 29 +++---------------- .../test_check_sender_authorization.py | 13 ++++----- 2 files changed, 9 insertions(+), 33 deletions(-) diff --git a/tests/db/test_credit_balances.py b/tests/db/test_credit_balances.py index f61f78025..48b6f0877 100644 --- a/tests/db/test_credit_balances.py +++ b/tests/db/test_credit_balances.py @@ -4,14 +4,17 @@ from typing import Any, Dict, List from sqlalchemy import select +from sqlalchemy import update as sql_update from aleph.db.accessors.balances import ( get_credit_balance, + get_resource_consumed_credits, update_credit_balances_distribution, update_credit_balances_expense, update_credit_balances_transfer, + validate_credit_transfer_balance, ) -from aleph.db.models import AlephCreditHistoryDb +from aleph.db.models import AlephCreditBalanceDb, AlephCreditHistoryDb from aleph.types.db_session import DbSessionFactory @@ -274,7 +277,6 @@ def test_whitelisted_sender_transfer(session_factory: DbSessionFactory): def test_balance_validation_insufficient_credits(session_factory: DbSessionFactory): """Test balance validation fails when sender has insufficient credits.""" - from aleph.db.accessors.balances import validate_credit_transfer_balance # Create initial balance of 500 credits_list = [ @@ -353,7 +355,6 @@ def test_expired_credits_excluded_from_transfers(session_factory: DbSessionFacto assert balance == 200 # Transfer validation should only consider valid credits (200) - from aleph.db.accessors.balances import validate_credit_transfer_balance assert validate_credit_transfer_balance(session, "0xexpired_user", 200) assert not validate_credit_transfer_balance(session, "0xexpired_user", 300) @@ -850,7 +851,6 @@ def test_cache_invalidation_on_credit_expiration(session_factory: DbSessionFacto # Verify that a cache entry was created and manually update its timestamp # to simulate it being created at T2 (cache_time) - from aleph.db.models import AlephCreditBalanceDb cached_balance = session.execute( select(AlephCreditBalanceDb).where( @@ -877,7 +877,6 @@ def test_cache_invalidation_on_credit_expiration(session_factory: DbSessionFacto def test_get_resource_consumed_credits_no_records(session_factory: DbSessionFactory): """Test get_resource_consumed_credits returns 0 when no records exist.""" - from aleph.db.accessors.balances import get_resource_consumed_credits with session_factory() as session: consumed_credits = get_resource_consumed_credits( @@ -888,7 +887,6 @@ def test_get_resource_consumed_credits_no_records(session_factory: DbSessionFact def test_get_resource_consumed_credits_single_record(session_factory: DbSessionFactory): """Test get_resource_consumed_credits with a single expense record.""" - from aleph.db.accessors.balances import get_resource_consumed_credits # Create a credit expense record expense_credits = [ @@ -912,9 +910,6 @@ def test_get_resource_consumed_credits_single_record(session_factory: DbSessionF # Manually set the origin field to the item_hash we want to test # Since update_credit_balances_expense doesn't set origin by default - from sqlalchemy import update as sql_update - - from aleph.db.models import AlephCreditHistoryDb session.execute( sql_update(AlephCreditHistoryDb) @@ -933,7 +928,6 @@ def test_get_resource_consumed_credits_multiple_records( session_factory: DbSessionFactory, ): """Test get_resource_consumed_credits with multiple expense records for the same resource.""" - from aleph.db.accessors.balances import get_resource_consumed_credits message_timestamp = dt.datetime(2023, 1, 1, 12, 0, 0, tzinfo=dt.timezone.utc) @@ -961,9 +955,6 @@ def test_get_resource_consumed_credits_multiple_records( ] # Import required modules - from sqlalchemy import update as sql_update - - from aleph.db.models import AlephCreditHistoryDb for batch in expense_batches: credits_list: List[Dict[str, Any]] = batch["credits"] @@ -997,7 +988,6 @@ def test_get_resource_consumed_credits_filters_by_payment_method( session_factory: DbSessionFactory, ): """Test that get_resource_consumed_credits only counts credit_expense payments.""" - from aleph.db.accessors.balances import get_resource_consumed_credits message_timestamp = dt.datetime(2023, 1, 1, 12, 0, 0, tzinfo=dt.timezone.utc) @@ -1043,9 +1033,6 @@ def test_get_resource_consumed_credits_filters_by_payment_method( ) # Set origin for all records to the same resource - from sqlalchemy import update as sql_update - - from aleph.db.models import AlephCreditHistoryDb for msg_hash in ["distribution_msg", "transfer_msg", "expense_msg"]: session.execute( @@ -1066,7 +1053,6 @@ def test_get_resource_consumed_credits_filters_by_origin( session_factory: DbSessionFactory, ): """Test that get_resource_consumed_credits only counts records with matching origin.""" - from aleph.db.accessors.balances import get_resource_consumed_credits message_timestamp = dt.datetime(2023, 1, 1, 12, 0, 0, tzinfo=dt.timezone.utc) @@ -1091,9 +1077,6 @@ def test_get_resource_consumed_credits_filters_by_origin( ] # Import required modules - from sqlalchemy import update as sql_update - - from aleph.db.models import AlephCreditHistoryDb for expense in expenses: credits_list: List[Dict[str, Any]] = expense["credits"] @@ -1139,7 +1122,6 @@ def test_get_resource_consumed_credits_uses_absolute_values( session_factory: DbSessionFactory, ): """Test that get_resource_consumed_credits uses absolute values of amounts.""" - from aleph.db.accessors.balances import get_resource_consumed_credits message_timestamp = dt.datetime(2023, 1, 1, 12, 0, 0, tzinfo=dt.timezone.utc) @@ -1154,9 +1136,6 @@ def test_get_resource_consumed_credits_uses_absolute_values( ) # Set origin - from sqlalchemy import update as sql_update - - from aleph.db.models import AlephCreditHistoryDb session.execute( sql_update(AlephCreditHistoryDb) diff --git a/tests/permissions/test_check_sender_authorization.py b/tests/permissions/test_check_sender_authorization.py index dd13119c2..f0a4e5b45 100644 --- a/tests/permissions/test_check_sender_authorization.py +++ b/tests/permissions/test_check_sender_authorization.py @@ -4,11 +4,15 @@ import pytest from message_test_helpers import make_validated_message_from_dict -from aleph.db.models import AggregateDb, AggregateElementDb +from aleph.chains.signature_verifier import SignatureVerifier +from aleph.db.models import AggregateDb, AggregateElementDb, PendingMessageDb +from aleph.handlers.message_handler import MessageHandler from aleph.permissions import check_sender_authorization +from aleph.storage import StorageService from aleph.toolkit.timestamp import timestamp_to_datetime from aleph.types.channel import Channel from aleph.types.db_session import DbSessionFactory +from aleph.types.message_status import PermissionDenied @pytest.mark.asyncio @@ -145,13 +149,6 @@ async def test_message_processing_should_fail_on_permission( An attacker can send a message with victim's address, and it should be rejected during message processing, but currently it's accepted due to the bug. """ - import datetime as dt - - from aleph.chains.signature_verifier import SignatureVerifier - from aleph.db.models import PendingMessageDb - from aleph.handlers.message_handler import MessageHandler - from aleph.storage import StorageService - from aleph.types.message_status import PermissionDenied # Mock the storage and signature verification to focus on permission testing storage_service = mocker.Mock(spec=StorageService)