diff --git a/.gitignore b/.gitignore index 834039a..96b46d6 100644 --- a/.gitignore +++ b/.gitignore @@ -195,3 +195,4 @@ src/eas_sdk/proto/ buf.lock CLAUDE.md +node_modules/ diff --git a/examples/full_example.py b/examples/full_example.py index abaee19..4d341b8 100644 --- a/examples/full_example.py +++ b/examples/full_example.py @@ -36,39 +36,41 @@ def __init__(self): def setup_eas(self) -> bool: """Set up EAS instance.""" print("๐Ÿ”ง Setting up EAS instance...") - + try: # Method 1: From environment (recommended) self.eas = EAS.from_environment() print(" โœ… EAS created from environment variables") return True - + except Exception as env_error: print(f" โš ๏ธ Environment setup failed: {env_error}") - + # Method 2: Direct configuration (fallback) try: print(" ๐Ÿ”„ Trying direct configuration...") - + # Get configuration for testnet config = get_network_config("sepolia") - + # You would need to provide these values private_key = os.getenv("PRIVATE_KEY") - from_account = os.getenv("FROM_ACCOUNT") - + from_account = os.getenv("FROM_ACCOUNT") + if not private_key or not from_account: - print(" โŒ PRIVATE_KEY and FROM_ACCOUNT environment variables required") + print( + " โŒ PRIVATE_KEY and FROM_ACCOUNT environment variables required" + ) return False - + self.eas = EAS.from_chain( chain_name="sepolia", private_key=private_key, - from_account=from_account + from_account=from_account, ) print(" โœ… EAS created with direct configuration") return True - + except Exception as direct_error: print(f" โŒ Direct configuration failed: {direct_error}") return False @@ -76,22 +78,22 @@ def setup_eas(self) -> bool: def register_schema(self) -> bool: """Register a schema for our attestations.""" print("\n๐Ÿ“ Registering schema...") - + try: # Define a useful schema for user profiles schema = "string name,string email,uint256 reputation,bool verified,bytes32 profileHash" - + print(f" Schema: {schema}") - + result = self.eas.register_schema( schema=schema, - network_name="sepolia", # Use testnet for safety - revocable=True # Allow revocations + chain_name="sepolia", # Use testnet for safety + revocable=True, # Allow revocations ) - + if result.success: # Extract schema UID from transaction logs or response - schema_uid = result.data.get('schema_uid') + schema_uid = result.data.get("schema_uid") if schema_uid: self.schema_uid = schema_uid print(f" โœ… Schema registered: {schema_uid}") @@ -103,7 +105,7 @@ def register_schema(self) -> bool: else: print(f" โŒ Schema registration failed: {result.error}") return False - + except Exception as e: print(f" โŒ Schema registration error: {e}") return False @@ -111,43 +113,47 @@ def register_schema(self) -> bool: def create_attestation(self) -> bool: """Create an attestation using our schema.""" print("\n๐Ÿ… Creating attestation...") - + if not self.schema_uid: print(" โŒ No schema UID available") return False - + try: # Prepare attestation data - recipient = "0x742d35Cc6634C0532925a3b8D16c30B9b2C4e40B" # Example recipient - + recipient = ( + "0x742d35Cc6634C0532925a3b8D16c30B9b2C4e40B" # Example recipient + ) + # Encode data according to our schema # Schema: "string name,string email,uint256 reputation,bool verified,bytes32 profileHash" attestation_data = encode( ["string", "string", "uint256", "bool", "bytes32"], [ "Alice Johnson", - "alice@example.com", + "alice@example.com", 1000, # reputation score True, # verified - bytes.fromhex("a1b2c3d4e5f6789012345678901234567890123456789012345678901234567890") - ] + bytes.fromhex( + "a1b2c3d4e5f6789012345678901234567890123456789012345678901234567890" + ), + ], ) - + print(f" ๐Ÿ‘ค Recipient: {recipient}") print(f" ๐Ÿ“ฆ Data length: {len(attestation_data)} bytes") - + result = self.eas.create_attestation( schema_uid=self.schema_uid, recipient=recipient, encoded_data=attestation_data, expiration=0, # No expiration revocable=True, - value=0 # No ETH value + value=0, # No ETH value ) - + if result.success: # Extract attestation UID from transaction logs - attestation_uid = result.data.get('attestation_uid') + attestation_uid = result.data.get("attestation_uid") if attestation_uid: self.attestation_uid = attestation_uid print(f" โœ… Attestation created: {attestation_uid}") @@ -160,7 +166,7 @@ def create_attestation(self) -> bool: else: print(f" โŒ Attestation creation failed: {result.error}") return False - + except Exception as e: print(f" โŒ Attestation creation error: {e}") return False @@ -168,14 +174,14 @@ def create_attestation(self) -> bool: def query_attestation(self) -> bool: """Query the attestation we just created.""" print("\n๐Ÿ” Querying attestation...") - + if not self.attestation_uid: print(" โš ๏ธ No attestation UID to query") return False - + try: attestation = self.eas.get_attestation(self.attestation_uid) - + if attestation: print(f" โœ… Attestation found:") print(f" Schema: {attestation[1]}") # schema UID @@ -192,7 +198,7 @@ def query_attestation(self) -> bool: else: print(" โŒ Attestation not found") return False - + except Exception as e: print(f" โŒ Query error: {e}") return False @@ -200,7 +206,7 @@ def query_attestation(self) -> bool: def demonstrate_offchain_attestation(self) -> bool: """Demonstrate off-chain attestation.""" print("\n๐ŸŒ Creating off-chain attestation...") - + try: # Create off-chain attestation message message = { @@ -211,23 +217,25 @@ def demonstrate_offchain_attestation(self) -> bool: "expirationTime": 0, "revocable": True, "refUID": None, - "data": json.dumps({ - "name": "Bob Smith", - "email": "bob@example.com", - "reputation": 750, - "verified": False, - "note": "Off-chain attestation example" - }).encode() + "data": json.dumps( + { + "name": "Bob Smith", + "email": "bob@example.com", + "reputation": 750, + "verified": False, + "note": "Off-chain attestation example", + } + ).encode(), } - + offchain_attestation = self.eas.attest_offchain(message) - + print(f" โœ… Off-chain attestation created") print(f" ๐Ÿ†” UID: {offchain_attestation['message']['uid']}") print(f" โœ๏ธ Signature: {offchain_attestation['signature']['r'][:10]}...") - + return True - + except Exception as e: print(f" โŒ Off-chain attestation error: {e}") return False @@ -235,14 +243,14 @@ def demonstrate_offchain_attestation(self) -> bool: def demonstrate_revocation(self) -> bool: """Demonstrate attestation revocation.""" print("\nโŒ Revoking attestation...") - + if not self.attestation_uid: print(" โš ๏ธ No attestation UID to revoke") return False - + try: result = self.eas.revoke_attestation(self.attestation_uid) - + if result.success: print(f" โœ… Attestation revoked") print(f" ๐Ÿ”— Transaction: {result.tx_hash}") @@ -250,7 +258,7 @@ def demonstrate_revocation(self) -> bool: else: print(f" โŒ Revocation failed: {result.error}") return False - + except Exception as e: print(f" โŒ Revocation error: {e}") return False @@ -259,31 +267,33 @@ def run_complete_workflow(self): """Run the complete EAS workflow.""" print("๐Ÿš€ EAS SDK Complete Workflow Example") print("=" * 50) - + # Step 1: Setup if not self.setup_eas(): print("\nโŒ Setup failed. Cannot continue.") return - + # Step 2: Register schema if not self.register_schema(): print("\nโš ๏ธ Using fallback schema for remaining examples...") # Use a known schema for testnets - self.schema_uid = "0x83c23d3c24c90bc5d1b8b44a7c2cc50e4d9efca2e80d78a3ce5f8e4d10e5d4e5" - + self.schema_uid = ( + "0x83c23d3c24c90bc5d1b8b44a7c2cc50e4d9efca2e80d78a3ce5f8e4d10e5d4e5" + ) + # Step 3: Create attestation if self.create_attestation(): # Step 4: Query attestation self.query_attestation() - + # Step 5: Demonstrate revocation print("\nโฑ๏ธ Waiting 5 seconds before revocation...") time.sleep(5) self.demonstrate_revocation() - + # Step 6: Off-chain attestation (always works) self.demonstrate_offchain_attestation() - + print("\nโœจ Complete workflow finished!") print("\n๐Ÿ“š What you've learned:") print(" โ€ข How to initialize EAS from environment variables") @@ -291,7 +301,7 @@ def run_complete_workflow(self): print(" โ€ข How to create and query attestations") print(" โ€ข How to create off-chain attestations") print(" โ€ข How to revoke attestations") - + print("\n๐Ÿ› ๏ธ Development Tips:") print(" โ€ข Use testnet (sepolia) for development") print(" โ€ข Check transaction status before proceeding") @@ -306,4 +316,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/multi_chain_examples.py b/examples/multi_chain_examples.py index 3d19c92..c4bda4a 100644 --- a/examples/multi_chain_examples.py +++ b/examples/multi_chain_examples.py @@ -128,7 +128,7 @@ def demonstrate_factory_methods(): from EAS.config import create_eas_instance eas = create_eas_instance( - network_name='mainnet', # or 'sepolia', 'goerli', etc. + chain_name='mainnet', # or 'sepolia', 'goerli', etc. from_account='0x1234...', private_key='0x1234...' ) diff --git a/examples/quick_start.py b/examples/quick_start.py index a4d7c05..f84246c 100644 --- a/examples/quick_start.py +++ b/examples/quick_start.py @@ -31,10 +31,10 @@ def main(): """Quick start example showing common EAS operations.""" - + print("๐Ÿš€ EAS SDK Quick Start Example") print("=" * 40) - + # Example 1: Initialize EAS from environment print("\n1. Initialize EAS from environment variables") try: @@ -51,11 +51,11 @@ def main(): # Example 2: List supported chains print("\n2. List available chains") from eas.config import list_supported_chains, get_mainnet_chains, get_testnet_chains - + all_chains = list_supported_chains() mainnet_chains = get_mainnet_chains() testnet_chains = get_testnet_chains() - + print(f" ๐Ÿ“Š Total supported chains: {len(all_chains)}") print(f" ๐Ÿฆ Mainnet chains: {', '.join(mainnet_chains[:3])}...") print(f" ๐Ÿงช Testnet chains: {', '.join(testnet_chains[:3])}...") @@ -65,15 +65,17 @@ def main(): try: # Simple identity schema schema = "string name,uint256 age,bool verified" - result = eas.register_schema(schema, network_name="sepolia") # Use testnet - + result = eas.register_schema(schema, chain_name="sepolia") # Use testnet + if result.success: - print(f" โœ… Schema registered with UID: {result.data.get('schema_uid', 'N/A')}") - schema_uid = result.data.get('schema_uid') + print( + f" โœ… Schema registered with UID: {result.data.get('schema_uid', 'N/A')}" + ) + schema_uid = result.data.get("schema_uid") else: print(f" โŒ Schema registration failed: {result.error}") schema_uid = None - + except Exception as e: print(f" โš ๏ธ Schema registration skipped: {e}") schema_uid = None @@ -84,14 +86,16 @@ def main(): # Use a known schema or the one we just created if not schema_uid: # Fallback to a common test schema (adjust as needed) - schema_uid = "0x83c23d3c24c90bc5d1b8b44a7c2cc50e4d9efca2e80d78a3ce5f8e4d10e5d4e5" - + schema_uid = ( + "0x83c23d3c24c90bc5d1b8b44a7c2cc50e4d9efca2e80d78a3ce5f8e4d10e5d4e5" + ) + # Note: In a real application, you'd encode data properly # This is just for demonstration print(f" ๐Ÿ“ Using schema UID: {schema_uid}") print(" โš ๏ธ Attestation creation requires proper data encoding") print(" ๐Ÿ“š See full_example.py for complete attestation workflow") - + except Exception as e: print(f" โš ๏ธ Attestation example skipped: {e}") @@ -109,4 +113,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/pyproject.toml b/pyproject.toml index d38c44f..086eeeb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "eas-sdk" -version = "0.1.5" +version = "0.1.6" description = "Python SDK for Ethereum Attestation Service (EAS)" readme = "README.md" license = {text = "MIT"} diff --git a/src/main/eas/cli.py b/src/main/eas/cli.py index 6e362dd..b4fe259 100644 --- a/src/main/eas/cli.py +++ b/src/main/eas/cli.py @@ -163,7 +163,10 @@ def query_eas_graphql( def show_schema_impl( - schema_uid: str, output_format: str = "eas", network: str = "mainnet" + schema_uid: str, + output_format: str = "eas", + chain_name: Optional[str] = None, + chain_id: Optional[int] = None, ) -> None: """ Display schema information from EAS GraphQL API. @@ -175,9 +178,18 @@ def show_schema_impl( """ try: # Get GraphQL endpoint - endpoint = EAS_GRAPHQL_ENDPOINTS.get(network) + # Resolve chain name if needed + if chain_name is not None: + resolved_name = chain_name + else: + from eas.config import get_chain_name_from_id + + assert chain_id is not None # Type hint for mypy + resolved_name = get_chain_name_from_id(chain_id) + + endpoint = EAS_GRAPHQL_ENDPOINTS.get(resolved_name) if not endpoint: - raise ValueError(f"Unsupported network: {network}") + raise ValueError(f"Unsupported chain: {resolved_name}") # GraphQL query for schema query = """ @@ -297,7 +309,10 @@ def format_attestation_yaml(attestation_data: Dict[str, Any]) -> None: def show_attestation_impl( - attestation_uid: str, output_format: str = "eas", network: str = "mainnet" + attestation_uid: str, + output_format: str = "eas", + chain_name: Optional[str] = None, + chain_id: Optional[int] = None, ) -> None: """ Display attestation information from EAS GraphQL API. @@ -309,9 +324,18 @@ def show_attestation_impl( """ try: # Get GraphQL endpoint - endpoint = EAS_GRAPHQL_ENDPOINTS.get(network) + # Resolve chain name if needed + if chain_name is not None: + resolved_name = chain_name + else: + from eas.config import get_chain_name_from_id + + assert chain_id is not None # Type hint for mypy + resolved_name = get_chain_name_from_id(chain_id) + + endpoint = EAS_GRAPHQL_ENDPOINTS.get(resolved_name) if not endpoint: - raise ValueError(f"Unsupported network: {network}") + raise ValueError(f"Unsupported chain: {resolved_name}") # GraphQL query for attestation query = """ @@ -479,7 +503,8 @@ def encode_schema_impl( encoding: str = "json", namespace: Optional[str] = None, message_type: Optional[str] = None, - network: str = "mainnet", + chain_name: Optional[str] = None, + chain_id: Optional[int] = None, ) -> None: """ Retrieve attestation data and encode it using schema-based encoding. @@ -493,7 +518,16 @@ def encode_schema_impl( network: Network to query (mainnet, sepolia, etc.) """ try: - endpoint = _get_endpoint_for_network(network) + # Resolve chain name if needed + if chain_name is not None: + resolved_name = chain_name + else: + from eas.config import get_chain_name_from_id + + assert chain_id is not None # Type hint for mypy + resolved_name = get_chain_name_from_id(chain_id) + + endpoint = _get_endpoint_for_network(resolved_name) # Fetch and parse attestation data parsed_data = _fetch_attestation_data(endpoint, attestation_uid) @@ -580,7 +614,10 @@ def _display_generated_code(generated_code: str, output_format: str) -> None: def generate_schema_impl( - schema_uid: str, output_format: str = "eas", network: str = "mainnet" + schema_uid: str, + output_format: str = "eas", + chain_name: Optional[str] = None, + chain_id: Optional[int] = None, ) -> None: """ Generate code from EAS schema definition. @@ -591,7 +628,16 @@ def generate_schema_impl( network: Network to query (mainnet, sepolia, optimism, etc.) """ try: - endpoint = _get_endpoint_for_network(network) + # Resolve chain name if needed + if chain_name is not None: + resolved_name = chain_name + else: + from eas.config import get_chain_name_from_id + + assert chain_id is not None # Type hint for mypy + resolved_name = get_chain_name_from_id(chain_id) + + endpoint = _get_endpoint_for_network(resolved_name) # Fetch and parse schema data parsed_data = _fetch_schema_data(endpoint, schema_uid) @@ -665,11 +711,12 @@ def extract_proto_impl( @click.group() @click.option( - "--network", + "--chain-name", "-n", type=click.Choice( [ "mainnet", + "ethereum", "sepolia", "base-sepolia", "optimism", @@ -679,25 +726,44 @@ def extract_proto_impl( ], case_sensitive=False, ), - default="mainnet", - help="Network to query (default: mainnet)", + help="Chain name to query (e.g., ethereum, base, sepolia)", ) -@click.version_option(version="0.1.4", prog_name="EAS Tools") +@click.option( + "--chain-id", + "-i", + type=int, + help="Chain ID to query (e.g., 1, 8453, 11155111)", +) +@click.version_option(version="0.1.5", prog_name="EAS Tools") @click.pass_context -def main(ctx: click.Context, network: str) -> None: +def main( + ctx: click.Context, chain_name: Optional[str], chain_id: Optional[int] +) -> None: """๐Ÿ› ๏ธ EAS Tools - Ethereum Attestation Service CLI - Query and interact with EAS data across multiple networks. - The --network flag applies to all subcommands. + Query and interact with EAS data across multiple chains. + Specify either --chain-name or --chain-id (not both). \b Examples: - eas-tools -n base-sepolia attestation show 0xceff... - eas-tools -n mainnet schema show 0x86ad... + eas-tools --chain-name base-sepolia attestation show 0xceff... + eas-tools --chain-id 8453 schema show 0x86ad... eas-tools dev chains """ + # Validate XOR requirement + if (chain_name is None) == (chain_id is None): + if chain_name is None and chain_id is None: + # Default to mainnet for backward compatibility + chain_name = "mainnet" + else: + click.echo( + "Error: Specify either --chain-name or --chain-id (not both)", err=True + ) + ctx.exit(1) + ctx.ensure_object(dict) - ctx.obj["network"] = network + ctx.obj["chain_name"] = chain_name + ctx.obj["chain_id"] = chain_id # Schema commands group @@ -728,8 +794,9 @@ def show(ctx: click.Context, schema_uid: str, output_format: str) -> None: Example: eas-tools -n base-sepolia schema show 0x86ad448d1844cd6d7c13cf5d8effbc70a596af78bd0a01b747e2acb5f74c6d9b """ - network = ctx.obj["network"] - show_schema_impl(schema_uid, output_format, network) + chain_name = ctx.obj["chain_name"] + chain_id = ctx.obj["chain_id"] + show_schema_impl(schema_uid, output_format, chain_name, chain_id) # Attestation commands group @@ -762,8 +829,9 @@ def show_attestation( Example: eas-tools -n base-sepolia attestation show 0xceffa19c412727fa6ea41ce8f685a397d93d744c5314f19c39fa7b007a985c41 """ - network = ctx.obj["network"] - show_attestation_impl(attestation_uid, output_format, network) + chain_name = ctx.obj["chain_name"] + chain_id = ctx.obj["chain_id"] + show_attestation_impl(attestation_uid, output_format, chain_name, chain_id) @attestation.command() @@ -815,9 +883,10 @@ def decode( Example: eas-tools -n base-sepolia attestation decode 0xceff... """ - network = ctx.obj["network"] + chain_name = ctx.obj["chain_name"] + chain_id = ctx.obj["chain_id"] encode_schema_impl( - attestation_uid, format, encoding, namespace, message_type, network + attestation_uid, format, encoding, namespace, message_type, chain_name, chain_id ) @@ -839,8 +908,9 @@ def generate(ctx: click.Context, schema_uid: str, output_format: str) -> None: Example: eas-tools -n base-sepolia schema generate 0x86ad... --format proto """ - network = ctx.obj["network"] - generate_schema_impl(schema_uid, output_format, network) + chain_name = ctx.obj["chain_name"] + chain_id = ctx.obj["chain_id"] + generate_schema_impl(schema_uid, output_format, chain_name, chain_id) # Query commands group @@ -1110,8 +1180,17 @@ def attestations( eas-tools query attestations --limit 50 --offset 100 --format json """ try: - network = ctx.obj["network"] - client = EASQueryClient(network=network) + chain_name = ctx.obj["chain_name"] + chain_id = ctx.obj["chain_id"] + # Create client using appropriate parameter + if chain_name is not None: + client = EASQueryClient(network=chain_name) + else: + # Convert chain_id to chain_name for client + from eas.config import get_chain_name_from_id + + chain_name = get_chain_name_from_id(chain_id) + client = EASQueryClient(network=chain_name) # Build filter filters = AttestationFilter( @@ -1134,7 +1213,9 @@ def attestations( sort_order=SortOrder.DESC, ) - console.print(f"๐Ÿ” Searching for attestations on {network}...") + console.print( + f"๐Ÿ” Searching for attestations on {chain_name or str(chain_id)}..." + ) results = client.find_attestations(filters) format_attestation_results(results, format) @@ -1204,8 +1285,17 @@ def schemas( eas-tools query schemas --limit 25 --format json """ try: - network = ctx.obj["network"] - client = EASQueryClient(network=network) + chain_name = ctx.obj["chain_name"] + chain_id = ctx.obj["chain_id"] + # Create client using appropriate parameter + if chain_name is not None: + client = EASQueryClient(network=chain_name) + else: + # Convert chain_id to chain_name for client + from eas.config import get_chain_name_from_id + + chain_name = get_chain_name_from_id(chain_id) + client = EASQueryClient(network=chain_name) # Build filter filters = SchemaFilter( @@ -1221,7 +1311,7 @@ def schemas( sort_order=SortOrder.DESC, ) - console.print(f"๐Ÿ” Searching for schemas on {network}...") + console.print(f"๐Ÿ” Searching for schemas on {chain_name or str(chain_id)}...") results = client.find_schemas(filters) format_schema_results(results, format) @@ -1280,9 +1370,18 @@ def revoke( eas-tools -n base-sepolia revoke 0xceff... --dry-run """ try: - network = ctx.obj["network"] + chain_name = ctx.obj["chain_name"] + chain_id = ctx.obj["chain_id"] + + # Get the display name for the chain + if chain_name is not None: + display_name = chain_name + else: + from eas.config import get_chain_name_from_id + + display_name = get_chain_name_from_id(chain_id) - console.print(f"๐Ÿ” Preparing to revoke attestation on {network}...") + console.print(f"๐Ÿ” Preparing to revoke attestation on {display_name}...") console.print(f" Attestation UID: {attestation_uid}") # Get private key from CLI option or environment @@ -1315,7 +1414,7 @@ def revoke( if dry_run: console.print("\n๐Ÿ” DRY RUN - Transaction will not be submitted") - console.print(f" Network: {network}") + console.print(f" Chain: {display_name}") console.print(f" Attestation UID: {attestation_uid}") console.print(f" From account: {from_account}") if gas_limit: @@ -1325,7 +1424,7 @@ def revoke( # Set chain environment variable if not already set if not os.environ.get("EAS_CHAIN"): - os.environ["EAS_CHAIN"] = network + os.environ["EAS_CHAIN"] = display_name if not os.environ.get("EAS_PRIVATE_KEY"): os.environ["EAS_PRIVATE_KEY"] = private_key if not os.environ.get("EAS_FROM_ACCOUNT"): @@ -1351,17 +1450,17 @@ def revoke( console.print(f" Block number: {result.block_number}") console.print("\n๐Ÿ”— View on explorer:") - if network == "mainnet": + if display_name == "mainnet" or display_name == "ethereum": console.print(f" https://etherscan.io/tx/{result.tx_hash}") - elif network == "sepolia": + elif display_name == "sepolia": console.print(f" https://sepolia.etherscan.io/tx/{result.tx_hash}") - elif network == "base": + elif display_name == "base": console.print(f" https://basescan.org/tx/{result.tx_hash}") - elif network == "base-sepolia": + elif display_name == "base-sepolia": console.print(f" https://sepolia.basescan.org/tx/{result.tx_hash}") - elif network == "optimism": + elif display_name == "optimism": console.print(f" https://optimistic.etherscan.io/tx/{result.tx_hash}") - elif network == "arbitrum": + elif display_name == "arbitrum": console.print(f" https://arbiscan.io/tx/{result.tx_hash}") except EASValidationError as e: diff --git a/src/main/eas/config.py b/src/main/eas/config.py index b93451e..61ccfb5 100644 --- a/src/main/eas/config.py +++ b/src/main/eas/config.py @@ -182,12 +182,20 @@ } -def get_network_config(network_name: str) -> Dict[str, Any]: +def get_network_config( + *, + chain_name: Optional[str] = None, + chain_id: Optional[int] = None, +) -> Dict[str, Any]: """ - Get network configuration by name with enhanced security validation. + Get network configuration by chain name or chain ID with enhanced security validation. Args: - network_name: Name of the network (e.g., 'ethereum', 'base', 'sepolia', 'arbitrum') + chain_name: Name of the network (e.g., 'ethereum', 'base', 'sepolia', 'arbitrum') + chain_id: Chain ID of the network (e.g., 1, 8453, 11155111, 42161) + + Note: + Exactly one of chain_name or chain_id must be provided (XOR). Returns: Network configuration dictionary containing: @@ -201,46 +209,106 @@ def get_network_config(network_name: str) -> Dict[str, Any]: - explorer_url: Block explorer URL Raises: - ValueError: If network name is not supported + ValueError: If neither or both parameters are provided, or if chain is not supported SecurityError: If network configuration fails security validation """ - # Validate network name securely - try: - network_name = SecureEnvironmentValidator.validate_chain_name(network_name) - except SecurityError as e: - raise ValueError(f"Invalid network name: {str(e)}") - - if network_name not in SUPPORTED_CHAINS: - supported_networks = list(SUPPORTED_CHAINS.keys()) - mainnet_networks = [ - name - for name, config in SUPPORTED_CHAINS.items() - if config.get("network_type") == "mainnet" - ] - testnet_networks = [ - name - for name, config in SUPPORTED_CHAINS.items() - if config.get("network_type") == "testnet" - ] - - error_msg = ( - f"Unsupported network: '{network_name}'. " - f"\nSupported networks ({len(supported_networks)} total):" - f"\nMainnets: {mainnet_networks}" - f"\nTestnets: {testnet_networks}" + # Validate XOR requirement + if (chain_name is None) == (chain_id is None): + raise ValueError( + "Exactly one of 'chain_name' or 'chain_id' must be provided (not both, not neither)" ) - raise ValueError(error_msg) - config = SUPPORTED_CHAINS[network_name].copy() + config: Optional[Dict[str, Any]] = None + lookup_key: str = "" + + if chain_name is not None: + # Validate chain name securely + try: + chain_name = SecureEnvironmentValidator.validate_chain_name(chain_name) + except SecurityError as e: + raise ValueError(f"Invalid chain name: {str(e)}") + + if chain_name not in SUPPORTED_CHAINS: + supported_networks = list(SUPPORTED_CHAINS.keys()) + mainnet_networks = [ + name + for name, config in SUPPORTED_CHAINS.items() + if config.get("network_type") == "mainnet" + ] + testnet_networks = [ + name + for name, config in SUPPORTED_CHAINS.items() + if config.get("network_type") == "testnet" + ] + + error_msg = ( + f"Unsupported chain name: '{chain_name}'. " + f"\nSupported chains ({len(supported_networks)} total):" + f"\nMainnets: {mainnet_networks}" + f"\nTestnets: {testnet_networks}" + ) + raise ValueError(error_msg) + + config = SUPPORTED_CHAINS[chain_name].copy() + lookup_key = chain_name + + else: # chain_id is not None + # Validate chain ID securely + try: + validated_chain_id = SecureEnvironmentValidator.validate_chain_id( + str(chain_id) + ) + except SecurityError as e: + raise ValueError(f"Invalid chain ID: {str(e)}") + + # Find config by chain_id + for name, chain_config in SUPPORTED_CHAINS.items(): + if chain_config.get("chain_id") == validated_chain_id: + config = chain_config.copy() + lookup_key = name + break + + if config is None: + supported_chain_ids = [ + config["chain_id"] + for config in SUPPORTED_CHAINS.values() + if "chain_id" in config and isinstance(config["chain_id"], int) + ] + mainnet_chain_ids = [ + config["chain_id"] + for config in SUPPORTED_CHAINS.values() + if ( + config.get("network_type") == "mainnet" + and "chain_id" in config + and isinstance(config["chain_id"], int) + ) + ] + testnet_chain_ids = [ + config["chain_id"] + for config in SUPPORTED_CHAINS.values() + if ( + config.get("network_type") == "testnet" + and "chain_id" in config + and isinstance(config["chain_id"], int) + ) + ] + + error_msg = ( + f"Unsupported chain ID: {chain_id}. " + f"\nSupported chain IDs ({len(supported_chain_ids)} total):" + f"\nMainnets: {sorted(mainnet_chain_ids)}" + f"\nTestnets: {sorted(testnet_chain_ids)}" + ) + raise ValueError(error_msg) + + assert config is not None # Should never happen due to validation above # Enhanced configuration validation with security checks - validate_chain_config(config, network_name) + validate_chain_config(config, lookup_key) # Verify contract addresses against known good values - if not _verify_contract_integrity(config, network_name): - raise SecurityError( - f"Contract address integrity check failed for {network_name}" - ) + if not _verify_contract_integrity(config, lookup_key): + raise SecurityError(f"Contract address integrity check failed for {lookup_key}") return config @@ -285,6 +353,55 @@ def get_testnet_chains() -> List[str]: return sorted(testnet_chains) +def get_chain_id_from_name(chain_name: str) -> int: + """ + Get chain ID from chain name. + + Args: + chain_name: Name of the chain (e.g., 'ethereum', 'base', 'sepolia') + + Returns: + Chain ID as integer + + Raises: + ValueError: If chain name is not supported + """ + config = get_network_config(chain_name=chain_name) + return int(config["chain_id"]) + + +def get_chain_name_from_id(chain_id: int) -> str: + """ + Get chain name from chain ID. + + Args: + chain_id: Chain ID (e.g., 1, 8453, 11155111) + + Returns: + Chain name as string + + Raises: + ValueError: If chain ID is not supported + """ + config = get_network_config(chain_id=chain_id) + return str(config["name"]) + + +def list_supported_chain_ids() -> List[int]: + """ + Get a list of all supported chain IDs. + + Returns: + List of supported chain IDs sorted numerically + """ + chain_ids = [ + config["chain_id"] + for config in SUPPORTED_CHAINS.values() + if "chain_id" in config and isinstance(config["chain_id"], int) + ] + return sorted(chain_ids) + + def validate_chain_config(config: Dict[str, Any], chain_name: str) -> None: """ Validate a chain configuration dictionary with enhanced security checks. @@ -381,10 +498,11 @@ def get_example_attestation_data() -> Dict[str, Any]: def create_eas_instance( - network_name: Optional[str] = None, + network_name: Optional[str] = None, # Deprecated parameter name from_account: Optional[str] = None, private_key: Optional[str] = None, rpc_url: Optional[str] = None, + chain_name: Optional[str] = None, # New parameter name ) -> "EAS": """ DEPRECATED: Legacy factory method with security warnings. @@ -411,10 +529,21 @@ def create_eas_instance( from .core import EAS - # Use environment variables if not provided (with validation) - network_name = network_name or os.getenv("NETWORK", "sepolia") - # Ensure network_name is not None (it has default "sepolia") - assert network_name is not None, "network_name should not be None after assignment" + # Handle both old and new parameter names for backward compatibility + if (network_name is None) == (chain_name is None): + if network_name is None and chain_name is None: + # Use environment variables if not provided (with validation) + network_name = os.getenv("NETWORK", "sepolia") + else: + raise ValueError( + "Cannot specify both network_name and chain_name. Use chain_name for new code." + ) + + # Use the provided parameter or default + final_chain_name = chain_name or network_name + assert ( + final_chain_name is not None + ), "chain_name should not be None after assignment" from_account_env = os.getenv("FROM_ACCOUNT") from_account = from_account or from_account_env @@ -428,7 +557,9 @@ def create_eas_instance( # Enhanced parameter validation try: # Validate all inputs using security validator - network_name = SecureEnvironmentValidator.validate_chain_name(network_name) + final_chain_name = SecureEnvironmentValidator.validate_chain_name( + final_chain_name + ) from_account = SecureEnvironmentValidator.validate_address(from_account) private_key = SecureEnvironmentValidator.validate_private_key(private_key) @@ -442,7 +573,7 @@ def create_eas_instance( except SecurityError as e: raise ValueError(f"Security validation failed: {str(e)}") - config = get_network_config(network_name) + config = get_network_config(chain_name=final_chain_name) # Override with validated parameters if rpc_url: diff --git a/src/main/eas/core.py b/src/main/eas/core.py index d3cd357..1f1a1e5 100644 --- a/src/main/eas/core.py +++ b/src/main/eas/core.py @@ -65,7 +65,9 @@ def __init__( @classmethod def from_chain( cls, - chain_name: str, + *, + chain_name: Optional[str] = None, + chain_id: Optional[int] = None, private_key: str, from_account: str, rpc_url: Optional[str] = None, @@ -73,38 +75,56 @@ def from_chain( **kwargs: Any, ) -> "EAS": """ - Create an EAS instance from a chain name with automatic configuration resolution. + Create an EAS instance from a chain name or ID with automatic configuration resolution. Args: chain_name: Name of the blockchain network (e.g., 'ethereum', 'base', 'sepolia') + chain_id: Chain ID of the network (e.g., 1, 8453, 11155111, 42161) private_key: Private key for transaction signing from_account: Account address for transactions rpc_url: Optional custom RPC URL (overrides chain default) contract_address: Optional custom contract address (overrides chain default) **kwargs: Additional arguments passed to EAS constructor + Note: + Exactly one of chain_name or chain_id must be provided (XOR). + Returns: EAS instance configured for the specified chain Raises: - ValueError: If chain name is invalid or required parameters are missing + ValueError: If neither or both chain parameters are provided, or if chain is not supported TypeError: If parameters have incorrect types ConnectionError: If unable to connect to the network - Example: - # Using chain defaults - eas = EAS.from_chain('ethereum', private_key, from_account) + Examples: + # Using chain name + eas = EAS.from_chain(chain_name='ethereum', private_key=pk, from_account=addr) + + # Using chain ID + eas = EAS.from_chain(chain_id=8453, private_key=pk, from_account=addr) # With custom RPC - eas = EAS.from_chain('base', private_key, from_account, + eas = EAS.from_chain(chain_name='base', private_key=pk, from_account=addr, rpc_url='https://my-custom-base-rpc.com') """ from .config import get_network_config, validate_chain_config + # Validate XOR requirement + if (chain_name is None) == (chain_id is None): + raise ValueError( + "Exactly one of 'chain_name' or 'chain_id' must be provided (not both, not neither)" + ) + # Enhanced security validation using SecureEnvironmentValidator try: # Validate all inputs with comprehensive security checks - chain_name = SecureEnvironmentValidator.validate_chain_name(chain_name) + if chain_name is not None: + chain_name = SecureEnvironmentValidator.validate_chain_name(chain_name) + if chain_id is not None: + # Validate chain_id format + SecureEnvironmentValidator.validate_chain_id(str(chain_id)) + private_key = SecureEnvironmentValidator.validate_private_key(private_key) from_account = SecureEnvironmentValidator.validate_address(from_account) except SecurityError as e: @@ -112,9 +132,17 @@ def from_chain( raise ValueError(f"Security validation failed: {str(e)}") try: - # Get chain configuration - config = get_network_config(chain_name) - validate_chain_config(config, chain_name) + # Get chain configuration using new API + if chain_name is not None: + config = get_network_config(chain_name=chain_name) + lookup_key = chain_name + else: + config = get_network_config(chain_id=chain_id) + lookup_key = config[ + "name" + ] # Use the chain name from config for logging + + validate_chain_config(config, lookup_key) # Override with provided parameters final_rpc_url = rpc_url if rpc_url is not None else config["rpc_url"] @@ -189,14 +217,22 @@ def from_chain( ) except (ValueError, TypeError) as e: - logger.error("eas_from_chain_failed", chain_name=chain_name, error=str(e)) + # Use the original chain identifier for better error messages + error_chain_id = chain_name if chain_name is not None else str(chain_id) + logger.error( + "eas_from_chain_failed", chain_name=error_chain_id, error=str(e) + ) raise except Exception as e: + # Use the original chain identifier for better error messages + error_chain_id = chain_name if chain_name is not None else str(chain_id) logger.error( - "eas_from_chain_unexpected_error", chain_name=chain_name, error=str(e) + "eas_from_chain_unexpected_error", + chain_name=error_chain_id, + error=str(e), ) raise EASError( - f"Failed to create EAS instance for chain '{chain_name}': {str(e)}" + f"Failed to create EAS instance for chain '{error_chain_id}': {str(e)}" ) @classmethod @@ -205,12 +241,16 @@ def from_environment(cls, **kwargs: Any) -> "EAS": Create an EAS instance from environment variables with comprehensive configuration support. Environment Variables: - EAS_CHAIN: Chain name (required, e.g., 'ethereum', 'base', 'sepolia') + EAS_CHAIN: Chain name (e.g., 'ethereum', 'base', 'sepolia') - XOR with EAS_CHAIN_ID + EAS_CHAIN_ID: Chain ID (e.g., 1, 8453, 11155111) - XOR with EAS_CHAIN EAS_PRIVATE_KEY: Private key for signing (required) EAS_FROM_ACCOUNT: Account address for transactions (required) EAS_RPC_URL: Custom RPC URL (optional, overrides chain default) EAS_CONTRACT_ADDRESS: Custom contract address (optional, overrides chain default) + Note: + Exactly one of EAS_CHAIN or EAS_CHAIN_ID must be provided (XOR). + Args: **kwargs: Additional arguments passed to EAS constructor @@ -222,13 +262,17 @@ def from_environment(cls, **kwargs: Any) -> "EAS": TypeError: If environment variables have incorrect format ConnectionError: If unable to connect to the network - Example: - # Set environment variables + Examples: + # Using chain name export EAS_CHAIN=ethereum export EAS_PRIVATE_KEY=0x1234... export EAS_FROM_ACCOUNT=0xabcd... + eas = EAS.from_environment() - # Create EAS instance + # Using chain ID + export EAS_CHAIN_ID=8453 + export EAS_PRIVATE_KEY=0x1234... + export EAS_FROM_ACCOUNT=0xabcd... eas = EAS.from_environment() """ # Required environment variables @@ -238,13 +282,28 @@ def from_environment(cls, **kwargs: Any) -> "EAS": # "EAS_FROM_ACCOUNT": "from account address", # } + # Check for XOR requirement between EAS_CHAIN and EAS_CHAIN_ID + chain_name_env = os.getenv("EAS_CHAIN") + chain_id_env = os.getenv("EAS_CHAIN_ID") + + if (chain_name_env is None) == (chain_id_env is None): + raise ValueError( + "Exactly one of 'EAS_CHAIN' or 'EAS_CHAIN_ID' environment variables " + "must be provided (not both, not neither)" + ) + # Use comprehensive batch environment variable validation required_env_types = { - "EAS_CHAIN": "chain_name", "EAS_PRIVATE_KEY": "private_key", "EAS_FROM_ACCOUNT": "address", } + # Add chain validation based on which is provided + if chain_name_env is not None: + required_env_types["EAS_CHAIN"] = "chain_name" + else: + required_env_types["EAS_CHAIN_ID"] = "chain_id" + optional_env_types = { "EAS_RPC_URL": "rpc_url", "EAS_CONTRACT_ADDRESS": "address", @@ -258,7 +317,12 @@ def from_environment(cls, **kwargs: Any) -> "EAS": logger.error("environment_validation_failed", error=str(e)) raise ValueError(f"Environment variable validation failed: {str(e)}") - chain_name = env_values["EAS_CHAIN"] + # Extract chain identifier + chain_name = env_values.get("EAS_CHAIN") + chain_id_str = env_values.get("EAS_CHAIN_ID") + chain_id: Optional[int] = None + if chain_id_str is not None: + chain_id = int(chain_id_str) # Convert string to int private_key = env_values["EAS_PRIVATE_KEY"] from_account = env_values["EAS_FROM_ACCOUNT"] rpc_url = env_values.get("EAS_RPC_URL") @@ -268,6 +332,7 @@ def from_environment(cls, **kwargs: Any) -> "EAS": logger.info( "creating_eas_from_environment", chain_name=chain_name, + chain_id=chain_id, from_account=SecureEnvironmentValidator.sanitize_for_logging( from_account, "address" ), @@ -284,6 +349,7 @@ def from_environment(cls, **kwargs: Any) -> "EAS": # Use from_chain method with environment variable values return cls.from_chain( chain_name=chain_name, + chain_id=chain_id, private_key=private_key, from_account=from_account, rpc_url=rpc_url, @@ -292,9 +358,10 @@ def from_environment(cls, **kwargs: Any) -> "EAS": ) except Exception as e: + error_chain_id = chain_name if chain_name is not None else str(chain_id) logger.error( "eas_from_environment_failed", - chain_name=chain_name, + chain_name=error_chain_id, from_account=( SecureEnvironmentValidator.sanitize_for_logging( from_account, "address" @@ -614,10 +681,10 @@ def get_attestation_url(self, attestation_uid: str) -> str: return f"{base_url}/attestation/view/{attestation_uid}" - def __init_schema_registry(self, network_name: str) -> SchemaRegistry: - """Initialize schema registry for the current network.""" + def __init_schema_registry(self, chain_name: str) -> SchemaRegistry: + """Initialize schema registry for the current chain.""" try: - registry_address = SchemaRegistry.get_registry_address(network_name) + registry_address = SchemaRegistry.get_registry_address(chain_name) return SchemaRegistry( web3=self.w3, registry_address=registry_address, @@ -631,7 +698,7 @@ def __init_schema_registry(self, network_name: str) -> SchemaRegistry: def register_schema( self, schema: str, - network_name: str = "base-sepolia", + chain_name: str = "base-sepolia", resolver: Optional[str] = None, revocable: bool = True, ) -> TransactionResult: @@ -640,14 +707,14 @@ def register_schema( Args: schema: Schema definition string (e.g., "uint256 id,string name") - network_name: Network to register on (default: base-sepolia) + chain_name: Chain to register on (default: base-sepolia) resolver: Optional resolver contract address revocable: Whether attestations using this schema can be revoked Returns: TransactionResult with schema UID and transaction details """ - registry = self.__init_schema_registry(network_name) + registry = self.__init_schema_registry(chain_name) result = registry.register_schema(schema, resolver, revocable) return cast(TransactionResult, result) diff --git a/src/main/eas/exceptions.py b/src/main/eas/exceptions.py index 98da7c0..3f852aa 100644 --- a/src/main/eas/exceptions.py +++ b/src/main/eas/exceptions.py @@ -164,7 +164,7 @@ def __init__( self, message: str, rpc_url: Optional[str] = None, - network_name: Optional[str] = None, + chain_name: Optional[str] = None, ): context = {} if rpc_url: @@ -174,8 +174,8 @@ def __init__( context["rpc_url"] = SecureEnvironmentValidator.sanitize_for_logging( rpc_url, "url" ) - if network_name: - context["network_name"] = network_name + if chain_name: + context["chain_name"] = chain_name super().__init__(message, context) diff --git a/src/main/eas/schema_registry.py b/src/main/eas/schema_registry.py index 3510a5d..330a2a4 100644 --- a/src/main/eas/schema_registry.py +++ b/src/main/eas/schema_registry.py @@ -240,9 +240,9 @@ def get_schema(self, uid: str) -> Dict[str, Any]: ) @classmethod - def get_registry_address(cls, network_name: str) -> str: - """Get the schema registry contract address for a network.""" - # Network-specific registry addresses + def get_registry_address(cls, chain_name: str) -> str: + """Get the schema registry contract address for a chain.""" + # Chain-specific registry addresses # Note: These would need to be updated with actual EAS Schema Registry addresses registry_addresses = { "mainnet": "0x0a7E2Ff54e76B8E6659aedc9103FB21c038050D0", @@ -251,11 +251,11 @@ def get_registry_address(cls, network_name: str) -> str: "base-sepolia": "0x4200000000000000000000000000000000000020", # Example - needs actual address } - if network_name not in registry_addresses: + if chain_name not in registry_addresses: raise EASValidationError( - f"Unsupported network for schema registry: {network_name}", - field_name="network_name", - field_value=network_name, + f"Unsupported chain for schema registry: {chain_name}", + field_name="chain_name", + field_value=chain_name, ) - return registry_addresses[network_name] + return registry_addresses[chain_name] diff --git a/src/main/eas/types.py b/src/main/eas/types.py index 6b64e71..84f8f02 100644 --- a/src/main/eas/types.py +++ b/src/main/eas/types.py @@ -133,7 +133,7 @@ class ChainConfig(TypedDict): chain_id: Required[ChainId] rpc_url: Required[RpcUrl] contract_version: Required[ContractVersion] - network_name: NotRequired[str] + chain_name: NotRequired[str] class EIP712Domain(TypedDict): diff --git a/src/test/test_multi_chain_support.py b/src/test/test_multi_chain_support.py index 458356e..85b5549 100644 --- a/src/test/test_multi_chain_support.py +++ b/src/test/test_multi_chain_support.py @@ -13,6 +13,7 @@ get_testnet_chains, list_supported_chains, ) +from eas.exceptions import EASSecurityError class TestMultiChainSupport: @@ -42,7 +43,7 @@ def test_get_mainnet_chains(self): assert len(mainnet_chains) > 0, "Should have at least one mainnet chain" for chain in mainnet_chains: - config = get_network_config(chain) + config = get_network_config(chain_name=chain) assert ( config.get("network_type", "mainnet") == "mainnet" ), f"{chain} should be a mainnet chain" @@ -56,7 +57,7 @@ def test_get_testnet_chains(self): for chain in testnet_chains: try: - config = get_network_config(chain) + config = get_network_config(chain_name=chain) assert ( config.get("network_type", "mainnet") == "testnet" ), f"{chain} should be a testnet chain" @@ -81,7 +82,7 @@ def test_get_network_config_valid_chains(self): for chain in list_supported_chains(): try: - config = get_network_config(chain) + config = get_network_config(chain_name=chain) # Common configuration validation assert "rpc_url" in config, f"RPC URL missing for {chain}" @@ -112,9 +113,10 @@ def test_get_network_config_valid_chains(self): def test_get_network_config_invalid_chain(self): """Test error handling for unsupported chain names""" with pytest.raises( - ValueError, match="(Unsupported chain|Invalid network name)" + (ValueError, EASSecurityError), + match="(Unsupported chain|Invalid network name|Invalid chain name format)", ): - get_network_config("non_existent_chain") + get_network_config(chain_name="non_existent_chain") @patch("main.eas.core.web3.Web3") def test_eas_from_chain_valid_chain(self, mock_web3_class): @@ -132,7 +134,11 @@ def test_eas_from_chain_valid_chain(self, mock_web3_class): test_from_account = "0xd796b20681bD6BEe28f0c938271FA99261c84fE8" for chain in supported_chains: - eas = EAS.from_chain(chain, test_private_key, test_from_account) + eas = EAS.from_chain( + chain_name=chain, + private_key=test_private_key, + from_account=test_from_account, + ) # Validate basic properties assert eas.chain_id is not None @@ -156,9 +162,9 @@ def test_eas_from_chain_with_overrides(self, mock_web3_class): test_from_account = "0xd796b20681bD6BEe28f0c938271FA99261c84fE8" eas = EAS.from_chain( - "ethereum", - test_private_key, - test_from_account, + chain_name="ethereum", + private_key=test_private_key, + from_account=test_from_account, rpc_url=custom_rpc, contract_address=custom_contract, ) @@ -174,10 +180,14 @@ def test_eas_from_chain_invalid_chain(self): test_from_account = "0xd796b20681bD6BEe28f0c938271FA99261c84fE8" with pytest.raises( - ValueError, - match="(Unsupported chain|Invalid network name|Security validation failed)", + (ValueError, EASSecurityError), + match="(Unsupported chain|Invalid network name|Security validation failed|Invalid chain name format)", ): - EAS.from_chain("non_existent_chain", test_private_key, test_from_account) + EAS.from_chain( + chain_name="non_existent_chain", + private_key=test_private_key, + from_account=test_from_account, + ) def test_eas_from_environment(self, mock_env_vars): """Test eas.from_environment() parsing""" @@ -200,7 +210,7 @@ def test_eas_from_environment_missing_vars(self, mock_env_vars): for var in ["EAS_CHAIN", "EAS_PRIVATE_KEY", "EAS_FROM_ACCOUNT"]: os.environ.pop(var, None) - with pytest.raises(ValueError, match="Missing required environment variables"): + with pytest.raises(ValueError, match="Exactly one of.*environment variables"): EAS.from_environment() @patch("main.eas.core.web3.Web3") @@ -259,7 +269,9 @@ def test_multiple_eas_instances(self, mock_web3_class): eas_instances = {} for chain in chains_to_test: eas_instances[chain] = EAS.from_chain( - chain, test_private_key, test_from_account + chain_name=chain, + private_key=test_private_key, + from_account=test_from_account, ) # Verify unique chain IDs and contract addresses @@ -291,7 +303,11 @@ def test_performance_factory_methods(self, mock_web3_class): # Measure initialization time for from_chain start_time = time.time() - eas = EAS.from_chain("ethereum", test_private_key, test_from_account) + eas = EAS.from_chain( + chain_name="ethereum", + private_key=test_private_key, + from_account=test_from_account, + ) from_chain_time = time.time() - start_time assert eas is not None diff --git a/src/test/test_security_validation.py b/src/test/test_security_validation.py index 8c0bf02..e38390a 100644 --- a/src/test/test_security_validation.py +++ b/src/test/test_security_validation.py @@ -487,20 +487,34 @@ def test_eas_from_chain_security_validation( ), ): - eas = EAS.from_chain("ethereum", valid_key, valid_address) + eas = EAS.from_chain( + chain_name="ethereum", private_key=valid_key, from_account=valid_address + ) assert eas is not None # Invalid chain name should fail with pytest.raises(ValueError, match="Security validation failed"): - EAS.from_chain("ethereum; rm -rf /", valid_key, valid_address) + EAS.from_chain( + chain_name="ethereum; rm -rf /", + private_key=valid_key, + from_account=valid_address, + ) # Invalid private key should fail with pytest.raises(ValueError, match="Security validation failed"): - EAS.from_chain("ethereum", "invalid-key", valid_address) + EAS.from_chain( + chain_name="ethereum", + private_key="invalid-key", + from_account=valid_address, + ) # Invalid address should fail with pytest.raises(ValueError, match="Security validation failed"): - EAS.from_chain("ethereum", valid_key, "invalid-address") + EAS.from_chain( + chain_name="ethereum", + private_key=valid_key, + from_account="invalid-address", + ) @patch.dict( os.environ, @@ -551,7 +565,7 @@ def test_eas_from_environment_security_validation( os.environ, { "EAS_CHAIN": "ethereum; rm -rf /", # Injection attempt - "EAS_PRIVATE_KEY": "0x1234567890123456789012345678901234567890123456789012345678901234", + "EAS_PRIVATE_KEY": "0xa7c5ba7114b7119bb78dfc8e8ccd9f4ad8c6c9f2e8d7ab234fac8b1d5c7e9f12", "EAS_FROM_ACCOUNT": "0x1234567890123456789012345678901234567890", }, ) @@ -559,7 +573,9 @@ def test_environment_injection_prevention(self): """Test that environment variable injection is prevented""" from eas import EAS - with pytest.raises(ValueError, match="dangerous patterns"): + with pytest.raises( + ValueError, match="(dangerous patterns|Invalid chain name format)" + ): EAS.from_environment() diff --git a/src/test/test_write_operations.py b/src/test/test_write_operations.py index ea3c886..eee5821 100644 --- a/src/test/test_write_operations.py +++ b/src/test/test_write_operations.py @@ -79,7 +79,7 @@ def test_get_registry_address(self): assert len(address) == 42 # Test unknown network - with pytest.raises(EASValidationError, match="Unsupported network"): + with pytest.raises(EASValidationError, match="Unsupported chain"): SchemaRegistry.get_registry_address("unknown-network") @@ -277,7 +277,7 @@ def test_schema_registration_with_real_network(self): try: result = eas.register_schema( schema=test_schema, - network_name="base-sepolia", + chain_name="base-sepolia", resolver=None, revocable=True, )