Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/story_protocol_python_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from .resources.WIP import WIP
from .story_client import StoryClient
from .types.common import AccessPermission
from .types.resource.IPAsset import RegistrationResponse
from .types.resource.Group import ClaimReward, ClaimRewardsResponse
from .types.resource.IPAsset import (
RegisterPILTermsAndAttachResponse,
RegistrationResponse,
)
from .utils.constants import (
DEFAULT_FUNCTION_SELECTOR,
MAX_ROYALTY_TOKEN,
Expand All @@ -34,6 +38,9 @@
"DerivativeDataInput",
"IPMetadataInput",
"RegistrationResponse",
"ClaimRewardsResponse",
"ClaimReward",
"RegisterPILTermsAndAttachResponse",
# Constants
"ZERO_ADDRESS",
"ZERO_HASH",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ def build_addIp_transaction(
groupIpId, ipIds, maxAllowedRewardShare
).build_transaction(tx_params)

def claimReward(self, groupId, token, ipIds):
return self.contract.functions.claimReward(groupId, token, ipIds).transact()

def build_claimReward_transaction(self, groupId, token, ipIds, tx_params):
return self.contract.functions.claimReward(
groupId, token, ipIds
).build_transaction(tx_params)

def registerGroup(self, groupPool):
return self.contract.functions.registerGroup(groupPool).transact()

Expand Down
87 changes: 84 additions & 3 deletions src/story_protocol_python_sdk/resources/Group.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# src/story_protocol_python_sdk/resources/Group.py

from ens.ens import HexStr
from ens.ens import Address, HexStr
from web3 import Web3

from story_protocol_python_sdk.abi.CoreMetadataModule.CoreMetadataModule_client import (
Expand Down Expand Up @@ -28,6 +26,10 @@
PILicenseTemplateClient,
)
from story_protocol_python_sdk.types.common import RevShareType
from story_protocol_python_sdk.types.resource.Group import (
ClaimReward,
ClaimRewardsResponse,
)
from story_protocol_python_sdk.utils.constants import ZERO_ADDRESS, ZERO_HASH
from story_protocol_python_sdk.utils.license_terms import LicenseTerms
from story_protocol_python_sdk.utils.sign import Sign
Expand Down Expand Up @@ -531,6 +533,57 @@ def collect_and_distribute_group_royalties(
f"Failed to collect and distribute group royalties: {str(e)}"
)

def claim_rewards(
self,
group_ip_id: Address,
currency_token: Address,
member_ip_ids: list[Address],
tx_options: dict | None = None,
) -> ClaimRewardsResponse:
"""
Claim rewards for the entire group.

:param group_ip_id str: The ID of the group IP.
:param currency_token str: The address of the currency (revenue) token to claim..
:param member_ip_ids list: The IDs of the member IPs to distribute the rewards to.
:param tx_options dict: [Optional] The transaction options.
:return ClaimRewardsResponse: A response object with the transaction hash and claimed rewards.
"""
try:
if not self.web3.is_address(group_ip_id):
raise ValueError(f"Invalid group IP ID: {group_ip_id}")
if not self.web3.is_address(currency_token):
raise ValueError(f"Invalid currency token: {currency_token}")
for ip_id in member_ip_ids:
if not self.web3.is_address(ip_id):
raise ValueError(f"Invalid member IP ID: {ip_id}")

claim_reward_param = {
"groupIpId": group_ip_id,
"token": currency_token,
"memberIpIds": member_ip_ids,
}

response = build_and_send_transaction(
self.web3,
self.account,
self.grouping_module_client.build_claimReward_transaction,
*claim_reward_param.values(),
tx_options=tx_options,
)

claimed_rewards = self._parse_tx_claimed_reward_event(
response["tx_receipt"]
)

return ClaimRewardsResponse(
tx_hash=response["tx_hash"],
claimed_rewards=claimed_rewards,
)

except Exception as e:
raise ValueError(f"Failed to claim rewards: {str(e)}")

def _get_license_data(self, license_data: list) -> list:
"""
Process license data into the format expected by the contracts.
Expand Down Expand Up @@ -695,3 +748,31 @@ def _parse_tx_royalty_paid_event(self, tx_receipt: dict) -> list:
)

return royalties_distributed

def _parse_tx_claimed_reward_event(self, tx_receipt: dict) -> list[ClaimReward]:
"""
Parse the ClaimedReward event from a transaction receipt.

:param tx_receipt dict: The transaction receipt.
:return list: List of claimed rewards.
"""
event_signature = self.web3.keccak(
text="ClaimedReward(address,address,address,uint256)"
).hex()
claimed_rewards = []

for log in tx_receipt["logs"]:
if log["topics"][0].hex() == event_signature:
ip_id = "0x" + log["topics"][0].hex()[24:]
amount = int(log["data"][:66].hex(), 16)
token = "0x" + log["topics"][2].hex()[24:]

claimed_rewards.append(
ClaimReward(
ip_id=ip_id,
amount=amount,
token=token,
)
)

return claimed_rewards
7 changes: 6 additions & 1 deletion src/story_protocol_python_sdk/scripts/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,12 @@
{
"contract_name": "GroupingModule",
"contract_address": "0x69D3a7aa9edb72Bc226E745A7cCdd50D947b69Ac",
"functions": ["registerGroup", "addIp", "IPGroupRegistered"]
"functions": [
"registerGroup",
"addIp",
"IPGroupRegistered",
"claimReward"
]
},
{
"contract_name": "LicenseRegistry",
Expand Down
22 changes: 22 additions & 0 deletions src/story_protocol_python_sdk/types/resource/Group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import TypedDict

from ens.ens import Address, HexStr


class ClaimReward(TypedDict):
"""
Structure for a claimed reward.
"""

ip_id: Address
amount: int
token: Address


class ClaimRewardsResponse(TypedDict):
"""
Response structure for Group.claim_rewards method.
"""

tx_hash: HexStr
claimed_rewards: list[ClaimReward]
30 changes: 30 additions & 0 deletions tests/integration/test_integration_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import pytest

from story_protocol_python_sdk.story_client import StoryClient

from .setup_for_integration import (
EVEN_SPLIT_GROUP_POOL,
PIL_LICENSE_TEMPLATE,
Expand Down Expand Up @@ -538,3 +540,31 @@ def test_collect_and_distribute_group_royalties(
assert len(response["royalties_distributed"]) == 2
assert response["royalties_distributed"][0]["amount"] == 10
assert response["royalties_distributed"][1]["amount"] == 10

def test_claim_rewards(self, story_client: StoryClient, setup_royalty_collection):
"""Test claiming rewards for group members."""
group_ip_id = setup_royalty_collection["group_ip_id"]
ip_ids = setup_royalty_collection["ip_ids"]
# Collect and distribute royalties to set up rewards for claiming
story_client.Group.collect_and_distribute_group_royalties(
group_ip_id=group_ip_id, currency_tokens=[MockERC20], member_ip_ids=ip_ids
)

# Test claiming rewards for specific members
response = story_client.Group.claim_rewards(
group_ip_id=group_ip_id,
currency_token=MockERC20,
member_ip_ids=ip_ids,
)
# Verify response structure
assert "tx_hash" in response
assert isinstance(response["tx_hash"], str)
assert len(response["tx_hash"]) > 0

# Verify claimed rewards details if any are present
if response["claimed_rewards"]:
for reward in response["claimed_rewards"]:
assert "amount" in reward
assert isinstance(reward["amount"], int)
assert "token" in reward
assert story_client.web3.is_address(reward["token"])
8 changes: 8 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return MockContext()

return _mock


@pytest.fixture(scope="module")
def mock_web3_is_address(mock_web3):
def _mock(is_address: bool = True):
return patch.object(mock_web3, "is_address", return_value=is_address)

return _mock
Loading
Loading