From 08a6b568f0eb977a3b22a9bf7f951d5279d9383e Mon Sep 17 00:00:00 2001 From: Bonnie Date: Tue, 12 Aug 2025 19:29:22 +0800 Subject: [PATCH 1/2] -mRefactor license terms and license config --- .../resources/Group.py | 28 +--- .../resources/IPAsset.py | 123 +++-------------- .../resources/License.py | 63 +++------ src/story_protocol_python_sdk/types/common.py | 7 +- .../utils/constants.py | 6 +- .../utils/license_terms.py | 126 ++++++++++-------- tests/integration/test_integration_group.py | 2 +- tests/unit/fixtures/data.py | 30 +++++ tests/unit/resources/test_license.py | 4 +- tests/unit/utils/test_derivative_data.py | 4 +- 10 files changed, 160 insertions(+), 233 deletions(-) diff --git a/src/story_protocol_python_sdk/resources/Group.py b/src/story_protocol_python_sdk/resources/Group.py index 9776d199..31be4a4c 100644 --- a/src/story_protocol_python_sdk/resources/Group.py +++ b/src/story_protocol_python_sdk/resources/Group.py @@ -547,34 +547,12 @@ def _get_license_data(self, license_data: list) -> list: f'License template address "{license_template}" is invalid.' ) - # Validate licensing config - licensing_config = item.get("licensing_config", {}) - - try: - self.license_terms_util.validate_licensing_config(licensing_config) - except Exception as e: - raise ValueError(f"Licensing config validation failed: {str(e)}") - - # Convert to camelCase for contract interaction - camelcase_config = { - "isSet": licensing_config.get("is_set", True), - "mintingFee": licensing_config.get("minting_fee", 0), - "hookData": licensing_config.get("hook_data", ZERO_ADDRESS), - "licensingHook": licensing_config.get("licensing_hook", ZERO_ADDRESS), - "commercialRevShare": licensing_config.get("commercial_rev_share", 0), - "disabled": licensing_config.get("disabled", False), - "expectMinimumGroupRewardShare": licensing_config.get( - "expect_minimum_group_reward_share", 0 - ), - "expectGroupRewardPool": licensing_config.get( - "expect_group_reward_pool", ZERO_ADDRESS - ), - } - processed_item = { "licenseTemplate": license_template, "licenseTermsId": item["license_terms_id"], - "licensingConfig": camelcase_config, + "licensingConfig": self.license_terms_util.validate_licensing_config( + item.get("licensing_config", {}) + ), } result.append(processed_item) diff --git a/src/story_protocol_python_sdk/resources/IPAsset.py b/src/story_protocol_python_sdk/resources/IPAsset.py index 2aea9ff0..7940fca4 100644 --- a/src/story_protocol_python_sdk/resources/IPAsset.py +++ b/src/story_protocol_python_sdk/resources/IPAsset.py @@ -1,5 +1,6 @@ """Module for handling IP Account operations and transactions.""" +from ens.ens import HexStr from web3 import Web3 from story_protocol_python_sdk.abi.AccessController.AccessController_client import ( @@ -173,7 +174,7 @@ def register( signature_response = self.sign_util.get_permission_signature( ip_id=ip_id, deadline=calculated_deadline, - state=self.web3.to_bytes(hexstr=ZERO_HASH), + state=self.web3.to_bytes(hexstr=HexStr(ZERO_HASH)), permissions=[ { "ipId": ip_id, @@ -366,7 +367,7 @@ def mint_and_register_ip_asset_with_pil_terms( :param commercializer_checker str: Allowed commercializers or zero address for none. :param commercializer_checker_data str: Data for checker contract. - :param commercial_rev_share int: Revenue share percentage. + :param commercial_rev_share int: The commercial revenue share percentage (from 0 to 100%, represented as 100_000_000). :param commercial_rev_ceiling int: Maximum commercial revenue. :param derivatives_allowed bool: Whether derivatives are allowed. :param derivatives_attribution bool: Whether attribution is needed @@ -384,10 +385,10 @@ def mint_and_register_ip_asset_with_pil_terms( :param hook_data str: The data used by the licensing hook. :param licensing_hook str: The licensing hook contract address or address(0) if none. - :param commercial_rev_share int: Commercial revenue share percent. + :param commercial_rev_share int: The commercial revenue share percentage (from 0 to 100%, represented as 100_000_000). :param disabled bool: Whether the license is disabled. :param expect_minimum_group_reward_share int: Minimum group reward - share (0-100%, as 100 * 10^6). + share percentage (from 0 to 100%, represented as 100_000_000). :param expect_group_reward_pool str: Address of the expected group reward pool. :param ip_metadata dict: [Optional] NFT and IP metadata. @@ -405,57 +406,17 @@ def mint_and_register_ip_asset_with_pil_terms( raise ValueError( f"The NFT contract address {spg_nft_contract} is not valid." ) - license_terms = [] for term in terms: - self.license_terms_util.validate_license_terms(term["terms"]) - validated_licensing_config = ( - self.license_terms_util.validate_licensing_config( - term["licensing_config"] - ) - ) - - camelcase_term = { - "transferable": term["terms"]["transferable"], - "royaltyPolicy": term["terms"]["royalty_policy"], - "defaultMintingFee": term["terms"]["default_minting_fee"], - "expiration": term["terms"]["expiration"], - "commercialUse": term["terms"]["commercial_use"], - "commercialAttribution": term["terms"]["commercial_attribution"], - "commercializerChecker": term["terms"]["commercializer_checker"], - "commercializerCheckerData": term["terms"][ - "commercializer_checker_data" - ], - "commercialRevShare": term["terms"]["commercial_rev_share"], - "commercialRevCeiling": term["terms"]["commercial_rev_ceiling"], - "derivativesAllowed": term["terms"]["derivatives_allowed"], - "derivativesAttribution": term["terms"]["derivatives_attribution"], - "derivativesApproval": term["terms"]["derivatives_approval"], - "derivativesReciprocal": term["terms"]["derivatives_reciprocal"], - "derivativeRevCeiling": term["terms"]["derivative_rev_ceiling"], - "currency": term["terms"]["currency"], - "uri": term["terms"]["uri"], - } - - camelcase_config = { - "isSet": validated_licensing_config["is_set"], - "mintingFee": validated_licensing_config["minting_fee"], - "hookData": validated_licensing_config["hook_data"], - "licensingHook": validated_licensing_config["licensing_hook"], - "commercialRevShare": validated_licensing_config[ - "commercial_rev_share" - ], - "disabled": validated_licensing_config["disabled"], - "expectMinimumGroupRewardShare": validated_licensing_config[ - "expect_minimum_group_reward_share" - ], - "expectGroupRewardPool": validated_licensing_config[ - "expect_group_reward_pool" - ], - } - license_terms.append( - {"terms": camelcase_term, "licensingConfig": camelcase_config} + { + "terms": self.license_terms_util.validate_license_terms( + term["terms"] + ), + "licensingConfig": self.license_terms_util.validate_licensing_config( + term["licensing_config"] + ), + } ) metadata = { @@ -632,57 +593,17 @@ def register_ip_and_attach_pil_terms( raise ValueError( f"The NFT with id {token_id} is already registered as IP." ) - license_terms = [] for term in license_terms_data: - self.license_terms_util.validate_license_terms(term["terms"]) - validated_licensing_config = ( - self.license_terms_util.validate_licensing_config( - term["licensing_config"] - ) - ) - - camelcase_term = { - "transferable": term["terms"]["transferable"], - "royaltyPolicy": term["terms"]["royalty_policy"], - "defaultMintingFee": term["terms"]["default_minting_fee"], - "expiration": term["terms"]["expiration"], - "commercialUse": term["terms"]["commercial_use"], - "commercialAttribution": term["terms"]["commercial_attribution"], - "commercializerChecker": term["terms"]["commercializer_checker"], - "commercializerCheckerData": term["terms"][ - "commercializer_checker_data" - ], - "commercialRevShare": term["terms"]["commercial_rev_share"], - "commercialRevCeiling": term["terms"]["commercial_rev_ceiling"], - "derivativesAllowed": term["terms"]["derivatives_allowed"], - "derivativesAttribution": term["terms"]["derivatives_attribution"], - "derivativesApproval": term["terms"]["derivatives_approval"], - "derivativesReciprocal": term["terms"]["derivatives_reciprocal"], - "derivativeRevCeiling": term["terms"]["derivative_rev_ceiling"], - "currency": term["terms"]["currency"], - "uri": term["terms"]["uri"], - } - - camelcase_config = { - "isSet": validated_licensing_config["is_set"], - "mintingFee": validated_licensing_config["minting_fee"], - "hookData": validated_licensing_config["hook_data"], - "licensingHook": validated_licensing_config["licensing_hook"], - "commercialRevShare": validated_licensing_config[ - "commercial_rev_share" - ], - "disabled": validated_licensing_config["disabled"], - "expectMinimumGroupRewardShare": validated_licensing_config[ - "expect_minimum_group_reward_share" - ], - "expectGroupRewardPool": validated_licensing_config[ - "expect_group_reward_pool" - ], - } - license_terms.append( - {"terms": camelcase_term, "licensingConfig": camelcase_config} + { + "terms": self.license_terms_util.validate_license_terms( + term["terms"] + ), + "licensingConfig": self.license_terms_util.validate_licensing_config( + term["licensing_config"] + ), + } ) calculated_deadline = self.sign_util.get_deadline(deadline=deadline) @@ -691,7 +612,7 @@ def register_ip_and_attach_pil_terms( signature_response = self.sign_util.get_permission_signature( ip_id=ip_id, deadline=calculated_deadline, - state=self.web3.to_bytes(hexstr=ZERO_HASH), + state=self.web3.to_bytes(hexstr=HexStr(ZERO_HASH)), permissions=[ { "ipId": ip_id, diff --git a/src/story_protocol_python_sdk/resources/License.py b/src/story_protocol_python_sdk/resources/License.py index 3700c08a..2859654f 100644 --- a/src/story_protocol_python_sdk/resources/License.py +++ b/src/story_protocol_python_sdk/resources/License.py @@ -99,48 +99,27 @@ def register_pil_terms( :return dict: A dictionary with the transaction hash and license terms ID. """ try: - license_terms = { - "transferable": transferable, - "royaltyPolicy": royalty_policy, - "defaultMintingFee": default_minting_fee, - "expiration": expiration, - "commercialUse": commercial_use, - "commercialAttribution": commercial_attribution, - "commercializerChecker": commercializer_checker, - "commercializerCheckerData": commercializer_checker_data, - "commercialRevShare": commercial_rev_share, - "commercialRevCeiling": commercial_rev_ceiling, - "derivativesAllowed": derivatives_allowed, - "derivativesAttribution": derivatives_attribution, - "derivativesApproval": derivatives_approval, - "derivativesReciprocal": derivatives_reciprocal, - "derivativeRevCeiling": derivative_rev_ceiling, - "currency": currency, - "uri": uri, - } - - license_terms_snake = { - "transferable": transferable, - "royalty_policy": royalty_policy, - "default_minting_fee": default_minting_fee, - "expiration": expiration, - "commercial_use": commercial_use, - "commercial_attribution": commercial_attribution, - "commercializer_checker": commercializer_checker, - "commercializer_checker_data": commercializer_checker_data, - "commercial_rev_share": commercial_rev_share, - "commercial_rev_ceiling": commercial_rev_ceiling, - "derivatives_allowed": derivatives_allowed, - "derivatives_attribution": derivatives_attribution, - "derivatives_approval": derivatives_approval, - "derivatives_reciprocal": derivatives_reciprocal, - "derivative_rev_ceiling": derivative_rev_ceiling, - "currency": currency, - "uri": uri, - } - - # Validate the license terms - self.license_terms_util.validate_license_terms(license_terms_snake) + license_terms = self.license_terms_util.validate_license_terms( + { + "transferable": transferable, + "royalty_policy": royalty_policy, + "default_minting_fee": default_minting_fee, + "expiration": expiration, + "commercial_use": commercial_use, + "commercial_attribution": commercial_attribution, + "commercializer_checker": commercializer_checker, + "commercializer_checker_data": commercializer_checker_data, + "commercial_rev_share": commercial_rev_share, + "commercial_rev_ceiling": commercial_rev_ceiling, + "derivatives_allowed": derivatives_allowed, + "derivatives_attribution": derivatives_attribution, + "derivatives_approval": derivatives_approval, + "derivatives_reciprocal": derivatives_reciprocal, + "derivative_rev_ceiling": derivative_rev_ceiling, + "currency": currency, + "uri": uri, + } + ) license_terms_id = self._get_license_terms_id(license_terms) if (license_terms_id is not None) and (license_terms_id != 0): diff --git a/src/story_protocol_python_sdk/types/common.py b/src/story_protocol_python_sdk/types/common.py index 314e2196..d78d9179 100644 --- a/src/story_protocol_python_sdk/types/common.py +++ b/src/story_protocol_python_sdk/types/common.py @@ -2,9 +2,10 @@ class RevShareType(Enum): - COMMERCIAL_REVENUE_SHARE = "commercialRevShare" - MAX_REVENUE_SHARE = "maxRevenueShare" - MAX_ALLOWED_REWARD_SHARE = "maxAllowedRewardShare" + COMMERCIAL_REVENUE_SHARE = "commercial_rev_share" + MAX_REVENUE_SHARE = "max_revenue_share" + MAX_ALLOWED_REWARD_SHARE = "max_allowed_reward_share" + EXPECT_MINIMUM_GROUP_REWARD_SHARE = "expect_minimum_group_reward_share" class AccessPermission(Enum): diff --git a/src/story_protocol_python_sdk/utils/constants.py b/src/story_protocol_python_sdk/utils/constants.py index b30f2769..a0f31171 100644 --- a/src/story_protocol_python_sdk/utils/constants.py +++ b/src/story_protocol_python_sdk/utils/constants.py @@ -1,9 +1,5 @@ -from eth_typing import HexStr - ZERO_ADDRESS = "0x0000000000000000000000000000000000000000" -ZERO_HASH: HexStr = HexStr( - "0x0000000000000000000000000000000000000000000000000000000000000000" -) +ZERO_HASH = "0x0000000000000000000000000000000000000000000000000000000000000000" ZERO_FUNC = "0x00000000" DEFAULT_FUNCTION_SELECTOR = "0x00000000" MAX_ROYALTY_TOKEN = 100000000 diff --git a/src/story_protocol_python_sdk/utils/license_terms.py b/src/story_protocol_python_sdk/utils/license_terms.py index f660f1a9..584efdd4 100644 --- a/src/story_protocol_python_sdk/utils/license_terms.py +++ b/src/story_protocol_python_sdk/utils/license_terms.py @@ -6,10 +6,12 @@ from story_protocol_python_sdk.abi.RoyaltyModule.RoyaltyModule_client import ( RoyaltyModuleClient, ) +from story_protocol_python_sdk.types.common import RevShareType from story_protocol_python_sdk.utils.constants import ( ROYALTY_POLICY_LAP_ADDRESS, ZERO_ADDRESS, ) +from story_protocol_python_sdk.utils.validation import get_revenue_share class LicenseTerms: @@ -119,34 +121,48 @@ def validate_license_terms(self, params): if royalty_policy != ZERO_ADDRESS and currency == ZERO_ADDRESS: raise ValueError("Royalty policy requires currency token.") - params["default_minting_fee"] = int(params.get("default_minting_fee", 0)) - params["expiration"] = int(params.get("expiration", 0)) - params["commercial_rev_ceiling"] = int(params.get("commercial_rev_ceiling", 0)) - params["derivative_rev_ceiling"] = int(params.get("derivative_rev_ceiling", 0)) - - self.verify_commercial_use(params) - self.verify_derivatives(params) - commercial_rev_share = params.get("commercial_rev_share", 0) if commercial_rev_share < 0 or commercial_rev_share > 100: - raise ValueError("CommercialRevShare should be between 0 and 100.") - else: - params["commercial_rev_share"] = int( - (commercial_rev_share / 100) * 100000000 - ) + raise ValueError("commercial_rev_share should be between 0 and 100.") - commercializer_checker_data = params.get( - "commercializer_checker_data", ZERO_ADDRESS + expect_minimum_group_reward_share = params.get( + "expect_minimum_group_reward_share", 0 ) - if isinstance(commercializer_checker_data, str): - params["commercializer_checker_data"] = Web3.to_bytes( - hexstr=HexStr(commercializer_checker_data) + if ( + expect_minimum_group_reward_share < 0 + or expect_minimum_group_reward_share > 100 + ): + raise ValueError( + "Expect minimum group reward share must be between 0 and 100" ) + validated_params = { + "transferable": params.get("transferable"), + "royaltyPolicy": params.get("royalty_policy"), + "defaultMintingFee": int(params.get("default_minting_fee", 0)), + "expiration": int(params.get("expiration", 0)), + "commercialUse": params.get("commercial_use"), + "commercialAttribution": params.get("commercial_attribution"), + "commercializerChecker": params.get("commercializer_checker"), + "commercializerCheckerData": Web3.to_bytes( + hexstr=HexStr(params.get("commercializer_checker_data", ZERO_ADDRESS)) + ), + "commercialRevShare": get_revenue_share( + params.get("commercial_rev_share", 0), + RevShareType.COMMERCIAL_REVENUE_SHARE, + ), + "commercialRevCeiling": int(params.get("commercial_rev_ceiling", 0)), + "derivativesAllowed": params.get("derivatives_allowed"), + "derivativesAttribution": params.get("derivatives_attribution"), + "derivativesApproval": params.get("derivatives_approval"), + "derivativesReciprocal": params.get("derivatives_reciprocal"), + "derivativeRevCeiling": int(params.get("derivative_rev_ceiling", 0)), + "currency": params.get("currency"), + "uri": params.get("uri"), + } - params["expect_minimum_group_reward_share"] = int( - params.get("expect_minimum_group_reward_share", 0) - ) - return params + self.verify_commercial_use(validated_params) + self.verify_derivatives(validated_params) + return validated_params def validate_licensing_config(self, params): if not isinstance(params, dict): @@ -169,14 +185,14 @@ def validate_licensing_config(self, params): raise TypeError(f"{param} must be of type {expected_type.__name__}") default_params = { - "is_set": False, - "minting_fee": 0, - "hook_data": ZERO_ADDRESS, - "licensing_hook": ZERO_ADDRESS, - "commercial_rev_share": 0, + "isSet": False, + "mintingFee": 0, + "hookData": ZERO_ADDRESS, + "licensingHook": ZERO_ADDRESS, + "commercialRevShare": 0, "disabled": False, - "expect_minimum_group_reward_share": 0, - "expect_group_reward_pool": ZERO_ADDRESS, + "expectMinimumGroupRewardShare": 0, + "expectGroupRewardPool": ZERO_ADDRESS, } if not params.get("is_set", False): @@ -190,11 +206,6 @@ def validate_licensing_config(self, params): or params.get("commercial_rev_share", 0) > 100 ): raise ValueError("Commercial revenue share must be between 0 and 100") - else: - params["commercial_rev_share"] = int( - (params["commercial_rev_share"] / 100) * 100000000 - ) - if ( params.get("expect_minimum_group_reward_share", 0) < 0 or params.get("expect_minimum_group_reward_share", 0) > 100 @@ -202,60 +213,71 @@ def validate_licensing_config(self, params): raise ValueError( "Expect minimum group reward share must be between 0 and 100" ) + validated_params = { + "isSet": params.get("is_set", False), + "mintingFee": params.get("minting_fee", 0), + "hookData": Web3.to_bytes(hexstr=HexStr(params["hook_data"])), + "licensingHook": params.get("licensing_hook", ZERO_ADDRESS), + "commercialRevShare": get_revenue_share(params["commercial_rev_share"]), + "disabled": params.get("disabled", False), + "expectMinimumGroupRewardShare": get_revenue_share( + params["expect_minimum_group_reward_share"], + RevShareType.EXPECT_MINIMUM_GROUP_REWARD_SHARE, + ), + "expectGroupRewardPool": params.get( + "expect_group_reward_pool", ZERO_ADDRESS + ), + } - params["hook_data"] = Web3.to_bytes(hexstr=params["hook_data"]) - - default_params.update(params) - - return default_params + return validated_params def verify_commercial_use(self, terms): - if not terms.get("commercial_use", False): - if terms.get("commercial_attribution"): + if not terms.get("commercialUse", False): + if terms.get("commercialAttribution", False): raise ValueError( "Cannot add commercial attribution when commercial use is disabled." ) - if terms.get("commercializer_checker") != ZERO_ADDRESS: + if terms.get("commercializerChecker") != ZERO_ADDRESS: raise ValueError( "Cannot add commercializerChecker when commercial use is disabled." ) - if terms.get("commercial_rev_share", 0) > 0: + if terms.get("commercialRevShare", 0) > 0: raise ValueError( "Cannot add commercial revenue share when commercial use is disabled." ) - if terms.get("commercial_rev_ceiling", 0) > 0: + if terms.get("commercialRevCeiling", 0) > 0: raise ValueError( "Cannot add commercial revenue ceiling when commercial use is disabled." ) - if terms.get("derivative_rev_ceiling", 0) > 0: + if terms.get("derivativeRevCeiling", 0) > 0: raise ValueError( "Cannot add derivative revenue ceiling when commercial use is disabled." ) - if terms.get("royalty_policy") != ZERO_ADDRESS: + if terms.get("royaltyPolicy") != ZERO_ADDRESS: raise ValueError( "Cannot add commercial royalty policy when commercial use is disabled." ) else: - if terms.get("royalty_policy") == ZERO_ADDRESS: + if terms.get("royaltyPolicy") == ZERO_ADDRESS: raise ValueError( "Royalty policy is required when commercial use is enabled." ) def verify_derivatives(self, terms): - if not terms.get("derivatives_allowed", False): - if terms.get("derivatives_attribution"): + if not terms.get("derivativesAllowed", False): + if terms.get("derivativesAttribution", False): raise ValueError( "Cannot add derivative attribution when derivative use is disabled." ) - if terms.get("derivatives_approval"): + if terms.get("derivativesApproval", False): raise ValueError( "Cannot add derivative approval when derivative use is disabled." ) - if terms.get("derivatives_reciprocal"): + if terms.get("derivativesReciprocal", False): raise ValueError( "Cannot add derivative reciprocal when derivative use is disabled." ) - if terms.get("derivative_rev_ceiling", 0) > 0: + if terms.get("derivativeRevCeiling", 0) > 0: raise ValueError( "Cannot add derivative revenue ceiling when derivative use is disabled." ) diff --git a/tests/integration/test_integration_group.py b/tests/integration/test_integration_group.py index a60f790b..fce3a173 100644 --- a/tests/integration/test_integration_group.py +++ b/tests/integration/test_integration_group.py @@ -369,7 +369,7 @@ def setup_royalty_collection(self, story_client, nft_collection): "licensing_hook": ZERO_ADDRESS, "commercial_rev_share": 10, "disabled": False, - "expect_minimum_group_reward_share": 0, + "expect_minimum_group_reward_share": 10, "expect_group_reward_pool": EVEN_SPLIT_GROUP_POOL, }, } diff --git a/tests/unit/fixtures/data.py b/tests/unit/fixtures/data.py index 42203f70..9e2505d9 100644 --- a/tests/unit/fixtures/data.py +++ b/tests/unit/fixtures/data.py @@ -4,3 +4,33 @@ # STATE as bytes32 (32 bytes = 64 hex characters) STATE = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" IP_ID = "0xFEB4eE75600768635010D80D56a5711268D26DaB" +LICENSE_TERMS = { + "royalty_policy": ADDRESS, + "commercial_rev_share": 19, + "currency": ADDRESS, + "default_minting_fee": 10, + "expiration": 100, + "commercial_use": True, + "commercial_attribution": True, + "commercializer_checker": True, + "commercializer_checker_data": ADDRESS, + "derivatives_allowed": True, + "derivatives_attribution": True, + "derivatives_approval": True, + "derivatives_reciprocal": True, + "derivative_rev_ceiling": 100, + "uri": "https://example.com", + "transferable": True, + "expect_minimum_group_reward_share": 10, +} + +LICENSING_CONFIG = { + "is_set": True, + "minting_fee": 10, + "licensing_hook": ADDRESS, + "hook_data": ADDRESS, + "commercial_rev_share": 10, + "disabled": False, + "expect_minimum_group_reward_share": 10, + "expect_group_reward_pool": ADDRESS, +} diff --git a/tests/unit/resources/test_license.py b/tests/unit/resources/test_license.py index 5ce788e7..f2151f0f 100644 --- a/tests/unit/resources/test_license.py +++ b/tests/unit/resources/test_license.py @@ -187,7 +187,7 @@ def test_register_pil_terms_commercial_rev_share_error_more_than_100( ): with pytest.raises( - ValueError, match="CommercialRevShare should be between 0 and 100." + ValueError, match="commercial_rev_share should be between 0 and 100." ): license_client.register_pil_terms( transferable=False, @@ -225,7 +225,7 @@ def test_register_pil_terms_commercial_rev_share_error_less_than_0( ): with pytest.raises( - ValueError, match="CommercialRevShare should be between 0 and 100." + ValueError, match="commercial_rev_share should be between 0 and 100." ): license_client.register_pil_terms( transferable=False, diff --git a/tests/unit/utils/test_derivative_data.py b/tests/unit/utils/test_derivative_data.py index e90b4695..7014ea61 100644 --- a/tests/unit/utils/test_derivative_data.py +++ b/tests/unit/utils/test_derivative_data.py @@ -309,7 +309,7 @@ def test_validate_max_revenue_share_is_less_than_0( ): with mock_ip_asset_registry_client(), mock_license_registry_client(): with raises( - ValueError, match="The maxRevenueShare must be between 0 and 100." + ValueError, match="max_revenue_share must be between 0 and 100." ): DerivativeData.from_input( web3=mock_web3, @@ -331,7 +331,7 @@ def test_validate_max_revenue_share_is_greater_than_100( ): with mock_is_checksum_address(), mock_ip_asset_registry_client(), mock_license_registry_client(): with raises( - ValueError, match="The maxRevenueShare must be between 0 and 100." + ValueError, match="max_revenue_share must be between 0 and 100." ): DerivativeData.from_input( web3=mock_web3, From eccee640d05d1a8b1b3b50d88fc6f2474d7d29a9 Mon Sep 17 00:00:00 2001 From: Bonnie Date: Tue, 12 Aug 2025 19:30:39 +0800 Subject: [PATCH 2/2] Add unit tests including mint and register_ip_and_attach_pil_terms --- tests/unit/resources/test_ip_asset.py | 174 +++++++++++++++++++++++++- 1 file changed, 172 insertions(+), 2 deletions(-) diff --git a/tests/unit/resources/test_ip_asset.py b/tests/unit/resources/test_ip_asset.py index 45c7e02e..4f9d9744 100644 --- a/tests/unit/resources/test_ip_asset.py +++ b/tests/unit/resources/test_ip_asset.py @@ -5,7 +5,15 @@ from story_protocol_python_sdk.resources.IPAsset import IPAsset from story_protocol_python_sdk.utils.constants import ZERO_HASH from story_protocol_python_sdk.utils.derivative_data import DerivativeDataInput -from tests.unit.fixtures.data import ADDRESS, CHAIN_ID, IP_ID, TX_HASH +from story_protocol_python_sdk.utils.ip_metadata import IPMetadata +from tests.unit.fixtures.data import ( + ADDRESS, + CHAIN_ID, + IP_ID, + LICENSE_TERMS, + LICENSING_CONFIG, + TX_HASH, +) @pytest.fixture(scope="class") @@ -39,7 +47,21 @@ def _mock(is_registered: bool = False): def mock_parse_ip_registered_event(ip_asset): def _mock(): return patch.object( - ip_asset, "_parse_tx_ip_registered_event", return_value={"ip_id": IP_ID} + ip_asset, + "_parse_tx_ip_registered_event", + return_value={"ip_id": IP_ID, "token_id": 1}, + ) + + return _mock + + +@pytest.fixture(scope="class") +def mock_parse_tx_license_terms_attached_event(ip_asset): + def _mock(): + return patch.object( + ip_asset, + "_parse_tx_license_terms_attached_event", + return_value=[1, 2], ) return _mock @@ -172,3 +194,151 @@ def test_success( ) assert result["tx_hash"] == TX_HASH.hex() assert result["ip_id"] == IP_ID + + +class TestMint: + def test_mint_successful(self, ip_asset): + result = ip_asset.mint( + nft_contract=ADDRESS, + to_address=ADDRESS, + metadata_uri="", + metadata_hash=ZERO_HASH, + ) + assert result == f"0x{TX_HASH.hex()}" + + def test_mint_failed_transaction(self, ip_asset): + with patch.object(ip_asset.web3.eth, "send_raw_transaction") as mock_send: + mock_send.side_effect = Exception("Transaction failed") + with pytest.raises(Exception, match="Transaction failed"): + ip_asset.mint( + nft_contract=ADDRESS, + to_address=ADDRESS, + metadata_uri="", + metadata_hash=ZERO_HASH, + allow_duplicates=False, + ) + + +class TestRegisterIpAndAttachPilTerms: + def test_token_id_is_already_registered( + self, ip_asset, mock_get_ip_id, mock_is_registered + ): + with mock_get_ip_id(), mock_is_registered(True): + with pytest.raises( + ValueError, match="The NFT with id 3 is already registered as IP." + ): + ip_asset.register_ip_and_attach_pil_terms( + nft_contract=ADDRESS, + token_id=3, + license_terms_data=[], + ) + + def test_royalty_policy_commercial_rev_share_is_less_than_0( + self, ip_asset, mock_get_ip_id, mock_is_registered + ): + with mock_get_ip_id(), mock_is_registered(): + with pytest.raises( + ValueError, match="commercial_rev_share should be between 0 and 100." + ): + ip_asset.register_ip_and_attach_pil_terms( + nft_contract=ADDRESS, + token_id=3, + license_terms_data=[ + { + "terms": { + **LICENSE_TERMS, + "commercial_rev_share": -1, + }, + } + ], + ) + + def test_transaction_to_be_called_with_correct_parameters( + self, + ip_asset: IPAsset, + mock_get_ip_id, + mock_is_registered, + mock_parse_ip_registered_event, + mock_parse_tx_license_terms_attached_event, + mock_signature_related_methods, + ): + with mock_get_ip_id(), mock_is_registered(), mock_parse_ip_registered_event(), mock_parse_tx_license_terms_attached_event(), mock_signature_related_methods(): + with patch.object( + ip_asset.license_attachment_workflows_client, + "build_registerIpAndAttachPILTerms_transaction", + ) as mock_build_registerIpAndAttachPILTerms_transaction: + + ip_asset.register_ip_and_attach_pil_terms( + nft_contract=ADDRESS, + token_id=3, + license_terms_data=[ + { + "terms": LICENSE_TERMS, + "licensing_config": LICENSING_CONFIG, + } + ], + ) + call_args = mock_build_registerIpAndAttachPILTerms_transaction.call_args[0] + assert call_args[0] == ADDRESS + assert call_args[1] == 3 + assert call_args[2] == IPMetadata.from_input().get_validated_data() + assert call_args[3] == [ + { + "terms": { + "transferable": True, + "royaltyPolicy": "0x1234567890123456789012345678901234567890", + "defaultMintingFee": 10, + "expiration": 100, + "commercialUse": True, + "commercialAttribution": True, + "commercializerChecker": True, + "commercializerCheckerData": b"mock_bytes", + "commercialRevShare": 19000000, + "commercialRevCeiling": 0, + "derivativesAllowed": True, + "derivativesAttribution": True, + "derivativesApproval": True, + "derivativesReciprocal": True, + "derivativeRevCeiling": 100, + "currency": "0x1234567890123456789012345678901234567890", + "uri": "https://example.com", + }, + "licensingConfig": { + "isSet": True, + "mintingFee": 10, + "hookData": b"mock_bytes", + "licensingHook": "0x1234567890123456789012345678901234567890", + "commercialRevShare": 10000000, + "disabled": False, + "expectMinimumGroupRewardShare": 10000000, + "expectGroupRewardPool": "0x1234567890123456789012345678901234567890", + }, + } + ] + + def test_success( + self, + ip_asset: IPAsset, + mock_get_ip_id, + mock_is_registered, + mock_parse_ip_registered_event, + mock_signature_related_methods, + mock_parse_tx_license_terms_attached_event, + ): + with mock_get_ip_id(), mock_is_registered(), mock_parse_ip_registered_event(), mock_parse_tx_license_terms_attached_event(), mock_signature_related_methods(): + result = ip_asset.register_ip_and_attach_pil_terms( + nft_contract=ADDRESS, + token_id=3, + license_terms_data=[ + { + "terms": LICENSE_TERMS, + "licensing_config": LICENSING_CONFIG, + } + ], + ) + assert result == { + "tx_hash": TX_HASH.hex(), + "ip_id": IP_ID, + "license_terms_ids": [1, 2], + "token_id": 1, + }