diff --git a/src/story_protocol_python_sdk/resources/Group.py b/src/story_protocol_python_sdk/resources/Group.py index ce7f420..1819234 100644 --- a/src/story_protocol_python_sdk/resources/Group.py +++ b/src/story_protocol_python_sdk/resources/Group.py @@ -543,9 +543,9 @@ def claim_rewards( """ Claim rewards for the entire group. - :param group_ip_id str: The ID of the group IP. - :param currency_token str: The address of the currency (revenue) token to claim.. - :param member_ip_ids list: The IDs of the member IPs to distribute the rewards to. + :param group_ip_id Address: The ID of the group IP. + :param currency_token Address: The address of the currency (revenue) token to claim. + :param member_ip_ids list[Address]: The IDs of the member IPs to distribute the rewards to. :param tx_options dict: [Optional] The transaction options. :return ClaimRewardsResponse: A response object with the transaction hash and claimed rewards. """ @@ -571,14 +571,27 @@ def claim_rewards( *claim_reward_param.values(), tx_options=tx_options, ) - - claimed_rewards = self._parse_tx_claimed_reward_event( - response["tx_receipt"] - ) - + event_signature = self.web3.keccak( + text="ClaimedReward(address,address,address[],uint256[])" + ).hex() + claimed_rewards = None + for log in response["tx_receipt"]["logs"]: + if log["topics"][0].hex() == event_signature: + event_result = self.grouping_module_client.contract.events.ClaimedReward.process_log( + log + ) + claimed_rewards = event_result["args"] + break + if not claimed_rewards: + raise ValueError("Not found ClaimedReward event in transaction logs.") return ClaimRewardsResponse( tx_hash=response["tx_hash"], - claimed_rewards=claimed_rewards, + claimed_rewards=ClaimReward( + ip_ids=claimed_rewards["ipId"], + amounts=claimed_rewards["amount"], + token=claimed_rewards["token"], + group_id=claimed_rewards["groupId"], + ), ) except Exception as e: @@ -748,31 +761,3 @@ def _parse_tx_royalty_paid_event(self, tx_receipt: dict) -> list: ) return royalties_distributed - - def _parse_tx_claimed_reward_event(self, tx_receipt: dict) -> list[ClaimReward]: - """ - Parse the ClaimedReward event from a transaction receipt. - - :param tx_receipt dict: The transaction receipt. - :return list: List of claimed rewards. - """ - event_signature = self.web3.keccak( - text="ClaimedReward(address,address,address,uint256)" - ).hex() - claimed_rewards = [] - - for log in tx_receipt["logs"]: - if log["topics"][0].hex() == event_signature: - ip_id = "0x" + log["topics"][0].hex()[24:] - amount = int(log["data"][:66].hex(), 16) - token = "0x" + log["topics"][2].hex()[24:] - - claimed_rewards.append( - ClaimReward( - ip_id=ip_id, - amount=amount, - token=token, - ) - ) - - return claimed_rewards diff --git a/src/story_protocol_python_sdk/types/resource/Group.py b/src/story_protocol_python_sdk/types/resource/Group.py index 62d0b2e..e6c225e 100644 --- a/src/story_protocol_python_sdk/types/resource/Group.py +++ b/src/story_protocol_python_sdk/types/resource/Group.py @@ -1,6 +1,6 @@ from typing import TypedDict -from ens.ens import Address, HexStr +from ens.ens import Address, HexBytes class ClaimReward(TypedDict): @@ -8,9 +8,10 @@ class ClaimReward(TypedDict): Structure for a claimed reward. """ - ip_id: Address - amount: int + ip_ids: list[Address] + amounts: list[int] token: Address + group_id: Address class ClaimRewardsResponse(TypedDict): @@ -18,5 +19,5 @@ class ClaimRewardsResponse(TypedDict): Response structure for Group.claim_rewards method. """ - tx_hash: HexStr - claimed_rewards: list[ClaimReward] + tx_hash: HexBytes + claimed_rewards: ClaimReward diff --git a/tests/integration/test_integration_group.py b/tests/integration/test_integration_group.py index 58b3079..e932781 100644 --- a/tests/integration/test_integration_group.py +++ b/tests/integration/test_integration_group.py @@ -549,22 +549,16 @@ def test_claim_rewards(self, story_client: StoryClient, setup_royalty_collection story_client.Group.collect_and_distribute_group_royalties( group_ip_id=group_ip_id, currency_tokens=[MockERC20], member_ip_ids=ip_ids ) - # Test claiming rewards for specific members response = story_client.Group.claim_rewards( group_ip_id=group_ip_id, currency_token=MockERC20, member_ip_ids=ip_ids, ) - # Verify response structure assert "tx_hash" in response assert isinstance(response["tx_hash"], str) - assert len(response["tx_hash"]) > 0 - - # Verify claimed rewards details if any are present - if response["claimed_rewards"]: - for reward in response["claimed_rewards"]: - assert "amount" in reward - assert isinstance(reward["amount"], int) - assert "token" in reward - assert story_client.web3.is_address(reward["token"]) + assert "claimed_rewards" in response + assert len(response["claimed_rewards"]["ip_ids"]) == 2 + assert len(response["claimed_rewards"]["amounts"]) == 2 + assert response["claimed_rewards"]["token"] == MockERC20 + assert response["claimed_rewards"]["group_id"] == group_ip_id diff --git a/tests/unit/resources/test_group.py b/tests/unit/resources/test_group.py index 3bee689..7ef683b 100644 --- a/tests/unit/resources/test_group.py +++ b/tests/unit/resources/test_group.py @@ -1,8 +1,12 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from story_protocol_python_sdk.resources.Group import Group +from story_protocol_python_sdk.types.resource.Group import ( + ClaimReward, + ClaimRewardsResponse, +) from tests.unit.fixtures.data import ADDRESS, CHAIN_ID, IP_ID, TX_HASH @@ -11,44 +15,6 @@ def group(mock_web3, mock_account): return Group(mock_web3, mock_account, CHAIN_ID) -@pytest.fixture(scope="class") -def mock_grouping_module_client(group): - def _mock(): - return patch.object( - group.grouping_module_client, - "build_claimReward_transaction", - return_value=MagicMock(), - ) - - return _mock - - -@pytest.fixture(scope="class") -def mock_build_and_send_transaction(): - def _mock(): - return patch( - "story_protocol_python_sdk.resources.Group.build_and_send_transaction", - return_value={ - "tx_hash": TX_HASH, - "tx_receipt": {"status": 1, "logs": []}, - }, - ) - - return _mock - - -@pytest.fixture(scope="class") -def mock_parse_tx_claimed_reward_event(group): - def _mock(): - return patch.object( - group, - "_parse_tx_claimed_reward_event", - return_value=[{"amount": 100, "token": ADDRESS}], - ) - - return _mock - - class TestGroupClaimRewards: """Test class for Group.claim_rewards method""" @@ -119,30 +85,194 @@ def test_claim_rewards_mixed_valid_invalid_members(self, group: Group, mock_web3 def test_claim_rewards_success( self, group: Group, - mock_grouping_module_client, - mock_build_and_send_transaction, - mock_parse_tx_claimed_reward_event, mock_web3_is_address, ): """Test successful claim_rewards operation.""" + with mock_web3_is_address(): + with patch( + "story_protocol_python_sdk.resources.Group.build_and_send_transaction", + return_value={ + "tx_hash": TX_HASH, + "tx_receipt": { + "logs": [ + { + "topics": [ + group.web3.keccak( + text="ClaimedReward(address,address,address[],uint256[])" + ) + ] + } + ] + }, + }, + ), patch.object( + group.grouping_module_client.contract.events.ClaimedReward, + "process_log", + return_value={ + "args": { + "ipId": [IP_ID, ADDRESS], + "amount": [100, 200], + "token": ADDRESS, + "groupId": IP_ID, + } + }, + ): + result = group.claim_rewards( + group_ip_id=IP_ID, + currency_token=ADDRESS, + member_ip_ids=[IP_ID, ADDRESS], + ) + + assert result == ClaimRewardsResponse( + tx_hash=TX_HASH, + claimed_rewards=ClaimReward( + ip_ids=[IP_ID, ADDRESS], + amounts=[100, 200], + token=ADDRESS, + group_id=IP_ID, + ), + ) + + def test_claim_rewards_with_tx_options( + self, + group: Group, + mock_web3_is_address, + ): + """Test claim_rewards with transaction options.""" + with mock_web3_is_address(): + with patch( + "story_protocol_python_sdk.resources.Group.build_and_send_transaction", + return_value={ + "tx_hash": TX_HASH, + "tx_receipt": { + "logs": [ + { + "topics": [ + group.web3.keccak( + text="ClaimedReward(address,address,address[],uint256[])" + ) + ] + } + ] + }, + }, + ) as mock_build_and_send, patch.object( + group.grouping_module_client.contract.events.ClaimedReward, + "process_log", + return_value={ + "args": { + "ipId": [IP_ID, ADDRESS], + "amount": [100, 200], + "token": ADDRESS, + "groupId": IP_ID, + } + }, + ): + tx_options = {"gas": 200000, "gasPrice": 20000000000} + result = group.claim_rewards( + group_ip_id=IP_ID, + currency_token=ADDRESS, + member_ip_ids=[IP_ID, ADDRESS], + tx_options=tx_options, + ) - with ( - mock_grouping_module_client(), - mock_build_and_send_transaction(), - mock_parse_tx_claimed_reward_event(), - mock_web3_is_address(), - ): - result = group.claim_rewards( - group_ip_id=IP_ID, - currency_token=ADDRESS, - member_ip_ids=[IP_ID, ADDRESS], - ) - - # Verify response structure - assert "tx_hash" in result - assert result["tx_hash"] == TX_HASH - assert "claimed_rewards" in result - assert result["claimed_rewards"] == [{"amount": 100, "token": ADDRESS}] + # Verify tx_options were passed to build_and_send_transaction + mock_build_and_send.assert_called_once() + call_args = mock_build_and_send.call_args + assert call_args[1]["tx_options"] == tx_options + + # Verify response with tx_options + assert result["tx_hash"] == TX_HASH + assert result["claimed_rewards"] == ClaimReward( + ip_ids=[IP_ID, ADDRESS], + amounts=[100, 200], + token=ADDRESS, + group_id=IP_ID, + ) + + def test_claim_rewards_no_event_found( + self, + group: Group, + mock_web3_is_address, + ): + """Test claim_rewards when no ClaimedReward event is found.""" + with mock_web3_is_address(): + with patch( + "story_protocol_python_sdk.resources.Group.build_and_send_transaction", + return_value={ + "tx_hash": TX_HASH, + "tx_receipt": { + "logs": [ + {"topics": [group.web3.keccak(text="DifferentEvent()")]} + ] + }, + }, + ), patch.object( + group.grouping_module_client.contract.events.ClaimedReward, + "process_log", + return_value={"args": {}}, + ): + with pytest.raises( + ValueError, + match="Failed to claim rewards: Not found ClaimedReward event in transaction logs.", + ): + group.claim_rewards( + group_ip_id=IP_ID, + currency_token=ADDRESS, + member_ip_ids=[IP_ID], + ) + + def test_claim_rewards_empty_member_ip_ids( + self, + group: Group, + mock_web3_is_address, + ): + """Test claim_rewards with empty member IP IDs list.""" + with mock_web3_is_address(): + with patch( + "story_protocol_python_sdk.resources.Group.build_and_send_transaction", + return_value={ + "tx_hash": TX_HASH, + "tx_receipt": { + "logs": [ + { + "topics": [ + group.web3.keccak( + text="ClaimedReward(address,address,address[],uint256[])" + ) + ] + } + ] + }, + }, + ), patch.object( + group.grouping_module_client.contract.events.ClaimedReward, + "process_log", + return_value={ + "args": { + "ipId": [], + "amount": [], + "token": ADDRESS, + "groupId": IP_ID, + } + }, + ): + result = group.claim_rewards( + group_ip_id=IP_ID, + currency_token=ADDRESS, + member_ip_ids=[], + ) + + # Verify response structure + assert "tx_hash" in result + assert result["tx_hash"] == TX_HASH + assert "claimed_rewards" in result + assert result["claimed_rewards"] == ClaimReward( + ip_ids=[], + amounts=[], + token=ADDRESS, + group_id=IP_ID, + ) def test_claim_rewards_transaction_build_failure( self, group: Group, mock_web3_is_address