Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,6 @@ def build_registerGroup_transaction(self, groupPool, tx_params):
return self.contract.functions.registerGroup(groupPool).build_transaction(
tx_params
)

def getClaimableReward(self, groupId, token, ipIds):
return self.contract.functions.getClaimableReward(groupId, token, ipIds).call()
34 changes: 34 additions & 0 deletions src/story_protocol_python_sdk/resources/Group.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,40 @@ def collect_royalties(
except Exception as e:
raise ValueError(f"Failed to collect royalties: {str(e)}")

def get_claimable_reward(
self,
group_ip_id: Address,
currency_token: Address,
member_ip_ids: list[Address],
) -> list[int]:
"""
Returns the available reward for each IP in the group.

:param group_ip_id Address: The ID of the group IP.
:param currency_token Address: The address of the currency (revenue) token to check.
:param member_ip_ids list[Address]: The IDs of the member IPs to check claimable rewards for.
:return list[int]: A list of claimable reward amounts corresponding to each member IP ID.
"""
try:
if not self.web3.is_address(group_ip_id):
raise ValueError(f"Invalid group IP ID: {group_ip_id}")
if not self.web3.is_address(currency_token):
raise ValueError(f"Invalid currency token: {currency_token}")
for ip_id in member_ip_ids:
if not self.web3.is_address(ip_id):
raise ValueError(f"Invalid member IP ID: {ip_id}")

claimable_rewards = self.grouping_module_client.getClaimableReward(
groupId=group_ip_id,
token=currency_token,
ipIds=member_ip_ids,
)

return claimable_rewards

except Exception as e:
raise ValueError(f"Failed to get claimable rewards: {str(e)}")

def _get_license_data(self, license_data: list) -> list:
"""
Process license data into the format expected by the contracts.
Expand Down
3 changes: 2 additions & 1 deletion src/story_protocol_python_sdk/scripts/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@
"addIp",
"IPGroupRegistered",
"claimReward",
"collectRoyalties"
"collectRoyalties",
"getClaimableReward"
]
},
{
Expand Down
53 changes: 53 additions & 0 deletions tests/integration/test_integration_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,56 @@ def test_collect_and_distribute_group_royalties(
assert len(response["royalties_distributed"]) == 2
assert response["royalties_distributed"][0]["amount"] == 10
assert response["royalties_distributed"][1]["amount"] == 10

def test_get_claimable_reward(
self, story_client: StoryClient, nft_collection: Address
):
"""Test getting claimable rewards for group members."""
# Register IP id
result1 = GroupTestHelper.mint_and_register_ip_asset_with_pil_terms(
story_client, nft_collection
)
ip_id1 = result1["ip_id"]
license_terms_id1 = result1["license_terms_id"]
result2 = GroupTestHelper.mint_and_register_ip_asset_with_pil_terms(
story_client, nft_collection
)
ip_id2 = result2["ip_id"]
license_terms_id2 = result2["license_terms_id"]

# Register group id
group_ip_id = GroupTestHelper.register_group_and_attach_license(
story_client, license_terms_id1, [ip_id1, ip_id2]
)
# Create a derivative IP and pay royalties
child_ip_id = GroupTestHelper.mint_and_register_ip_and_make_derivative(
story_client, nft_collection, group_ip_id, license_terms_id1
)
child_ip_id2 = GroupTestHelper.mint_and_register_ip_and_make_derivative(
story_client, nft_collection, group_ip_id, license_terms_id2
)

# Pay royalties from group IP id to child IP id
GroupTestHelper.pay_royalty_and_transfer_to_vault(
story_client, child_ip_id, group_ip_id, MockERC20, 100
)
GroupTestHelper.pay_royalty_and_transfer_to_vault(
story_client, child_ip_id2, group_ip_id, MockERC20, 100
)

# Collect royalties
story_client.Group.collect_royalties(
group_ip_id=group_ip_id,
currency_token=MockERC20,
)
# Get claimable rewards after royalties are collected
claimable_rewards = story_client.Group.get_claimable_reward(
group_ip_id=group_ip_id,
currency_token=MockERC20,
member_ip_ids=[ip_id1, ip_id2],
)

assert isinstance(claimable_rewards, list)
assert len(claimable_rewards) == 2
assert claimable_rewards[0] == 10
assert claimable_rewards[1] == 10
133 changes: 133 additions & 0 deletions tests/unit/resources/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,3 +422,136 @@ def test_claim_rewards_transaction_build_failure(
currency_token=ADDRESS,
member_ip_ids=[IP_ID],
)


class TestGroupGetClaimableReward:
"""Test class for Group.get_claimable_reward method"""

def test_get_claimable_reward_invalid_group_ip_id(
self, group: Group, mock_web3_is_address
):
"""Test get_claimable_reward with invalid group IP ID."""
invalid_group_ip_id = "invalid_group_ip_id"
with mock_web3_is_address(False):
with pytest.raises(
ValueError,
match=f"Failed to get claimable rewards: Invalid group IP ID: {invalid_group_ip_id}",
):
group.get_claimable_reward(
group_ip_id=invalid_group_ip_id,
currency_token=ADDRESS,
member_ip_ids=[IP_ID],
)

def test_get_claimable_reward_invalid_currency_token(self, group: Group, mock_web3):
"""Test get_claimable_reward with invalid currency token."""
invalid_currency_token = "invalid_currency_token"
with patch.object(mock_web3, "is_address") as mock_is_address:
# group_ip_id=True, currency_token=False, member_ip_ids=True
mock_is_address.side_effect = [True, False]
with pytest.raises(
ValueError,
match=f"Failed to get claimable rewards: Invalid currency token: {invalid_currency_token}",
):
group.get_claimable_reward(
group_ip_id=IP_ID,
currency_token=invalid_currency_token,
member_ip_ids=[IP_ID],
)

def test_get_claimable_reward_invalid_member_ip_id(self, group: Group, mock_web3):
"""Test get_claimable_reward with invalid member IP ID."""
invalid_member_ip_id = "invalid_member_ip_id"
with patch.object(mock_web3, "is_address") as mock_is_address:
# group_ip_id=True, currency_token=True, first_member=True, second_member=False
mock_is_address.side_effect = [True, True, True, False]
with pytest.raises(
ValueError,
match=f"Failed to get claimable rewards: Invalid member IP ID: {invalid_member_ip_id}",
):
group.get_claimable_reward(
group_ip_id=IP_ID,
currency_token=ADDRESS,
member_ip_ids=[ADDRESS, invalid_member_ip_id],
)

def test_get_claimable_reward_success(
self,
group: Group,
mock_web3_is_address,
):
"""Test successful get_claimable_reward operation."""
expected_claimable_rewards = [100, 200, 300]
member_ip_ids = [IP_ID, ADDRESS, ADDRESS]

with mock_web3_is_address():
with patch.object(
group.grouping_module_client,
"getClaimableReward",
return_value=expected_claimable_rewards,
) as mock_get_claimable_reward:
result = group.get_claimable_reward(
group_ip_id=IP_ID,
currency_token=ADDRESS,
member_ip_ids=member_ip_ids,
)

# Verify the result
assert result == expected_claimable_rewards
assert len(result) == len(member_ip_ids)
mock_get_claimable_reward.assert_called_once_with(
groupId=IP_ID,
token=ADDRESS,
ipIds=member_ip_ids,
)

def test_get_claimable_reward_empty_member_ip_ids(
self,
group: Group,
mock_web3_is_address,
):
"""Test get_claimable_reward with empty member IP IDs list."""
expected_claimable_rewards: list[int] = []

with mock_web3_is_address():
with patch.object(
group.grouping_module_client,
"getClaimableReward",
return_value=expected_claimable_rewards,
) as mock_get_claimable_reward:
result = group.get_claimable_reward(
group_ip_id=IP_ID,
currency_token=ADDRESS,
member_ip_ids=[],
)

# Verify the result
assert result == expected_claimable_rewards
assert len(result) == 0

# Verify the client method was called with correct parameters
mock_get_claimable_reward.assert_called_once_with(
groupId=IP_ID,
token=ADDRESS,
ipIds=[],
)

def test_get_claimable_reward_client_call_failure(
self, group: Group, mock_web3_is_address
):
"""Test get_claimable_reward when client call fails."""
with mock_web3_is_address():
with patch.object(
group.grouping_module_client,
"getClaimableReward",
side_effect=Exception("Client call failed"),
):
with pytest.raises(
ValueError,
match="Failed to get claimable rewards: Client call failed",
):
group.get_claimable_reward(
group_ip_id=IP_ID,
currency_token=ADDRESS,
member_ip_ids=[IP_ID],
)
Loading