diff --git a/src/story_protocol_python_sdk/__init__.py b/src/story_protocol_python_sdk/__init__.py index 8520c85..590f0e4 100644 --- a/src/story_protocol_python_sdk/__init__.py +++ b/src/story_protocol_python_sdk/__init__.py @@ -7,6 +7,18 @@ from .resources.Royalty import Royalty from .resources.WIP import WIP from .story_client import StoryClient +from .types.common import AccessPermission +from .utils.constants import ( + DEFAULT_FUNCTION_SELECTOR, + MAX_ROYALTY_TOKEN, + ROYALTY_POLICY_LAP_ADDRESS, + ROYALTY_POLICY_LRP_ADDRESS, + ZERO_ADDRESS, + ZERO_FUNC, + ZERO_HASH, +) +from .utils.derivative_data import DerivativeDataInput +from .utils.ip_metadata import IPMetadataInput __all__ = [ "StoryClient", @@ -16,4 +28,15 @@ "IPAccount", "Dispute", "WIP", + "AccessPermission", + "DerivativeDataInput", + "IPMetadataInput", + # Constants + "ZERO_ADDRESS", + "ZERO_HASH", + "ROYALTY_POLICY_LAP_ADDRESS", + "ROYALTY_POLICY_LRP_ADDRESS", + "ZERO_FUNC", + "DEFAULT_FUNCTION_SELECTOR", + "MAX_ROYALTY_TOKEN", ] diff --git a/src/story_protocol_python_sdk/abi/AccessController/AccessController_client.py b/src/story_protocol_python_sdk/abi/AccessController/AccessController_client.py index 0b6118f..30259b1 100644 --- a/src/story_protocol_python_sdk/abi/AccessController/AccessController_client.py +++ b/src/story_protocol_python_sdk/abi/AccessController/AccessController_client.py @@ -48,13 +48,13 @@ def build_setAllPermissions_transaction( ipAccount, signer, permission ).build_transaction(tx_params) - def setTransientBatchPermissions(self, permissions): - return self.contract.functions.setTransientBatchPermissions( + def setBatchTransientPermissions(self, permissions): + return self.contract.functions.setBatchTransientPermissions( permissions ).transact() - def build_setTransientBatchPermissions_transaction(self, permissions, tx_params): - return self.contract.functions.setTransientBatchPermissions( + def build_setBatchTransientPermissions_transaction(self, permissions, tx_params): + return self.contract.functions.setBatchTransientPermissions( permissions ).build_transaction(tx_params) diff --git a/src/story_protocol_python_sdk/abi/DerivativeWorkflows/DerivativeWorkflows_client.py b/src/story_protocol_python_sdk/abi/DerivativeWorkflows/DerivativeWorkflows_client.py index b1162b6..8cc2331 100644 --- a/src/story_protocol_python_sdk/abi/DerivativeWorkflows/DerivativeWorkflows_client.py +++ b/src/story_protocol_python_sdk/abi/DerivativeWorkflows/DerivativeWorkflows_client.py @@ -35,3 +35,23 @@ def __init__(self, web3: Web3): with open(abi_path, "r") as abi_file: abi = json.load(abi_file) self.contract = self.web3.eth.contract(address=contract_address, abi=abi) + + def registerIpAndMakeDerivative( + self, nftContract, tokenId, derivData, ipMetadata, sigMetadataAndRegister + ): + return self.contract.functions.registerIpAndMakeDerivative( + nftContract, tokenId, derivData, ipMetadata, sigMetadataAndRegister + ).transact() + + def build_registerIpAndMakeDerivative_transaction( + self, + nftContract, + tokenId, + derivData, + ipMetadata, + sigMetadataAndRegister, + tx_params, + ): + return self.contract.functions.registerIpAndMakeDerivative( + nftContract, tokenId, derivData, ipMetadata, sigMetadataAndRegister + ).build_transaction(tx_params) diff --git a/src/story_protocol_python_sdk/abi/jsons/AccessController.json b/src/story_protocol_python_sdk/abi/jsons/AccessController.json index 85f0230..9956dcd 100644 --- a/src/story_protocol_python_sdk/abi/jsons/AccessController.json +++ b/src/story_protocol_python_sdk/abi/jsons/AccessController.json @@ -6,27 +6,15 @@ "name": "ipAccountRegistry", "type": "address" }, - { - "internalType": "address", - "name": "moduleRegistry", - "type": "address" - } + { "internalType": "address", "name": "moduleRegistry", "type": "address" } ], "stateMutability": "nonpayable", "type": "constructor" }, { "inputs": [ - { - "internalType": "address", - "name": "signer", - "type": "address" - }, - { - "internalType": "address", - "name": "to", - "type": "address" - } + { "internalType": "address", "name": "signer", "type": "address" }, + { "internalType": "address", "name": "to", "type": "address" } ], "name": "AccessController__BothCallerAndRecipientAreNotRegisteredModule", "type": "error" @@ -38,11 +26,7 @@ }, { "inputs": [ - { - "internalType": "address", - "name": "ipAccount", - "type": "address" - } + { "internalType": "address", "name": "ipAccount", "type": "address" } ], "name": "AccessController__IPAccountIsNotValid", "type": "error" @@ -54,42 +38,18 @@ }, { "inputs": [ - { - "internalType": "address", - "name": "ipAccount", - "type": "address" - }, - { - "internalType": "address", - "name": "owner", - "type": "address" - } + { "internalType": "address", "name": "ipAccount", "type": "address" }, + { "internalType": "address", "name": "owner", "type": "address" } ], "name": "AccessController__OwnerIsIPAccount", "type": "error" }, { "inputs": [ - { - "internalType": "address", - "name": "ipAccount", - "type": "address" - }, - { - "internalType": "address", - "name": "signer", - "type": "address" - }, - { - "internalType": "address", - "name": "to", - "type": "address" - }, - { - "internalType": "bytes4", - "name": "func", - "type": "bytes4" - } + { "internalType": "address", "name": "ipAccount", "type": "address" }, + { "internalType": "address", "name": "signer", "type": "address" }, + { "internalType": "address", "name": "to", "type": "address" }, + { "internalType": "bytes4", "name": "func", "type": "bytes4" } ], "name": "AccessController__PermissionDenied", "type": "error" @@ -126,106 +86,50 @@ }, { "inputs": [ - { - "internalType": "address", - "name": "authority", - "type": "address" - } + { "internalType": "address", "name": "authority", "type": "address" } ], "name": "AccessManagedInvalidAuthority", "type": "error" }, { "inputs": [ - { - "internalType": "address", - "name": "caller", - "type": "address" - }, - { - "internalType": "uint32", - "name": "delay", - "type": "uint32" - } + { "internalType": "address", "name": "caller", "type": "address" }, + { "internalType": "uint32", "name": "delay", "type": "uint32" } ], "name": "AccessManagedRequiredDelay", "type": "error" }, { "inputs": [ - { - "internalType": "address", - "name": "caller", - "type": "address" - } + { "internalType": "address", "name": "caller", "type": "address" } ], "name": "AccessManagedUnauthorized", "type": "error" }, { "inputs": [ - { - "internalType": "address", - "name": "target", - "type": "address" - } + { "internalType": "address", "name": "target", "type": "address" } ], "name": "AddressEmptyCode", "type": "error" }, { "inputs": [ - { - "internalType": "address", - "name": "implementation", - "type": "address" - } + { "internalType": "address", "name": "implementation", "type": "address" } ], "name": "ERC1967InvalidImplementation", "type": "error" }, - { - "inputs": [], - "name": "ERC1967NonPayable", - "type": "error" - }, - { - "inputs": [], - "name": "EnforcedPause", - "type": "error" - }, - { - "inputs": [], - "name": "ExpectedPause", - "type": "error" - }, - { - "inputs": [], - "name": "FailedCall", - "type": "error" - }, - { - "inputs": [], - "name": "InvalidInitialization", - "type": "error" - }, - { - "inputs": [], - "name": "NotInitializing", - "type": "error" - }, - { - "inputs": [], - "name": "UUPSUnauthorizedCallContext", - "type": "error" - }, + { "inputs": [], "name": "ERC1967NonPayable", "type": "error" }, + { "inputs": [], "name": "EnforcedPause", "type": "error" }, + { "inputs": [], "name": "ExpectedPause", "type": "error" }, + { "inputs": [], "name": "FailedCall", "type": "error" }, + { "inputs": [], "name": "InvalidInitialization", "type": "error" }, + { "inputs": [], "name": "NotInitializing", "type": "error" }, + { "inputs": [], "name": "UUPSUnauthorizedCallContext", "type": "error" }, { "inputs": [ - { - "internalType": "bytes32", - "name": "slot", - "type": "bytes32" - } + { "internalType": "bytes32", "name": "slot", "type": "bytes32" } ], "name": "UUPSUnsupportedProxiableUUID", "type": "error" @@ -410,23 +314,13 @@ { "inputs": [], "name": "UPGRADE_INTERFACE_VERSION", - "outputs": [ - { - "internalType": "string", - "name": "", - "type": "string" - } - ], + "outputs": [{ "internalType": "string", "name": "", "type": "string" }], "stateMutability": "view", "type": "function" }, { "inputs": [ - { - "internalType": "address", - "name": "accessManager", - "type": "address" - } + { "internalType": "address", "name": "accessManager", "type": "address" } ], "name": "__ProtocolPausable_init", "outputs": [], @@ -436,38 +330,16 @@ { "inputs": [], "name": "authority", - "outputs": [ - { - "internalType": "address", - "name": "", - "type": "address" - } - ], + "outputs": [{ "internalType": "address", "name": "", "type": "address" }], "stateMutability": "view", "type": "function" }, { "inputs": [ - { - "internalType": "address", - "name": "ipAccount", - "type": "address" - }, - { - "internalType": "address", - "name": "signer", - "type": "address" - }, - { - "internalType": "address", - "name": "to", - "type": "address" - }, - { - "internalType": "bytes4", - "name": "func", - "type": "bytes4" - } + { "internalType": "address", "name": "ipAccount", "type": "address" }, + { "internalType": "address", "name": "signer", "type": "address" }, + { "internalType": "address", "name": "to", "type": "address" }, + { "internalType": "bytes4", "name": "func", "type": "bytes4" } ], "name": "checkPermission", "outputs": [], @@ -476,113 +348,43 @@ }, { "inputs": [ - { - "internalType": "address", - "name": "ipAccount", - "type": "address" - }, - { - "internalType": "address", - "name": "signer", - "type": "address" - }, - { - "internalType": "address", - "name": "to", - "type": "address" - }, - { - "internalType": "bytes4", - "name": "func", - "type": "bytes4" - } + { "internalType": "address", "name": "ipAccount", "type": "address" }, + { "internalType": "address", "name": "signer", "type": "address" }, + { "internalType": "address", "name": "to", "type": "address" }, + { "internalType": "bytes4", "name": "func", "type": "bytes4" } ], "name": "getPermanentPermission", - "outputs": [ - { - "internalType": "uint8", - "name": "", - "type": "uint8" - } - ], + "outputs": [{ "internalType": "uint8", "name": "", "type": "uint8" }], "stateMutability": "view", "type": "function" }, { "inputs": [ - { - "internalType": "address", - "name": "ipAccount", - "type": "address" - }, - { - "internalType": "address", - "name": "signer", - "type": "address" - }, - { - "internalType": "address", - "name": "to", - "type": "address" - }, - { - "internalType": "bytes4", - "name": "func", - "type": "bytes4" - } + { "internalType": "address", "name": "ipAccount", "type": "address" }, + { "internalType": "address", "name": "signer", "type": "address" }, + { "internalType": "address", "name": "to", "type": "address" }, + { "internalType": "bytes4", "name": "func", "type": "bytes4" } ], "name": "getPermission", - "outputs": [ - { - "internalType": "uint8", - "name": "", - "type": "uint8" - } - ], + "outputs": [{ "internalType": "uint8", "name": "", "type": "uint8" }], "stateMutability": "view", "type": "function" }, { "inputs": [ - { - "internalType": "address", - "name": "ipAccount", - "type": "address" - }, - { - "internalType": "address", - "name": "signer", - "type": "address" - }, - { - "internalType": "address", - "name": "to", - "type": "address" - }, - { - "internalType": "bytes4", - "name": "func", - "type": "bytes4" - } + { "internalType": "address", "name": "ipAccount", "type": "address" }, + { "internalType": "address", "name": "signer", "type": "address" }, + { "internalType": "address", "name": "to", "type": "address" }, + { "internalType": "bytes4", "name": "func", "type": "bytes4" } ], "name": "getTransientPermission", - "outputs": [ - { - "internalType": "uint8", - "name": "", - "type": "uint8" - } - ], + "outputs": [{ "internalType": "uint8", "name": "", "type": "uint8" }], "stateMutability": "view", "type": "function" }, { "inputs": [ - { - "internalType": "address", - "name": "accessManager", - "type": "address" - } + { "internalType": "address", "name": "accessManager", "type": "address" } ], "name": "initialize", "outputs": [], @@ -592,13 +394,7 @@ { "inputs": [], "name": "isConsumingScheduledOp", - "outputs": [ - { - "internalType": "bytes4", - "name": "", - "type": "bytes4" - } - ], + "outputs": [{ "internalType": "bytes4", "name": "", "type": "bytes4" }], "stateMutability": "view", "type": "function" }, @@ -612,46 +408,22 @@ { "inputs": [], "name": "paused", - "outputs": [ - { - "internalType": "bool", - "name": "", - "type": "bool" - } - ], + "outputs": [{ "internalType": "bool", "name": "", "type": "bool" }], "stateMutability": "view", "type": "function" }, { "inputs": [], "name": "proxiableUUID", - "outputs": [ - { - "internalType": "bytes32", - "name": "", - "type": "bytes32" - } - ], + "outputs": [{ "internalType": "bytes32", "name": "", "type": "bytes32" }], "stateMutability": "view", "type": "function" }, { "inputs": [ - { - "internalType": "address", - "name": "ipAccount", - "type": "address" - }, - { - "internalType": "address", - "name": "signer", - "type": "address" - }, - { - "internalType": "uint8", - "name": "permission", - "type": "uint8" - } + { "internalType": "address", "name": "ipAccount", "type": "address" }, + { "internalType": "address", "name": "signer", "type": "address" }, + { "internalType": "uint8", "name": "permission", "type": "uint8" } ], "name": "setAllPermissions", "outputs": [], @@ -660,21 +432,9 @@ }, { "inputs": [ - { - "internalType": "address", - "name": "ipAccount", - "type": "address" - }, - { - "internalType": "address", - "name": "signer", - "type": "address" - }, - { - "internalType": "uint8", - "name": "permission", - "type": "uint8" - } + { "internalType": "address", "name": "ipAccount", "type": "address" }, + { "internalType": "address", "name": "signer", "type": "address" }, + { "internalType": "uint8", "name": "permission", "type": "uint8" } ], "name": "setAllTransientPermissions", "outputs": [], @@ -683,11 +443,7 @@ }, { "inputs": [ - { - "internalType": "address", - "name": "newAuthority", - "type": "address" - } + { "internalType": "address", "name": "newAuthority", "type": "address" } ], "name": "setAuthority", "outputs": [], @@ -698,31 +454,11 @@ "inputs": [ { "components": [ - { - "internalType": "address", - "name": "ipAccount", - "type": "address" - }, - { - "internalType": "address", - "name": "signer", - "type": "address" - }, - { - "internalType": "address", - "name": "to", - "type": "address" - }, - { - "internalType": "bytes4", - "name": "func", - "type": "bytes4" - }, - { - "internalType": "uint8", - "name": "permission", - "type": "uint8" - } + { "internalType": "address", "name": "ipAccount", "type": "address" }, + { "internalType": "address", "name": "signer", "type": "address" }, + { "internalType": "address", "name": "to", "type": "address" }, + { "internalType": "bytes4", "name": "func", "type": "bytes4" }, + { "internalType": "uint8", "name": "permission", "type": "uint8" } ], "internalType": "struct AccessPermission.Permission[]", "name": "permissions", @@ -738,31 +474,11 @@ "inputs": [ { "components": [ - { - "internalType": "address", - "name": "ipAccount", - "type": "address" - }, - { - "internalType": "address", - "name": "signer", - "type": "address" - }, - { - "internalType": "address", - "name": "to", - "type": "address" - }, - { - "internalType": "bytes4", - "name": "func", - "type": "bytes4" - }, - { - "internalType": "uint8", - "name": "permission", - "type": "uint8" - } + { "internalType": "address", "name": "ipAccount", "type": "address" }, + { "internalType": "address", "name": "signer", "type": "address" }, + { "internalType": "address", "name": "to", "type": "address" }, + { "internalType": "bytes4", "name": "func", "type": "bytes4" }, + { "internalType": "uint8", "name": "permission", "type": "uint8" } ], "internalType": "struct AccessPermission.Permission[]", "name": "permissions", @@ -776,31 +492,11 @@ }, { "inputs": [ - { - "internalType": "address", - "name": "ipAccount", - "type": "address" - }, - { - "internalType": "address", - "name": "signer", - "type": "address" - }, - { - "internalType": "address", - "name": "to", - "type": "address" - }, - { - "internalType": "bytes4", - "name": "func", - "type": "bytes4" - }, - { - "internalType": "uint8", - "name": "permission", - "type": "uint8" - } + { "internalType": "address", "name": "ipAccount", "type": "address" }, + { "internalType": "address", "name": "signer", "type": "address" }, + { "internalType": "address", "name": "to", "type": "address" }, + { "internalType": "bytes4", "name": "func", "type": "bytes4" }, + { "internalType": "uint8", "name": "permission", "type": "uint8" } ], "name": "setPermission", "outputs": [], @@ -809,31 +505,11 @@ }, { "inputs": [ - { - "internalType": "address", - "name": "ipAccount", - "type": "address" - }, - { - "internalType": "address", - "name": "signer", - "type": "address" - }, - { - "internalType": "address", - "name": "to", - "type": "address" - }, - { - "internalType": "bytes4", - "name": "func", - "type": "bytes4" - }, - { - "internalType": "uint8", - "name": "permission", - "type": "uint8" - } + { "internalType": "address", "name": "ipAccount", "type": "address" }, + { "internalType": "address", "name": "signer", "type": "address" }, + { "internalType": "address", "name": "to", "type": "address" }, + { "internalType": "bytes4", "name": "func", "type": "bytes4" }, + { "internalType": "uint8", "name": "permission", "type": "uint8" } ], "name": "setTransientPermission", "outputs": [], @@ -854,11 +530,7 @@ "name": "newImplementation", "type": "address" }, - { - "internalType": "bytes", - "name": "data", - "type": "bytes" - } + { "internalType": "bytes", "name": "data", "type": "bytes" } ], "name": "upgradeToAndCall", "outputs": [], diff --git a/src/story_protocol_python_sdk/resources/IPAsset.py b/src/story_protocol_python_sdk/resources/IPAsset.py index d6e6709..2aea9ff 100644 --- a/src/story_protocol_python_sdk/resources/IPAsset.py +++ b/src/story_protocol_python_sdk/resources/IPAsset.py @@ -33,7 +33,14 @@ RegistrationWorkflowsClient, ) from story_protocol_python_sdk.abi.SPGNFTImpl.SPGNFTImpl_client import SPGNFTImplClient +from story_protocol_python_sdk.types.common import AccessPermission from story_protocol_python_sdk.utils.constants import ZERO_ADDRESS, ZERO_HASH +from story_protocol_python_sdk.utils.derivative_data import ( + DerivativeData, + DerivativeDataInput, +) +from story_protocol_python_sdk.utils.function_signature import get_function_signature +from story_protocol_python_sdk.utils.ip_metadata import IPMetadata, IPMetadataInput from story_protocol_python_sdk.utils.license_terms import LicenseTerms from story_protocol_python_sdk.utils.sign import Sign from story_protocol_python_sdk.utils.transaction_utils import build_and_send_transaction @@ -173,7 +180,7 @@ def register( "signer": self.registration_workflows_client.contract.address, "to": self.core_metadata_module_client.contract.address, "func": "setAll(address,string,bytes32,bytes32)", - "permission": 1, + "permission": AccessPermission.ALLOW, } ], ) @@ -690,21 +697,21 @@ def register_ip_and_attach_pil_terms( "ipId": ip_id, "signer": self.license_attachment_workflows_client.contract.address, "to": self.core_metadata_module_client.contract.address, - "permission": 1, # ALLOW + "permission": AccessPermission.ALLOW, "func": "setAll(address,string,bytes32,bytes32)", }, { "ipId": ip_id, "signer": self.license_attachment_workflows_client.contract.address, "to": self.licensing_module_client.contract.address, - "permission": 1, # ALLOW + "permission": AccessPermission.ALLOW, "func": "attachLicenseTerms(address,address,uint256)", }, { "ipId": ip_id, "signer": self.license_attachment_workflows_client.contract.address, "to": self.licensing_module_client.contract.address, - "permission": 1, # ALLOW + "permission": AccessPermission.ALLOW, "func": "setLicensingConfig(address,address,uint256,(bool,uint256,address,bytes,uint32,bool,uint32,address))", }, ], @@ -764,143 +771,86 @@ def register_ip_and_attach_pil_terms( except Exception as e: raise e - # def register_derivative_ip( - # self, - # nft_contract: str, - # token_id: int, - # deriv_data: dict, - # metadata: dict = None, - # deadline: int = None, - # tx_options: dict = None - # ) -> dict: - # """ - # Register the given NFT as a derivative IP with metadata without using - # license tokens. - - # :param nft_contract str: The address of the NFT collection. - # :param token_id int: The ID of the NFT. - # :param deriv_data dict: The derivative data for registerDerivative. - # :param parentIpIds list: The parent IP IDs. - # :param licenseTemplate str: License template address to be used. - # :param licenseTermsIds list: The license terms IDs. - # :param metadata dict: [Optional] Desired IP metadata. - # :param metadataURI str: [Optional] Metadata URI for the IP. - # :param metadataHash str: [Optional] Metadata hash for the IP. - # :param nftMetadataHash str: [Optional] NFT metadata hash. - # :param deadline int: [Optional] Signature deadline in milliseconds. - # :param tx_options dict: [Optional] Transaction options. - # :return dict: Dictionary with the tx hash and IP ID. - # """ - # try: - # ip_id = self._get_ip_id(nft_contract, token_id) - # if self._is_registered(ip_id): - # raise ValueError( - # f"The NFT with id {token_id} is already registered as IP." - # ) - - # if len(deriv_data['parentIpIds']) != len(deriv_data['licenseTermsIds']): - # raise ValueError( - # "Parent IP IDs and license terms IDs must match in quantity." - # ) - # if len(deriv_data['parentIpIds']) not in [1, 2]: - # raise ValueError("There can only be 1 or 2 parent IP IDs.") - - # for parent_ip_id, license_terms_id in zip( - # deriv_data['parentIpIds'], - # deriv_data['licenseTermsIds'] - # ): - # if not self.license_registry_client.hasIpAttachedLicenseTerms( - # parent_ip_id, - # self.pi_license_template_client.contract.address, - # license_terms_id - # ): - # raise ValueError( - # f"License terms id {license_terms_id} must be attached to " - # f"the parent ipId {parent_ip_id} before registering " - # f"derivative." - # ) - - # calculated_deadline = self._get_deadline(deadline=deadline) - # sig_register_signature = self._get_signature( - # ip_id, - # self.licensing_module_client.contract.address, - # calculated_deadline, - # "registerDerivative(address,address[],uint256[],address,bytes)", - # 2 - # ) - - # req_object = { - # 'nftContract': nft_contract, - # 'tokenId': token_id, - # 'derivData': { - # 'parentIpIds': [ - # self.web3.to_checksum_address(id) - # for id in deriv_data['parentIpIds'] - # ], - # 'licenseTermsIds': deriv_data['licenseTermsIds'], - # 'licenseTemplate': self.pi_license_template_client.contract.address, - # 'royaltyContext': ZERO_ADDRESS, - # }, - # 'sigRegister': { - # 'signer': self.web3.to_checksum_address(self.account.address), - # 'deadline': calculated_deadline, - # 'signature': sig_register_signature, - # }, - # 'metadata': { - # 'metadataURI': "", - # 'metadataHash': ZERO_HASH, - # 'nftMetadataHash': ZERO_HASH, - # }, - # 'sigMetadata': { - # 'signer': ZERO_ADDRESS, - # 'deadline': 0, - # 'signature': ZERO_HASH, - # }, - # } - - # if metadata: - # req_object['metadata'].update({ - # 'metadataURI': metadata.get('metadataURI', ""), - # 'metadataHash': metadata.get('metadataHash', ZERO_HASH), - # 'nftMetadataHash': metadata.get('nftMetadataHash', ZERO_HASH), - # }) - - # signature = self._get_signature( - # ip_id, - # self.core_metadata_module_client.contract.address, - # calculated_deadline, - # "setAll(address,string,bytes32,bytes32)", - # 1 - # ) - - # req_object['sigMetadata'] = { - # 'signer': self.web3.to_checksum_address(self.account.address), - # 'deadline': calculated_deadline, - # 'signature': signature, - # } - - # response = build_and_send_transaction( - # self.web3, - # self.account, - # self.derivative_workflows_client.build_registerIpAndMakeDerivative_transaction, # noqa: E501 - # req_object['nftContract'], - # req_object['tokenId'], - # req_object['derivData'], - # req_object['metadata'], - # req_object['sigMetadata'], - # req_object['sigRegister'], - # tx_options=tx_options - # ) - - # ip_registered = self._parse_tx_ip_registered_event(response['tx_receipt']) - - # return { - # 'tx_hash': response['tx_hash'], - # 'ip_id': ip_registered['ip_id'] - # } - - # except Exception as e: - # raise e + def register_derivative_ip( + self, + nft_contract: str, + token_id: int, + deriv_data: DerivativeDataInput, + metadata: IPMetadataInput | None = None, + deadline: int | None = None, + tx_options: dict | None = None, + ) -> dict: + """ + Register the given NFT as a derivative IP with metadata without using + license tokens. + + :param nft_contract str: The address of the NFT collection. + :param token_id int: The ID of the NFT. + :param deriv_data `DerivativeDataInput`: The derivative data for registerDerivative. + :param metadata `IPMetadataInput`: [Optional] Desired IP metadata. + :param deadline int: [Optional] Signature deadline in milliseconds. + :param tx_options dict: [Optional] Transaction options. + :return dict: Dictionary with the tx hash and IP ID. + """ + try: + ip_id = self._get_ip_id(nft_contract, token_id) + if self._is_registered(ip_id): + raise ValueError( + f"The NFT with id {token_id} is already registered as IP." + ) + validated_deriv_data = DerivativeData.from_input( + web3=self.web3, input_data=deriv_data + ).get_validated_data() + calculated_deadline = self.sign_util.get_deadline(deadline=deadline) + sig_register_signature = self.sign_util.get_permission_signature( + ip_id=ip_id, + deadline=calculated_deadline, + state=Web3.to_bytes(0), + permissions=[ + { + "ipId": ip_id, + "signer": self.derivative_workflows_client.contract.address, + "to": self.core_metadata_module_client.contract.address, + "permission": AccessPermission.ALLOW, + "func": get_function_signature( + self.core_metadata_module_client.contract.abi, + "setAll", + ), + }, + { + "ipId": ip_id, + "signer": self.derivative_workflows_client.contract.address, + "to": self.licensing_module_client.contract.address, + "permission": AccessPermission.ALLOW, + "func": get_function_signature( + self.licensing_module_client.contract.abi, + "registerDerivative", + ), + }, + ], + ) + response = build_and_send_transaction( + self.web3, + self.account, + self.derivative_workflows_client.build_registerIpAndMakeDerivative_transaction, + nft_contract, + token_id, + validated_deriv_data, + IPMetadata.from_input(metadata).get_validated_data(), + { + "signer": self.account.address, + "deadline": calculated_deadline, + "signature": sig_register_signature["signature"], + }, + tx_options=tx_options, + ) + + ip_registered = self._parse_tx_ip_registered_event(response["tx_receipt"]) + + return {"tx_hash": response["tx_hash"], "ip_id": ip_registered["ip_id"]} + + except Exception as e: + raise e def _validate_max_rts(self, max_rts: int): """ diff --git a/src/story_protocol_python_sdk/resources/Permission.py b/src/story_protocol_python_sdk/resources/Permission.py index 845cdc9..10a0d26 100644 --- a/src/story_protocol_python_sdk/resources/Permission.py +++ b/src/story_protocol_python_sdk/resources/Permission.py @@ -13,6 +13,7 @@ IPAssetRegistryClient, ) from story_protocol_python_sdk.resources.IPAccount import IPAccount +from story_protocol_python_sdk.types.common import AccessPermission from story_protocol_python_sdk.utils.constants import DEFAULT_FUNCTION_SELECTOR from story_protocol_python_sdk.utils.sign import Sign from story_protocol_python_sdk.utils.validation import validate_address @@ -42,7 +43,7 @@ def set_permission( ip_id: str, signer: str, to: str, - permission: int, + permission: AccessPermission, func: str = DEFAULT_FUNCTION_SELECTOR, tx_options: dict | None = None, ) -> dict: @@ -56,7 +57,7 @@ def set_permission( :param ip_id str: The IP ID of the IP account that grants the permission for `signer`. :param signer str: The address that can call `to` on behalf of the `ip_id`. :param to str: The address that can be called by the `signer`. - :param permission int: The new permission level. + :param permission `AccessPermission`: The new permission level. :param func str: [Optional] The function selector string of `to` that can be called by the `signer` on behalf of the `ipAccount`. :param tx_options dict: [Optional] The transaction options. :return dict: A dictionary with the transaction hash and success status if waiting for transaction. @@ -74,7 +75,7 @@ def set_permission( self.web3.to_checksum_address(signer), self.web3.to_checksum_address(to), Web3.keccak(text=func)[:4] if func else b"\x00\x00\x00\x00", - permission, + permission.value, ], ) @@ -92,14 +93,18 @@ def set_permission( raise Exception(f"Failed to set permission for IP {ip_id}: {str(e)}") def set_all_permissions( - self, ip_id: str, signer: str, permission: int, tx_options: dict | None = None + self, + ip_id: str, + signer: str, + permission: AccessPermission, + tx_options: dict | None = None, ) -> dict: """ Sets permission to a signer for all functions across all modules. :param ip_id str: The IP ID of the IP account that grants the permission. :param signer str: The address that will receive the permissions. - :param permission int: The new permission level. + :param permission `AccessPermission`: The new permission level. :param tx_options dict: [Optional] The transaction options. :return dict: A dictionary with the transaction hash and success status if waiting for transaction. """ @@ -113,7 +118,7 @@ def set_all_permissions( args=[ self.web3.to_checksum_address(ip_id), self.web3.to_checksum_address(signer), - permission, + permission.value, ], ) @@ -137,7 +142,7 @@ def create_set_permission_signature( ip_id: str, signer: str, to: str, - permission: int, + permission: AccessPermission, func: str = DEFAULT_FUNCTION_SELECTOR, deadline: int | None = None, tx_options: dict | None = None, @@ -148,7 +153,7 @@ def create_set_permission_signature( :param ip_id str: The IP ID of the IP account that grants the permission. :param signer str: The address that can call `to` on behalf of the `ip_id`. :param to str: The address that can be called by the `signer`. - :param permission int: The new permission level. + :param permission `AccessPermission`: The new permission level. :param func str: [Optional] The function selector string. :param deadline int: [Optional] The deadline for the signature validity. :param tx_options dict: [Optional] The transaction options. @@ -174,7 +179,7 @@ def create_set_permission_signature( signer, to, Web3.keccak(text=func)[:4] if func else b"\x00\x00\x00\x00", - permission, + permission.value, ], ) diff --git a/src/story_protocol_python_sdk/scripts/config.json b/src/story_protocol_python_sdk/scripts/config.json index 8f580e4..bdc4395 100644 --- a/src/story_protocol_python_sdk/scripts/config.json +++ b/src/story_protocol_python_sdk/scripts/config.json @@ -5,11 +5,9 @@ "contract_address": "0xcCF37d0a503Ee1D4C11208672e622ed3DFB2275a", "functions": [ "PermissionSet", - "setPermission", "setAllPermissions", - "setBatchPermissions", "setTransientPermission", - "setTransientBatchPermissions" + "setBatchTransientPermissions" ] }, { @@ -230,7 +228,7 @@ { "contract_name": "DerivativeWorkflows", "contract_address": "0x9e2d496f72C547C2C535B167e06ED8729B374a4f", - "functions": [] + "functions": ["registerIpAndMakeDerivative"] } ] } diff --git a/src/story_protocol_python_sdk/types/common.py b/src/story_protocol_python_sdk/types/common.py new file mode 100644 index 0000000..314e219 --- /dev/null +++ b/src/story_protocol_python_sdk/types/common.py @@ -0,0 +1,21 @@ +from enum import Enum + + +class RevShareType(Enum): + COMMERCIAL_REVENUE_SHARE = "commercialRevShare" + MAX_REVENUE_SHARE = "maxRevenueShare" + MAX_ALLOWED_REWARD_SHARE = "maxAllowedRewardShare" + + +class AccessPermission(Enum): + """ + Permission level + """ + + # ABSTAIN means having not enough information to make decision at + # current level, deferred decision to up. + ABSTAIN = 0 + # ALLOW means the permission is granted to transaction signer to call the function. + ALLOW = 1 + # DENY means the permission is denied to transaction signer to call the function. + DENY = 2 diff --git a/src/story_protocol_python_sdk/utils/constants.py b/src/story_protocol_python_sdk/utils/constants.py index e16eb8d..b30f276 100644 --- a/src/story_protocol_python_sdk/utils/constants.py +++ b/src/story_protocol_python_sdk/utils/constants.py @@ -4,6 +4,8 @@ ZERO_HASH: HexStr = HexStr( "0x0000000000000000000000000000000000000000000000000000000000000000" ) -ROYALTY_POLICY = "0xBe54FB168b3c982b7AaE60dB6CF75Bd8447b390E" ZERO_FUNC = "0x00000000" DEFAULT_FUNCTION_SELECTOR = "0x00000000" +MAX_ROYALTY_TOKEN = 100000000 +ROYALTY_POLICY_LAP_ADDRESS = "0xBe54FB168b3c982b7AaE60dB6CF75Bd8447b390E" +ROYALTY_POLICY_LRP_ADDRESS = "0x9156E603C949481883B1D3355C6F1132D191FC41" diff --git a/src/story_protocol_python_sdk/utils/derivative_data.py b/src/story_protocol_python_sdk/utils/derivative_data.py new file mode 100644 index 0000000..47aeacc --- /dev/null +++ b/src/story_protocol_python_sdk/utils/derivative_data.py @@ -0,0 +1,159 @@ +from dataclasses import dataclass, field +from typing import List, Optional + +from ens.ens import Address +from web3 import Web3 + +from story_protocol_python_sdk.abi.IPAssetRegistry.IPAssetRegistry_client import ( + IPAssetRegistryClient, +) +from story_protocol_python_sdk.abi.LicenseRegistry.LicenseRegistry_client import ( + LicenseRegistryClient, +) +from story_protocol_python_sdk.abi.PILicenseTemplate.PILicenseTemplate_client import ( + PILicenseTemplateClient, +) +from story_protocol_python_sdk.types.common import RevShareType +from story_protocol_python_sdk.utils.constants import MAX_ROYALTY_TOKEN, ZERO_ADDRESS +from story_protocol_python_sdk.utils.validation import get_revenue_share + + +@dataclass +class DerivativeDataInput: + """ + Input data structure for creating derivative IP assets. + + This type defines the data that users need to provide when creating derivative works. + + Attributes: + parent_ip_ids: List of parent IP asset addresses that this derivative is based on. + license_terms_ids: List of license terms IDs corresponding to each parent IP. + max_minting_fee: [Optional] The maximum minting fee that the caller is willing to pay. if set to 0 then no limit. (default: 0). + max_rts: [Optional] The maximum number of royalty tokens that can be distributed to the external royalty policies. (max: 100,000,000) (default: 100,000,000). + max_revenue_share: [Optional] The maximum revenue share percentage allowed for minting the License Tokens. Must be between 0 and 100 (where 100% represents 100_000_000) (default: 100). + license_template: [Optional] The address of the license template. Defaults to [License Template](https://docs.story.foundation/docs/programmable-ip-license) address if not provided + """ + + parent_ip_ids: List[Address] + license_terms_ids: List[int] + max_minting_fee: int | float = field(default=0) + max_rts: int | float = field(default=MAX_ROYALTY_TOKEN) + max_revenue_share: int = field(default=100) + license_template: Optional[Address] = field(default=None) + + +@dataclass +class DerivativeData: + """Validated derivative data for IP creation.""" + + web3: Web3 + parent_ip_ids: List[str] + license_terms_ids: List[int] + max_minting_fee: int | float + max_rts: int | float + max_revenue_share: int + license_template: Optional[str] + + pi_license_template_client: PILicenseTemplateClient = field(init=False) + ip_asset_registry_client: IPAssetRegistryClient = field(init=False) + license_registry_client: LicenseRegistryClient = field(init=False) + + @classmethod + def from_input( + cls, web3: Web3, input_data: DerivativeDataInput + ) -> "DerivativeData": + """ + Create a DerivativeData instance from DerivativeDataInput. + + Args: + web3: Web3 instance for blockchain interaction + input_data: User-provided derivative data + + Returns: + DerivativeData instance with validated data + """ + return cls( + web3=web3, + parent_ip_ids=input_data.parent_ip_ids, + license_terms_ids=input_data.license_terms_ids, + max_minting_fee=input_data.max_minting_fee, + max_rts=input_data.max_rts, + max_revenue_share=input_data.max_revenue_share, + license_template=input_data.license_template, + ) + + def __post_init__(self): + """Initialize clients and validate data after object creation.""" + + self.pi_license_template_client = PILicenseTemplateClient(self.web3) + self.ip_asset_registry_client = IPAssetRegistryClient(self.web3) + self.license_registry_client = LicenseRegistryClient(self.web3) + + if self.license_template is None: + self.license_template = self.pi_license_template_client.contract.address + self.max_revenue_share = get_revenue_share( + self.max_revenue_share, type=RevShareType.MAX_REVENUE_SHARE + ) + + self.validate_max_minting_fee() + self.validate_max_rts() + self.validate_parent_ip_ids_and_license_terms_ids() + + def validate_parent_ip_ids_and_license_terms_ids(self): + if len(self.parent_ip_ids) == 0: + raise ValueError("The parent IP IDs must be provided.") + + if len(self.license_terms_ids) == 0: + raise ValueError("The license terms IDs must be provided.") + + if len(self.parent_ip_ids) != len(self.license_terms_ids): + raise ValueError( + "The number of parent IP IDs must match the number of license terms IDs." + ) + + total_royalty_percent = 0 + for parent_ip_id, license_terms_id in zip( + self.parent_ip_ids, self.license_terms_ids + ): + if not Web3.is_checksum_address(parent_ip_id): + raise ValueError("The parent IP ID must be a valid address.") + if not self.ip_asset_registry_client.isRegistered(parent_ip_id): + raise ValueError(f"The parent IP ID {parent_ip_id} must be registered.") + if not self.license_registry_client.hasIpAttachedLicenseTerms( + parent_ip_id, self.license_template, license_terms_id + ): + raise ValueError( + f"License terms id {license_terms_id} must be attached to the parent ipId {parent_ip_id} before registering derivative." + ) + royalty_percent = self.license_registry_client.getRoyaltyPercent( + parent_ip_id, self.license_template, license_terms_id + ) + total_royalty_percent += royalty_percent + if ( + self.max_revenue_share != 0 + and total_royalty_percent > self.max_revenue_share + ): + raise ValueError( + f"The total royalty percent for the parent IP {parent_ip_id} is greater than the maximum revenue share {self.max_revenue_share}." + ) + + def validate_max_minting_fee(self): + if self.max_minting_fee < 0: + raise ValueError("The max minting fee must be greater than 0.") + + def validate_max_rts(self): + if self.max_rts < 0 or self.max_rts > MAX_ROYALTY_TOKEN: + raise ValueError( + f"The maxRts must be greater than 0 and less than {MAX_ROYALTY_TOKEN}." + ) + + def get_validated_data(self) -> dict: + return { + "parentIpIds": self.parent_ip_ids, + "licenseTermsIds": self.license_terms_ids, + "maxMintingFee": self.max_minting_fee, + "maxRts": self.max_rts, + "maxRevenueShare": self.max_revenue_share, + "licenseTemplate": self.license_template, + "royaltyContext": ZERO_ADDRESS, + } diff --git a/src/story_protocol_python_sdk/utils/function_signature.py b/src/story_protocol_python_sdk/utils/function_signature.py new file mode 100644 index 0000000..8524cce --- /dev/null +++ b/src/story_protocol_python_sdk/utils/function_signature.py @@ -0,0 +1,57 @@ +from typing import Any, Dict, List + + +def get_function_signature( + abi: List[Dict[str, Any]], + method_name: str, +) -> str: + """ + Gets the function signature from an ABI for a given method name. + + Args: + abi: The contract ABI as a list of dictionaries + method_name: The name of the method to get the signature for + + Returns: + The function signature in standard format (e.g. "methodName(uint256,address)") + """ + + # Filter functions by name and type + functions = [ + item + for item in abi + if item.get("type") == "function" and item.get("name") == method_name + ] + + if len(functions) == 0: + raise ValueError(f"Method {method_name} not found in ABI.") + + # Get the target function + func = functions[0] + + def get_type_string(input_param: Dict[str, Any]) -> str: + """ + Recursively get the type string for a parameter. + + Args: + input_param: The ABI parameter as a dictionary + + Returns: + The type string representation + """ + param_type = input_param["type"] + + if param_type.startswith("tuple"): + components = input_param.get("components", []) + if components: + component_types = ",".join(get_type_string(comp) for comp in components) + return f"({component_types})" + else: + return "()" # Empty tuple + return param_type + + # Build the function signature + inputs = ",".join( + get_type_string(input_param) for input_param in func.get("inputs", []) + ) + return f"{method_name}({inputs})" diff --git a/src/story_protocol_python_sdk/utils/ip_metadata.py b/src/story_protocol_python_sdk/utils/ip_metadata.py new file mode 100644 index 0000000..4242d79 --- /dev/null +++ b/src/story_protocol_python_sdk/utils/ip_metadata.py @@ -0,0 +1,79 @@ +from dataclasses import dataclass, field + +from ens.ens import HexStr + +from story_protocol_python_sdk.utils.constants import ZERO_HASH + + +@dataclass +class IPMetadataInput: + """ + Input data structure for IP metadata. + + This type defines the data that users need to provide when setting IP metadata. + + Attributes: + ip_metadata_uri: [Optional] URI for IP metadata (default: ""). + ip_metadata_hash: [Optional] Hash for IP metadata (default: ZERO_HASH). + nft_metadata_uri: [Optional] URI for NFT metadata (default: ""). + nft_metadata_hash: [Optional] Hash for NFT metadata (default: ZERO_HASH). + """ + + ip_metadata_uri: str = field(default="") + ip_metadata_hash: HexStr = field(default=ZERO_HASH) + nft_metadata_uri: str = field(default="") + nft_metadata_hash: HexStr = field(default=ZERO_HASH) + + +@dataclass +class IPMetadata: + """Validated IP metadata for IP asset operations.""" + + ip_metadata_uri: str + ip_metadata_hash: HexStr + nft_metadata_uri: str + nft_metadata_hash: HexStr + + @classmethod + def from_input(cls, input_data: IPMetadataInput | None = None) -> "IPMetadata": + """ + Create an IPMetadata instance from IPMetadataInput. + + Args: + input_data: User-provided IP metadata + + Returns: + IPMetadata instance with validated data + """ + if input_data is None: + return cls( + ip_metadata_uri="", + ip_metadata_hash=ZERO_HASH, + nft_metadata_uri="", + nft_metadata_hash=ZERO_HASH, + ) + + return cls( + ip_metadata_uri=input_data.ip_metadata_uri, + ip_metadata_hash=input_data.ip_metadata_hash, + nft_metadata_uri=input_data.nft_metadata_uri, + nft_metadata_hash=input_data.nft_metadata_hash, + ) + + def __post_init__(self): + """Validate data after object creation.""" + self.get_validated_data() + + def get_validated_data(self) -> dict: + """ + Get the metadata as a dictionary in the format expected by the blockchain. + + Returns: + Dictionary with validated metadata fields + """ + return { + "ipMetadataURI": self.ip_metadata_uri, + "ipMetadataHash": self.ip_metadata_hash, + "nftMetadataURI": self.nft_metadata_uri, + "nftMetadataHash": self.nft_metadata_hash, + } diff --git a/src/story_protocol_python_sdk/utils/license_terms.py b/src/story_protocol_python_sdk/utils/license_terms.py index 05cd049..f660f1a 100644 --- a/src/story_protocol_python_sdk/utils/license_terms.py +++ b/src/story_protocol_python_sdk/utils/license_terms.py @@ -6,7 +6,10 @@ from story_protocol_python_sdk.abi.RoyaltyModule.RoyaltyModule_client import ( RoyaltyModuleClient, ) -from story_protocol_python_sdk.utils.constants import ROYALTY_POLICY, ZERO_ADDRESS +from story_protocol_python_sdk.utils.constants import ( + ROYALTY_POLICY_LAP_ADDRESS, + ZERO_ADDRESS, +) class LicenseTerms: @@ -51,7 +54,7 @@ def get_license_term_by_type(self, type, term=None): ) if term["royaltyPolicyAddress"] is None: - term["royaltyPolicyAddress"] = ROYALTY_POLICY + term["royaltyPolicyAddress"] = ROYALTY_POLICY_LAP_ADDRESS license_terms.update( { diff --git a/src/story_protocol_python_sdk/utils/sign.py b/src/story_protocol_python_sdk/utils/sign.py index 2e63bd6..e5be696 100644 --- a/src/story_protocol_python_sdk/utils/sign.py +++ b/src/story_protocol_python_sdk/utils/sign.py @@ -46,7 +46,6 @@ def get_signature( execute_data = self.ip_account_client.contract.encode_abi( abi_element_identifier="execute", args=[to, 0, encode_data] ) - # expected_state = nonce expected_state = Web3.keccak( encode( @@ -91,7 +90,7 @@ def get_signature( signed_message = Account.sign_message(signable_message, self.account.key) return { - "signature": signed_message.signature.hex(), + "signature": "0x" + signed_message.signature.hex(), "nonce": expected_state, } @@ -159,7 +158,7 @@ def get_permission_signature( if permissions[0].get("func") else b"\x00\x00\x00\x00" ), - permissions[0]["permission"], + permissions[0]["permission"].value, ], ) else: @@ -175,7 +174,7 @@ def get_permission_signature( if p.get("func") else b"\x00\x00\x00\x00" ), - "permission": p["permission"], + "permission": p["permission"].value, } formatted_permissions.append(formatted_permission) diff --git a/src/story_protocol_python_sdk/utils/validation.py b/src/story_protocol_python_sdk/utils/validation.py index fdaaee1..5d0155b 100644 --- a/src/story_protocol_python_sdk/utils/validation.py +++ b/src/story_protocol_python_sdk/utils/validation.py @@ -1,5 +1,8 @@ from web3 import Web3 +from story_protocol_python_sdk.types.common import RevShareType +from story_protocol_python_sdk.utils.constants import MAX_ROYALTY_TOKEN + def validate_address(address: str) -> str: """ @@ -12,3 +15,20 @@ def validate_address(address: str) -> str: if not Web3.is_address(address): raise ValueError(f"Invalid address: {address}.") return address + + +def get_revenue_share( + revShare: int, + type: RevShareType = RevShareType.COMMERCIAL_REVENUE_SHARE, +) -> int: + """ + Convert revenue share percentage to token amount. + + :param revShare int: Revenue share percentage between 0-100 + :param type RevShareType: Type of revenue share + :return int: Revenue share token amount + """ + if revShare < 0 or revShare > 100: + raise ValueError(f"The {type.value} must be between 0 and 100.") + + return (revShare * MAX_ROYALTY_TOKEN) // 100 diff --git a/tests/integration/test_integration_ip_asset.py b/tests/integration/test_integration_ip_asset.py index ff9a5a0..26dd29d 100644 --- a/tests/integration/test_integration_ip_asset.py +++ b/tests/integration/test_integration_ip_asset.py @@ -3,6 +3,8 @@ import pytest from story_protocol_python_sdk.story_client import StoryClient +from story_protocol_python_sdk.utils.derivative_data import DerivativeDataInput +from story_protocol_python_sdk.utils.ip_metadata import IPMetadataInput from .setup_for_integration import ( PIL_LICENSE_TEMPLATE, @@ -280,7 +282,7 @@ def parent_ip_and_license_terms(self, story_client: StoryClient, nft_collection) "commercial_attribution": False, "commercializer_checker": ZERO_ADDRESS, "commercializer_checker_data": ZERO_ADDRESS, - "commercial_rev_share": 90, + "commercial_rev_share": 50, "commercial_rev_ceiling": 0, "derivatives_allowed": True, "derivatives_attribution": True, @@ -327,25 +329,76 @@ def parent_ip_and_license_terms(self, story_client: StoryClient, nft_collection) # assert isinstance(response['ip_id'], str) # assert response['ip_id'] != '' - # def test_register_derivative_ip(self, story_client, parent_ip_id, license_terms_id): - # token_child_id = mint_by_spg(MockERC721, story_client.web3, story_client.account) - - # result = story_client.IPAsset.register_derivative_ip( - # nft_contract=MockERC721, - # token_id=token_child_id, - # deriv_data={ - # 'parentIpIds': [parent_ip_id], - # 'licenseTermsIds': [license_terms_id], - # 'maxMintingFee': 0, - # 'maxRts': 5 * 10**6, - # 'maxRevenueShare': 0 - # }, - # deadline=1000, - # tx_options={'waitForTransaction': True} - # ) + def test_register_derivative_ip( + self, story_client: StoryClient, parent_ip_and_license_terms, nft_collection + ): + token_child_id = mint_by_spg( + nft_collection, story_client.web3, story_client.account + ) + # Register another IP asset with PIL terms + second_ip_id_response = ( + story_client.IPAsset.mint_and_register_ip_asset_with_pil_terms( + spg_nft_contract=nft_collection, + terms=[ + { + "terms": { + "transferable": True, + "royalty_policy": ROYALTY_POLICY, + "default_minting_fee": 0, + "expiration": 0, + "commercial_use": True, + "commercial_attribution": False, + "commercializer_checker": ZERO_ADDRESS, + "commercializer_checker_data": ZERO_ADDRESS, + "commercial_rev_share": 50, + "commercial_rev_ceiling": 0, + "derivatives_allowed": True, + "derivatives_attribution": True, + "derivatives_approval": False, + "derivatives_reciprocal": True, + "derivative_rev_ceiling": 0, + "currency": MockERC20, + "uri": "", + }, + "licensing_config": { + "is_set": True, + "minting_fee": 0, + "hook_data": ZERO_ADDRESS, + "licensing_hook": ZERO_ADDRESS, + "commercial_rev_share": 0, + "disabled": False, + "expect_minimum_group_reward_share": 0, + "expect_group_reward_pool": ZERO_ADDRESS, + }, + } + ], + allow_duplicates=True, + ) + ) - # assert isinstance(result['tx_hash'], str) and result['tx_hash'] - # assert isinstance(result['ip_id'], str) and result['ip_id'] + result = story_client.IPAsset.register_derivative_ip( + nft_contract=nft_collection, + token_id=token_child_id, + deriv_data=DerivativeDataInput( + parent_ip_ids=[ + parent_ip_and_license_terms["parent_ip_id"], + second_ip_id_response["ip_id"], + ], + license_terms_ids=[ + parent_ip_and_license_terms["license_terms_id"], + second_ip_id_response["license_terms_ids"][0], + ], + ), + metadata=IPMetadataInput( + nft_metadata_uri="https://ipfs.io/ipfs/Qm...", + nft_metadata_hash=web3.to_hex( + web3.keccak(text="test-nft-metadata-hash") + ), + ), + deadline=1000, + ) + assert isinstance(result["tx_hash"], str) and result["tx_hash"] + assert isinstance(result["ip_id"], str) and result["ip_id"] def test_register_ip_and_attach_pil_terms( self, story_client: StoryClient, nft_collection, parent_ip_and_license_terms diff --git a/tests/integration/test_integration_permission.py b/tests/integration/test_integration_permission.py index ae0e06e..2f8c729 100644 --- a/tests/integration/test_integration_permission.py +++ b/tests/integration/test_integration_permission.py @@ -3,6 +3,7 @@ import pytest from story_protocol_python_sdk.story_client import StoryClient +from story_protocol_python_sdk.types.common import AccessPermission from .setup_for_integration import ( CORE_METADATA_MODULE, @@ -30,7 +31,7 @@ def test_set_permission(self, story_client: StoryClient, ip_id): ip_id=ip_id, signer=account.address, to=CORE_METADATA_MODULE, - permission=1, # ALLOW + permission=AccessPermission.ALLOW, func="function setAll(address,string,bytes32,bytes32)", ) @@ -42,7 +43,7 @@ def test_set_permission(self, story_client: StoryClient, ip_id): def test_set_all_permissions(self, story_client: StoryClient, ip_id): """Test setting all permissions successfully.""" response = story_client.Permission.set_all_permissions( - ip_id=ip_id, signer=account.address, permission=1 # ALLOW + ip_id=ip_id, signer=account.address, permission=AccessPermission.ALLOW ) assert response is not None @@ -59,7 +60,7 @@ def test_create_set_permission_signature(self, story_client: StoryClient, ip_id) signer=account.address, to=CORE_METADATA_MODULE, func="setAll(address,string,bytes32,bytes32)", - permission=1, # ALLOW + permission=AccessPermission.ALLOW, deadline=deadline, ) @@ -77,7 +78,7 @@ def test_set_permission_invalid_ip(self, story_client: StoryClient): ip_id=unregistered_ip, signer=account.address, to=CORE_METADATA_MODULE, - permission=1, + permission=AccessPermission.ALLOW, ) assert f"IP id with {unregistered_ip} is not registered" in str(exc_info.value) @@ -91,7 +92,7 @@ def test_set_permission_invalid_addresses(self, story_client: StoryClient, ip_id ip_id=ip_id, signer=invalid_signer, to=CORE_METADATA_MODULE, - permission=1, # ALLOW + permission=AccessPermission.ALLOW, ) assert "invalid address" in str(exc_info.value).lower() @@ -103,7 +104,7 @@ def test_set_permission_invalid_addresses(self, story_client: StoryClient, ip_id ip_id=ip_id, signer=account.address, to=invalid_to, - permission=1, # ALLOW + permission=AccessPermission.ALLOW, ) assert "invalid address" in str(exc_info.value).lower() @@ -114,7 +115,7 @@ def test_set_permission_invalid_addresses(self, story_client: StoryClient, ip_id ip_id=ip_id, signer=lowercase_address, to=CORE_METADATA_MODULE, - permission=1, + permission=AccessPermission.ALLOW, ) assert "tx_hash" in response except Exception as e: @@ -124,15 +125,11 @@ def test_set_permission_invalid_addresses(self, story_client: StoryClient, ip_id def test_different_permission_levels(self, story_client: StoryClient, ip_id): """Test setting and changing different permission levels.""" - DISALLOW = 0 - ALLOW = 1 - ABSTAIN = 2 - response = story_client.Permission.set_permission( ip_id=ip_id, signer=account.address, to=CORE_METADATA_MODULE, - permission=DISALLOW, + permission=AccessPermission.ABSTAIN, func="function setAll(address,string,bytes32,bytes32)", ) @@ -145,7 +142,7 @@ def test_different_permission_levels(self, story_client: StoryClient, ip_id): ip_id=ip_id, signer=account.address, to=CORE_METADATA_MODULE, - permission=ALLOW, + permission=AccessPermission.ALLOW, func="function setAll(address,string,bytes32,bytes32)", ) @@ -156,7 +153,7 @@ def test_different_permission_levels(self, story_client: StoryClient, ip_id): ip_id=ip_id, signer=account.address, to=CORE_METADATA_MODULE, - permission=ABSTAIN, + permission=AccessPermission.DENY, func="function setAll(address,string,bytes32,bytes32)", ) @@ -164,14 +161,14 @@ def test_different_permission_levels(self, story_client: StoryClient, ip_id): assert "tx_hash" in response response = story_client.Permission.set_all_permissions( - ip_id=ip_id, signer=account.address, permission=DISALLOW + ip_id=ip_id, signer=account.address, permission=AccessPermission.ABSTAIN ) assert response is not None assert "tx_hash" in response response = story_client.Permission.set_all_permissions( - ip_id=ip_id, signer=account.address, permission=ABSTAIN + ip_id=ip_id, signer=account.address, permission=AccessPermission.DENY ) assert response is not None @@ -179,13 +176,11 @@ def test_different_permission_levels(self, story_client: StoryClient, ip_id): def test_different_function_selectors(self, story_client: StoryClient, ip_id): """Test setting permissions with different function selectors.""" - ALLOW = 1 - response = story_client.Permission.set_permission( ip_id=ip_id, signer=account.address, to=CORE_METADATA_MODULE, - permission=1, + permission=AccessPermission.ALLOW, # No func parameter provided - should use default ) @@ -198,7 +193,7 @@ def test_different_function_selectors(self, story_client: StoryClient, ip_id): ip_id=ip_id, signer=account.address, to=CORE_METADATA_MODULE, - permission=ALLOW, + permission=AccessPermission.ALLOW, func="setAll(address,string,bytes32,bytes32)", ) @@ -209,7 +204,7 @@ def test_different_function_selectors(self, story_client: StoryClient, ip_id): ip_id=ip_id, signer=account.address, to=CORE_METADATA_MODULE, - permission=ALLOW, + permission=AccessPermission.ALLOW, func="setName(address,string)", ) @@ -220,7 +215,7 @@ def test_different_function_selectors(self, story_client: StoryClient, ip_id): ip_id=ip_id, signer=account.address, to=CORE_METADATA_MODULE, - permission=ALLOW, + permission=AccessPermission.ALLOW, func="setDescription(address,string)", ) @@ -232,7 +227,7 @@ def test_different_function_selectors(self, story_client: StoryClient, ip_id): ip_id=ip_id, signer=account.address, to=CORE_METADATA_MODULE, - permission=ALLOW, + permission=AccessPermission.ALLOW, # No func parameter provided deadline=deadline, ) @@ -244,12 +239,8 @@ def test_permission_hierarchies_and_overrides( self, story_client: StoryClient, ip_id ): """Test permission hierarchies and how permissions override each other.""" - DISALLOW = 0 - ALLOW = 1 - ABSTAIN = 2 - response = story_client.Permission.set_all_permissions( - ip_id=ip_id, signer=account.address, permission=DISALLOW + ip_id=ip_id, signer=account.address, permission=AccessPermission.ABSTAIN ) assert response is not None @@ -260,7 +251,7 @@ def test_permission_hierarchies_and_overrides( ip_id=ip_id, signer=account.address, to=CORE_METADATA_MODULE, - permission=ALLOW, + permission=AccessPermission.ALLOW, func=specific_func, ) @@ -270,7 +261,9 @@ def test_permission_hierarchies_and_overrides( alternate_signer = web3.eth.account.create() response = story_client.Permission.set_all_permissions( - ip_id=ip_id, signer=alternate_signer.address, permission=ALLOW + ip_id=ip_id, + signer=alternate_signer.address, + permission=AccessPermission.ALLOW, ) assert response is not None @@ -280,7 +273,7 @@ def test_permission_hierarchies_and_overrides( ip_id=ip_id, signer=alternate_signer.address, to=CORE_METADATA_MODULE, - permission=DISALLOW, + permission=AccessPermission.ABSTAIN, func=specific_func, ) @@ -293,7 +286,7 @@ def test_permission_hierarchies_and_overrides( ip_id=ip_id, signer=account.address, to=CORE_METADATA_MODULE, - permission=ALLOW, + permission=AccessPermission.ALLOW, func="setDescription(address,string)", deadline=deadline, ) @@ -302,14 +295,16 @@ def test_permission_hierarchies_and_overrides( assert "tx_hash" in response response = story_client.Permission.set_all_permissions( - ip_id=ip_id, signer=account.address, permission=ABSTAIN + ip_id=ip_id, signer=account.address, permission=AccessPermission.DENY ) assert response is not None assert "tx_hash" in response response = story_client.Permission.set_all_permissions( - ip_id=ip_id, signer=alternate_signer.address, permission=ABSTAIN + ip_id=ip_id, + signer=alternate_signer.address, + permission=AccessPermission.DENY, ) assert response is not None diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 0000000..08b7793 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,152 @@ +from unittest.mock import MagicMock, Mock, patch + +import pytest +from eth_account import Account +from web3 import Web3 + +from tests.unit.fixtures.data import ADDRESS, TX_HASH + + +@pytest.fixture(scope="package") +def mock_account(): + account = MagicMock() + account.address = "0xF60cBF0Ea1A61567F1dDaf79A6219D20d189155c" + # Create a mock signed transaction object with raw_transaction attribute + mock_signed_txn = MagicMock() + mock_signed_txn.raw_transaction = b"raw_transaction_bytes" + + account.sign_transaction = MagicMock(return_value=mock_signed_txn) + account.sign_message = MagicMock(return_value=b"mock_signature") + return account + + +@pytest.fixture(scope="package") +def mock_web3(): + mock_web3 = Mock(spec=Web3) + mock_web3.to_checksum_address = MagicMock(return_value=ADDRESS) + mock_web3.to_bytes = MagicMock(return_value=b"mock_bytes") + + # Add eth attribute with contract method + mock_eth = Mock() + + # Create a function that returns a new mock contract each time + def create_mock_contract(*args, **kwargs): + """Create a new mock contract instance with address""" + mock_contract = Mock() + mock_contract.address = ADDRESS + mock_contract.encode_abi = MagicMock(return_value="0x00") + return mock_contract + + # Set up the contract method to return new mock contracts + mock_eth.contract = create_mock_contract + mock_web3.eth = mock_eth + mock_web3.eth.get_transaction_count = MagicMock(return_value=0) + mock_web3.eth.send_raw_transaction = MagicMock(return_value=TX_HASH) + mock_web3.eth.wait_for_transaction_receipt = MagicMock( + return_value={"status": 1, "logs": []} + ) + return mock_web3 + + +@pytest.fixture(scope="package") +def mock_is_checksum_address(): + def _mock(is_checksum_address: bool = True): + return patch.object( + Web3, "is_checksum_address", return_value=is_checksum_address + ) + + return _mock + + +@pytest.fixture(scope="package") +def mock_signature_related_methods(): + class SignatureMockContext: + def __init__(self): + self.patches = [] + + def __enter__(self): + # Mock the IPAccountImplClient constructor and its contract.encode_abi method + mock_client = MagicMock() + mock_contract = MagicMock() + mock_contract.encode_abi = MagicMock(return_value=b"encoded_data") + mock_client.contract = mock_contract + + # Create all the patches + mock_web3_to_bytes = patch.object( + Web3, "to_bytes", return_value=b"mock_bytes" + ) + mock_account_sign_message = patch.object( + Account, + "sign_message", + return_value=MagicMock(signature=b"mock_signature"), + ) + + # Create a mock class that behaves like IPAccountImplClient + class MockIPAccountImplClient: + def __init__(self, web3, contract_address=None): + self.web3 = web3 + self.contract_address = contract_address + self.contract = mock_contract + + # Patch the class to return our mock instance + mock_ip_account_client = patch( + "story_protocol_python_sdk.abi.IPAccountImpl.IPAccountImpl_client.IPAccountImplClient", + MockIPAccountImplClient, + ) + + # Apply all patches at once + mock_web3_to_bytes.start() + mock_account_sign_message.start() + mock_ip_account_client.start() + + # Store patches for cleanup + self.patches = [ + mock_web3_to_bytes, + mock_account_sign_message, + mock_ip_account_client, + ] + + def __exit__(self, exc_type, exc_val, exc_tb): + # Stop all patches in reverse order + for patch_obj in reversed(self.patches): + patch_obj.stop() + + return SignatureMockContext + + +@pytest.fixture(scope="package") +def mock_license_registry_client(): + """Fixture to mock LicenseRegistryClient for derivative data validation""" + + def _mock(): + # Create a mock that returns a proper value for getRoyaltyPercent + mock_client = MagicMock() + mock_client.hasIpAttachedLicenseTerms = MagicMock(return_value=True) + mock_client.getRoyaltyPercent = MagicMock(return_value=10) + + # Patch both IPAsset and derivative_data modules + patch1 = patch( + "story_protocol_python_sdk.resources.IPAsset.LicenseRegistryClient", + return_value=mock_client, + ) + patch2 = patch( + "story_protocol_python_sdk.utils.derivative_data.LicenseRegistryClient", + return_value=mock_client, + ) + + # Start both patches + patch1.start() + patch2.start() + + # Return a context manager that stops both patches + class MockContext: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + patch1.stop() + patch2.stop() + + return MockContext() + + return _mock diff --git a/tests/unit/fixtures/data.py b/tests/unit/fixtures/data.py index df5ad3f..42203f7 100644 --- a/tests/unit/fixtures/data.py +++ b/tests/unit/fixtures/data.py @@ -1,5 +1,6 @@ CHAIN_ID = 1315 ADDRESS = "0x1234567890123456789012345678901234567890" -TX_HASH = "0x0c0cce07beb64ccfbdd59da111f23084ab7c9e96a951f7381af49e792d014c04" +TX_HASH = b"tx_hash_bytes" # STATE as bytes32 (32 bytes = 64 hex characters) STATE = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +IP_ID = "0xFEB4eE75600768635010D80D56a5711268D26DaB" diff --git a/tests/unit/fixtures/web3.py b/tests/unit/fixtures/web3.py deleted file mode 100644 index 6817f8e..0000000 --- a/tests/unit/fixtures/web3.py +++ /dev/null @@ -1,25 +0,0 @@ -from unittest.mock import MagicMock, Mock - -from web3 import Web3 - -from tests.unit.fixtures.data import ADDRESS - -mock_web3 = Mock(spec=Web3) -mock_web3.to_checksum_address = MagicMock(return_value=ADDRESS) - -# Add eth attribute with contract method -mock_eth = Mock() - - -# Create a function that returns a new mock contract each time -def create_mock_contract(*args, **kwargs): - """Create a new mock contract instance with address""" - mock_contract = Mock() - mock_contract.address = ADDRESS - mock_contract.encode_abi = MagicMock(return_value="0x00") - return mock_contract - - -# Set up the contract method to return new mock contracts -mock_eth.contract = create_mock_contract -mock_web3.eth = mock_eth diff --git a/tests/unit/resources/test_ip_asset.py b/tests/unit/resources/test_ip_asset.py index 8fe3917..45c7e02 100644 --- a/tests/unit/resources/test_ip_asset.py +++ b/tests/unit/resources/test_ip_asset.py @@ -1,179 +1,174 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest -from eth_utils import is_address, to_checksum_address -from web3 import Web3 from story_protocol_python_sdk.resources.IPAsset import IPAsset +from story_protocol_python_sdk.utils.constants import ZERO_HASH +from story_protocol_python_sdk.utils.derivative_data import DerivativeDataInput +from tests.unit.fixtures.data import ADDRESS, CHAIN_ID, IP_ID, TX_HASH -ZERO_HASH = "0x0000000000000000000000000000000000000000000000000000000000000000" -ZERO_ADDRESS = "0x0000000000000000000000000000000000000000" +@pytest.fixture(scope="class") +def ip_asset(mock_web3, mock_account): + return IPAsset(mock_web3, mock_account, CHAIN_ID) -class MockWeb3: - def __init__(self): - self.eth = MagicMock() - - @staticmethod - def to_checksum_address(address): - if not is_address(address): - raise ValueError(f"Invalid address: {address}") - return to_checksum_address(address) - @staticmethod - def to_bytes(hexstr=None, **kwargs): - return Web3.to_bytes(hexstr=hexstr, **kwargs) +@pytest.fixture(scope="class") +def mock_get_ip_id(ip_asset): + def _mock(): + return patch.object( + ip_asset.ip_asset_registry_client, "ipId", return_value=IP_ID + ) - @staticmethod - def to_wei(number, unit): - return Web3.to_wei(number, unit) + return _mock - @staticmethod - def is_address(address): - return is_address(address) - @staticmethod - def keccak(text=None): - return Web3.keccak(text=text) +@pytest.fixture(scope="class") +def mock_is_registered(ip_asset): + def _mock(is_registered: bool = False): + return patch.object( + ip_asset.ip_asset_registry_client, + "isRegistered", + return_value=is_registered, + ) - def is_connected(self): - return True + return _mock -@pytest.fixture -def mock_web3(): - return MockWeb3() +@pytest.fixture(scope="class") +def mock_parse_ip_registered_event(ip_asset): + def _mock(): + return patch.object( + ip_asset, "_parse_tx_ip_registered_event", return_value={"ip_id": IP_ID} + ) + return _mock -@pytest.fixture -def mock_account(): - account = MagicMock() - account.address = "0xF60cBF0Ea1A61567F1dDaf79A6219D20d189155c" - return account +@pytest.fixture(scope="class") +def mock_get_function_signature(): + def _mock(): + return patch( + "story_protocol_python_sdk.resources.IPAsset.get_function_signature", + return_value="setAll(address,string,bytes32,bytes32)", + ) -@pytest.fixture -def ip_asset(mock_web3, mock_account): - chain_id = 1516 - return IPAsset(mock_web3, mock_account, chain_id) + return _mock class TestIPAssetRegister: - def test_register_invalid_deadline_type(self, ip_asset): - with patch.object( - ip_asset, - "_get_ip_id", - return_value="0xd142822Dc1674154EaF4DDF38bbF7EF8f0D8ECe4", - ), patch.object(ip_asset, "_is_registered", return_value=False): - with pytest.raises(ValueError): + def test_register_invalid_deadline_type( + self, ip_asset, mock_get_ip_id, mock_is_registered + ): + with mock_get_ip_id(), mock_is_registered(): + with pytest.raises(ValueError, match="Invalid deadline value."): ip_asset.register( - nft_contract="0x1daAE3197Bc469Cb97B917aa460a12dD95c6627c", + nft_contract=ADDRESS, token_id=3, deadline="error", ip_metadata={"ip_metadata_uri": "1", "ip_metadata_hash": ZERO_HASH}, ) - def test_register_already_registered(self, ip_asset): - token_contract = "0x1daAE3197Bc469Cb97B917aa460a12dD95c6627c" - token_id = 3 - ip_id = "0xd142822Dc1674154EaF4DDF38bbF7EF8f0D8ECe4" - - with patch.object(ip_asset, "_get_ip_id", return_value=ip_id), patch.object( - ip_asset, "_is_registered", return_value=True - ): - response = ip_asset.register(token_contract, token_id) - assert response["ip_id"] == ip_id + def test_register_already_registered( + self, ip_asset, mock_get_ip_id, mock_is_registered + ): + with mock_get_ip_id(), mock_is_registered(True): + response = ip_asset.register(ADDRESS, 3) + assert response["ip_id"] == IP_ID assert response["tx_hash"] is None - def test_register_successful(self, ip_asset): - token_contract = "0x1daAE3197Bc469Cb97B917aa460a12dD95c6627c" - token_id = 3 - ip_id = "0x1daAE3197Bc469Cb97B917aa460a12dD95c6627c" - tx_hash = "0x129f7dd802200f096221dd89d5b086e4bd3ad6eafb378a0c75e3b04fc375f997" - - class MockTxHash: - def hex(self): - return tx_hash - - mock_tx_hash = MockTxHash() - - mock_signed_txn = MagicMock() - mock_signed_txn.raw_transaction = b"raw_transaction_bytes" - - ip_asset.account.sign_transaction = MagicMock(return_value=mock_signed_txn) - - with patch.object(ip_asset, "_get_ip_id", return_value=ip_id), patch.object( - ip_asset, "_is_registered", return_value=False - ), patch.object( - ip_asset.web3.eth, "get_transaction_count", return_value=0 - ), patch.object( - ip_asset.web3.eth, "send_raw_transaction", return_value=mock_tx_hash - ), patch.object( - ip_asset.web3.eth, - "wait_for_transaction_receipt", - return_value={"status": 1, "logs": []}, - ), patch.object( - ip_asset, "_parse_tx_ip_registered_event", return_value={"ip_id": ip_id} - ): - - result = ip_asset.register(token_contract, token_id) - assert result["tx_hash"] == tx_hash - assert result["ip_id"] == ip_id - - def test_register_with_metadata(self, ip_asset): - token_contract = "0x1daAE3197Bc469Cb97B917aa460a12dD95c6627c" - token_id = 3 - ip_id = "0x1daAE3197Bc469Cb97B917aa460a12dD95c6627c" - tx_hash = "0x129f7dd802200f096221dd89d5b086e4bd3ad6eafb378a0c75e3b04fc375f997" - - metadata = { - "ip_metadata_uri": "", - "ip_metadata_hash": ZERO_HASH, - "nft_metadata_uri": "", - "nft_metadata_hash": ZERO_HASH, - } - - calculated_deadline = 1000 - - class MockTxHash: - def hex(self): - return tx_hash - - mock_tx_hash = MockTxHash() - - mock_signed_txn = MagicMock() - mock_signed_txn.raw_transaction = b"raw_transaction_bytes" - - ip_asset.account.sign_transaction = MagicMock(return_value=mock_signed_txn) - - with patch.object(ip_asset, "_get_ip_id", return_value=ip_id), patch.object( - ip_asset, "_is_registered", return_value=False - ), patch.object( - ip_asset.sign_util, "get_deadline", return_value=calculated_deadline - ), patch.object( - ip_asset.sign_util, - "get_permission_signature", - return_value={"signature": tx_hash}, - ), patch.object( - ip_asset.web3.eth, "get_transaction_count", return_value=0 - ), patch.object( - ip_asset.web3.eth, "send_raw_transaction", return_value=mock_tx_hash - ), patch.object( - ip_asset.web3.eth, - "wait_for_transaction_receipt", - return_value={"status": 1, "logs": []}, - ), patch.object( - ip_asset, - "_parse_tx_ip_registered_event", - return_value={"ip_id": ip_id, "token_id": token_id}, - ): - - result = ip_asset.register( - nft_contract=token_contract, - token_id=token_id, - ip_metadata=metadata, - deadline=1000, - ) - - assert result["tx_hash"] == tx_hash - assert result["ip_id"] == ip_id + def test_register_successful( + self, + ip_asset, + mock_get_ip_id, + mock_is_registered, + mock_parse_ip_registered_event, + ): + with mock_get_ip_id(), mock_is_registered(), mock_parse_ip_registered_event(): + + result = ip_asset.register(ADDRESS, 3) + assert result["tx_hash"] == TX_HASH.hex() + assert result["ip_id"] == IP_ID + + def test_register_with_metadata( + self, + ip_asset: IPAsset, + mock_get_ip_id, + mock_is_registered, + mock_parse_ip_registered_event, + mock_signature_related_methods, + ): + + with mock_get_ip_id(), mock_is_registered(), mock_parse_ip_registered_event(): + with mock_signature_related_methods(): + result = ip_asset.register( + nft_contract=ADDRESS, + token_id=3, + ip_metadata={ + "ip_metadata_uri": "", + "ip_metadata_hash": ZERO_HASH, + "nft_metadata_uri": "", + "nft_metadata_hash": ZERO_HASH, + }, + deadline=1000, + ) + + assert result["tx_hash"] == TX_HASH.hex() + assert result["ip_id"] == IP_ID + + +class TestRegisterDerivativeIp: + def test_ip_is_already_registered( + self, ip_asset, mock_get_ip_id, mock_is_registered + ): + with mock_get_ip_id(), mock_is_registered(True): + with pytest.raises( + ValueError, match="The NFT with id 3 is already registered as IP." + ): + ip_asset.register_derivative_ip( + nft_contract=ADDRESS, + token_id=3, + deriv_data={ + "max_minting_fee": 1000000000000000000, + "max_rts": 1000000000000000000, + "max_revenue_share": 1000000000000000000, + }, + ) + + def test_parent_ip_id_is_empty(self, ip_asset, mock_get_ip_id, mock_is_registered): + with mock_get_ip_id(), mock_is_registered(): + with pytest.raises(ValueError, match="The parent IP IDs must be provided."): + ip_asset.register_derivative_ip( + nft_contract=ADDRESS, + token_id=3, + deriv_data=DerivativeDataInput( + parent_ip_ids=[], + license_terms_ids=[], + ), + ) + + def test_success( + self, + ip_asset, + mock_get_ip_id, + mock_is_registered, + mock_parse_ip_registered_event, + mock_signature_related_methods, + mock_get_function_signature, + mock_license_registry_client, + ): + with mock_get_ip_id(), mock_is_registered(), mock_parse_ip_registered_event(), mock_get_function_signature(), mock_license_registry_client(): + with mock_signature_related_methods(): + result = ip_asset.register_derivative_ip( + nft_contract=ADDRESS, + token_id=3, + deriv_data=DerivativeDataInput( + parent_ip_ids=[IP_ID, IP_ID], + license_terms_ids=[1, 2], + max_minting_fee=10, + max_rts=100, + max_revenue_share=100, + ), + ) + assert result["tx_hash"] == TX_HASH.hex() + assert result["ip_id"] == IP_ID diff --git a/tests/unit/resources/test_permission.py b/tests/unit/resources/test_permission.py index 9b0687b..2b7250e 100644 --- a/tests/unit/resources/test_permission.py +++ b/tests/unit/resources/test_permission.py @@ -3,12 +3,12 @@ import pytest from story_protocol_python_sdk.resources.Permission import Permission +from story_protocol_python_sdk.types.common import AccessPermission from tests.unit.fixtures.data import ADDRESS, CHAIN_ID, STATE, TX_HASH -from tests.unit.fixtures.web3 import mock_web3 @pytest.fixture -def permission(): +def permission(mock_web3): return Permission(mock_web3, ADDRESS, CHAIN_ID) @@ -21,18 +21,24 @@ def test_unregistered_ip_account(self, permission: Permission): Exception, match="IP id with 0x1234567890123456789012345678901234567890 is not registered.", ): - permission.set_permission(ADDRESS, ADDRESS, ADDRESS, 1) + permission.set_permission( + ADDRESS, ADDRESS, ADDRESS, AccessPermission.ALLOW + ) def test_invalid_signer_address(self, permission: Permission): with patch.object( permission.ip_asset_registry_client, "isRegistered", return_value=True ): with pytest.raises(Exception, match="Invalid address: 0xInvalidAddress."): - permission.set_permission(ADDRESS, "0xInvalidAddress", ADDRESS, 1) + permission.set_permission( + ADDRESS, "0xInvalidAddress", ADDRESS, AccessPermission.ALLOW + ) def test_invalid_to_address(self, permission: Permission): with pytest.raises(Exception, match="Invalid address: 0xInvalidAddress."): - permission.set_permission(ADDRESS, ADDRESS, "0xInvalidAddress", 1) + permission.set_permission( + ADDRESS, ADDRESS, "0xInvalidAddress", AccessPermission.ALLOW + ) def test_successful_transaction(self, permission: Permission): with patch.object( @@ -40,7 +46,9 @@ def test_successful_transaction(self, permission: Permission): ), patch.object( permission.ip_account, "execute", return_value={"tx_hash": TX_HASH} ): - response = permission.set_permission(ADDRESS, ADDRESS, ADDRESS, 1) + response = permission.set_permission( + ADDRESS, ADDRESS, ADDRESS, AccessPermission.ALLOW + ) assert response["tx_hash"] == TX_HASH def test_transaction_request_fails(self, permission: Permission): @@ -52,7 +60,9 @@ def test_transaction_request_fails(self, permission: Permission): side_effect=Exception("Transaction failed"), ): with pytest.raises(Exception, match="Transaction failed"): - permission.set_permission(ADDRESS, ADDRESS, ADDRESS, 1) + permission.set_permission( + ADDRESS, ADDRESS, ADDRESS, AccessPermission.ALLOW + ) class TestSetAllPermissions: @@ -62,7 +72,9 @@ def test_successful_transaction(self, permission: Permission): ), patch.object( permission.ip_account, "execute", return_value={"tx_hash": TX_HASH} ): - response = permission.set_all_permissions(ADDRESS, ADDRESS, 1) + response = permission.set_all_permissions( + ADDRESS, ADDRESS, AccessPermission.ALLOW + ) assert response["tx_hash"] == TX_HASH def test_transaction_request_fails(self, permission: Permission): @@ -74,7 +86,7 @@ def test_transaction_request_fails(self, permission: Permission): side_effect=Exception("Transaction failed"), ): with pytest.raises(Exception, match="Transaction failed"): - permission.set_all_permissions(ADDRESS, ADDRESS, 1) + permission.set_all_permissions(ADDRESS, ADDRESS, AccessPermission.ALLOW) class TestCreateSetPermissionSignature: @@ -82,7 +94,7 @@ class TestCreateSetPermissionSignature: def test_invalid_deadline(self, permission: Permission): with pytest.raises(Exception, match="Invalid deadline value."): permission.create_set_permission_signature( - ADDRESS, ADDRESS, ADDRESS, 1, deadline=-1 + ADDRESS, ADDRESS, ADDRESS, AccessPermission.ALLOW, deadline=-1 ) def test_successful_signature(self, permission: Permission): @@ -104,6 +116,6 @@ def test_successful_signature(self, permission: Permission): ), ): response = permission.create_set_permission_signature( - ADDRESS, ADDRESS, ADDRESS, 1 + ADDRESS, ADDRESS, ADDRESS, AccessPermission.ALLOW ) assert response["tx_hash"] == TX_HASH diff --git a/tests/unit/utils/test_derivative_data.py b/tests/unit/utils/test_derivative_data.py new file mode 100644 index 0000000..e90b469 --- /dev/null +++ b/tests/unit/utils/test_derivative_data.py @@ -0,0 +1,438 @@ +from unittest.mock import MagicMock, patch + +import pytest +from _pytest.raises import raises + +from story_protocol_python_sdk.abi.IPAssetRegistry.IPAssetRegistry_client import ( + IPAssetRegistryClient, +) +from story_protocol_python_sdk.abi.LicenseRegistry.LicenseRegistry_client import ( + LicenseRegistryClient, +) +from story_protocol_python_sdk.abi.PILicenseTemplate.PILicenseTemplate_client import ( + PILicenseTemplateClient, +) +from story_protocol_python_sdk.utils.constants import MAX_ROYALTY_TOKEN +from story_protocol_python_sdk.utils.derivative_data import ( + DerivativeData, + DerivativeDataInput, +) +from tests.unit.fixtures.data import ADDRESS, IP_ID + + +@pytest.fixture(scope="module") +def mock_ip_asset_registry_client(): + """Fixture to mock IPAssetRegistryClient""" + + def _mock_ip_registered(is_registered=True): + return patch.object( + IPAssetRegistryClient, + "__new__", + return_value=MagicMock(isRegistered=MagicMock(return_value=is_registered)), + ) + + return _mock_ip_registered + + +@pytest.fixture(scope="module") +def mock_license_registry_client(): + """Fixture to mock LicenseRegistryClient""" + + def _mock_license_registry_client( + has_ip_attached_license_terms=True, get_royalty_percent=10 + ): + return patch.object( + LicenseRegistryClient, + "__new__", + return_value=MagicMock( + hasIpAttachedLicenseTerms=MagicMock( + return_value=has_ip_attached_license_terms + ), + getRoyaltyPercent=MagicMock(return_value=get_royalty_percent), + ), + ) + + return _mock_license_registry_client + + +@pytest.fixture(scope="module") +def mock_pi_license_template_client(): + """Fixture to mock PILicenseTemplateClient""" + + def _mock_pi_license_template_client(): + mock_instance = MagicMock() + mock_instance.contract = MagicMock() + mock_instance.contract.address = ADDRESS + return patch.object( + PILicenseTemplateClient, + "__new__", + return_value=mock_instance, + ) + + return _mock_pi_license_template_client + + +class TestValidateParentIpIdsAndLicenseTermsIds: + def test_validate_parent_ip_ids_is_empty(self, mock_web3): + with raises(ValueError, match="The parent IP IDs must be provided."): + DerivativeData( + web3=mock_web3, + parent_ip_ids=[], + license_terms_ids=[2], + max_minting_fee=10, + max_rts=10, + max_revenue_share=100, + license_template="0x1234567890123456789012345678901234567890", + ) + + def test_validate_license_terms_ids_is_empty(self, mock_web3): + with raises(ValueError, match="The license terms IDs must be provided."): + DerivativeData( + web3=mock_web3, + parent_ip_ids=[ADDRESS], + license_terms_ids=[], + max_minting_fee=10, + max_rts=10, + max_revenue_share=100, + license_template="0x1234567890123456789012345678901234567890", + ) + + def test_validate_parent_ip_ids_and_license_terms_ids_are_not_equal( + self, mock_web3 + ): + with raises( + ValueError, + match="The number of parent IP IDs must match the number of license terms IDs.", + ): + DerivativeData( + web3=mock_web3, + parent_ip_ids=[ADDRESS], + license_terms_ids=[2, 3], + max_minting_fee=10, + max_rts=10, + max_revenue_share=100, + license_template="0x1234567890123456789012345678901234567890", + ) + + def test_validate_parent_ip_ids_is_not_valid_address( + self, mock_web3, mock_is_checksum_address + ): + with mock_is_checksum_address(is_checksum_address=False): + with raises(ValueError, match="The parent IP ID must be a valid address."): + DerivativeData( + web3=mock_web3, + parent_ip_ids=["0x1234567890123456789012345678901234567890"], + license_terms_ids=[2], + max_minting_fee=10, + max_rts=10, + max_revenue_share=100, + license_template="0x1234567890123456789012345678901234567890", + ) + + def test_validate_parent_ip_ids_is_not_registered( + self, + mock_web3, + mock_is_checksum_address, + mock_ip_asset_registry_client, + ): + with mock_is_checksum_address(), mock_ip_asset_registry_client( + is_registered=False + ): + with raises( + ValueError, + match=f"The parent IP ID {IP_ID} must be registered.", + ): + DerivativeData( + web3=mock_web3, + parent_ip_ids=[IP_ID], + license_terms_ids=[2], + max_minting_fee=10, + max_rts=10, + max_revenue_share=100, + license_template="0x1234567890123456789012345678901234567890", + ) + + def test_validate_license_terms_not_attached( + self, + mock_web3, + mock_is_checksum_address, + mock_ip_asset_registry_client, + mock_license_registry_client, + ): + with mock_is_checksum_address(), mock_ip_asset_registry_client( + is_registered=True + ), mock_license_registry_client(has_ip_attached_license_terms=False): + with raises( + ValueError, + match=f"License terms id 2 must be attached to the parent ipId {IP_ID} before registering derivative.", + ): + DerivativeData( + web3=mock_web3, + parent_ip_ids=[IP_ID], + license_terms_ids=[2], + max_minting_fee=10, + max_rts=10, + max_revenue_share=100, + license_template="0x1234567890123456789012345678901234567890", + ) + + def test_validate_royalty_percent_exceeds_max_revenue_share( + self, + mock_web3, + mock_is_checksum_address, + mock_ip_asset_registry_client, + mock_license_registry_client, + ): + with mock_is_checksum_address(), mock_ip_asset_registry_client( + is_registered=True + ), mock_license_registry_client( + has_ip_attached_license_terms=True, get_royalty_percent=1500000000000 + ): + with raises( + ValueError, + match=f"The total royalty percent for the parent IP {IP_ID} is greater than the maximum revenue share 1000000", + ): + DerivativeData( + web3=mock_web3, + parent_ip_ids=[IP_ID], + license_terms_ids=[2], + max_minting_fee=10, + max_rts=10, + max_revenue_share=1, + license_template="0x1234567890123456789012345678901234567890", + ) + + def test_validate_royalty_percent_is_less_than_max_revenue_share( + self, + mock_web3, + mock_is_checksum_address, + mock_ip_asset_registry_client, + mock_license_registry_client, + ): + with mock_is_checksum_address(), mock_ip_asset_registry_client(), mock_license_registry_client(): + derivative_data = DerivativeData.from_input( + web3=mock_web3, + input_data=DerivativeDataInput( + parent_ip_ids=[IP_ID], + license_terms_ids=[2], + max_minting_fee=10, + max_rts=10, + license_template="0x1234567890123456789012345678901234567890", + ), + ) + assert derivative_data.max_revenue_share == MAX_ROYALTY_TOKEN + + +class TestValidateMaxMintingFee: + def test_validate_max_minting_fee_is_less_than_0( + self, + mock_web3, + mock_is_checksum_address, + mock_ip_asset_registry_client, + mock_license_registry_client, + ): + with mock_is_checksum_address(), mock_ip_asset_registry_client(), mock_license_registry_client(): + with raises( + ValueError, match="The max minting fee must be greater than 0." + ): + DerivativeData( + web3=mock_web3, + parent_ip_ids=[IP_ID], + license_terms_ids=[2], + max_minting_fee=-1, + max_rts=10, + max_revenue_share=100, + license_template="0x1234567890123456789012345678901234567890", + ) + + +class TestValidateMaxRts: + def test_validate_max_rts_is_less_than_0( + self, + mock_web3, + mock_is_checksum_address, + mock_ip_asset_registry_client, + mock_license_registry_client, + ): + with mock_is_checksum_address(), mock_ip_asset_registry_client(), mock_license_registry_client(): + with raises( + ValueError, + match="The maxRts must be greater than 0 and less than 100000000.", + ): + DerivativeData.from_input( + web3=mock_web3, + input_data=DerivativeDataInput( + parent_ip_ids=[IP_ID], + license_terms_ids=[2], + max_rts=-1, + ), + ) + + def test_validate_max_rts_is_greater_than_100_000_000( + self, mock_web3, mock_ip_asset_registry_client, mock_license_registry_client + ): + with mock_ip_asset_registry_client(), mock_license_registry_client(): + with raises( + ValueError, + match="The maxRts must be greater than 0 and less than 100000000.", + ): + DerivativeData.from_input( + web3=mock_web3, + input_data=DerivativeDataInput( + parent_ip_ids=[IP_ID], + license_terms_ids=[2], + max_rts=1000000000001, + ), + ) + + def test_validate_max_rts_default_value_is_max_rts( + self, + mock_web3, + mock_is_checksum_address, + mock_ip_asset_registry_client, + mock_license_registry_client, + ): + with mock_is_checksum_address(), mock_ip_asset_registry_client(), mock_license_registry_client(): + derivative_data = DerivativeData.from_input( + web3=mock_web3, + input_data=DerivativeDataInput( + parent_ip_ids=[IP_ID], + license_terms_ids=[2], + ), + ) + assert derivative_data.max_rts == MAX_ROYALTY_TOKEN + + +class TestValidateMaxRevenueShare: + def test_validate_max_revenue_share_is_less_than_0( + self, mock_web3, mock_ip_asset_registry_client, mock_license_registry_client + ): + with mock_ip_asset_registry_client(), mock_license_registry_client(): + with raises( + ValueError, match="The maxRevenueShare must be between 0 and 100." + ): + DerivativeData.from_input( + web3=mock_web3, + input_data=DerivativeDataInput( + parent_ip_ids=[IP_ID], + license_terms_ids=[2], + max_minting_fee=10, + max_rts=10, + max_revenue_share=-1, + ), + ) + + def test_validate_max_revenue_share_is_greater_than_100( + self, + mock_web3, + mock_is_checksum_address, + mock_ip_asset_registry_client, + mock_license_registry_client, + ): + with mock_is_checksum_address(), mock_ip_asset_registry_client(), mock_license_registry_client(): + with raises( + ValueError, match="The maxRevenueShare must be between 0 and 100." + ): + DerivativeData.from_input( + web3=mock_web3, + input_data=DerivativeDataInput( + parent_ip_ids=[IP_ID], + license_terms_ids=[2], + max_minting_fee=10, + max_rts=10, + max_revenue_share=101, + ), + ) + + def test_validate_max_revenue_share_default_value_is_100( + self, + mock_web3, + mock_is_checksum_address, + mock_ip_asset_registry_client, + mock_license_registry_client, + ): + with mock_is_checksum_address(), mock_ip_asset_registry_client(), mock_license_registry_client(): + derivative_data = DerivativeData.from_input( + web3=mock_web3, + input_data=DerivativeDataInput( + parent_ip_ids=[IP_ID], + license_terms_ids=[2], + ), + ) + assert derivative_data.max_revenue_share == MAX_ROYALTY_TOKEN + + +class TestValidateLicenseTemplate: + def test_validate_license_template_default_value_is_pi_license_template( + self, + mock_web3, + mock_is_checksum_address, + mock_pi_license_template_client, + mock_ip_asset_registry_client, + mock_license_registry_client, + ): + with mock_is_checksum_address(), mock_pi_license_template_client(), mock_ip_asset_registry_client(), mock_license_registry_client(): + derivative_data = DerivativeData.from_input( + web3=mock_web3, + input_data=DerivativeDataInput( + parent_ip_ids=[IP_ID], + license_terms_ids=[2], + ), + ) + assert derivative_data.license_template == ADDRESS + + +class TestGetValidatedData: + def test_get_validated_data_with_default_values( + self, + mock_web3, + mock_is_checksum_address, + mock_pi_license_template_client, + mock_ip_asset_registry_client, + mock_license_registry_client, + ): + with mock_is_checksum_address(), mock_pi_license_template_client(), mock_ip_asset_registry_client(), mock_license_registry_client(): + derivative_data = DerivativeData.from_input( + web3=mock_web3, + input_data=DerivativeDataInput( + parent_ip_ids=[IP_ID], + license_terms_ids=[2], + ), + ) + assert derivative_data.get_validated_data() == { + "parentIpIds": [IP_ID], + "licenseTermsIds": [2], + "maxMintingFee": 0, + "maxRts": MAX_ROYALTY_TOKEN, + "maxRevenueShare": MAX_ROYALTY_TOKEN, + "licenseTemplate": ADDRESS, + "royaltyContext": "0x0000000000000000000000000000000000000000", + } + + def test_get_validated_data_with_custom_values( + self, + mock_web3, + mock_ip_asset_registry_client, + mock_license_registry_client, + mock_pi_license_template_client, + mock_is_checksum_address, + ): + with mock_is_checksum_address(), mock_pi_license_template_client(), mock_ip_asset_registry_client(), mock_license_registry_client(): + derivative_data = DerivativeData( + web3=mock_web3, + parent_ip_ids=[IP_ID], + license_terms_ids=[2], + max_minting_fee=10, + max_rts=10, + max_revenue_share=10, + license_template="0x1234567890123456789012345678901234567890", + ) + assert derivative_data.get_validated_data() == { + "parentIpIds": [IP_ID], + "licenseTermsIds": [2], + "maxMintingFee": 10, + "maxRts": 10, + "maxRevenueShare": 10000000.0, + "licenseTemplate": "0x1234567890123456789012345678901234567890", + "royaltyContext": "0x0000000000000000000000000000000000000000", + } diff --git a/tests/unit/utils/test_ip_metadata.py b/tests/unit/utils/test_ip_metadata.py new file mode 100644 index 0000000..8649843 --- /dev/null +++ b/tests/unit/utils/test_ip_metadata.py @@ -0,0 +1,41 @@ +from ens.ens import HexStr + +from story_protocol_python_sdk.utils.constants import ZERO_HASH +from story_protocol_python_sdk.utils.ip_metadata import IPMetadata, IPMetadataInput +from tests.unit.fixtures.data import TX_HASH + + +class TestIPMetadata: + def test_from_input_with_default_values(self): + ip_metadata = IPMetadata.from_input(IPMetadataInput(ip_metadata_hash=TX_HASH)) + assert ip_metadata.get_validated_data() == { + "ipMetadataURI": "", + "ipMetadataHash": TX_HASH, + "nftMetadataURI": "", + "nftMetadataHash": ZERO_HASH, + } + + def test_from_input_with_custom_values(self): + ip_metadata = IPMetadata.from_input( + IPMetadataInput( + ip_metadata_uri="https://ipfs.io/ipfs/Qm...", + ip_metadata_hash=HexStr("0x1234567890"), + nft_metadata_uri="https://ipfs.io/ipfs/Qm...", + nft_metadata_hash=HexStr("0x1234567890"), + ) + ) + assert ip_metadata.get_validated_data() == { + "ipMetadataURI": "https://ipfs.io/ipfs/Qm...", + "ipMetadataHash": HexStr("0x1234567890"), + "nftMetadataURI": "https://ipfs.io/ipfs/Qm...", + "nftMetadataHash": HexStr("0x1234567890"), + } + + def test_from_input_with_none(self): + ip_metadata = IPMetadata.from_input(None) + assert ip_metadata.get_validated_data() == { + "ipMetadataURI": "", + "ipMetadataHash": ZERO_HASH, + "nftMetadataURI": "", + "nftMetadataHash": ZERO_HASH, + }