diff --git a/src/story_protocol_python_sdk/__init__.py b/src/story_protocol_python_sdk/__init__.py index 8a94af2..062f1c3 100644 --- a/src/story_protocol_python_sdk/__init__.py +++ b/src/story_protocol_python_sdk/__init__.py @@ -16,6 +16,7 @@ from .types.resource.IPAsset import ( RegisterPILTermsAndAttachResponse, RegistrationResponse, + RegistrationWithRoyaltyVaultResponse, ) from .utils.constants import ( DEFAULT_FUNCTION_SELECTOR, @@ -29,6 +30,7 @@ from .utils.derivative_data import DerivativeDataInput from .utils.ip_metadata import IPMetadataInput from .utils.licensing_config_data import LicensingConfig +from .utils.royalty_shares import RoyaltyShareInput __all__ = [ "StoryClient", @@ -43,11 +45,13 @@ "DerivativeDataInput", "IPMetadataInput", "RegistrationResponse", + "RegistrationWithRoyaltyVaultResponse", "ClaimRewardsResponse", "ClaimReward", "CollectRoyaltiesResponse", "LicensingConfig", "RegisterPILTermsAndAttachResponse", + "RoyaltyShareInput", # Constants "ZERO_ADDRESS", "ZERO_HASH", diff --git a/src/story_protocol_python_sdk/resources/IPAsset.py b/src/story_protocol_python_sdk/resources/IPAsset.py index 394ccf8..487fed2 100644 --- a/src/story_protocol_python_sdk/resources/IPAsset.py +++ b/src/story_protocol_python_sdk/resources/IPAsset.py @@ -36,11 +36,18 @@ from story_protocol_python_sdk.abi.RegistrationWorkflows.RegistrationWorkflows_client import ( RegistrationWorkflowsClient, ) +from story_protocol_python_sdk.abi.RoyaltyModule.RoyaltyModule_client import ( + RoyaltyModuleClient, +) +from story_protocol_python_sdk.abi.RoyaltyTokenDistributionWorkflows.RoyaltyTokenDistributionWorkflows_client import ( + RoyaltyTokenDistributionWorkflowsClient, +) from story_protocol_python_sdk.abi.SPGNFTImpl.SPGNFTImpl_client import SPGNFTImplClient from story_protocol_python_sdk.types.common import AccessPermission from story_protocol_python_sdk.types.resource.IPAsset import ( RegisterPILTermsAndAttachResponse, RegistrationResponse, + RegistrationWithRoyaltyVaultResponse, ) from story_protocol_python_sdk.utils.constants import ( MAX_ROYALTY_TOKEN, @@ -54,6 +61,10 @@ from story_protocol_python_sdk.utils.function_signature import get_function_signature from story_protocol_python_sdk.utils.ip_metadata import IPMetadata, IPMetadataInput from story_protocol_python_sdk.utils.license_terms import LicenseTerms +from story_protocol_python_sdk.utils.royalty_shares import ( + RoyaltyShare, + RoyaltyShareInput, +) from story_protocol_python_sdk.utils.sign import Sign from story_protocol_python_sdk.utils.transaction_utils import build_and_send_transaction from story_protocol_python_sdk.utils.validation import ( @@ -89,6 +100,10 @@ def __init__(self, web3: Web3, account, chain_id: int): self.core_metadata_module_client = CoreMetadataModuleClient(web3) self.access_controller_client = AccessControllerClient(web3) self.pi_license_template_client = PILicenseTemplateClient(web3) + self.royalty_token_distribution_workflows_client = ( + RoyaltyTokenDistributionWorkflowsClient(web3) + ) + self.royalty_module_client = RoyaltyModuleClient(web3) self.license_terms_util = LicenseTerms(web3) self.sign_util = Sign(web3, self.chain_id, self.account) @@ -457,7 +472,7 @@ def mint_and_register_ip_asset_with_pil_terms( self.account, self.license_attachment_workflows_client.build_mintAndRegisterIpAndAttachPILTerms_transaction, spg_nft_contract, - recipient if recipient else self.account.address, + self._validate_recipient(recipient), metadata, license_terms, allow_duplicates, @@ -531,7 +546,7 @@ def mint_and_register_ip( self.account, self.registration_workflows_client.build_mintAndRegisterIp_transaction, spg_nft_contract, - recipient if recipient else self.account.address, + self._validate_recipient(recipient), metadata, allow_duplicates, tx_options=tx_options, @@ -817,11 +832,7 @@ def mint_and_register_ip_and_make_derivative( validate_address(spg_nft_contract), validated_deriv_data, IPMetadata.from_input(ip_metadata).get_validated_data(), - ( - validate_address(recipient) - if recipient is not None - else self.account.address - ), + self._validate_recipient(recipient), allow_duplicates, tx_options=tx_options, ) @@ -870,11 +881,7 @@ def mint_and_register_ip_and_make_derivative_with_license_tokens( ZERO_ADDRESS, max_rts, IPMetadata.from_input(ip_metadata).get_validated_data(), - ( - validate_address(recipient) - if recipient is not None - else self.account.address - ), + self._validate_recipient(recipient), allow_duplicates, tx_options=tx_options, ) @@ -988,6 +995,66 @@ def register_ip_and_make_derivative_with_license_tokens( f"Failed to register IP and make derivative with license tokens: {str(e)}" ) from e + def mint_and_register_ip_and_make_derivative_and_distribute_royalty_tokens( + self, + spg_nft_contract: Address, + deriv_data: DerivativeDataInput, + royalty_shares: list[RoyaltyShareInput], + ip_metadata: IPMetadataInput | None = None, + recipient: Address | None = None, + allow_duplicates: bool = True, + tx_options: dict | None = None, + ) -> RegistrationWithRoyaltyVaultResponse: + """ + Mint an NFT and register the IP, make a derivative, and distribute royalty tokens. + + :param spg_nft_contract Address: The address of the SPGNFT collection. + :param deriv_data `DerivativeDataInput`: The derivative data to be used for register derivative. + :param royalty_shares `list[RoyaltyShareInput]`: The royalty shares to be distributed. + :param ip_metadata `IPMetadataInput`: [Optional] The desired metadata for the newly minted NFT and newly registered IP. + :param recipient Address: [Optional] The address to receive the minted NFT. If not provided, the client's own wallet address will be used. + :param allow_duplicates bool: [Optional] Set to true to allow minting an NFT with a duplicate metadata hash. (default: True) + :param tx_options dict: [Optional] Transaction options. + :return `RegistrationWithRoyaltyVaultResponse`: Dictionary with the tx hash, IP ID and token ID, royalty vault. + """ + try: + validated_royalty_shares_obj = RoyaltyShare.get_royalty_shares( + royalty_shares + ) + validated_deriv_data = DerivativeData.from_input( + web3=self.web3, input_data=deriv_data + ).get_validated_data() + + response = build_and_send_transaction( + self.web3, + self.account, + self.royalty_token_distribution_workflows_client.build_mintAndRegisterIpAndMakeDerivativeAndDistributeRoyaltyTokens_transaction, + validate_address(spg_nft_contract), + self._validate_recipient(recipient), + IPMetadata.from_input(ip_metadata).get_validated_data(), + validated_deriv_data, + validated_royalty_shares_obj["royalty_shares"], + allow_duplicates, + tx_options=tx_options, + ) + + ip_registered = self._parse_tx_ip_registered_event(response["tx_receipt"]) + royalty_vault = self.get_royalty_vault_address_by_ip_id( + response["tx_receipt"], + ip_registered["ip_id"], + ) + + return RegistrationWithRoyaltyVaultResponse( + tx_hash=response["tx_hash"], + ip_id=ip_registered["ip_id"], + token_id=ip_registered["token_id"], + royalty_vault=royalty_vault, + ) + except Exception as e: + raise ValueError( + f"Failed to mint, register IP, make derivative and distribute royalty tokens: {str(e)}" + ) from e + def register_pil_terms_and_attach( self, ip_id: Address, @@ -1244,3 +1311,37 @@ def _parse_tx_license_terms_attached_event(self, tx_receipt: dict) -> list[int]: license_terms_ids.append(license_terms_id) return license_terms_ids + + def get_royalty_vault_address_by_ip_id( + self, tx_receipt: dict, ipId: Address + ) -> Address: + """ + Parse the IpRoyaltyVaultDeployed event from a transaction receipt and return the royalty vault address for a given IP ID. + + :param tx_receipt dict: The transaction receipt. + :param ipId Address: The IP ID. + :return Address: The royalty vault address. + """ + event_signature = self.web3.keccak( + text="IpRoyaltyVaultDeployed(address,address)" + ).hex() + for log in tx_receipt["logs"]: + if log["topics"][0].hex() == event_signature: + event_result = self.royalty_module_client.contract.events.IpRoyaltyVaultDeployed.process_log( + log + ) + if event_result["args"]["ipId"] == ipId: + return event_result["args"]["ipRoyaltyVault"] + + raise ValueError("RoyaltyVaultDeployed event not found in transaction receipt.") + + def _validate_recipient(self, recipient: Address | None) -> Address: + """ + Validate the recipient address. + + :param recipient Address: The recipient address to validate. + :return Address: The validated recipient address. + """ + if recipient is None: + return self.account.address + return validate_address(recipient) diff --git a/src/story_protocol_python_sdk/types/resource/IPAsset.py b/src/story_protocol_python_sdk/types/resource/IPAsset.py index b7424f9..5e0f18a 100644 --- a/src/story_protocol_python_sdk/types/resource/IPAsset.py +++ b/src/story_protocol_python_sdk/types/resource/IPAsset.py @@ -1,4 +1,4 @@ -from typing import Optional, TypedDict +from typing import TypedDict from ens.ens import Address, HexStr @@ -15,7 +15,20 @@ class RegistrationResponse(TypedDict): ip_id: Address tx_hash: HexStr - token_id: Optional[int] + token_id: int + + +class RegistrationWithRoyaltyVaultResponse(RegistrationResponse): + """ + Response structure for IP asset registration operations with royalty vault. + + Extends `RegistrationResponse` with royalty vault information. + + Attributes: + royalty_vault: The royalty vault address of the registered IP asset + """ + + royalty_vault: Address class RegisterPILTermsAndAttachResponse(TypedDict): diff --git a/src/story_protocol_python_sdk/utils/royalty_shares.py b/src/story_protocol_python_sdk/utils/royalty_shares.py new file mode 100644 index 0000000..b64de92 --- /dev/null +++ b/src/story_protocol_python_sdk/utils/royalty_shares.py @@ -0,0 +1,71 @@ +"""Module for handling royalty shares data structure and validation.""" + +from dataclasses import dataclass +from typing import List + +from ens.ens import Address + +from story_protocol_python_sdk.utils.validation import validate_address + + +@dataclass +class RoyaltyShareInput: + """Input data structure for a single royalty share. + + Attributes: + recipient: The address of the recipient. + percentage: The percentage of the total royalty share. Supports up to 6 decimal places precision. For example, a value of 10 represents 10% of max royalty shares, which is 10,000,000. + """ + + recipient: Address + percentage: float | int + + +@dataclass +class RoyaltyShare: + """Validated royalty share data.""" + + @classmethod + def get_royalty_shares(cls, royalty_shares: List[RoyaltyShareInput]): + """ + Validate and convert royalty shares. + + :param royalty_shares: List of `RoyaltyShareInput` + :return: Dictionary with validated royalty_shares and total_amount + """ + if len(royalty_shares) == 0: + raise ValueError("Royalty shares must be provided.") + + actual_total = 0 + sum_percentage = 0.0 + converted_shares: List[dict] = [] + + for share_dict in royalty_shares: + recipient = validate_address(share_dict.recipient) + percentage = share_dict.percentage + + if percentage < 0: + raise ValueError( + "The percentage of the royalty shares must be greater than or equal to 0." + ) + + if percentage > 100: + raise ValueError( + "The percentage of the royalty shares must be less than or equal to 100." + ) + + sum_percentage += percentage + if sum_percentage > 100: + raise ValueError("The sum of the royalty shares cannot exceeds 100.") + + value = int(percentage * 10**6) + actual_total += value + + converted_shares.append( + { + "recipient": recipient, + "percentage": value, + } + ) + + return {"royalty_shares": converted_shares, "total_amount": actual_total} diff --git a/tests/integration/test_integration_ip_asset.py b/tests/integration/test_integration_ip_asset.py index 251a34e..152c168 100644 --- a/tests/integration/test_integration_ip_asset.py +++ b/tests/integration/test_integration_ip_asset.py @@ -10,6 +10,7 @@ from story_protocol_python_sdk.utils.constants import ROYALTY_POLICY_LAP_ADDRESS from story_protocol_python_sdk.utils.derivative_data import DerivativeDataInput from story_protocol_python_sdk.utils.ip_metadata import IPMetadataInput +from story_protocol_python_sdk.utils.royalty_shares import RoyaltyShareInput from tests.integration.config.test_config import account_2 from tests.integration.config.utils import approve @@ -988,3 +989,54 @@ def test_successful_registration( assert response is not None assert isinstance(response["tx_hash"], str) assert len(response["license_terms_ids"]) == 2 + + +class TestMintAndRegisterIpAndMakeDerivativeAndDistributeRoyaltyTokens: + def test_mint_register_ip_make_derivative_distribute_royalty_tokens_default_value( + self, story_client: StoryClient, nft_collection, parent_ip_and_license_terms + ): + response = story_client.IPAsset.mint_and_register_ip_and_make_derivative_and_distribute_royalty_tokens( + spg_nft_contract=nft_collection, + deriv_data=DerivativeDataInput( + parent_ip_ids=[parent_ip_and_license_terms["parent_ip_id"]], + license_terms_ids=[parent_ip_and_license_terms["license_terms_id"]], + ), + royalty_shares=[ + RoyaltyShareInput(recipient=account.address, percentage=50.000032222), + RoyaltyShareInput(recipient=account_2.address, percentage=30.000032222), + ], + ) + assert isinstance(response["tx_hash"], str) + assert isinstance(response["ip_id"], str) + assert isinstance(response["token_id"], int) + assert isinstance(response["royalty_vault"], str) + + def test_mint_register_ip_make_derivative_distribute_royalty_tokens_with_custom_values( + self, story_client: StoryClient, nft_collection, parent_ip_and_license_terms + ): + response = story_client.IPAsset.mint_and_register_ip_and_make_derivative_and_distribute_royalty_tokens( + spg_nft_contract=nft_collection, + deriv_data=DerivativeDataInput( + parent_ip_ids=[parent_ip_and_license_terms["parent_ip_id"]], + license_terms_ids=[parent_ip_and_license_terms["license_terms_id"]], + max_minting_fee=10000, + max_rts=10, + max_revenue_share=100, + ), + royalty_shares=[ + RoyaltyShareInput(recipient=account.address, percentage=60), + RoyaltyShareInput(recipient=account_2.address, percentage=40), + ], + ip_metadata=IPMetadataInput( + ip_metadata_uri="https://example.com/ip-metadata", + ip_metadata_hash=web3.keccak(text="ip_metadata_hash"), + nft_metadata_uri="https://example.com/nft-metadata", + nft_metadata_hash=web3.keccak(text="nft_metadata_hash"), + ), + recipient=account_2.address, + allow_duplicates=False, + ) + assert isinstance(response["tx_hash"], str) + assert isinstance(response["ip_id"], str) + assert isinstance(response["token_id"], int) + assert isinstance(response["royalty_vault"], str) diff --git a/tests/unit/resources/test_ip_asset.py b/tests/unit/resources/test_ip_asset.py index a65bf1d..baa3986 100644 --- a/tests/unit/resources/test_ip_asset.py +++ b/tests/unit/resources/test_ip_asset.py @@ -10,6 +10,10 @@ from story_protocol_python_sdk.utils.constants import ZERO_HASH from story_protocol_python_sdk.utils.derivative_data import DerivativeDataInput from story_protocol_python_sdk.utils.ip_metadata import IPMetadata, IPMetadataInput +from story_protocol_python_sdk.utils.royalty_shares import ( + RoyaltyShare, + RoyaltyShareInput, +) from tests.integration.config.utils import ZERO_ADDRESS from tests.unit.fixtures.data import ( ACCOUNT_ADDRESS, @@ -96,6 +100,18 @@ def _mock(): return _mock +@pytest.fixture +def mock_get_royalty_vault_address_by_ip_id(ip_asset): + def _mock(): + return patch.object( + ip_asset, + "get_royalty_vault_address_by_ip_id", + return_value=ADDRESS, + ) + + return _mock + + class TestIPAssetRegister: def test_register_invalid_deadline_type( self, ip_asset, mock_get_ip_id, mock_is_registered @@ -988,3 +1004,189 @@ def test_registration_with_transaction_failed( }, ], ) + + +class TestMintAndRegisterIpAndMakeDerivativeAndDistributeRoyaltyTokens: + def test_throw_error_when_royalty_shares_empty(self, ip_asset: IPAsset): + with pytest.raises(ValueError, match="Royalty shares must be provided."): + ip_asset.mint_and_register_ip_and_make_derivative_and_distribute_royalty_tokens( + spg_nft_contract=ADDRESS, + deriv_data=DerivativeDataInput( + parent_ip_ids=[IP_ID], + license_terms_ids=[1], + ), + royalty_shares=[], + ) + + def test_throw_error_when_deriv_data_is_invalid(self, ip_asset: IPAsset): + with pytest.raises(ValueError, match="The parent IP IDs must be provided."): + ip_asset.mint_and_register_ip_and_make_derivative_and_distribute_royalty_tokens( + spg_nft_contract=ADDRESS, + deriv_data=DerivativeDataInput( + parent_ip_ids=[], + license_terms_ids=[1], + ), + royalty_shares=[ + RoyaltyShareInput(recipient=ACCOUNT_ADDRESS, percentage=50.0) + ], + ) + + def test_success_with_default_values( + self, + ip_asset: IPAsset, + mock_license_registry_client, + mock_parse_ip_registered_event, + mock_get_royalty_vault_address_by_ip_id, + ): + royalty_shares = [ + RoyaltyShareInput(recipient=ACCOUNT_ADDRESS, percentage=50.0), + RoyaltyShareInput(recipient=ADDRESS, percentage=30.0), + ] + + with mock_parse_ip_registered_event(), mock_license_registry_client(), mock_get_royalty_vault_address_by_ip_id(): + with patch.object( + ip_asset.royalty_token_distribution_workflows_client, + "build_mintAndRegisterIpAndMakeDerivativeAndDistributeRoyaltyTokens_transaction", + return_value={"tx_hash": TX_HASH.hex()}, + ) as mock_build_transaction: + result = ip_asset.mint_and_register_ip_and_make_derivative_and_distribute_royalty_tokens( + spg_nft_contract=ADDRESS, + deriv_data=DerivativeDataInput( + parent_ip_ids=[IP_ID, IP_ID], + license_terms_ids=[1, 2], + ), + royalty_shares=royalty_shares, + ) + called_args = mock_build_transaction.call_args[0] + assert called_args[2] == IPMetadata.from_input().get_validated_data() + assert ( + called_args[4] + == RoyaltyShare.get_royalty_shares(royalty_shares)["royalty_shares"] + ) + assert called_args[5] is True + + assert result["tx_hash"] == TX_HASH.hex() + assert result["ip_id"] == IP_ID + assert result["token_id"] == 3 + assert result["royalty_vault"] == ADDRESS + + def test_royalty_vault_address( + self, + ip_asset: IPAsset, + mock_license_registry_client, + mock_parse_ip_registered_event, + mock_get_royalty_vault_address_by_ip_id, + ): + royalty_shares = [ + RoyaltyShareInput(recipient=ACCOUNT_ADDRESS, percentage=50.0), + RoyaltyShareInput(recipient=ADDRESS, percentage=30.0), + ] + + with mock_parse_ip_registered_event(), mock_license_registry_client(), mock_get_royalty_vault_address_by_ip_id(): + with patch( + "story_protocol_python_sdk.resources.IPAsset.build_and_send_transaction", + return_value={ + "tx_hash": TX_HASH, + "tx_receipt": { + "logs": [ + { + "topics": [ + ip_asset.web3.keccak( + text="IpRoyaltyVaultDeployed(address,address)" + ) + ], + "data": IP_ID + ADDRESS, + } + ] + }, + }, + ): + result = ip_asset.mint_and_register_ip_and_make_derivative_and_distribute_royalty_tokens( + spg_nft_contract=ADDRESS, + deriv_data=DerivativeDataInput( + parent_ip_ids=[IP_ID, IP_ID], + license_terms_ids=[1, 2], + ), + royalty_shares=royalty_shares, + ) + assert result["royalty_vault"] == ADDRESS + + def test_success_with_custom_values( + self, + ip_asset: IPAsset, + mock_license_registry_client, + mock_parse_ip_registered_event, + mock_get_royalty_vault_address_by_ip_id, + ): + royalty_shares = [ + RoyaltyShareInput(recipient=ACCOUNT_ADDRESS, percentage=60.0), + ] + ip_metadata = IPMetadataInput( + ip_metadata_uri="https://example.com/ip-metadata", + ip_metadata_hash="0x1234567890abcdef", + nft_metadata_uri="https://example.com/nft-metadata", + nft_metadata_hash="0xabcdef1234567890", + ) + with mock_parse_ip_registered_event(), mock_license_registry_client(), mock_get_royalty_vault_address_by_ip_id(): + with patch.object( + ip_asset.royalty_token_distribution_workflows_client, + "build_mintAndRegisterIpAndMakeDerivativeAndDistributeRoyaltyTokens_transaction", + return_value={"tx_hash": TX_HASH.hex()}, + ) as mock_build_transaction: + + result = ip_asset.mint_and_register_ip_and_make_derivative_and_distribute_royalty_tokens( + spg_nft_contract=ADDRESS, + deriv_data=DerivativeDataInput( + parent_ip_ids=[IP_ID], + license_terms_ids=[1], + max_minting_fee=10000, + max_rts=10, + max_revenue_share=100, + ), + royalty_shares=royalty_shares, + ip_metadata=ip_metadata, + recipient=ACCOUNT_ADDRESS, + allow_duplicates=False, + ) + + called_args = mock_build_transaction.call_args[0] + assert ( + called_args[2] + == IPMetadata.from_input(ip_metadata).get_validated_data() + ) + assert ( + called_args[4] + == RoyaltyShare.get_royalty_shares(royalty_shares)["royalty_shares"] + ) + assert called_args[5] is False + + assert result["tx_hash"] == TX_HASH.hex() + assert result["ip_id"] == IP_ID + assert result["token_id"] == 3 + assert result["royalty_vault"] == ADDRESS + + def test_throw_error_when_transaction_failed( + self, + ip_asset: IPAsset, + mock_license_registry_client, + mock_parse_ip_registered_event, + ): + with mock_parse_ip_registered_event(), mock_license_registry_client(): + with patch.object( + ip_asset.royalty_token_distribution_workflows_client, + "build_mintAndRegisterIpAndMakeDerivativeAndDistributeRoyaltyTokens_transaction", + side_effect=Exception("Transaction failed."), + ): + with pytest.raises(Exception, match="Transaction failed."): + ip_asset.mint_and_register_ip_and_make_derivative_and_distribute_royalty_tokens( + spg_nft_contract=ADDRESS, + deriv_data=DerivativeDataInput( + parent_ip_ids=[IP_ID], + license_terms_ids=[1], + ), + royalty_shares=[ + RoyaltyShareInput( + recipient=ACCOUNT_ADDRESS, percentage=50.0 + ) + ], + ) diff --git a/tests/unit/utils/test_royalty_shares.py b/tests/unit/utils/test_royalty_shares.py new file mode 100644 index 0000000..332c94a --- /dev/null +++ b/tests/unit/utils/test_royalty_shares.py @@ -0,0 +1,317 @@ +"""Tests for royalty_shares module.""" + +import pytest + +from story_protocol_python_sdk.utils.royalty_shares import ( + RoyaltyShare, + RoyaltyShareInput, +) + + +class TestRoyaltyShareGetRoyaltyShares: + """Test RoyaltyShare.get_royalty_shares method.""" + + def test_get_royalty_shares_success(self): + """Test successful processing of valid royalty shares.""" + shares = [ + RoyaltyShareInput( + recipient="0x1234567890123456789012345678901234567890", percentage=50.0 + ), + RoyaltyShareInput( + recipient="0xabcdefabcdefabcdefabcdefabcdefabcdefabcd", percentage=30.0 + ), + ] + + result = RoyaltyShare.get_royalty_shares(shares) + + expected_shares = [ + { + "recipient": "0x1234567890123456789012345678901234567890", + "percentage": 50_000_000, # 50.0 * 10^6 + }, + { + "recipient": "0xabcdefabcdefabcdefabcdefabcdefabcdefabcd", + "percentage": 30_000_000, # 30.0 * 10^6 + }, + ] + + assert result["royalty_shares"] == expected_shares + assert result["total_amount"] == 80_000_000 + + def test_get_royalty_shares_with_integer_percentages(self): + """Test processing royalty shares with integer percentages.""" + shares = [ + RoyaltyShareInput( + recipient="0x1234567890123456789012345678901234567890", percentage=25 + ), + RoyaltyShareInput( + recipient="0xabcdefabcdefabcdefabcdefabcdefabcdefabcd", percentage=75 + ), + ] + + result = RoyaltyShare.get_royalty_shares(shares) + expected_shares = [ + { + "recipient": "0x1234567890123456789012345678901234567890", + "percentage": 25_000_000, # 25 * 10^6 + }, + { + "recipient": "0xabcdefabcdefabcdefabcdefabcdefabcdefabcd", + "percentage": 75_000_000, # 75 * 10^6 + }, + ] + assert result["total_amount"] == 100_000_000 + assert result["royalty_shares"] == expected_shares + + def test_get_royalty_shares_precision_handling_6_decimals(self): + """Test precision handling with exactly 6 decimal places.""" + shares = [ + RoyaltyShareInput( + recipient="0x1234567890123456789012345678901234567890", + percentage=33.333333, # Exactly 6 decimal places + ), + RoyaltyShareInput( + recipient="0xabcdefabcdefabcdefabcdefabcdefabcdefabcd", + percentage=66.666667, # Exactly 6 decimal places + ), + ] + + result = RoyaltyShare.get_royalty_shares(shares) + + # 33.333333 * 10^6 = 33333333 + # 66.666667 * 10^6 = 66666667 + expected_shares = [ + { + "recipient": "0x1234567890123456789012345678901234567890", + "percentage": 33_333_333, + }, + { + "recipient": "0xabcdefabcdefabcdefabcdefabcdefabcdefabcd", + "percentage": 66_666_667, + }, + ] + + assert result["royalty_shares"] == expected_shares + assert result["total_amount"] == 100_000_000 + + def test_get_royalty_shares_precision_loss_more_than_6_decimals(self): + """Test precision loss with more than 6 decimal places.""" + shares = [ + RoyaltyShareInput( + recipient="0x1234567890123456789012345678901234567890", + percentage=33.3333333333, # More than 6 decimal places + ), + RoyaltyShareInput( + recipient="0xabcdefabcdefabcdefabcdefabcdefabcdefabcd", + percentage=66.6666666667, # More than 6 decimal places + ), + ] + + result = RoyaltyShare.get_royalty_shares(shares) + + # Due to floating point precision and int() truncation: + # 33.3333333333 * 10^6 = 33333333.3333, int() = 33333333 + # 66.6666666667 * 10^6 = 66666666.6667, int() = 66666666 + # Total would be 99999999, not 100000000 (precision loss) + + assert result["royalty_shares"][0]["percentage"] == 33_333_333 + assert result["royalty_shares"][1]["percentage"] == 66_666_666 + assert result["total_amount"] == 99_999_999 # Precision loss evident + + def test_get_royalty_shares_very_small_percentages(self): + """Test handling of very small percentages that might lose precision.""" + shares = [ + RoyaltyShareInput( + recipient="0x1234567890123456789012345678901234567890", + percentage=0.000001, # 1 part per million + ), + RoyaltyShareInput( + recipient="0xabcdefabcdefabcdefabcdefabcdefabcdefabcd", + percentage=99.999999, # Rest + ), + ] + + result = RoyaltyShare.get_royalty_shares(shares) + + # 0.000001 * 10^6 = 1 + # 99.999999 * 10^6 = 99999999 + assert result["royalty_shares"][0]["percentage"] == 1 + assert result["royalty_shares"][1]["percentage"] == 99_999_999 + assert result["total_amount"] == 100_000_000 + + def test_get_royalty_shares_boundary_case_exactly_100_percent(self): + """Test boundary case with exactly 100% total.""" + shares = [ + RoyaltyShareInput( + recipient="0x1234567890123456789012345678901234567890", percentage=100.0 + ) + ] + + result = RoyaltyShare.get_royalty_shares(shares) + + assert result["royalty_shares"][0]["percentage"] == 100_000_000 + assert result["total_amount"] == 100_000_000 + + def test_get_royalty_shares_boundary_case_minimum_percentage(self): + """Test boundary case with minimum valid percentage.""" + shares = [ + RoyaltyShareInput( + recipient="0x1234567890123456789012345678901234567890", + percentage=0.000001, # Minimum that results in 1 after conversion + ) + ] + + result = RoyaltyShare.get_royalty_shares(shares) + + assert result["royalty_shares"][0]["percentage"] == 1 + assert result["total_amount"] == 1 + + def test_get_royalty_shares_empty_list_error(self): + """Test error when providing empty royalty shares list.""" + with pytest.raises(ValueError, match="Royalty shares must be provided."): + RoyaltyShare.get_royalty_shares([]) + + def test_get_royalty_shares_zero_percentage(self): + """Test error when percentage is zero.""" + shares = [ + RoyaltyShareInput( + recipient="0x1234567890123456789012345678901234567890", percentage=0 + ) + ] + + result = RoyaltyShare.get_royalty_shares(shares) + + assert result["royalty_shares"][0]["percentage"] == 0 + assert result["total_amount"] == 0 + + def test_get_royalty_shares_negative_percentage_error(self): + """Test error when percentage is negative.""" + shares = [ + RoyaltyShareInput( + recipient="0x1234567890123456789012345678901234567890", percentage=-10 + ) + ] + + with pytest.raises( + ValueError, + match="he percentage of the royalty shares must be greater than or equal to 0.", + ): + RoyaltyShare.get_royalty_shares(shares) + + def test_get_royalty_shares_percentage_100(self): + """Test when percentage is 100.""" + shares = [ + RoyaltyShareInput( + recipient="0x1234567890123456789012345678901234567890", percentage=100 + ) + ] + + result = RoyaltyShare.get_royalty_shares(shares) + + assert result["royalty_shares"][0]["percentage"] == 100_000_000 + assert result["total_amount"] == 100_000_000 + + def test_get_royalty_shares_percentage_over_100(self): + """Test error when single percentage exceeds 100.""" + shares = [ + RoyaltyShareInput( + recipient="0x1234567890123456789012345678901234567890", percentage=101 + ) + ] + + with pytest.raises( + ValueError, + match="The percentage of the royalty shares must be less than or equal to 100.", + ): + RoyaltyShare.get_royalty_shares(shares) + + def test_get_royalty_shares_invalid_address_error(self): + """Test error when address is invalid.""" + shares = [RoyaltyShareInput(recipient="invalid_address", percentage=50)] + + with pytest.raises(ValueError, match="Invalid address"): + RoyaltyShare.get_royalty_shares(shares) + + def test_get_royalty_shares_cumulative_precision_boundary(self): + """Test cumulative precision at the boundary of 100%.""" + # This tests a scenario where individual percentages are valid + # but cumulative floating point errors might cause issues + shares = [ + RoyaltyShareInput( + recipient="0x1234567890123456789012345678901234567890", + percentage=33.333333, + ), + RoyaltyShareInput( + recipient="0xabcdefabcdefabcdefabcdefabcdefabcdefabcd", + percentage=33.333333, + ), + RoyaltyShareInput( + recipient="0x9876543210987654321098765432109876543210", + percentage=33.333334, + ), + ] + + # This should work because 33.333333 + 33.333333 + 33.333334 = 100.0 + result = RoyaltyShare.get_royalty_shares(shares) + + assert len(result["royalty_shares"]) == 3 + assert result["total_amount"] == 100_000_000 + + def test_get_royalty_shares_precision_edge_case_just_over_100(self): + """Test precision edge case where floating point arithmetic results in just over 100%.""" + shares = [ + RoyaltyShareInput( + recipient="0x1234567890123456789012345678901234567890", + percentage=50.0000001, + ), + RoyaltyShareInput( + recipient="0xabcdefabcdefabcdefabcdefabcdefabcdefabcd", + percentage=50.0000001, + ), + ] + + # 50.000001 + 50.000001 = 100.000002, which is > 100 + with pytest.raises( + ValueError, match="The sum of the royalty shares cannot exceeds 100." + ): + RoyaltyShare.get_royalty_shares(shares) + + def test_get_royalty_shares_single_recipient_multiple_entries(self): + """Test multiple entries for the same recipient.""" + shares = [ + RoyaltyShareInput( + recipient="0x1234567890123456789012345678901234567890", percentage=25.5 + ), + RoyaltyShareInput( + recipient="0x1234567890123456789012345678901234567890", percentage=24.5 + ), + RoyaltyShareInput( + recipient="0xabcdefabcdefabcdefabcdefabcdefabcdefabcd", percentage=50.0 + ), + ] + + result = RoyaltyShare.get_royalty_shares(shares) + + # Should treat each entry separately, not merge them + assert len(result["royalty_shares"]) == 3 + assert result["royalty_shares"][0]["percentage"] == 25_500_000 + assert result["royalty_shares"][1]["percentage"] == 24_500_000 + assert result["royalty_shares"][2]["percentage"] == 50_000_000 + assert result["total_amount"] == 100_000_000 + + def test_get_royalty_shares_mixed_data_types(self): + """Test mixing int and float percentages.""" + shares = [ + RoyaltyShareInput( + recipient="0x1234567890123456789012345678901234567890", percentage=25 + ), # int + RoyaltyShareInput( + recipient="0xabcdefabcdefabcdefabcdefabcdefabcdefabcd", percentage=75.0 + ), # float + ] + + result = RoyaltyShare.get_royalty_shares(shares) + + assert result["royalty_shares"][0]["percentage"] == 25_000_000 + assert result["royalty_shares"][1]["percentage"] == 75_000_000 + assert result["total_amount"] == 100_000_000