diff --git a/src/story_protocol_python_sdk/resources/Group.py b/src/story_protocol_python_sdk/resources/Group.py index 31be4a4..2e86bb6 100644 --- a/src/story_protocol_python_sdk/resources/Group.py +++ b/src/story_protocol_python_sdk/resources/Group.py @@ -1,5 +1,6 @@ # src/story_protocol_python_sdk/resources/Group.py +from ens.ens import HexStr from web3 import Web3 from story_protocol_python_sdk.abi.CoreMetadataModule.CoreMetadataModule_client import ( @@ -26,10 +27,12 @@ from story_protocol_python_sdk.abi.PILicenseTemplate.PILicenseTemplate_client import ( PILicenseTemplateClient, ) +from story_protocol_python_sdk.types.common import RevShareType from story_protocol_python_sdk.utils.constants import ZERO_ADDRESS, ZERO_HASH from story_protocol_python_sdk.utils.license_terms import LicenseTerms from story_protocol_python_sdk.utils.sign import Sign from story_protocol_python_sdk.utils.transaction_utils import build_and_send_transaction +from story_protocol_python_sdk.utils.validation import get_revenue_share class Group: @@ -137,7 +140,7 @@ def mint_and_register_ip_and_attach_license_and_add_to_group( :param group_id str: The ID of the group to add the IP to. :param spg_nft_contract str: The address of the SPG NFT contract. :param license_data list: List of license data objects with terms and config. - :param max_allowed_reward_share int: Maximum allowed reward share percentage. + :param max_allowed_reward_share int: Maximum allowed reward share percentage. Must be between 0 and 100 (where 100% represents 100,000,000). :param ip_metadata dict: [Optional] The metadata for the IP. :param recipient str: [Optional] The recipient of the NFT (defaults to caller). :param allow_duplicates bool: [Optional] Whether to allow duplicate IPs. @@ -185,8 +188,8 @@ def mint_and_register_ip_and_attach_license_and_add_to_group( metadata = self._get_ip_metadata(ip_metadata) - max_allowed_reward_share = self.license_terms_util.get_revenue_share( - max_allowed_reward_share + max_allowed_reward_share = get_revenue_share( + max_allowed_reward_share, type=RevShareType.MAX_ALLOWED_REWARD_SHARE ) # Set recipient to caller if not provided @@ -249,7 +252,7 @@ def register_ip_and_attach_license_and_add_to_group( :param nft_contract str: The address of the NFT contract. :param token_id int: The token ID of the NFT. :param license_data list: List of license data objects with terms and config. - :param max_allowed_reward_share int: Maximum allowed reward share percentage. + :param max_allowed_reward_share int: Maximum allowed reward share percentage. Must be between 0 and 100 (where 100% represents 100,000,000). :param ip_metadata dict: [Optional] The metadata for the IP. :param deadline int: [Optional] The deadline for the signature in milliseconds. :param tx_options dict: [Optional] The transaction options. @@ -299,7 +302,7 @@ def register_ip_and_attach_license_and_add_to_group( sig_metadata_and_attach = self.sign_util.get_permission_signature( ip_id=ip_id, deadline=calculated_deadline, - state=self.web3.to_bytes(hexstr=ZERO_HASH), + state=self.web3.to_bytes(hexstr=HexStr(ZERO_HASH)), permissions=[ { "ipId": ip_id, @@ -338,7 +341,9 @@ def register_ip_and_attach_license_and_add_to_group( nft_contract, token_id, group_id, - self.license_terms_util.get_revenue_share(max_allowed_reward_share), + get_revenue_share( + max_allowed_reward_share, type=RevShareType.MAX_ALLOWED_REWARD_SHARE + ), licenses_data, metadata, { @@ -389,7 +394,7 @@ def register_group_and_attach_license_and_add_ips( :param group_pool str: The address of the group pool. :param ip_ids list: List of IP IDs to add to the group. :param license_data dict: License data object with terms and config. - :param max_allowed_reward_share int: Maximum allowed reward share percentage. + :param max_allowed_reward_share int: Maximum allowed reward share percentage. Must be between 0 and 100 (where 100% represents 100,000,000). :param tx_options dict: [Optional] The transaction options. :return dict: A dictionary with the transaction hash and group ID. """ @@ -427,7 +432,9 @@ def register_group_and_attach_license_and_add_ips( self.grouping_workflows_client.build_registerGroupAndAttachLicenseAndAddIps_transaction, group_pool, ip_ids, - self.license_terms_util.get_revenue_share(max_allowed_reward_share), + get_revenue_share( + max_allowed_reward_share, type=RevShareType.MAX_ALLOWED_REWARD_SHARE + ), license_data_processed, tx_options=tx_options, ) diff --git a/src/story_protocol_python_sdk/resources/IPAsset.py b/src/story_protocol_python_sdk/resources/IPAsset.py index 7940fca..a9cb446 100644 --- a/src/story_protocol_python_sdk/resources/IPAsset.py +++ b/src/story_protocol_python_sdk/resources/IPAsset.py @@ -35,7 +35,11 @@ ) from story_protocol_python_sdk.abi.SPGNFTImpl.SPGNFTImpl_client import SPGNFTImplClient from story_protocol_python_sdk.types.common import AccessPermission -from story_protocol_python_sdk.utils.constants import ZERO_ADDRESS, ZERO_HASH +from story_protocol_python_sdk.utils.constants import ( + MAX_ROYALTY_TOKEN, + ZERO_ADDRESS, + ZERO_HASH, +) from story_protocol_python_sdk.utils.derivative_data import ( DerivativeData, DerivativeDataInput, @@ -228,8 +232,8 @@ def register_derivative( parent_ip_ids: list, license_terms_ids: list, max_minting_fee: int = 0, - max_rts: int = 0, - max_revenue_share: int = 0, + max_rts: int = MAX_ROYALTY_TOKEN, + max_revenue_share: int = 100, license_template: str | None = None, tx_options: dict | None = None, ) -> dict: @@ -244,11 +248,11 @@ def register_derivative( :param parent_ip_ids list: The parent IP IDs :param license_terms_ids list: The IDs of the license terms that the parent IP supports :param max_minting_fee int: The maximum minting fee that the caller is willing to pay. - if set to 0 then no limit + if set to 0 then no limit. (default: 0) :param max_rts int: The maximum number of royalty tokens that can be distributed - (max: 100,000,000) - :param max_revenue_share int: The maximum revenue share percentage allowed (0-100,000,000) - :param license_template str: [Optional] The license template address + (max: 100,000,000) (default: 100,000,000) + :param max_revenue_share int: The maximum revenue share percentage allowed. Must be between 0 and 100 (where 100% represents 100,000,000). (default: 100) + :param license_template str: [Optional] The license template address. Defaults to [License Template](https://docs.story.foundation/docs/programmable-ip-license) address if not provided. :param tx_options dict: [Optional] Transaction options :return dict: A dictionary with the transaction hash """ @@ -257,24 +261,23 @@ def register_derivative( raise ValueError( f"The child IP with id {child_ip_id} is not registered." ) - - derivative_data = self._validate_derivative_data( - { - "childIpId": child_ip_id, - "parentIpIds": parent_ip_ids, - "licenseTermsIds": license_terms_ids, - "maxMintingFee": max_minting_fee, - "maxRts": max_rts, - "maxRevenueShare": max_revenue_share, - "licenseTemplate": license_template, - } - ) + derivative_data = DerivativeData.from_input( + web3=self.web3, + input_data=DerivativeDataInput( + parent_ip_ids=parent_ip_ids, + license_terms_ids=license_terms_ids, + max_minting_fee=max_minting_fee, + max_rts=max_rts, + max_revenue_share=max_revenue_share, + license_template=license_template, + ), + ).get_validated_data() response = build_and_send_transaction( self.web3, self.account, self.licensing_module_client.build_registerDerivative_transaction, - derivative_data["childIpId"], + child_ip_id, derivative_data["parentIpIds"], derivative_data["licenseTermsIds"], derivative_data["licenseTemplate"], @@ -367,7 +370,7 @@ def mint_and_register_ip_asset_with_pil_terms( :param commercializer_checker str: Allowed commercializers or zero address for none. :param commercializer_checker_data str: Data for checker contract. - :param commercial_rev_share int: The commercial revenue share percentage (from 0 to 100%, represented as 100_000_000). + :param commercial_rev_share int: Percentage of revenue that must be shared with the licensor. Must be between 0 and 100 (where 100% represents 100,000,000). :param commercial_rev_ceiling int: Maximum commercial revenue. :param derivatives_allowed bool: Whether derivatives are allowed. :param derivatives_attribution bool: Whether attribution is needed @@ -385,12 +388,10 @@ def mint_and_register_ip_asset_with_pil_terms( :param hook_data str: The data used by the licensing hook. :param licensing_hook str: The licensing hook contract address or address(0) if none. - :param commercial_rev_share int: The commercial revenue share percentage (from 0 to 100%, represented as 100_000_000). + :param commercial_rev_share int: Percentage of revenue that must be shared with the licensor. Must be between 0 and 100 (where 100% represents 100,000,000). :param disabled bool: Whether the license is disabled. - :param expect_minimum_group_reward_share int: Minimum group reward - share percentage (from 0 to 100%, represented as 100_000_000). - :param expect_group_reward_pool str: Address of the expected group - reward pool. + :param expect_minimum_group_reward_share int: Minimum group reward share percentage. Must be between 0 and 100 (where 100% represents 100,000,000). + :param expect_group_reward_pool str: Address of the expected group reward pool. :param ip_metadata dict: [Optional] NFT and IP metadata. :param ip_metadata_uri str: [Optional] IP metadata URI. :param ip_metadata_hash str: [Optional] IP metadata hash. @@ -560,7 +561,7 @@ def register_ip_and_attach_pil_terms( :param commercial_attribution bool: Whether attribution is required when reproducing the work commercially or not. :param commercializer_checker str: Commercializers that are allowed to commercially exploit the work. :param commercializer_checker_data str: The data to be passed to the commercializer checker contract. - :param commercial_rev_share int: Percentage of revenue that must be shared with the licensor. + :param commercial_rev_share int: Percentage of revenue that must be shared with the licensor. Must be between 0 and 100 (where 100% represents 100,000,000). :param commercial_rev_ceiling int: The maximum revenue that can be generated from the commercial use of the work. :param derivatives_allowed bool: Indicates whether the licensee can create derivatives of his work or not. :param derivatives_attribution bool: Indicates whether attribution is required for derivatives of the work or not. @@ -574,9 +575,9 @@ def register_ip_and_attach_pil_terms( :param minting_fee int: The minting fee to be paid when minting license tokens. :param licensing_hook str: The hook contract address for the licensing module. :param hook_data str: The data to be used by the licensing hook. - :param commercial_rev_share int: The commercial revenue share percentage. + :param commercial_rev_share int: Percentage of revenue that must be shared with the licensor. Must be between 0 and 100 (where 100% represents 100,000,000). :param disabled bool: Whether the licensing is disabled or not. - :param expect_minimum_group_reward_share int: The minimum percentage of the group's reward share. + :param expect_minimum_group_reward_share int: The minimum percentage of the group's reward share. Must be between 0 and 100 (where 100% represents 100,000,000). :param expect_group_reward_pool str: The address of the expected group reward pool. :param ip_metadata dict: [Optional] The metadata for the newly registered IP. :param ip_metadata_uri str: [Optional] The URI of the metadata for the IP. diff --git a/src/story_protocol_python_sdk/resources/License.py b/src/story_protocol_python_sdk/resources/License.py index 2859654..0f553ac 100644 --- a/src/story_protocol_python_sdk/resources/License.py +++ b/src/story_protocol_python_sdk/resources/License.py @@ -18,9 +18,11 @@ from story_protocol_python_sdk.abi.PILicenseTemplate.PILicenseTemplate_client import ( PILicenseTemplateClient, ) +from story_protocol_python_sdk.types.common import RevShareType from story_protocol_python_sdk.utils.constants import ZERO_ADDRESS from story_protocol_python_sdk.utils.license_terms import LicenseTerms from story_protocol_python_sdk.utils.transaction_utils import build_and_send_transaction +from story_protocol_python_sdk.utils.validation import get_revenue_share class License: @@ -86,7 +88,7 @@ def register_pil_terms( :param commercial_attribution bool: Whether attribution is required when reproducing the work commercially or not. :param commercializer_checker str: Commercializers that are allowed to commercially exploit the work. If zero address, then no restrictions is enforced. :param commercializer_checker_data str: The data to be passed to the commercializer checker contract. - :param commercial_rev_share int: Percentage of revenue that must be shared with the licensor. + :param commercial_rev_share int: Percentage of revenue that must be shared with the licensor. Must be between 0 and 100 (where 100% represents 100,000,000). :param commercial_rev_ceiling int: The maximum revenue that can be generated from the commercial use of the work. :param derivatives_allowed bool: Indicates whether the licensee can create derivatives of his work or not. :param derivatives_attribution bool: Indicates whether attribution is required for derivatives of the work or not. @@ -237,7 +239,7 @@ def register_commercial_remix_pil( :param default_minting_fee int: The fee to be paid when minting a license. :param currency str: The ERC20 token to be used to pay the minting fee. - :param commercial_rev_share int: Percentage of revenue that must be shared with the licensor. + :param commercial_rev_share int: Percentage of revenue that must be shared with the licensor. Must be between 0 and 100 (where 100% represents 100,000,000). :param royalty_policy str: The address of the royalty policy contract. :param tx_options dict: [Optional] The transaction options. :return dict: A dictionary with the transaction hash and the license terms ID. @@ -356,7 +358,7 @@ def mint_license_tokens( amount: int, receiver: str, max_minting_fee: int = 0, - max_revenue_share: int = 0, + max_revenue_share: int = 100, tx_options: dict | None = None, ) -> dict: """ @@ -367,8 +369,8 @@ def mint_license_tokens( :param license_terms_id int: The ID of the license terms within the license template. :param amount int: The amount of license tokens to mint. :param receiver str: The address of the receiver. - :param max_minting_fee int: [Optional] The maximum minting fee that the caller is willing to pay. If set to 0 then no limit. Defaults to 0. - :param max_revenue_share int: [Optional] The maximum revenue share percentage allowed for minting the License Tokens. Must be between 0 and 100,000,000 (where 100,000,000 represents 100%). Defaults to 0. + :param max_minting_fee int: [Optional] The maximum minting fee that the caller is willing to pay. If set to 0 then no limit. (default: 0) + :param max_revenue_share int: [Optional] The maximum revenue share percentage allowed for minting the License Tokens. Must be between 0 and 100,000,000 (where 100,000,000 represents 100%). (default: 100) :param tx_options dict: [Optional] The transaction options. :return dict: A dictionary with the transaction hash and the license token IDs. """ @@ -410,7 +412,10 @@ def mint_license_tokens( receiver, ZERO_ADDRESS, # Zero address for royalty context max_minting_fee, - self.license_terms_util.get_revenue_share(max_revenue_share), + get_revenue_share( + max_revenue_share, + RevShareType.MAX_REVENUE_SHARE, + ), tx_options=tx_options, ) diff --git a/src/story_protocol_python_sdk/utils/derivative_data.py b/src/story_protocol_python_sdk/utils/derivative_data.py index 47aeacc..7927e78 100644 --- a/src/story_protocol_python_sdk/utils/derivative_data.py +++ b/src/story_protocol_python_sdk/utils/derivative_data.py @@ -30,7 +30,7 @@ class DerivativeDataInput: license_terms_ids: List of license terms IDs corresponding to each parent IP. max_minting_fee: [Optional] The maximum minting fee that the caller is willing to pay. if set to 0 then no limit. (default: 0). max_rts: [Optional] The maximum number of royalty tokens that can be distributed to the external royalty policies. (max: 100,000,000) (default: 100,000,000). - max_revenue_share: [Optional] The maximum revenue share percentage allowed for minting the License Tokens. Must be between 0 and 100 (where 100% represents 100_000_000) (default: 100). + max_revenue_share: [Optional] The maximum revenue share percentage allowed for minting the License Tokens. Must be between 0 and 100 (where 100% represents 100,000,000) (default: 100). license_template: [Optional] The address of the license template. Defaults to [License Template](https://docs.story.foundation/docs/programmable-ip-license) address if not provided """ diff --git a/src/story_protocol_python_sdk/utils/license_terms.py b/src/story_protocol_python_sdk/utils/license_terms.py index 584efdd..7b190b3 100644 --- a/src/story_protocol_python_sdk/utils/license_terms.py +++ b/src/story_protocol_python_sdk/utils/license_terms.py @@ -92,9 +92,7 @@ def get_license_term_by_type(self, type, term=None): "currency": term["currency"], "commercialUse": True, "commercialAttribution": True, - "commercialRevShare": int( - (term["commercialRevShare"] / 100) * 100000000 - ), + "commercialRevShare": get_revenue_share(term["commercialRevShare"]), "derivativesReciprocal": True, "royaltyPolicy": term["royaltyPolicyAddress"], } @@ -125,16 +123,6 @@ def validate_license_terms(self, params): if commercial_rev_share < 0 or commercial_rev_share > 100: raise ValueError("commercial_rev_share should be between 0 and 100.") - expect_minimum_group_reward_share = params.get( - "expect_minimum_group_reward_share", 0 - ) - if ( - expect_minimum_group_reward_share < 0 - or expect_minimum_group_reward_share > 100 - ): - raise ValueError( - "Expect minimum group reward share must be between 0 and 100" - ) validated_params = { "transferable": params.get("transferable"), "royaltyPolicy": params.get("royalty_policy"), @@ -147,8 +135,7 @@ def validate_license_terms(self, params): hexstr=HexStr(params.get("commercializer_checker_data", ZERO_ADDRESS)) ), "commercialRevShare": get_revenue_share( - params.get("commercial_rev_share", 0), - RevShareType.COMMERCIAL_REVENUE_SHARE, + params.get("commercial_rev_share", 0) ), "commercialRevCeiling": int(params.get("commercial_rev_ceiling", 0)), "derivativesAllowed": params.get("derivatives_allowed"), diff --git a/src/story_protocol_python_sdk/utils/validation.py b/src/story_protocol_python_sdk/utils/validation.py index 5d0155b..85a603d 100644 --- a/src/story_protocol_python_sdk/utils/validation.py +++ b/src/story_protocol_python_sdk/utils/validation.py @@ -1,7 +1,6 @@ from web3 import Web3 from story_protocol_python_sdk.types.common import RevShareType -from story_protocol_python_sdk.utils.constants import MAX_ROYALTY_TOKEN def validate_address(address: str) -> str: @@ -25,10 +24,10 @@ def get_revenue_share( Convert revenue share percentage to token amount. :param revShare int: Revenue share percentage between 0-100 - :param type RevShareType: Type of revenue share + :param type RevShareType: Type of revenue share, default is commercial revenue share :return int: Revenue share token amount """ if revShare < 0 or revShare > 100: raise ValueError(f"The {type.value} must be between 0 and 100.") - return (revShare * MAX_ROYALTY_TOKEN) // 100 + return revShare * 10**6 diff --git a/tests/unit/resources/test_ip_asset.py b/tests/unit/resources/test_ip_asset.py index 4f9d974..fcbc0fd 100644 --- a/tests/unit/resources/test_ip_asset.py +++ b/tests/unit/resources/test_ip_asset.py @@ -342,3 +342,68 @@ def test_success( "license_terms_ids": [1, 2], "token_id": 1, } + + +class TestRegisterDerivative: + def test_default_value_when_not_provided( + self, + ip_asset: IPAsset, + mock_get_ip_id, + mock_is_registered, + mock_parse_ip_registered_event, + mock_license_registry_client, + ): + with mock_get_ip_id(), mock_is_registered( + True + ), mock_parse_ip_registered_event(), mock_license_registry_client(): + with patch.object( + ip_asset.licensing_module_client, + "build_registerDerivative_transaction", + ) as mock_build_registerDerivative_transaction: + + ip_asset.register_derivative( + child_ip_id=IP_ID, + parent_ip_ids=[IP_ID, IP_ID], + license_terms_ids=[1, 2], + ) + call_args = mock_build_registerDerivative_transaction.call_args[0] + print(call_args) + assert ( + call_args[3] == "0x1234567890123456789012345678901234567890" + ) # license_template + assert ( + call_args[4] == "0x0000000000000000000000000000000000000000" + ) # royalty_context + assert call_args[5] == 0 # max_minting_fee + assert call_args[6] == 100000000 # max_rts + assert call_args[7] == 100 * 10**6 # max_revenue_share + + def test_call_value_when_provided( + self, + ip_asset: IPAsset, + mock_get_ip_id, + mock_is_registered, + mock_parse_ip_registered_event, + mock_license_registry_client, + ): + with mock_get_ip_id(), mock_is_registered( + True + ), mock_parse_ip_registered_event(), mock_license_registry_client(): + with patch.object( + ip_asset.licensing_module_client, + "build_registerDerivative_transaction", + ) as mock_build_registerDerivative_transaction: + ip_asset.register_derivative( + child_ip_id=IP_ID, + parent_ip_ids=[IP_ID, IP_ID], + license_terms_ids=[1, 2], + max_revenue_share=10, + max_minting_fee=10, + max_rts=100, + license_template=ADDRESS, + ) + call_args = mock_build_registerDerivative_transaction.call_args[0] + assert call_args[7] == 10 * 10**6 # max_revenue_share + assert call_args[5] == 10 # max_minting_fee + assert call_args[6] == 100 # max_rts + assert call_args[3] == ADDRESS # license_template diff --git a/tests/unit/resources/test_license.py b/tests/unit/resources/test_license.py index f2151f0..8934200 100644 --- a/tests/unit/resources/test_license.py +++ b/tests/unit/resources/test_license.py @@ -1,10 +1,12 @@ from unittest.mock import MagicMock, patch import pytest +from _pytest.fixtures import fixture from eth_utils import is_address, to_checksum_address from web3 import Web3 from story_protocol_python_sdk.resources.License import License +from tests.unit.fixtures.data import ADDRESS, CHAIN_ID, IP_ID ZERO_ADDRESS = "0x0000000000000000000000000000000000000000" VALID_ADDRESS = "0x1daAE3197Bc469Cb97B917aa460a12dD95c6627c" @@ -832,3 +834,96 @@ def test_set_licensing_config_zero_address_with_rev_share(self, license_client): license_template=ZERO_ADDRESS, licensing_config=config, ) + + +######################################################################################## +##TODO: Need to refactor the previous test case + + +@fixture +def license(mock_web3, mock_account): + return License(web3=mock_web3, account=mock_account, chain_id=CHAIN_ID) + + +@fixture +def patch_is_registered(license): + def _patch(is_registered=True): + return patch.object( + license.ip_asset_registry_client, "isRegistered", return_value=is_registered + ) + + return _patch + + +@fixture +def patch_exists(license): + def _patch(exists=True): + return patch.object( + license.license_template_client, "exists", return_value=exists + ) + + return _patch + + +@fixture +def patch_has_ip_attached_license_terms(license): + def _patch(has_ip_attached_license_terms=True): + return patch.object( + license.license_registry_client, + "hasIpAttachedLicenseTerms", + return_value=has_ip_attached_license_terms, + ) + + return _patch + + +class TestMintLicenseTokens: + def test_default_value_when_not_provided( + self, + license: License, + patch_is_registered, + patch_exists, + patch_has_ip_attached_license_terms, + ): + with patch_is_registered(), patch_exists(), patch_has_ip_attached_license_terms(): + with patch.object( + license.licensing_module_client, + "build_mintLicenseTokens_transaction", + ) as mock_build_mintLicenseTokens_transaction: + + license.mint_license_tokens( + licensor_ip_id=IP_ID, + license_template=ADDRESS, + license_terms_id=1, + amount=1, + receiver=ZERO_ADDRESS, + ) + call_args = mock_build_mintLicenseTokens_transaction.call_args[0] + assert call_args[6] == 0 # max_minting_fee + assert call_args[7] == 100 * 10**6 # max_revenue_share + + def test_call_value_when_provided( + self, + license: License, + patch_is_registered, + patch_exists, + patch_has_ip_attached_license_terms, + ): + with patch_is_registered(), patch_exists(), patch_has_ip_attached_license_terms(): + with patch.object( + license.licensing_module_client, + "build_mintLicenseTokens_transaction", + ) as mock_build_mintLicenseTokens_transaction: + + license.mint_license_tokens( + licensor_ip_id=IP_ID, + license_template=ADDRESS, + license_terms_id=1, + amount=1, + receiver=ZERO_ADDRESS, + max_revenue_share=10, + max_minting_fee=10, + ) + call_args = mock_build_mintLicenseTokens_transaction.call_args[0] + assert call_args[6] == 10 # max_minting_fee + assert call_args[7] == 10 * 10**6 # max_revenue_share diff --git a/tests/unit/utils/test_validation.py b/tests/unit/utils/test_validation.py new file mode 100644 index 0000000..1e458cd --- /dev/null +++ b/tests/unit/utils/test_validation.py @@ -0,0 +1,66 @@ +import pytest + +from story_protocol_python_sdk.types.common import RevShareType +from story_protocol_python_sdk.utils.validation import ( + get_revenue_share, + validate_address, +) + + +class TestValidateAddress: + def test_valid_address(self): + address = "0x1234567890123456789012345678901234567890" + assert validate_address(address) == address + + def test_invalid_address(self): + with pytest.raises(ValueError, match="Invalid address: invalid_address."): + validate_address("invalid_address") + + +class TestGetRevenueShare: + def test_valid_revenue_share_of_100(self): + assert get_revenue_share(100) == 100 * 10**6 + + def test_valid_revenue_share_of_0(self): + assert get_revenue_share(0, RevShareType.COMMERCIAL_REVENUE_SHARE) == 0 + + def test_valid_revenue_share_of_50(self): + assert ( + get_revenue_share(50, RevShareType.MAX_ALLOWED_REWARD_SHARE) == 50 * 10**6 + ) + + def test_valid_revenue_share_with_type_of_100(self): + assert get_revenue_share(100) == 100 * 10**6 + + def test_revenue_share_less_than_0_with_commercial_revenue_share(self): + with pytest.raises( + ValueError, match="The commercial_rev_share must be between 0 and 100." + ): + get_revenue_share(-1) + + def test_revenue_share_greater_than_100_with_max_allowed_reward_share(self): + with pytest.raises( + ValueError, match="The max_allowed_reward_share must be between 0 and 100." + ): + get_revenue_share(101, RevShareType.MAX_ALLOWED_REWARD_SHARE) + + def test_revenue_share_greater_than_100_with_commercial_revenue_share(self): + with pytest.raises( + ValueError, match="The commercial_rev_share must be between 0 and 100." + ): + get_revenue_share(101, RevShareType.COMMERCIAL_REVENUE_SHARE) + + def test_revenue_share_less_than_0_with_max_revenue_share(self): + with pytest.raises( + ValueError, match="The max_revenue_share must be between 0 and 100." + ): + get_revenue_share(-1, RevShareType.MAX_REVENUE_SHARE) + + def test_revenue_share_greater_than_100_with_expect_minimum_group_reward_share( + self, + ): + with pytest.raises( + ValueError, + match="The expect_minimum_group_reward_share must be between 0 and 100.", + ): + get_revenue_share(101, RevShareType.EXPECT_MINIMUM_GROUP_REWARD_SHARE)