From 3b435faa424c97fe5a6771d827252150037986ee Mon Sep 17 00:00:00 2001 From: Bonnie Date: Wed, 3 Sep 2025 14:53:03 +0800 Subject: [PATCH] feat: add get_claimable_reward method to Group and GroupingModuleClient --- .../GroupingModule/GroupingModule_client.py | 3 + .../resources/Group.py | 34 +++++ .../scripts/config.json | 3 +- tests/integration/test_integration_group.py | 53 +++++++ tests/unit/resources/test_group.py | 133 ++++++++++++++++++ 5 files changed, 225 insertions(+), 1 deletion(-) diff --git a/src/story_protocol_python_sdk/abi/GroupingModule/GroupingModule_client.py b/src/story_protocol_python_sdk/abi/GroupingModule/GroupingModule_client.py index e207d8d..a3b17c8 100644 --- a/src/story_protocol_python_sdk/abi/GroupingModule/GroupingModule_client.py +++ b/src/story_protocol_python_sdk/abi/GroupingModule/GroupingModule_client.py @@ -66,3 +66,6 @@ def build_registerGroup_transaction(self, groupPool, tx_params): return self.contract.functions.registerGroup(groupPool).build_transaction( tx_params ) + + def getClaimableReward(self, groupId, token, ipIds): + return self.contract.functions.getClaimableReward(groupId, token, ipIds).call() diff --git a/src/story_protocol_python_sdk/resources/Group.py b/src/story_protocol_python_sdk/resources/Group.py index 18cb4d5..7836e8c 100644 --- a/src/story_protocol_python_sdk/resources/Group.py +++ b/src/story_protocol_python_sdk/resources/Group.py @@ -647,6 +647,40 @@ def collect_royalties( except Exception as e: raise ValueError(f"Failed to collect royalties: {str(e)}") + def get_claimable_reward( + self, + group_ip_id: Address, + currency_token: Address, + member_ip_ids: list[Address], + ) -> list[int]: + """ + Returns the available reward for each IP in the group. + + :param group_ip_id Address: The ID of the group IP. + :param currency_token Address: The address of the currency (revenue) token to check. + :param member_ip_ids list[Address]: The IDs of the member IPs to check claimable rewards for. + :return list[int]: A list of claimable reward amounts corresponding to each member IP ID. + """ + try: + if not self.web3.is_address(group_ip_id): + raise ValueError(f"Invalid group IP ID: {group_ip_id}") + if not self.web3.is_address(currency_token): + raise ValueError(f"Invalid currency token: {currency_token}") + for ip_id in member_ip_ids: + if not self.web3.is_address(ip_id): + raise ValueError(f"Invalid member IP ID: {ip_id}") + + claimable_rewards = self.grouping_module_client.getClaimableReward( + groupId=group_ip_id, + token=currency_token, + ipIds=member_ip_ids, + ) + + return claimable_rewards + + except Exception as e: + raise ValueError(f"Failed to get claimable rewards: {str(e)}") + def _get_license_data(self, license_data: list) -> list: """ Process license data into the format expected by the contracts. diff --git a/src/story_protocol_python_sdk/scripts/config.json b/src/story_protocol_python_sdk/scripts/config.json index bf797ad..b28b428 100644 --- a/src/story_protocol_python_sdk/scripts/config.json +++ b/src/story_protocol_python_sdk/scripts/config.json @@ -184,7 +184,8 @@ "addIp", "IPGroupRegistered", "claimReward", - "collectRoyalties" + "collectRoyalties", + "getClaimableReward" ] }, { diff --git a/tests/integration/test_integration_group.py b/tests/integration/test_integration_group.py index 2f81928..6f8f300 100644 --- a/tests/integration/test_integration_group.py +++ b/tests/integration/test_integration_group.py @@ -307,3 +307,56 @@ def test_collect_and_distribute_group_royalties( assert len(response["royalties_distributed"]) == 2 assert response["royalties_distributed"][0]["amount"] == 10 assert response["royalties_distributed"][1]["amount"] == 10 + + def test_get_claimable_reward( + self, story_client: StoryClient, nft_collection: Address + ): + """Test getting claimable rewards for group members.""" + # Register IP id + result1 = GroupTestHelper.mint_and_register_ip_asset_with_pil_terms( + story_client, nft_collection + ) + ip_id1 = result1["ip_id"] + license_terms_id1 = result1["license_terms_id"] + result2 = GroupTestHelper.mint_and_register_ip_asset_with_pil_terms( + story_client, nft_collection + ) + ip_id2 = result2["ip_id"] + license_terms_id2 = result2["license_terms_id"] + + # Register group id + group_ip_id = GroupTestHelper.register_group_and_attach_license( + story_client, license_terms_id1, [ip_id1, ip_id2] + ) + # Create a derivative IP and pay royalties + child_ip_id = GroupTestHelper.mint_and_register_ip_and_make_derivative( + story_client, nft_collection, group_ip_id, license_terms_id1 + ) + child_ip_id2 = GroupTestHelper.mint_and_register_ip_and_make_derivative( + story_client, nft_collection, group_ip_id, license_terms_id2 + ) + + # Pay royalties from group IP id to child IP id + GroupTestHelper.pay_royalty_and_transfer_to_vault( + story_client, child_ip_id, group_ip_id, MockERC20, 100 + ) + GroupTestHelper.pay_royalty_and_transfer_to_vault( + story_client, child_ip_id2, group_ip_id, MockERC20, 100 + ) + + # Collect royalties + story_client.Group.collect_royalties( + group_ip_id=group_ip_id, + currency_token=MockERC20, + ) + # Get claimable rewards after royalties are collected + claimable_rewards = story_client.Group.get_claimable_reward( + group_ip_id=group_ip_id, + currency_token=MockERC20, + member_ip_ids=[ip_id1, ip_id2], + ) + + assert isinstance(claimable_rewards, list) + assert len(claimable_rewards) == 2 + assert claimable_rewards[0] == 10 + assert claimable_rewards[1] == 10 diff --git a/tests/unit/resources/test_group.py b/tests/unit/resources/test_group.py index bfbcf45..94984de 100644 --- a/tests/unit/resources/test_group.py +++ b/tests/unit/resources/test_group.py @@ -422,3 +422,136 @@ def test_claim_rewards_transaction_build_failure( currency_token=ADDRESS, member_ip_ids=[IP_ID], ) + + +class TestGroupGetClaimableReward: + """Test class for Group.get_claimable_reward method""" + + def test_get_claimable_reward_invalid_group_ip_id( + self, group: Group, mock_web3_is_address + ): + """Test get_claimable_reward with invalid group IP ID.""" + invalid_group_ip_id = "invalid_group_ip_id" + with mock_web3_is_address(False): + with pytest.raises( + ValueError, + match=f"Failed to get claimable rewards: Invalid group IP ID: {invalid_group_ip_id}", + ): + group.get_claimable_reward( + group_ip_id=invalid_group_ip_id, + currency_token=ADDRESS, + member_ip_ids=[IP_ID], + ) + + def test_get_claimable_reward_invalid_currency_token(self, group: Group, mock_web3): + """Test get_claimable_reward with invalid currency token.""" + invalid_currency_token = "invalid_currency_token" + with patch.object(mock_web3, "is_address") as mock_is_address: + # group_ip_id=True, currency_token=False, member_ip_ids=True + mock_is_address.side_effect = [True, False] + with pytest.raises( + ValueError, + match=f"Failed to get claimable rewards: Invalid currency token: {invalid_currency_token}", + ): + group.get_claimable_reward( + group_ip_id=IP_ID, + currency_token=invalid_currency_token, + member_ip_ids=[IP_ID], + ) + + def test_get_claimable_reward_invalid_member_ip_id(self, group: Group, mock_web3): + """Test get_claimable_reward with invalid member IP ID.""" + invalid_member_ip_id = "invalid_member_ip_id" + with patch.object(mock_web3, "is_address") as mock_is_address: + # group_ip_id=True, currency_token=True, first_member=True, second_member=False + mock_is_address.side_effect = [True, True, True, False] + with pytest.raises( + ValueError, + match=f"Failed to get claimable rewards: Invalid member IP ID: {invalid_member_ip_id}", + ): + group.get_claimable_reward( + group_ip_id=IP_ID, + currency_token=ADDRESS, + member_ip_ids=[ADDRESS, invalid_member_ip_id], + ) + + def test_get_claimable_reward_success( + self, + group: Group, + mock_web3_is_address, + ): + """Test successful get_claimable_reward operation.""" + expected_claimable_rewards = [100, 200, 300] + member_ip_ids = [IP_ID, ADDRESS, ADDRESS] + + with mock_web3_is_address(): + with patch.object( + group.grouping_module_client, + "getClaimableReward", + return_value=expected_claimable_rewards, + ) as mock_get_claimable_reward: + result = group.get_claimable_reward( + group_ip_id=IP_ID, + currency_token=ADDRESS, + member_ip_ids=member_ip_ids, + ) + + # Verify the result + assert result == expected_claimable_rewards + assert len(result) == len(member_ip_ids) + mock_get_claimable_reward.assert_called_once_with( + groupId=IP_ID, + token=ADDRESS, + ipIds=member_ip_ids, + ) + + def test_get_claimable_reward_empty_member_ip_ids( + self, + group: Group, + mock_web3_is_address, + ): + """Test get_claimable_reward with empty member IP IDs list.""" + expected_claimable_rewards: list[int] = [] + + with mock_web3_is_address(): + with patch.object( + group.grouping_module_client, + "getClaimableReward", + return_value=expected_claimable_rewards, + ) as mock_get_claimable_reward: + result = group.get_claimable_reward( + group_ip_id=IP_ID, + currency_token=ADDRESS, + member_ip_ids=[], + ) + + # Verify the result + assert result == expected_claimable_rewards + assert len(result) == 0 + + # Verify the client method was called with correct parameters + mock_get_claimable_reward.assert_called_once_with( + groupId=IP_ID, + token=ADDRESS, + ipIds=[], + ) + + def test_get_claimable_reward_client_call_failure( + self, group: Group, mock_web3_is_address + ): + """Test get_claimable_reward when client call fails.""" + with mock_web3_is_address(): + with patch.object( + group.grouping_module_client, + "getClaimableReward", + side_effect=Exception("Client call failed"), + ): + with pytest.raises( + ValueError, + match="Failed to get claimable rewards: Client call failed", + ): + group.get_claimable_reward( + group_ip_id=IP_ID, + currency_token=ADDRESS, + member_ip_ids=[IP_ID], + )