diff --git a/src/story_protocol_python_sdk/__init__.py b/src/story_protocol_python_sdk/__init__.py index 6107fee3..6a095b39 100644 --- a/src/story_protocol_python_sdk/__init__.py +++ b/src/story_protocol_python_sdk/__init__.py @@ -8,7 +8,11 @@ from .resources.WIP import WIP from .story_client import StoryClient from .types.common import AccessPermission -from .types.resource.IPAsset import RegistrationResponse +from .types.resource.Group import ClaimReward, ClaimRewardsResponse +from .types.resource.IPAsset import ( + RegisterPILTermsAndAttachResponse, + RegistrationResponse, +) from .utils.constants import ( DEFAULT_FUNCTION_SELECTOR, MAX_ROYALTY_TOKEN, @@ -34,6 +38,9 @@ "DerivativeDataInput", "IPMetadataInput", "RegistrationResponse", + "ClaimRewardsResponse", + "ClaimReward", + "RegisterPILTermsAndAttachResponse", # Constants "ZERO_ADDRESS", "ZERO_HASH", diff --git a/src/story_protocol_python_sdk/abi/GroupingModule/GroupingModule_client.py b/src/story_protocol_python_sdk/abi/GroupingModule/GroupingModule_client.py index ed8ffe16..1d647ed9 100644 --- a/src/story_protocol_python_sdk/abi/GroupingModule/GroupingModule_client.py +++ b/src/story_protocol_python_sdk/abi/GroupingModule/GroupingModule_client.py @@ -43,6 +43,14 @@ def build_addIp_transaction( groupIpId, ipIds, maxAllowedRewardShare ).build_transaction(tx_params) + def claimReward(self, groupId, token, ipIds): + return self.contract.functions.claimReward(groupId, token, ipIds).transact() + + def build_claimReward_transaction(self, groupId, token, ipIds, tx_params): + return self.contract.functions.claimReward( + groupId, token, ipIds + ).build_transaction(tx_params) + def registerGroup(self, groupPool): return self.contract.functions.registerGroup(groupPool).transact() diff --git a/src/story_protocol_python_sdk/resources/Group.py b/src/story_protocol_python_sdk/resources/Group.py index 2e86bb66..ce7f4204 100644 --- a/src/story_protocol_python_sdk/resources/Group.py +++ b/src/story_protocol_python_sdk/resources/Group.py @@ -1,6 +1,4 @@ -# src/story_protocol_python_sdk/resources/Group.py - -from ens.ens import HexStr +from ens.ens import Address, HexStr from web3 import Web3 from story_protocol_python_sdk.abi.CoreMetadataModule.CoreMetadataModule_client import ( @@ -28,6 +26,10 @@ PILicenseTemplateClient, ) from story_protocol_python_sdk.types.common import RevShareType +from story_protocol_python_sdk.types.resource.Group import ( + ClaimReward, + ClaimRewardsResponse, +) from story_protocol_python_sdk.utils.constants import ZERO_ADDRESS, ZERO_HASH from story_protocol_python_sdk.utils.license_terms import LicenseTerms from story_protocol_python_sdk.utils.sign import Sign @@ -531,6 +533,57 @@ def collect_and_distribute_group_royalties( f"Failed to collect and distribute group royalties: {str(e)}" ) + def claim_rewards( + self, + group_ip_id: Address, + currency_token: Address, + member_ip_ids: list[Address], + tx_options: dict | None = None, + ) -> ClaimRewardsResponse: + """ + Claim rewards for the entire group. + + :param group_ip_id str: The ID of the group IP. + :param currency_token str: The address of the currency (revenue) token to claim.. + :param member_ip_ids list: The IDs of the member IPs to distribute the rewards to. + :param tx_options dict: [Optional] The transaction options. + :return ClaimRewardsResponse: A response object with the transaction hash and claimed rewards. + """ + try: + if not self.web3.is_address(group_ip_id): + raise ValueError(f"Invalid group IP ID: {group_ip_id}") + if not self.web3.is_address(currency_token): + raise ValueError(f"Invalid currency token: {currency_token}") + for ip_id in member_ip_ids: + if not self.web3.is_address(ip_id): + raise ValueError(f"Invalid member IP ID: {ip_id}") + + claim_reward_param = { + "groupIpId": group_ip_id, + "token": currency_token, + "memberIpIds": member_ip_ids, + } + + response = build_and_send_transaction( + self.web3, + self.account, + self.grouping_module_client.build_claimReward_transaction, + *claim_reward_param.values(), + tx_options=tx_options, + ) + + claimed_rewards = self._parse_tx_claimed_reward_event( + response["tx_receipt"] + ) + + return ClaimRewardsResponse( + tx_hash=response["tx_hash"], + claimed_rewards=claimed_rewards, + ) + + except Exception as e: + raise ValueError(f"Failed to claim rewards: {str(e)}") + def _get_license_data(self, license_data: list) -> list: """ Process license data into the format expected by the contracts. @@ -695,3 +748,31 @@ def _parse_tx_royalty_paid_event(self, tx_receipt: dict) -> list: ) return royalties_distributed + + def _parse_tx_claimed_reward_event(self, tx_receipt: dict) -> list[ClaimReward]: + """ + Parse the ClaimedReward event from a transaction receipt. + + :param tx_receipt dict: The transaction receipt. + :return list: List of claimed rewards. + """ + event_signature = self.web3.keccak( + text="ClaimedReward(address,address,address,uint256)" + ).hex() + claimed_rewards = [] + + for log in tx_receipt["logs"]: + if log["topics"][0].hex() == event_signature: + ip_id = "0x" + log["topics"][0].hex()[24:] + amount = int(log["data"][:66].hex(), 16) + token = "0x" + log["topics"][2].hex()[24:] + + claimed_rewards.append( + ClaimReward( + ip_id=ip_id, + amount=amount, + token=token, + ) + ) + + return claimed_rewards diff --git a/src/story_protocol_python_sdk/scripts/config.json b/src/story_protocol_python_sdk/scripts/config.json index 58573fa3..b1d9252c 100644 --- a/src/story_protocol_python_sdk/scripts/config.json +++ b/src/story_protocol_python_sdk/scripts/config.json @@ -179,7 +179,12 @@ { "contract_name": "GroupingModule", "contract_address": "0x69D3a7aa9edb72Bc226E745A7cCdd50D947b69Ac", - "functions": ["registerGroup", "addIp", "IPGroupRegistered"] + "functions": [ + "registerGroup", + "addIp", + "IPGroupRegistered", + "claimReward" + ] }, { "contract_name": "LicenseRegistry", diff --git a/src/story_protocol_python_sdk/types/resource/Group.py b/src/story_protocol_python_sdk/types/resource/Group.py new file mode 100644 index 00000000..62d0b2ec --- /dev/null +++ b/src/story_protocol_python_sdk/types/resource/Group.py @@ -0,0 +1,22 @@ +from typing import TypedDict + +from ens.ens import Address, HexStr + + +class ClaimReward(TypedDict): + """ + Structure for a claimed reward. + """ + + ip_id: Address + amount: int + token: Address + + +class ClaimRewardsResponse(TypedDict): + """ + Response structure for Group.claim_rewards method. + """ + + tx_hash: HexStr + claimed_rewards: list[ClaimReward] diff --git a/tests/integration/test_integration_group.py b/tests/integration/test_integration_group.py index e5d645e5..58b30797 100644 --- a/tests/integration/test_integration_group.py +++ b/tests/integration/test_integration_group.py @@ -4,6 +4,8 @@ import pytest +from story_protocol_python_sdk.story_client import StoryClient + from .setup_for_integration import ( EVEN_SPLIT_GROUP_POOL, PIL_LICENSE_TEMPLATE, @@ -538,3 +540,31 @@ def test_collect_and_distribute_group_royalties( assert len(response["royalties_distributed"]) == 2 assert response["royalties_distributed"][0]["amount"] == 10 assert response["royalties_distributed"][1]["amount"] == 10 + + def test_claim_rewards(self, story_client: StoryClient, setup_royalty_collection): + """Test claiming rewards for group members.""" + group_ip_id = setup_royalty_collection["group_ip_id"] + ip_ids = setup_royalty_collection["ip_ids"] + # Collect and distribute royalties to set up rewards for claiming + story_client.Group.collect_and_distribute_group_royalties( + group_ip_id=group_ip_id, currency_tokens=[MockERC20], member_ip_ids=ip_ids + ) + + # Test claiming rewards for specific members + response = story_client.Group.claim_rewards( + group_ip_id=group_ip_id, + currency_token=MockERC20, + member_ip_ids=ip_ids, + ) + # Verify response structure + assert "tx_hash" in response + assert isinstance(response["tx_hash"], str) + assert len(response["tx_hash"]) > 0 + + # Verify claimed rewards details if any are present + if response["claimed_rewards"]: + for reward in response["claimed_rewards"]: + assert "amount" in reward + assert isinstance(reward["amount"], int) + assert "token" in reward + assert story_client.web3.is_address(reward["token"]) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 88c3b587..484fbfc6 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -149,3 +149,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): return MockContext() return _mock + + +@pytest.fixture(scope="module") +def mock_web3_is_address(mock_web3): + def _mock(is_address: bool = True): + return patch.object(mock_web3, "is_address", return_value=is_address) + + return _mock diff --git a/tests/unit/resources/test_group.py b/tests/unit/resources/test_group.py new file mode 100644 index 00000000..3bee6893 --- /dev/null +++ b/tests/unit/resources/test_group.py @@ -0,0 +1,164 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from story_protocol_python_sdk.resources.Group import Group +from tests.unit.fixtures.data import ADDRESS, CHAIN_ID, IP_ID, TX_HASH + + +@pytest.fixture(scope="class") +def group(mock_web3, mock_account): + return Group(mock_web3, mock_account, CHAIN_ID) + + +@pytest.fixture(scope="class") +def mock_grouping_module_client(group): + def _mock(): + return patch.object( + group.grouping_module_client, + "build_claimReward_transaction", + return_value=MagicMock(), + ) + + return _mock + + +@pytest.fixture(scope="class") +def mock_build_and_send_transaction(): + def _mock(): + return patch( + "story_protocol_python_sdk.resources.Group.build_and_send_transaction", + return_value={ + "tx_hash": TX_HASH, + "tx_receipt": {"status": 1, "logs": []}, + }, + ) + + return _mock + + +@pytest.fixture(scope="class") +def mock_parse_tx_claimed_reward_event(group): + def _mock(): + return patch.object( + group, + "_parse_tx_claimed_reward_event", + return_value=[{"amount": 100, "token": ADDRESS}], + ) + + return _mock + + +class TestGroupClaimRewards: + """Test class for Group.claim_rewards method""" + + def test_claim_rewards_invalid_group_ip_id( + self, group: Group, mock_web3_is_address + ): + """Test claim_rewards with invalid group IP ID.""" + invalid_group_ip_id = "invalid_group_ip_id" + with mock_web3_is_address(False): + with pytest.raises( + ValueError, + match=f"Failed to claim rewards: Invalid group IP ID: {invalid_group_ip_id}", + ): + group.claim_rewards( + group_ip_id=invalid_group_ip_id, + currency_token=ADDRESS, + member_ip_ids=[IP_ID], + ) + + def test_claim_rewards_invalid_currency_token(self, group: Group, mock_web3): + """Test claim_rewards with invalid currency token.""" + invalid_currency_token = "invalid_currency_token" + with patch.object(mock_web3, "is_address") as mock_is_address: + # group_ip_id=True, currency_token=False + mock_is_address.side_effect = [True, False] + with pytest.raises( + ValueError, + match=f"Failed to claim rewards: Invalid currency token: {invalid_currency_token}", + ): + group.claim_rewards( + group_ip_id=IP_ID, + currency_token=invalid_currency_token, + member_ip_ids=[ADDRESS], + ) + + def test_claim_rewards_invalid_member_ip_ids(self, group: Group, mock_web3): + """Test claim_rewards with invalid member IP IDs.""" + invalid_member_ip_id = "invalid_member_ip" + with patch.object(mock_web3, "is_address") as mock_is_address: + # group_ip_id=True, currency_token=True, first member_ip_id=False + mock_is_address.side_effect = [True, True, False] + with pytest.raises( + ValueError, + match=f"Failed to claim rewards: Invalid member IP ID: {invalid_member_ip_id}", + ): + group.claim_rewards( + group_ip_id=IP_ID, + currency_token=ADDRESS, + member_ip_ids=[invalid_member_ip_id], + ) + + def test_claim_rewards_mixed_valid_invalid_members(self, group: Group, mock_web3): + """Test claim_rewards with mix of valid and invalid member IP IDs.""" + invalid_member_ip_id = "invalid_member_ip" + with patch.object(mock_web3, "is_address") as mock_is_address: + # group_ip_id=True, currency_token=True, first_member=True, second_member=False + mock_is_address.side_effect = [True, True, True, False] + with pytest.raises( + ValueError, + match=f"Failed to claim rewards: Invalid member IP ID: {invalid_member_ip_id}", + ): + group.claim_rewards( + group_ip_id=IP_ID, + currency_token=ADDRESS, + member_ip_ids=[ADDRESS, invalid_member_ip_id], + ) + + def test_claim_rewards_success( + self, + group: Group, + mock_grouping_module_client, + mock_build_and_send_transaction, + mock_parse_tx_claimed_reward_event, + mock_web3_is_address, + ): + """Test successful claim_rewards operation.""" + + with ( + mock_grouping_module_client(), + mock_build_and_send_transaction(), + mock_parse_tx_claimed_reward_event(), + mock_web3_is_address(), + ): + result = group.claim_rewards( + group_ip_id=IP_ID, + currency_token=ADDRESS, + member_ip_ids=[IP_ID, ADDRESS], + ) + + # Verify response structure + assert "tx_hash" in result + assert result["tx_hash"] == TX_HASH + assert "claimed_rewards" in result + assert result["claimed_rewards"] == [{"amount": 100, "token": ADDRESS}] + + def test_claim_rewards_transaction_build_failure( + self, group: Group, mock_web3_is_address + ): + """Test claim_rewards when transaction building fails.""" + with mock_web3_is_address(True): + with patch( + "story_protocol_python_sdk.resources.Group.build_and_send_transaction", + side_effect=Exception("Transaction build failed"), + ): + with pytest.raises( + ValueError, + match="Failed to claim rewards: Transaction build failed", + ): + group.claim_rewards( + group_ip_id=IP_ID, + currency_token=ADDRESS, + member_ip_ids=[IP_ID], + )