diff --git a/src/story_protocol_python_sdk/__init__.py b/src/story_protocol_python_sdk/__init__.py index 95a383af..30589db6 100644 --- a/src/story_protocol_python_sdk/__init__.py +++ b/src/story_protocol_python_sdk/__init__.py @@ -16,8 +16,12 @@ from .types.resource.IPAsset import ( BatchMintAndRegisterIPInput, BatchMintAndRegisterIPResponse, + BatchRegisterIpAssetsWithOptimizedWorkflowsResponse, + BatchRegistrationResult, + IpRegistrationWorkflowRequest, LicenseTermsDataInput, LinkDerivativeResponse, + MintAndRegisterRequest, MintedNFT, MintNFT, RegisterAndAttachAndDistributeRoyaltyTokensResponse, @@ -26,6 +30,7 @@ RegisteredIP, RegisterIpAssetResponse, RegisterPILTermsAndAttachResponse, + RegisterRegistrationRequest, RegistrationResponse, RegistrationWithRoyaltyVaultAndLicenseTermsResponse, RegistrationWithRoyaltyVaultResponse, @@ -82,6 +87,11 @@ "RegisterIpAssetResponse", "RegisterDerivativeIpAssetResponse", "LinkDerivativeResponse", + "MintAndRegisterRequest", + "RegisterRegistrationRequest", + "IpRegistrationWorkflowRequest", + "BatchRegistrationResult", + "BatchRegisterIpAssetsWithOptimizedWorkflowsResponse", # Constants "ZERO_ADDRESS", "ZERO_HASH", diff --git a/src/story_protocol_python_sdk/abi/DerivativeWorkflows/DerivativeWorkflows_client.py b/src/story_protocol_python_sdk/abi/DerivativeWorkflows/DerivativeWorkflows_client.py index 5881aca0..dd065977 100644 --- a/src/story_protocol_python_sdk/abi/DerivativeWorkflows/DerivativeWorkflows_client.py +++ b/src/story_protocol_python_sdk/abi/DerivativeWorkflows/DerivativeWorkflows_client.py @@ -101,6 +101,12 @@ def build_mintAndRegisterIpAndMakeDerivativeWithLicenseTokens_transaction( ).build_transaction(tx_params) ) + def multicall(self, data): + return self.contract.functions.multicall(data).transact() + + def build_multicall_transaction(self, data, tx_params): + return self.contract.functions.multicall(data).build_transaction(tx_params) + def registerIpAndMakeDerivative( self, nftContract, tokenId, derivData, ipMetadata, sigMetadataAndRegister ): diff --git a/src/story_protocol_python_sdk/abi/Multicall3/Multicall3_client.py b/src/story_protocol_python_sdk/abi/Multicall3/Multicall3_client.py index fb28b71a..7e9ab2ca 100644 --- a/src/story_protocol_python_sdk/abi/Multicall3/Multicall3_client.py +++ b/src/story_protocol_python_sdk/abi/Multicall3/Multicall3_client.py @@ -29,6 +29,12 @@ def __init__(self, web3: Web3): abi = json.load(abi_file) self.contract = self.web3.eth.contract(address=contract_address, abi=abi) + def aggregate3(self, calls): + return self.contract.functions.aggregate3(calls).transact() + + def build_aggregate3_transaction(self, calls, tx_params): + return self.contract.functions.aggregate3(calls).build_transaction(tx_params) + def aggregate3Value(self, calls): return self.contract.functions.aggregate3Value(calls).transact() diff --git a/src/story_protocol_python_sdk/abi/RoyaltyTokenDistributionWorkflows/RoyaltyTokenDistributionWorkflows_client.py b/src/story_protocol_python_sdk/abi/RoyaltyTokenDistributionWorkflows/RoyaltyTokenDistributionWorkflows_client.py index 651e3bbf..a33e4c69 100644 --- a/src/story_protocol_python_sdk/abi/RoyaltyTokenDistributionWorkflows/RoyaltyTokenDistributionWorkflows_client.py +++ b/src/story_protocol_python_sdk/abi/RoyaltyTokenDistributionWorkflows/RoyaltyTokenDistributionWorkflows_client.py @@ -126,6 +126,12 @@ def build_mintAndRegisterIpAndMakeDerivativeAndDistributeRoyaltyTokens_transacti tx_params ) + def multicall(self, data): + return self.contract.functions.multicall(data).transact() + + def build_multicall_transaction(self, data, tx_params): + return self.contract.functions.multicall(data).build_transaction(tx_params) + def registerIpAndAttachPILTermsAndDeployRoyaltyVault( self, nftContract, diff --git a/src/story_protocol_python_sdk/abi/SPGNFTImpl/SPGNFTImpl_client.py b/src/story_protocol_python_sdk/abi/SPGNFTImpl/SPGNFTImpl_client.py index 4490654e..e29db284 100644 --- a/src/story_protocol_python_sdk/abi/SPGNFTImpl/SPGNFTImpl_client.py +++ b/src/story_protocol_python_sdk/abi/SPGNFTImpl/SPGNFTImpl_client.py @@ -19,3 +19,6 @@ def mintFee(self): def mintFeeToken(self): return self.contract.functions.mintFeeToken().call() + + def publicMinting(self): + return self.contract.functions.publicMinting().call() diff --git a/src/story_protocol_python_sdk/resources/IPAsset.py b/src/story_protocol_python_sdk/resources/IPAsset.py index 2b0e7935..a69a0d77 100644 --- a/src/story_protocol_python_sdk/resources/IPAsset.py +++ b/src/story_protocol_python_sdk/resources/IPAsset.py @@ -1,6 +1,6 @@ """Module for handling IP Account operations and transactions.""" -from dataclasses import asdict, is_dataclass, replace +from collections.abc import Sequence from typing import cast from ens.ens import Address, HexStr @@ -22,9 +22,6 @@ from story_protocol_python_sdk.abi.IPAssetRegistry.IPAssetRegistry_client import ( IPAssetRegistryClient, ) -from story_protocol_python_sdk.abi.IpRoyaltyVaultImpl.IpRoyaltyVaultImpl_client import ( - IpRoyaltyVaultImplClient, -) from story_protocol_python_sdk.abi.LicenseAttachmentWorkflows.LicenseAttachmentWorkflows_client import ( LicenseAttachmentWorkflowsClient, ) @@ -58,22 +55,32 @@ from story_protocol_python_sdk.types.resource.IPAsset import ( BatchMintAndRegisterIPInput, BatchMintAndRegisterIPResponse, + BatchRegisterIpAssetsWithOptimizedWorkflowsResponse, + BatchRegistrationResult, + IpRegistrationWorkflowRequest, LicenseTermsDataInput, LinkDerivativeResponse, + MintAndRegisterRequest, MintedNFT, MintNFT, RegisterAndAttachAndDistributeRoyaltyTokensResponse, RegisterDerivativeIPAndAttachAndDistributeRoyaltyTokensResponse, RegisterDerivativeIpAssetResponse, RegisteredIP, + RegisteredIPWithLicenseTermsIds, RegisterIpAssetResponse, RegisterPILTermsAndAttachResponse, + RegisterRegistrationRequest, RegistrationResponse, RegistrationWithRoyaltyVaultAndLicenseTermsResponse, RegistrationWithRoyaltyVaultResponse, ) -from story_protocol_python_sdk.types.resource.License import LicenseTermsInput from story_protocol_python_sdk.types.resource.Royalty import RoyaltyShareInput +from story_protocol_python_sdk.types.utils import ( + AggregatedRequestData, + ExtraData, + TransformedRegistrationRequest, +) from story_protocol_python_sdk.utils.constants import ( DEADLINE, MAX_ROYALTY_TOKEN, @@ -91,14 +98,18 @@ get_ip_metadata_dict, is_initial_ip_metadata, ) -from story_protocol_python_sdk.utils.licensing_config_data import LicensingConfigData -from story_protocol_python_sdk.utils.pil_flavor import PILFlavor -from story_protocol_python_sdk.utils.royalty import get_royalty_shares +from story_protocol_python_sdk.utils.registration.registration_utils import ( + prepare_distribute_royalty_tokens_requests, + send_transactions, +) +from story_protocol_python_sdk.utils.registration.transform_registration_request import ( + transform_distribute_royalty_tokens_request, + transform_request, + validate_license_terms_data, +) from story_protocol_python_sdk.utils.sign import Sign from story_protocol_python_sdk.utils.transaction_utils import build_and_send_transaction -from story_protocol_python_sdk.utils.util import convert_dict_keys_to_camel_case from story_protocol_python_sdk.utils.validation import ( - get_revenue_share, validate_address, validate_max_rts, ) @@ -275,9 +286,7 @@ def register( tx_options=tx_options, ) - ip_registered = self._parse_tx_ip_registered_event(response["tx_receipt"])[ - 0 - ] + ip_registered = self._get_registered_ips(response["tx_receipt"])[0] return {"tx_hash": response["tx_hash"], "ip_id": ip_registered["ip_id"]} @@ -531,47 +540,36 @@ def mint_and_register_ip_asset_with_pil_terms( :return dict: Dictionary with tx hash, IP ID, token ID, and license term IDs. """ try: - if not self.web3.is_address(spg_nft_contract): - raise ValueError( - f"The NFT contract address {spg_nft_contract} is not valid." - ) - license_terms = self._validate_license_terms_data(terms) - metadata = { - "ipMetadataURI": "", - "ipMetadataHash": ZERO_HASH, - "nftMetadataURI": "", - "nftMetadataHash": ZERO_HASH, - } - - if ip_metadata: - metadata.update( - { - "ipMetadataURI": ip_metadata.get("ip_metadata_uri", ""), - "ipMetadataHash": ip_metadata.get( - "ip_metadata_hash", ZERO_HASH - ), - "nftMetadataURI": ip_metadata.get("nft_metadata_uri", ""), - "nftMetadataHash": ip_metadata.get( - "nft_metadata_hash", ZERO_HASH - ), - } - ) + transformed_request = transform_request( + request=MintAndRegisterRequest( + spg_nft_contract=spg_nft_contract, + recipient=recipient, + ip_metadata=( + IPMetadataInput( + ip_metadata_uri=ip_metadata["ip_metadata_uri"], + ip_metadata_hash=ip_metadata["ip_metadata_hash"], + nft_metadata_uri=ip_metadata["nft_metadata_uri"], + nft_metadata_hash=ip_metadata["nft_metadata_hash"], + ) + if ip_metadata + else None + ), + license_terms_data=terms, + allow_duplicates=allow_duplicates, + ), + web3=self.web3, + account=self.account, + chain_id=self.chain_id, + ) response = build_and_send_transaction( self.web3, self.account, self.license_attachment_workflows_client.build_mintAndRegisterIpAndAttachPILTerms_transaction, - spg_nft_contract, - self._validate_recipient(recipient), - metadata, - license_terms, - allow_duplicates, + *transformed_request.validated_request, tx_options=tx_options, ) - - ip_registered = self._parse_tx_ip_registered_event(response["tx_receipt"])[ - 0 - ] + ip_registered = self._get_registered_ips(response["tx_receipt"])[0] license_terms_ids = self._parse_tx_license_terms_attached_event( response["tx_receipt"] ) @@ -645,9 +643,7 @@ def mint_and_register_ip( tx_options=tx_options, ) - ip_registered = self._parse_tx_ip_registered_event(response["tx_receipt"])[ - 0 - ] + ip_registered = self._get_registered_ips(response["tx_receipt"])[0] return { "tx_hash": response["tx_hash"], @@ -694,7 +690,7 @@ def batch_mint_and_register_ip( encoded_data, tx_options=tx_options, ) - registered_ips = self._parse_tx_ip_registered_event(response["tx_receipt"]) + registered_ips = self._get_registered_ips(response["tx_receipt"]) return BatchMintAndRegisterIPResponse( tx_hash=response["tx_hash"], registered_ips=registered_ips, @@ -756,84 +752,39 @@ def register_ip_and_attach_pil_terms( :return dict: A dictionary with the transaction hash, license terms ID, and IP ID. """ try: - ip_id = self._get_ip_id(nft_contract, token_id) - if self._is_registered(ip_id): - raise ValueError( - f"The NFT with id {token_id} is already registered as IP." - ) - license_terms = self._validate_license_terms_data(license_terms_data) - calculated_deadline = self.sign_util.get_deadline(deadline=deadline) + if not license_terms_data: + raise ValueError("License terms data must be provided.") - # Get permission signature for all required permissions - signature_response = self.sign_util.get_permission_signature( - ip_id=ip_id, - deadline=calculated_deadline, - state=self.web3.to_bytes(hexstr=HexStr(ZERO_HASH)), - permissions=[ - { - "ipId": ip_id, - "signer": self.license_attachment_workflows_client.contract.address, - "to": self.core_metadata_module_client.contract.address, - "permission": AccessPermission.ALLOW, - "func": "setAll(address,string,bytes32,bytes32)", - }, - { - "ipId": ip_id, - "signer": self.license_attachment_workflows_client.contract.address, - "to": self.licensing_module_client.contract.address, - "permission": AccessPermission.ALLOW, - "func": "attachLicenseTerms(address,address,uint256)", - }, - { - "ipId": ip_id, - "signer": self.license_attachment_workflows_client.contract.address, - "to": self.licensing_module_client.contract.address, - "permission": AccessPermission.ALLOW, - "func": "setLicensingConfig(address,address,uint256,(bool,uint256,address,bytes,uint32,bool,uint32,address))", - }, - ], + transformed_request = transform_request( + request=RegisterRegistrationRequest( + nft_contract=nft_contract, + token_id=token_id, + license_terms_data=license_terms_data, + ip_metadata=( + IPMetadataInput( + ip_metadata_uri=ip_metadata["ip_metadata_uri"], + ip_metadata_hash=ip_metadata["ip_metadata_hash"], + nft_metadata_uri=ip_metadata["nft_metadata_uri"], + nft_metadata_hash=ip_metadata["nft_metadata_hash"], + ) + if ip_metadata + else None + ), + deadline=deadline, + ), + web3=self.web3, + account=self.account, + chain_id=self.chain_id, ) - - metadata = { - "ipMetadataURI": "", - "ipMetadataHash": ZERO_HASH, - "nftMetadataURI": "", - "nftMetadataHash": ZERO_HASH, - } - - if ip_metadata: - metadata.update( - { - "ipMetadataURI": ip_metadata.get("ip_metadata_uri", ""), - "ipMetadataHash": ip_metadata.get( - "ip_metadata_hash", ZERO_HASH - ), - "nftMetadataURI": ip_metadata.get("nft_metadata_uri", ""), - "nftMetadataHash": ip_metadata.get( - "nft_metadata_hash", ZERO_HASH - ), - } - ) - response = build_and_send_transaction( self.web3, self.account, self.license_attachment_workflows_client.build_registerIpAndAttachPILTerms_transaction, - nft_contract, - token_id, - metadata, - license_terms, - { - "signer": self.web3.to_checksum_address(self.account.address), - "deadline": calculated_deadline, - "signature": signature_response["signature"], - }, + *transformed_request.validated_request, tx_options=tx_options, ) - ip_registered = self._parse_tx_ip_registered_event(response["tx_receipt"])[ - 0 - ] + ip_registered = self._get_registered_ips(response["tx_receipt"])[0] license_terms_ids = self._parse_tx_license_terms_attached_event( response["tx_receipt"] ) @@ -871,61 +822,27 @@ def register_derivative_ip( :return dict: Dictionary with the tx hash and IP ID. """ try: - ip_id = self._get_ip_id(nft_contract, token_id) - if self._is_registered(ip_id): - raise ValueError( - f"The NFT with id {token_id} is already registered as IP." - ) - validated_deriv_data = DerivativeData.from_input( - web3=self.web3, input_data=deriv_data - ).get_validated_data() - calculated_deadline = self.sign_util.get_deadline(deadline=deadline) - sig_register_signature = self.sign_util.get_permission_signature( - ip_id=ip_id, - deadline=calculated_deadline, - state=Web3.to_bytes(0), - permissions=[ - { - "ipId": ip_id, - "signer": self.derivative_workflows_client.contract.address, - "to": self.core_metadata_module_client.contract.address, - "permission": AccessPermission.ALLOW, - "func": get_function_signature( - self.core_metadata_module_client.contract.abi, - "setAll", - ), - }, - { - "ipId": ip_id, - "signer": self.derivative_workflows_client.contract.address, - "to": self.licensing_module_client.contract.address, - "permission": AccessPermission.ALLOW, - "func": get_function_signature( - self.licensing_module_client.contract.abi, - "registerDerivative", - ), - }, - ], + transformed_request = transform_request( + request=RegisterRegistrationRequest( + nft_contract=nft_contract, + token_id=token_id, + deriv_data=deriv_data, + ip_metadata=metadata, + deadline=deadline, + ), + web3=self.web3, + account=self.account, + chain_id=self.chain_id, ) response = build_and_send_transaction( self.web3, self.account, self.derivative_workflows_client.build_registerIpAndMakeDerivative_transaction, - nft_contract, - token_id, - validated_deriv_data, - IPMetadata.from_input(metadata).get_validated_data(), - { - "signer": self.account.address, - "deadline": calculated_deadline, - "signature": sig_register_signature["signature"], - }, + *transformed_request.validated_request, tx_options=tx_options, ) - ip_registered = self._parse_tx_ip_registered_event(response["tx_receipt"])[ - 0 - ] + ip_registered = self._get_registered_ips(response["tx_receipt"])[0] return { "tx_hash": response["tx_hash"], @@ -959,23 +876,26 @@ def mint_and_register_ip_and_make_derivative( """ try: - validated_deriv_data = DerivativeData.from_input( - web3=self.web3, input_data=deriv_data - ).get_validated_data() + transformed_request = transform_request( + request=MintAndRegisterRequest( + spg_nft_contract=spg_nft_contract, + recipient=recipient, + ip_metadata=ip_metadata, + deriv_data=deriv_data, + allow_duplicates=allow_duplicates, + ), + web3=self.web3, + account=self.account, + chain_id=self.chain_id, + ) response = build_and_send_transaction( self.web3, self.account, self.derivative_workflows_client.build_mintAndRegisterIpAndMakeDerivative_transaction, - validate_address(spg_nft_contract), - validated_deriv_data, - IPMetadata.from_input(ip_metadata).get_validated_data(), - self._validate_recipient(recipient), - allow_duplicates, + *transformed_request.validated_request, tx_options=tx_options, ) - ip_registered = self._parse_tx_ip_registered_event(response["tx_receipt"])[ - 0 - ] + ip_registered = self._get_registered_ips(response["tx_receipt"])[0] return RegistrationResponse( tx_hash=response["tx_hash"], ip_id=ip_registered["ip_id"], @@ -1025,9 +945,7 @@ def mint_and_register_ip_and_make_derivative_with_license_tokens( allow_duplicates, tx_options=tx_options, ) - ip_registered = self._parse_tx_ip_registered_event(response["tx_receipt"])[ - 0 - ] + ip_registered = self._get_registered_ips(response["tx_receipt"])[0] return RegistrationResponse( tx_hash=response["tx_hash"], ip_id=ip_registered["ip_id"], @@ -1123,9 +1041,7 @@ def register_ip_and_make_derivative_with_license_tokens( tx_options=tx_options, ) - ip_registered = self._parse_tx_ip_registered_event(response["tx_receipt"])[ - 0 - ] + ip_registered = self._get_registered_ips(response["tx_receipt"])[0] return RegistrationResponse( tx_hash=response["tx_hash"], @@ -1162,31 +1078,37 @@ def mint_and_register_ip_and_attach_pil_terms_and_distribute_royalty_tokens( :return `RegistrationWithRoyaltyVaultAndLicenseTermsResponse`: Response with tx hash, IP ID, token ID, license terms IDs, and royalty vault address. """ try: - validated_royalty_shares = get_royalty_shares(royalty_shares)[ - "royalty_shares" - ] - license_terms = self._validate_license_terms_data(license_terms_data) - + if not license_terms_data: + raise ValueError("License terms data must be provided.") + if not royalty_shares: + raise ValueError("Royalty shares must be provided.") + + transformed_request = transform_request( + request=MintAndRegisterRequest( + spg_nft_contract=spg_nft_contract, + license_terms_data=license_terms_data, + royalty_shares=royalty_shares, + ip_metadata=ip_metadata, + recipient=recipient, + allow_duplicates=allow_duplicates, + ), + web3=self.web3, + account=self.account, + chain_id=self.chain_id, + ) response = build_and_send_transaction( self.web3, self.account, self.royalty_token_distribution_workflows_client.build_mintAndRegisterIpAndAttachPILTermsAndDistributeRoyaltyTokens_transaction, - validate_address(spg_nft_contract), - self._validate_recipient(recipient), - IPMetadata.from_input(ip_metadata).get_validated_data(), - license_terms, - validated_royalty_shares, - allow_duplicates, + *transformed_request.validated_request, tx_options=tx_options, ) - ip_registered = self._parse_tx_ip_registered_event(response["tx_receipt"])[ - 0 - ] + ip_registered = self._get_registered_ips(response["tx_receipt"])[0] license_terms_ids = self._parse_tx_license_terms_attached_event( response["tx_receipt"] ) - royalty_vault = self.get_royalty_vault_address_by_ip_id( + royalty_vault = self._get_royalty_vault_address_by_ip_id( response["tx_receipt"], ip_registered["ip_id"], ) @@ -1227,28 +1149,32 @@ def mint_and_register_ip_and_make_derivative_and_distribute_royalty_tokens( :return `RegistrationWithRoyaltyVaultResponse`: Dictionary with the tx hash, IP ID and token ID, royalty vault. """ try: - validated_royalty_shares_obj = get_royalty_shares(royalty_shares) - validated_deriv_data = DerivativeData.from_input( - web3=self.web3, input_data=deriv_data - ).get_validated_data() + if not royalty_shares: + raise ValueError("Royalty shares must be provided.") + transformed_request = transform_request( + request=MintAndRegisterRequest( + spg_nft_contract=spg_nft_contract, + deriv_data=deriv_data, + royalty_shares=royalty_shares, + ip_metadata=ip_metadata, + recipient=recipient, + allow_duplicates=allow_duplicates, + ), + web3=self.web3, + account=self.account, + chain_id=self.chain_id, + ) response = build_and_send_transaction( self.web3, self.account, self.royalty_token_distribution_workflows_client.build_mintAndRegisterIpAndMakeDerivativeAndDistributeRoyaltyTokens_transaction, - validate_address(spg_nft_contract), - self._validate_recipient(recipient), - IPMetadata.from_input(ip_metadata).get_validated_data(), - validated_deriv_data, - validated_royalty_shares_obj["royalty_shares"], - allow_duplicates, + *transformed_request.validated_request, tx_options=tx_options, ) - ip_registered = self._parse_tx_ip_registered_event(response["tx_receipt"])[ - 0 - ] - royalty_vault = self.get_royalty_vault_address_by_ip_id( + ip_registered = self._get_registered_ips(response["tx_receipt"])[0] + royalty_vault = self._get_royalty_vault_address_by_ip_id( response["tx_receipt"], ip_registered["ip_id"], ) @@ -1288,73 +1214,45 @@ def register_derivative_ip_and_attach_pil_terms_and_distribute_royalty_tokens( :return `RegisterDerivativeIPAndAttachAndDistributeRoyaltyTokensResponse`: Response with tx hash, IP ID, token ID, royalty vault address, and distribute royalty tokens transaction hash. """ try: - nft_contract = validate_address(nft_contract) - ip_id = self._get_ip_id(nft_contract, token_id) - if self._is_registered(ip_id): - raise ValueError( - f"The NFT with id {token_id} is already registered as IP." - ) + if not royalty_shares: + raise ValueError("Royalty shares must be provided.") - validated_deriv_data = DerivativeData.from_input( - web3=self.web3, input_data=deriv_data - ).get_validated_data() - calculated_deadline = self.sign_util.get_deadline(deadline=deadline) - royalty_shares_obj = get_royalty_shares(royalty_shares) - - signature_response = self.sign_util.get_permission_signature( - ip_id=ip_id, - deadline=calculated_deadline, - state=self.web3.to_bytes(hexstr=HexStr(ZERO_HASH)), - permissions=[ - { - "ipId": ip_id, - "signer": self.royalty_token_distribution_workflows_client.contract.address, - "to": self.core_metadata_module_client.contract.address, - "permission": AccessPermission.ALLOW, - "func": "setAll(address,string,bytes32,bytes32)", - }, - { - "ipId": ip_id, - "signer": self.royalty_token_distribution_workflows_client.contract.address, - "to": self.licensing_module_client.contract.address, - "permission": AccessPermission.ALLOW, - "func": "registerDerivative(address,address[],uint256[],address,bytes,uint256,uint32,uint32)", - }, - ], + transformed_request = transform_request( + request=RegisterRegistrationRequest( + nft_contract=nft_contract, + token_id=token_id, + deriv_data=deriv_data, + royalty_shares=royalty_shares, + ip_metadata=ip_metadata, + deadline=deadline, + ), + web3=self.web3, + account=self.account, + chain_id=self.chain_id, ) response = build_and_send_transaction( self.web3, self.account, self.royalty_token_distribution_workflows_client.build_registerIpAndMakeDerivativeAndDeployRoyaltyVault_transaction, - nft_contract, - token_id, - IPMetadata.from_input(ip_metadata).get_validated_data(), - validated_deriv_data, - { - "signer": self.web3.to_checksum_address(self.account.address), - "deadline": calculated_deadline, - "signature": signature_response["signature"], - }, + *transformed_request.validated_request, tx_options=tx_options, ) - ip_registered = self._parse_tx_ip_registered_event(response["tx_receipt"])[ - 0 - ] - royalty_vault = self.get_royalty_vault_address_by_ip_id( + ip_registered = self._get_registered_ips(response["tx_receipt"])[0] + royalty_vault = self._get_royalty_vault_address_by_ip_id( response["tx_receipt"], ip_registered["ip_id"], ) - + extra_data = cast(ExtraData, transformed_request.extra_data) # Distribute royalty tokens distribute_tx_hash = self._distribute_royalty_tokens( ip_id=ip_registered["ip_id"], - royalty_shares=royalty_shares_obj["royalty_shares"], + royalty_shares=extra_data["royalty_shares"], royalty_vault=royalty_vault, - total_amount=royalty_shares_obj["total_amount"], + total_amount=extra_data["royalty_total_amount"], tx_options=tx_options, - deadline=calculated_deadline, + deadline=extra_data["deadline"], ) return RegisterDerivativeIPAndAttachAndDistributeRoyaltyTokensResponse( @@ -1395,79 +1293,48 @@ def register_ip_and_attach_pil_terms_and_distribute_royalty_tokens( :return `RegisterAndAttachAndDistributeRoyaltyTokensResponse`: Response with tx hash, license terms IDs, royalty vault address, and distribute royalty tokens transaction hash. """ try: - nft_contract = validate_address(nft_contract) - ip_id = self._get_ip_id(nft_contract, token_id) - if self._is_registered(ip_id): - raise ValueError( - f"The NFT with id {token_id} is already registered as IP." - ) - - license_terms = self._validate_license_terms_data(license_terms_data) - calculated_deadline = self.sign_util.get_deadline(deadline=deadline) - royalty_shares_obj = get_royalty_shares(royalty_shares) - signature_response = self.sign_util.get_permission_signature( - ip_id=ip_id, - deadline=calculated_deadline, - state=self.web3.to_bytes(hexstr=HexStr(ZERO_HASH)), - permissions=[ - { - "ipId": ip_id, - "signer": self.royalty_token_distribution_workflows_client.contract.address, - "to": self.core_metadata_module_client.contract.address, - "permission": AccessPermission.ALLOW, - "func": "setAll(address,string,bytes32,bytes32)", - }, - { - "ipId": ip_id, - "signer": self.royalty_token_distribution_workflows_client.contract.address, - "to": self.licensing_module_client.contract.address, - "permission": AccessPermission.ALLOW, - "func": "attachLicenseTerms(address,address,uint256)", - }, - { - "ipId": ip_id, - "signer": self.royalty_token_distribution_workflows_client.contract.address, - "to": self.licensing_module_client.contract.address, - "permission": AccessPermission.ALLOW, - "func": "setLicensingConfig(address,address,uint256,(bool,uint256,address,bytes,uint32,bool,uint32,address))", - }, - ], + if not royalty_shares: + raise ValueError("Royalty shares must be provided.") + if not license_terms_data: + raise ValueError("License terms data must be provided.") + + transformed_request = transform_request( + request=RegisterRegistrationRequest( + nft_contract=nft_contract, + token_id=token_id, + license_terms_data=license_terms_data, + royalty_shares=royalty_shares, + ip_metadata=ip_metadata, + ), + web3=self.web3, + account=self.account, + chain_id=self.chain_id, ) response = build_and_send_transaction( self.web3, self.account, self.royalty_token_distribution_workflows_client.build_registerIpAndAttachPILTermsAndDeployRoyaltyVault_transaction, - nft_contract, - token_id, - IPMetadata.from_input(ip_metadata).get_validated_data(), - license_terms, - { - "signer": self.web3.to_checksum_address(self.account.address), - "deadline": calculated_deadline, - "signature": signature_response["signature"], - }, + *transformed_request.validated_request, tx_options=tx_options, ) - ip_registered = self._parse_tx_ip_registered_event(response["tx_receipt"])[ - 0 - ] + ip_registered = self._get_registered_ips(response["tx_receipt"])[0] license_terms_ids = self._parse_tx_license_terms_attached_event( response["tx_receipt"] ) - royalty_vault = self.get_royalty_vault_address_by_ip_id( + royalty_vault = self._get_royalty_vault_address_by_ip_id( response["tx_receipt"], ip_registered["ip_id"], ) - + extra_data = cast(ExtraData, transformed_request.extra_data) # Distribute royalty tokens distribute_tx_hash = self._distribute_royalty_tokens( ip_id=ip_registered["ip_id"], - royalty_shares=royalty_shares_obj["royalty_shares"], + royalty_shares=extra_data["royalty_shares"], royalty_vault=royalty_vault, - total_amount=royalty_shares_obj["total_amount"], + total_amount=extra_data["royalty_total_amount"], tx_options=tx_options, - deadline=calculated_deadline, + deadline=extra_data["deadline"], ) return RegisterAndAttachAndDistributeRoyaltyTokensResponse( @@ -1505,7 +1372,7 @@ def register_pil_terms_and_attach( calculated_deadline = self.sign_util.get_deadline(deadline=deadline) ip_account_impl_client = IPAccountImplClient(self.web3, ip_id) state = ip_account_impl_client.state() - license_terms = self._validate_license_terms_data(license_terms_data) + license_terms = validate_license_terms_data(license_terms_data, self.web3) signature_response = self.sign_util.get_permission_signature( ip_id=ip_id, deadline=calculated_deadline, @@ -1625,6 +1492,208 @@ def register_ip_asset( except Exception as e: raise ValueError(f"Failed to register IP Asset: {str(e)}") from e + def batch_ip_asset_with_optimized_workflows( + self, + requests: Sequence[IpRegistrationWorkflowRequest], + is_use_multicall: bool = True, + tx_options: dict | None = None, + ) -> BatchRegisterIpAssetsWithOptimizedWorkflowsResponse: + """ + Batch register IP assets with optimized workflow selection. + + This method automatically selects the appropriate workflow based on input parameters and provides + intelligent transaction batching for better gas efficiency. + + **Request Types:** + + - `MintAndRegisterRequest`: Mint a new NFT from an SPG NFT contract and register as IP ID + - `RegisterRegistrationRequest`: Register an already minted NFT as IP ID + + **Workflow Selection:** + + For `MintAndRegisterRequest` (supports Multicall3 when `spg_nft_contract` has public minting enabled): + 1. `license_terms_data` + `royalty_shares` → `mintAndRegisterIpAndAttachPILTermsAndDistributeRoyaltyTokens` (contract method) + - Note: Always uses workflow's native multicall due to `msg.sender` limitation + 2. `license_terms_data` → `mintAndRegisterIpAndAttachPILTerms` + 3. `deriv_data` + `royalty_shares` → `mintAndRegisterIpAndMakeDerivativeAndDistributeRoyaltyTokens` (contract method) + 4. `deriv_data` → `mintAndRegisterIpAndMakeDerivative` (contract method) + 5. Other combinations throw `Invalid mint and register request type` error + + For `RegisterRegistrationRequest` (always uses workflow's native multicall due to signature requirements): + 1. `license_terms_data` + `royalty_shares` → `registerIpAndAttachPILTermsAndDeployRoyaltyVault` (contract method) + 2. `deriv_data` + `royalty_shares` → `registerIpAndMakeDerivativeAndDeployRoyaltyVault` (contract method) + 3. `license_terms_data` → `registerIpAndAttachPILTerms` (contract method) + 4. `deriv_data` → `registerIpAndMakeDerivative` (contract method) + 5. Other combinations throw `Invalid register request type` error + + **Multicall Strategy:** + + - Multicall3: Used when `is_use_multicall=True`, request is `MintAndRegisterRequest`, `spg_nft_contract` has public minting except for `mintAndRegisterIpAndAttachPILTermsAndDistributeRoyaltyTokens`. + - Workflow's native multicall: Used for all other cases + - Requests using the same workflow are aggregated into a single multicall transaction + + **Special Handling:** + + Royalty token distribution is handled in a separate transaction because it requires a signature with the + royalty vault address, which is only available after initial registration completes. + + :param requests `Sequence[IpRegistrationWorkflowRequest]`: The list of registration requests. + :param is_use_multicall `bool`: [Optional] Whether to use multicall3 for eligible workflows. (default: True) + :param tx_options `dict`: [Optional] Transaction options. + :return `BatchRegisterIpAssetsWithOptimizedWorkflowsResponse`: Response with registration results and distribute royalty tokens transaction hashes. + + **Example:** + + ```python + response = client.ip_asset.batch_ip_asset_with_optimized_workflows( + requests=[ + MintAndRegisterRequest( + spg_nft_contract="0x...", + license_terms_data=[...], + ip_metadata={...} + ), + RegisterRegistrationRequest( + nft_contract="0x...", + token_id=123, + deriv_data={...}, + ) + ], + ) + ``` + """ + try: + # Transform registration requests and send them into transaction + transformed_requests: list[TransformedRegistrationRequest] = [ + transform_request(request, self.web3, self.account, self.chain_id) + for request in requests + ] + tx_responses, aggregated_requests = send_transactions( + transformed_requests=transformed_requests, + is_use_multicall3=is_use_multicall, + web3=self.web3, + account=self.account, + tx_options=tx_options, + ) + # Extract royalty distribution requests from workflow responses that contain royalty shares + # We need to handle `distributeRoyaltyTokens` separately because this method requires + # a signature with the royalty vault address, which is only available after the initial registration + distribute_royalty_tokens_requests: list[TransformedRegistrationRequest] = ( + [] + ) + royalty_distribution_requests: list[ExtraData] = [ + tr.extra_data + for tr in transformed_requests + if tr.extra_data is not None + and tr.extra_data.get("royalty_shares", None) is not None + ] + # Parse the response of the registration responses and collect distribute royalty tokens requests + response_list: list[BatchRegistrationResult] = [] + for tx_response in tx_responses: + ip_registered_events = self._parse_tx_ip_registered_event( + tx_response["tx_receipt"] + ) + ip_royalty_vault_deployed_events = ( + self._parse_all_ip_royalty_vault_deployed_events( + tx_response["tx_receipt"] + ) + ) + transferred_distribute_royalty_tokens_requests, matching_vaults = ( + prepare_distribute_royalty_tokens_requests( + extra_data_list=royalty_distribution_requests, + web3=self.web3, + ip_registered=ip_registered_events, + royalty_vault=ip_royalty_vault_deployed_events, + account=self.account, + chain_id=self.chain_id, + ) + ) + + distribute_royalty_tokens_requests.extend( + transferred_distribute_royalty_tokens_requests + ) + response_list.append( + BatchRegistrationResult( + tx_hash=tx_response["tx_hash"], + registered_ips=[ + RegisteredIPWithLicenseTermsIds( + ip_id=log["ipId"], + token_id=log["tokenId"], + license_terms_ids=[], + ) + for log in ip_registered_events + ], + ip_royalty_vaults=matching_vaults, + ) + ) + # Send distribute royalty tokens requests + distribute_royalty_tokens_tx_responses, _ = ( + send_transactions( + transformed_requests=distribute_royalty_tokens_requests, + is_use_multicall3=is_use_multicall, + web3=self.web3, + account=self.account, + tx_options=tx_options, + ) + if distribute_royalty_tokens_requests + else ([], {}) + ) + + # Populate the license terms ids into the response + response_list_with_license_terms_ids = ( + self._populate_license_terms_ids_into_response( + response_list, aggregated_requests + ) + ) + + return BatchRegisterIpAssetsWithOptimizedWorkflowsResponse( + registration_results=response_list_with_license_terms_ids, + distribute_royalty_tokens_tx_hashes=[ + response["tx_hash"] + for response in distribute_royalty_tokens_tx_responses + ], + ) + except ValueError as e: + raise ValueError( + f"Failed to batch register IP assets with optimized workflows: {str(e)}" + ) from e + + def _populate_license_terms_ids_into_response( + self, + registration_results: list[BatchRegistrationResult], + aggregated_requests: dict[Address, AggregatedRequestData], + ) -> list[BatchRegistrationResult]: + # Flatten all license_terms_data from aggregated requests into a single list + all_license_terms_data = [ + license_terms_data + for value in aggregated_requests.values() + for license_terms_data in value["license_terms_data"] + ] + + # Create an iterator to automatically consume license_terms_data + license_terms_iter = iter(all_license_terms_data) + + # Populate license terms ids for each registered IP + for registration_result in registration_results: + for registered_ip in registration_result["registered_ips"]: + license_terms_data = next(license_terms_iter, None) + if license_terms_data: + registered_ip["license_terms_ids"] = self._get_license_terms_id( + license_terms_data + ) + + return registration_results + + def _get_license_terms_id(self, license_terms_data: list[dict]) -> list[int]: + """ + Get the license terms ids from the license terms data. + :param license_terms_data: The license terms data. + :return: The license terms ids. + """ + return [ + self.pi_license_template_client.getLicenseTermsId(license_terms["terms"]) + for license_terms in license_terms_data + ] + def _handle_minted_nft_registration( self, nft: MintedNFT, @@ -1967,75 +2036,6 @@ def _handle_mint_nft_derivative_registration( token_id=token_result["token_id"], ) - def _validate_derivative_data(self, derivative_data: dict) -> dict: - """ - Validates the derivative data and returns processed internal data. - - :param derivative_data dict: The derivative data to validate - :return dict: The processed internal derivative data - :raises ValueError: If validation fails - """ - internal_data = { - "childIpId": derivative_data["childIpId"], - "parentIpIds": derivative_data["parentIpIds"], - "licenseTermsIds": [int(id) for id in derivative_data["licenseTermsIds"]], - "licenseTemplate": ( - derivative_data.get("licenseTemplate") - if derivative_data.get("licenseTemplate") is not None - else self.pi_license_template_client.contract.address - ), - "royaltyContext": ZERO_ADDRESS, - "maxMintingFee": int(derivative_data.get("maxMintingFee", 0)), - "maxRts": int(derivative_data.get("maxRts", 0)), - "maxRevenueShare": int(derivative_data.get("maxRevenueShare", 0)), - } - - if not internal_data["parentIpIds"]: - raise ValueError("The parent IP IDs must be provided.") - - if not internal_data["licenseTermsIds"]: - raise ValueError("The license terms IDs must be provided.") - - if len(internal_data["parentIpIds"]) != len(internal_data["licenseTermsIds"]): - raise ValueError( - "The number of parent IP IDs must match the number of license terms IDs." - ) - - if internal_data["maxMintingFee"] < 0: - raise ValueError("The maxMintingFee must be greater than 0.") - - validate_max_rts(internal_data["maxRts"]) - - for parent_id, terms_id in zip( - internal_data["parentIpIds"], internal_data["licenseTermsIds"] - ): - if not self._is_registered(parent_id): - raise ValueError( - f"The parent IP with id {parent_id} is not registered." - ) - - if not self.license_registry_client.hasIpAttachedLicenseTerms( - parent_id, internal_data["licenseTemplate"], terms_id - ): - raise ValueError( - f"License terms id {terms_id} must be attached to the parent ipId " - f"{parent_id} before registering derivative." - ) - - royalty_percent = self.license_registry_client.getRoyaltyPercent( - parent_id, internal_data["licenseTemplate"], terms_id - ) - if ( - internal_data["maxRevenueShare"] != 0 - and royalty_percent > internal_data["maxRevenueShare"] - ): - raise ValueError( - f"The royalty percent for the parent IP with id {parent_id} is greater " - f"than the maximum revenue share {internal_data['maxRevenueShare']}." - ) - - return internal_data - def _validate_license_token_ids(self, license_token_ids: list) -> list: """ Validates the license token IDs and checks ownership. @@ -2084,36 +2084,21 @@ def _distribute_royalty_tokens( :return HexStr: The transaction hash. """ try: - ip_account_impl_client = IPAccountImplClient(self.web3, ip_id) - state = ip_account_impl_client.state() - - ip_royalty_vault_client = IpRoyaltyVaultImplClient(self.web3, royalty_vault) - - signature_response = self.sign_util.get_signature( - state=state, - to=royalty_vault, - encode_data=ip_royalty_vault_client.contract.encode_abi( - abi_element_identifier="approve", - args=[ - self.royalty_token_distribution_workflows_client.contract.address, - total_amount, - ], - ), - verifying_contract=ip_id, + transformed_request = transform_distribute_royalty_tokens_request( + ip_id=ip_id, + royalty_vault=royalty_vault, deadline=deadline, + web3=self.web3, + account=self.account, + chain_id=self.chain_id, + royalty_shares=royalty_shares, + total_amount=total_amount, ) - response = build_and_send_transaction( self.web3, self.account, self.royalty_token_distribution_workflows_client.build_distributeRoyaltyTokens_transaction, - ip_id, - royalty_shares, - { - "signer": self.web3.to_checksum_address(self.account.address), - "deadline": deadline, - "signature": signature_response["signature"], - }, + *transformed_request.validated_request, tx_options=tx_options, ) @@ -2143,29 +2128,39 @@ def _is_registered(self, ip_id: str) -> bool: """ return self.ip_asset_registry_client.isRegistered(ip_id) - def _parse_tx_ip_registered_event(self, tx_receipt: dict) -> list[RegisteredIP]: + def _parse_tx_ip_registered_event( + self, tx_receipt: dict + ) -> list[dict[str, int | Address]]: """ Parse the IPRegistered event from a transaction receipt. :param tx_receipt dict: The transaction receipt. - :return int: The IP ID and token ID from the event, or None. + :return list[dict[str, int | Address]]: The list of IPRegistered event logs. """ event_signature = self.web3.keccak( text="IPRegistered(address,uint256,address,uint256,string,string,uint256)" ).hex() - registered_ips: list[RegisteredIP] = [] - for log in tx_receipt["logs"]: - if log["topics"][0].hex() == event_signature: - event_result = self.ip_asset_registry_client.contract.events.IPRegistered.process_log( - log - ) - registered_ips.append( - RegisteredIP( - ip_id=event_result["args"]["ipId"], - token_id=event_result["args"]["tokenId"], - ) - ) - return registered_ips + registered_ip_logs: list[dict[str, int | Address]] = [ + self.ip_asset_registry_client.contract.events.IPRegistered.process_log(log)[ + "args" + ] + for log in tx_receipt["logs"] + if log["topics"][0].hex() == event_signature + ] + return registered_ip_logs + + def _get_registered_ips(self, tx_receipt: dict) -> list[RegisteredIP]: + """ + Get the registered IPs from a transaction receipt. + + :param tx_receipt dict: The transaction receipt. + :return list[RegisteredIP]: The list of registered IPs. + """ + registered_ip_logs = self._parse_tx_ip_registered_event(tx_receipt) + return [ + RegisteredIP(ip_id=log["ipId"], token_id=log["tokenId"]) + for log in registered_ip_logs + ] def _parse_tx_license_term_attached_event(self, tx_receipt: dict) -> int | None: """ @@ -2205,8 +2200,8 @@ def _parse_tx_license_terms_attached_event(self, tx_receipt: dict) -> list[int]: return license_terms_ids - def get_royalty_vault_address_by_ip_id( - self, tx_receipt: dict, ipId: Address + def _get_royalty_vault_address_by_ip_id( + self, tx_receipt: dict, ip_id: Address ) -> Address: """ Parse the IpRoyaltyVaultDeployed event from a transaction receipt and return the royalty vault address for a given IP ID. @@ -2215,16 +2210,17 @@ def get_royalty_vault_address_by_ip_id( :param ipId Address: The IP ID. :return Address: The royalty vault address. """ - event_signature = Web3.keccak( - text="IpRoyaltyVaultDeployed(address,address)" - ).hex() - for log in tx_receipt["logs"]: - if log["topics"][0].hex() == event_signature: - event_result = self.royalty_module_client.contract.events.IpRoyaltyVaultDeployed.process_log( - log - ) - if event_result["args"]["ipId"] == ipId: - return event_result["args"]["ipRoyaltyVault"] + ip_royalty_vault_deployed_events = ( + self._parse_all_ip_royalty_vault_deployed_events(tx_receipt) + ) + return next( + ( + event["ipRoyaltyVault"] + for event in ip_royalty_vault_deployed_events + if event["ipId"] == ip_id + ), + None, + ) def _validate_recipient(self, recipient: Address | None) -> Address: """ @@ -2237,54 +2233,24 @@ def _validate_recipient(self, recipient: Address | None) -> Address: return self.account.address return validate_address(recipient) - def _validate_license_terms_data( - self, license_terms_data: list[LicenseTermsDataInput] | list[dict] - ) -> list: + def _parse_all_ip_royalty_vault_deployed_events( + self, tx_receipt: dict + ) -> list[dict[str, Address]]: """ - Validate the license terms data. + Parse all IpRoyaltyVaultDeployed events from a transaction receipt. - :param license_terms_data `list[LicenseTermsDataInput]` or `list[dict]`: The license terms data to validate. - :return list: The validated license terms data. + :param tx_receipt dict: The transaction receipt. + :return list[dict[str, Address]]: List of dicts with keys "ipId" and "ipRoyaltyVault". """ - - validated_license_terms_data = [] - for term in license_terms_data: - if is_dataclass(term): - terms_dict = asdict(term.terms) - licensing_config_dict = term.licensing_config - else: - terms_dict = term["terms"] - licensing_config_dict = term["licensing_config"] - - license_terms = PILFlavor.validate_license_terms( - LicenseTermsInput(**terms_dict) - ) - license_terms = replace( - license_terms, - commercial_rev_share=get_revenue_share( - license_terms.commercial_rev_share - ), - ) - if license_terms.royalty_policy != ZERO_ADDRESS: - is_whitelisted = self.royalty_module_client.isWhitelistedRoyaltyPolicy( - license_terms.royalty_policy - ) - if not is_whitelisted: - raise ValueError("The royalty_policy is not whitelisted.") - - if license_terms.currency != ZERO_ADDRESS: - is_whitelisted = self.royalty_module_client.isWhitelistedRoyaltyToken( - license_terms.currency - ) - if not is_whitelisted: - raise ValueError("The currency is not whitelisted.") - - validated_license_terms_data.append( - { - "terms": convert_dict_keys_to_camel_case(asdict(license_terms)), - "licensingConfig": LicensingConfigData.validate_license_config( - self.module_registry_client, licensing_config_dict - ), - } - ) - return validated_license_terms_data + event_signature = Web3.keccak( + text="IpRoyaltyVaultDeployed(address,address)" + ).hex() + return [ + self.royalty_module_client.contract.events.IpRoyaltyVaultDeployed.process_log( + log + )[ + "args" + ] + for log in tx_receipt["logs"] + if log["topics"][0].hex() == event_signature + ] diff --git a/src/story_protocol_python_sdk/scripts/config.json b/src/story_protocol_python_sdk/scripts/config.json index 91df2b58..b8ea2801 100644 --- a/src/story_protocol_python_sdk/scripts/config.json +++ b/src/story_protocol_python_sdk/scripts/config.json @@ -163,7 +163,8 @@ "mintAndRegisterIpAndMakeDerivativeAndDistributeRoyaltyTokens", "registerIpAndAttachPILTermsAndDeployRoyaltyVault", "distributeRoyaltyTokens", - "registerIpAndMakeDerivativeAndDeployRoyaltyVault" + "registerIpAndMakeDerivativeAndDeployRoyaltyVault", + "multicall" ] }, { @@ -235,7 +236,7 @@ { "contract_name": "SPGNFTImpl", "contract_address": "0xc09e3788Fdfbd3dd8CDaa2aa481B52CcFAb74a42", - "functions": ["mintFeeToken", "mintFee"] + "functions": ["mintFeeToken", "mintFee", "publicMinting"] }, { "contract_name": "DerivativeWorkflows", @@ -244,13 +245,14 @@ "registerIpAndMakeDerivative", "mintAndRegisterIpAndMakeDerivative", "mintAndRegisterIpAndMakeDerivativeWithLicenseTokens", - "registerIpAndMakeDerivativeWithLicenseTokens" + "registerIpAndMakeDerivativeWithLicenseTokens", + "multicall" ] }, { "contract_name": "Multicall3", "contract_address": "0xcA11bde05977b3631167028862bE2a173976CA11", - "functions": ["aggregate3Value"] + "functions": ["aggregate3Value", "aggregate3"] }, { "contract_name": "WrappedIP", diff --git a/src/story_protocol_python_sdk/types/resource/IPAsset.py b/src/story_protocol_python_sdk/types/resource/IPAsset.py index d698a84b..2751c20c 100644 --- a/src/story_protocol_python_sdk/types/resource/IPAsset.py +++ b/src/story_protocol_python_sdk/types/resource/IPAsset.py @@ -4,6 +4,8 @@ from ens.ens import Address, HexStr from story_protocol_python_sdk.types.resource.License import LicenseTermsInput +from story_protocol_python_sdk.types.resource.Royalty import RoyaltyShareInput +from story_protocol_python_sdk.utils.derivative_data import DerivativeDataInput from story_protocol_python_sdk.utils.ip_metadata import IPMetadataInput from story_protocol_python_sdk.utils.licensing_config_data import LicensingConfig @@ -237,3 +239,124 @@ class LinkDerivativeResponse(TypedDict): """ tx_hash: HexStr + + +# ============================================================================= +# Batch Registration Types for batch_register_ip_assets_with_optimized_workflows +# ============================================================================= + + +@dataclass +class MintAndRegisterRequest: + """ + Request for mint and register IP operations. + + Used for(contract method): + - mintAndRegisterIpAssetWithPilTerms + - mintAndRegisterIpAndMakeDerivative + - mintAndRegisterIpAndAttachPilTermsAndDistributeRoyaltyTokens + - mintAndRegisterIpAndMakeDerivativeAndDistributeRoyaltyTokens + + Attributes: + spg_nft_contract: The address of the SPG NFT contract. + recipient: [Optional] The address to receive the NFT. Defaults to caller's wallet address. + allow_duplicates: [Optional] Set to true to allow minting an NFT with a duplicate metadata hash. (default: True) + ip_metadata: [Optional] The metadata for the newly minted NFT and registered IP. + license_terms_data: [Optional] The license terms data to attach. Required if not using `deriv_data`. + deriv_data: [Optional] The derivative data for creating derivative IP. Required if not using `license_terms_data`. + royalty_shares: [Optional] The royalty shares for distributing royalty tokens. Must be specified together with either `license_terms_data` or `deriv_data`. + """ + + spg_nft_contract: Address + recipient: Address | None = None + allow_duplicates: bool | None = True + ip_metadata: IPMetadataInput | None = None + license_terms_data: list[LicenseTermsDataInput] | None = None + deriv_data: DerivativeDataInput | None = None + royalty_shares: list[RoyaltyShareInput] | None = None + + +@dataclass +class RegisterRegistrationRequest: + """ + Request for register IP operations (already minted NFT). + + Used for(contract method): + - registerIpAndAttachPilTerms + - registerIpAndMakeDerivative + - registerIpAndAttachPilTermsAndDeployRoyaltyVault + - registerIpAndMakeDerivativeAndDeployRoyaltyVault + + Attributes: + nft_contract: The address of the NFT contract. + token_id: The token ID of the NFT. + ip_metadata: [Optional] The metadata for the registered IP. + deadline: [Optional] The deadline for the signature in seconds. (default: 1000) + license_terms_data: [Optional] The license terms data to attach. Required if not using `deriv_data`. + deriv_data: [Optional] The derivative data for creating derivative IP. Required if not using `license_terms_data`. + royalty_shares: [Optional] The royalty shares for distributing royalty tokens. Must be specified together with either `license_terms_data` or `deriv_data`. + """ + + nft_contract: Address + token_id: int + ip_metadata: IPMetadataInput | None = None + deadline: int | None = None + license_terms_data: list[LicenseTermsDataInput] | None = None + deriv_data: DerivativeDataInput | None = None + royalty_shares: list[RoyaltyShareInput] | None = None + + +# Union type for all registration requests +IpRegistrationWorkflowRequest = MintAndRegisterRequest | RegisterRegistrationRequest + + +class IPRoyaltyVault(TypedDict): + """ + IP royalty vault. + + Attributes: + ip_id: The IP ID. + royalty_vault: The royalty vault address. + """ + + ip_id: Address + royalty_vault: Address + + +class RegisteredIPWithLicenseTermsIds(RegisteredIP): + """ + Data structure for IP and token ID with license terms IDs. + + Attributes: + license_terms_ids: The license terms IDs of the registered IP asset. + """ + + license_terms_ids: list[int] + + +class BatchRegistrationResult(TypedDict, total=False): + """ + Result of a single batch registration transaction. + + Attributes: + tx_hash: The transaction hash. + registered_ips: List of registered IP assets (ip_id, token_id, license_terms_ids). + ip_royalty_vaults: [Optional] List of IP royalty vaults for deployed royalty vaults. + """ + + tx_hash: HexStr + registered_ips: list[RegisteredIPWithLicenseTermsIds] + ip_royalty_vaults: list[IPRoyaltyVault] + + +class BatchRegisterIpAssetsWithOptimizedWorkflowsResponse(TypedDict, total=False): + """ + Response for batch register IP assets with optimized workflows. + + Attributes: + registration_results: List of batch registration results. + distribute_royalty_tokens_tx_hashes: [Optional] Transaction hashes for royalty token distribution. + """ + + registration_results: list[BatchRegistrationResult] + distribute_royalty_tokens_tx_hashes: list[HexStr] diff --git a/src/story_protocol_python_sdk/types/utils.py b/src/story_protocol_python_sdk/types/utils.py new file mode 100644 index 00000000..a4e97e10 --- /dev/null +++ b/src/story_protocol_python_sdk/types/utils.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass +from typing import TypedDict + +from ens.ens import Address, HexStr +from typing_extensions import Callable + +from story_protocol_python_sdk.types.resource.Royalty import RoyaltyShareInput + + +class Multicall3Call(TypedDict): + target: Address + allowFailure: bool + value: int + callData: bytes + + +class AggregatedRequestData(TypedDict): + """Aggregated request data structure.""" + + call_data: list[bytes | Multicall3Call] + license_terms_data: list[list[dict]] + method_reference: Callable[[list[bytes], dict], HexStr] + + +# ============================================================================= +# Transform Registration Request Types +# ============================================================================= +class ExtraData(TypedDict, total=False): + """ + Extra data for post-processing after registration. + + Attributes: + royalty_shares: [Optional] The royalty shares for distribution. + deadline: [Optional] The deadline for the signature. + royalty_total_amount: [Optional] The total amount of royalty tokens to distribute. + nft_contract: [Optional] The NFT contract address. + token_id: [Optional] The token ID. + license_terms_data: [Optional] The license terms data. + """ + + royalty_shares: list[RoyaltyShareInput] + deadline: int + royalty_total_amount: int + nft_contract: Address + token_id: int + license_terms_data: list[dict] | None + + +@dataclass +class TransformedRegistrationRequest: + """ + Transformed registration request with encoded data and multicall info. + + Attributes: + encoded_tx_data: The encoded transaction data. + is_use_multicall3: Whether to use multicall3 or SPG's native multicall. + workflow_address: The workflow contract address. + validated_request: The validated request arguments for the contract method. + workflow_multicall_reference: The multicall reference for the workflow. + extra_data: [Optional] Extra data for post-processing. + """ + + encoded_tx_data: bytes + is_use_multicall3: bool + workflow_address: Address + validated_request: list[Address | int | str | bytes | dict | bool] + workflow_multicall_reference: Callable[..., HexStr] + extra_data: ExtraData | None = None diff --git a/src/story_protocol_python_sdk/utils/registration/registration_utils.py b/src/story_protocol_python_sdk/utils/registration/registration_utils.py new file mode 100644 index 00000000..6ef57ae7 --- /dev/null +++ b/src/story_protocol_python_sdk/utils/registration/registration_utils.py @@ -0,0 +1,161 @@ +"""Registration utilities for IP asset operations.""" + +from ens.ens import Address, HexStr +from eth_account.signers.local import LocalAccount +from web3 import Web3 + +from story_protocol_python_sdk.abi.Multicall3.Multicall3_client import Multicall3Client +from story_protocol_python_sdk.types.resource.IPAsset import IPRoyaltyVault +from story_protocol_python_sdk.types.utils import ( + AggregatedRequestData, + ExtraData, + TransformedRegistrationRequest, +) +from story_protocol_python_sdk.utils.registration.transform_registration_request import ( + transform_distribute_royalty_tokens_request, +) +from story_protocol_python_sdk.utils.transaction_utils import build_and_send_transaction + + +def aggregate_multicall_requests( + requests: list[TransformedRegistrationRequest], + is_use_multicall3: bool, + web3: Web3, +) -> dict[Address, AggregatedRequestData]: + """ + Aggregate multicall requests by grouping them by target address. + + Groups requests that should be sent to the same multicall address together, + collecting their encoded transaction data and method references. + """ + aggregated_requests: dict[Address, AggregatedRequestData] = {} + multicall3_client = Multicall3Client(web3) + + for request in requests: + # Determine the target address for this request + target_address = ( + multicall3_client.contract.address + if request.is_use_multicall3 and is_use_multicall3 + else request.workflow_address + ) + + # Initialize entry if it doesn't exist + if target_address not in aggregated_requests: + aggregated_requests[target_address] = { + "call_data": [], + "license_terms_data": [], + "method_reference": ( + multicall3_client.build_aggregate3_transaction + if target_address == multicall3_client.contract.address + else request.workflow_multicall_reference + ), + } + if target_address == multicall3_client.contract.address: + aggregated_requests[target_address]["call_data"].append( + { + "target": request.workflow_address, + "allowFailure": False, + "value": 0, + "callData": request.encoded_tx_data, + } + ) + else: + aggregated_requests[target_address]["call_data"].append( + request.encoded_tx_data + ) + license_terms_data = ( + request.extra_data.get("license_terms_data") or [] + if request.extra_data is not None + else [] + ) + aggregated_requests[target_address]["license_terms_data"].append( + license_terms_data + ) + + return aggregated_requests + + +def prepare_distribute_royalty_tokens_requests( + extra_data_list: list[ExtraData], + web3: Web3, + ip_registered: list[dict[str, int | Address]], + royalty_vault: list[dict[str, Address]], + account: LocalAccount, + chain_id: int, +) -> tuple[list[TransformedRegistrationRequest], list[IPRoyaltyVault]]: + if not extra_data_list: + return [], [] + transformed_requests: list[TransformedRegistrationRequest] = [] + matching_vaults: list[IPRoyaltyVault] = [] + for extra_data in extra_data_list: + # Find matching IP registration + ip_registered_match = next( + ( + x + for x in ip_registered + if x["tokenContract"] == extra_data["nft_contract"] + and x["tokenId"] == extra_data["token_id"] + ), + None, + ) + if not ip_registered_match: + continue + + ip_id = ip_registered_match["ipId"] + + # Find matching royalty vault + matching_vault = next( + (x for x in royalty_vault if x["ipId"] == ip_id), + None, + ) + if not matching_vault: + continue + + ip_royalty_vault = matching_vault["ipRoyaltyVault"] + matching_vaults.append( + IPRoyaltyVault(ip_id=ip_id, royalty_vault=ip_royalty_vault) + ) + transformed_request = transform_distribute_royalty_tokens_request( + ip_id=ip_id, + royalty_vault=ip_royalty_vault, + deadline=extra_data["deadline"], + web3=web3, + account=account, + chain_id=chain_id, + royalty_shares=extra_data["royalty_shares"], + total_amount=extra_data["royalty_total_amount"], + ) + transformed_requests.append(transformed_request) + return transformed_requests, matching_vaults + + +def send_transactions( + transformed_requests: list[TransformedRegistrationRequest], + is_use_multicall3: bool, + web3: Web3, + account: LocalAccount, + tx_options: dict | None = None, +) -> tuple[list[dict[str, HexStr | dict]], dict[Address, AggregatedRequestData]]: + aggregated_requests: dict[Address, AggregatedRequestData] = ( + aggregate_multicall_requests( + requests=transformed_requests, + is_use_multicall3=is_use_multicall3, + web3=web3, + ) + ) + tx_results: list[dict[str, HexStr | dict]] = [] + for request_data in aggregated_requests.values(): + response = build_and_send_transaction( + web3, + account, + request_data["method_reference"], + request_data["call_data"], + tx_options=tx_options, + ) + tx_results.append( + { + "tx_hash": response["tx_hash"], + "tx_receipt": response["tx_receipt"], + } + ) + return tx_results, aggregated_requests diff --git a/src/story_protocol_python_sdk/utils/registration/transform_registration_request.py b/src/story_protocol_python_sdk/utils/registration/transform_registration_request.py new file mode 100644 index 00000000..0cd079e5 --- /dev/null +++ b/src/story_protocol_python_sdk/utils/registration/transform_registration_request.py @@ -0,0 +1,905 @@ +"""Transform registration request utilities.""" + +from dataclasses import asdict, is_dataclass, replace + +from ens.ens import Address, HexStr +from eth_account.signers.local import LocalAccount +from typing_extensions import cast +from web3 import Web3 + +from story_protocol_python_sdk.abi.CoreMetadataModule.CoreMetadataModule_client import ( + CoreMetadataModuleClient, +) +from story_protocol_python_sdk.abi.DerivativeWorkflows.DerivativeWorkflows_client import ( + DerivativeWorkflowsClient, +) +from story_protocol_python_sdk.abi.IPAccountImpl.IPAccountImpl_client import ( + IPAccountImplClient, +) +from story_protocol_python_sdk.abi.IPAssetRegistry.IPAssetRegistry_client import ( + IPAssetRegistryClient, +) +from story_protocol_python_sdk.abi.IpRoyaltyVaultImpl.IpRoyaltyVaultImpl_client import ( + IpRoyaltyVaultImplClient, +) +from story_protocol_python_sdk.abi.LicenseAttachmentWorkflows.LicenseAttachmentWorkflows_client import ( + LicenseAttachmentWorkflowsClient, +) +from story_protocol_python_sdk.abi.LicensingModule.LicensingModule_client import ( + LicensingModuleClient, +) +from story_protocol_python_sdk.abi.ModuleRegistry.ModuleRegistry_client import ( + ModuleRegistryClient, +) +from story_protocol_python_sdk.abi.RoyaltyModule.RoyaltyModule_client import ( + RoyaltyModuleClient, +) +from story_protocol_python_sdk.abi.RoyaltyTokenDistributionWorkflows.RoyaltyTokenDistributionWorkflows_client import ( + RoyaltyTokenDistributionWorkflowsClient, +) +from story_protocol_python_sdk.abi.SPGNFTImpl.SPGNFTImpl_client import SPGNFTImplClient +from story_protocol_python_sdk.types.common import AccessPermission +from story_protocol_python_sdk.types.resource.IPAsset import ( + IpRegistrationWorkflowRequest, + LicenseTermsDataInput, + MintAndRegisterRequest, + RegisterRegistrationRequest, +) +from story_protocol_python_sdk.types.resource.License import LicenseTermsInput +from story_protocol_python_sdk.types.resource.Royalty import RoyaltyShareInput +from story_protocol_python_sdk.types.utils import ( + ExtraData, + TransformedRegistrationRequest, +) +from story_protocol_python_sdk.utils.constants import ZERO_ADDRESS, ZERO_HASH +from story_protocol_python_sdk.utils.derivative_data import DerivativeData +from story_protocol_python_sdk.utils.function_signature import get_function_signature +from story_protocol_python_sdk.utils.ip_metadata import IPMetadata +from story_protocol_python_sdk.utils.licensing_config_data import LicensingConfigData +from story_protocol_python_sdk.utils.pil_flavor import PILFlavor +from story_protocol_python_sdk.utils.royalty import get_royalty_shares +from story_protocol_python_sdk.utils.sign import Sign +from story_protocol_python_sdk.utils.util import convert_dict_keys_to_camel_case +from story_protocol_python_sdk.utils.validation import ( + get_revenue_share, + validate_address, +) + + +def get_public_minting(spg_nft_contract: Address, web3: Web3) -> bool: + """ + Check if SPG NFT contract has public minting enabled. + + Args: + spg_nft_contract: The address of the SPG NFT contract. + web3: Web3 instance. + + Returns: + True if public minting is enabled, False otherwise. + """ + spg_client = SPGNFTImplClient( + web3, contract_address=validate_address(spg_nft_contract) + ) + return spg_client.publicMinting() + + +def validate_license_terms_data( + license_terms_data: list[LicenseTermsDataInput] | list[dict], + web3: Web3, +) -> list[dict]: + """ + Validate the license terms data. + + Args: + license_terms_data: The license terms data to validate. + web3: Web3 instance. + + Returns: + The validated license terms data. + """ + royalty_module_client = RoyaltyModuleClient(web3) + module_registry_client = ModuleRegistryClient(web3) + + validated_license_terms_data = [] + for term in license_terms_data: + if is_dataclass(term): + terms_dict = asdict(term.terms) + licensing_config_dict = term.licensing_config + else: + terms_dict = term["terms"] + licensing_config_dict = term["licensing_config"] + + license_terms = PILFlavor.validate_license_terms( + LicenseTermsInput(**terms_dict) + ) + license_terms = replace( + license_terms, + commercial_rev_share=get_revenue_share(license_terms.commercial_rev_share), + ) + if license_terms.royalty_policy != ZERO_ADDRESS: + is_whitelisted = royalty_module_client.isWhitelistedRoyaltyPolicy( + license_terms.royalty_policy + ) + if not is_whitelisted: + raise ValueError("The royalty_policy is not whitelisted.") + + if license_terms.currency != ZERO_ADDRESS: + is_whitelisted = royalty_module_client.isWhitelistedRoyaltyToken( + license_terms.currency + ) + if not is_whitelisted: + raise ValueError("The currency is not whitelisted.") + + validated_license_terms_data.append( + { + "terms": convert_dict_keys_to_camel_case(asdict(license_terms)), + "licensingConfig": LicensingConfigData.validate_license_config( + module_registry_client, licensing_config_dict + ), + } + ) + return validated_license_terms_data + + +def transform_request( + request: IpRegistrationWorkflowRequest, + web3: Web3, + account: LocalAccount, + chain_id: int, +) -> TransformedRegistrationRequest: + """ + Transform a registration request into encoded transaction data with multicall info. + + This is the main entry point for processing registration requests. It: + 1. Validates all input parameters + 2. Generates required signatures (for register* methods) + 3. Encodes the transaction data + 4. Determines whether to use multicall3 or SPG's native multicall + + """ + if hasattr(request, "spg_nft_contract"): + return _handle_mint_and_register_request( + cast(MintAndRegisterRequest, request), web3, account.address + ) + elif hasattr(request, "nft_contract") and hasattr(request, "token_id"): + return _handle_register_request(request, web3, account, chain_id) + else: + raise ValueError("Invalid registration request type") + + +def transform_distribute_royalty_tokens_request( + ip_id: Address, + royalty_vault: Address, + deadline: int, + web3: Web3, + account: LocalAccount, + chain_id: int, + royalty_shares: list[RoyaltyShareInput], + total_amount: int, +) -> TransformedRegistrationRequest: + ip_account_impl_client = IPAccountImplClient(web3, ip_id) + state = ip_account_impl_client.state() + royalty_token_distribution_workflows_client = ( + RoyaltyTokenDistributionWorkflowsClient(web3) + ) + ip_royalty_vault_client = IpRoyaltyVaultImplClient(web3, royalty_vault) + signature_response = Sign(web3, chain_id, account).get_signature( + state=state, + to=royalty_vault, + encode_data=ip_royalty_vault_client.contract.encode_abi( + abi_element_identifier="approve", + args=[ + RoyaltyTokenDistributionWorkflowsClient(web3).contract.address, + total_amount, + ], + ), + verifying_contract=ip_id, + deadline=deadline, + ) + validated_request = [ + ip_id, + royalty_shares, + { + "signer": web3.to_checksum_address(account.address), + "deadline": deadline, + "signature": signature_response["signature"], + }, + ] + encoded_data = royalty_token_distribution_workflows_client.contract.encode_abi( + abi_element_identifier="distributeRoyaltyTokens", + args=validated_request, + ) + return TransformedRegistrationRequest( + encoded_tx_data=encoded_data, + is_use_multicall3=False, + workflow_address=royalty_token_distribution_workflows_client.contract.address, + validated_request=validated_request, + workflow_multicall_reference=royalty_token_distribution_workflows_client.build_multicall_transaction, + extra_data=None, + ) + + +# ============================================================================= +# Mint and Register Request Handlers +# ============================================================================= + + +def _handle_mint_and_register_request( + request: MintAndRegisterRequest, + web3: Web3, + wallet_address: Address, +) -> TransformedRegistrationRequest: + """ + Handle mintAndRegister* workflow requests. + + Supports (contract method): + - mintAndRegisterIpAndAttachPILTermsAndDistributeRoyaltyTokens + - mintAndRegisterIpAndMakeDerivativeAndDistributeRoyaltyTokens + - mintAndRegisterIpAndAttachPILTerms + - mintAndRegisterIpAndMakeDerivative + + Multicall strategy: + - Public minting enabled: Uses multicall3 + - Public minting disabled: Uses SPG's native multicall + """ + spg_nft_contract = validate_address(request.spg_nft_contract) + is_public_minting = get_public_minting(spg_nft_contract, web3) + recipient = ( + validate_address(request.recipient) if request.recipient else wallet_address + ) + license_terms_data = ( + validate_license_terms_data(request.license_terms_data, web3) + if request.license_terms_data + else None + ) + deriv_data = ( + DerivativeData.from_input( + web3=web3, input_data=request.deriv_data + ).get_validated_data() + if request.deriv_data + else None + ) + royalty_shares = ( + get_royalty_shares(request.royalty_shares)["royalty_shares"] + if request.royalty_shares + else None + ) + metadata = IPMetadata.from_input(request.ip_metadata).get_validated_data() + # Build encoded data based on request type + if license_terms_data and royalty_shares: + return _handle_mint_and_register_with_license_terms_and_royalty_tokens( + web3=web3, + spg_nft_contract=spg_nft_contract, + recipient=recipient, + metadata=metadata, + license_terms_data=license_terms_data, + royalty_shares=royalty_shares, + allow_duplicates=request.allow_duplicates, + ) + + elif deriv_data and royalty_shares: + return _handle_mint_and_register_with_derivative_and_royalty_tokens( + web3=web3, + spg_nft_contract=spg_nft_contract, + recipient=recipient, + metadata=metadata, + deriv_data=deriv_data, + royalty_shares=royalty_shares, + allow_duplicates=request.allow_duplicates, + is_public_minting=is_public_minting, + ) + + elif license_terms_data: + return _handle_mint_and_register_with_license_terms( + web3=web3, + spg_nft_contract=spg_nft_contract, + recipient=recipient, + metadata=metadata, + license_terms_data=license_terms_data, + allow_duplicates=request.allow_duplicates, + is_public_minting=is_public_minting, + ) + + elif deriv_data: + return _handle_mint_and_register_with_derivative( + web3=web3, + spg_nft_contract=spg_nft_contract, + recipient=recipient, + metadata=metadata, + deriv_data=deriv_data, + allow_duplicates=request.allow_duplicates, + is_public_minting=is_public_minting, + ) + + else: + raise ValueError("Invalid mint and register request type") + + +def _handle_mint_and_register_with_license_terms_and_royalty_tokens( + web3: Web3, + spg_nft_contract: Address, + recipient: Address, + metadata: dict, + license_terms_data: list[dict], + royalty_shares: list[dict], + allow_duplicates: bool | None, +) -> TransformedRegistrationRequest: + royalty_token_distribution_workflows_client = ( + RoyaltyTokenDistributionWorkflowsClient(web3) + ) + royalty_token_distribution_workflows_address = ( + royalty_token_distribution_workflows_client.contract.address + ) + + validated_request = [ + spg_nft_contract, + recipient, + metadata, + license_terms_data, + royalty_shares, + allow_duplicates, + ] + encoded_data = royalty_token_distribution_workflows_client.contract.encode_abi( + abi_element_identifier="mintAndRegisterIpAndAttachPILTermsAndDistributeRoyaltyTokens", + args=validated_request, + ) + + return TransformedRegistrationRequest( + encoded_tx_data=encoded_data, + # Because mint tokens is given `msg.sender` as the recipient, so we need to set `useMulticall3` to false. + is_use_multicall3=False, + workflow_address=royalty_token_distribution_workflows_address, + workflow_multicall_reference=royalty_token_distribution_workflows_client.build_multicall_transaction, + validated_request=validated_request, + extra_data=ExtraData( + license_terms_data=license_terms_data, + ), + ) + + +def _handle_mint_and_register_with_derivative_and_royalty_tokens( + web3: Web3, + spg_nft_contract: Address, + recipient: Address, + metadata: dict, + deriv_data: dict, + royalty_shares: list[dict], + allow_duplicates: bool | None, + is_public_minting: bool, +) -> TransformedRegistrationRequest: + royalty_token_distribution_workflows_client = ( + RoyaltyTokenDistributionWorkflowsClient(web3) + ) + royalty_token_distribution_workflows_address = ( + royalty_token_distribution_workflows_client.contract.address + ) + + validated_request = [ + spg_nft_contract, + recipient, + metadata, + deriv_data, + royalty_shares, + allow_duplicates, + ] + encoded_data = royalty_token_distribution_workflows_client.contract.encode_abi( + abi_element_identifier="mintAndRegisterIpAndMakeDerivativeAndDistributeRoyaltyTokens", + args=validated_request, + ) + + return TransformedRegistrationRequest( + encoded_tx_data=encoded_data, + is_use_multicall3=is_public_minting, + workflow_address=royalty_token_distribution_workflows_address, + validated_request=validated_request, + extra_data=None, + workflow_multicall_reference=royalty_token_distribution_workflows_client.build_multicall_transaction, + ) + + +def _handle_mint_and_register_with_license_terms( + web3: Web3, + spg_nft_contract: Address, + recipient: Address, + metadata: dict, + license_terms_data: list[dict], + allow_duplicates: bool | None, + is_public_minting: bool, +) -> TransformedRegistrationRequest: + license_attachment_workflows_client = LicenseAttachmentWorkflowsClient(web3) + license_attachment_workflows_address = ( + license_attachment_workflows_client.contract.address + ) + validated_request = [ + spg_nft_contract, + recipient, + metadata, + license_terms_data, + allow_duplicates, + ] + encoded_data = license_attachment_workflows_client.contract.encode_abi( + abi_element_identifier="mintAndRegisterIpAndAttachPILTerms", + args=validated_request, + ) + + return TransformedRegistrationRequest( + encoded_tx_data=encoded_data, + is_use_multicall3=is_public_minting, + workflow_address=license_attachment_workflows_address, + validated_request=validated_request, + extra_data=ExtraData( + license_terms_data=license_terms_data, + ), + workflow_multicall_reference=license_attachment_workflows_client.build_multicall_transaction, + ) + + +def _handle_mint_and_register_with_derivative( + web3: Web3, + spg_nft_contract: Address, + recipient: Address, + metadata: dict, + deriv_data: dict, + allow_duplicates: bool | None, + is_public_minting: bool, +) -> TransformedRegistrationRequest: + derivative_workflows_client = DerivativeWorkflowsClient(web3) + derivative_workflows_address = derivative_workflows_client.contract.address + validated_request = [ + spg_nft_contract, + deriv_data, + metadata, + recipient, + allow_duplicates, + ] + encoded_data = derivative_workflows_client.contract.encode_abi( + abi_element_identifier="mintAndRegisterIpAndMakeDerivative", + args=validated_request, + ) + + return TransformedRegistrationRequest( + encoded_tx_data=encoded_data, + is_use_multicall3=is_public_minting, + workflow_address=derivative_workflows_address, + validated_request=validated_request, + extra_data=None, + workflow_multicall_reference=derivative_workflows_client.build_multicall_transaction, + ) + + +# ============================================================================= +# Register Request Handlers +# ============================================================================= + + +def _handle_register_request( + request: RegisterRegistrationRequest, + web3: Web3, + account: LocalAccount, + chain_id: int, +) -> TransformedRegistrationRequest: + """ + Handle register* workflow requests (already minted NFTs). + + Supports (contract method): + - registerIpAndAttachPILTermsAndDeployRoyaltyVault + - registerIpAndMakeDerivativeAndDeployRoyaltyVault + - registerIpAndAttachPILTerms + - registerIpAndMakeDerivative + + Note: register* methods always use SPG's native multicall because + signatures require msg.sender preservation. + """ + ip_asset_registry_client = IPAssetRegistryClient(web3) + ip_id = ip_asset_registry_client.ipId( + chain_id, request.nft_contract, request.token_id + ) + if ip_asset_registry_client.isRegistered(ip_id): + raise ValueError( + f"The NFT with id {request.token_id} is already registered as IP." + ) + + nft_contract = validate_address(request.nft_contract) + sign_util = Sign(web3=web3, chain_id=chain_id, account=account) + core_metadata_module_client = CoreMetadataModuleClient(web3) + licensing_module_client = LicensingModuleClient(web3) + license_terms_data = ( + validate_license_terms_data(request.license_terms_data, web3) + if request.license_terms_data + else None + ) + deriv_data = ( + DerivativeData.from_input( + web3=web3, input_data=request.deriv_data + ).get_validated_data() + if request.deriv_data + else None + ) + royalty_shares = ( + get_royalty_shares(request.royalty_shares) if request.royalty_shares else None + ) + state = web3.to_bytes(hexstr=HexStr(ZERO_HASH)) + metadata = IPMetadata.from_input(request.ip_metadata).get_validated_data() + calculated_deadline = sign_util.get_deadline(deadline=request.deadline) + wallet_address = account.address + if license_terms_data and royalty_shares: + return _handle_register_with_license_terms_and_royalty_vault( + web3=web3, + nft_contract=nft_contract, + token_id=request.token_id, + metadata=metadata, + license_terms_data=license_terms_data, + royalty_shares=royalty_shares["royalty_shares"], + royalty_total_amount=royalty_shares["total_amount"], + ip_id=ip_id, + wallet_address=wallet_address, + calculated_deadline=calculated_deadline, + sign_util=sign_util, + core_metadata_module_client=core_metadata_module_client, + licensing_module_client=licensing_module_client, + state=state, + ) + + elif deriv_data and royalty_shares: + return _handle_register_with_derivative_and_royalty_vault( + web3=web3, + nft_contract=nft_contract, + token_id=request.token_id, + metadata=metadata, + deriv_data=deriv_data, + royalty_shares=royalty_shares["royalty_shares"], + ip_id=ip_id, + wallet_address=wallet_address, + calculated_deadline=calculated_deadline, + sign_util=sign_util, + core_metadata_module_client=core_metadata_module_client, + licensing_module_client=licensing_module_client, + state=state, + royalty_total_amount=royalty_shares["total_amount"], + ) + + elif license_terms_data: + return _handle_register_with_license_terms( + web3=web3, + nft_contract=nft_contract, + token_id=request.token_id, + metadata=metadata, + license_terms_data=license_terms_data, + ip_id=ip_id, + wallet_address=wallet_address, + calculated_deadline=calculated_deadline, + sign_util=sign_util, + core_metadata_module_client=core_metadata_module_client, + licensing_module_client=licensing_module_client, + state=state, + ) + + elif deriv_data: + return _handle_register_with_derivative( + web3=web3, + nft_contract=nft_contract, + deriv_data=deriv_data, + metadata=metadata, + token_id=request.token_id, + wallet_address=wallet_address, + ip_id=ip_id, + calculated_deadline=calculated_deadline, + sign_util=sign_util, + core_metadata_module_client=core_metadata_module_client, + licensing_module_client=licensing_module_client, + state=state, + ) + + else: + raise ValueError("Invalid register request type") + + +def _handle_register_with_license_terms_and_royalty_vault( + web3: Web3, + nft_contract: Address, + token_id: int, + metadata: dict, + license_terms_data: list[dict], + royalty_shares: list[dict], + ip_id: Address, + wallet_address: Address, + calculated_deadline: int, + sign_util: Sign, + core_metadata_module_client: CoreMetadataModuleClient, + licensing_module_client: LicensingModuleClient, + state: bytes, + royalty_total_amount: int, +) -> TransformedRegistrationRequest: + royalty_token_distribution_workflows_client = ( + RoyaltyTokenDistributionWorkflowsClient(web3) + ) + royalty_token_distribution_workflows_address = ( + royalty_token_distribution_workflows_client.contract.address + ) + signature_data = sign_util.get_permission_signature( + ip_id=ip_id, + deadline=calculated_deadline, + state=state, + permissions=_get_license_terms_permissions( + ip_id=ip_id, + signer_address=royalty_token_distribution_workflows_address, + core_metadata_client=core_metadata_module_client, + licensing_module_client=licensing_module_client, + ), + ) + validated_request = [ + nft_contract, + token_id, + metadata, + license_terms_data, + { + "signer": wallet_address, + "deadline": calculated_deadline, + "signature": signature_data["signature"], + }, + ] + encoded_data = royalty_token_distribution_workflows_client.contract.encode_abi( + abi_element_identifier="registerIpAndAttachPILTermsAndDeployRoyaltyVault", + args=validated_request, + ) + + return TransformedRegistrationRequest( + encoded_tx_data=encoded_data, + is_use_multicall3=False, + workflow_address=royalty_token_distribution_workflows_address, + validated_request=validated_request, + extra_data=ExtraData( + royalty_shares=cast(list[RoyaltyShareInput], royalty_shares), + deadline=calculated_deadline, + royalty_total_amount=royalty_total_amount, + nft_contract=nft_contract, + token_id=token_id, + license_terms_data=license_terms_data, + ), + workflow_multicall_reference=royalty_token_distribution_workflows_client.build_multicall_transaction, + ) + + +def _handle_register_with_derivative_and_royalty_vault( + web3: Web3, + nft_contract: Address, + token_id: int, + metadata: dict, + deriv_data: dict, + royalty_shares: list[dict], + ip_id: Address, + wallet_address: Address, + calculated_deadline: int, + sign_util: Sign, + core_metadata_module_client: CoreMetadataModuleClient, + licensing_module_client: LicensingModuleClient, + state: bytes, + royalty_total_amount: int, +) -> TransformedRegistrationRequest: + royalty_token_distribution_workflows_client = ( + RoyaltyTokenDistributionWorkflowsClient(web3) + ) + royalty_token_distribution_workflows_address = ( + royalty_token_distribution_workflows_client.contract.address + ) + signature_response = sign_util.get_permission_signature( + ip_id=ip_id, + deadline=calculated_deadline, + state=state, + permissions=_get_derivative_permissions( + ip_id=ip_id, + signer_address=royalty_token_distribution_workflows_address, + core_metadata_client=core_metadata_module_client, + licensing_module_client=licensing_module_client, + ), + ) + validated_request = [ + nft_contract, + token_id, + metadata, + deriv_data, + { + "signer": wallet_address, + "deadline": calculated_deadline, + "signature": signature_response["signature"], + }, + ] + encoded_data = royalty_token_distribution_workflows_client.contract.encode_abi( + abi_element_identifier="registerIpAndMakeDerivativeAndDeployRoyaltyVault", + args=validated_request, + ) + + return TransformedRegistrationRequest( + encoded_tx_data=encoded_data, + is_use_multicall3=False, + workflow_address=royalty_token_distribution_workflows_address, + validated_request=validated_request, + extra_data=ExtraData( + royalty_shares=cast(list[RoyaltyShareInput], royalty_shares), + deadline=calculated_deadline, + royalty_total_amount=royalty_total_amount, + nft_contract=nft_contract, + token_id=token_id, + ), + workflow_multicall_reference=royalty_token_distribution_workflows_client.build_multicall_transaction, + ) + + +def _handle_register_with_license_terms( + web3: Web3, + nft_contract: Address, + token_id: int, + metadata: dict, + license_terms_data: list[dict], + ip_id: Address, + wallet_address: Address, + calculated_deadline: int, + sign_util: Sign, + core_metadata_module_client: CoreMetadataModuleClient, + licensing_module_client: LicensingModuleClient, + state: bytes, +) -> TransformedRegistrationRequest: + """Handle registerIpAndAttachPILTerms.""" + license_attachment_workflows_client = LicenseAttachmentWorkflowsClient(web3) + license_attachment_workflows_address = ( + license_attachment_workflows_client.contract.address + ) + signature_data = sign_util.get_permission_signature( + ip_id=ip_id, + deadline=calculated_deadline, + state=state, + permissions=_get_license_terms_permissions( + ip_id=ip_id, + signer_address=license_attachment_workflows_address, + core_metadata_client=core_metadata_module_client, + licensing_module_client=licensing_module_client, + ), + ) + validated_request = [ + nft_contract, + token_id, + metadata, + license_terms_data, + { + "signer": wallet_address, + "deadline": calculated_deadline, + "signature": signature_data["signature"], + }, + ] + encoded_data = license_attachment_workflows_client.contract.encode_abi( + abi_element_identifier="registerIpAndAttachPILTerms", + args=validated_request, + ) + + return TransformedRegistrationRequest( + encoded_tx_data=encoded_data, + is_use_multicall3=False, + workflow_address=license_attachment_workflows_address, + validated_request=validated_request, + extra_data=ExtraData( + license_terms_data=license_terms_data, + ), + workflow_multicall_reference=license_attachment_workflows_client.build_multicall_transaction, + ) + + +def _handle_register_with_derivative( + web3: Web3, + nft_contract: Address, + deriv_data: dict, + metadata: dict, + token_id: int, + ip_id: Address, + wallet_address: Address, + calculated_deadline: int, + sign_util: Sign, + core_metadata_module_client: CoreMetadataModuleClient, + licensing_module_client: LicensingModuleClient, + state: bytes, +) -> TransformedRegistrationRequest: + """Handle registerIpAndMakeDerivative.""" + derivative_workflows_client = DerivativeWorkflowsClient(web3) + derivative_workflows_address = derivative_workflows_client.contract.address + signature_data = sign_util.get_permission_signature( + ip_id=ip_id, + deadline=calculated_deadline, + state=state, + permissions=_get_derivative_permissions( + ip_id=ip_id, + signer_address=derivative_workflows_address, + core_metadata_client=core_metadata_module_client, + licensing_module_client=licensing_module_client, + ), + ) + validated_request = [ + nft_contract, + token_id, + deriv_data, + metadata, + { + "signer": wallet_address, + "deadline": calculated_deadline, + "signature": signature_data["signature"], + }, + ] + encoded_data = derivative_workflows_client.contract.encode_abi( + abi_element_identifier="registerIpAndMakeDerivative", + args=validated_request, + ) + + return TransformedRegistrationRequest( + encoded_tx_data=encoded_data, + is_use_multicall3=False, + workflow_address=derivative_workflows_address, + validated_request=validated_request, + extra_data=None, + workflow_multicall_reference=derivative_workflows_client.build_multicall_transaction, + ) + + +# ============================================================================= +# Internal Helper Methods +# ============================================================================= + + +def _get_license_terms_permissions( + ip_id: Address, + signer_address: Address, + core_metadata_client: CoreMetadataModuleClient, + licensing_module_client: LicensingModuleClient, +) -> list[dict]: + """Get permissions for license terms operations.""" + return [ + { + "ipId": ip_id, + "signer": signer_address, + "to": core_metadata_client.contract.address, + "permission": AccessPermission.ALLOW, + "func": get_function_signature(core_metadata_client.contract.abi, "setAll"), + }, + { + "ipId": ip_id, + "signer": signer_address, + "to": licensing_module_client.contract.address, + "permission": AccessPermission.ALLOW, + "func": get_function_signature( + licensing_module_client.contract.abi, "attachLicenseTerms" + ), + }, + { + "ipId": ip_id, + "signer": signer_address, + "to": licensing_module_client.contract.address, + "permission": AccessPermission.ALLOW, + "func": get_function_signature( + licensing_module_client.contract.abi, "setLicensingConfig" + ), + }, + ] + + +def _get_derivative_permissions( + ip_id: Address, + signer_address: Address, + core_metadata_client: CoreMetadataModuleClient, + licensing_module_client: LicensingModuleClient, +) -> list[dict]: + """Get permissions for derivative operations.""" + return [ + { + "ipId": ip_id, + "signer": signer_address, + "to": core_metadata_client.contract.address, + "permission": AccessPermission.ALLOW, + "func": get_function_signature(core_metadata_client.contract.abi, "setAll"), + }, + { + "ipId": ip_id, + "signer": signer_address, + "to": licensing_module_client.contract.address, + "permission": AccessPermission.ALLOW, + "func": get_function_signature( + licensing_module_client.contract.abi, "registerDerivative" + ), + }, + ] diff --git a/tests/integration/test_integration_group.py b/tests/integration/test_integration_group.py index 6f8f3001..896546d0 100644 --- a/tests/integration/test_integration_group.py +++ b/tests/integration/test_integration_group.py @@ -231,10 +231,8 @@ def test_claim_reward(self, story_client: StoryClient, nft_collection: Address): licensor_ip_id=ip_id, license_template=PIL_LICENSE_TEMPLATE, license_terms_id=license_terms_id, - amount=100, + amount=1, receiver=ip_id, - max_minting_fee=1, - max_revenue_share=100, ) # Claim reward diff --git a/tests/integration/test_integration_ip_asset.py b/tests/integration/test_integration_ip_asset.py index f1ec1bf8..5b4cecf6 100644 --- a/tests/integration/test_integration_ip_asset.py +++ b/tests/integration/test_integration_ip_asset.py @@ -8,14 +8,17 @@ BatchMintAndRegisterIPInput, DerivativeDataInput, IPMetadataInput, + IpRegistrationWorkflowRequest, LicenseTermsDataInput, LicenseTermsInput, LicenseTermsOverride, LicensingConfig, + MintAndRegisterRequest, MintedNFT, MintNFT, NativeRoyaltyPolicy, PILFlavor, + RegisterRegistrationRequest, RoyaltyShareInput, StoryClient, ) @@ -53,6 +56,34 @@ ) +@pytest.fixture(scope="module") +def public_nft_collection(story_client: StoryClient): + tx_data = story_client.NFTClient.create_nft_collection( + name="test-public-collection", + symbol="TEST", + max_supply=100, + is_public_minting=True, + mint_open=True, + contract_uri="test-uri", + mint_fee_recipient=account.address, + ) + return tx_data["nft_contract"] + + +@pytest.fixture(scope="module") +def private_nft_collection(story_client: StoryClient): + tx_data = story_client.NFTClient.create_nft_collection( + name="test-private-collection", + symbol="TEST", + max_supply=100, + is_public_minting=False, + mint_open=True, + contract_uri="test-uri", + mint_fee_recipient=account.address, + ) + return tx_data["nft_contract"] + + class TestIPAssetRegistration: @pytest.fixture(scope="module") def child_ip_id(self, story_client: StoryClient): @@ -308,19 +339,6 @@ def test_register_ip_and_make_derivative_with_license_tokens_with_metadata( class TestIPAssetMinting: - @pytest.fixture(scope="module") - def nft_collection(self, story_client: StoryClient): - tx_data = story_client.NFTClient.create_nft_collection( - name="test-collection", - symbol="TEST", - max_supply=100, - is_public_minting=True, - mint_open=True, - contract_uri="test-uri", - mint_fee_recipient=account.address, - ) - return tx_data["nft_contract"] - def test_mint_register_attach_terms( self, story_client: StoryClient, nft_collection ): @@ -1758,3 +1776,929 @@ def test_link_derivative_with_license_tokens( assert "tx_hash" in response assert isinstance(response["tx_hash"], str) assert len(response["tx_hash"]) > 0 + + +class TestBatchRegisterIpAssetsWithOptimizedWorkflows: + def test_batch_register_ip_assets_with_optimized_workflows_with_register_registration_request( + self, + story_client: StoryClient, + nft_collection, + ): + """Test batch register IP assets with optimized workflows.""" + token_id_1 = get_token_id(MockERC721, story_client.web3, story_client.account) + token_id_2 = get_token_id(MockERC721, story_client.web3, story_client.account) + token_id_3 = get_token_id(MockERC721, story_client.web3, story_client.account) + token_id_4 = get_token_id(MockERC721, story_client.web3, story_client.account) + token_id_5 = get_token_id(MockERC721, story_client.web3, story_client.account) + parent_ip_and_license_terms_1 = create_parent_ip_and_license_terms( + story_client, nft_collection, account + ) + parent_ip_and_license_terms_2 = create_parent_ip_and_license_terms( + story_client, nft_collection, account + ) + requests = [ + # LicenseAttachmentWorkflowsClient + RegisterRegistrationRequest( + nft_contract=MockERC721, + token_id=token_id_1, + ip_metadata=COMMON_IP_METADATA, + deadline=100000, + license_terms_data=[ + LicenseTermsDataInput( + terms=PILFlavor.commercial_use( + default_minting_fee=1, + currency=MockERC20, + royalty_policy=NativeRoyaltyPolicy.LAP, + ), + licensing_config=LicensingConfig( + is_set=True, + minting_fee=1, + licensing_hook=ZERO_ADDRESS, + hook_data=ZERO_HASH, + commercial_rev_share=50, + disabled=False, + expect_minimum_group_reward_share=0, + expect_group_reward_pool=ZERO_ADDRESS, + ), + ) + ], + ), + # RoyaltyTokenDistributionWorkflowsClient + RegisterRegistrationRequest( + nft_contract=MockERC721, + token_id=token_id_2, + license_terms_data=[ + LicenseTermsDataInput( + terms=PILFlavor.non_commercial_social_remixing(), + licensing_config=LicensingConfig( + is_set=True, + minting_fee=0, + licensing_hook=ZERO_ADDRESS, + hook_data=ZERO_HASH, + commercial_rev_share=0, + disabled=False, + expect_minimum_group_reward_share=0, + expect_group_reward_pool=ZERO_ADDRESS, + ), + ), + LicenseTermsDataInput( + terms=PILFlavor.commercial_use( + default_minting_fee=10, + currency=MockERC20, + royalty_policy=NativeRoyaltyPolicy.LAP, + ), + licensing_config=LicensingConfig( + is_set=True, + minting_fee=10, + licensing_hook=ZERO_ADDRESS, + hook_data=ZERO_HASH, + commercial_rev_share=50, + disabled=False, + expect_minimum_group_reward_share=0, + expect_group_reward_pool=ZERO_ADDRESS, + ), + ), + ], + royalty_shares=[ + RoyaltyShareInput(recipient=account.address, percentage=50.0), + RoyaltyShareInput(recipient=account_2.address, percentage=50.0), + ], + ), + # DerivativeWorkflowsClient + RegisterRegistrationRequest( + nft_contract=MockERC721, + token_id=token_id_3, + deriv_data=DerivativeDataInput( + parent_ip_ids=[parent_ip_and_license_terms_1["parent_ip_id"]], + license_terms_ids=[ + parent_ip_and_license_terms_1["license_terms_id"] + ], + ), + ), + # RoyaltyTokenDistributionWorkflowsClient + RegisterRegistrationRequest( + nft_contract=MockERC721, + token_id=token_id_4, + deriv_data=DerivativeDataInput( + parent_ip_ids=[parent_ip_and_license_terms_2["parent_ip_id"]], + license_terms_ids=[ + parent_ip_and_license_terms_2["license_terms_id"] + ], + ), + royalty_shares=[ + RoyaltyShareInput(recipient=account.address, percentage=60.0), + RoyaltyShareInput(recipient=account_2.address, percentage=40.0), + ], + ), + # RoyaltyTokenDistributionWorkflowsClient + RegisterRegistrationRequest( + nft_contract=MockERC721, + token_id=token_id_5, + ip_metadata=COMMON_IP_METADATA, + deadline=100000, + license_terms_data=[ + LicenseTermsDataInput( + terms=PILFlavor.commercial_use( + default_minting_fee=10, + currency=MockERC20, + royalty_policy=NativeRoyaltyPolicy.LAP, + ), + licensing_config=LicensingConfig( + is_set=True, + minting_fee=10, + licensing_hook=ZERO_ADDRESS, + hook_data=ZERO_HASH, + commercial_rev_share=50, + disabled=False, + expect_minimum_group_reward_share=0, + expect_group_reward_pool=ZERO_ADDRESS, + ), + ) + ], + royalty_shares=[ + RoyaltyShareInput(recipient=account.address, percentage=50.0), + RoyaltyShareInput(recipient=account_2.address, percentage=50.0), + ], + ), + ] + # Enhanced: Thoroughly verify transaction aggregation, registration output, + # and cross-check the actual on-chain registered asset state with expectations. + # + # Expectations: + # - 3 total blockchain transactions (aggregated by workflow): + # 1. LicenseAttachmentWorkflowsClient: attaches license terms (1 tx) +1 license terms id + # 2. RoyaltyTokenDistributionWorkflowsClient: batch of 3, merged to 1 tx (with vault creation) + # 2.1: 2 license terms id + # 2.2: 0 license terms id + # 2.3: 1 license terms id + # 3. DerivativeWorkflowsClient: creates derivatives (1 tx) + # - Only 1 distribute_royalty_tokens_tx_hash, even for multiple assets with royalty shares. + response = story_client.IPAsset.batch_ip_asset_with_optimized_workflows( + requests=requests, + ) + # Assert batch-level structure and invariants + assert isinstance(response, dict) + assert "registration_results" in response + assert "distribute_royalty_tokens_tx_hashes" in response + registration_results = response["registration_results"] + assert registration_results[0]["tx_hash"] is not None + assert len(registration_results[0]["registered_ips"]) == 1 + assert ( + len(registration_results[0]["registered_ips"][0]["license_terms_ids"]) == 1 + ) + assert len(registration_results[0]["ip_royalty_vaults"]) == 0 + + assert registration_results[1]["tx_hash"] is not None + assert len(registration_results[1]["registered_ips"]) == 3 + assert ( + len(registration_results[1]["registered_ips"][0]["license_terms_ids"]) == 2 + ) + assert ( + len(registration_results[1]["registered_ips"][1]["license_terms_ids"]) == 0 + ) + assert ( + len(registration_results[1]["registered_ips"][2]["license_terms_ids"]) == 1 + ) + assert len(registration_results[1]["ip_royalty_vaults"]) == 3 + + assert registration_results[2]["tx_hash"] is not None + assert len(registration_results[2]["registered_ips"]) == 1 + assert ( + len(registration_results[2]["registered_ips"][0]["license_terms_ids"]) == 0 + ) + assert len(registration_results[2]["ip_royalty_vaults"]) == 0 + + def test_batch_register_ip_assets_with_optimized_workflows_with_mint_and_register_ip_request( + self, + story_client: StoryClient, + public_nft_collection, + private_nft_collection, + ): + """Test batch register IP assets with optimized workflows with mint and register IP request.""" + parent_ip_and_license_terms_1 = create_parent_ip_and_license_terms( + story_client, public_nft_collection, account + ) + parent_ip_and_license_terms_2 = create_parent_ip_and_license_terms( + story_client, private_nft_collection, account + ) + requests = [ + MintAndRegisterRequest( + spg_nft_contract=public_nft_collection, + recipient=account.address, + allow_duplicates=True, + ip_metadata=COMMON_IP_METADATA, + license_terms_data=[ + LicenseTermsDataInput( + terms=PILFlavor.commercial_use( + default_minting_fee=10, + currency=MockERC20, + royalty_policy=NativeRoyaltyPolicy.LAP, + ), + licensing_config=LicensingConfig( + is_set=True, + minting_fee=10, + licensing_hook=ZERO_ADDRESS, + hook_data=ZERO_HASH, + commercial_rev_share=50, + disabled=False, + expect_minimum_group_reward_share=0, + expect_group_reward_pool=ZERO_ADDRESS, + ), + ), + ], + ), + MintAndRegisterRequest( + spg_nft_contract=public_nft_collection, + recipient=account.address, + allow_duplicates=True, + deriv_data=DerivativeDataInput( + parent_ip_ids=[parent_ip_and_license_terms_1["parent_ip_id"]], + license_terms_ids=[ + parent_ip_and_license_terms_1["license_terms_id"] + ], + ), + ), + MintAndRegisterRequest( + spg_nft_contract=public_nft_collection, + recipient=account.address, + allow_duplicates=True, + deriv_data=DerivativeDataInput( + parent_ip_ids=[parent_ip_and_license_terms_2["parent_ip_id"]], + license_terms_ids=[ + parent_ip_and_license_terms_2["license_terms_id"] + ], + ), + royalty_shares=[ + RoyaltyShareInput(recipient=account.address, percentage=60.0), + RoyaltyShareInput(recipient=account_2.address, percentage=40.0), + ], + ), + # public minting + royalty_token_distribution_workflows_client+ workflow_multicall + MintAndRegisterRequest( + spg_nft_contract=public_nft_collection, + license_terms_data=[ + LicenseTermsDataInput( + terms=PILFlavor.commercial_use( + default_minting_fee=0, + currency=MockERC20, + royalty_policy=NativeRoyaltyPolicy.LAP, + ), + licensing_config=LicensingConfig( + is_set=True, + minting_fee=0, + licensing_hook=ZERO_ADDRESS, + hook_data=ZERO_HASH, + commercial_rev_share=50, + disabled=False, + expect_minimum_group_reward_share=0, + expect_group_reward_pool=ZERO_ADDRESS, + ), + ), + ], + royalty_shares=[ + RoyaltyShareInput(recipient=account.address, percentage=50.0), + RoyaltyShareInput(recipient=account_2.address, percentage=50.0), + ], + ), + # private minting + license_attachment_workflows_client + MintAndRegisterRequest( + spg_nft_contract=private_nft_collection, + allow_duplicates=True, + license_terms_data=[ + LicenseTermsDataInput( + terms=PILFlavor.commercial_use( + default_minting_fee=10, + currency=MockERC20, + royalty_policy=NativeRoyaltyPolicy.LAP, + ), + licensing_config=LicensingConfig( + is_set=True, + minting_fee=10, + licensing_hook=ZERO_ADDRESS, + hook_data=ZERO_HASH, + commercial_rev_share=50, + disabled=False, + expect_minimum_group_reward_share=0, + expect_group_reward_pool=ZERO_ADDRESS, + ), + ), + LicenseTermsDataInput( + terms=PILFlavor.non_commercial_social_remixing(), + licensing_config=LicensingConfig( + is_set=True, + minting_fee=0, + licensing_hook=ZERO_ADDRESS, + hook_data=ZERO_HASH, + commercial_rev_share=0, + disabled=False, + expect_minimum_group_reward_share=0, + expect_group_reward_pool=ZERO_ADDRESS, + ), + ), + ], + ), + # private minting + royalty_token_distribution_workflows_client + MintAndRegisterRequest( + spg_nft_contract=private_nft_collection, + license_terms_data=[ + LicenseTermsDataInput( + terms=PILFlavor.commercial_use( + default_minting_fee=20, + currency=MockERC20, + royalty_policy=NativeRoyaltyPolicy.LAP, + ), + licensing_config=LicensingConfig( + is_set=True, + minting_fee=20, + licensing_hook=ZERO_ADDRESS, + hook_data=ZERO_HASH, + commercial_rev_share=50, + disabled=False, + expect_minimum_group_reward_share=0, + expect_group_reward_pool=ZERO_ADDRESS, + ), + ), + ], + royalty_shares=[ + RoyaltyShareInput(recipient=account.address, percentage=50.0), + RoyaltyShareInput(recipient=account_2.address, percentage=50.0), + ], + ), + # # private minting + royalty_token_distribution_workflows_client + MintAndRegisterRequest( + spg_nft_contract=private_nft_collection, + deriv_data=DerivativeDataInput( + parent_ip_ids=[parent_ip_and_license_terms_1["parent_ip_id"]], + license_terms_ids=[ + parent_ip_and_license_terms_1["license_terms_id"] + ], + ), + royalty_shares=[ + RoyaltyShareInput(recipient=account.address, percentage=50.0), + RoyaltyShareInput(recipient=account_2.address, percentage=50.0), + ], + ), + # private minting + derivative_workflows_client + MintAndRegisterRequest( + spg_nft_contract=private_nft_collection, + deriv_data=DerivativeDataInput( + parent_ip_ids=[parent_ip_and_license_terms_2["parent_ip_id"]], + license_terms_ids=[ + parent_ip_and_license_terms_2["license_terms_id"] + ], + ), + ), + ] + # Enhanced: Thoroughly verify transaction aggregation, registration output, + # and cross-check the actual on-chain registered asset state with expectations. + # + # Enhanced verification of batch registration logic and on-chain state: + # + # Expectations: + # - 4 total blockchain transactions: + # 1. Multicall3: 3 ip_ids + # 1.1: 1 license_terms_id + # 2. royalty_token_distribution_workflows_client: 3 ip_ids + # 2.1: 1 license_terms_id + # 2.2: 1 license_terms_id + # 2.3: + # 3. license_attachment_workflows_client 1 ip_id + # 2.1: 2 license_terms_id + # 4. derivative_workflows_client 1 ip_id + # + # --- Enhanced assertions and on-chain checks follow below --- + response = story_client.IPAsset.batch_ip_asset_with_optimized_workflows( + requests=requests, + ) + assert isinstance(response, dict) + assert "registration_results" in response + assert "distribute_royalty_tokens_tx_hashes" in response + + registration_results = response["registration_results"] + assert registration_results[0]["tx_hash"] is not None + assert len(registration_results[0]["registered_ips"]) == 3 + assert ( + len(registration_results[0]["registered_ips"][0]["license_terms_ids"]) == 1 + ) + assert ( + len(registration_results[0]["registered_ips"][1]["license_terms_ids"]) == 0 + ) + assert ( + len(registration_results[0]["registered_ips"][2]["license_terms_ids"]) == 0 + ) + + assert len(registration_results[1]["registered_ips"]) == 3 + assert ( + len(registration_results[1]["registered_ips"][0]["license_terms_ids"]) == 1 + ) + assert ( + len(registration_results[1]["registered_ips"][1]["license_terms_ids"]) == 1 + ) + assert ( + len(registration_results[1]["registered_ips"][2]["license_terms_ids"]) == 0 + ) + + assert len(registration_results[2]["registered_ips"]) == 1 + assert ( + len(registration_results[2]["registered_ips"][0]["license_terms_ids"]) == 2 + ) + + assert len(registration_results[3]["registered_ips"]) == 1 + assert ( + len(registration_results[3]["registered_ips"][0]["license_terms_ids"]) == 0 + ) + + assert len(response["distribute_royalty_tokens_tx_hashes"]) == 0 + + def test_batch_register_ip_assets_with_optimized_workflows_with_mint_and_register_registration_request( + self, + story_client: StoryClient, + public_nft_collection, + private_nft_collection, + ): + """Test batch register IP assets with optimized workflows with mint and register registration request.""" + parent_ip_and_license_terms_1 = create_parent_ip_and_license_terms( + story_client, public_nft_collection, account + ) + parent_ip_and_license_terms_2 = create_parent_ip_and_license_terms( + story_client, private_nft_collection, account + ) + token_id_1 = get_token_id(MockERC721, story_client.web3, story_client.account) + token_id_2 = get_token_id(MockERC721, story_client.web3, story_client.account) + token_id_3 = get_token_id(MockERC721, story_client.web3, story_client.account) + token_id_4 = get_token_id(MockERC721, story_client.web3, story_client.account) + requests: list[IpRegistrationWorkflowRequest] = [ + # derivative_workflows_client+ workflow_multicall + RegisterRegistrationRequest( + nft_contract=MockERC721, + token_id=token_id_1, + deriv_data=DerivativeDataInput( + parent_ip_ids=[parent_ip_and_license_terms_1["parent_ip_id"]], + license_terms_ids=[ + parent_ip_and_license_terms_1["license_terms_id"] + ], + ), + ), + # royalty_token_distribution_workflows_client+ workflow_multicall+distribute royalty tokens + RegisterRegistrationRequest( + nft_contract=MockERC721, + token_id=token_id_2, + deriv_data=DerivativeDataInput( + parent_ip_ids=[parent_ip_and_license_terms_2["parent_ip_id"]], + license_terms_ids=[ + parent_ip_and_license_terms_2["license_terms_id"] + ], + ), + royalty_shares=[ + RoyaltyShareInput(recipient=account.address, percentage=50.0), + RoyaltyShareInput(recipient=account_2.address, percentage=50.0), + ], + ), + # public minting + royalty_token_distribution_workflows_client + workflow_multicall + MintAndRegisterRequest( + spg_nft_contract=public_nft_collection, + recipient=account.address, + allow_duplicates=True, + ip_metadata=COMMON_IP_METADATA, + license_terms_data=[ + LicenseTermsDataInput( + terms=PILFlavor.commercial_use( + default_minting_fee=10, + currency=MockERC20, + royalty_policy=NativeRoyaltyPolicy.LAP, + ), + licensing_config=LicensingConfig( + is_set=True, + minting_fee=10, + licensing_hook=ZERO_ADDRESS, + hook_data=ZERO_HASH, + commercial_rev_share=50, + disabled=False, + expect_minimum_group_reward_share=0, + expect_group_reward_pool=ZERO_ADDRESS, + ), + ), + ], + royalty_shares=[ + RoyaltyShareInput(recipient=account.address, percentage=50.0), + RoyaltyShareInput(recipient=account_2.address, percentage=50.0), + ], + ), + # multicall3 + MintAndRegisterRequest( + spg_nft_contract=public_nft_collection, + recipient=account.address, + allow_duplicates=True, + ip_metadata=COMMON_IP_METADATA, + license_terms_data=[ + LicenseTermsDataInput( + terms=PILFlavor.non_commercial_social_remixing(), + licensing_config=LicensingConfig( + is_set=True, + minting_fee=0, + licensing_hook=ZERO_ADDRESS, + hook_data=ZERO_HASH, + commercial_rev_share=0, + disabled=False, + expect_minimum_group_reward_share=0, + expect_group_reward_pool=ZERO_ADDRESS, + ), + ), + LicenseTermsDataInput( + terms=PILFlavor.commercial_use( + default_minting_fee=10, + currency=MockERC20, + royalty_policy=NativeRoyaltyPolicy.LAP, + ), + licensing_config=LicensingConfig( + is_set=True, + minting_fee=10, + licensing_hook=ZERO_ADDRESS, + hook_data=ZERO_HASH, + commercial_rev_share=50, + disabled=False, + expect_minimum_group_reward_share=0, + expect_group_reward_pool=ZERO_ADDRESS, + ), + ), + ], + ), + # royalty_token_distribution_workflows_client+ workflow_multicall + MintAndRegisterRequest( + spg_nft_contract=private_nft_collection, + recipient=account.address, + allow_duplicates=True, + ip_metadata=COMMON_IP_METADATA, + license_terms_data=[ + LicenseTermsDataInput( + terms=PILFlavor.commercial_use( + default_minting_fee=30, + currency=MockERC20, + royalty_policy=NativeRoyaltyPolicy.LAP, + ), + licensing_config=LicensingConfig( + is_set=True, + minting_fee=30, + licensing_hook=ZERO_ADDRESS, + hook_data=ZERO_HASH, + commercial_rev_share=50, + disabled=False, + expect_minimum_group_reward_share=0, + expect_group_reward_pool=ZERO_ADDRESS, + ), + ), + ], + royalty_shares=[ + RoyaltyShareInput(recipient=account.address, percentage=50.0), + RoyaltyShareInput(recipient=account_2.address, percentage=50.0), + ], + ), + # public minting + multicall3 + MintAndRegisterRequest( + spg_nft_contract=public_nft_collection, + recipient=account.address, + allow_duplicates=True, + deriv_data=DerivativeDataInput( + parent_ip_ids=[parent_ip_and_license_terms_1["parent_ip_id"]], + license_terms_ids=[ + parent_ip_and_license_terms_1["license_terms_id"] + ], + ), + royalty_shares=[ + RoyaltyShareInput(recipient=account.address, percentage=50.0), + RoyaltyShareInput(recipient=account_2.address, percentage=50.0), + ], + ), + # private minting + royalty_token_distribution_workflows_client + workflow_multicall + MintAndRegisterRequest( + spg_nft_contract=private_nft_collection, + recipient=account.address, + allow_duplicates=True, + deriv_data=DerivativeDataInput( + parent_ip_ids=[parent_ip_and_license_terms_2["parent_ip_id"]], + license_terms_ids=[ + parent_ip_and_license_terms_2["license_terms_id"] + ], + ), + royalty_shares=[ + RoyaltyShareInput(recipient=account.address, percentage=50.0), + RoyaltyShareInput(recipient=account_2.address, percentage=50.0), + ], + ), + # derivative_workflows_client+ workflow_multicall + RegisterRegistrationRequest( + nft_contract=MockERC721, + token_id=token_id_3, + deriv_data=DerivativeDataInput( + parent_ip_ids=[parent_ip_and_license_terms_1["parent_ip_id"]], + license_terms_ids=[ + parent_ip_and_license_terms_1["license_terms_id"] + ], + ), + ), + # royalty_token_distribution_workflows_client+ workflow_multicall+distribute royalty tokens + RegisterRegistrationRequest( + nft_contract=MockERC721, + token_id=token_id_4, + deriv_data=DerivativeDataInput( + parent_ip_ids=[parent_ip_and_license_terms_2["parent_ip_id"]], + license_terms_ids=[ + parent_ip_and_license_terms_2["license_terms_id"] + ], + ), + royalty_shares=[ + RoyaltyShareInput(recipient=account.address, percentage=50.0), + RoyaltyShareInput(recipient=account_2.address, percentage=50.0), + ], + ), + ] + + # Enhanced: Thoroughly verify transaction aggregation, registration output, + # and cross-check the actual on-chain registered asset state with expectations. + # + # Enhanced verification of batch registration logic and on-chain state: + # + # Expectations: + # - 3 total blockchain transactions: + # 1. derivative_workflows_client: 2 ip_ids + # 2. royalty_token_distribution_workflows_client: 5 ip_ids+ 2 royalty vaults + # 2.1: distribute royalty tokens + # 2.2: 1 license_terms_id + # 2.3: 1 license_terms_id + # 2.4: + # 2.5: distribute royalty tokens + # 3. multicall3 2 ip_ids + # 2.1: 2 license_terms_id + # + response = story_client.IPAsset.batch_ip_asset_with_optimized_workflows( + requests=requests, + ) + + assert isinstance(response, dict) + assert "registration_results" in response + assert "distribute_royalty_tokens_tx_hashes" in response + + registration_results = response["registration_results"] + assert registration_results[0]["tx_hash"] is not None + assert len(registration_results[0]["registered_ips"]) == 2 + assert ( + len(registration_results[0]["registered_ips"][0]["license_terms_ids"]) == 0 + ) + assert ( + len(registration_results[0]["registered_ips"][1]["license_terms_ids"]) == 0 + ) + + assert registration_results[1]["tx_hash"] is not None + assert len(registration_results[1]["registered_ips"]) == 5 + assert ( + len(registration_results[1]["registered_ips"][0]["license_terms_ids"]) == 0 + ) + assert ( + len(registration_results[1]["registered_ips"][1]["license_terms_ids"]) == 1 + ) + assert ( + len(registration_results[1]["registered_ips"][2]["license_terms_ids"]) == 1 + ) + assert ( + len(registration_results[1]["registered_ips"][3]["license_terms_ids"]) == 0 + ) + assert ( + len(registration_results[1]["registered_ips"][4]["license_terms_ids"]) == 0 + ) + + assert len(registration_results[1]["ip_royalty_vaults"]) == 2 + + assert registration_results[2]["tx_hash"] is not None + assert len(registration_results[2]["registered_ips"]) == 2 + assert ( + len(registration_results[2]["registered_ips"][0]["license_terms_ids"]) == 2 + ) + assert ( + len(registration_results[2]["registered_ips"][1]["license_terms_ids"]) == 0 + ) + + assert len(response["distribute_royalty_tokens_tx_hashes"]) == 1 + + def test_batch_register_ip_assets_with_optimized_workflows_without_multicall( + self, + story_client: StoryClient, + public_nft_collection, + private_nft_collection, + ): + """Test batch register IP assets with optimized workflows without using multicall3.""" + # Create parent IP assets for derivative tests + parent_ip_and_license_terms_1 = create_parent_ip_and_license_terms( + story_client, public_nft_collection, account + ) + parent_ip_and_license_terms_2 = create_parent_ip_and_license_terms( + story_client, private_nft_collection, account + ) + + # Create token IDs for RegisterRegistrationRequest + token_id_1 = get_token_id(MockERC721, story_client.web3, story_client.account) + token_id_2 = get_token_id(MockERC721, story_client.web3, story_client.account) + token_id_3 = get_token_id(MockERC721, story_client.web3, story_client.account) + + requests: list[IpRegistrationWorkflowRequest] = [ + # MintAndRegisterRequest with license terms data - LicenseAttachmentWorkflows + MintAndRegisterRequest( + spg_nft_contract=public_nft_collection, + recipient=account.address, + allow_duplicates=True, + ip_metadata=COMMON_IP_METADATA, + license_terms_data=[ + LicenseTermsDataInput( + terms=PILFlavor.commercial_use( + default_minting_fee=10, + currency=MockERC20, + royalty_policy=NativeRoyaltyPolicy.LAP, + ), + licensing_config=LicensingConfig( + is_set=True, + minting_fee=10, + licensing_hook=ZERO_ADDRESS, + hook_data=ZERO_HASH, + commercial_rev_share=50, + disabled=False, + expect_minimum_group_reward_share=0, + expect_group_reward_pool=ZERO_ADDRESS, + ), + ), + ], + ), + # MintAndRegisterRequest with derivative data - DerivativeWorkflows + MintAndRegisterRequest( + spg_nft_contract=public_nft_collection, + recipient=account.address, + allow_duplicates=True, + deriv_data=DerivativeDataInput( + parent_ip_ids=[parent_ip_and_license_terms_1["parent_ip_id"]], + license_terms_ids=[ + parent_ip_and_license_terms_1["license_terms_id"] + ], + ), + ), + # MintAndRegisterRequest with license terms data + royalty shares - RoyaltyTokenDistributionWorkflows + MintAndRegisterRequest( + spg_nft_contract=public_nft_collection, + recipient=account.address, + allow_duplicates=True, + license_terms_data=[ + LicenseTermsDataInput( + terms=PILFlavor.commercial_use( + default_minting_fee=20, + currency=MockERC20, + royalty_policy=NativeRoyaltyPolicy.LAP, + ), + licensing_config=LicensingConfig( + is_set=True, + minting_fee=20, + licensing_hook=ZERO_ADDRESS, + hook_data=ZERO_HASH, + commercial_rev_share=50, + disabled=False, + expect_minimum_group_reward_share=0, + expect_group_reward_pool=ZERO_ADDRESS, + ), + ), + ], + royalty_shares=[ + RoyaltyShareInput(recipient=account.address, percentage=70.0), + RoyaltyShareInput(recipient=account_2.address, percentage=30.0), + ], + ), + # MintAndRegisterRequest with derivative data + royalty shares - RoyaltyTokenDistributionWorkflows + MintAndRegisterRequest( + spg_nft_contract=public_nft_collection, + recipient=account.address, + allow_duplicates=True, + deriv_data=DerivativeDataInput( + parent_ip_ids=[parent_ip_and_license_terms_2["parent_ip_id"]], + license_terms_ids=[ + parent_ip_and_license_terms_2["license_terms_id"] + ], + ), + royalty_shares=[ + RoyaltyShareInput(recipient=account.address, percentage=60.0), + RoyaltyShareInput(recipient=account_2.address, percentage=40.0), + ], + ), + # RegisterRegistrationRequest with license terms data - LicenseAttachmentWorkflows + RegisterRegistrationRequest( + nft_contract=MockERC721, + token_id=token_id_1, + ip_metadata=COMMON_IP_METADATA, + deadline=100000, + license_terms_data=[ + LicenseTermsDataInput( + terms=PILFlavor.non_commercial_social_remixing(), + licensing_config=LicensingConfig( + is_set=True, + minting_fee=0, + licensing_hook=ZERO_ADDRESS, + hook_data=ZERO_HASH, + commercial_rev_share=0, + disabled=False, + expect_minimum_group_reward_share=0, + expect_group_reward_pool=ZERO_ADDRESS, + ), + ), + ], + ), + # RegisterRegistrationRequest with derivative data - DerivativeWorkflows + RegisterRegistrationRequest( + nft_contract=MockERC721, + token_id=token_id_2, + deriv_data=DerivativeDataInput( + parent_ip_ids=[parent_ip_and_license_terms_1["parent_ip_id"]], + license_terms_ids=[ + parent_ip_and_license_terms_1["license_terms_id"] + ], + ), + ), + # RegisterRegistrationRequest with license terms data + royalty shares - RoyaltyTokenDistributionWorkflows + RegisterRegistrationRequest( + nft_contract=MockERC721, + token_id=token_id_3, + license_terms_data=[ + LicenseTermsDataInput( + terms=PILFlavor.commercial_use( + default_minting_fee=50, + currency=MockERC20, + royalty_policy=NativeRoyaltyPolicy.LAP, + ), + licensing_config=LicensingConfig( + is_set=True, + minting_fee=50, + licensing_hook=ZERO_ADDRESS, + hook_data=ZERO_HASH, + commercial_rev_share=50, + disabled=False, + expect_minimum_group_reward_share=0, + expect_group_reward_pool=ZERO_ADDRESS, + ), + ), + ], + royalty_shares=[ + RoyaltyShareInput(recipient=account.address, percentage=50.0), + RoyaltyShareInput(recipient=account_2.address, percentage=50.0), + ], + ), + ] + + # Test batch register IP assets with optimized workflows without using multicall3. + # Expectations: + # - 3 total blockchain transactions: + # 1. LicenseAttachmentWorkflows: 2 ip_ids + # 1.1: 1 license_terms_id + # 1.2: 1 license_terms_id + # 2. DerivativeWorkflows: 2 ip_ids + # 3. RoyaltyTokenDistributionWorkflows: 3 ip_ids + 1 royalty vaults + # 3.1: 1 license_terms_id + # 3.2: + # 3.3: 1 license_terms_id+distribute royalty tokens + response = story_client.IPAsset.batch_ip_asset_with_optimized_workflows( + requests=requests, + is_use_multicall=False, + ) + # Verify response structure + assert isinstance(response, dict) + assert "registration_results" in response + assert "distribute_royalty_tokens_tx_hashes" in response + + registration_results = response["registration_results"] + assert len(registration_results) == 3 + assert registration_results[0]["tx_hash"] is not None + assert len(registration_results[0]["registered_ips"]) == 2 + assert ( + len(registration_results[0]["registered_ips"][0]["license_terms_ids"]) == 1 + ) + assert ( + len(registration_results[0]["registered_ips"][1]["license_terms_ids"]) == 1 + ) + assert len(registration_results[0]["ip_royalty_vaults"]) == 0 + + assert registration_results[1]["tx_hash"] is not None + assert len(registration_results[1]["registered_ips"]) == 2 + assert ( + len(registration_results[1]["registered_ips"][0]["license_terms_ids"]) == 0 + ) + assert ( + len(registration_results[1]["registered_ips"][1]["license_terms_ids"]) == 0 + ) + assert len(registration_results[1]["ip_royalty_vaults"]) == 0 + + assert registration_results[2]["tx_hash"] is not None + assert len(registration_results[2]["registered_ips"]) == 3 + assert ( + len(registration_results[2]["registered_ips"][0]["license_terms_ids"]) == 1 + ) + assert ( + len(registration_results[2]["registered_ips"][1]["license_terms_ids"]) == 0 + ) + assert ( + len(registration_results[2]["registered_ips"][2]["license_terms_ids"]) == 1 + ) + assert len(registration_results[2]["ip_royalty_vaults"]) == 1 + + assert len(response["distribute_royalty_tokens_tx_hashes"]) == 1 diff --git a/tests/unit/fixtures/data.py b/tests/unit/fixtures/data.py index 136ded96..efd15fc6 100644 --- a/tests/unit/fixtures/data.py +++ b/tests/unit/fixtures/data.py @@ -1,3 +1,5 @@ +from dataclasses import asdict, replace + from ens.ens import HexStr from story_protocol_python_sdk import ( @@ -6,6 +8,7 @@ LicenseTermsInput, ) from story_protocol_python_sdk.utils.constants import ZERO_ADDRESS, ZERO_HASH +from story_protocol_python_sdk.utils.util import convert_dict_keys_to_camel_case CHAIN_ID = 1315 ADDRESS = "0x1234567890123456789012345678901234567890" @@ -78,6 +81,25 @@ }, ) ] + +# camel case version of LICENSE_TERMS_DATA +LICENSE_TERMS_DATA_CAMEL_CASE = { + "terms": convert_dict_keys_to_camel_case( + asdict(replace(LICENSE_TERMS_DATA[0].terms, commercial_rev_share=10 * 10**6)) + ), + "licensingConfig": { + "isSet": True, + "mintingFee": 10, + "licensingHook": ADDRESS, + "hookData": ZERO_HASH, + "commercialRevShare": 10 * 10**6, + "disabled": False, + "expectMinimumGroupRewardShare": 0, + "expectGroupRewardPool": ZERO_ADDRESS, + }, +} + + IP_METADATA = IPMetadataInput( ip_metadata_uri="https://example.com/ip-metadata.json", ip_metadata_hash=HexStr("0x" + "a" * 64), diff --git a/tests/unit/resources/test_ip_asset.py b/tests/unit/resources/test_ip_asset.py index e1a6467e..5dbfe15d 100644 --- a/tests/unit/resources/test_ip_asset.py +++ b/tests/unit/resources/test_ip_asset.py @@ -1,4 +1,5 @@ -from unittest.mock import patch +from dataclasses import asdict +from unittest.mock import MagicMock, patch import pytest from ens.ens import HexStr @@ -6,14 +7,17 @@ from story_protocol_python_sdk import ( MAX_ROYALTY_TOKEN, + IpRegistrationWorkflowRequest, LicenseTermsDataInput, LicenseTermsOverride, LicensingConfig, + MintAndRegisterRequest, MintedNFT, MintNFT, NativeRoyaltyPolicy, PILFlavor, PILFlavorError, + RegisterRegistrationRequest, RoyaltyShareInput, ) from story_protocol_python_sdk.abi.IPAccountImpl.IPAccountImpl_client import ( @@ -21,6 +25,7 @@ ) from story_protocol_python_sdk.resources.IPAsset import IPAsset from story_protocol_python_sdk.types.resource.IPAsset import BatchMintAndRegisterIPInput +from story_protocol_python_sdk.types.utils import TransformedRegistrationRequest from story_protocol_python_sdk.utils.derivative_data import DerivativeDataInput from story_protocol_python_sdk.utils.ip_metadata import IPMetadata, IPMetadataInput from story_protocol_python_sdk.utils.royalty import get_royalty_shares @@ -71,7 +76,7 @@ def mock_parse_ip_registered_event(ip_asset): def _mock(): return patch.object( ip_asset, - "_parse_tx_ip_registered_event", + "_get_registered_ips", return_value=[ {"ip_id": IP_ID, "token_id": 3}, {"ip_id": ADDRESS, "token_id": 4}, @@ -121,7 +126,7 @@ def mock_get_royalty_vault_address_by_ip_id(ip_asset): def _mock(royalty_vault=ADDRESS): return patch.object( ip_asset, - "get_royalty_vault_address_by_ip_id", + "_get_royalty_vault_address_by_ip_id", return_value=royalty_vault, ) @@ -138,6 +143,131 @@ def _mock(owner=ACCOUNT_ADDRESS): return _mock +@pytest.fixture(scope="class") +def mock_transform_request_dependencies(): + """ + In order to coverage edge cases, we need to mock all dependencies of transform_request. + Mock all dependencies of transform_request for detailed testing. + + This fixture mocks all the internal dependencies of transform_request: + - IPAssetRegistryClient (for IP ID and registration check) + - Sign utility (for signatures) + - CoreMetadataModuleClient (for metadata operations) + - LicensingModuleClient (for licensing operations) + - RoyaltyTokenDistributionWorkflowsClient (for transaction building) + - RoyaltyModuleClient (for validate_license_terms_data) + - ModuleRegistryClient (for validate_license_terms_data) + + This allows testing the actual transform_request logic while mocking only + the external dependencies. The validate_license_terms_data function will + use the real implementation since its dependencies are mocked. + + Usage: + with mock_transform_request_dependencies( + is_registered=False, + ip_id=IP_ID, + deadline=1000, + signature=b"signature", + license_terms_data=LICENSE_TERMS_DATA, + ): + """ + + def _mock( + is_registered: bool = False, + ip_id: str = IP_ID, + deadline: int = 1000, + signature: bytes = b"signature", + ): + # Mock IPAssetRegistryClient + mock_ip_registry_client = MagicMock() + mock_ip_registry_client.ipId = MagicMock(return_value=ip_id) + mock_ip_registry_client.isRegistered = MagicMock(return_value=is_registered) + + # Mock Sign utility + mock_sign_util = MagicMock() + mock_sign_util.get_deadline = MagicMock(return_value=deadline) + mock_sign_util.get_permission_signature = MagicMock( + return_value={"signature": signature} + ) + + # Mock CoreMetadataModuleClient + mock_core_metadata_contract = MagicMock() + mock_core_metadata_contract.address = ADDRESS + mock_core_metadata_client = MagicMock() + mock_core_metadata_client.contract = mock_core_metadata_contract + + # Mock LicensingModuleClient + mock_licensing_contract = MagicMock() + mock_licensing_contract.address = ADDRESS + mock_licensing_client = MagicMock() + mock_licensing_client.contract = mock_licensing_contract + + # Mock RoyaltyTokenDistributionWorkflowsClient + mock_royalty_workflows_contract = MagicMock() + mock_royalty_workflows_contract.address = ADDRESS + mock_royalty_workflows_contract.encode_abi = MagicMock( + return_value=b"encoded_data" + ) + mock_royalty_workflows_client = MagicMock() + mock_royalty_workflows_client.contract = mock_royalty_workflows_contract + + # Mock RoyaltyModuleClient (for validate_license_terms_data) + mock_royalty_module_client = MagicMock() + mock_royalty_module_client.isWhitelistedRoyaltyPolicy = MagicMock( + return_value=True + ) + mock_royalty_module_client.isWhitelistedRoyaltyToken = MagicMock( + return_value=True + ) + + # Create patches + patches = [ + patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.IPAssetRegistryClient", + return_value=mock_ip_registry_client, + ), + patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.Sign", + return_value=mock_sign_util, + ), + patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.CoreMetadataModuleClient", + return_value=mock_core_metadata_client, + ), + patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.LicensingModuleClient", + return_value=mock_licensing_client, + ), + patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.RoyaltyTokenDistributionWorkflowsClient", + return_value=mock_royalty_workflows_client, + ), + patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.RoyaltyModuleClient", + return_value=mock_royalty_module_client, + ), + patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.get_function_signature", + return_value="", + ), + ] + + # Return context manager that applies all patches + class MockContext: + def __enter__(self): + for p in patches: + p.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + for p in reversed(patches): + p.stop() + + return MockContext() + + return _mock + + class TestIPAssetRegister: def test_register_invalid_deadline_type( self, ip_asset, mock_get_ip_id, mock_is_registered @@ -200,11 +330,10 @@ def test_register_with_metadata( @pytest.fixture(scope="class") -def mock_is_whitelisted_royalty_policy(ip_asset): +def mock_is_whitelisted_royalty_policy(): def _mock(is_whitelisted: bool = True): - return patch.object( - ip_asset.royalty_module_client, - "isWhitelistedRoyaltyPolicy", + return patch( + "story_protocol_python_sdk.abi.RoyaltyModule.RoyaltyModule_client.RoyaltyModuleClient.isWhitelistedRoyaltyPolicy", return_value=is_whitelisted, ) @@ -212,11 +341,10 @@ def _mock(is_whitelisted: bool = True): @pytest.fixture(scope="class") -def mock_is_whitelisted_royalty_token(ip_asset): +def mock_is_whitelisted_royalty_token(): def _mock(is_whitelisted: bool = True): - return patch.object( - ip_asset.royalty_module_client, - "isWhitelistedRoyaltyToken", + return patch( + "story_protocol_python_sdk.abi.RoyaltyModule.RoyaltyModule_client.RoyaltyModuleClient.isWhitelistedRoyaltyToken", return_value=is_whitelisted, ) @@ -241,8 +369,8 @@ def test_ip_is_already_registered( }, ) - def test_parent_ip_id_is_empty(self, ip_asset, mock_get_ip_id, mock_is_registered): - with mock_get_ip_id(), mock_is_registered(): + def test_parent_ip_id_is_empty(self, ip_asset, mock_transform_request_dependencies): + with mock_transform_request_dependencies(): with pytest.raises(ValueError, match="The parent IP IDs must be provided."): ip_asset.register_derivative_ip( nft_contract=ADDRESS, @@ -256,16 +384,14 @@ def test_parent_ip_id_is_empty(self, ip_asset, mock_get_ip_id, mock_is_registere def test_success( self, ip_asset, - mock_get_ip_id, - mock_is_registered, + mock_transform_request_dependencies, mock_parse_ip_registered_event, mock_signature_related_methods, mock_get_function_signature, mock_license_registry_client, ): with ( - mock_get_ip_id(), - mock_is_registered(), + mock_transform_request_dependencies(), mock_parse_ip_registered_event(), mock_get_function_signature(), mock_license_registry_client(), @@ -310,23 +436,39 @@ def test_mint_failed_transaction(self, ip_asset): class TestRegisterIpAndAttachPilTerms: + def test_throw_error_when_license_terms_data_is_not_provided( + self, ip_asset: IPAsset, mock_transform_request_dependencies + ): + with mock_transform_request_dependencies(): + with pytest.raises( + ValueError, match="License terms data must be provided." + ): + ip_asset.register_ip_and_attach_pil_terms( + nft_contract=ADDRESS, + token_id=3, + license_terms_data=[], + ) + def test_token_id_is_already_registered( - self, ip_asset, mock_get_ip_id, mock_is_registered + self, ip_asset, mock_transform_request_dependencies, mock_is_registered ): - with mock_get_ip_id(), mock_is_registered(True): + with ( + mock_transform_request_dependencies(is_registered=True), + mock_is_registered(True), + ): with pytest.raises( ValueError, match="The NFT with id 3 is already registered as IP." ): ip_asset.register_ip_and_attach_pil_terms( nft_contract=ADDRESS, token_id=3, - license_terms_data=[], + license_terms_data=LICENSE_TERMS_DATA, ) def test_royalty_policy_commercial_rev_share_is_less_than_0( - self, ip_asset: IPAsset, mock_get_ip_id, mock_is_registered + self, ip_asset: IPAsset, mock_transform_request_dependencies ): - with mock_get_ip_id(), mock_is_registered(): + with mock_transform_request_dependencies(): with pytest.raises( PILFlavorError, match="commercial_rev_share must be between 0 and 100." ): @@ -347,18 +489,14 @@ def test_royalty_policy_commercial_rev_share_is_less_than_0( def test_transaction_to_be_called_with_correct_parameters( self, ip_asset: IPAsset, - mock_get_ip_id, - mock_is_registered, + mock_transform_request_dependencies, mock_parse_ip_registered_event, mock_parse_tx_license_terms_attached_event, - mock_signature_related_methods, ): with ( - mock_get_ip_id(), - mock_is_registered(), + mock_transform_request_dependencies(), mock_parse_ip_registered_event(), mock_parse_tx_license_terms_attached_event(), - mock_signature_related_methods(), ): with patch.object( ip_asset.license_attachment_workflows_client, @@ -415,41 +553,50 @@ def test_transaction_to_be_called_with_correct_parameters( def test_success( self, ip_asset: IPAsset, - mock_get_ip_id, - mock_is_registered, + mock_transform_request_dependencies, mock_parse_ip_registered_event, - mock_signature_related_methods, mock_parse_tx_license_terms_attached_event, ): with ( - mock_get_ip_id(), - mock_is_registered(), + mock_transform_request_dependencies(), mock_parse_ip_registered_event(), mock_parse_tx_license_terms_attached_event(), - mock_signature_related_methods(), ): - result = ip_asset.register_ip_and_attach_pil_terms( - nft_contract=ADDRESS, - token_id=3, - license_terms_data=[ - { - "terms": LICENSE_TERMS, - "licensing_config": LICENSING_CONFIG, - } - ], - ip_metadata={ - "ip_metadata_uri": "https://example.com/metadata/custom-value.json", - "ip_metadata_hash": "ip_metadata_hash", - "nft_metadata_uri": "https://example.com/metadata/custom-value.json", - "nft_metadata_hash": "nft_metadata_hash", - }, - ) - assert result == { - "tx_hash": TX_HASH.hex(), - "ip_id": IP_ID, - "license_terms_ids": [1, 2], - "token_id": 3, - } + with patch.object( + ip_asset.license_attachment_workflows_client, + "build_registerIpAndAttachPILTerms_transaction", + ) as mock_build_registerIpAndAttachPILTerms_transaction: + result = ip_asset.register_ip_and_attach_pil_terms( + nft_contract=ADDRESS, + token_id=3, + license_terms_data=[ + { + "terms": LICENSE_TERMS, + "licensing_config": LICENSING_CONFIG, + } + ], + ip_metadata={ + "ip_metadata_uri": "https://example.com/metadata/custom-value.json", + "ip_metadata_hash": "ip_metadata_hash", + "nft_metadata_uri": "https://example.com/metadata/custom-value.json", + "nft_metadata_hash": "nft_metadata_hash", + }, + ) + call_args = ( + mock_build_registerIpAndAttachPILTerms_transaction.call_args[0] + ) + assert call_args[2] == { + "ipMetadataURI": "https://example.com/metadata/custom-value.json", + "ipMetadataHash": "ip_metadata_hash", + "nftMetadataURI": "https://example.com/metadata/custom-value.json", + "nftMetadataHash": "nft_metadata_hash", + } + assert result == { + "tx_hash": TX_HASH.hex(), + "ip_id": IP_ID, + "license_terms_ids": [1, 2], + "token_id": 3, + } class TestRegisterDerivative: @@ -555,8 +702,13 @@ def test_success_and_expect_value_when_default_values_not_provided( ip_asset: IPAsset, mock_license_registry_client, mock_parse_ip_registered_event, + mock_transform_request_dependencies, ): - with mock_parse_ip_registered_event(), mock_license_registry_client(): + with ( + mock_transform_request_dependencies(), + mock_parse_ip_registered_event(), + mock_license_registry_client(), + ): with patch.object( ip_asset.derivative_workflows_client, "build_mintAndRegisterIpAndMakeDerivative_transaction", @@ -596,10 +748,15 @@ def test_success_and_expect_value_when_default_values_not_provided( def test_with_custom_value( self, ip_asset: IPAsset, + mock_transform_request_dependencies, mock_license_registry_client, mock_parse_ip_registered_event, ): - with mock_parse_ip_registered_event(), mock_license_registry_client(): + with ( + mock_transform_request_dependencies(), + mock_parse_ip_registered_event(), + mock_license_registry_client(), + ): with patch.object( ip_asset.derivative_workflows_client, "build_mintAndRegisterIpAndMakeDerivative_transaction", @@ -616,10 +773,8 @@ def test_with_custom_value( license_template=ADDRESS, ), ip_metadata=IPMetadataInput( - ip_metadata_uri="https://example.com/metadata/custom-value.json", ip_metadata_hash=HexStr("ip_metadata_hash"), nft_metadata_uri="https://example.com/metadata/custom-value.json", - nft_metadata_hash=HexStr("nft_metadata_hash"), ), recipient=ADDRESS, allow_duplicates=False, @@ -638,10 +793,10 @@ def test_with_custom_value( "licenseTemplate": ADDRESS, } assert mock_build_transaction.call_args[0][2] == { - "ipMetadataURI": "https://example.com/metadata/custom-value.json", - "ipMetadataHash": "ip_metadata_hash", + "ipMetadataURI": "", + "ipMetadataHash": HexStr("ip_metadata_hash"), "nftMetadataURI": "https://example.com/metadata/custom-value.json", - "nftMetadataHash": "nft_metadata_hash", + "nftMetadataHash": ZERO_HASH, } assert mock_build_transaction.call_args[0][3] == ADDRESS # recipient assert not mock_build_transaction.call_args[0][4] # allowDuplicates @@ -1102,22 +1257,26 @@ def test_throw_error_when_royalty_shares_empty(self, ip_asset: IPAsset): royalty_shares=[], ) - def test_throw_error_when_deriv_data_is_invalid(self, ip_asset: IPAsset): - with pytest.raises(ValueError, match="The parent IP IDs must be provided."): - ip_asset.mint_and_register_ip_and_make_derivative_and_distribute_royalty_tokens( - spg_nft_contract=ADDRESS, - deriv_data=DerivativeDataInput( - parent_ip_ids=[], - license_terms_ids=[1], - ), - royalty_shares=[ - RoyaltyShareInput(recipient=ACCOUNT_ADDRESS, percentage=50.0) - ], - ) + def test_throw_error_when_deriv_data_is_invalid( + self, ip_asset: IPAsset, mock_transform_request_dependencies + ): + with mock_transform_request_dependencies(): + with pytest.raises(ValueError, match="The parent IP IDs must be provided."): + ip_asset.mint_and_register_ip_and_make_derivative_and_distribute_royalty_tokens( + spg_nft_contract=ADDRESS, + deriv_data=DerivativeDataInput( + parent_ip_ids=[], + license_terms_ids=[1], + ), + royalty_shares=[ + RoyaltyShareInput(recipient=ACCOUNT_ADDRESS, percentage=50.0) + ], + ) def test_success_with_default_values( self, ip_asset: IPAsset, + mock_transform_request_dependencies, mock_license_registry_client, mock_parse_ip_registered_event, mock_get_royalty_vault_address_by_ip_id, @@ -1128,6 +1287,7 @@ def test_success_with_default_values( ] with ( + mock_transform_request_dependencies(), mock_parse_ip_registered_event(), mock_license_registry_client(), mock_get_royalty_vault_address_by_ip_id(), @@ -1160,6 +1320,7 @@ def test_success_with_default_values( def test_royalty_vault_address( self, ip_asset: IPAsset, + mock_transform_request_dependencies, mock_license_registry_client, mock_parse_ip_registered_event, ): @@ -1169,6 +1330,7 @@ def test_royalty_vault_address( ] with ( + mock_transform_request_dependencies(), mock_parse_ip_registered_event(), mock_license_registry_client(), ): @@ -1209,6 +1371,7 @@ def test_royalty_vault_address( def test_success_with_custom_values( self, ip_asset: IPAsset, + mock_transform_request_dependencies, mock_license_registry_client, mock_parse_ip_registered_event, mock_get_royalty_vault_address_by_ip_id, @@ -1223,6 +1386,7 @@ def test_success_with_custom_values( nft_metadata_hash="0xabcdef1234567890", ) with ( + mock_transform_request_dependencies(), mock_parse_ip_registered_event(), mock_license_registry_client(), mock_get_royalty_vault_address_by_ip_id(), @@ -1266,10 +1430,15 @@ def test_success_with_custom_values( def test_throw_error_when_transaction_failed( self, ip_asset: IPAsset, + mock_transform_request_dependencies, mock_license_registry_client, mock_parse_ip_registered_event, ): - with mock_parse_ip_registered_event(), mock_license_registry_client(): + with ( + mock_transform_request_dependencies(), + mock_parse_ip_registered_event(), + mock_license_registry_client(), + ): with patch.object( ip_asset.royalty_token_distribution_workflows_client, "build_mintAndRegisterIpAndMakeDerivativeAndDistributeRoyaltyTokens_transaction", @@ -1302,9 +1471,24 @@ def test_throw_error_when_royalty_shares_empty(self, ip_asset: IPAsset): royalty_shares=[], ) + def test_throw_error_when_license_terms_data_is_empty(self, ip_asset: IPAsset): + + with pytest.raises( + ValueError, + match="Failed to mint, register IP, attach PIL terms and distribute royalty tokens: License terms data must be provided.", + ): + ip_asset.mint_and_register_ip_and_attach_pil_terms_and_distribute_royalty_tokens( + spg_nft_contract=ADDRESS, + license_terms_data=[], + royalty_shares=[ + RoyaltyShareInput(recipient=ACCOUNT_ADDRESS, percentage=50.0) + ], + ) + def test_success_with_default_values( self, ip_asset: IPAsset, + mock_transform_request_dependencies, mock_license_registry_client, mock_parse_ip_registered_event, mock_parse_tx_license_terms_attached_event, @@ -1316,6 +1500,7 @@ def test_success_with_default_values( ] with ( + mock_transform_request_dependencies(), mock_parse_ip_registered_event(), mock_parse_tx_license_terms_attached_event(), mock_license_registry_client(), @@ -1460,13 +1645,27 @@ def test_token_id_is_already_registered( ], ) + def test_throw_error_when_license_terms_data_is_empty( + self, ip_asset: IPAsset, mock_transform_request_dependencies + ): + with (mock_transform_request_dependencies(),): + with pytest.raises( + ValueError, + match="Failed to register IP, attach PIL terms and distribute royalty tokens: License terms data must be provided.", + ): + ip_asset.register_ip_and_attach_pil_terms_and_distribute_royalty_tokens( + nft_contract=ADDRESS, + token_id=3, + license_terms_data=[], + royalty_shares=[ + RoyaltyShareInput(recipient=ACCOUNT_ADDRESS, percentage=50.0) + ], + ) + def test_throw_error_when_royalty_shares_empty( - self, ip_asset: IPAsset, mock_get_ip_id, mock_is_registered + self, ip_asset: IPAsset, mock_transform_request_dependencies ): - with ( - mock_get_ip_id(), - mock_is_registered(), - ): + with (mock_transform_request_dependencies(),): with pytest.raises( ValueError, match="Failed to register IP, attach PIL terms and distribute royalty tokens: Royalty shares must be provided.", @@ -1481,13 +1680,10 @@ def test_throw_error_when_royalty_shares_empty( def test_success_with_default_values( self, ip_asset: IPAsset, - mock_get_ip_id, - mock_is_registered, + mock_transform_request_dependencies, mock_parse_ip_registered_event, mock_parse_tx_license_terms_attached_event, - mock_signature_related_methods, mock_get_royalty_vault_address_by_ip_id, - mock_ip_account_impl_client, ): royalty_shares = [ RoyaltyShareInput(recipient=ACCOUNT_ADDRESS, percentage=50.0), @@ -1495,13 +1691,10 @@ def test_success_with_default_values( ] with ( - mock_get_ip_id(), - mock_is_registered(), + mock_transform_request_dependencies(), mock_parse_ip_registered_event(), mock_parse_tx_license_terms_attached_event(), - mock_signature_related_methods(), mock_get_royalty_vault_address_by_ip_id(), - mock_ip_account_impl_client(), ): with ( patch.object( @@ -1546,26 +1739,20 @@ def test_success_with_default_values( def test_success_with_custom_values( self, ip_asset: IPAsset, - mock_get_ip_id, - mock_is_registered, + mock_transform_request_dependencies, mock_parse_ip_registered_event, mock_parse_tx_license_terms_attached_event, - mock_signature_related_methods, mock_get_royalty_vault_address_by_ip_id, - mock_ip_account_impl_client, ): royalty_shares = [ RoyaltyShareInput(recipient=ACCOUNT_ADDRESS, percentage=60.0), ] with ( - mock_get_ip_id(), - mock_is_registered(), + mock_transform_request_dependencies(), mock_parse_ip_registered_event(), mock_parse_tx_license_terms_attached_event(), - mock_signature_related_methods(), mock_get_royalty_vault_address_by_ip_id(), - mock_ip_account_impl_client(), ): with ( patch.object( @@ -1607,15 +1794,9 @@ def test_success_with_custom_values( def test_throw_error_when_transaction_failed( self, ip_asset: IPAsset, - mock_get_ip_id, - mock_is_registered, - mock_signature_related_methods, + mock_transform_request_dependencies, ): - with ( - mock_get_ip_id(), - mock_is_registered(), - mock_signature_related_methods(), - ): + with (mock_transform_request_dependencies(),): with patch.object( ip_asset.royalty_token_distribution_workflows_client, "build_registerIpAndAttachPILTermsAndDeployRoyaltyVault_transaction", @@ -1639,12 +1820,9 @@ def test_throw_error_when_transaction_failed( class TestRegisterDerivativeIpAndAttachPilTermsAndDistributeRoyaltyTokens: def test_token_id_is_already_registered( - self, ip_asset: IPAsset, mock_get_ip_id, mock_is_registered + self, ip_asset: IPAsset, mock_transform_request_dependencies ): - with ( - mock_get_ip_id(), - mock_is_registered(True), - ): + with (mock_transform_request_dependencies(is_registered=True),): with pytest.raises( ValueError, match="Failed to register derivative IP and distribute royalty tokens: The NFT with id 3 is already registered as IP.", @@ -1664,13 +1842,11 @@ def test_token_id_is_already_registered( def test_throw_error_when_royalty_shares_empty( self, ip_asset: IPAsset, - mock_get_ip_id, - mock_is_registered, + mock_transform_request_dependencies, mock_license_registry_client, ): with ( - mock_get_ip_id(), - mock_is_registered(), + mock_transform_request_dependencies(), mock_license_registry_client(), ): with pytest.raises( @@ -1690,12 +1866,9 @@ def test_throw_error_when_royalty_shares_empty( def test_success_with_default_values( self, ip_asset: IPAsset, - mock_get_ip_id, - mock_is_registered, + mock_transform_request_dependencies, mock_parse_ip_registered_event, - mock_signature_related_methods, mock_get_royalty_vault_address_by_ip_id, - mock_ip_account_impl_client, mock_license_registry_client, ): royalty_shares = [ @@ -1704,12 +1877,9 @@ def test_success_with_default_values( ] with ( - mock_get_ip_id(), - mock_is_registered(), + mock_transform_request_dependencies(), mock_parse_ip_registered_event(), - mock_signature_related_methods(), mock_get_royalty_vault_address_by_ip_id(), - mock_ip_account_impl_client(), mock_license_registry_client(), ): with ( @@ -1758,12 +1928,9 @@ def test_success_with_default_values( def test_success_with_custom_values( self, ip_asset: IPAsset, - mock_get_ip_id, - mock_is_registered, + mock_transform_request_dependencies, mock_parse_ip_registered_event, - mock_signature_related_methods, mock_get_royalty_vault_address_by_ip_id, - mock_ip_account_impl_client, mock_license_registry_client, ): royalty_shares = [ @@ -1771,12 +1938,9 @@ def test_success_with_custom_values( ] with ( - mock_get_ip_id(), - mock_is_registered(), + mock_transform_request_dependencies(), mock_parse_ip_registered_event(), - mock_signature_related_methods(), mock_get_royalty_vault_address_by_ip_id(), - mock_ip_account_impl_client(), mock_license_registry_client(), ): with ( @@ -1831,15 +1995,11 @@ def test_success_with_custom_values( def test_throw_error_when_transaction_failed( self, ip_asset: IPAsset, - mock_get_ip_id, - mock_is_registered, - mock_signature_related_methods, + mock_transform_request_dependencies, mock_license_registry_client, ): with ( - mock_get_ip_id(), - mock_is_registered(), - mock_signature_related_methods(), + mock_transform_request_dependencies(), mock_license_registry_client(), ): with patch.object( @@ -1868,13 +2028,12 @@ def test_throw_error_when_transaction_failed( def test_success_with_tx_options( self, ip_asset: IPAsset, - mock_get_ip_id, - mock_is_registered, + mock_transform_request_dependencies, mock_parse_ip_registered_event, - mock_signature_related_methods, mock_get_royalty_vault_address_by_ip_id, - mock_ip_account_impl_client, mock_license_registry_client, + mock_ip_account_impl_client, + mock_signature_related_methods, ): royalty_shares = [ RoyaltyShareInput(recipient=ACCOUNT_ADDRESS, percentage=60.0), @@ -1886,13 +2045,12 @@ def test_success_with_tx_options( "chainId": 1, } with ( - mock_get_ip_id(), - mock_is_registered(), + mock_transform_request_dependencies(), mock_parse_ip_registered_event(), - mock_signature_related_methods(), mock_get_royalty_vault_address_by_ip_id(), - mock_ip_account_impl_client(), mock_license_registry_client(), + mock_ip_account_impl_client(), + mock_signature_related_methods(), ): with patch( "story_protocol_python_sdk.resources.IPAsset.build_and_send_transaction" @@ -2138,11 +2296,8 @@ def test_throw_not_provided_license_terms_data_when_royalty_shares_provided_for_ def test_success_when_license_terms_data_and_royalty_shares_provided_for_minted_nft( self, ip_asset: IPAsset, - mock_is_registered, - mock_get_ip_id, - mock_signature_related_methods, + mock_transform_request_dependencies, mock_parse_ip_registered_event, - mock_ip_account_impl_client, mock_parse_tx_license_terms_attached_event, mock_get_royalty_vault_address_by_ip_id, ): @@ -2152,11 +2307,8 @@ def test_success_when_license_terms_data_and_royalty_shares_provided_for_minted_ royalty_shares_obj = get_royalty_shares(royalty_shares) royalty_vault = HexStr("0x" + "a" * 64) with ( - mock_is_registered(is_registered=False), - mock_get_ip_id(), - mock_signature_related_methods(), + mock_transform_request_dependencies(), mock_parse_ip_registered_event(), - mock_ip_account_impl_client(), mock_parse_tx_license_terms_attached_event(), mock_get_royalty_vault_address_by_ip_id(royalty_vault), patch.object( @@ -2599,16 +2751,12 @@ def test_success_when_license_terms_data_provided_for_minted_nft( ip_asset: IPAsset, mock_parse_ip_registered_event, mock_parse_tx_license_terms_attached_event, - mock_get_ip_id, - mock_signature_related_methods, - mock_is_registered, + mock_transform_request_dependencies, ): with ( mock_parse_ip_registered_event(), mock_parse_tx_license_terms_attached_event(), - mock_get_ip_id(), - mock_signature_related_methods(), - mock_is_registered(is_registered=False), + mock_transform_request_dependencies(), patch.object( ip_asset.license_attachment_workflows_client, "build_registerIpAndAttachPILTerms_transaction", @@ -2638,16 +2786,12 @@ def test_success_when_license_terms_data_is_commercial_use_for_minted_nft( ip_asset: IPAsset, mock_parse_ip_registered_event, mock_parse_tx_license_terms_attached_event, - mock_get_ip_id, - mock_signature_related_methods, - mock_is_registered, + mock_transform_request_dependencies, ): with ( mock_parse_ip_registered_event(), mock_parse_tx_license_terms_attached_event(), - mock_get_ip_id(), - mock_signature_related_methods(), - mock_is_registered(is_registered=False), + mock_transform_request_dependencies(), patch.object( ip_asset.license_attachment_workflows_client, "build_registerIpAndAttachPILTerms_transaction", @@ -2890,21 +3034,19 @@ def test_throw_error_when_deriv_data_is_not_provided_and_royalty_shares_are_prov def test_success_when_deriv_data_and_royalty_shares_are_provided_for_minted_nft( self, ip_asset: IPAsset, + mock_transform_request_dependencies, + mock_license_registry_client, mock_parse_ip_registered_event, - mock_get_ip_id, - mock_signature_related_methods, - mock_is_registered, mock_get_royalty_vault_address_by_ip_id, - mock_license_registry_client, + mock_signature_related_methods, mock_ip_account_impl_client, ): with ( - mock_get_ip_id(), - mock_is_registered(is_registered=False), + mock_transform_request_dependencies(), + mock_license_registry_client(), mock_parse_ip_registered_event(), - mock_signature_related_methods(), mock_get_royalty_vault_address_by_ip_id(), - mock_license_registry_client(), + mock_signature_related_methods(), mock_ip_account_impl_client(), patch.object( ip_asset.royalty_token_distribution_workflows_client, @@ -3066,16 +3208,14 @@ def test_throw_error_when_license_token_ids_and_deriv_data_are_not_provided_for_ def test_success_when_deriv_data_only_are_provided_for_minted_nft( self, ip_asset: IPAsset, + mock_transform_request_dependencies, mock_parse_ip_registered_event, - mock_get_ip_id, mock_license_registry_client, mock_signature_related_methods, - mock_is_registered, mock_get_function_signature, ): with ( - mock_get_ip_id(), - mock_is_registered(is_registered=False), + mock_transform_request_dependencies(), mock_parse_ip_registered_event(), mock_license_registry_client(), mock_signature_related_methods(), @@ -3544,3 +3684,508 @@ def test_throw_error_when_license_token_ids_are_not_owned_by_caller( match="Failed to link derivative: Failed to register derivative with license tokens: License token id 1 must be owned by the caller.", ): ip_asset.link_derivative(license_token_ids=[1], child_ip_id=IP_ID) + + +class TestBatchIpAssetWithOptimizedWorkflows: + """Test batch_ip_asset_with_optimized_workflows method.""" + + @pytest.fixture + def mock_transform_request(self): + """Mock transform_request function.""" + + def _mock(): + return patch( + "story_protocol_python_sdk.resources.IPAsset.transform_request" + ) + + return _mock + + @pytest.fixture + def mock_send_transactions(self): + """Mock send_transactions function.""" + + def _mock(): + return patch( + "story_protocol_python_sdk.resources.IPAsset.send_transactions" + ) + + return _mock + + @pytest.fixture + def mock_prepare_distribute_royalty_tokens_requests(self): + """Mock prepare_distribute_royalty_tokens_requests function.""" + + def _mock(): + return patch( + "story_protocol_python_sdk.resources.IPAsset.prepare_distribute_royalty_tokens_requests" + ) + + return _mock + + def test_batch_mint_and_register_with_license_terms( + self, + ip_asset: IPAsset, + mock_transform_request, + mock_send_transactions, + mock_prepare_distribute_royalty_tokens_requests, + ): + """Test batch registration with MintAndRegisterRequest and license terms.""" + requests = [ + MintAndRegisterRequest( + spg_nft_contract=ADDRESS, + license_terms_data=LICENSE_TERMS_DATA, + ip_metadata=IP_METADATA, + recipient=ACCOUNT_ADDRESS, + allow_duplicates=True, + ), + MintAndRegisterRequest( + spg_nft_contract=ADDRESS, + license_terms_data=LICENSE_TERMS_DATA, + ), + ] + + with ( + mock_transform_request() as mock_transform, + mock_send_transactions() as mock_send, + mock_prepare_distribute_royalty_tokens_requests() as mock_prepare, + patch.object( + ip_asset, + "_parse_tx_ip_registered_event", + return_value=[ + {"ipId": IP_ID, "tokenId": 1}, + {"ipId": ADDRESS, "tokenId": 2}, + ], + ), + patch.object( + ip_asset, "_parse_all_ip_royalty_vault_deployed_events", return_value=[] + ), + patch.object( + ip_asset.pi_license_template_client, + "getLicenseTermsId", + side_effect=[1, 2], + ), + ): + # Mock transform_request to return TransformedRegistrationRequest + mock_transformed_1 = TransformedRegistrationRequest( + encoded_tx_data=b"encoded_data_1", + is_use_multicall3=True, + workflow_address=ADDRESS, + validated_request=[], + workflow_multicall_reference=MagicMock(), + extra_data=None, + ) + mock_transformed_2 = TransformedRegistrationRequest( + encoded_tx_data=b"encoded_data_2", + is_use_multicall3=True, + workflow_address=ADDRESS, + validated_request=[], + workflow_multicall_reference=MagicMock(), + extra_data=None, + ) + mock_transform.side_effect = [mock_transformed_1, mock_transformed_2] + + license_terms_dict = asdict(LICENSE_TERMS_DATA[0]) + mock_send.side_effect = [ + ( + [{"tx_hash": TX_HASH.hex(), "tx_receipt": {"logs": []}}], + { + ADDRESS: { + "license_terms_data": [ + [license_terms_dict], + [license_terms_dict], + ] + } + }, + ), + ([], {}), # No distribute royalty tokens requests + ] + + # Mock prepare_distribute_royalty_tokens_requests + mock_prepare.return_value = ([], []) + + result = ip_asset.batch_ip_asset_with_optimized_workflows( + requests=requests, is_use_multicall=True + ) + + # Verify transform_request was called for each request + assert mock_transform.call_count == 2 + + assert mock_send.call_count == 1 + + # Verify response structure + assert isinstance(result, dict) + assert "registration_results" in result + assert "distribute_royalty_tokens_tx_hashes" in result + assert len(result["registration_results"]) == 1 + assert result["registration_results"][0]["tx_hash"] == TX_HASH.hex() + assert len(result["distribute_royalty_tokens_tx_hashes"]) == 0 + + assert len(result["registration_results"][0]["registered_ips"]) == 2 + assert result["registration_results"][0]["registered_ips"][0][ + "license_terms_ids" + ] == [1] + # Second IP gets license_terms_ids [2] from the second LICENSE_TERMS_DATA + assert result["registration_results"][0]["registered_ips"][1][ + "license_terms_ids" + ] == [2] + + def test_batch_register_with_royalty_shares( + self, + ip_asset: IPAsset, + mock_transform_request, + mock_send_transactions, + mock_prepare_distribute_royalty_tokens_requests, + mock_parse_ip_registered_event, + ): + """Test batch registration with RegisterRegistrationRequest and royalty shares.""" + royalty_shares = [ + RoyaltyShareInput(recipient=ACCOUNT_ADDRESS, percentage=50.0), + RoyaltyShareInput(recipient=ADDRESS, percentage=30.0), + ] + + requests = [ + RegisterRegistrationRequest( + nft_contract=ADDRESS, + token_id=1, + license_terms_data=LICENSE_TERMS_DATA, + royalty_shares=royalty_shares, + ip_metadata=IP_METADATA, + deadline=1000, + ), + ] + + with ( + mock_transform_request() as mock_transform, + mock_send_transactions() as mock_send, + mock_prepare_distribute_royalty_tokens_requests() as mock_prepare, + mock_parse_ip_registered_event(), + ): + # Mock transform_request with extra_data containing royalty_shares + mock_transformed = TransformedRegistrationRequest( + encoded_tx_data=b"encoded_data", + is_use_multicall3=False, + workflow_address=ADDRESS, + validated_request=[], + workflow_multicall_reference=MagicMock(), + extra_data={"royalty_shares": royalty_shares}, + ) + mock_transform.return_value = mock_transformed + + # Mock send_transactions + mock_send.side_effect = [ + ( + [{"tx_hash": TX_HASH.hex(), "tx_receipt": {"logs": []}}], + {ADDRESS: {"license_terms_data": [LICENSE_TERMS_DATA]}}, + ), + ( + [{"tx_hash": "0xDistributeTxHash", "tx_receipt": {"logs": []}}], + {}, + ), + ] + + # Mock prepare_distribute_royalty_tokens_requests + mock_distribute_request = TransformedRegistrationRequest( + encoded_tx_data=b"distribute_data", + is_use_multicall3=False, + workflow_address=ADDRESS, + validated_request=[], + workflow_multicall_reference=MagicMock(), + extra_data=None, + ) + mock_prepare.return_value = ([mock_distribute_request], []) + + result = ip_asset.batch_ip_asset_with_optimized_workflows( + requests=requests, is_use_multicall=True + ) + + # Verify response + assert len(result["registration_results"]) == 1 + assert len(result["distribute_royalty_tokens_tx_hashes"]) == 1 + assert ( + result["distribute_royalty_tokens_tx_hashes"][0] == "0xDistributeTxHash" + ) + + def test_batch_mixed_requests( + self, + ip_asset: IPAsset, + mock_transform_request, + mock_send_transactions, + mock_prepare_distribute_royalty_tokens_requests, + mock_parse_ip_registered_event, + ): + """Test batch registration with mixed MintAndRegisterRequest and RegisterRegistrationRequest.""" + requests: list[IpRegistrationWorkflowRequest] = [ + MintAndRegisterRequest( + spg_nft_contract=ADDRESS, + deriv_data=DerivativeDataInput( + parent_ip_ids=[IP_ID], + license_terms_ids=[1], + ), + ), + RegisterRegistrationRequest( + nft_contract=ADDRESS, + token_id=2, + license_terms_data=LICENSE_TERMS_DATA, + ), + ] + + with ( + mock_transform_request() as mock_transform, + mock_send_transactions() as mock_send, + mock_prepare_distribute_royalty_tokens_requests() as mock_prepare, + mock_parse_ip_registered_event(), + ): + mock_transformed_1 = TransformedRegistrationRequest( + encoded_tx_data=b"encoded_data_1", + is_use_multicall3=True, + workflow_address=ADDRESS, + validated_request=[], + workflow_multicall_reference=MagicMock(), + extra_data=None, + ) + mock_transformed_2 = TransformedRegistrationRequest( + encoded_tx_data=b"encoded_data_2", + is_use_multicall3=True, + workflow_address=ADDRESS, + validated_request=[], + workflow_multicall_reference=MagicMock(), + extra_data=None, + ) + mock_transform.side_effect = [mock_transformed_1, mock_transformed_2] + + mock_send.side_effect = [ + ( + [ + {"tx_hash": TX_HASH.hex(), "tx_receipt": {"logs": []}}, + {"tx_hash": "0xTxHash2", "tx_receipt": {"logs": []}}, + ], + { + ADDRESS: { + "license_terms_data": [[], LICENSE_TERMS_DATA], + } + }, + ), + ([], {}), + ] + + mock_prepare.return_value = ([], []) + + result = ip_asset.batch_ip_asset_with_optimized_workflows( + requests=requests, is_use_multicall=True + ) + + # Verify multiple registrations + assert len(result["registration_results"]) == 2 + assert result["registration_results"][0]["tx_hash"] == TX_HASH.hex() + assert result["registration_results"][1]["tx_hash"] == "0xTxHash2" + + def test_batch_with_multicall_disabled( + self, + ip_asset: IPAsset, + mock_transform_request, + mock_send_transactions, + mock_prepare_distribute_royalty_tokens_requests, + mock_parse_ip_registered_event, + ): + """Test batch registration with is_use_multicall=False.""" + requests: list[IpRegistrationWorkflowRequest] = [ + MintAndRegisterRequest( + spg_nft_contract=ADDRESS, + license_terms_data=LICENSE_TERMS_DATA, + ), + RegisterRegistrationRequest( + nft_contract=ADDRESS, + token_id=1, + license_terms_data=LICENSE_TERMS_DATA, + ), + ] + + with ( + mock_transform_request() as mock_transform, + mock_send_transactions() as mock_send, + mock_prepare_distribute_royalty_tokens_requests() as mock_prepare, + mock_parse_ip_registered_event(), + ): + mock_transformed_1 = TransformedRegistrationRequest( + encoded_tx_data=b"encoded_data_1", + is_use_multicall3=True, + workflow_address=ADDRESS, + validated_request=[], + workflow_multicall_reference=MagicMock(), + extra_data=None, + ) + mock_transformed_2 = TransformedRegistrationRequest( + encoded_tx_data=b"encoded_data_2", + is_use_multicall3=True, + workflow_address=ADDRESS, + validated_request=[], + workflow_multicall_reference=MagicMock(), + extra_data=None, + ) + mock_transform.side_effect = [mock_transformed_1, mock_transformed_2] + + mock_send.side_effect = [ + ( + [{"tx_hash": TX_HASH.hex(), "tx_receipt": {"logs": []}}], + {ADDRESS: {"license_terms_data": [LICENSE_TERMS_DATA]}}, + ), + ([], {}), + ] + + mock_prepare.return_value = ([], []) + + result = ip_asset.batch_ip_asset_with_optimized_workflows( + requests=requests, is_use_multicall=False + ) + + # Verify is_use_multicall3 was passed correctly + assert mock_send.call_args_list[0][1]["is_use_multicall3"] is False + # Verify result + assert len(result["registration_results"]) == 1 + + def test_batch_with_royalty_shares_and_license_terms( + self, + ip_asset: IPAsset, + mock_transform_request, + mock_send_transactions, + mock_prepare_distribute_royalty_tokens_requests, + mock_parse_ip_registered_event, + ): + """Test batch registration with both royalty shares and license terms.""" + royalty_shares = [ + RoyaltyShareInput(recipient=ACCOUNT_ADDRESS, percentage=60.0), + ] + + requests = [ + MintAndRegisterRequest( + spg_nft_contract=ADDRESS, + license_terms_data=LICENSE_TERMS_DATA, + royalty_shares=royalty_shares, + ip_metadata=IP_METADATA, + ), + ] + + with ( + mock_transform_request() as mock_transform, + mock_send_transactions() as mock_send, + mock_prepare_distribute_royalty_tokens_requests() as mock_prepare, + mock_parse_ip_registered_event(), + ): + mock_transformed = TransformedRegistrationRequest( + encoded_tx_data=b"encoded_data", + is_use_multicall3=False, + workflow_address=ADDRESS, + validated_request=[], + workflow_multicall_reference=MagicMock(), + extra_data={"royalty_shares": royalty_shares}, + ) + mock_transform.return_value = mock_transformed + + mock_send.side_effect = [ + ( + [{"tx_hash": TX_HASH.hex(), "tx_receipt": {"logs": []}}], + {ADDRESS: {"license_terms_data": [LICENSE_TERMS_DATA]}}, + ), + ( + [{"tx_hash": "0xDistributeTxHash", "tx_receipt": {"logs": []}}], + {}, + ), + ] + + mock_distribute_request = TransformedRegistrationRequest( + encoded_tx_data=b"distribute_data", + is_use_multicall3=False, + workflow_address=ADDRESS, + validated_request=[], + workflow_multicall_reference=MagicMock(), + extra_data=None, + ) + mock_prepare.return_value = ( + [mock_distribute_request], + [{"ip_id": IP_ID, "royalty_vault": ADDRESS}], + ) + + result = ip_asset.batch_ip_asset_with_optimized_workflows( + requests=requests, is_use_multicall=True + ) + + # Verify royalty distribution was handled + assert len(result["distribute_royalty_tokens_tx_hashes"]) == 1 + assert len(result["registration_results"][0]["ip_royalty_vaults"]) == 1 + assert ( + result["registration_results"][0]["ip_royalty_vaults"][0][ + "royalty_vault" + ] + == ADDRESS + ) + + def test_batch_empty_requests( + self, + ip_asset: IPAsset, + mock_transform_request, + mock_send_transactions, + mock_prepare_distribute_royalty_tokens_requests, + ): + """Test batch registration with empty requests list.""" + requests: list[MintAndRegisterRequest | RegisterRegistrationRequest] = [] + + with ( + mock_transform_request() as mock_transform, + mock_send_transactions() as mock_send, + mock_prepare_distribute_royalty_tokens_requests() as mock_prepare, + ): + mock_send.side_effect = [ + ([], {}), + ([], {}), + ] + + mock_prepare.return_value = ([], []) + + result = ip_asset.batch_ip_asset_with_optimized_workflows( + requests=requests, is_use_multicall=True + ) + + # Verify no transform was called + mock_transform.assert_not_called() + # Verify empty response + assert len(result["registration_results"]) == 0 + assert len(result["distribute_royalty_tokens_tx_hashes"]) == 0 + + def test_batch_transaction_failure( + self, + ip_asset: IPAsset, + mock_transform_request, + mock_send_transactions, + ): + """Test batch registration when transaction fails.""" + requests = [ + MintAndRegisterRequest( + spg_nft_contract=ADDRESS, + license_terms_data=LICENSE_TERMS_DATA, + ), + ] + + with ( + mock_transform_request() as mock_transform, + mock_send_transactions() as mock_send, + ): + mock_transformed = TransformedRegistrationRequest( + encoded_tx_data=b"encoded_data", + is_use_multicall3=True, + workflow_address=ADDRESS, + validated_request=[], + workflow_multicall_reference=MagicMock(), + extra_data=None, + ) + mock_transform.return_value = mock_transformed + + # Mock send_transactions to raise an error + mock_send.side_effect = ValueError("Transaction failed") + + with pytest.raises( + ValueError, + match="Failed to batch register IP assets with optimized workflows: Transaction failed", + ): + ip_asset.batch_ip_asset_with_optimized_workflows( + requests=requests, is_use_multicall=True + ) diff --git a/tests/unit/utils/test_registration_utils.py b/tests/unit/utils/test_registration_utils.py new file mode 100644 index 00000000..fa818fab --- /dev/null +++ b/tests/unit/utils/test_registration_utils.py @@ -0,0 +1,834 @@ +from unittest.mock import MagicMock, patch + +import pytest +from ens.ens import HexStr + +from story_protocol_python_sdk import RoyaltyShareInput +from story_protocol_python_sdk.types.resource.IPAsset import IPRoyaltyVault +from story_protocol_python_sdk.types.utils import ( + ExtraData, + TransformedRegistrationRequest, +) +from story_protocol_python_sdk.utils.registration.registration_utils import ( + aggregate_multicall_requests, + prepare_distribute_royalty_tokens_requests, + send_transactions, +) +from tests.unit.fixtures.data import ADDRESS, LICENSE_TERMS_DATA_CAMEL_CASE + + +@pytest.fixture +def mock_multicall3_client(): + """Mock Multicall3Client.""" + + def _mock(): + return patch( + "story_protocol_python_sdk.utils.registration.registration_utils.Multicall3Client", + return_value=MagicMock( + contract=MagicMock( + address="multicall3", + ), + ), + ) + + return _mock + + +class TestAggregateMulticallRequests: + def test_aggregates_single_request(self, mock_web3, mock_multicall3_client): + """Test aggregating a single request.""" + with mock_multicall3_client(): + encoded_data = b"encoded_data_1" + contract_call = MagicMock(return_value=HexStr("0x123")) + request = TransformedRegistrationRequest( + encoded_tx_data=encoded_data, + is_use_multicall3=False, + workflow_address=ADDRESS, + validated_request=[], + workflow_multicall_reference=contract_call, + ) + result = aggregate_multicall_requests( + requests=[request], + is_use_multicall3=False, + web3=mock_web3, + ) + assert len(result) == 1 + assert ADDRESS in result + aggregated_request_data = result[ADDRESS] + assert aggregated_request_data["call_data"] == [encoded_data] + assert aggregated_request_data["license_terms_data"] == [[]] + assert aggregated_request_data["method_reference"] == contract_call + + def test_aggregates_multiple_requests_same_address( + self, mock_web3, mock_multicall3_client + ): + """Test aggregating multiple requests to the same address.""" + with mock_multicall3_client(): + encoded_data_1 = b"encoded_data_1" + encoded_data_2 = b"encoded_data_2" + contract_call = MagicMock(return_value=HexStr("0x111")) + + request_1 = TransformedRegistrationRequest( + encoded_tx_data=encoded_data_1, + is_use_multicall3=False, + workflow_address=ADDRESS, + validated_request=[], + workflow_multicall_reference=contract_call, + ) + request_2 = TransformedRegistrationRequest( + encoded_tx_data=encoded_data_2, + is_use_multicall3=False, + workflow_address=ADDRESS, + validated_request=[], + workflow_multicall_reference=contract_call, + extra_data=ExtraData( + license_terms_data=[LICENSE_TERMS_DATA_CAMEL_CASE], + ), + ) + + result = aggregate_multicall_requests( + requests=[request_1, request_2], + is_use_multicall3=False, + web3=mock_web3, + ) + + assert len(result) == 1 + assert ADDRESS in result + aggregated_request_data = result[ADDRESS] + assert aggregated_request_data["call_data"] == [ + encoded_data_1, + encoded_data_2, + ] + assert aggregated_request_data["license_terms_data"] == [ + [], + [LICENSE_TERMS_DATA_CAMEL_CASE], + ] + assert aggregated_request_data["method_reference"] == contract_call + + def test_aggregates_multiple_requests_different_addresses( + self, mock_web3, mock_multicall3_client + ): + """Test aggregating multiple requests to different addresses.""" + with mock_multicall3_client(): + workflow_address_1 = ADDRESS + workflow_address_2 = "0xDifferentAddress" + encoded_data_1 = b"encoded_data_1" + encoded_data_2 = b"encoded_data_2" + encoded_data_3 = b"encoded_data_3" + contract_call_1 = MagicMock(return_value=HexStr("0x111")) + contract_call_2 = MagicMock(return_value=HexStr("0x222")) + royalty_shares = [RoyaltyShareInput(recipient=ADDRESS, percentage=50)] + + request_1 = TransformedRegistrationRequest( + encoded_tx_data=encoded_data_1, + is_use_multicall3=False, + workflow_address=workflow_address_1, + validated_request=[], + workflow_multicall_reference=contract_call_1, + ) + request_2 = TransformedRegistrationRequest( + encoded_tx_data=encoded_data_2, + is_use_multicall3=False, + workflow_address=workflow_address_2, + validated_request=[], + workflow_multicall_reference=contract_call_2, + extra_data=ExtraData( + license_terms_data=[LICENSE_TERMS_DATA_CAMEL_CASE], + ), + ) + request_3 = TransformedRegistrationRequest( + encoded_tx_data=encoded_data_3, + is_use_multicall3=False, + workflow_address=workflow_address_2, + validated_request=[], + workflow_multicall_reference=contract_call_2, + extra_data=ExtraData( + royalty_shares=royalty_shares, + ), + ) + + result = aggregate_multicall_requests( + requests=[request_1, request_2, request_3], + is_use_multicall3=False, + web3=mock_web3, + ) + + assert len(result) == 2 + assert workflow_address_1 in result + assert workflow_address_2 in result + + aggregated_request_data = result[workflow_address_1] + assert aggregated_request_data["call_data"] == [encoded_data_1] + assert aggregated_request_data["license_terms_data"] == [[]] + assert aggregated_request_data["method_reference"] == contract_call_1 + + aggregated_request_data = result[workflow_address_2] + assert aggregated_request_data["call_data"] == [ + encoded_data_2, + encoded_data_3, + ] + assert aggregated_request_data["license_terms_data"] == [ + [LICENSE_TERMS_DATA_CAMEL_CASE], + [], + ] + assert aggregated_request_data["method_reference"] == contract_call_2 + + def test_uses_multicall3_address_when_enabled( + self, mock_web3, mock_multicall3_client + ): + """Test using multicall3 address when is_use_multicall3 is True.""" + with mock_multicall3_client() as mock_patch: + multicall3_instance = mock_patch.return_value + encoded_data_1 = b"encoded_data1" + encoded_data_2 = b"encoded_data2" + contract_call1 = MagicMock(return_value=HexStr("0x111")) + contract_call2 = MagicMock(return_value=HexStr("0x222")) + + request1 = TransformedRegistrationRequest( + encoded_tx_data=encoded_data_1, + is_use_multicall3=True, + workflow_address="workflow1", + validated_request=[], + workflow_multicall_reference=contract_call1, + ) + request2 = TransformedRegistrationRequest( + encoded_tx_data=encoded_data_2, + is_use_multicall3=True, + workflow_address="workflow2", + validated_request=[], + workflow_multicall_reference=contract_call2, + extra_data=ExtraData( + license_terms_data=[LICENSE_TERMS_DATA_CAMEL_CASE], + ), + ) + + result = aggregate_multicall_requests( + requests=[request1, request2], + is_use_multicall3=True, + web3=mock_web3, + ) + + assert len(result) == 1 + assert "multicall3" in result + assert "workflow1" not in result + assert "workflow2" not in result + + aggregated_request_data = result["multicall3"] + # When using multicall3, call_data should be Multicall3Call structure + expected_call_data = [ + { + "target": "workflow1", + "allowFailure": False, + "value": 0, + "callData": encoded_data_1, + }, + { + "target": "workflow2", + "allowFailure": False, + "value": 0, + "callData": encoded_data_2, + }, + ] + assert aggregated_request_data["call_data"] == expected_call_data + assert aggregated_request_data["license_terms_data"] == [ + [], + [LICENSE_TERMS_DATA_CAMEL_CASE], + ] + # Method reference should be multicall3's method + assert ( + aggregated_request_data["method_reference"] + == multicall3_instance.build_aggregate3_transaction + ) + + def test_uses_workflow_address_when_multicall3_disabled( + self, mock_web3, mock_multicall3_client + ): + """Test using workflow address when is_use_multicall3 is False.""" + with mock_multicall3_client() as mock_patch: + multicall3_instance = mock_patch.return_value + multicall3_address = multicall3_instance.contract.address + + encoded_data_1 = b"encoded_data1" + encoded_data_2 = b"encoded_data2" + contract_call1 = MagicMock(return_value=HexStr("0x111")) + contract_call2 = MagicMock(return_value=HexStr("0x222")) + + request1 = TransformedRegistrationRequest( + encoded_tx_data=encoded_data_1, + is_use_multicall3=True, # Request wants to use multicall3 + workflow_address="workflow1", + validated_request=[], + workflow_multicall_reference=contract_call1, + extra_data=ExtraData( + license_terms_data=[LICENSE_TERMS_DATA_CAMEL_CASE], + ), + ) + request2 = TransformedRegistrationRequest( + encoded_tx_data=encoded_data_2, + is_use_multicall3=True, + workflow_address="workflow2", + validated_request=[], + workflow_multicall_reference=contract_call2, + ) + result = aggregate_multicall_requests( + requests=[request1, request2], + is_use_multicall3=False, + web3=mock_web3, + ) + + assert len(result) == 2 + assert "workflow1" in result + assert "workflow2" in result + assert multicall3_address not in result + + aggregated_request_data = result["workflow1"] + assert aggregated_request_data["call_data"] == [encoded_data_1] + assert aggregated_request_data["license_terms_data"] == [ + [LICENSE_TERMS_DATA_CAMEL_CASE] + ] + assert aggregated_request_data["method_reference"] == contract_call1 + + aggregated_request_data = result["workflow2"] + assert aggregated_request_data["call_data"] == [encoded_data_2] + assert aggregated_request_data["license_terms_data"] == [[]] + assert aggregated_request_data["method_reference"] == contract_call2 + + def test_aggregates_mixed_requests_with_multicall3( + self, mock_web3, mock_multicall3_client + ): + """Test aggregating mixed requests where some use multicall3 and some don't.""" + with mock_multicall3_client() as mock_patch: + multicall3_instance = mock_patch.return_value + multicall3_address = multicall3_instance.contract.address + + workflow_address_1 = ADDRESS + workflow_address_2 = "0xDifferentWorkflow" + + encoded_data_1 = b"encoded_data_1" + encoded_data_2 = b"encoded_data_2" + encoded_data_3 = b"encoded_data_3" + contract_call_1 = MagicMock(return_value=HexStr("0x111")) + contract_call_2 = MagicMock(return_value=HexStr("0x222")) + contract_call_3 = MagicMock(return_value=HexStr("0x333")) + + # Request 1: uses multicall3 + request_1 = TransformedRegistrationRequest( + encoded_tx_data=encoded_data_1, + is_use_multicall3=True, + workflow_address=workflow_address_1, + validated_request=[], + workflow_multicall_reference=contract_call_1, + extra_data=ExtraData( + license_terms_data=[LICENSE_TERMS_DATA_CAMEL_CASE], + ), + ) + # Request 2: uses multicall3 + request_2 = TransformedRegistrationRequest( + encoded_tx_data=encoded_data_2, + is_use_multicall3=True, + workflow_address=workflow_address_2, + validated_request=[], + workflow_multicall_reference=contract_call_2, + ) + # Request 3: doesn't use multicall3 + request_3 = TransformedRegistrationRequest( + encoded_tx_data=encoded_data_3, + is_use_multicall3=False, + workflow_address=workflow_address_1, + validated_request=[], + workflow_multicall_reference=contract_call_3, + ) + + result = aggregate_multicall_requests( + requests=[request_1, request_2, request_3], + is_use_multicall3=True, + web3=mock_web3, + ) + + # Request 1 and 2 should be aggregated to multicall3_address + # Request 3 should use its workflow_address + assert len(result) == 2 + assert multicall3_address in result + assert workflow_address_1 in result + + # Check multicall3 aggregation (request_1 and request_2) + multicall_data = result[multicall3_address]["call_data"] + assert len(multicall_data) == 2 + # Check that multicall structures are correct + assert multicall_data == [ + { + "target": workflow_address_1, + "allowFailure": False, + "value": 0, + "callData": encoded_data_1, + }, + { + "target": workflow_address_2, + "allowFailure": False, + "value": 0, + "callData": encoded_data_2, + }, + ] + assert result[multicall3_address]["license_terms_data"] == [ + [LICENSE_TERMS_DATA_CAMEL_CASE], + [], + ] + assert ( + result[multicall3_address]["method_reference"] + == multicall3_instance.build_aggregate3_transaction + ) + + # Check workflow address (request_3) + aggregated_request_data = result[workflow_address_1] + assert aggregated_request_data["call_data"] == [encoded_data_3] + assert aggregated_request_data["license_terms_data"] == [[]] + assert aggregated_request_data["method_reference"] == contract_call_3 + + def test_aggregates_empty_requests_list(self, mock_web3, mock_multicall3_client): + """Test aggregating an empty list of requests.""" + with mock_multicall3_client(): + result = aggregate_multicall_requests( + requests=[], + is_use_multicall3=False, + web3=mock_web3, + ) + + assert len(result) == 0 + assert isinstance(result, dict) + + +@pytest.fixture +def mock_transform_distribute_royalty_tokens_request(): + """Mock dependencies needed by transform_distribute_royalty_tokens_request.""" + + def _mock(): + return patch( + "story_protocol_python_sdk.utils.registration.registration_utils.transform_distribute_royalty_tokens_request", + return_value=TransformedRegistrationRequest( + encoded_tx_data=b"mock_encoded_data", + is_use_multicall3=False, + workflow_address=ADDRESS, + validated_request=[], + workflow_multicall_reference=MagicMock(), + ), + ) + + return _mock + + +class TestPrepareDistributeRoyaltyTokensRequests: + def test_returns_empty_lists_when_extra_data_list_is_empty( + self, mock_web3, mock_account + ): + """Test that empty lists are returned when extra_data_list is empty.""" + result = prepare_distribute_royalty_tokens_requests( + extra_data_list=[], + web3=mock_web3, + ip_registered=[], + royalty_vault=[], + account=mock_account, + chain_id=1, + ) + + transformed_requests, matching_vaults = result + assert transformed_requests == [] + assert matching_vaults == [] + + def test_filters_and_matches_ip_and_vault_data( + self, mock_web3, mock_account, mock_transform_distribute_royalty_tokens_request + ): + """Test successful filtering and matching of IP and vault data.""" + with mock_transform_distribute_royalty_tokens_request(): + nft_contract = "0xNFTContract" + token_id = 123 + ip_id = "ip_id" + ip_registered = [ + { + "tokenContract": nft_contract, + "tokenId": token_id, + "ipId": ip_id, + } + ] + ip_royalty_vault = "ip_royalty_vault" + royalty_vault = [ + { + "ipId": ip_id, + "ipRoyaltyVault": ip_royalty_vault, + } + ] + result = prepare_distribute_royalty_tokens_requests( + extra_data_list=[ + ExtraData( + nft_contract=nft_contract, + token_id=token_id, + deadline=1000, + royalty_total_amount=5000, + royalty_shares=[ + RoyaltyShareInput(recipient="0xRecipient", percentage=50) + ], + ) + ], + web3=mock_web3, + ip_registered=ip_registered, + royalty_vault=royalty_vault, + account=mock_account, + chain_id=1, + ) + transformed_requests, matching_vaults = result + assert len(transformed_requests) == 1 + assert len(matching_vaults) == 1 + assert matching_vaults == [ + IPRoyaltyVault(ip_id=ip_id, royalty_vault=ip_royalty_vault) + ] + + def test_skips_when_no_matching_ip_registered( + self, mock_web3, mock_account, mock_transform_distribute_royalty_tokens_request + ): + """Test that items are skipped when no matching IP is registered.""" + with mock_transform_distribute_royalty_tokens_request() as mock_transform: + nft_contract = "0xNonExistentContract" + token_id = 999 + ip_registered = [ + { + "tokenContract": "0xDifferentContract", + "tokenId": 123, + "ipId": "0xIPID", + } + ] + royalty_vault = [ + { + "ipId": "0xIPID", + "ipRoyaltyVault": "0xRoyaltyVault", + } + ] + result = prepare_distribute_royalty_tokens_requests( + extra_data_list=[ + ExtraData( + nft_contract=nft_contract, + token_id=token_id, + deadline=1000, + royalty_total_amount=5000, + royalty_shares=[ + RoyaltyShareInput(recipient="0xRecipient", percentage=50) + ], + ) + ], + web3=mock_web3, + ip_registered=ip_registered, + royalty_vault=royalty_vault, + account=mock_account, + chain_id=1, + ) + transformed_requests, matching_vaults = result + assert transformed_requests == [] + assert matching_vaults == [] + # Verify transform was not called since no IP matched + mock_transform.assert_not_called() + + def test_skips_when_no_matching_vault( + self, mock_web3, mock_account, mock_transform_distribute_royalty_tokens_request + ): + """Test that items are skipped when no matching vault is found.""" + with mock_transform_distribute_royalty_tokens_request() as mock_transform: + nft_contract = "0xNFTContract" + token_id = 123 + ip_id = "0xIPID" + ip_registered = [ + { + "tokenContract": nft_contract, + "tokenId": token_id, + "ipId": ip_id, + } + ] + # Vault for different IP ID + royalty_vault = [ + { + "ipId": "0xDifferentIPID", + "ipRoyaltyVault": "0xRoyaltyVault", + } + ] + result = prepare_distribute_royalty_tokens_requests( + extra_data_list=[ + ExtraData( + nft_contract=nft_contract, + token_id=token_id, + deadline=1000, + royalty_total_amount=5000, + royalty_shares=[ + RoyaltyShareInput(recipient="0xRecipient", percentage=50) + ], + ) + ], + web3=mock_web3, + ip_registered=ip_registered, + royalty_vault=royalty_vault, + account=mock_account, + chain_id=1, + ) + transformed_requests, matching_vaults = result + assert transformed_requests == [] + assert matching_vaults == [] + # Verify transform was not called since no vault matched + mock_transform.assert_not_called() + + def test_processes_multiple_extra_data_items( + self, mock_web3, mock_account, mock_transform_distribute_royalty_tokens_request + ): + """Test processing multiple extra_data items with mixed matching results.""" + with mock_transform_distribute_royalty_tokens_request() as mock_transform: + # Test data - 3 items: 2 should match, 1 should not + ip_registered = [ + {"tokenContract": "0xContract1", "tokenId": 1, "ipId": "0xIPID1"}, + {"tokenContract": "0xContract2", "tokenId": 2, "ipId": "0xIPID2"}, + {"tokenContract": "0xContract3", "tokenId": 3, "ipId": "0xIPID3"}, + ] + # Only vaults for first two items + royalty_vault = [ + {"ipId": "0xIPID1", "ipRoyaltyVault": "0xVault1"}, + {"ipId": "0xIPID2", "ipRoyaltyVault": "0xVault2"}, + # No vault for 0xIPID3 + ] + result = prepare_distribute_royalty_tokens_requests( + extra_data_list=[ + # Item 1: Should match + ExtraData( + nft_contract="0xContract1", + token_id=1, + deadline=1000, + royalty_total_amount=5000, + royalty_shares=[ + RoyaltyShareInput(recipient="0xRecipient1", percentage=30) + ], + ), + # Item 2: Should match + ExtraData( + nft_contract="0xContract2", + token_id=2, + deadline=2000, + royalty_total_amount=6000, + royalty_shares=[ + RoyaltyShareInput(recipient="0xRecipient2", percentage=40) + ], + ), + # Item 3: Should not match (no vault) + ExtraData( + nft_contract="0xContract3", + token_id=3, + deadline=3000, + royalty_total_amount=7000, + royalty_shares=[ + RoyaltyShareInput(recipient="0xRecipient3", percentage=50) + ], + ), + ], + web3=mock_web3, + ip_registered=ip_registered, + royalty_vault=royalty_vault, + account=mock_account, + chain_id=1, + ) + transformed_requests, matching_vaults = result + # Should have 2 results (first two items matched) + assert len(transformed_requests) == 2 + assert len(matching_vaults) == 2 + assert matching_vaults == [ + IPRoyaltyVault(ip_id="0xIPID1", royalty_vault="0xVault1"), + IPRoyaltyVault(ip_id="0xIPID2", royalty_vault="0xVault2"), + ] + # Verify transform was called twice (once for each matched item) + assert mock_transform.call_count == 2 + + +@pytest.fixture +def mock_build_and_send_transaction(): + """Mock build_and_send_transaction function.""" + + def _mock(): + return patch( + "story_protocol_python_sdk.utils.registration.registration_utils.build_and_send_transaction", + ) + + return _mock + + +class TestSendTransactions: + def test_sends_single_transaction( + self, + mock_web3, + mock_account, + mock_multicall3_client, + mock_build_and_send_transaction, + ): + """Test sending a single transaction.""" + with mock_multicall3_client(), mock_build_and_send_transaction() as mock_build: + # Setup test data + encoded_data = b"encoded_data" + method_reference = MagicMock() + workflow_address = ADDRESS + + transformed_request = TransformedRegistrationRequest( + encoded_tx_data=encoded_data, + is_use_multicall3=False, + workflow_address=workflow_address, + validated_request=[], + workflow_multicall_reference=method_reference, + ) + + # Mock build_and_send_transaction return value + tx_hash = "0xTxHash" + tx_receipt = {"status": 1} + mock_build.return_value = { + "tx_hash": tx_hash, + "tx_receipt": tx_receipt, + } + + result = send_transactions( + transformed_requests=[transformed_request], + is_use_multicall3=False, + web3=mock_web3, + account=mock_account, + ) + + tx_results, aggregated_requests = result + + # Verify results + assert len(tx_results) == 1 + assert tx_results[0]["tx_hash"] == tx_hash + assert tx_results[0]["tx_receipt"] == tx_receipt + + # Verify aggregated_requests structure (from real aggregate function) + assert len(aggregated_requests) == 1 + assert workflow_address in aggregated_requests + assert aggregated_requests[workflow_address]["call_data"] == [encoded_data] + assert ( + aggregated_requests[workflow_address]["method_reference"] + == method_reference + ) + + # Verify build_and_send_transaction was called correctly + mock_build.assert_called_once_with( + mock_web3, + mock_account, + method_reference, + [encoded_data], + tx_options=None, + ) + + def test_sends_multiple_transactions_to_different_addresses( + self, + mock_web3, + mock_account, + mock_multicall3_client, + mock_build_and_send_transaction, + ): + """Test sending multiple transactions to different addresses.""" + with mock_multicall3_client(), mock_build_and_send_transaction() as mock_build: + # Setup test data + workflow_address_1 = ADDRESS + workflow_address_2 = "0xWorkflowAddress2" + method_reference_1 = MagicMock() + method_reference_2 = MagicMock() + + transformed_request_1 = TransformedRegistrationRequest( + encoded_tx_data=b"data1", + is_use_multicall3=True, + workflow_address=workflow_address_1, + validated_request=[], + workflow_multicall_reference=method_reference_1, + ) + transformed_request_2 = TransformedRegistrationRequest( + encoded_tx_data=b"data2", + is_use_multicall3=False, + workflow_address=workflow_address_2, + validated_request=[], + workflow_multicall_reference=method_reference_2, + ) + transformed_request_3 = TransformedRegistrationRequest( + encoded_tx_data=b"data3", + is_use_multicall3=True, + workflow_address=workflow_address_1, + validated_request=[], + workflow_multicall_reference=method_reference_1, + extra_data=ExtraData( + license_terms_data=[LICENSE_TERMS_DATA_CAMEL_CASE], + ), + ) + + # Mock build_and_send_transaction return values + mock_build.side_effect = [ + {"tx_hash": "0xHash1", "tx_receipt": {"status": 1}}, + {"tx_hash": "0xHash2", "tx_receipt": {"status": 1}}, + ] + + result = send_transactions( + transformed_requests=[ + transformed_request_1, + transformed_request_2, + transformed_request_3, + ], + is_use_multicall3=True, + web3=mock_web3, + account=mock_account, + ) + + tx_results, aggregated_requests = result + + # Verify results + assert len(tx_results) == 2 + assert tx_results[0]["tx_hash"] == "0xHash1" + assert tx_results[1]["tx_hash"] == "0xHash2" + + # Verify aggregated_requests structure (from real aggregate function) + assert len(aggregated_requests) == 2 + assert "multicall3" in aggregated_requests + assert aggregated_requests["multicall3"]["call_data"] == [ + { + "target": workflow_address_1, + "allowFailure": False, + "value": 0, + "callData": b"data1", + }, + { + "target": workflow_address_1, + "allowFailure": False, + "value": 0, + "callData": b"data3", + }, + ] + assert aggregated_requests["multicall3"]["license_terms_data"] == [ + [], + [LICENSE_TERMS_DATA_CAMEL_CASE], + ] + assert workflow_address_2 in aggregated_requests + workflow_address_2_data = aggregated_requests[workflow_address_2] + assert workflow_address_2_data["call_data"] == [b"data2"] + assert workflow_address_2_data["method_reference"] == method_reference_2 + assert workflow_address_2_data["license_terms_data"] == [[]] + + # Verify build_and_send_transaction was called twice + assert mock_build.call_count == 2 + + def test_sends_empty_requests_list( + self, + mock_web3, + mock_account, + mock_multicall3_client, + mock_build_and_send_transaction, + ): + """Test sending empty requests list.""" + with mock_multicall3_client(), mock_build_and_send_transaction() as mock_build: + result = send_transactions( + transformed_requests=[], + is_use_multicall3=False, + web3=mock_web3, + account=mock_account, + ) + + tx_results, aggregated_requests = result + + # Verify results + assert tx_results == [] + assert aggregated_requests == {} + + # Verify build_and_send_transaction was not called + mock_build.assert_not_called() diff --git a/tests/unit/utils/test_transform_registration_request.py b/tests/unit/utils/test_transform_registration_request.py new file mode 100644 index 00000000..01505ec8 --- /dev/null +++ b/tests/unit/utils/test_transform_registration_request.py @@ -0,0 +1,1060 @@ +from dataclasses import asdict, replace +from unittest.mock import MagicMock, patch + +import pytest +from typing_extensions import cast +from web3 import Account + +from story_protocol_python_sdk import ( + DerivativeDataInput, + MintAndRegisterRequest, + RegisterRegistrationRequest, + RoyaltyShareInput, +) +from story_protocol_python_sdk.abi.DerivativeWorkflows.DerivativeWorkflows_client import ( + DerivativeWorkflowsClient, +) +from story_protocol_python_sdk.abi.LicenseAttachmentWorkflows.LicenseAttachmentWorkflows_client import ( + LicenseAttachmentWorkflowsClient, +) +from story_protocol_python_sdk.abi.RoyaltyTokenDistributionWorkflows.RoyaltyTokenDistributionWorkflows_client import ( + RoyaltyTokenDistributionWorkflowsClient, +) +from story_protocol_python_sdk.utils.ip_metadata import IPMetadata +from story_protocol_python_sdk.utils.registration.transform_registration_request import ( + get_public_minting, + transform_distribute_royalty_tokens_request, + transform_request, + validate_license_terms_data, +) +from tests.unit.fixtures.data import ( + ADDRESS, + CHAIN_ID, + IP_ID, + IP_METADATA, + LICENSE_TERMS_DATA, + LICENSE_TERMS_DATA_CAMEL_CASE, +) + + +@pytest.fixture +def mock_get_public_minting(): + """Mock get_public_minting function.""" + + def _mock(public_minting: bool = True): + return patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.SPGNFTImplClient", + return_value=MagicMock( + publicMinting=MagicMock(return_value=public_minting) + ), + ) + + return _mock + + +@pytest.fixture +def mock_royalty_module_client(): + """Mock RoyaltyModuleClient for validate_license_terms_data.""" + + def _mock(is_whitelisted_policy: bool = True, is_whitelisted_token: bool = True): + return patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.RoyaltyModuleClient", + return_value=MagicMock( + isWhitelistedRoyaltyPolicy=MagicMock( + return_value=is_whitelisted_policy + ), + isWhitelistedRoyaltyToken=MagicMock(return_value=is_whitelisted_token), + ), + ) + + return _mock + + +@pytest.fixture +def mock_module_registry_client(): + """Mock ModuleRegistryClient for validate_license_terms_data.""" + return patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.ModuleRegistryClient", + return_value=MagicMock(), + ) + + +@pytest.fixture +def mock_pi_license_template_client(): + """Mock PILicenseTemplateClient for DerivativeData.""" + + def _mock(): + mock_instance = MagicMock() + mock_instance.contract = MagicMock() + mock_instance.contract.address = ADDRESS + return patch( + "story_protocol_python_sdk.utils.derivative_data.PILicenseTemplateClient", + return_value=mock_instance, + ) + + return _mock + + +@pytest.fixture +def mock_derivative_ip_asset_registry_client(): + """Mock IPAssetRegistryClient for DerivativeData.""" + + def _mock(is_registered: bool = True): + return patch( + "story_protocol_python_sdk.utils.derivative_data.IPAssetRegistryClient", + return_value=MagicMock(isRegistered=MagicMock(return_value=is_registered)), + ) + + return _mock + + +@pytest.fixture +def mock_license_registry_client(): + """Mock LicenseRegistryClient for DerivativeData.""" + + def _mock(has_attached_license_terms: bool = True, royalty_percent: int = 1000000): + return patch( + "story_protocol_python_sdk.utils.derivative_data.LicenseRegistryClient", + return_value=MagicMock( + hasIpAttachedLicenseTerms=MagicMock( + return_value=has_attached_license_terms + ), + getRoyaltyPercent=MagicMock(return_value=royalty_percent), + ), + ) + + return _mock + + +@pytest.fixture +def mock_workflow_clients(mock_web3): + """Mock workflow clients (RoyaltyTokenDistributionWorkflowsClient, LicenseAttachmentWorkflowsClient, DerivativeWorkflowsClient). + + Returns real client instances so encode_abi can produce real encoding results. + """ + + def _mock(): + + # Create real client instances with mock_web3 + royalty_token_distribution_client = RoyaltyTokenDistributionWorkflowsClient( + mock_web3 + ) + license_attachment_client = LicenseAttachmentWorkflowsClient(mock_web3) + derivative_workflows_client = DerivativeWorkflowsClient(mock_web3) + royalty_token_distribution_client.contract.address = ( + "royalty_token_distribution_client_address" + ) + license_attachment_client.contract.address = "license_attachment_client_address" + derivative_workflows_client.contract.address = ( + "derivative_workflows_client_address" + ) + return { + "royalty_token_distribution_patch": patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.RoyaltyTokenDistributionWorkflowsClient", + return_value=royalty_token_distribution_client, + ), + "license_attachment_patch": patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.LicenseAttachmentWorkflowsClient", + return_value=license_attachment_client, + ), + "derivative_workflows_patch": patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.DerivativeWorkflowsClient", + return_value=derivative_workflows_client, + ), + "royalty_token_distribution_client": royalty_token_distribution_client, + "license_attachment_client": license_attachment_client, + "derivative_workflows_client": derivative_workflows_client, + "get_function_signature": patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.get_function_signature", + return_value="", + ), + "royalty_token_distribution_workflows_client": patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.RoyaltyTokenDistributionWorkflowsClient", + return_value=royalty_token_distribution_client, + ), + } + + return _mock + + +@pytest.fixture +def mock_ip_asset_registry_client(): + """Mock IPAssetRegistryClient.""" + + def _mock(is_registered: bool = False, ip_id: str = IP_ID): + mock_client = MagicMock() + mock_client.ipId = MagicMock(return_value=ip_id) + mock_client.isRegistered = MagicMock(return_value=is_registered) + return patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.IPAssetRegistryClient", + return_value=mock_client, + ) + + return _mock + + +@pytest.fixture +def mock_sign_util(): + """Mock Sign utility.""" + + def _mock(deadline: int = 1000, signature: bytes = b"signature"): + mock_sign = MagicMock() + mock_sign.get_deadline = MagicMock(return_value=deadline) + mock_sign.get_permission_signature = MagicMock( + return_value={"signature": signature} + ) + mock_sign.get_signature = MagicMock(return_value={"signature": signature}) + return patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.Sign", + return_value=mock_sign, + ) + + return _mock + + +@pytest.fixture +def mock_module_clients(): + """Mock CoreMetadataModuleClient and LicensingModuleClient.""" + + def _mock(): + mock_core_metadata_contract = MagicMock() + mock_core_metadata_contract.address = ADDRESS + mock_core_metadata_client = MagicMock() + mock_core_metadata_client.contract = mock_core_metadata_contract + + mock_licensing_contract = MagicMock() + mock_licensing_contract.address = ADDRESS + mock_licensing_client = MagicMock() + mock_licensing_client.contract = mock_licensing_contract + + return { + "core_metadata_module_patch": patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.CoreMetadataModuleClient", + return_value=mock_core_metadata_client, + ), + "licensing_module_patch": patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.LicensingModuleClient", + return_value=mock_licensing_client, + ), + } + + return _mock + + +@pytest.fixture +def mock_spg_nft_client(): + """Mock SPGNFTImplClient.""" + + def _mock(public_minting: bool = True): + return patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.SPGNFTImplClient", + return_value=MagicMock( + publicMinting=MagicMock(return_value=public_minting) + ), + ) + + return _mock + + +account = Account.from_key( + "0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" +) +ACCOUNT_ADDRESS = account.address + + +class TestGetPublicMinting: + def test_returns_true_when_public_minting_enabled( + self, mock_web3, mock_spg_nft_client + ): + with mock_spg_nft_client(public_minting=True): + result = get_public_minting(ADDRESS, mock_web3) + assert result is True + + def test_returns_false_when_public_minting_disabled( + self, mock_web3, mock_spg_nft_client + ): + with mock_spg_nft_client(public_minting=False): + result = get_public_minting(ADDRESS, mock_web3) + assert result is False + + def test_throws_error_when_spg_nft_contract_invalid(self, mock_web3): + with pytest.raises(Exception): + get_public_minting("invalid_address", mock_web3) + + +class TestValidateLicenseTermsData: + def test_validates_license_terms_with_dataclass_input( + self, + mock_web3, + mock_royalty_module_client, + mock_module_registry_client, + ): + with ( + mock_royalty_module_client(), + mock_module_registry_client, + ): + result = validate_license_terms_data(LICENSE_TERMS_DATA, mock_web3) + assert isinstance(result, list) + assert len(result) == len(LICENSE_TERMS_DATA) + assert result[0] == LICENSE_TERMS_DATA_CAMEL_CASE + + def test_validates_license_terms_with_dict_input( + self, + mock_web3, + mock_royalty_module_client, + mock_module_registry_client, + ): + with ( + mock_royalty_module_client(), + mock_module_registry_client, + ): + result = validate_license_terms_data( + [ + { + "terms": asdict(LICENSE_TERMS_DATA[0].terms), + "licensing_config": LICENSE_TERMS_DATA[0].licensing_config, + } + ], + mock_web3, + ) + assert result[0] == LICENSE_TERMS_DATA_CAMEL_CASE + + def test_throws_error_when_royalty_policy_not_whitelisted( + self, + mock_web3, + mock_royalty_module_client, + mock_module_registry_client, + ): + with ( + mock_royalty_module_client(is_whitelisted_policy=False), + mock_module_registry_client, + pytest.raises(ValueError, match="The royalty_policy is not whitelisted."), + ): + validate_license_terms_data(LICENSE_TERMS_DATA, mock_web3) + + def test_throws_error_when_currency_not_whitelisted( + self, + mock_web3, + mock_royalty_module_client, + mock_module_registry_client, + ): + with ( + mock_royalty_module_client(is_whitelisted_token=False), + mock_module_registry_client, + pytest.raises(ValueError, match="The currency is not whitelisted."), + ): + validate_license_terms_data(LICENSE_TERMS_DATA, mock_web3) + + def test_validates_multiple_license_terms( + self, + mock_web3, + mock_royalty_module_client, + mock_module_registry_client, + ): + # Use LICENSE_TERMS_DATA twice to test multiple terms + license_terms_data = LICENSE_TERMS_DATA + [ + replace( + LICENSE_TERMS_DATA[0], + terms=replace(LICENSE_TERMS_DATA[0].terms, commercial_rev_share=20), + ) + ] + + with ( + mock_royalty_module_client(), + mock_module_registry_client, + ): + result = validate_license_terms_data(license_terms_data, mock_web3) + assert result[0] == LICENSE_TERMS_DATA_CAMEL_CASE + assert result[1] == { + "terms": { + **LICENSE_TERMS_DATA_CAMEL_CASE["terms"], + "commercialRevShare": 20 * 10**6, + }, + "licensingConfig": LICENSE_TERMS_DATA_CAMEL_CASE["licensingConfig"], + } + + +class TestTransformRegistrationRequest: + def test_routes_to_mint_and_register_attach_pil_terms_when_spg_nft_contract_present( + self, + mock_web3, + mock_get_public_minting, + mock_royalty_module_client, + mock_module_registry_client, + mock_workflow_clients, + ): + request = MintAndRegisterRequest( + spg_nft_contract=ADDRESS, + ip_metadata=IP_METADATA, + license_terms_data=LICENSE_TERMS_DATA, + ) + workflow_mocks = mock_workflow_clients() + license_attachment_client = workflow_mocks["license_attachment_client"] + with ( + mock_get_public_minting(), + mock_royalty_module_client(), + mock_module_registry_client, + workflow_mocks["royalty_token_distribution_patch"], + workflow_mocks["license_attachment_patch"], + workflow_mocks["derivative_workflows_patch"], + ): + result = transform_request(request, mock_web3, account, CHAIN_ID) + # Assert real encoding result (not mock value) + license_attachment_client.contract.encode_abi.assert_called_once() + call_args = license_attachment_client.contract.encode_abi.call_args + assert call_args[1]["abi_element_identifier"] == ( + "mintAndRegisterIpAndAttachPILTerms" + ) + # Verify args + args = call_args[1]["args"] + assert args[0] == ADDRESS # spg_nft_contract + assert args[1] == ACCOUNT_ADDRESS # recipient + assert ( + args[2] == IPMetadata.from_input(IP_METADATA).get_validated_data() + ) # metadata + assert args[3][0] == LICENSE_TERMS_DATA_CAMEL_CASE # license_terms_data + assert args[4] is True # allow_duplicates + assert result.workflow_address == "license_attachment_client_address" + assert result.is_use_multicall3 is True + assert result.extra_data is not None + license_terms_data = result.extra_data.get("license_terms_data") + assert license_terms_data is not None + assert license_terms_data[0] == LICENSE_TERMS_DATA_CAMEL_CASE + assert result.workflow_multicall_reference is not None + + def test_routes_to_register_ip_and_attach_pil_terms_when_nft_contract_and_token_id_present( + self, + mock_web3, + mock_ip_asset_registry_client, + mock_sign_util, + mock_module_clients, + mock_royalty_module_client, + mock_module_registry_client, + mock_workflow_clients, + ): + request = RegisterRegistrationRequest( + nft_contract=ADDRESS, + token_id=1, + license_terms_data=LICENSE_TERMS_DATA, + ) + workflow_mocks = mock_workflow_clients() + module_patches = mock_module_clients() + license_attachment_client = workflow_mocks["license_attachment_client"] + with ( + mock_ip_asset_registry_client(), + mock_sign_util(), + mock_royalty_module_client(), + mock_module_registry_client, + workflow_mocks["royalty_token_distribution_patch"], + workflow_mocks["license_attachment_patch"], + workflow_mocks["derivative_workflows_patch"], + workflow_mocks["get_function_signature"], + module_patches["core_metadata_module_patch"], + module_patches["licensing_module_patch"], + ): + result = transform_request(request, mock_web3, account, CHAIN_ID) + # Assert real encoding result (not mock value) + license_attachment_client.contract.encode_abi.assert_called_once() + assert result.workflow_address == "license_attachment_client_address" + assert result.is_use_multicall3 is False + call_args = license_attachment_client.contract.encode_abi.call_args + assert call_args[1]["abi_element_identifier"] == ( + "registerIpAndAttachPILTerms" + ) + # Verify args + args = call_args[1]["args"] + assert args[0] == ADDRESS # nft_contract + assert args[1] == 1 # token_id + assert args[2] == IPMetadata.from_input().get_validated_data() # metadata + assert args[3][0] == LICENSE_TERMS_DATA_CAMEL_CASE # license_terms_data + assert args[4]["signer"] == ACCOUNT_ADDRESS # signature data + assert args[4]["deadline"] == 1000 + assert args[4]["signature"] == b"signature" + assert result.extra_data is not None + license_terms_data = result.extra_data.get("license_terms_data") + assert license_terms_data is not None + assert license_terms_data[0] == LICENSE_TERMS_DATA_CAMEL_CASE + assert result.is_use_multicall3 is False + assert result.workflow_multicall_reference is not None + + def test_raises_error_for_invalid_request_type( + self, mock_web3, mock_ip_asset_registry_client + ): + with mock_ip_asset_registry_client(): + with pytest.raises(ValueError, match="Invalid register request type"): + transform_request( + RegisterRegistrationRequest( + nft_contract=ADDRESS, + token_id=1, + ), + mock_web3, + account, + CHAIN_ID, + ) + + def test_raises_error_for_invalid_registration_request_type(self, mock_web3): + """Test that ValueError is raised when request doesn't match any known type.""" + with pytest.raises(ValueError, match="Invalid registration request type"): + transform_request( + None, # type: ignore[arg-type] + mock_web3, + account, + CHAIN_ID, + ) + + +class TestHandleMintAndRegisterRequest: + def test_mint_and_register_ip_and_attach_pil_terms_and_distribute_royalty_tokens( + self, + mock_web3, + mock_get_public_minting, + mock_royalty_module_client, + mock_module_registry_client, + mock_workflow_clients, + ): + request = MintAndRegisterRequest( + spg_nft_contract=ADDRESS, + recipient=ADDRESS, + ip_metadata=IP_METADATA, + license_terms_data=LICENSE_TERMS_DATA, + royalty_shares=[RoyaltyShareInput(recipient=ADDRESS, percentage=50.0)], + ) + workflow_mocks = mock_workflow_clients() + royalty_token_distribution_client = workflow_mocks[ + "royalty_token_distribution_client" + ] + with ( + mock_get_public_minting(public_minting=True), + mock_royalty_module_client(), + mock_module_registry_client, + workflow_mocks["royalty_token_distribution_patch"], + workflow_mocks["license_attachment_patch"], + workflow_mocks["derivative_workflows_patch"], + ): + result = transform_request(request, mock_web3, account, CHAIN_ID) + + royalty_token_distribution_client.contract.encode_abi.assert_called_once() + call_args = royalty_token_distribution_client.contract.encode_abi.call_args + assert result.is_use_multicall3 is False + assert ( + result.workflow_address == "royalty_token_distribution_client_address" + ) + assert result.extra_data is not None + license_terms_data = result.extra_data.get("license_terms_data") + assert license_terms_data is not None + assert license_terms_data[0] == LICENSE_TERMS_DATA_CAMEL_CASE + # Verify encode_abi was called with correct method and arguments + assert call_args[1]["abi_element_identifier"] == ( + "mintAndRegisterIpAndAttachPILTermsAndDistributeRoyaltyTokens" + ) + # Verify args + args = call_args[1]["args"] + assert args[0] == ADDRESS # spg_nft_contract + assert args[1] == ADDRESS # recipient + assert ( + args[2] == IPMetadata.from_input(IP_METADATA).get_validated_data() + ) # metadata + assert args[3][0] == LICENSE_TERMS_DATA_CAMEL_CASE # license_terms_data + assert args[4][0]["recipient"] == ADDRESS + assert args[4][0]["percentage"] == 50 * 10**6 + assert args[5] is True # allow_duplicates (default for this method) + assert result.workflow_multicall_reference is not None + + def test_mint_and_register_ip_and_make_derivative_and_distribute_royalty_tokens( + self, + mock_web3, + mock_get_public_minting, + mock_pi_license_template_client, + mock_derivative_ip_asset_registry_client, + mock_license_registry_client, + mock_workflow_clients, + ): + request = MintAndRegisterRequest( + spg_nft_contract=ADDRESS, + ip_metadata=IP_METADATA, + deriv_data=DerivativeDataInput( + parent_ip_ids=[IP_ID], license_terms_ids=[1] + ), + royalty_shares=[RoyaltyShareInput(recipient=ADDRESS, percentage=50.0)], + ) + workflow_mocks = mock_workflow_clients() + royalty_token_distribution_client = workflow_mocks[ + "royalty_token_distribution_client" + ] + with ( + mock_get_public_minting(public_minting=False), + mock_pi_license_template_client(), + mock_derivative_ip_asset_registry_client(), + mock_license_registry_client(), + workflow_mocks["royalty_token_distribution_patch"], + workflow_mocks["license_attachment_patch"], + workflow_mocks["derivative_workflows_patch"], + ): + result = transform_request(request, mock_web3, account, CHAIN_ID) + + royalty_token_distribution_client.contract.encode_abi.assert_called_once() + call_args = royalty_token_distribution_client.contract.encode_abi.call_args + assert result.is_use_multicall3 is False + assert ( + result.workflow_address == "royalty_token_distribution_client_address" + ) + assert result.extra_data is None + # Verify encode_abi was called with correct method and arguments + assert call_args[1]["abi_element_identifier"] == ( + "mintAndRegisterIpAndMakeDerivativeAndDistributeRoyaltyTokens" + ) + # Verify args + args = call_args[1]["args"] + assert args[0] == ADDRESS # spg_nft_contract + assert args[1] == ACCOUNT_ADDRESS # recipient + assert ( + args[2] == IPMetadata.from_input(IP_METADATA).get_validated_data() + ) # metadata + assert args[4][0]["recipient"] == ADDRESS # royalty_shares + assert args[4][0]["percentage"] == 50 * 10**6 # royalty_shares + assert args[5] is True # allow_duplicates (default for this method) + assert result.workflow_multicall_reference is not None + + def test_mint_and_register_ip_and_make_derivative( + self, + mock_web3, + mock_get_public_minting, + mock_pi_license_template_client, + mock_derivative_ip_asset_registry_client, + mock_license_registry_client, + mock_workflow_clients, + ): + request = MintAndRegisterRequest( + spg_nft_contract=ADDRESS, + recipient=ACCOUNT_ADDRESS, + ip_metadata=IP_METADATA, + deriv_data=DerivativeDataInput( + parent_ip_ids=[IP_ID], license_terms_ids=[1] + ), + allow_duplicates=False, + ) + workflow_mocks = mock_workflow_clients() + derivative_workflows_client = workflow_mocks["derivative_workflows_client"] + with ( + mock_get_public_minting(public_minting=True), + mock_pi_license_template_client(), + mock_derivative_ip_asset_registry_client(), + mock_license_registry_client(), + workflow_mocks["royalty_token_distribution_patch"], + workflow_mocks["license_attachment_patch"], + workflow_mocks["derivative_workflows_patch"], + ): + result = transform_request(request, mock_web3, account, CHAIN_ID) + # Assert real encoding result (not mock value) + derivative_workflows_client.contract.encode_abi.assert_called_once() + call_args = derivative_workflows_client.contract.encode_abi.call_args + assert result.is_use_multicall3 is True + assert result.workflow_address == "derivative_workflows_client_address" + assert result.extra_data is None + assert call_args[1]["abi_element_identifier"] == ( + "mintAndRegisterIpAndMakeDerivative" + ) + assert call_args[1]["args"][0] == ADDRESS # spg_nft_contract + assert ( + call_args[1]["args"][2] + == IPMetadata.from_input(IP_METADATA).get_validated_data() + ) # metadata + assert call_args[1]["args"][3] == ACCOUNT_ADDRESS # recipient + assert call_args[1]["args"][4] is False # allow_duplicates + assert result.workflow_multicall_reference is not None + + def test_raises_error_for_invalid_mint_and_register_request_type( + self, + mock_web3, + mock_get_public_minting, + mock_workflow_clients, + ): + request = MintAndRegisterRequest( + spg_nft_contract=ADDRESS, + ip_metadata=IP_METADATA, + ) + workflow_mocks = mock_workflow_clients() + with ( + mock_get_public_minting(), + workflow_mocks["royalty_token_distribution_patch"], + workflow_mocks["license_attachment_patch"], + workflow_mocks["derivative_workflows_patch"], + ): + with pytest.raises( + ValueError, match="Invalid mint and register request type" + ): + transform_request(request, mock_web3, account, CHAIN_ID) + + +class TestHandleRegisterRequest: + def test_register_ip_and_attach_pil_terms_and_deploy_royalty_vault( + self, + mock_web3, + mock_ip_asset_registry_client, + mock_sign_util, + mock_module_clients, + mock_royalty_module_client, + mock_module_registry_client, + mock_workflow_clients, + ): + request = RegisterRegistrationRequest( + nft_contract=ADDRESS, + token_id=1, + ip_metadata=IP_METADATA, + license_terms_data=LICENSE_TERMS_DATA, + royalty_shares=[RoyaltyShareInput(recipient=ADDRESS, percentage=50.0)], + deadline=2000, + ) + workflow_mocks = mock_workflow_clients() + royalty_token_distribution_client = workflow_mocks[ + "royalty_token_distribution_client" + ] + module_patches = mock_module_clients() + with ( + mock_ip_asset_registry_client(), + mock_sign_util(deadline=2000), + mock_royalty_module_client(), + mock_module_registry_client, + workflow_mocks["royalty_token_distribution_patch"], + workflow_mocks["license_attachment_patch"], + workflow_mocks["derivative_workflows_patch"], + workflow_mocks["get_function_signature"], + module_patches["core_metadata_module_patch"], + module_patches["licensing_module_patch"], + ): + result = transform_request(request, mock_web3, account, CHAIN_ID) + + royalty_token_distribution_client.contract.encode_abi.assert_called_once() + call_args = royalty_token_distribution_client.contract.encode_abi.call_args + assert call_args[1]["abi_element_identifier"] == ( + "registerIpAndAttachPILTermsAndDeployRoyaltyVault" + ) + args = call_args[1]["args"] + assert args[0] == ADDRESS # nft_contract + assert args[1] == 1 # token_id + assert ( + args[2] == IPMetadata.from_input(IP_METADATA).get_validated_data() + ) # metadata + assert args[3][0] == LICENSE_TERMS_DATA_CAMEL_CASE # license_terms_data + assert args[4]["signer"] == ACCOUNT_ADDRESS # signature data + assert args[4]["deadline"] == 2000 + assert args[4]["signature"] == b"signature" + assert result.is_use_multicall3 is False + assert ( + result.workflow_address == "royalty_token_distribution_client_address" + ) + assert result.extra_data is not None + royalty_shares = result.extra_data["royalty_shares"] + royalty_total_amount = cast(dict[str, int], result.extra_data)[ + "royalty_total_amount" + ] + assert royalty_total_amount == 50 * 10**6 + assert len(royalty_shares) == 1 + royalty_share_dict = cast(list[dict[str, str | int]], royalty_shares)[0] + assert royalty_share_dict["recipient"] == ADDRESS + assert royalty_share_dict["percentage"] == 50 * 10**6 + assert result.extra_data["deadline"] == 2000 + assert result.workflow_multicall_reference is not None + + def test_register_ip_and_make_derivative_and_deploy_royalty_vault( + self, + mock_web3, + mock_ip_asset_registry_client, + mock_sign_util, + mock_module_clients, + mock_pi_license_template_client, + mock_derivative_ip_asset_registry_client, + mock_license_registry_client, + mock_workflow_clients, + ): + request = RegisterRegistrationRequest( + nft_contract=ADDRESS, + token_id=1, + deriv_data=DerivativeDataInput( + parent_ip_ids=[IP_ID], license_terms_ids=[1] + ), + royalty_shares=[RoyaltyShareInput(recipient=ADDRESS, percentage=50.0)], + ) + workflow_mocks = mock_workflow_clients() + royalty_token_distribution_client = workflow_mocks[ + "royalty_token_distribution_client" + ] + module_patches = mock_module_clients() + with ( + mock_ip_asset_registry_client(), + mock_sign_util(), + mock_pi_license_template_client(), + mock_derivative_ip_asset_registry_client(), + mock_license_registry_client(), + workflow_mocks["royalty_token_distribution_patch"], + workflow_mocks["license_attachment_patch"], + workflow_mocks["derivative_workflows_patch"], + workflow_mocks["get_function_signature"], + module_patches["core_metadata_module_patch"], + module_patches["licensing_module_patch"], + ): + result = transform_request(request, mock_web3, account, CHAIN_ID) + + # Verify encode_abi was called with correct method and arguments + royalty_token_distribution_client.contract.encode_abi.assert_called_once() + call_args = royalty_token_distribution_client.contract.encode_abi.call_args + assert result.is_use_multicall3 is False + assert ( + result.workflow_address == "royalty_token_distribution_client_address" + ) + assert result.extra_data is not None + royalty_shares = result.extra_data["royalty_shares"] + assert len(royalty_shares) == 1 + royalty_share_dict = cast(list[dict[str, str | int]], royalty_shares)[0] + assert royalty_share_dict["recipient"] == ADDRESS + assert royalty_share_dict["percentage"] == 50 * 10**6 + assert call_args[1]["abi_element_identifier"] == ( + "registerIpAndMakeDerivativeAndDeployRoyaltyVault" + ) + # Verify args + args = call_args[1]["args"] + assert args[0] == ADDRESS # nft_contract + assert args[1] == 1 # token_id + assert args[2] == IPMetadata.from_input().get_validated_data() # metadata + assert args[4]["signer"] == ACCOUNT_ADDRESS + assert args[4]["deadline"] == 1000 + assert args[4]["signature"] == b"signature" + assert result.workflow_multicall_reference is not None + + def test_register_ip_and_attach_pil_terms( + self, + mock_web3, + mock_ip_asset_registry_client, + mock_sign_util, + mock_module_clients, + mock_royalty_module_client, + mock_module_registry_client, + mock_workflow_clients, + ): + request = RegisterRegistrationRequest( + nft_contract=ADDRESS, + token_id=1, + ip_metadata=IP_METADATA, + license_terms_data=LICENSE_TERMS_DATA, + ) + workflow_mocks = mock_workflow_clients() + license_attachment_client = workflow_mocks["license_attachment_client"] + module_patches = mock_module_clients() + with ( + mock_ip_asset_registry_client(), + mock_sign_util(), + mock_royalty_module_client(), + mock_module_registry_client, + workflow_mocks["royalty_token_distribution_patch"], + workflow_mocks["license_attachment_patch"], + workflow_mocks["derivative_workflows_patch"], + workflow_mocks["get_function_signature"], + module_patches["core_metadata_module_patch"], + module_patches["licensing_module_patch"], + ): + result = transform_request(request, mock_web3, account, CHAIN_ID) + + # Verify encode_abi was called with correct method and arguments + license_attachment_client.contract.encode_abi.assert_called_once() + call_args = license_attachment_client.contract.encode_abi.call_args + assert result.is_use_multicall3 is False + assert result.workflow_address == "license_attachment_client_address" + assert result.extra_data is not None + license_terms_data = result.extra_data.get("license_terms_data") + assert license_terms_data is not None + assert license_terms_data[0] == LICENSE_TERMS_DATA_CAMEL_CASE + assert call_args[1]["abi_element_identifier"] == ( + "registerIpAndAttachPILTerms" + ) + # Verify args + args = call_args[1]["args"] + assert args[0] == ADDRESS # nft_contract + assert args[1] == 1 # token_id + assert ( + args[2] == IPMetadata.from_input(IP_METADATA).get_validated_data() + ) # metadata + assert args[3][0] == LICENSE_TERMS_DATA_CAMEL_CASE # license_terms_data + assert args[4]["signer"] == ACCOUNT_ADDRESS + assert args[4]["deadline"] == 1000 + assert args[4]["signature"] == b"signature" + assert result.workflow_multicall_reference is not None + + def test_register_ip_and_make_derivative( + self, + mock_web3, + mock_ip_asset_registry_client, + mock_sign_util, + mock_module_clients, + mock_pi_license_template_client, + mock_derivative_ip_asset_registry_client, + mock_license_registry_client, + mock_workflow_clients, + ): + request = RegisterRegistrationRequest( + nft_contract=ADDRESS, + token_id=1, + ip_metadata=IP_METADATA, + deriv_data=DerivativeDataInput( + parent_ip_ids=[IP_ID], license_terms_ids=[1] + ), + ) + workflow_mocks = mock_workflow_clients() + derivative_workflows_client = workflow_mocks["derivative_workflows_client"] + module_patches = mock_module_clients() + with ( + mock_ip_asset_registry_client(), + mock_sign_util(), + mock_pi_license_template_client(), + mock_derivative_ip_asset_registry_client(), + mock_license_registry_client(), + workflow_mocks["royalty_token_distribution_patch"], + workflow_mocks["license_attachment_patch"], + workflow_mocks["derivative_workflows_patch"], + workflow_mocks["get_function_signature"], + module_patches["core_metadata_module_patch"], + module_patches["licensing_module_patch"], + ): + result = transform_request(request, mock_web3, account, CHAIN_ID) + + # Verify encode_abi was called with correct method and arguments + derivative_workflows_client.contract.encode_abi.assert_called_once() + call_args = derivative_workflows_client.contract.encode_abi.call_args + assert result.is_use_multicall3 is False + assert result.workflow_address == "derivative_workflows_client_address" + assert result.extra_data is None + assert call_args[1]["abi_element_identifier"] == ( + "registerIpAndMakeDerivative" + ) + # Verify args + args = call_args[1]["args"] + assert args[0] == ADDRESS # nft_contract + assert args[1] == 1 # token_id + assert ( + args[3] == IPMetadata.from_input(IP_METADATA).get_validated_data() + ) # metadata + assert args[4]["signer"] == ACCOUNT_ADDRESS + assert args[4]["deadline"] == 1000 + assert args[4]["signature"] == b"signature" + assert result.workflow_multicall_reference is not None + + def test_raises_error_when_ip_not_registered( + self, + mock_web3, + mock_ip_asset_registry_client, + ): + request = RegisterRegistrationRequest( + nft_contract=ADDRESS, + token_id=1, + ip_metadata=IP_METADATA, + license_terms_data=LICENSE_TERMS_DATA, + ) + with ( + mock_ip_asset_registry_client(is_registered=True), + pytest.raises( + ValueError, match="The NFT with id 1 is already registered as IP." + ), + ): + transform_request(request, mock_web3, account, CHAIN_ID) + + +class TestTransformDistributeRoyaltyTokensRequest: + @pytest.fixture + def mock_ip_account_impl_client(self): + """Mock IPAccountImplClient.""" + + def _mock(): + return patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.IPAccountImplClient", + return_value=MagicMock(state=MagicMock(return_value=123)), + ) + + return _mock + + @pytest.fixture + def mock_ip_royalty_vault_client(self): + """Mock IpRoyaltyVaultImplClient.""" + + def _mock(): + mock_contract = MagicMock() + mock_contract.encode_abi.return_value = b"encoded_approve" + return patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.IpRoyaltyVaultImplClient", + return_value=MagicMock(contract=mock_contract), + ) + + return _mock + + @pytest.fixture + def mock_royalty_token_distribution_workflows_client(self): + """Mock RoyaltyTokenDistributionWorkflowsClient.""" + + def _mock(): + mock_contract = MagicMock() + mock_contract.address = ( + "royalty_token_distribution_workflows_client_address" + ) + mock_contract.encode_abi.return_value = b"encoded_distribute" + mock_build_multicall = MagicMock() + return patch( + "story_protocol_python_sdk.utils.registration.transform_registration_request.RoyaltyTokenDistributionWorkflowsClient", + return_value=MagicMock( + contract=mock_contract, + build_multicall_transaction=mock_build_multicall, + ), + ) + + return _mock + + def test_transforms_distribute_royalty_tokens_request_successfully( + self, + mock_web3, + mock_ip_account_impl_client, + mock_ip_royalty_vault_client, + mock_royalty_token_distribution_workflows_client, + mock_sign_util, + mock_account, + ): + """Test successful transformation of distribute royalty tokens request.""" + ip_id = IP_ID + royalty_vault = "0xRoyaltyVault" + royalty_shares = [ + RoyaltyShareInput(recipient="0xRecipient1", percentage=50), + RoyaltyShareInput(recipient="0xRecipient2", percentage=50), + ] + deadline = 1000 + total_amount = 100 + + with ( + mock_ip_account_impl_client(), + mock_ip_royalty_vault_client(), + mock_royalty_token_distribution_workflows_client(), + mock_sign_util(), + ): + result = transform_distribute_royalty_tokens_request( + ip_id=ip_id, + royalty_vault=royalty_vault, + deadline=deadline, + web3=mock_web3, + account=mock_account, + chain_id=CHAIN_ID, + royalty_shares=royalty_shares, + total_amount=total_amount, + ) + + # Verify result structure + assert result.encoded_tx_data == b"encoded_distribute" + assert result.is_use_multicall3 is False + assert ( + result.workflow_address + == "royalty_token_distribution_workflows_client_address" + ) + assert result.extra_data is None + assert result.workflow_multicall_reference is not None + + # Verify validated_request structure + assert result.validated_request[0] == ip_id + assert result.validated_request[1] == royalty_shares + signature_data = cast(dict, result.validated_request[2]) + assert signature_data["signer"] == ADDRESS + assert signature_data["deadline"] == deadline + assert signature_data["signature"] == b"signature"