From f2b516a1a6975a58065f87ac0cf90d3292306c80 Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Sat, 11 Jan 2025 09:22:18 -0600 Subject: [PATCH 01/27] feat: add solana key type --- commune/key.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/commune/key.py b/commune/key.py index fafad9608..28bf87df0 100644 --- a/commune/key.py +++ b/commune/key.py @@ -38,6 +38,7 @@ from ecdsa.curves import SECP256k1 from eth_keys.datatypes import Signature, PrivateKey from eth_utils import to_checksum_address, keccak as eth_utils_keccak +from solders.keypair import Keypair as SolanaKeypair BIP39_PBKDF2_ROUNDS = 2048 BIP39_SALT_MODIFIER = "mnemonic" @@ -305,6 +306,7 @@ class KeyType: ED25519 = 0 SR25519 = 1 ECDSA = 2 + SOLANA = 3 KeyType.crypto_types = [k for k in KeyType.__dict__.keys() if not k.startswith('_')] KeyType.crypto_type_map = {k.lower():v for k,v in KeyType.__dict__.items() if k in KeyType.crypto_types } KeyType.crypto_types = list(KeyType.crypto_type_map.keys()) @@ -393,6 +395,8 @@ def set_private_key(self, public_key = private_key_obj.public_key.to_address() key_address = private_key_obj.public_key.to_checksum_address() hash_type = 'h160' + elif crypto_type == KeyType.SOLANA: + pass else: raise ValueError('crypto_type "{}" not supported'.format(crypto_type)) if type(public_key) is str: @@ -798,7 +802,9 @@ def create_from_mnemonic(cls, mnemonic: str = None, ss58_format=ss58_format, cry raise ValueError("ECDSA mnemonic only supports english") private_key = mnemonic_to_ecdsa_private_key(mnemonic) keypair = cls.create_from_private_key(private_key, ss58_format=ss58_format, crypto_type=crypto_type) - + elif crypto_type == KeyType.SOLANA: + private_key = SolanaKeypair.from_seed_phrase_and_passphrase(mnemonic, "").secret() + keypair = cls.create_from_private_key(private_key, ss58_format=ss58_format, crypto_type=crypto_type) else: keypair = cls.create_from_seed( seed_hex=binascii.hexlify(bytearray(bip39_to_mini_secret(mnemonic, "", language_code))).decode("ascii"), @@ -834,6 +840,10 @@ def create_from_seed(cls, seed_hex: Union[bytes, str], ss58_format: Optional[int public_key, private_key = sr25519.pair_from_seed(seed_hex) elif crypto_type == KeyType.ED25519: private_key, public_key = ed25519_zebra.ed_from_seed(seed_hex) + elif crypto_type == KeyType.SOLANA: + keypair = SolanaKeypair.from_seed(seed_hex) + public_key = keypair.pubkey() + private_key = keypair.secret() else: raise ValueError('crypto_type "{}" not supported'.format(crypto_type)) @@ -900,6 +910,11 @@ def create_from_uri( passphrase=suri_parts['password'] ) derived_keypair = cls.create_from_private_key(private_key, ss58_format=ss58_format, crypto_type=crypto_type) + elif crypto_type == KeyType.SOLANA: + if language_code != "en": + raise ValueError("Solana mnemonic only supports english") + private_key = SolanaKeypair.from_seed_phrase_and_passphrase(suri_parts['phrase'], passphrase=suri_parts['password']).secret() + derived_keypair = cls.create_from_private_key(private_key, ss58_format=ss58_format, crypto_type=crypto_type) else: if suri_parts['password']: From 60eae0fdc258fcb6da00899ec5d97ab102eb3662 Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Sat, 11 Jan 2025 09:37:35 -0600 Subject: [PATCH 02/27] feat: aadd solana key type in set private key --- commune/key.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/commune/key.py b/commune/key.py index 28bf87df0..b3d72d151 100644 --- a/commune/key.py +++ b/commune/key.py @@ -396,7 +396,12 @@ def set_private_key(self, key_address = private_key_obj.public_key.to_checksum_address() hash_type = 'h160' elif crypto_type == KeyType.SOLANA: - pass + private_key = private_key[0:32] + keypair = SolanaKeypair.from_seed(private_key) + public_key = keypair.pubkey() + private_key = keypair.secret() + key_address = ss58_encode(public_key, ss58_format=ss58_format) + hash_type = 'ss58' else: raise ValueError('crypto_type "{}" not supported'.format(crypto_type)) if type(public_key) is str: From ea543449c7c3fee42981eebb82a5047e29d1cc85 Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Sat, 11 Jan 2025 10:29:20 -0600 Subject: [PATCH 03/27] feat: sign and verify for solana key type --- commune/key.py | 35 ++++++++++++++++++++++++++--------- test.py | 0 2 files changed, 26 insertions(+), 9 deletions(-) create mode 100644 test.py diff --git a/commune/key.py b/commune/key.py index b3d72d151..aef459e1e 100644 --- a/commune/key.py +++ b/commune/key.py @@ -39,6 +39,8 @@ from eth_keys.datatypes import Signature, PrivateKey from eth_utils import to_checksum_address, keccak as eth_utils_keccak from solders.keypair import Keypair as SolanaKeypair +from solders.signature import Signature as SolanaSignature +from solders.pubkey import Pubkey as SolanaPubkey BIP39_PBKDF2_ROUNDS = 2048 BIP39_SALT_MODIFIER = "mnemonic" @@ -129,6 +131,11 @@ def ecdsa_verify(signature: bytes, data: bytes, address: bytes) -> bool: recovered_pubkey = signature_obj.recover_public_key_from_msg(data) return recovered_pubkey.to_canonical_address() == address +def solana_verify(signature: bytes, message: bytes, public_key: bytes) -> bool: + signature = SolanaSignature.from_bytes(signature) + pubkey = SolanaPubkey.from_bytes(public_key) + return signature.verify(message, pubkey) + NONCE_LENGTH = 24 SCRYPT_LENGTH = 32 + (3 * 4) PKCS8_DIVIDER = bytes([161, 35, 3, 33, 0]) @@ -1023,6 +1030,9 @@ def create_from_encrypted_json(cls, json_data: Union[str, dict], passphrase: str crypto_type = KeyType.ED25519 # Strip the nonce part of the private key private_key = private_key[0:32] + elif 'solana' in json_data['encoding']['content']: + crypto_type = KeyType.SOLANA + private_key = private_key[0:32] else: raise NotImplementedError("Unknown KeyType found in JSON") @@ -1047,24 +1057,27 @@ def export_to_encrypted_json(self, passphrase: str, name: str = None) -> dict: if not name: name = self.ss58_address - if self.crypto_type != KeyType.SR25519: + if self.crypto_type == KeyType.SR25519: + # Secret key from PolkadotJS is an Ed25519 expanded secret key, so has to be converted + # https://github.com/polkadot-js/wasm/blob/master/packages/wasm-crypto/src/rs/sr25519.rs#L125 + converted_private_key = sr25519.convert_secret_key_to_ed25519(self.private_key) + encoded = encode_pair(self.public_key, converted_private_key, passphrase) + encoding_content = ["pkcs8", "sr25519"] + elif self.crypto_type == KeyType.SOLANA: + keypair = SolanaKeypair.from_seed(self.private_key) + encoded = encode_pair(self.public_key, keypair.secret(), passphrase) + encoding_content = ["pkcs8", "solana"] + else: raise NotImplementedError(f"Cannot create JSON for crypto_type '{self.crypto_type}'") - # Secret key from PolkadotJS is an Ed25519 expanded secret key, so has to be converted - # https://github.com/polkadot-js/wasm/blob/master/packages/wasm-crypto/src/rs/sr25519.rs#L125 - converted_private_key = sr25519.convert_secret_key_to_ed25519(self.private_key) - - encoded = encode_pair(self.public_key, converted_private_key, passphrase) - json_data = { "encoded": b64encode(encoded).decode(), - "encoding": {"content": ["pkcs8", "sr25519"], "type": ["scrypt", "xsalsa20-poly1305"], "version": "3"}, + "encoding": {"content": encoding_content, "type": ["scrypt", "xsalsa20-poly1305"], "version": "3"}, "address": self.ss58_address, "meta": { "name": name, "tags": [], "whenCreated": int(time.time()) } } - return json_data seperator = "::signature=" @@ -1097,6 +1110,8 @@ def sign(self, data: Union[ScaleBytes, bytes, str], to_json = False) -> bytes: signature = ed25519_zebra.ed_sign(self.private_key, data) elif self.crypto_type == KeyType.ECDSA: signature = ecdsa_sign(self.private_key, data) + elif self.crypto_type == KeyType.SOLANA: + signature = SolanaKeypair.from_seed(self.private_key).sign_message(data) else: raise Exception("Crypto type not supported") @@ -1213,6 +1228,8 @@ def verify(self, crypto_verify_fn = ed25519_zebra.ed_verify elif self.crypto_type == KeyType.ECDSA: crypto_verify_fn = ecdsa_verify + elif self.crypto_type == KeyType.SOLANA: + crypto_verify_fn = solana_verify else: raise Exception("Crypto type not supported") verified = crypto_verify_fn(signature, data, public_key) diff --git a/test.py b/test.py new file mode 100644 index 000000000..e69de29bb From 7d3be72d578412acf358288c8120ce7469b1ff24 Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Sat, 11 Jan 2025 10:35:23 -0600 Subject: [PATCH 04/27] feat: solana sign fn --- commune/key.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/commune/key.py b/commune/key.py index aef459e1e..25f83c8b3 100644 --- a/commune/key.py +++ b/commune/key.py @@ -131,6 +131,10 @@ def ecdsa_verify(signature: bytes, data: bytes, address: bytes) -> bool: recovered_pubkey = signature_obj.recover_public_key_from_msg(data) return recovered_pubkey.to_canonical_address() == address +def solana_sign(private_key: bytes, message: bytes) -> bytes: + keypair = SolanaKeypair.from_seed(private_key) + return bytes(keypair.sign_message(message)) + def solana_verify(signature: bytes, message: bytes, public_key: bytes) -> bool: signature = SolanaSignature.from_bytes(signature) pubkey = SolanaPubkey.from_bytes(public_key) @@ -1111,7 +1115,7 @@ def sign(self, data: Union[ScaleBytes, bytes, str], to_json = False) -> bytes: elif self.crypto_type == KeyType.ECDSA: signature = ecdsa_sign(self.private_key, data) elif self.crypto_type == KeyType.SOLANA: - signature = SolanaKeypair.from_seed(self.private_key).sign_message(data) + signature = solana_sign(self.private_key, data) else: raise Exception("Crypto type not supported") From 6f89d56ff152eb98d54a9360bd1868512f09cc8f Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Sun, 12 Jan 2025 08:24:53 -0600 Subject: [PATCH 05/27] fix: change pubkey to bytes type --- commune/key.py | 12 ++++-------- test.py | 0 2 files changed, 4 insertions(+), 8 deletions(-) delete mode 100644 test.py diff --git a/commune/key.py b/commune/key.py index 25f83c8b3..800e6566f 100644 --- a/commune/key.py +++ b/commune/key.py @@ -41,6 +41,7 @@ from solders.keypair import Keypair as SolanaKeypair from solders.signature import Signature as SolanaSignature from solders.pubkey import Pubkey as SolanaPubkey +from base58 import b58encode BIP39_PBKDF2_ROUNDS = 2048 BIP39_SALT_MODIFIER = "mnemonic" @@ -409,10 +410,10 @@ def set_private_key(self, elif crypto_type == KeyType.SOLANA: private_key = private_key[0:32] keypair = SolanaKeypair.from_seed(private_key) - public_key = keypair.pubkey() + public_key = keypair.pubkey().__bytes__() private_key = keypair.secret() - key_address = ss58_encode(public_key, ss58_format=ss58_format) - hash_type = 'ss58' + key_address = b58encode(bytes(public_key)).decode('utf-8') + hash_type = 'base58' else: raise ValueError('crypto_type "{}" not supported'.format(crypto_type)) if type(public_key) is str: @@ -719,7 +720,6 @@ def new_key(cls, c.print(f'generating {crypto_type} keypair, {suri}') crypto_type = cls.resolve_crypto_type(crypto_type) - if suri: key = cls.create_from_uri(suri, crypto_type=crypto_type) elif mnemonic: @@ -856,10 +856,6 @@ def create_from_seed(cls, seed_hex: Union[bytes, str], ss58_format: Optional[int public_key, private_key = sr25519.pair_from_seed(seed_hex) elif crypto_type == KeyType.ED25519: private_key, public_key = ed25519_zebra.ed_from_seed(seed_hex) - elif crypto_type == KeyType.SOLANA: - keypair = SolanaKeypair.from_seed(seed_hex) - public_key = keypair.pubkey() - private_key = keypair.secret() else: raise ValueError('crypto_type "{}" not supported'.format(crypto_type)) diff --git a/test.py b/test.py deleted file mode 100644 index e69de29bb..000000000 From 9b11cedef6b657b09e0387b5abcb47ab6ca2715c Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Mon, 13 Jan 2025 10:27:32 -0600 Subject: [PATCH 06/27] fix: verify signature --- commune/key.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/commune/key.py b/commune/key.py index 800e6566f..b04ddfca0 100644 --- a/commune/key.py +++ b/commune/key.py @@ -138,8 +138,8 @@ def solana_sign(private_key: bytes, message: bytes) -> bytes: def solana_verify(signature: bytes, message: bytes, public_key: bytes) -> bool: signature = SolanaSignature.from_bytes(signature) - pubkey = SolanaPubkey.from_bytes(public_key) - return signature.verify(message, pubkey) + pubkey = SolanaPubkey(public_key) + return signature.verify(pubkey, message) NONCE_LENGTH = 24 SCRYPT_LENGTH = 32 + (3 * 4) @@ -450,13 +450,13 @@ def ticket(cls , data=None, key=None, **kwargs): return cls.get_key(key).sign({'data':data, 'time': c.time()} , to_json=True, **kwargs) @classmethod - def mv_key(cls, path, new_path): + def mv_key(cls, path, new_path, crypto_type='sr25519'): assert cls.key_exists(path), f'key does not exist at {path}' - cls.put(new_path, cls.get_key(path).to_json()) + cls.put(new_path, cls.get_key(path, crypto_type=crypto_type).to_json()) cls.rm_key(path) assert cls.key_exists(new_path), f'key does not exist at {new_path}' assert not cls.key_exists(path), f'key still exists at {path}' - new_key = cls.get_key(new_path) + new_key = cls.get_key(new_path, crypto_type=crypto_type) return {'success': True, 'from': path , 'to': new_path, 'key': new_key} @classmethod From b505e2f25133a377e191be1ead8d7d788285b15f Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Mon, 13 Jan 2025 10:28:34 -0600 Subject: [PATCH 07/27] test: add solana key type --- tests/test_key.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/test_key.py b/tests/test_key.py index 27c211a51..ae5fda6b6 100644 --- a/tests/test_key.py +++ b/tests/test_key.py @@ -1,11 +1,11 @@ import commune as c - +crypto_type='solana' def test_encryption(values = [10, 'fam', 'hello world']): cls = c.module('key') for value in values: value = str(value) - key = cls.new_key() + key = cls.new_key(crypto_type=crypto_type) enc = key.encrypt(value) dec = key.decrypt(enc) assert dec == value, f'encryption failed, {dec} != {value}' @@ -14,7 +14,7 @@ def test_encryption(values = [10, 'fam', 'hello world']): def test_encryption_with_password(value = 10, password = 'fam'): cls = c.module('key') value = str(value) - key = cls.new_key() + key = cls.new_key(crypto_type=crypto_type) enc = key.encrypt(value, password=password) dec = key.decrypt(enc, password=password) assert dec == value, f'encryption failed, {dec} != {value}' @@ -22,11 +22,11 @@ def test_encryption_with_password(value = 10, password = 'fam'): def test_key_encryption(test_key='test.key'): self = c.module('key') - key = self.add_key(test_key, refresh=True) - og_key = self.get_key(test_key) + key = self.add_key(test_key, refresh=True, crypto_type=crypto_type) + og_key = self.get_key(test_key, crypto_type=crypto_type) r = self.encrypt_key(test_key) self.decrypt_key(test_key, password=r['password']) - key = self.get_key(test_key) + key = self.get_key(test_key, crypto_type=crypto_type) assert key.ss58_address == og_key.ss58_address, f'key encryption failed, {key.ss58_address} != {self.ss58_address}' return {'success': True, 'msg': 'test_key_encryption passed'} @@ -36,15 +36,15 @@ def test_key_management(key1='test.key' , key2='test2.key'): self.rm_key(key1) if self.key_exists(key2): self.rm_key(key2) - self.add_key(key1) - k1 = self.get_key(key1) + self.add_key(key1, crypto_type=crypto_type) + k1 = self.get_key(key1, crypto_type=crypto_type) assert self.key_exists(key1), f'Key management failed, key still exists' - self.mv_key(key1, key2) - k2 = self.get_key(key2) + self.mv_key(key1, key2, crypto_type=crypto_type) + k2 = self.get_key(key2, crypto_type=crypto_type) assert k1.ss58_address == k2.ss58_address, f'Key management failed, {k1.ss58_address} != {k2.ss58_address}' assert self.key_exists(key2), f'Key management failed, key does not exist' assert not self.key_exists(key1), f'Key management failed, key still exists' - self.mv_key(key2, key1) + self.mv_key(key2, key1, crypto_type=crypto_type) assert self.key_exists(key1), f'Key management failed, key does not exist' assert not self.key_exists(key2), f'Key management failed, key still exists' self.rm_key(key1) @@ -55,7 +55,7 @@ def test_key_management(key1='test.key' , key2='test2.key'): def test_signing(): - self = c.module('key')() + self = c.module('key')(crypto_type=crypto_type) sig = self.sign('test') assert self.verify('test',sig, self.public_key) return {'success':True} @@ -63,7 +63,7 @@ def test_signing(): def test_key_encryption(password='1234'): cls = c.module('key') path = 'test.enc' - cls.add_key('test.enc', refresh=True) + cls.add_key('test.enc', refresh=True, crypto_type=crypto_type) assert cls.is_key_encrypted(path) == False, f'file {path} is encrypted' cls.encrypt_key(path, password=password) assert cls.is_key_encrypted(path) == True, f'file {path} is not encrypted' @@ -76,13 +76,13 @@ def test_key_encryption(password='1234'): def test_move_key(): self = c.module('key')() - self.add_key('testfrom') + self.add_key('testfrom', crypto_type=crypto_type) assert self.key_exists('testfrom') - og_key = self.get_key('testfrom') - self.mv_key('testfrom', 'testto') - assert self.key_exists('testto') + og_key = self.get_key('testfrom', crypto_type=crypto_type) + self.mv_key('testfrom', 'testto', crypto_type=crypto_type) + assert self.key_exists('testto', crypto_type=crypto_type) assert not self.key_exists('testfrom') - new_key = self.get_key('testto') + new_key = self.get_key('testto', crypto_type=crypto_type) assert og_key.ss58_address == new_key.ss58_address self.rm_key('testto') assert not self.key_exists('testto') From f9aeea3f7fc2c1aa47ef1f66e25e3c312882690c Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:41:50 -0600 Subject: [PATCH 08/27] ref: define constant --- commune/key/types/index.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 commune/key/types/index.py diff --git a/commune/key/types/index.py b/commune/key/types/index.py new file mode 100644 index 000000000..af3f083ff --- /dev/null +++ b/commune/key/types/index.py @@ -0,0 +1,32 @@ +from ecdsa.curves import SECP256k1 + +SS58_FORMAT = 42 + +ERROR_INVALID_KEY = "Invalid key provided." +ERROR_KEY_GENERATION_FAILED = "Key generation failed." +ERROR_KEY_VALIDATION_FAILED = "Key validation failed." + +DEV_PHRASE = 'bottom drive obey lake curtain smoke basket hold race lonely fit walk' + +JUNCTION_ID_LEN = 32 +RE_JUNCTION = r'(\/\/?)([^/]+)' + +NONCE_LENGTH = 24 +SCRYPT_LENGTH = 32 + (3 * 4) +PKCS8_DIVIDER = bytes([161, 35, 3, 33, 0]) +PKCS8_HEADER = bytes([48, 83, 2, 1, 1, 48, 5, 6, 3, 43, 101, 112, 4, 34, 4, 32]) +PUB_LENGTH = 32 +SALT_LENGTH = 32 +SEC_LENGTH = 64 +SEED_LENGTH = 32 + +SCRYPT_N = 1 << 15 +SCRYPT_P = 1 +SCRYPT_R = 8 + +BIP39_PBKDF2_ROUNDS = 2048 +BIP39_SALT_MODIFIER = "mnemonic" +BIP32_PRIVDEV = 0x80000000 +BIP32_CURVE = SECP256k1 +BIP32_SEED_MODIFIER = b"Bitcoin seed" +ETH_DERIVATION_PATH = "m/44'/60'/0'/0" \ No newline at end of file From 525ef82d0adac236ccf784422490c26a858456fa Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:42:52 -0600 Subject: [PATCH 09/27] ref: add base key class --- commune/key/key.py | 1011 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1011 insertions(+) create mode 100644 commune/key/key.py diff --git a/commune/key/key.py b/commune/key/key.py new file mode 100644 index 000000000..dd953bdb9 --- /dev/null +++ b/commune/key/key.py @@ -0,0 +1,1011 @@ +import json +from typing import Union, Optional +import time +import os +import binascii +import re +import secrets +import base64 +from base64 import b64encode +import hashlib +from Crypto import Random +from Crypto.Cipher import AES +import nacl.bindings +import nacl.public +from scalecodec.utils.ss58 import ss58_encode, ss58_decode, get_ss58_format +from scalecodec.base import ScaleBytes +from bip39 import bip39_generate, bip39_validate +import commune as c + +from scalecodec.utils.ss58 import is_valid_ss58_address + +from .types.index import * +from .utils import * + +class KeyType: + ED25519 = 0 + SR25519 = 1 + ECDSA = 2 + SOLANA = 3 +KeyType.crypto_types = [k for k in KeyType.__dict__.keys() if not k.startswith('_')] +KeyType.crypto_type_map = {k.lower():v for k,v in KeyType.__dict__.items() if k in KeyType.crypto_types } +KeyType.crypto_types = list(KeyType.crypto_type_map.keys()) + +class Key(c.Module): + crypto_types = KeyType.crypto_types + crypto_type_map = KeyType.crypto_type_map + crypto_types = list(crypto_type_map.keys()) + def __new__( + cls, + crypto_type: Union[str, int] = KeyType.SR25519, + **kwargs, + ): + crypto_type = cls.resolve_crypto_type(crypto_type) + if crypto_type == KeyType.SR25519: + from .dot_sr25519 import DotSR25519 + return super().__new__(DotSR25519) + elif crypto_type == KeyType.ED25519: + from .dot_ed25519 import DotED25519 + return super().__new__(DotED25519) + elif crypto_type == KeyType.ECDSA: + from .eth import ECDSA + return super().__new__(ECDSA) + elif crypto_type == KeyType.SOLANA: + from .sol import Solana + return super().__new__(Solana) + else: + raise NotImplementedError(f"unsupported crypto_type {crypto_type}") + + @property + def short_address(self): + n = 4 + return self.ss58_address[:n] + "..." + self.ss58_address[-n:] + + def set_crypto_type(self, crypto_type): + crypto_type = self.resolve_crypto_type(crypto_type) + if crypto_type != self.crypto_type: + kwargs = { + "private_key": self.private_key, + "ss58_format": self.ss58_format, + "derive_path": self.derive_path, + "path": self.path, + "crypto_type": crypto_type, # update crypto_type + } + return self.set_private_key(**kwargs) + else: + return { + "success": False, + "message": f"crypto_type already set to {crypto_type}", + } + + def set_private_key( + self, + private_key: Union[bytes, str] = None, + ss58_format: int = 42, + derive_path: str = None, + path: str = None, + **kwargs, + ): + """ + Allows generation of Keys from a variety of input combination, such as a public/private key combination, + mnemonic or URI containing soft and hard derivation paths. With these Keys data can be signed and verified + + Parameters + ---------- + ss58_address: Substrate address + public_key: hex string or bytes of public_key key + private_key: hex string or bytes of private key + ss58_format: Substrate address format, default to 42 when omitted + seed_hex: hex string of seed + """ + raise NotImplementedError("set_private_key not implemented") + + @classmethod + def add_key( + cls, + path: str, + mnemonic: str = None, + password: str = None, + refresh: bool = False, + private_key=None, + crypto_type: Union[str, int] = KeyType.SR25519, + **kwargs, + ): + if cls.key_exists(path) and not refresh: + c.print(f"key already exists at {path}") + return cls.get(path) + key = cls.new_key(mnemonic=mnemonic, private_key=private_key, crypto_type=crypto_type, **kwargs) + key.path = path + key_json = key.to_json() + if password != None: + key_json = cls.encrypt(data=key_json, password=password) + c.print(cls.put(path, key_json)) + cls.update() + return json.loads(key_json) + + @classmethod + def ticket(cls, data=None, key=None, **kwargs): + return cls.get_key(key).sign( + {"data": data, "time": c.time()}, to_json=True, **kwargs + ) + + @classmethod + def mv_key(cls, path, new_path, crypto_type=KeyType.SR25519): + assert cls.key_exists(path), f"key does not exist at {path}" + cls.put(new_path, cls.get_key(path, crypto_type=crypto_type).to_json()) + cls.rm_key(path) + assert cls.key_exists(new_path), f"key does not exist at {new_path}" + assert not cls.key_exists(path), f"key still exists at {path}" + new_key = cls.get_key(new_path, crypto_type=crypto_type) + return {"success": True, "from": path, "to": new_path, "key": new_key} + + @classmethod + def copy_key(cls, path, new_path): + assert cls.key_exists(path), f"key does not exist at {path}" + cls.put(new_path, cls.get_key(path).to_json()) + assert cls.key_exists(new_path), f"key does not exist at {new_path}" + assert cls.get_key(path) == cls.get_key(new_path), "key does not match" + new_key = cls.get_key(new_path) + return {"success": True, "from": path, "to": new_path, "key": new_key} + + @classmethod + def add_keys(cls, name, n=100, verbose: bool = False, **kwargs): + response = [] + for i in range(n): + key_name = f"{name}.{i}" + if bool == True: + c.print(f"generating key {key_name}") + response.append(cls.add_key(key_name, **kwargs)) + + def key2encrypted(self): + keys = self.keys() + key2encrypted = {} + for k in keys: + key2encrypted[k] = self.is_key_encrypted(k) + return key2encrypted + + def encrypted_keys(self): + return [k for k, v in self.key2encrypted().items() if v == True] + + @classmethod + def key_info(cls, path="module", **kwargs): + return cls.get_key_json(path) + + @classmethod + def load_key(cls, path=None): + key_info = cls.get(path) + key_info = c.jload(key_info) + if key_info["path"] == None: + key_info["path"] = path.replace(".json", "").split("/")[-1] + + cls.add_key(**key_info) + return {"status": "success", "message": f"key loaded from {path}"} + + @classmethod + def save_keys(cls, path="saved_keys.json", **kwargs): + path = cls.resolve_path(path) + c.print(f"saving mems to {path}") + key2mnemonic = cls.key2mnemonic() + c.put_json(path, key2mnemonic) + return { + "success": True, + "msg": "saved keys", + "path": path, + "n": len(key2mnemonic), + } + + @classmethod + def load_keys(cls, path="saved_keys.json", refresh=False, **kwargs): + key2mnemonic = c.get_json(path) + for k, mnemonic in key2mnemonic.items(): + try: + cls.add_key(k, mnemonic=mnemonic, refresh=refresh, **kwargs) + except Exception: + # c.print(f'failed to load mem {k} due to {e}', color='red') + pass + return {"loaded_mems": list(key2mnemonic.keys()), "path": path} + + loadkeys = loadmems = load_keys + + @classmethod + def key2mnemonic(cls, search=None) -> dict[str, str]: + """ + keyname (str) --> mnemonic (str) + + """ + mems = {} + for key in cls.keys(search): + try: + mems[key] = cls.get_mnemonic(key) + except Exception as e: + c.print(f"failed to get mem for {key} due to {e}") + if search: + mems = {k: v for k, v in mems.items() if search in k or search in v} + return mems + + @classmethod + def get_key( + cls, + path: str, + crypto_type: Union[str, int] = KeyType.SR25519, + password: str = None, + create_if_not_exists: bool = True, + **kwargs, + ): + crypto_type = cls.resolve_crypto_type(crypto_type) + if hasattr(path, "key_address"): + key = path + return key + path = path or "module" + # if ss58_address is provided, get key from address + if cls.valid_ss58_address(path): + path = cls.address2key().get(path) + if not cls.key_exists(path): + if create_if_not_exists: + key = cls.add_key(path, crypto_type=crypto_type, **kwargs) + c.print(f"key does not exist, generating new key -> {key['path']}") + else: + print(path) + raise ValueError(f"key does not exist at --> {path}") + key_json = cls.get(path) + # if key is encrypted, decrypt it + if cls.is_encrypted(key_json): + key_json = c.decrypt(data=key_json, password=password) + if key_json == None: + c.print( + { + "status": "error", + "message": f"key is encrypted, please {path} provide password", + } + ) + return None + key_json = c.jload(key_json) if isinstance(key_json, str) else key_json + key = cls.from_json(key_json, crypto_type=crypto_type) + key.path = path + return key + + @classmethod + def get_keys(cls, search=None, clean_failed_keys=False): + keys = {} + for key in cls.keys(): + if str(search) in key or search == None: + try: + keys[key] = cls.get_key(key) + except Exception: + continue + if keys[key] == None: + if clean_failed_keys: + cls.rm_key(key) + keys.pop(key) + return keys + + @classmethod + def key2address(cls, search=None, max_age=10, update=False, **kwargs): + path = "key2address" + key2address = cls.get(path, None, max_age=max_age, update=update) + if key2address == None: + key2address = {k: v.ss58_address for k, v in cls.get_keys(search).items()} + cls.put(path, key2address) + return key2address + + @classmethod + def n(cls, search=None, **kwargs): + return len(cls.key2address(search, **kwargs)) + + @classmethod + def address2key(cls, search: Optional[str] = None, update: bool = False): + address2key = {v: k for k, v in cls.key2address(update=update).items()} + if search != None: + return address2key.get(search, None) + return address2key + + @classmethod + def get_address(cls, key): + return cls.get_key(key).ss58_address + + get_addy = get_address + + @classmethod + def key_paths(cls): + return cls.ls() + + address_seperator = "_address=" + + @classmethod + def key2path(cls) -> dict: + """ + defines the path for each key + """ + path2key_fn = lambda path: ".".join(path.split("/")[-1].split(".")[:-1]) + key2path = {path2key_fn(path): path for path in cls.key_paths()} + return key2path + + @classmethod + def keys(cls, search: str = None, **kwargs): + keys = list(cls.key2path().keys()) + if search != None: + keys = [key for key in keys if search in key] + return keys + + @classmethod + def n(cls, *args, **kwargs): + return len(cls.key2address(*args, **kwargs)) + + @classmethod + def key_exists(cls, key, **kwargs): + path = cls.get_key_path(key) + import os + + return os.path.exists(path) + + @classmethod + def get_key_path(cls, key): + storage_dir = cls.storage_dir() + key_path = storage_dir + "/" + key + ".json" + return key_path + + @classmethod + def get_key_json(cls, key): + storage_dir = cls.storage_dir() + key_path = storage_dir + "/" + key + ".json" + return c.get(key_path) + + @classmethod + def get_key_address(cls, key): + return cls.get_key(key).ss58_address + + @classmethod + def rm_key(cls, key=None): + key2path = cls.key2path() + keys = list(key2path.keys()) + if key not in keys: + raise Exception(f"key {key} not found, available keys: {keys}") + c.rm(key2path[key]) + return {"deleted": [key]} + + @classmethod + def crypto_name2type(cls, name: str): + crypto_type_map = cls.crypto_type_map + name = name.lower() + if name not in crypto_type_map: + raise ValueError(f"crypto_type {name} not supported {crypto_type_map}") + return crypto_type_map[name] + + @classmethod + def crypto_type2name(cls, crypto_type: str): + crypto_type_map = {v: k for k, v in cls.crypto_type_map.items()} + return crypto_type_map[crypto_type] + + @classmethod + def resolve_crypto_type_name(cls, crypto_type): + return cls.crypto_type2name(cls.resolve_crypto_type(crypto_type)) + + @classmethod + def resolve_crypto_type(cls, crypto_type): + if isinstance(crypto_type, int) or ( + isinstance(crypto_type, str) and c.is_int(crypto_type) + ): + crypto_type = int(crypto_type) + crypto_type_map = cls.crypto_type_map + reverse_crypto_type_map = {v: k for k, v in crypto_type_map.items()} + assert crypto_type in reverse_crypto_type_map, ( + f"crypto_type {crypto_type} not supported {crypto_type_map}" + ) + crypto_type = reverse_crypto_type_map[crypto_type] + if isinstance(crypto_type, str): + crypto_type = crypto_type.lower() + crypto_type = cls.crypto_name2type(crypto_type) + return int(crypto_type) + + @classmethod + def new_private_key(cls): + return cls.new_key().private_key.hex() + + @classmethod + def new_key( + cls, + mnemonic: str = None, + suri: str = None, + private_key: str = None, + verbose: bool = False, + crypto_type: Union[str, int] = KeyType.SR25519, + **kwargs, + ): + """ + yo rody, this is a class method you can gen keys whenever fam + """ + if verbose: + c.print(f"generating polkadot keypair, {suri}") + + if suri: + key = cls.create_from_uri(suri, crypto_type=crypto_type) + elif mnemonic: + key = cls.create_from_mnemonic(mnemonic, crypto_type=crypto_type) + elif private_key: + key = cls.create_from_private_key(private_key, crypto_type=crypto_type) + else: + mnemonic = cls.generate_mnemonic() + key = cls.create_from_mnemonic(mnemonic, crypto_type=crypto_type) + return key + + create = gen = new_key + + def to_json(self, password: str = None) -> dict: + state_dict = c.copy(self.__dict__) + for k, v in state_dict.items(): + if type(v) in [bytes]: + state_dict[k] = v.hex() + if password != None: + state_dict[k] = self.encrypt(data=state_dict[k], password=password) + if "_ss58_address" in state_dict: + state_dict["ss58_address"] = state_dict.pop("_ss58_address") + + state_dict = json.dumps(state_dict) + + return state_dict + + @classmethod + def from_json(cls, obj: Union[str, dict], password: str = None, crypto_type: Union[str, int] = KeyType.SR25519) -> dict: + if type(obj) == str: + obj = json.loads(obj) + if obj == None: + return None + for k, v in obj.items(): + if cls.is_encrypted(obj[k]) and password != None: + obj[k] = cls.decrypt(data=obj[k], password=password) + if "ss58_address" in obj: + obj["_ss58_address"] = obj.pop("ss58_address") + obj["crypto_type"] = crypto_type + return cls(**obj) + + @classmethod + def generate_mnemonic(cls, words: int = 12, language_code: str = "en") -> str: + """ + params: + words: The amount of words to generate, valid values are 12, 15, 18, 21 and 24 + language_code: The language to use, valid values are: 'en', 'zh-hans', + 'zh-hant', 'fr', 'it', 'ja', 'ko', 'es'. + Defaults to `"en"` + """ + mnemonic = bip39_generate(words, language_code) + assert cls.validate_mnemonic(mnemonic, language_code), "mnemonic is invalid" + return mnemonic + + @classmethod + def validate_mnemonic(cls, mnemonic: str, language_code: str = "en") -> bool: + """ + Verify if specified mnemonic is valid + + Parameters + ---------- + mnemonic: Seed phrase + language_code: The language to use, valid values are: 'en', 'zh-hans', 'zh-hant', 'fr', 'it', 'ja', 'ko', 'es'. Defaults to `"en"` + + Returns + ------- + bool + """ + return bip39_validate(mnemonic, language_code) + + @classmethod + def create_from_mnemonic( + cls, mnemonic: str, ss58_format: int = SS58_FORMAT, language_code: str = "en", crypto_type: Union[str, int] = KeyType.SR25519 + ) -> "Key": + """ + Create a Key for given memonic + + Parameters + ---------- + mnemonic: Seed phrase + ss58_format: Substrate address format + crypto_type: Use `KeyType.SR25519` or `KeyType.ED25519` cryptography for generating the Key + language_code: The language to use, valid values are: 'en', 'zh-hans', 'zh-hant', 'fr', 'it', 'ja', 'ko', 'es'. Defaults to `"en"` + + Returns + ------- + Key + """ + crypto_type = cls.resolve_crypto_type(crypto_type) + if crypto_type == KeyType.SR25519: + from commune.key.dot_sr25519 import DotSR25519 + return DotSR25519.create_from_mnemonic(mnemonic=mnemonic, ss58_format=ss58_format, language_code=language_code, crypto_type=crypto_type) + elif crypto_type == KeyType.ED25519: + from commune.key.dot_ed25519 import DotED25519 + return DotED25519.create_from_mnemonic(mnemonic=mnemonic, ss58_format=ss58_format, language_code=language_code, crypto_type=crypto_type) + elif crypto_type == KeyType.ECDSA: + from commune.key.eth import ECDSA + return ECDSA.create_from_mnemonic(mnemonic=mnemonic, ss58_format=ss58_format, language_code=language_code, crypto_type=crypto_type) + elif crypto_type == KeyType.SOLANA: + from commune.key.sol import Solana + return Solana.create_from_mnemonic(mnemonic=mnemonic, ss58_format=ss58_format, language_code=language_code, crypto_type=crypto_type) + else: + raise NotImplementedError("create_from_mnemonic not implemented") + + from_mnemonic = from_mem = create_from_mnemonic + + @classmethod + def create_from_seed( + cls, + seed_hex: Union[bytes, str] = None, + ss58_format: Optional[int] = SS58_FORMAT, + crypto_type: Union[str, int] = KeyType.SR25519 + ) -> "Key": + """ + Create a Key for given seed + + Parameters + ---------- + seed_hex: hex string of seed + ss58_format: Substrate address format + crypto_type: Use KeyType.SR25519 or KeyType.ED25519 cryptography for generating the Key + + Returns + ------- + Key + """ + crypto_type = cls.resolve_crypto_type(crypto_type) + if crypto_type == KeyType.SR25519: + from commune.key.dot_sr25519 import DotSR25519 + return DotSR25519.create_from_seed(seed_hex=seed_hex, ss58_format=ss58_format, crypto_type=crypto_type) + elif crypto_type == KeyType.ED25519: + from commune.key.dot_ed25519 import DotED25519 + return DotED25519.create_from_seed(seed_hex=seed_hex, ss58_format=ss58_format, crypto_type=crypto_type) + elif crypto_type == KeyType.ECDSA: + from commune.key.eth import ECDSA + return ECDSA.create_from_seed(seed_hex=seed_hex, ss58_format=ss58_format, crypto_type=crypto_type) + elif crypto_type == KeyType.SOLANA: + from commune.key.sol import Solana + return Solana.create_from_seed(seed_hex=seed_hex, ss58_format=ss58_format, crypto_type=crypto_type) + else: + raise NotImplementedError("create_from_seed not implemented") + + @classmethod + def create_from_password(cls, password: str, crypto_type: Union[str, int] = KeyType.SR25519, **kwargs): + key = cls.create_from_uri(password, crypto_type=crypto_type, **kwargs) + key.set_crypto_type(crypto_type) + + str2key = pwd2key = password2key = from_password = create_from_password + + @classmethod + def create_from_uri( + cls, + suri: str, + ss58_format: int = SS58_FORMAT, + crypto_type: Union[str, int] = KeyType.SR25519, + language_code: str = "en", + ) -> "Key": + """ + Creates Key for specified suri in following format: `[mnemonic]/[soft-path]//[hard-path]` + + Parameters + ---------- + suri: + ss58_format: Substrate address format + language_code: The language to use, valid values are: 'en', 'zh-hans', 'zh-hant', 'fr', 'it', 'ja', 'ko', 'es'. Defaults to `"en"` + + Returns + ------- + Key + """ + if crypto_type == KeyType.SR25519: + from commune.key.dot_sr25519 import DotSR25519 + return DotSR25519.create_from_uri(suri=suri, ss58_format=ss58_format, language_code=language_code) + elif crypto_type == KeyType.ED25519: + from commune.key.dot_ed25519 import DotED25519 + return DotED25519.create_from_uri(suri=suri, ss58_format=ss58_format, language_code=language_code) + elif crypto_type == KeyType.ECDSA: + from commune.key.eth import ECDSA + return ECDSA.create_from_uri(suri=suri, ss58_format=ss58_format, language_code=language_code) + elif crypto_type == KeyType.SOLANA: + from commune.key.sol import Solana + return Solana.create_from_uri(suri=suri, ss58_format=ss58_format, language_code=language_code) + else: + raise NotImplementedError("create_from_uri not implemented") + + + @classmethod + def create_from_private_key( + cls, + private_key: Union[bytes, str], + public_key: Union[bytes, str] = None, + ss58_address: str = None, + ss58_format: int = SS58_FORMAT, + crypto_type: Union[str, int] = KeyType.SR25519, + ) -> "Key": + """ + Creates Key for specified public/private keys + Parameters + ---------- + private_key: hex string or bytes of private key + public_key: hex string or bytes of public key + ss58_address: Substrate address + ss58_format: Substrate address format, default = 42 + crypto_type: Use KeyType.[SR25519|ED25519|ECDSA] cryptography for generating the Key + + Returns + ------- + Key + """ + return cls( + private_key=private_key, + public_key=public_key, + crypto_type=crypto_type, + ss58_format=ss58_format, + ss58_address=ss58_address + ) + + from_private_key = create_from_private_key + + @classmethod + def create_from_encrypted_json( + cls, json_data: Union[str, dict], passphrase: str, ss58_format: int = None + ) -> "Key": + """ + Create a Key from a PolkadotJS format encrypted JSON file + + Parameters + ---------- + json_data: Dict or JSON string containing PolkadotJS export format + passphrase: Used to encrypt the keypair + ss58_format: Which network ID to use to format the SS58 address (42 for testnet) + + Returns + ------- + Key + """ + + if type(json_data) is str: + json_data = json.loads(json_data) + private_key, public_key = decode_pair_from_encrypted_json(json_data, passphrase) + if 'sr25519' in json_data['encoding']['content']: + crypto_type = KeyType.SR25519 + elif 'ed25519' in json_data['encoding']['content']: + crypto_type = KeyType.ED25519 + # Strip the nonce part of the private key + private_key = private_key[0:32] + elif 'solana' in json_data['encoding']['content']: + crypto_type = KeyType.SOLANA + private_key = private_key[0:32] + else: + raise NotImplementedError("Unknown KeyType found in JSON") + + if ss58_format is None and 'address' in json_data: + ss58_format = get_ss58_format(json_data['address']) + + return cls.create_from_private_key(private_key, public_key, ss58_format=ss58_format, crypto_type=crypto_type) + + def export_to_encrypted_json(self, passphrase: str, name: str = None) -> dict: + """ + Export Key to PolkadotJS format encrypted JSON file + + Parameters + ---------- + passphrase: Used to encrypt the keypair + name: Display name of Key used + + Returns + ------- + dict + """ + raise NotImplementedError("export_to_encrypted_json not implemented") + + seperator = "::signature=" + + def sign(self, data: Union[ScaleBytes, bytes, str], to_json=False) -> bytes: + """ + Creates a signature for given data + Parameters + ---------- + data: data to sign in `Scalebytes`, bytes or hex string format + Returns + ------- + signature in bytes + + """ + raise NotImplementedError("sign not implemented") + + @classmethod + def bytes2str(cls, data: bytes, mode: str = "utf-8") -> str: + if hasattr(data, "hex"): + return data.hex() + else: + if isinstance(data, str): + return data + return bytes.decode(data, mode) + + @classmethod + def python2str(cls, input): + from copy import deepcopy + import json + + input = deepcopy(input) + input_type = type(input) + if input_type == str: + return input + if input_type in [dict]: + input = json.dumps(input) + elif input_type in [bytes]: + input = cls.bytes2str(input) + elif input_type in [list, tuple, set]: + input = json.dumps(list(input)) + elif input_type in [int, float, bool]: + input = str(input) + return input + + def verify( + self, + data: Union[ScaleBytes, bytes, str, dict], + signature: Union[bytes, str], + public_key: Optional[str], + return_address, + ss58_format, + max_age, + address, + **kwargs, + ) -> bool: + """ + Verifies data with specified signature + + Parameters + ---------- + data: data to be verified in `Scalebytes`, bytes or hex string format + signature: signature in bytes or hex string format + public_key: public key in bytes or hex string format + + Returns + ------- + True if data is signed with this Key, otherwise False + """ + raise NotImplementedError("verify not implemented") + + def is_ticket(self, data): + return all( + [k in data for k in ["data", "signature", "address", "crypto_type"]] + ) and any([k in data for k in ["time", "timestamp"]]) + + def resolve_encryption_password(self, password: str = None) -> str: + if password == None: + password = self.private_key + if isinstance(password, str): + password = password.encode() + return hashlib.sha256(password).digest() + + def resolve_encryption_data(self, data): + if not isinstance(data, str): + data = str(data) + return data + + def encrypt(self, data, password=None): + data = self.resolve_encryption_data(data) + password = self.resolve_encryption_password(password) + data = data + (AES.block_size - len(data) % AES.block_size) * chr( + AES.block_size - len(data) % AES.block_size + ) + iv = Random.new().read(AES.block_size) + cipher = AES.new(password, AES.MODE_CBC, iv) + encrypted_bytes = base64.b64encode(iv + cipher.encrypt(data.encode())) + return encrypted_bytes.decode() + + def decrypt(self, data, password=None): + password = self.resolve_encryption_password(password) + data = base64.b64decode(data) + iv = data[: AES.block_size] + cipher = AES.new(password, AES.MODE_CBC, iv) + data = cipher.decrypt(data[AES.block_size :]) + data = data[: -ord(data[len(data) - 1 :])].decode("utf-8") + return data + + def encrypt_message( + self, + message: Union[bytes, str], + recipient_public_key: bytes, + nonce: bytes = secrets.token_bytes(24), + ) -> bytes: + """ + Encrypts message with for specified recipient + + Parameters + ---------- + message: message to be encrypted, bytes or string + recipient_public_key: recipient's public key + nonce: the nonce to use in the encryption + + Returns + ------- + Encrypted message + """ + if not self.private_key: + raise Exception("No private key set to encrypt") + if self.crypto_type != KeyType.ED25519: + raise Exception("Only ed25519 keypair type supported") + curve25519_public_key = nacl.bindings.crypto_sign_ed25519_pk_to_curve25519( + recipient_public_key + ) + recipient = nacl.public.PublicKey(curve25519_public_key) + private_key = nacl.bindings.crypto_sign_ed25519_sk_to_curve25519( + self.private_key + self.public_key + ) + sender = nacl.public.PrivateKey(private_key) + box = nacl.public.Box(sender, recipient) + return box.encrypt( + message if isinstance(message, bytes) else message.encode("utf-8"), nonce + ) + + def decrypt_message( + self, encrypted_message_with_nonce: bytes, sender_public_key: bytes + ) -> bytes: + """ + Decrypts message from a specified sender + + Parameters + ---------- + encrypted_message_with_nonce: message to be decrypted + sender_public_key: sender's public key + + Returns + ------- + Decrypted message + """ + + if not self.private_key: + raise Exception("No private key set to decrypt") + if self.crypto_type != KeyType.ED25519: + raise Exception("Only ed25519 keypair type supported") + private_key = nacl.bindings.crypto_sign_ed25519_sk_to_curve25519( + self.private_key + self.public_key + ) + recipient = nacl.public.PrivateKey(private_key) + curve25519_public_key = nacl.bindings.crypto_sign_ed25519_pk_to_curve25519( + sender_public_key + ) + sender = nacl.public.PublicKey(curve25519_public_key) + return nacl.public.Box(recipient, sender).decrypt(encrypted_message_with_nonce) + + encrypted_prefix = "ENCRYPTED::" + + @classmethod + def encrypt_key(cls, path="test.enc", password=None): + assert cls.key_exists(path), f"file {path} does not exist" + assert not cls.is_key_encrypted(path), f"{path} already encrypted" + data = cls.get(path) + enc_text = {"data": c.encrypt(data, password=password), "encrypted": True} + cls.put(path, enc_text) + return {"number_of_characters_encrypted": len(enc_text), "path": path} + + @classmethod + def is_key_encrypted(cls, key, data=None): + data = data or cls.get(key) + return cls.is_encrypted(data) + + @classmethod + def decrypt_key(cls, path="test.enc", password=None, key=None): + assert cls.key_exists(path), f"file {path} does not exist" + assert cls.is_key_encrypted(path), f"{path} not encrypted" + data = cls.get(path) + assert cls.is_encrypted(data), f"{path} not encrypted" + dec_text = c.decrypt(data["data"], password=password) + cls.put(path, dec_text) + assert not cls.is_key_encrypted(path), f"failed to decrypt {path}" + loaded_key = c.get_key(path) + return { + "path": path, + "key_address": loaded_key.ss58_address, + "crypto_type": loaded_key.crypto_type, + } + + @classmethod + def get_mnemonic(cls, key): + return cls.get_key(key).mnemonic + + def __str__(self): + return ( + f"" + ) + + def save(self, path=None): + if path == None: + path = self.path + c.put_json(path, self.to_json()) + return {"saved": path} + + def __repr__(self): + return self.__str__() + + @classmethod + def from_private_key(cls, private_key: str): + return cls(private_key=private_key) + + @classmethod + def valid_ss58_address(cls, address: str, ss58_format:int = SS58_FORMAT) -> bool: + """ + Checks if the given address is a valid ss58 address. + """ + try: + return is_valid_ss58_address(address, valid_ss58_format=ss58_format) + except Exception: + return False + + @classmethod + def is_encrypted(cls, data): + if isinstance(data, str): + if os.path.exists(data): + data = c.get_json(data) + else: + try: + data = json.loads(data) + except: + return False + if isinstance(data, dict): + return bool(data.get("encrypted", False)) + else: + return False + + @staticmethod + def ss58_encode(*args, **kwargs): + return ss58_encode(*args, **kwargs) + + @staticmethod + def ss58_decode(*args, **kwargs): + return ss58_decode(*args, **kwargs) + + @classmethod + def resolve_key_address(cls, key): + key2address = c.key2address() + if key in key2address: + address = key2address[key] + else: + address = key + return address + + @classmethod + def valid_h160_address(cls, address): + # Check if it starts with '0x' + if not address.startswith("0x"): + return False + + # Remove '0x' prefix + address = address[2:] + + # Check length + if len(address) != 40: + return False + + # Check if it contains only valid hex characters + if not re.match("^[0-9a-fA-F]{40}$", address): + return False + + return True + + def storage_migration(self): + key2path = self.key2path() + new_key2path = {} + for k_name, k_path in key2path.items(): + try: + key = c.get_key(k_name) + new_k_path = ( + "/".join(k_path.split("/")[:-1]) + + "/" + + f"{k_name}_address={key.ss58_address}_type={key.crypto_type}.json" + ) + new_key2path[k_name] = new_k_path + except Exception as e: + c.print(f"failed to migrate {k_name} due to {e}", color="red") + + return new_key2path + + def storage_migration(self): + key2path = self.key2path() + new_key2path = {} + for k_name, k_path in key2path.items(): + try: + key = c.get_key(k_name) + new_k_path = ( + "/".join(k_path.split("/")[:-1]) + + "/" + + f"{k_name}_address={key.ss58_address}_type={key.crypto_type}.json" + ) + new_key2path[k_name] = new_k_path + except Exception as e: + c.print(f"failed to migrate {k_name} due to {e}", color="red") + + return new_key2path \ No newline at end of file From f52fac225e9f257946f8c4601714d1d77f6d85f6 Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:43:28 -0600 Subject: [PATCH 10/27] ref: add sr25519 key for polkadot --- commune/key/dot_sr25519.py | 383 +++++++++++++++++++++++++++++++++++++ 1 file changed, 383 insertions(+) create mode 100644 commune/key/dot_sr25519.py diff --git a/commune/key/dot_sr25519.py b/commune/key/dot_sr25519.py new file mode 100644 index 000000000..4911f5450 --- /dev/null +++ b/commune/key/dot_sr25519.py @@ -0,0 +1,383 @@ +import sr25519 +import commune as c + +from .types.index import * +from .utils import * + +import json + +import time +import os +import binascii +import re +import secrets +import base64 +from base64 import b64encode +import hashlib +from Crypto import Random +from Crypto.Cipher import AES +import nacl.bindings +import nacl.public +from scalecodec.base import ScaleBytes +from bip39 import bip39_to_mini_secret +from scalecodec.utils.ss58 import ( + ss58_decode, + ss58_encode, + is_valid_ss58_address, + get_ss58_format, +) +from typing import Union, Optional +from .key import Key, KeyType +class DotSR25519(Key): + def __init__( + self, + private_key: Union[bytes, str] = None, + ss58_format: int = SS58_FORMAT, + derive_path: str = None, + path: str = None, + crypto_type: Union[str, int] = KeyType.SR25519, + **kwargs + ): + self.crypto_type = KeyType.SR25519 + self.set_private_key(private_key=private_key, ss58_format = ss58_format, derive_path=derive_path, path=path, **kwargs) + + def set_private_key( + self, + private_key: Union[bytes, str] = None, + ss58_format: int = 42, + derive_path: str = None, + path: str = None, + **kwargs, + ): + """ + Allows generation of Keys from a variety of input combination, such as a public/private key combination, + mnemonic or URI containing soft and hard derivation paths. With these Keys data can be signed and verified + + Parameters + ---------- + ss58_address: Substrate address + public_key: hex string or bytes of public_key key + private_key: hex string or bytes of private key + ss58_format: Substrate address format, default to 42 when omitted + seed_hex: hex string of seed + """ + # If no arguments are provided, generate a random keypair + if private_key == None: + private_key = self.new_key().private_key + if type(private_key) == str: + private_key = c.str2bytes(private_key) + if self.crypto_type == KeyType.SR25519: + if len(private_key) != 64: + private_key = sr25519.pair_from_seed(private_key)[1] + public_key = sr25519.public_from_secret_key(private_key) + key_address = ss58_encode(public_key, ss58_format=ss58_format) + hash_type = "ss58" + else: + raise ValueError('crypto_type "{}" not supported'.format(self.crypto_type)) + if type(public_key) is str: + public_key = bytes.fromhex(public_key.replace("0x", "")) + + self.hash_type = hash_type + self.public_key = public_key + self.address = self.key_address = self.ss58_address = key_address + self.private_key = private_key + self.derive_path = derive_path + self.path = path + self.ss58_format = ss58_format + self.key_address = self.ss58_address + self.key_type = self.crypto_type2name(self.crypto_type) + return {"key_address": key_address, "crypto_type": self.crypto_type} + + + @classmethod + def create_from_mnemonic( + cls, mnemonic: str = None, ss58_format=SS58_FORMAT, language_code: str = "en", crypto_type=KeyType.SR25519 + ) -> "DotSR25519": + """ + Create a Key for given memonic + + Parameters + ---------- + mnemonic: Seed phrase + ss58_format: Substrate address format + language_code: The language to use, valid values are: 'en', 'zh-hans', 'zh-hant', 'fr', 'it', 'ja', 'ko', 'es'. Defaults to `"en"` + + Returns + ------- + Key + """ + if not mnemonic: + mnemonic = cls.generate_mnemonic(language_code=language_code) + + keypair = cls.create_from_seed( + seed_hex=binascii.hexlify( + bytearray(bip39_to_mini_secret(mnemonic, "", language_code)) + ).decode("ascii"), + ss58_format=ss58_format, + ) + + keypair.mnemonic = mnemonic + + return keypair + + from_mnemonic = from_mem = create_from_mnemonic + + @classmethod + def create_from_seed( + cls, seed_hex: Union[bytes, str], ss58_format: Optional[int] = SS58_FORMAT + ) -> "DotSR25519": + """ + Create a Key for given seed + + Parameters + ---------- + seed_hex: hex string of seed + ss58_format: Substrate address format + + Returns + ------- + Key + """ + if type(seed_hex) is str: + seed_hex = bytes.fromhex(seed_hex.replace("0x", "")) + public_key, private_key = sr25519.pair_from_seed(seed_hex) + + ss58_address = ss58_encode(public_key, ss58_format) + kwargs = dict( + ss58_address=ss58_address, + public_key=public_key, + private_key=private_key, + ss58_format=ss58_format, + crypto_type=KeyType.SR25519, + ) + + return cls(**kwargs) + + @classmethod + def create_from_uri( + cls, + suri: str, + ss58_format: Optional[int] = SS58_FORMAT, + language_code: str = "en", + crypto_type: Union[str, int] = KeyType.SR25519, + ) -> "DotSR25519": + """ + Creates Key for specified suri in following format: `[mnemonic]/[soft-path]//[hard-path]` + + Parameters + ---------- + suri: + ss58_format: Substrate address format + language_code: The language to use, valid values are: 'en', 'zh-hans', 'zh-hant', 'fr', 'it', 'ja', 'ko', 'es'. Defaults to `"en"` + + Returns + ------- + Key + """ + crypto_type = cls.resolve_crypto_type(crypto_type) + suri = str(suri) + if not suri.startswith("//"): + suri = "//" + suri + + if suri and suri.startswith("/"): + suri = DEV_PHRASE + suri + + suri_regex = re.match( + r"^(?P.[^/]+( .[^/]+)*)(?P(//?[^/]+)*)(///(?P.*))?$", + suri, + ) + + suri_parts = suri_regex.groupdict() + + if suri_parts["password"]: + raise NotImplementedError( + f"Passwords in suri not supported for crypto_type '{KeyType.SR25519}'" + ) + + derived_keypair = cls.create_from_mnemonic( + suri_parts["phrase"], ss58_format=ss58_format, language_code=language_code + ) + + if suri_parts["path"] != "": + derived_keypair.derive_path = suri_parts["path"] + + derive_junctions = extract_derive_path(suri_parts["path"]) + + child_pubkey = derived_keypair.public_key + child_privkey = derived_keypair.private_key + + for junction in derive_junctions: + if junction.is_hard: + _, child_pubkey, child_privkey = sr25519.hard_derive_keypair( + (junction.chain_code, child_pubkey, child_privkey), b"" + ) + + else: + _, child_pubkey, child_privkey = sr25519.derive_keypair( + (junction.chain_code, child_pubkey, child_privkey), b"" + ) + + derived_keypair = DotSR25519( + public_key=child_pubkey, + private_key=child_privkey, + ss58_format=ss58_format, + ) + + return derived_keypair + + from_mnem = from_mnemonic = create_from_mnemonic + + def export_to_encrypted_json(self, passphrase: str, name: str = None) -> dict: + """ + Export Key to PolkadotJS format encrypted JSON file + + Parameters + ---------- + passphrase: Used to encrypt the keypair + name: Display name of Key used + + Returns + ------- + dict + """ + if not name: + name = self.ss58_address + + # Secret key from PolkadotJS is an Ed25519 expanded secret key, so has to be converted + # https://github.com/polkadot-js/wasm/blob/master/packages/wasm-crypto/src/rs/sr25519.rs#L125 + converted_private_key = sr25519.convert_secret_key_to_ed25519(self.private_key) + encoded = encode_pair(self.public_key, converted_private_key, passphrase) + encoding_content = ["pkcs8", "sr25519"] + + json_data = { + "encoded": b64encode(encoded).decode(), + "encoding": { + "content": encoding_content, + "type": ["scrypt", "xsalsa20-poly1305"], + "version": "3", + }, + "address": self.ss58_address, + "meta": {"name": name, "tags": [], "whenCreated": int(time.time())}, + } + return json_data + + seperator = "::signature=" + + def sign(self, data: Union[ScaleBytes, bytes, str], to_json=False) -> bytes: + """ + Creates a signature for given data + Parameters + ---------- + data: data to sign in `Scalebytes`, bytes or hex string format + Returns + ------- + signature in bytes + + """ + if not isinstance(data, str): + data = c.python2str(data) + if type(data) is ScaleBytes: + data = bytes(data.data) + elif data[0:2] == "0x": + data = bytes.fromhex(data[2:]) + elif type(data) is str: + data = data.encode() + if not self.private_key: + raise Exception("No private key set to create signatures") + signature = sr25519.sign((self.public_key, self.private_key), data) + + if to_json: + return { + "data": data.decode(), + "crypto_type": self.crypto_type, + "signature": signature.hex(), + "address": self.ss58_address, + } + return signature + + def verify( + self, + data: Union[ScaleBytes, bytes, str, dict], + signature: Union[bytes, str] = None, + public_key: Optional[str] = None, + return_address=False, + ss58_format=SS58_FORMAT, + max_age=None, + address=None, + **kwargs, + ) -> bool: + """ + Verifies data with specified signature + + Parameters + ---------- + data: data to be verified in `Scalebytes`, bytes or hex string format + signature: signature in bytes or hex string format + public_key: public key in bytes or hex string format + + Returns + ------- + True if data is signed with this Key, otherwise False + """ + data = c.copy(data) + + if isinstance(data, dict): + if self.is_ticket(data): + address = data.pop("address") + signature = data.pop("signature") + elif "data" in data and "signature" in data and "address" in data: + signature = data.pop("signature") + address = data.pop("address", address) + data = data.pop("data") + else: + assert signature != None, "signature not found in data" + assert address != None, "address not found in data" + + if max_age != None: + if isinstance(data, int): + staleness = c.timestamp() - int(data) + elif "timestamp" in data or "time" in data: + timestamp = data.get("timestamp", data.get("time")) + staleness = c.timestamp() - int(timestamp) + else: + raise ValueError( + "data should be a timestamp or a dict with a timestamp key" + ) + assert staleness < max_age, ( + f"data is too old, {staleness} seconds old, max_age is {max_age}" + ) + + if not isinstance(data, str): + data = c.python2str(data) + if address != None: + if self.valid_ss58_address(address): + public_key = ss58_decode(address) + if public_key == None: + public_key = self.public_key + if isinstance(public_key, str): + public_key = bytes.fromhex(public_key.replace("0x", "")) + if type(data) is ScaleBytes: + data = bytes(data.data) + elif data[0:2] == "0x": + data = bytes.fromhex(data[2:]) + elif type(data) is str: + data = data.encode() + if type(signature) is str and signature[0:2] == "0x": + signature = bytes.fromhex(signature[2:]) + elif type(signature) is str: + signature = bytes.fromhex(signature) + if type(signature) is not bytes: + raise TypeError("Signature should be of type bytes or a hex-string") + + crypto_verify_fn = sr25519.verify + + verified = crypto_verify_fn(signature, data, public_key) + if not verified: + # Another attempt with the data wrapped, as discussed in https://github.com/polkadot-js/extension/pull/743 + # Note: As Python apps are trusted sources on its own, no need to wrap data when signing from this lib + verified = crypto_verify_fn( + signature, b"" + data + b"", public_key + ) + if return_address: + return ss58_encode(public_key, ss58_format=ss58_format) + return verified From 7455a77d3f9c10714c28d6afe3f55ded249e4fb2 Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:43:56 -0600 Subject: [PATCH 11/27] ref: add ed25519 key for polkadot --- commune/key/dot_ed25519.py | 348 +++++++++++++++++++++++++++++++++++++ 1 file changed, 348 insertions(+) create mode 100644 commune/key/dot_ed25519.py diff --git a/commune/key/dot_ed25519.py b/commune/key/dot_ed25519.py new file mode 100644 index 000000000..d0a9e72b2 --- /dev/null +++ b/commune/key/dot_ed25519.py @@ -0,0 +1,348 @@ +import ed25519_zebra +import sr25519 +import commune as c + +from .types.index import * +from .utils import * + +import binascii +import re +import nacl.public +from scalecodec.base import ScaleBytes +from bip39 import bip39_to_mini_secret +from scalecodec.utils.ss58 import ( + ss58_decode, + ss58_encode +) +from typing import Union, Optional +from .key import Key, KeyType +class DotED25519(Key): + def __init__( + self, + private_key: Union[bytes, str] = None, + ss58_format: int = SS58_FORMAT, + derive_path: str = None, + path: str = None, + crypto_type: Union[str, int] = KeyType.ED25519, + **kwargs + ): + self.crypto_type = KeyType.ED25519 + self.set_private_key(private_key=private_key, ss58_format = ss58_format, derive_path=derive_path, path=path, **kwargs) + + def set_private_key( + self, + private_key: Union[bytes, str] = None, + ss58_format: int = 42, + derive_path: str = None, + path: str = None, + **kwargs, + ): + """ + Allows generation of Keys from a variety of input combination, such as a public/private key combination, + mnemonic or URI containing soft and hard derivation paths. With these Keys data can be signed and verified + + Parameters + ---------- + ss58_address: Substrate address + public_key: hex string or bytes of public_key key + private_key: hex string or bytes of private key + ss58_format: Substrate address format, default to 42 when omitted + seed_hex: hex string of seed + """ + # If no arguments are provided, generate a random keypair + + + if private_key == None: + private_key = self.new_key(crypto_type=self.crypto_type).private_key + if type(private_key) == str: + private_key = c.str2bytes(private_key) + if self.crypto_type == KeyType.ED25519: + private_key = private_key[:32] if len(private_key) == 64 else private_key + private_key, public_key = ed25519_zebra.ed_from_seed(private_key) + key_address = ss58_encode(public_key, ss58_format=ss58_format) + hash_type = 'ss58' + else: + raise ValueError('crypto_type "{}" not supported'.format(self.crypto_type)) + if type(public_key) is str: + public_key = bytes.fromhex(public_key.replace("0x", "")) + + self.hash_type = hash_type + self.public_key = public_key + self.address = self.key_address = self.ss58_address = key_address + self.private_key = private_key + self.derive_path = derive_path + self.path = path + self.ss58_format = ss58_format + self.key_address = self.ss58_address + self.key_type = self.crypto_type2name(self.crypto_type) + return {"key_address": key_address, "crypto_type": self.crypto_type} + + + @classmethod + def create_from_mnemonic( + cls, mnemonic: str = None, ss58_format=SS58_FORMAT, language_code: str = "en", crypto_type=KeyType.ED25519 + ) -> "DotED25519": + """ + Create a Key for given memonic + + Parameters + ---------- + mnemonic: Seed phrase + ss58_format: Substrate address format + language_code: The language to use, valid values are: 'en', 'zh-hans', 'zh-hant', 'fr', 'it', 'ja', 'ko', 'es'. Defaults to `"en"` + + Returns + ------- + Key + """ + if not mnemonic: + mnemonic = cls.generate_mnemonic(language_code=language_code) + + keypair = cls.create_from_seed( + seed_hex=binascii.hexlify( + bytearray(bip39_to_mini_secret(mnemonic, "", language_code)) + ).decode("ascii"), + ss58_format=ss58_format, + ) + + keypair.mnemonic = mnemonic + + return keypair + + @classmethod + def create_from_seed( + cls, seed_hex: Union[bytes, str], ss58_format: Optional[int] = SS58_FORMAT + ) -> "DotED25519": + """ + Create a Key for given seed + + Parameters + ---------- + seed_hex: hex string of seed + ss58_format: Substrate address format + + Returns + ------- + Key + """ + if type(seed_hex) is str: + seed_hex = bytes.fromhex(seed_hex.replace("0x", "")) + private_key, public_key = ed25519_zebra.ed_from_seed(seed_hex) + ss58_address = ss58_encode(public_key, ss58_format) + kwargs = dict( + ss58_address=ss58_address, + public_key=public_key, + private_key=private_key, + ss58_format=ss58_format, + crypto_type=KeyType.ED25519, + ) + + return cls(**kwargs) + + @classmethod + def create_from_uri( + cls, + suri: str, + ss58_format: Optional[int] = SS58_FORMAT, + language_code: str = "en", + crypto_type: Union[str, int] = KeyType.ED25519, + ) -> "DotED25519": + """ + Creates Key for specified suri in following format: `[mnemonic]/[soft-path]//[hard-path]` + + Parameters + ---------- + suri: + ss58_format: Substrate address format + language_code: The language to use, valid values are: 'en', 'zh-hans', 'zh-hant', 'fr', 'it', 'ja', 'ko', 'es'. Defaults to `"en"` + + Returns + ------- + DotED25519 + """ + crypto_type = cls.resolve_crypto_type(crypto_type) + suri = str(suri) + if not suri.startswith("//"): + suri = "//" + suri + + if suri and suri.startswith("/"): + suri = DEV_PHRASE + suri + + suri_regex = re.match( + r"^(?P.[^/]+( .[^/]+)*)(?P(//?[^/]+)*)(///(?P.*))?$", + suri, + ) + + suri_parts = suri_regex.groupdict() + + if suri_parts["password"]: + raise NotImplementedError( + f"Passwords in suri not supported for crypto_type '{KeyType.ED25519}'" + ) + + derived_keypair = cls.create_from_mnemonic( + suri_parts["phrase"], ss58_format=ss58_format, language_code=language_code + ) + + if suri_parts["path"] != "": + derived_keypair.derive_path = suri_parts["path"] + + derive_junctions = extract_derive_path(suri_parts["path"]) + + child_pubkey = derived_keypair.public_key + child_privkey = derived_keypair.private_key + + for junction in derive_junctions: + if junction.is_hard: + _, child_pubkey, child_privkey = sr25519.hard_derive_keypair( + (junction.chain_code, child_pubkey, child_privkey), b"" + ) + + else: + _, child_pubkey, child_privkey = sr25519.derive_keypair( + (junction.chain_code, child_pubkey, child_privkey), b"" + ) + + derived_keypair = DotED25519( + public_key=child_pubkey, + private_key=child_privkey, + ss58_format=ss58_format, + ) + + return derived_keypair + + def export_to_encrypted_json(self, passphrase: str, name: str = None) -> dict: + """ + Export Key to PolkadotJS format encrypted JSON file + + Parameters + ---------- + passphrase: Used to encrypt the keypair + name: Display name of Key used + + Returns + ------- + dict + """ + raise NotImplementedError(f"Cannot create JSON for crypto_type '{self.crypto_type}'") + + seperator = "::signature=" + + def sign(self, data: Union[ScaleBytes, bytes, str], to_json=False) -> bytes: + """ + Creates a signature for given data + Parameters + ---------- + data: data to sign in `Scalebytes`, bytes or hex string format + Returns + ------- + signature in bytes + + """ + if not isinstance(data, str): + data = c.python2str(data) + if type(data) is ScaleBytes: + data = bytes(data.data) + elif data[0:2] == "0x": + data = bytes.fromhex(data[2:]) + elif type(data) is str: + data = data.encode() + if not self.private_key: + raise Exception("No private key set to create signatures") + signature = ed25519_zebra.ed_sign(self.private_key, data) + + if to_json: + return { + "data": data.decode(), + "crypto_type": self.crypto_type, + "signature": signature.hex(), + "address": self.ss58_address, + } + return signature + + def verify( + self, + data: Union[ScaleBytes, bytes, str, dict], + signature: Union[bytes, str] = None, + public_key: Optional[str] = None, + return_address=False, + ss58_format=SS58_FORMAT, + max_age=None, + address=None, + **kwargs, + ) -> bool: + """ + Verifies data with specified signature + + Parameters + ---------- + data: data to be verified in `Scalebytes`, bytes or hex string format + signature: signature in bytes or hex string format + public_key: public key in bytes or hex string format + + Returns + ------- + True if data is signed with this Key, otherwise False + """ + data = c.copy(data) + + if isinstance(data, dict): + if self.is_ticket(data): + address = data.pop("address") + signature = data.pop("signature") + elif "data" in data and "signature" in data and "address" in data: + signature = data.pop("signature") + address = data.pop("address", address) + data = data.pop("data") + else: + assert signature != None, "signature not found in data" + assert address != None, "address not found in data" + + if max_age != None: + if isinstance(data, int): + staleness = c.timestamp() - int(data) + elif "timestamp" in data or "time" in data: + timestamp = data.get("timestamp", data.get("time")) + staleness = c.timestamp() - int(timestamp) + else: + raise ValueError( + "data should be a timestamp or a dict with a timestamp key" + ) + assert staleness < max_age, ( + f"data is too old, {staleness} seconds old, max_age is {max_age}" + ) + + if not isinstance(data, str): + data = c.python2str(data) + if address != None: + if self.valid_ss58_address(address): + public_key = ss58_decode(address) + if public_key == None: + public_key = self.public_key + if isinstance(public_key, str): + public_key = bytes.fromhex(public_key.replace("0x", "")) + if type(data) is ScaleBytes: + data = bytes(data.data) + elif data[0:2] == "0x": + data = bytes.fromhex(data[2:]) + elif type(data) is str: + data = data.encode() + if type(signature) is str and signature[0:2] == "0x": + signature = bytes.fromhex(signature[2:]) + elif type(signature) is str: + signature = bytes.fromhex(signature) + if type(signature) is not bytes: + raise TypeError("Signature should be of type bytes or a hex-string") + + crypto_verify_fn = ed25519_zebra.ed_verify + + verified = crypto_verify_fn(signature, data, public_key) + if not verified: + # Another attempt with the data wrapped, as discussed in https://github.com/polkadot-js/extension/pull/743 + # Note: As Python apps are trusted sources on its own, no need to wrap data when signing from this lib + verified = crypto_verify_fn( + signature, b"" + data + b"", public_key + ) + if return_address: + return ss58_encode(public_key, ss58_format=ss58_format) + return verified From 48f78b633066c20a4ff2649d9a1431ca127d5e2b Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:44:25 -0600 Subject: [PATCH 12/27] ref: add ecdsa key for ethereum --- commune/key/eth.py | 331 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 331 insertions(+) create mode 100644 commune/key/eth.py diff --git a/commune/key/eth.py b/commune/key/eth.py new file mode 100644 index 000000000..9bff9d4a0 --- /dev/null +++ b/commune/key/eth.py @@ -0,0 +1,331 @@ +import sr25519 +import commune as c + +from .types.index import * +from .utils import * + +import json + +import time +import os +import binascii +import re +import secrets +import base64 +from base64 import b64encode +import hashlib +from Crypto import Random +from Crypto.Cipher import AES +import nacl.bindings +import nacl.public +from scalecodec.base import ScaleBytes +from bip39 import bip39_to_mini_secret +from scalecodec.utils.ss58 import ( + ss58_decode, + ss58_encode, + is_valid_ss58_address, + get_ss58_format, +) +from typing import Union, Optional +from .key import Key, KeyType +class ECDSA(Key): + def __init__( + self, + private_key: Union[bytes, str] = None, + ss58_format: int = SS58_FORMAT, + derive_path: str = None, + path: str = None, + crypto_type: Union[str, int] = KeyType.ECDSA, + **kwargs + ): + self.crypto_type = KeyType.ECDSA + self.set_private_key(private_key=private_key, ss58_format = ss58_format, derive_path=derive_path, path=path, **kwargs) + + def set_private_key( + self, + private_key: Union[bytes, str] = None, + ss58_format: int = 42, + derive_path: str = None, + path: str = None, + **kwargs, + ): + """ + Allows generation of Keys from a variety of input combination, such as a public/private key combination, + mnemonic or URI containing soft and hard derivation paths. With these Keys data can be signed and verified + + Parameters + ---------- + ss58_address: Substrate address + public_key: hex string or bytes of public_key key + private_key: hex string or bytes of private key + ss58_format: Substrate address format, default to 42 when omitted + seed_hex: hex string of seed + """ + # If no arguments are provided, generate a random keypair + if private_key == None: + private_key = self.new_key(crypto_type=self.crypto_type).private_key + if type(private_key) == str: + private_key = c.str2bytes(private_key) + if self.crypto_type == KeyType.ECDSA: + private_key = private_key[0:32] + private_key_obj = PrivateKey(private_key) + public_key = private_key_obj.public_key.to_address() + key_address = private_key_obj.public_key.to_checksum_address() + hash_type = 'h160' + else: + raise ValueError('crypto_type "{}" not supported'.format(self.crypto_type)) + if type(public_key) is str: + public_key = bytes.fromhex(public_key.replace("0x", "")) + + self.hash_type = hash_type + self.public_key = public_key + self.address = self.key_address = self.ss58_address = key_address + self.private_key = private_key + self.derive_path = derive_path + self.path = path + self.ss58_format = ss58_format + self.key_address = self.ss58_address + self.key_type = self.crypto_type2name(self.crypto_type) + return {"key_address": key_address, "crypto_type": self.crypto_type} + + + @classmethod + def create_from_mnemonic( + cls, mnemonic: str = None, ss58_format=SS58_FORMAT, language_code: str = "en", crypto_type=KeyType.ECDSA + ) -> "ECDSA": + """ + Create a Key for given memonic + + Parameters + ---------- + mnemonic: Seed phrase + ss58_format: Substrate address format + language_code: The language to use, valid values are: 'en', 'zh-hans', 'zh-hant', 'fr', 'it', 'ja', 'ko', 'es'. Defaults to `"en"` + + Returns + ------- + Key + """ + if not mnemonic: + mnemonic = cls.generate_mnemonic(language_code=language_code) + + if language_code != "en": + raise ValueError("ECDSA mnemonic only supports english") + private_key = mnemonic_to_ecdsa_private_key(mnemonic) + keypair = cls.create_from_private_key(private_key, ss58_format=ss58_format, crypto_type=crypto_type) + + keypair.mnemonic = mnemonic + + return keypair + + from_mnemonic = from_mem = create_from_mnemonic + + @classmethod + def create_from_seed( + cls, seed_hex: Union[bytes, str], ss58_format: Optional[int] = SS58_FORMAT + ) -> "ECDSA": + """ + Create a Key for given seed + + Parameters + ---------- + seed_hex: hex string of seed + ss58_format: Substrate address format + + Returns + ------- + Key + """ + raise ValueError('crypto_type "{}" not supported'.format(KeyType.ECDSA)) + + @classmethod + def create_from_uri( + cls, + suri: str, + ss58_format: Optional[int] = SS58_FORMAT, + language_code: str = "en", + crypto_type: Union[str, int] = KeyType.ECDSA, + ) -> "ECDSA": + """ + Creates Key for specified suri in following format: `[mnemonic]/[soft-path]//[hard-path]` + + Parameters + ---------- + suri: + ss58_format: Substrate address format + language_code: The language to use, valid values are: 'en', 'zh-hans', 'zh-hant', 'fr', 'it', 'ja', 'ko', 'es'. Defaults to `"en"` + + Returns + ------- + Key + """ + crypto_type = cls.resolve_crypto_type(crypto_type) + suri = str(suri) + if not suri.startswith("//"): + suri = "//" + suri + + if suri and suri.startswith("/"): + suri = DEV_PHRASE + suri + + suri_regex = re.match( + r"^(?P.[^/]+( .[^/]+)*)(?P(//?[^/]+)*)(///(?P.*))?$", + suri, + ) + + suri_parts = suri_regex.groupdict() + + if language_code != "en": + raise ValueError("ECDSA mnemonic only supports english") + print(suri_parts) + private_key = mnemonic_to_ecdsa_private_key( + mnemonic=suri_parts['phrase'], + str_derivation_path=suri_parts['path'], + passphrase=suri_parts['password'] + ) + derived_keypair = cls.create_from_private_key(private_key, ss58_format=ss58_format, crypto_type=crypto_type) + + return derived_keypair + + from_mnem = from_mnemonic = create_from_mnemonic + + def export_to_encrypted_json(self, passphrase: str, name: str = None) -> dict: + """ + Export Key to PolkadotJS format encrypted JSON file + + Parameters + ---------- + passphrase: Used to encrypt the keypair + name: Display name of Key used + + Returns + ------- + dict + """ + if not name: + name = self.ss58_address + + # Secret key from PolkadotJS is an Ed25519 expanded secret key, so has to be converted + # https://github.com/polkadot-js/wasm/blob/master/packages/wasm-crypto/src/rs/sr25519.rs#L125 + raise NotImplementedError("export_to_encrypted_json not implemented") + + seperator = "::signature=" + + def sign(self, data: Union[ScaleBytes, bytes, str], to_json=False) -> bytes: + """ + Creates a signature for given data + Parameters + ---------- + data: data to sign in `Scalebytes`, bytes or hex string format + Returns + ------- + signature in bytes + + """ + if not isinstance(data, str): + data = c.python2str(data) + if type(data) is ScaleBytes: + data = bytes(data.data) + elif data[0:2] == "0x": + data = bytes.fromhex(data[2:]) + elif type(data) is str: + data = data.encode() + if not self.private_key: + raise Exception("No private key set to create signatures") + + signature = ecdsa_sign(self.private_key, data) + + if to_json: + return { + "data": data.decode(), + "crypto_type": self.crypto_type, + "signature": signature.hex(), + "address": self.ss58_address, + } + return signature + + def verify( + self, + data: Union[ScaleBytes, bytes, str, dict], + signature: Union[bytes, str] = None, + public_key: Optional[str] = None, + return_address=False, + ss58_format=SS58_FORMAT, + max_age=None, + address=None, + **kwargs, + ) -> bool: + """ + Verifies data with specified signature + + Parameters + ---------- + data: data to be verified in `Scalebytes`, bytes or hex string format + signature: signature in bytes or hex string format + public_key: public key in bytes or hex string format + + Returns + ------- + True if data is signed with this Key, otherwise False + """ + data = c.copy(data) + + if isinstance(data, dict): + if self.is_ticket(data): + address = data.pop("address") + signature = data.pop("signature") + elif "data" in data and "signature" in data and "address" in data: + signature = data.pop("signature") + address = data.pop("address", address) + data = data.pop("data") + else: + assert signature != None, "signature not found in data" + assert address != None, "address not found in data" + + if max_age != None: + if isinstance(data, int): + staleness = c.timestamp() - int(data) + elif "timestamp" in data or "time" in data: + timestamp = data.get("timestamp", data.get("time")) + staleness = c.timestamp() - int(timestamp) + else: + raise ValueError( + "data should be a timestamp or a dict with a timestamp key" + ) + assert staleness < max_age, ( + f"data is too old, {staleness} seconds old, max_age is {max_age}" + ) + + if not isinstance(data, str): + data = c.python2str(data) + if address != None: + if self.valid_ss58_address(address): + public_key = ss58_decode(address) + if public_key == None: + public_key = self.public_key + if isinstance(public_key, str): + public_key = bytes.fromhex(public_key.replace("0x", "")) + if type(data) is ScaleBytes: + data = bytes(data.data) + elif data[0:2] == "0x": + data = bytes.fromhex(data[2:]) + elif type(data) is str: + data = data.encode() + if type(signature) is str and signature[0:2] == "0x": + signature = bytes.fromhex(signature[2:]) + elif type(signature) is str: + signature = bytes.fromhex(signature) + if type(signature) is not bytes: + raise TypeError("Signature should be of type bytes or a hex-string") + + crypto_verify_fn = ecdsa_verify + + verified = crypto_verify_fn(signature, data, public_key) + if not verified: + # Another attempt with the data wrapped, as discussed in https://github.com/polkadot-js/extension/pull/743 + # Note: As Python apps are trusted sources on its own, no need to wrap data when signing from this lib + verified = crypto_verify_fn( + signature, b"" + data + b"", public_key + ) + if return_address: + return ss58_encode(public_key, ss58_format=ss58_format) + return verified From 36112f4f8a160f0dd4d4b2ea08c94545b148f686 Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:44:52 -0600 Subject: [PATCH 13/27] ref: add solana key --- commune/key/sol.py | 306 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 306 insertions(+) create mode 100644 commune/key/sol.py diff --git a/commune/key/sol.py b/commune/key/sol.py new file mode 100644 index 000000000..6c3a949bd --- /dev/null +++ b/commune/key/sol.py @@ -0,0 +1,306 @@ +import sr25519 +import commune as c + +from .types.index import * +from .utils import * + +import re +import nacl.public +from scalecodec.base import ScaleBytes +from scalecodec.utils.ss58 import ( + ss58_decode, +) +from base58 import b58encode +from typing import Union, Optional +from .key import Key, KeyType +class Solana(Key): + def __init__( + self, + private_key: Union[bytes, str] = None, + ss58_format: int = SS58_FORMAT, + derive_path: str = None, + path: str = None, + crypto_type: Union[str, int] = KeyType.SOLANA, + **kwargs + ): + self.crypto_type = KeyType.SOLANA + self.set_private_key(private_key=private_key, ss58_format = ss58_format, derive_path=derive_path, path=path, **kwargs) + + def set_private_key( + self, + private_key: Union[bytes, str] = None, + ss58_format: int = 44, + derive_path: str = None, + path: str = None, + **kwargs, + ): + """ + Allows generation of Keys from a variety of input combination, such as a public/private key combination, + mnemonic or URI containing soft and hard derivation paths. With these Keys data can be signed and verified + + Parameters + ---------- + ss58_address: Substrate address + public_key: hex string or bytes of public_key key + private_key: hex string or bytes of private key + ss58_format: Substrate address format, default to 42 when omitted + seed_hex: hex string of seed + """ + # If no arguments are provided, generate a random keypair + if private_key == None: + private_key = self.new_key(crypto_type=self.crypto_type).private_key + if type(private_key) == str: + private_key = c.str2bytes(private_key) + if self.crypto_type == KeyType.SOLANA: + private_key = private_key[0:32] + keypair = SolanaKeypair.from_seed(private_key) + public_key = keypair.pubkey().__bytes__() + private_key = keypair.secret() + key_address = b58encode(bytes(public_key)).decode('utf-8') + hash_type = 'base58' + else: + raise ValueError('crypto_type "{}" not supported'.format(self.crypto_type)) + if type(public_key) is str: + public_key = bytes.fromhex(public_key.replace("0x", "")) + + self.hash_type = hash_type + self.public_key = public_key + self.address = self.key_address = self.ss58_address = key_address + self.private_key = private_key + self.derive_path = derive_path + self.path = path + self.key_address = self.ss58_address + self.key_type = self.crypto_type2name(self.crypto_type) + return {"key_address": key_address, "crypto_type": self.crypto_type} + + + @classmethod + def create_from_mnemonic( + cls, mnemonic: str = None, ss58_format=SS58_FORMAT, language_code: str = "en", crypto_type=KeyType.SOLANA + ) -> "Solana": + """ + Create a Key for given memonic + + Parameters + ---------- + mnemonic: Seed phrase + ss58_format: Substrate address format + language_code: The language to use, valid values are: 'en', 'zh-hans', 'zh-hant', 'fr', 'it', 'ja', 'ko', 'es'. Defaults to `"en"` + + Returns + ------- + Key + """ + if not mnemonic: + mnemonic = cls.generate_mnemonic(language_code=language_code) + + if language_code != "en": + raise ValueError("Solana mnemonic only supports english") + + private_key = SolanaKeypair.from_seed_phrase_and_passphrase(mnemonic, "").secret() + keypair = cls.create_from_private_key(private_key, ss58_format=ss58_format, crypto_type=crypto_type) + + keypair.mnemonic = mnemonic + + return keypair + + from_mnemonic = from_mem = create_from_mnemonic + + @classmethod + def create_from_seed( + cls, seed_hex: Union[bytes, str], ss58_format: Optional[int] = SS58_FORMAT + ) -> "Solana": + """ + Create a Key for given seed + + Parameters + ---------- + seed_hex: hex string of seed + ss58_format: Substrate address format + + Returns + ------- + Key + """ + raise ValueError('crypto_type "{}" not supported'.format(KeyType.SOLANA)) + + @classmethod + def create_from_uri( + cls, + suri: str, + ss58_format: Optional[int] = SS58_FORMAT, + language_code: str = "en", + crypto_type: Union[str, int] = KeyType.SOLANA, + ) -> "Solana": + """ + Creates Key for specified suri in following format: `[mnemonic]/[soft-path]//[hard-path]` + + Parameters + ---------- + suri: + ss58_format: Substrate address format + language_code: The language to use, valid values are: 'en', 'zh-hans', 'zh-hant', 'fr', 'it', 'ja', 'ko', 'es'. Defaults to `"en"` + + Returns + ------- + Key + """ + crypto_type = cls.resolve_crypto_type(crypto_type) + suri = str(suri) + if not suri.startswith("//"): + suri = "//" + suri + + if suri and suri.startswith("/"): + suri = DEV_PHRASE + suri + + suri_regex = re.match( + r"^(?P.[^/]+( .[^/]+)*)(?P(//?[^/]+)*)(///(?P.*))?$", + suri, + ) + + suri_parts = suri_regex.groupdict() + + if language_code != "en": + raise ValueError("Solana mnemonic only supports english") + private_key = SolanaKeypair.from_seed_phrase_and_passphrase(suri_parts['phrase'], passphrase=suri_parts['password']).secret() + derived_keypair = cls.create_from_private_key(private_key, ss58_format=ss58_format, crypto_type=crypto_type) + return derived_keypair + + from_mnem = from_mnemonic = create_from_mnemonic + + def export_to_encrypted_json(self, passphrase: str, name: str = None) -> dict: + """ + Export Key to PolkadotJS format encrypted JSON file + + Parameters + ---------- + passphrase: Used to encrypt the keypair + name: Display name of Key used + + Returns + ------- + dict + """ + + # Secret key from PolkadotJS is an Ed25519 expanded secret key, so has to be converted + # https://github.com/polkadot-js/wasm/blob/master/packages/wasm-crypto/src/rs/sr25519.rs#L125 + raise NotImplementedError("export_to_encrypted_json not implemented") + + seperator = "::signature=" + + def sign(self, data: Union[ScaleBytes, bytes, str], to_json=False) -> bytes: + """ + Creates a signature for given data + Parameters + ---------- + data: data to sign in `Scalebytes`, bytes or hex string format + Returns + ------- + signature in bytes + + """ + if not isinstance(data, str): + data = c.python2str(data) + if type(data) is ScaleBytes: + data = bytes(data.data) + elif data[0:2] == "0x": + data = bytes.fromhex(data[2:]) + elif type(data) is str: + data = data.encode() + if not self.private_key: + raise Exception("No private key set to create signatures") + + signature = solana_sign(self.private_key, data) + + if to_json: + return { + "data": data.decode(), + "crypto_type": self.crypto_type, + "signature": signature.hex(), + "address": self.ss58_address, + } + return signature + + def verify( + self, + data: Union[ScaleBytes, bytes, str, dict], + signature: Union[bytes, str] = None, + public_key: Optional[str] = None, + return_address=False, + ss58_format=SS58_FORMAT, + max_age=None, + address=None, + **kwargs, + ) -> bool: + """ + Verifies data with specified signature + + Parameters + ---------- + data: data to be verified in `Scalebytes`, bytes or hex string format + signature: signature in bytes or hex string format + public_key: public key in bytes or hex string format + + Returns + ------- + True if data is signed with this Key, otherwise False + """ + data = c.copy(data) + + if isinstance(data, dict): + if self.is_ticket(data): + address = data.pop("address") + signature = data.pop("signature") + elif "data" in data and "signature" in data and "address" in data: + signature = data.pop("signature") + address = data.pop("address", address) + data = data.pop("data") + else: + assert signature != None, "signature not found in data" + assert address != None, "address not found in data" + + if max_age != None: + if isinstance(data, int): + staleness = c.timestamp() - int(data) + elif "timestamp" in data or "time" in data: + timestamp = data.get("timestamp", data.get("time")) + staleness = c.timestamp() - int(timestamp) + else: + raise ValueError( + "data should be a timestamp or a dict with a timestamp key" + ) + assert staleness < max_age, ( + f"data is too old, {staleness} seconds old, max_age is {max_age}" + ) + + if not isinstance(data, str): + data = c.python2str(data) + if public_key == None: + public_key = self.public_key + if isinstance(public_key, str): + public_key = bytes.fromhex(public_key.replace("0x", "")) + if type(data) is ScaleBytes: + data = bytes(data.data) + elif data[0:2] == "0x": + data = bytes.fromhex(data[2:]) + elif type(data) is str: + data = data.encode() + if type(signature) is str and signature[0:2] == "0x": + signature = bytes.fromhex(signature[2:]) + elif type(signature) is str: + signature = bytes.fromhex(signature) + if type(signature) is not bytes: + raise TypeError("Signature should be of type bytes or a hex-string") + + crypto_verify_fn = solana_verify + + verified = crypto_verify_fn(signature, data, public_key) + if not verified: + # Another attempt with the data wrapped, as discussed in https://github.com/polkadot-js/extension/pull/743 + # Note: As Python apps are trusted sources on its own, no need to wrap data when signing from this lib + verified = crypto_verify_fn( + signature, b"" + data + b"", public_key + ) + if return_address: + return b58encode(public_key).decode('utf-8') + return verified From e691cc53501dbe71252992914d687876b4a0e7f6 Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:45:42 -0600 Subject: [PATCH 14/27] ref: add utils --- commune/key/utils.py | 264 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 264 insertions(+) create mode 100644 commune/key/utils.py diff --git a/commune/key/utils.py b/commune/key/utils.py new file mode 100644 index 000000000..6fa8c07c6 --- /dev/null +++ b/commune/key/utils.py @@ -0,0 +1,264 @@ +import re +import json +import base64 +import hashlib +import hmac +import struct +from nacl.secret import SecretBox +from nacl.pwhash import scrypt + +from typing import Union +from scalecodec.types import Bytes +from hashlib import blake2b +from math import ceil +from os import urandom +from sr25519 import pair_from_ed25519_secret_key +from eth_keys.datatypes import PrivateKey, Signature +from eth_utils import to_checksum_address, keccak as eth_utils_keccak +from solders.keypair import Keypair as SolanaKeypair +from solders.pubkey import Pubkey as SolanaPubkey +from solders.signature import Signature as SolanaSignature + +from .types.index import * + +class PublicKey: + def __init__(self, private_key): + self.point = int.from_bytes(private_key, byteorder='big') * BIP32_CURVE.generator + + def __bytes__(self): + xstr = int(self.point.x()).to_bytes(32, byteorder='big') + parity = int(self.point.y()) & 1 + return (2 + parity).to_bytes(1, byteorder='big') + xstr + + def address(self): + x = int(self.point.x()) + y = int(self.point.y()) + s = x.to_bytes(32, 'big') + y.to_bytes(32, 'big') + return to_checksum_address(eth_utils_keccak(s)[12:]) + +def mnemonic_to_bip39seed(mnemonic, passphrase): + mnemonic = bytes(mnemonic, 'utf8') + salt = bytes(BIP39_SALT_MODIFIER + passphrase, 'utf8') + return hashlib.pbkdf2_hmac('sha512', mnemonic, salt, BIP39_PBKDF2_ROUNDS) + +def bip39seed_to_bip32masternode(seed): + h = hmac.new(BIP32_SEED_MODIFIER, seed, hashlib.sha512).digest() + key, chain_code = h[:32], h[32:] + return key, chain_code + +def derive_bip32childkey(parent_key, parent_chain_code, i): + assert len(parent_key) == 32 + assert len(parent_chain_code) == 32 + k = parent_chain_code + if (i & BIP32_PRIVDEV) != 0: + key = b'\x00' + parent_key + else: + key = bytes(PublicKey(parent_key)) + d = key + struct.pack('>L', i) + while True: + h = hmac.new(k, d, hashlib.sha512).digest() + key, chain_code = h[:32], h[32:] + a = int.from_bytes(key, byteorder='big') + b = int.from_bytes(parent_key, byteorder='big') + key = (a + b) % int(BIP32_CURVE.order) + if a < BIP32_CURVE.order and key != 0: + key = key.to_bytes(32, byteorder='big') + break + d = b'\x01' + h[32:] + struct.pack('>L', i) + return key, chain_code + +def parse_derivation_path(str_derivation_path): + path = [] + if str_derivation_path[0:2] != 'm/': + raise ValueError("Can't recognize derivation path. It should look like \"m/44'/60/0'/0\".") + for i in str_derivation_path.lstrip('m/').split('/'): + if "'" in i: + path.append(BIP32_PRIVDEV + int(i[:-1])) + else: + path.append(int(i)) + return path + + +def mnemonic_to_ecdsa_private_key(mnemonic: str, str_derivation_path: str = None, passphrase: str = "") -> bytes: + + if str_derivation_path is None: + str_derivation_path = f'{ETH_DERIVATION_PATH}/0' + + derivation_path = parse_derivation_path(str_derivation_path) + bip39seed = mnemonic_to_bip39seed(mnemonic, passphrase) + master_private_key, master_chain_code = bip39seed_to_bip32masternode(bip39seed) + private_key, chain_code = master_private_key, master_chain_code + for i in derivation_path: + private_key, chain_code = derive_bip32childkey(private_key, chain_code, i) + return private_key + + +def ecdsa_sign(private_key: bytes, message: bytes) -> bytes: + signer = PrivateKey(private_key) + return signer.sign_msg(message).to_bytes() + + +def ecdsa_verify(signature: bytes, data: bytes, address: bytes) -> bool: + signature_obj = Signature(signature) + recovered_pubkey = signature_obj.recover_public_key_from_msg(data) + return recovered_pubkey.to_canonical_address() == address + +def solana_sign(private_key: bytes, message: bytes) -> bytes: + keypair = SolanaKeypair.from_seed(private_key) + return bytes(keypair.sign_message(message)) + +def solana_verify(signature: bytes, message: bytes, public_key: bytes) -> bool: + signature = SolanaSignature.from_bytes(signature) + pubkey = SolanaPubkey(public_key) + return signature.verify(pubkey, message) + +class DeriveJunction: + def __init__(self, chain_code, is_hard=False): + self.chain_code = chain_code + self.is_hard = is_hard + + @classmethod + def from_derive_path(cls, path: str, is_hard=False): + + if path.isnumeric(): + byte_length = ceil(int(path).bit_length() / 8) + chain_code = int(path).to_bytes(byte_length, 'little').ljust(32, b'\x00') + + else: + path_scale = Bytes() + path_scale.encode(path) + + if len(path_scale.data) > JUNCTION_ID_LEN: + chain_code = blake2b(path_scale.data.data, digest_size=32).digest() + else: + chain_code = bytes(path_scale.data.data.ljust(32, b'\x00')) + + return cls(chain_code=chain_code, is_hard=is_hard) + +def extract_derive_path(derive_path: str): + + path_check = '' + junctions = [] + paths = re.findall(RE_JUNCTION, derive_path) + + if paths: + path_check = ''.join(''.join(path) for path in paths) + + for path_separator, path_value in paths: + junctions.append(DeriveJunction.from_derive_path( + path=path_value, is_hard=path_separator == '//') + ) + + if path_check != derive_path: + raise ValueError('Reconstructed path "{}" does not match input'.format(path_check)) + + return junctions + + +def decode_pair_from_encrypted_json(json_data: Union[str, dict], passphrase: str) -> tuple: + """ + Decodes encrypted PKCS#8 message from PolkadotJS JSON format + + Parameters + ---------- + json_data + passphrase + + Returns + ------- + tuple containing private and public key + """ + if type(json_data) is str: + json_data = json.loads(json_data) + + # Check requirements + if json_data.get('encoding', {}).get('version') != "3": + raise ValueError("Unsupported JSON format") + + encrypted = base64.b64decode(json_data['encoded']) + + if 'scrypt' in json_data['encoding']['type']: + salt = encrypted[0:32] + n = int.from_bytes(encrypted[32:36], byteorder='little') + p = int.from_bytes(encrypted[36:40], byteorder='little') + r = int.from_bytes(encrypted[40:44], byteorder='little') + + password = scrypt(passphrase.encode(), salt, n=n, r=r, p=p, dklen=32, maxmem=2 ** 26) + encrypted = encrypted[SCRYPT_LENGTH:] + + else: + password = passphrase.encode().rjust(32, b'\x00') + + if "xsalsa20-poly1305" not in json_data['encoding']['type']: + raise ValueError("Unsupported encoding type") + + nonce = encrypted[0:NONCE_LENGTH] + message = encrypted[NONCE_LENGTH:] + + secret_box = SecretBox(key=password) + decrypted = secret_box.decrypt(message, nonce) + + # Decode PKCS8 message + secret_key, public_key = decode_pkcs8(decrypted) + + if 'sr25519' in json_data['encoding']['content']: + # Secret key from PolkadotJS is an Ed25519 expanded secret key, so has to be converted + # https://github.com/polkadot-js/wasm/blob/master/packages/wasm-crypto/src/rs/sr25519.rs#L125 + converted_public_key, secret_key = pair_from_ed25519_secret_key(secret_key) + assert(public_key == converted_public_key) + + return secret_key, public_key + + +def decode_pkcs8(ciphertext: bytes) -> tuple: + current_offset = 0 + + header = ciphertext[current_offset:len(PKCS8_HEADER)] + if header != PKCS8_HEADER: + raise ValueError("Invalid Pkcs8 header found in body") + + current_offset += len(PKCS8_HEADER) + + secret_key = ciphertext[current_offset:current_offset + SEC_LENGTH] + current_offset += SEC_LENGTH + + divider = ciphertext[current_offset:current_offset + len(PKCS8_DIVIDER)] + + if divider != PKCS8_DIVIDER: + raise ValueError("Invalid Pkcs8 divider found in body") + + current_offset += len(PKCS8_DIVIDER) + + public_key = ciphertext[current_offset: current_offset + PUB_LENGTH] + + return secret_key, public_key + +def encode_pkcs8(public_key: bytes, private_key: bytes) -> bytes: + return PKCS8_HEADER + private_key + PKCS8_DIVIDER + public_key + +def encode_pair(public_key: bytes, private_key: bytes, passphrase: str) -> bytes: + """ + Encode a public/private pair to PKCS#8 format, encrypted with provided passphrase + + Parameters + ---------- + public_key: 32 bytes public key + private_key: 64 bytes private key + passphrase: passphrase to encrypt the PKCS#8 message + + Returns + ------- + (Encrypted) PKCS#8 message bytes + """ + message = encode_pkcs8(public_key, private_key) + + salt = urandom(SALT_LENGTH) + password = scrypt(passphrase.encode(), salt, n=SCRYPT_N, r=SCRYPT_R, p=SCRYPT_P, dklen=32, maxmem=2 ** 26) + + secret_box = SecretBox(key=password) + message = secret_box.encrypt(message) + + scrypt_params = SCRYPT_N.to_bytes(4, 'little') + SCRYPT_P.to_bytes(4, 'little') + SCRYPT_R.to_bytes(4, 'little') + + return salt + scrypt_params + message.nonce + message.ciphertext + From 88186f2f252cdd06529edc899216a9deb55f4e77 Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:46:31 -0600 Subject: [PATCH 15/27] chore: define init file --- commune/key/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 commune/key/__init__.py diff --git a/commune/key/__init__.py b/commune/key/__init__.py new file mode 100644 index 000000000..e69de29bb From a207f9b3327c2e2cde780c5c09b314c851e75e97 Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:47:27 -0600 Subject: [PATCH 16/27] chore: change key import --- commune/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/commune/__init__.py b/commune/__init__.py index 9b3a1a0bf..bd85c776f 100755 --- a/commune/__init__.py +++ b/commune/__init__.py @@ -6,7 +6,7 @@ from .vali import Vali # the vali module from .server import Server # the server module from .client import Client # the client module -from .key import Key # the key module +from .key.key import Key # the key module # set the module functions as globalsw c.add_to_globals(globals()) key = c.get_key # override key function with file key in commune/key.py TODO: remove this line with a better solution From 8411412dd6eef6c1ecf807501556474581ab7302 Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:47:57 -0600 Subject: [PATCH 17/27] chore: change key import --- commune/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/commune/module.py b/commune/module.py index 486f7c118..54c97f108 100755 --- a/commune/module.py +++ b/commune/module.py @@ -313,7 +313,7 @@ def is_module_folder(cls, module = None) -> bool: @classmethod def get_key(cls,key:str = None , **kwargs) -> None: - from commune.key import Key + from commune.key.key import Key if not isinstance(key, str) and hasattr(key,"module_name" ): key = key.module_name() return Key.get_key(key, **kwargs) From 2abd4caefe3e499234eef2bdcb3fa707ec949a98 Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:48:33 -0600 Subject: [PATCH 18/27] ref: test all key types --- tests/test_key.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_key.py b/tests/test_key.py index ae5fda6b6..badef093e 100644 --- a/tests/test_key.py +++ b/tests/test_key.py @@ -75,7 +75,7 @@ def test_key_encryption(password='1234'): return {'success': True, 'msg': 'test_key_encryption passed'} def test_move_key(): - self = c.module('key')() + self = c.module('key')(crypto_type=crypto_type) self.add_key('testfrom', crypto_type=crypto_type) assert self.key_exists('testfrom') og_key = self.get_key('testfrom', crypto_type=crypto_type) From 93af4311c50d597b85500ee3d43acfb308eccc74 Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:49:12 -0600 Subject: [PATCH 19/27] ref: remove origin key file --- commune/key.py | 1468 ------------------------------------------------ 1 file changed, 1468 deletions(-) delete mode 100644 commune/key.py diff --git a/commune/key.py b/commune/key.py deleted file mode 100644 index b04ddfca0..000000000 --- a/commune/key.py +++ /dev/null @@ -1,1468 +0,0 @@ - -import json -from typing import Union, Optional -import time -import os -import binascii -import re -import secrets -import base64 -from base64 import b64encode -import hashlib -from Crypto import Random -from Crypto.Cipher import AES -import nacl.bindings -import nacl.public -from eth_keys.datatypes import PrivateKey -from scalecodec.utils.ss58 import ss58_encode, ss58_decode, get_ss58_format -from scalecodec.base import ScaleBytes -from bip39 import bip39_to_mini_secret, bip39_generate, bip39_validate -import sr25519 -import ed25519_zebra -import commune as c -import re -from hashlib import blake2b -from math import ceil -from scalecodec.utils.ss58 import ss58_decode, ss58_encode, is_valid_ss58_address, get_ss58_format -import base64 -import json -from os import urandom -from typing import Union -from nacl.hashlib import scrypt -from nacl.secret import SecretBox -from sr25519 import pair_from_ed25519_secret_key -from scalecodec.types import Bytes -import hashlib -import hmac -import struct -from ecdsa.curves import SECP256k1 -from eth_keys.datatypes import Signature, PrivateKey -from eth_utils import to_checksum_address, keccak as eth_utils_keccak -from solders.keypair import Keypair as SolanaKeypair -from solders.signature import Signature as SolanaSignature -from solders.pubkey import Pubkey as SolanaPubkey -from base58 import b58encode - -BIP39_PBKDF2_ROUNDS = 2048 -BIP39_SALT_MODIFIER = "mnemonic" -BIP32_PRIVDEV = 0x80000000 -BIP32_CURVE = SECP256k1 -BIP32_SEED_MODIFIER = b'Bitcoin seed' -ETH_DERIVATION_PATH = "m/44'/60'/0'/0" - -class PublicKey: - def __init__(self, private_key): - self.point = int.from_bytes(private_key, byteorder='big') * BIP32_CURVE.generator - - def __bytes__(self): - xstr = int(self.point.x()).to_bytes(32, byteorder='big') - parity = int(self.point.y()) & 1 - return (2 + parity).to_bytes(1, byteorder='big') + xstr - - def address(self): - x = int(self.point.x()) - y = int(self.point.y()) - s = x.to_bytes(32, 'big') + y.to_bytes(32, 'big') - return to_checksum_address(eth_utils_keccak(s)[12:]) - -def mnemonic_to_bip39seed(mnemonic, passphrase): - mnemonic = bytes(mnemonic, 'utf8') - salt = bytes(BIP39_SALT_MODIFIER + passphrase, 'utf8') - return hashlib.pbkdf2_hmac('sha512', mnemonic, salt, BIP39_PBKDF2_ROUNDS) - -def bip39seed_to_bip32masternode(seed): - h = hmac.new(BIP32_SEED_MODIFIER, seed, hashlib.sha512).digest() - key, chain_code = h[:32], h[32:] - return key, chain_code - -def derive_bip32childkey(parent_key, parent_chain_code, i): - assert len(parent_key) == 32 - assert len(parent_chain_code) == 32 - k = parent_chain_code - if (i & BIP32_PRIVDEV) != 0: - key = b'\x00' + parent_key - else: - key = bytes(PublicKey(parent_key)) - d = key + struct.pack('>L', i) - while True: - h = hmac.new(k, d, hashlib.sha512).digest() - key, chain_code = h[:32], h[32:] - a = int.from_bytes(key, byteorder='big') - b = int.from_bytes(parent_key, byteorder='big') - key = (a + b) % int(BIP32_CURVE.order) - if a < BIP32_CURVE.order and key != 0: - key = key.to_bytes(32, byteorder='big') - break - d = b'\x01' + h[32:] + struct.pack('>L', i) - return key, chain_code - -def parse_derivation_path(str_derivation_path): - path = [] - if str_derivation_path[0:2] != 'm/': - raise ValueError("Can't recognize derivation path. It should look like \"m/44'/60/0'/0\".") - for i in str_derivation_path.lstrip('m/').split('/'): - if "'" in i: - path.append(BIP32_PRIVDEV + int(i[:-1])) - else: - path.append(int(i)) - return path - - -def mnemonic_to_ecdsa_private_key(mnemonic: str, str_derivation_path: str = None, passphrase: str = "") -> bytes: - - if str_derivation_path is None: - str_derivation_path = f'{ETH_DERIVATION_PATH}/0' - - derivation_path = parse_derivation_path(str_derivation_path) - bip39seed = mnemonic_to_bip39seed(mnemonic, passphrase) - master_private_key, master_chain_code = bip39seed_to_bip32masternode(bip39seed) - private_key, chain_code = master_private_key, master_chain_code - for i in derivation_path: - private_key, chain_code = derive_bip32childkey(private_key, chain_code, i) - return private_key - - -def ecdsa_sign(private_key: bytes, message: bytes) -> bytes: - signer = PrivateKey(private_key) - return signer.sign_msg(message).to_bytes() - - -def ecdsa_verify(signature: bytes, data: bytes, address: bytes) -> bool: - signature_obj = Signature(signature) - recovered_pubkey = signature_obj.recover_public_key_from_msg(data) - return recovered_pubkey.to_canonical_address() == address - -def solana_sign(private_key: bytes, message: bytes) -> bytes: - keypair = SolanaKeypair.from_seed(private_key) - return bytes(keypair.sign_message(message)) - -def solana_verify(signature: bytes, message: bytes, public_key: bytes) -> bool: - signature = SolanaSignature.from_bytes(signature) - pubkey = SolanaPubkey(public_key) - return signature.verify(pubkey, message) - -NONCE_LENGTH = 24 -SCRYPT_LENGTH = 32 + (3 * 4) -PKCS8_DIVIDER = bytes([161, 35, 3, 33, 0]) -PKCS8_HEADER = bytes([48, 83, 2, 1, 1, 48, 5, 6, 3, 43, 101, 112, 4, 34, 4, 32]) -PUB_LENGTH = 32 -SALT_LENGTH = 32 -SEC_LENGTH = 64 -SEED_LENGTH = 32 - -SCRYPT_N = 1 << 15 -SCRYPT_P = 1 -SCRYPT_R = 8 - - -def decode_pair_from_encrypted_json(json_data: Union[str, dict], passphrase: str) -> tuple: - """ - Decodes encrypted PKCS#8 message from PolkadotJS JSON format - - Parameters - ---------- - json_data - passphrase - - Returns - ------- - tuple containing private and public key - """ - if type(json_data) is str: - json_data = json.loads(json_data) - - # Check requirements - if json_data.get('encoding', {}).get('version') != "3": - raise ValueError("Unsupported JSON format") - - encrypted = base64.b64decode(json_data['encoded']) - - if 'scrypt' in json_data['encoding']['type']: - salt = encrypted[0:32] - n = int.from_bytes(encrypted[32:36], byteorder='little') - p = int.from_bytes(encrypted[36:40], byteorder='little') - r = int.from_bytes(encrypted[40:44], byteorder='little') - - password = scrypt(passphrase.encode(), salt, n=n, r=r, p=p, dklen=32, maxmem=2 ** 26) - encrypted = encrypted[SCRYPT_LENGTH:] - - else: - password = passphrase.encode().rjust(32, b'\x00') - - if "xsalsa20-poly1305" not in json_data['encoding']['type']: - raise ValueError("Unsupported encoding type") - - nonce = encrypted[0:NONCE_LENGTH] - message = encrypted[NONCE_LENGTH:] - - secret_box = SecretBox(key=password) - decrypted = secret_box.decrypt(message, nonce) - - # Decode PKCS8 message - secret_key, public_key = decode_pkcs8(decrypted) - - if 'sr25519' in json_data['encoding']['content']: - # Secret key from PolkadotJS is an Ed25519 expanded secret key, so has to be converted - # https://github.com/polkadot-js/wasm/blob/master/packages/wasm-crypto/src/rs/sr25519.rs#L125 - converted_public_key, secret_key = pair_from_ed25519_secret_key(secret_key) - assert(public_key == converted_public_key) - - return secret_key, public_key - - -def decode_pkcs8(ciphertext: bytes) -> tuple: - current_offset = 0 - - header = ciphertext[current_offset:len(PKCS8_HEADER)] - if header != PKCS8_HEADER: - raise ValueError("Invalid Pkcs8 header found in body") - - current_offset += len(PKCS8_HEADER) - - secret_key = ciphertext[current_offset:current_offset + SEC_LENGTH] - current_offset += SEC_LENGTH - - divider = ciphertext[current_offset:current_offset + len(PKCS8_DIVIDER)] - - if divider != PKCS8_DIVIDER: - raise ValueError("Invalid Pkcs8 divider found in body") - - current_offset += len(PKCS8_DIVIDER) - - public_key = ciphertext[current_offset: current_offset + PUB_LENGTH] - - return secret_key, public_key - -def encode_pkcs8(public_key: bytes, private_key: bytes) -> bytes: - return PKCS8_HEADER + private_key + PKCS8_DIVIDER + public_key - -def encode_pair(public_key: bytes, private_key: bytes, passphrase: str) -> bytes: - """ - Encode a public/private pair to PKCS#8 format, encrypted with provided passphrase - - Parameters - ---------- - public_key: 32 bytes public key - private_key: 64 bytes private key - passphrase: passphrase to encrypt the PKCS#8 message - - Returns - ------- - (Encrypted) PKCS#8 message bytes - """ - message = encode_pkcs8(public_key, private_key) - - salt = urandom(SALT_LENGTH) - password = scrypt(passphrase.encode(), salt, n=SCRYPT_N, r=SCRYPT_R, p=SCRYPT_P, dklen=32, maxmem=2 ** 26) - - secret_box = SecretBox(key=password) - message = secret_box.encrypt(message) - - scrypt_params = SCRYPT_N.to_bytes(4, 'little') + SCRYPT_P.to_bytes(4, 'little') + SCRYPT_R.to_bytes(4, 'little') - - return salt + scrypt_params + message.nonce + message.ciphertext - - - - -JUNCTION_ID_LEN = 32 -RE_JUNCTION = r'(\/\/?)([^/]+)' - - -class DeriveJunction: - def __init__(self, chain_code, is_hard=False): - self.chain_code = chain_code - self.is_hard = is_hard - - @classmethod - def from_derive_path(cls, path: str, is_hard=False): - - if path.isnumeric(): - byte_length = ceil(int(path).bit_length() / 8) - chain_code = int(path).to_bytes(byte_length, 'little').ljust(32, b'\x00') - - else: - path_scale = Bytes() - path_scale.encode(path) - - if len(path_scale.data) > JUNCTION_ID_LEN: - chain_code = blake2b(path_scale.data.data, digest_size=32).digest() - else: - chain_code = bytes(path_scale.data.data.ljust(32, b'\x00')) - - return cls(chain_code=chain_code, is_hard=is_hard) - - -def extract_derive_path(derive_path: str): - - path_check = '' - junctions = [] - paths = re.findall(RE_JUNCTION, derive_path) - - if paths: - path_check = ''.join(''.join(path) for path in paths) - - for path_separator, path_value in paths: - junctions.append(DeriveJunction.from_derive_path( - path=path_value, is_hard=path_separator == '//') - ) - - if path_check != derive_path: - raise ValueError('Reconstructed path "{}" does not match input'.format(path_check)) - - return junctions - -DEV_PHRASE = 'bottom drive obey lake curtain smoke basket hold race lonely fit walk' - -class KeyType: - ED25519 = 0 - SR25519 = 1 - ECDSA = 2 - SOLANA = 3 -KeyType.crypto_types = [k for k in KeyType.__dict__.keys() if not k.startswith('_')] -KeyType.crypto_type_map = {k.lower():v for k,v in KeyType.__dict__.items() if k in KeyType.crypto_types } -KeyType.crypto_types = list(KeyType.crypto_type_map.keys()) - -class Key(c.Module): - crypto_types = KeyType.crypto_types - crypto_type_map = KeyType.crypto_type_map - crypto_types = list(crypto_type_map.keys()) - ss58_format = 42 - crypto_type = 'sr25519' - def __init__(self, - private_key: Union[bytes, str] = None, - ss58_format: int = ss58_format, - crypto_type: int = crypto_type, - derive_path: str = None, - path:str = None, - **kwargs): - self.set_private_key(private_key=private_key, - ss58_format=ss58_format, - crypto_type=crypto_type, - derive_path=derive_path, - path=path, **kwargs) - - - @property - def short_address(self): - n = 4 - return self.ss58_address[:n] + '...' + self.ss58_address[-n:] - - def set_crypto_type(self, crypto_type): - crypto_type = self.resolve_crypto_type(crypto_type) - if crypto_type != self.crypto_type: - kwargs = { - 'private_key': self.private_key, - 'ss58_format': self.ss58_format, - 'derive_path': self.derive_path, - 'path': self.path, - 'crypto_type': crypto_type # update crypto_type - } - return self.set_private_key(**kwargs) - else: - return {'success': False, 'message': f'crypto_type already set to {crypto_type}'} - - def set_private_key(self, - private_key: Union[bytes, str] = None, - ss58_format: int = ss58_format, - crypto_type: int = crypto_type, - derive_path: str = None, - path:str = None, - **kwargs - ): - """ - Allows generation of Keys from a variety of input combination, such as a public/private key combination, - mnemonic or URI containing soft and hard derivation paths. With these Keys data can be signed and verified - - Parameters - ---------- - ss58_address: Substrate address - public_key: hex string or bytes of public_key key - private_key: hex string or bytes of private key - ss58_format: Substrate address format, default to 42 when omitted - seed_hex: hex string of seed - crypto_type: Use KeyType.SR25519 or KeyType.ED25519 cryptography for generating the Key - """ - crypto_type = self.resolve_crypto_type(crypto_type) - # If no arguments are provided, generate a random keypair - if private_key == None: - private_key = self.new_key(crypto_type=crypto_type).private_key - if type(private_key) == str: - private_key = c.str2bytes(private_key) - crypto_type = self.resolve_crypto_type(crypto_type) - if crypto_type == KeyType.SR25519: - if len(private_key) != 64: - private_key = sr25519.pair_from_seed(private_key)[1] - public_key = sr25519.public_from_secret_key(private_key) - key_address = ss58_encode(public_key, ss58_format=ss58_format) - hash_type = 'ss58' - elif crypto_type == KeyType.ED25519: - private_key = private_key[:32] if len(private_key) == 64 else private_key - public_key, private_key = ed25519_zebra.ed_from_seed(private_key) - key_address = ss58_encode(public_key, ss58_format=ss58_format) - hash_type = 'ss58' - elif crypto_type == KeyType.ECDSA: - private_key = private_key[0:32] - private_key_obj = PrivateKey(private_key) - public_key = private_key_obj.public_key.to_address() - key_address = private_key_obj.public_key.to_checksum_address() - hash_type = 'h160' - elif crypto_type == KeyType.SOLANA: - private_key = private_key[0:32] - keypair = SolanaKeypair.from_seed(private_key) - public_key = keypair.pubkey().__bytes__() - private_key = keypair.secret() - key_address = b58encode(bytes(public_key)).decode('utf-8') - hash_type = 'base58' - else: - raise ValueError('crypto_type "{}" not supported'.format(crypto_type)) - if type(public_key) is str: - public_key = bytes.fromhex(public_key.replace('0x', '')) - - self.hash_type = hash_type - self.public_key = public_key - self.address = self.key_address = self.ss58_address = key_address - self.private_key = private_key - self.crypto_type = crypto_type - self.derive_path = derive_path - self.path = path - self.ss58_format = ss58_format - self.key_address = self.ss58_address - self.key_type = self.crypto_type2name(self.crypto_type) - return {'key_address':key_address, 'crypto_type':crypto_type} - - @classmethod - def add_key(cls, path:str, mnemonic:str = None, password:str=None, refresh:bool=False, private_key=None, **kwargs): - if cls.key_exists(path) and not refresh : - c.print(f'key already exists at {path}') - return cls.get(path) - key = cls.new_key(mnemonic=mnemonic, private_key=private_key, **kwargs) - key.path = path - key_json = key.to_json() - if password != None: - key_json = cls.encrypt(data=key_json, password=password) - c.print(cls.put(path, key_json)) - cls.update() - return json.loads(key_json) - - @classmethod - def ticket(cls , data=None, key=None, **kwargs): - return cls.get_key(key).sign({'data':data, 'time': c.time()} , to_json=True, **kwargs) - - @classmethod - def mv_key(cls, path, new_path, crypto_type='sr25519'): - assert cls.key_exists(path), f'key does not exist at {path}' - cls.put(new_path, cls.get_key(path, crypto_type=crypto_type).to_json()) - cls.rm_key(path) - assert cls.key_exists(new_path), f'key does not exist at {new_path}' - assert not cls.key_exists(path), f'key still exists at {path}' - new_key = cls.get_key(new_path, crypto_type=crypto_type) - return {'success': True, 'from': path , 'to': new_path, 'key': new_key} - - @classmethod - def copy_key(cls, path, new_path): - assert cls.key_exists(path), f'key does not exist at {path}' - cls.put(new_path, cls.get_key(path).to_json()) - assert cls.key_exists(new_path), f'key does not exist at {new_path}' - assert cls.get_key(path) == cls.get_key(new_path), f'key does not match' - new_key = cls.get_key(new_path) - return {'success': True, 'from': path , 'to': new_path, 'key': new_key} - - - @classmethod - def add_keys(cls, name, n=100, verbose:bool = False, **kwargs): - response = [] - for i in range(n): - key_name = f'{name}.{i}' - if bool == True: - c.print(f'generating key {key_name}') - response.append(cls.add_key(key_name, **kwargs)) - - return response - - def key2encrypted(self): - keys = self.keys() - key2encrypted = {} - for k in keys: - key2encrypted[k] = self.is_key_encrypted(k) - return key2encrypted - - def encrypted_keys(self): - return [k for k,v in self.key2encrypted().items() if v == True] - - @classmethod - def key_info(cls, path='module', **kwargs): - return cls.get_key_json(path) - - @classmethod - def load_key(cls, path=None): - key_info = cls.get(path) - key_info = c.jload(key_info) - if key_info['path'] == None: - key_info['path'] = path.replace('.json', '').split('/')[-1] - - cls.add_key(**key_info) - return {'status': 'success', 'message': f'key loaded from {path}'} - - - @classmethod - def save_keys(cls, path='saved_keys.json', **kwargs): - path = cls.resolve_path(path) - c.print(f'saving mems to {path}') - key2mnemonic = cls.key2mnemonic() - c.put_json(path, key2mnemonic) - return {'success': True, 'msg': 'saved keys', 'path':path, 'n': len(key2mnemonic)} - - @classmethod - def load_keys(cls, path='saved_keys.json', refresh=False, **kwargs): - key2mnemonic = c.get_json(path) - for k,mnemonic in key2mnemonic.items(): - try: - cls.add_key(k, mnemonic=mnemonic, refresh=refresh, **kwargs) - except Exception as e: - # c.print(f'failed to load mem {k} due to {e}', color='red') - pass - return {'loaded_mems':list(key2mnemonic.keys()), 'path':path} - loadkeys = loadmems = load_keys - - @classmethod - def key2mnemonic(cls, search=None) -> dict[str, str]: - """ - keyname (str) --> mnemonic (str) - - """ - mems = {} - for key in cls.keys(search): - try: - mems[key] = cls.get_mnemonic(key) - except Exception as e: - c.print(f'failed to get mem for {key} due to {e}') - if search: - mems = {k:v for k,v in mems.items() if search in k or search in v} - return mems - - @classmethod - def get_key(cls, - path:str,password:str=None, - create_if_not_exists:bool = True, - crypto_type=crypto_type, - **kwargs): - for k in ['crypto_type', 'key_type', 'type']: - if k in kwargs: - crypto_type = kwargs.pop(k) - break - if hasattr(path, 'key_address'): - key = path - return key - path = path or 'module' - # if ss58_address is provided, get key from address - if cls.valid_ss58_address(path): - path = cls.address2key().get(path) - if not cls.key_exists(path): - if create_if_not_exists: - key = cls.add_key(path, **kwargs) - c.print(f'key does not exist, generating new key -> {key["path"]}') - else: - print(path) - raise ValueError(f'key does not exist at --> {path}') - key_json = cls.get(path) - # if key is encrypted, decrypt it - if cls.is_encrypted(key_json): - key_json = c.decrypt(data=key_json, password=password) - if key_json == None: - c.print({'status': 'error', 'message': f'key is encrypted, please {path} provide password'}) - return None - key_json = c.jload(key_json) if isinstance(key_json, str) else key_json - key = cls.from_json(key_json, crypto_type=crypto_type) - key.path = path - return key - - @classmethod - def get_keys(cls, search=None, clean_failed_keys=False): - keys = {} - for key in cls.keys(): - if str(search) in key or search == None: - try: - keys[key] = cls.get_key(key) - except Exception as e: - continue - if keys[key] == None: - if clean_failed_keys: - cls.rm_key(key) - keys.pop(key) - return keys - - @classmethod - def key2address(cls, search=None, max_age=10, update=False, **kwargs): - path = 'key2address' - key2address = cls.get(path, None, max_age=max_age, update=update) - if key2address == None: - key2address = { k: v.ss58_address for k,v in cls.get_keys(search).items()} - cls.put(path, key2address) - return key2address - - @classmethod - def n(cls, search=None, **kwargs): - return len(cls.key2address(search, **kwargs)) - - @classmethod - def address2key(cls, search:Optional[str]=None, update:bool=False): - address2key = { v: k for k,v in cls.key2address(update=update).items()} - if search != None : - return address2key.get(search, None) - return address2key - - @classmethod - def get_address(cls, key): - return cls.get_key(key).ss58_address - get_addy = get_address - @classmethod - def key_paths(cls): - return cls.ls() - address_seperator = '_address=' - @classmethod - def key2path(cls) -> dict: - """ - defines the path for each key - """ - path2key_fn = lambda path: '.'.join(path.split('/')[-1].split('.')[:-1]) - key2path = {path2key_fn(path):path for path in cls.key_paths()} - return key2path - - @classmethod - def keys(cls, search : str = None, **kwargs): - keys = list(cls.key2path().keys()) - if search != None: - keys = [key for key in keys if search in key] - return keys - - @classmethod - def n(cls, *args, **kwargs): - return len(cls.key2address(*args, **kwargs)) - - @classmethod - def key_exists(cls, key, **kwargs): - path = cls.get_key_path(key) - import os - return os.path.exists(path) - - @classmethod - def get_key_path(cls, key): - storage_dir = cls.storage_dir() - key_path = storage_dir + '/' + key + '.json' - return key_path - @classmethod - def get_key_json(cls, key): - storage_dir = cls.storage_dir() - key_path = storage_dir + '/' + key + '.json' - return c.get(key_path) - @classmethod - def get_key_address(cls, key): - return cls.get_key(key).ss58_address - - @classmethod - def rm_key(cls, key=None): - key2path = cls.key2path() - keys = list(key2path.keys()) - if key not in keys: - raise Exception(f'key {key} not found, available keys: {keys}') - c.rm(key2path[key]) - return {'deleted':[key]} - - - - @classmethod - def crypto_name2type(cls, name:str): - crypto_type_map = cls.crypto_type_map - name = name.lower() - if not name in crypto_type_map: - raise ValueError(f'crypto_type {name} not supported {crypto_type_map}') - return crypto_type_map[name] - - @classmethod - def crypto_type2name(cls, crypto_type:str): - crypto_type_map ={v:k for k,v in cls.crypto_type_map.items()} - return crypto_type_map[crypto_type] - - @classmethod - def resolve_crypto_type_name(cls, crypto_type): - return cls.crypto_type2name(cls.resolve_crypto_type(crypto_type)) - - @classmethod - def resolve_crypto_type(cls, crypto_type): - if isinstance(crypto_type, int) or (isinstance(crypto_type, str) and c.is_int(crypto_type)): - crypto_type = int(crypto_type) - crypto_type_map = cls.crypto_type_map - reverse_crypto_type_map = {v:k for k,v in crypto_type_map.items()} - assert crypto_type in reverse_crypto_type_map, f'crypto_type {crypto_type} not supported {crypto_type_map}' - crypto_type = reverse_crypto_type_map[crypto_type] - if isinstance(crypto_type, str): - crypto_type = crypto_type.lower() - crypto_type = cls.crypto_name2type(crypto_type) - return int(crypto_type) - - @classmethod - def new_private_key(cls, crypto_type='ecdsa'): - return cls.new_key(crypto_type=crypto_type).private_key.hex() - - @classmethod - def new_key(cls, - mnemonic:str = None, - suri:str = None, - private_key: str = None, - crypto_type: Union[int,str] = 'sr25519', - verbose:bool=False, - **kwargs): - ''' - yo rody, this is a class method you can gen keys whenever fam - ''' - if verbose: - c.print(f'generating {crypto_type} keypair, {suri}') - - crypto_type = cls.resolve_crypto_type(crypto_type) - if suri: - key = cls.create_from_uri(suri, crypto_type=crypto_type) - elif mnemonic: - key = cls.create_from_mnemonic(mnemonic, crypto_type=crypto_type) - elif private_key: - key = cls.create_from_private_key(private_key,crypto_type=crypto_type) - else: - mnemonic = cls.generate_mnemonic() - key = cls.create_from_mnemonic(mnemonic, crypto_type=crypto_type) - - return key - - create = gen = new_key - - def to_json(self, password: str = None ) -> dict: - state_dict = c.copy(self.__dict__) - for k,v in state_dict.items(): - if type(v) in [bytes]: - state_dict[k] = v.hex() - if password != None: - state_dict[k] = self.encrypt(data=state_dict[k], password=password) - if '_ss58_address' in state_dict: - state_dict['ss58_address'] = state_dict.pop('_ss58_address') - - state_dict = json.dumps(state_dict) - - return state_dict - - @classmethod - def from_json(cls, obj: Union[str, dict], password: str = None, crypto_type=None) -> dict: - if type(obj) == str: - obj = json.loads(obj) - if obj == None: - return None - for k,v in obj.items(): - if cls.is_encrypted(obj[k]) and password != None: - obj[k] = cls.decrypt(data=obj[k], password=password) - if 'ss58_address' in obj: - obj['_ss58_address'] = obj.pop('ss58_address') - if crypto_type != None: - obj['crypto_type'] = crypto_type - return cls(**obj) - - - @classmethod - def generate_mnemonic(cls, words: int = 12, language_code: str = "en") -> str: - """ - params: - words: The amount of words to generate, valid values are 12, 15, 18, 21 and 24 - language_code: The language to use, valid values are: 'en', 'zh-hans', - 'zh-hant', 'fr', 'it', 'ja', 'ko', 'es'. - Defaults to `"en"` - """ - mnemonic = bip39_generate(words, language_code) - assert cls.validate_mnemonic(mnemonic, language_code), 'mnemonic is invalid' - return mnemonic - - @classmethod - def validate_mnemonic(cls, mnemonic: str, language_code: str = "en") -> bool: - """ - Verify if specified mnemonic is valid - - Parameters - ---------- - mnemonic: Seed phrase - language_code: The language to use, valid values are: 'en', 'zh-hans', 'zh-hant', 'fr', 'it', 'ja', 'ko', 'es'. Defaults to `"en"` - - Returns - ------- - bool - """ - return bip39_validate(mnemonic, language_code) - - - @classmethod - def create_from_mnemonic(cls, mnemonic: str = None, ss58_format=ss58_format, crypto_type=KeyType.SR25519, language_code: str = "en") -> 'Key': - """ - Create a Key for given memonic - - Parameters - ---------- - mnemonic: Seed phrase - ss58_format: Substrate address format - crypto_type: Use `KeyType.SR25519` or `KeyType.ED25519` cryptography for generating the Key - language_code: The language to use, valid values are: 'en', 'zh-hans', 'zh-hant', 'fr', 'it', 'ja', 'ko', 'es'. Defaults to `"en"` - - Returns - ------- - Key - """ - if not mnemonic: - mnemonic = cls.generate_mnemonic(language_code=language_code) - - if crypto_type == KeyType.ECDSA: - if language_code != "en": - raise ValueError("ECDSA mnemonic only supports english") - private_key = mnemonic_to_ecdsa_private_key(mnemonic) - keypair = cls.create_from_private_key(private_key, ss58_format=ss58_format, crypto_type=crypto_type) - elif crypto_type == KeyType.SOLANA: - private_key = SolanaKeypair.from_seed_phrase_and_passphrase(mnemonic, "").secret() - keypair = cls.create_from_private_key(private_key, ss58_format=ss58_format, crypto_type=crypto_type) - else: - keypair = cls.create_from_seed( - seed_hex=binascii.hexlify(bytearray(bip39_to_mini_secret(mnemonic, "", language_code))).decode("ascii"), - ss58_format=ss58_format, - crypto_type=crypto_type, - ) - - keypair.mnemonic = mnemonic - - return keypair - - from_mnemonic = from_mem = create_from_mnemonic - - @classmethod - def create_from_seed(cls, seed_hex: Union[bytes, str], ss58_format: Optional[int] = ss58_format, crypto_type=KeyType.SR25519) -> 'Key': - """ - Create a Key for given seed - - Parameters - ---------- - seed_hex: hex string of seed - ss58_format: Substrate address format - crypto_type: Use KeyType.SR25519 or KeyType.ED25519 cryptography for generating the Key - - Returns - ------- - Key - """ - crypto_type = cls.resolve_crypto_type(crypto_type) - if type(seed_hex) is str: - seed_hex = bytes.fromhex(seed_hex.replace('0x', '')) - if crypto_type == KeyType.SR25519: - public_key, private_key = sr25519.pair_from_seed(seed_hex) - elif crypto_type == KeyType.ED25519: - private_key, public_key = ed25519_zebra.ed_from_seed(seed_hex) - else: - raise ValueError('crypto_type "{}" not supported'.format(crypto_type)) - - ss58_address = ss58_encode(public_key, ss58_format) - - kwargs = dict( - ss58_address=ss58_address, - public_key=public_key, - private_key=private_key, - ss58_format=ss58_format, - crypto_type=crypto_type, - ) - - return cls(**kwargs) - @classmethod - def create_from_password(cls, password:str, crypto_type=2, **kwargs): - key= cls.create_from_uri(password, crypto_type=1, **kwargs) - key.set_crypto_type(crypto_type) - return key - - str2key = pwd2key = password2key = from_password = create_from_password - - @classmethod - def create_from_uri( - cls, - suri: str, - ss58_format: Optional[int] = ss58_format, - crypto_type=KeyType.SR25519, - language_code: str = "en" - ) -> 'Key': - """ - Creates Key for specified suri in following format: `[mnemonic]/[soft-path]//[hard-path]` - - Parameters - ---------- - suri: - ss58_format: Substrate address format - crypto_type: Use KeyType.SR25519 or KeyType.ED25519 cryptography for generating the Key - language_code: The language to use, valid values are: 'en', 'zh-hans', 'zh-hant', 'fr', 'it', 'ja', 'ko', 'es'. Defaults to `"en"` - - Returns - ------- - Key - """ - crypto_type = cls.resolve_crypto_type(crypto_type) - suri = str(suri) - if not suri.startswith('//'): - suri = '//' + suri - - if suri and suri.startswith('/'): - suri = DEV_PHRASE + suri - - suri_regex = re.match(r'^(?P.[^/]+( .[^/]+)*)(?P(//?[^/]+)*)(///(?P.*))?$', suri) - - suri_parts = suri_regex.groupdict() - - if crypto_type == KeyType.ECDSA: - if language_code != "en": - raise ValueError("ECDSA mnemonic only supports english") - print(suri_parts) - private_key = mnemonic_to_ecdsa_private_key( - mnemonic=suri_parts['phrase'], - str_derivation_path=suri_parts['path'], - passphrase=suri_parts['password'] - ) - derived_keypair = cls.create_from_private_key(private_key, ss58_format=ss58_format, crypto_type=crypto_type) - elif crypto_type == KeyType.SOLANA: - if language_code != "en": - raise ValueError("Solana mnemonic only supports english") - private_key = SolanaKeypair.from_seed_phrase_and_passphrase(suri_parts['phrase'], passphrase=suri_parts['password']).secret() - derived_keypair = cls.create_from_private_key(private_key, ss58_format=ss58_format, crypto_type=crypto_type) - else: - - if suri_parts['password']: - raise NotImplementedError(f"Passwords in suri not supported for crypto_type '{crypto_type}'") - - derived_keypair = cls.create_from_mnemonic( - suri_parts['phrase'], ss58_format=ss58_format, crypto_type=crypto_type, language_code=language_code - ) - - if suri_parts['path'] != '': - - derived_keypair.derive_path = suri_parts['path'] - - if crypto_type not in [KeyType.SR25519]: - raise NotImplementedError('Derivation paths for this crypto type not supported') - - derive_junctions = extract_derive_path(suri_parts['path']) - - child_pubkey = derived_keypair.public_key - child_privkey = derived_keypair.private_key - - for junction in derive_junctions: - - if junction.is_hard: - - _, child_pubkey, child_privkey = sr25519.hard_derive_keypair( - (junction.chain_code, child_pubkey, child_privkey), - b'' - ) - - else: - - _, child_pubkey, child_privkey = sr25519.derive_keypair( - (junction.chain_code, child_pubkey, child_privkey), - b'' - ) - - derived_keypair = Key(public_key=child_pubkey, private_key=child_privkey, ss58_format=ss58_format) - - return derived_keypair - from_mnem = from_mnemonic = create_from_mnemonic - @classmethod - def create_from_private_key( - cls, - private_key: Union[bytes, str], - public_key: Union[bytes, str] = None, - ss58_address: str = None, - ss58_format: int = ss58_format, - crypto_type: int = KeyType.SR25519 - ) -> 'Key': - """ - Creates Key for specified public/private keys - Parameters - ---------- - private_key: hex string or bytes of private key - public_key: hex string or bytes of public key - ss58_address: Substrate address - ss58_format: Substrate address format, default = 42 - crypto_type: Use KeyType.[SR25519|ED25519|ECDSA] cryptography for generating the Key - - Returns - ------- - Key - """ - - return cls(ss58_address=ss58_address, - public_key=public_key, - private_key=private_key, - ss58_format=ss58_format, - crypto_type=crypto_type - ) - from_private_key = create_from_private_key - - @classmethod - def create_from_encrypted_json(cls, json_data: Union[str, dict], passphrase: str, - ss58_format: int = None) -> 'Key': - """ - Create a Key from a PolkadotJS format encrypted JSON file - - Parameters - ---------- - json_data: Dict or JSON string containing PolkadotJS export format - passphrase: Used to encrypt the keypair - ss58_format: Which network ID to use to format the SS58 address (42 for testnet) - - Returns - ------- - Key - """ - - crypto_type = cls.resolve_crypto_type(crypto_type) - - if type(json_data) is str: - json_data = json.loads(json_data) - - private_key, public_key = decode_pair_from_encrypted_json(json_data, passphrase) - - if 'sr25519' in json_data['encoding']['content']: - crypto_type = KeyType.SR25519 - elif 'ed25519' in json_data['encoding']['content']: - crypto_type = KeyType.ED25519 - # Strip the nonce part of the private key - private_key = private_key[0:32] - elif 'solana' in json_data['encoding']['content']: - crypto_type = KeyType.SOLANA - private_key = private_key[0:32] - else: - raise NotImplementedError("Unknown KeyType found in JSON") - - if ss58_format is None and 'address' in json_data: - ss58_format = get_ss58_format(json_data['address']) - - return cls.create_from_private_key(private_key, public_key, ss58_format=ss58_format, crypto_type=crypto_type) - - def export_to_encrypted_json(self, passphrase: str, name: str = None) -> dict: - """ - Export Key to PolkadotJS format encrypted JSON file - - Parameters - ---------- - passphrase: Used to encrypt the keypair - name: Display name of Key used - - Returns - ------- - dict - """ - if not name: - name = self.ss58_address - - if self.crypto_type == KeyType.SR25519: - # Secret key from PolkadotJS is an Ed25519 expanded secret key, so has to be converted - # https://github.com/polkadot-js/wasm/blob/master/packages/wasm-crypto/src/rs/sr25519.rs#L125 - converted_private_key = sr25519.convert_secret_key_to_ed25519(self.private_key) - encoded = encode_pair(self.public_key, converted_private_key, passphrase) - encoding_content = ["pkcs8", "sr25519"] - elif self.crypto_type == KeyType.SOLANA: - keypair = SolanaKeypair.from_seed(self.private_key) - encoded = encode_pair(self.public_key, keypair.secret(), passphrase) - encoding_content = ["pkcs8", "solana"] - else: - raise NotImplementedError(f"Cannot create JSON for crypto_type '{self.crypto_type}'") - - json_data = { - "encoded": b64encode(encoded).decode(), - "encoding": {"content": encoding_content, "type": ["scrypt", "xsalsa20-poly1305"], "version": "3"}, - "address": self.ss58_address, - "meta": { - "name": name, "tags": [], "whenCreated": int(time.time()) - } - } - return json_data - - seperator = "::signature=" - - - def sign(self, data: Union[ScaleBytes, bytes, str], to_json = False) -> bytes: - """ - Creates a signature for given data - Parameters - ---------- - data: data to sign in `Scalebytes`, bytes or hex string format - Returns - ------- - signature in bytes - - """ - if not isinstance(data, str): - data = c.python2str(data) - if type(data) is ScaleBytes: - data = bytes(data.data) - elif data[0:2] == '0x': - data = bytes.fromhex(data[2:]) - elif type(data) is str: - data = data.encode() - if not self.private_key: - raise Exception('No private key set to create signatures') - if self.crypto_type == KeyType.SR25519: - signature = sr25519.sign((self.public_key, self.private_key), data) - elif self.crypto_type == KeyType.ED25519: - signature = ed25519_zebra.ed_sign(self.private_key, data) - elif self.crypto_type == KeyType.ECDSA: - signature = ecdsa_sign(self.private_key, data) - elif self.crypto_type == KeyType.SOLANA: - signature = solana_sign(self.private_key, data) - else: - raise Exception("Crypto type not supported") - - if to_json: - return {'data':data.decode(), - 'crypto_type':self.crypto_type, - 'signature':signature.hex(), - 'address': self.ss58_address,} - return signature - - - @classmethod - def bytes2str(cls, data: bytes, mode: str = 'utf-8') -> str: - - if hasattr(data, 'hex'): - return data.hex() - else: - if isinstance(data, str): - return data - return bytes.decode(data, mode) - - @classmethod - def python2str(cls, input): - from copy import deepcopy - import json - - input = deepcopy(input) - input_type = type(input) - if input_type == str: - return input - if input_type in [dict]: - input = json.dumps(input) - elif input_type in [bytes]: - input = cls.bytes2str(input) - elif input_type in [list, tuple, set]: - input = json.dumps(list(input)) - elif input_type in [int, float, bool]: - input = str(input) - return input - - def verify(self, - data: Union[ScaleBytes, bytes, str, dict], - signature: Union[bytes, str] = None, - public_key:Optional[str]= None, - return_address = False, - ss58_format = ss58_format, - max_age = None, - address = None, - **kwargs - ) -> bool: - """ - Verifies data with specified signature - - Parameters - ---------- - data: data to be verified in `Scalebytes`, bytes or hex string format - signature: signature in bytes or hex string format - public_key: public key in bytes or hex string format - - Returns - ------- - True if data is signed with this Key, otherwise False - """ - data = c.copy(data) - - if isinstance(data, dict): - if self.is_ticket(data): - address = data.pop('address') - signature = data.pop('signature') - elif 'data' in data and 'signature' in data and 'address' in data: - signature = data.pop('signature') - address = data.pop('address', address) - data = data.pop('data') - else: - assert signature != None, 'signature not found in data' - assert address != None, 'address not found in data' - - if max_age != None: - if isinstance(data, int): - staleness = c.timestamp() - int(data) - elif 'timestamp' in data or 'time' in data: - timestamp = data.get('timestamp', data.get('time')) - staleness = c.timestamp() - int(timestamp) - else: - raise ValueError('data should be a timestamp or a dict with a timestamp key') - assert staleness < max_age, f'data is too old, {staleness} seconds old, max_age is {max_age}' - - if not isinstance(data, str): - data = c.python2str(data) - if address != None: - if self.valid_ss58_address(address): - public_key = ss58_decode(address) - if public_key == None: - public_key = self.public_key - if isinstance(public_key, str): - public_key = bytes.fromhex(public_key.replace('0x', '')) - if type(data) is ScaleBytes: - data = bytes(data.data) - elif data[0:2] == '0x': - data = bytes.fromhex(data[2:]) - elif type(data) is str: - data = data.encode() - if type(signature) is str and signature[0:2] == '0x': - signature = bytes.fromhex(signature[2:]) - elif type(signature) is str: - signature = bytes.fromhex(signature) - if type(signature) is not bytes: - raise TypeError("Signature should be of type bytes or a hex-string") - - - if self.crypto_type == KeyType.SR25519: - crypto_verify_fn = sr25519.verify - elif self.crypto_type == KeyType.ED25519: - crypto_verify_fn = ed25519_zebra.ed_verify - elif self.crypto_type == KeyType.ECDSA: - crypto_verify_fn = ecdsa_verify - elif self.crypto_type == KeyType.SOLANA: - crypto_verify_fn = solana_verify - else: - raise Exception("Crypto type not supported") - verified = crypto_verify_fn(signature, data, public_key) - if not verified: - # Another attempt with the data wrapped, as discussed in https://github.com/polkadot-js/extension/pull/743 - # Note: As Python apps are trusted sources on its own, no need to wrap data when signing from this lib - verified = crypto_verify_fn(signature, b'' + data + b'', public_key) - if return_address: - return ss58_encode(public_key, ss58_format=ss58_format) - return verified - - def is_ticket(self, data): - return all([k in data for k in ['data','signature', 'address', 'crypto_type']]) and any([k in data for k in ['time', 'timestamp']]) - - def resolve_encryption_password(self, password:str=None) -> str: - if password == None: - password = self.private_key - if isinstance(password, str): - password = password.encode() - return hashlib.sha256(password).digest() - - def resolve_encryption_data(self, data): - if not isinstance(data, str): - data = str(data) - return data - - def encrypt(self, data, password=None): - data = self.resolve_encryption_data(data) - password = self.resolve_encryption_password(password) - data = data + (AES.block_size - len(data) % AES.block_size) * chr(AES.block_size - len(data) % AES.block_size) - iv = Random.new().read(AES.block_size) - cipher = AES.new(password, AES.MODE_CBC, iv) - encrypted_bytes = base64.b64encode(iv + cipher.encrypt(data.encode())) - return encrypted_bytes.decode() - - def decrypt(self, data, password=None): - password = self.resolve_encryption_password(password) - data = base64.b64decode(data) - iv = data[:AES.block_size] - cipher = AES.new(password, AES.MODE_CBC, iv) - data = cipher.decrypt(data[AES.block_size:]) - data = data[:-ord(data[len(data)-1:])].decode('utf-8') - return data - - def encrypt_message( - self, - message: Union[bytes, str], - recipient_public_key: bytes, - nonce: bytes = secrets.token_bytes(24), - ) -> bytes: - """ - Encrypts message with for specified recipient - - Parameters - ---------- - message: message to be encrypted, bytes or string - recipient_public_key: recipient's public key - nonce: the nonce to use in the encryption - - Returns - ------- - Encrypted message - """ - if not self.private_key: - raise Exception('No private key set to encrypt') - if self.crypto_type != KeyType.ED25519: - raise Exception('Only ed25519 keypair type supported') - curve25519_public_key = nacl.bindings.crypto_sign_ed25519_pk_to_curve25519(recipient_public_key) - recipient = nacl.public.PublicKey(curve25519_public_key) - private_key = nacl.bindings.crypto_sign_ed25519_sk_to_curve25519(self.private_key + self.public_key) - sender = nacl.public.PrivateKey(private_key) - box = nacl.public.Box(sender, recipient) - return box.encrypt(message if isinstance(message, bytes) else message.encode("utf-8"), nonce) - - def decrypt_message(self, encrypted_message_with_nonce: bytes, sender_public_key: bytes) -> bytes: - """ - Decrypts message from a specified sender - - Parameters - ---------- - encrypted_message_with_nonce: message to be decrypted - sender_public_key: sender's public key - - Returns - ------- - Decrypted message - """ - - if not self.private_key: - raise Exception('No private key set to decrypt') - if self.crypto_type != KeyType.ED25519: - raise Exception('Only ed25519 keypair type supported') - private_key = nacl.bindings.crypto_sign_ed25519_sk_to_curve25519(self.private_key + self.public_key) - recipient = nacl.public.PrivateKey(private_key) - curve25519_public_key = nacl.bindings.crypto_sign_ed25519_pk_to_curve25519(sender_public_key) - sender = nacl.public.PublicKey(curve25519_public_key) - return nacl.public.Box(recipient, sender).decrypt(encrypted_message_with_nonce) - - encrypted_prefix = 'ENCRYPTED::' - - @classmethod - def encrypt_key(cls, path = 'test.enc', password=None): - assert cls.key_exists(path), f'file {path} does not exist' - assert not cls.is_key_encrypted(path), f'{path} already encrypted' - data = cls.get(path) - enc_text = {'data': c.encrypt(data, password=password), - 'encrypted': True} - cls.put(path, enc_text) - return {'number_of_characters_encrypted':len(enc_text), 'path':path } - - @classmethod - def is_key_encrypted(cls, key, data=None): - data = data or cls.get(key) - return cls.is_encrypted(data) - - @classmethod - def decrypt_key(cls, path = 'test.enc', password=None, key=None): - assert cls.key_exists(path), f'file {path} does not exist' - assert cls.is_key_encrypted(path), f'{path} not encrypted' - data = cls.get(path) - assert cls.is_encrypted(data), f'{path} not encrypted' - dec_text = c.decrypt(data['data'], password=password) - cls.put(path, dec_text) - assert not cls.is_key_encrypted(path), f'failed to decrypt {path}' - loaded_key = c.get_key(path) - return { 'path':path , - 'key_address': loaded_key.ss58_address, - 'crypto_type': loaded_key.crypto_type} - - @classmethod - def get_mnemonic(cls, key): - return cls.get_key(key).mnemonic - - def __str__(self): - return f'' - - def save(self, path=None): - if path == None: - path = self.path - c.put_json(path, self.to_json()) - return {'saved':path} - - def __repr__(self): - return self.__str__() - - @classmethod - def from_private_key(cls, private_key:str): - return cls(private_key=private_key) - - @classmethod - def valid_ss58_address(cls, address: str, ss58_format=ss58_format ) -> bool: - """ - Checks if the given address is a valid ss58 address. - """ - try: - return is_valid_ss58_address( address , valid_ss58_format =ss58_format ) - except Exception as e: - return False - - @classmethod - def is_encrypted(cls, data): - if isinstance(data, str): - if os.path.exists(data): - data = c.get_json(data) - else: - try: - data = json.loads(data) - except: - return False - if isinstance(data, dict): - return bool(data.get('encrypted', False)) - else: - return False - - @staticmethod - def ss58_encode(*args, **kwargs): - return ss58_encode(*args, **kwargs) - - @staticmethod - def ss58_decode(*args, **kwargs): - return ss58_decode(*args, **kwargs) - - @classmethod - def get_key_address(cls, key): - return cls.get_key(key).ss58_address - - @classmethod - def resolve_key_address(cls, key): - key2address = c.key2address() - if key in key2address: - address = key2address[key] - else: - address = key - return address - - @classmethod - def valid_h160_address(cls, address): - # Check if it starts with '0x' - if not address.startswith('0x'): - return False - - # Remove '0x' prefix - address = address[2:] - - # Check length - if len(address) != 40: - return False - - # Check if it contains only valid hex characters - if not re.match('^[0-9a-fA-F]{40}$', address): - return False - - return True - - def storage_migration(self): - key2path = self.key2path() - new_key2path = {} - for k_name, k_path in key2path.items(): - try: - key = c.get_key(k_name) - new_k_path = '/'.join(k_path.split('/')[:-1]) + '/' + f'{k_name}_address={key.ss58_address}_type={key.crypto_type}.json' - new_key2path[k_name] = new_k_path - except Exception as e: - c.print(f'failed to migrate {k_name} due to {e}', color='red') - - return new_key2path - -# if __name__ == "__main__": -# Key.run() - - - - - - - From 850d1b6a0d4fcea654318a4791c31748880a253f Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Wed, 15 Jan 2025 18:55:48 -0600 Subject: [PATCH 20/27] ref: commune key module --- commune/key/__init__.py | 1 + commune/key/{types/index.py => constants.py} | 0 commune/key/key.py | 34 +++--- commune/key/test_key.py | 109 ++++++++++++++++++ commune/key/types/__init__.py | 6 + .../{dot_ed25519.py => types/dot/ed25519.py} | 6 +- .../{dot_sr25519.py => types/dot/sr25519.py} | 6 +- commune/key/{ => types}/eth.py | 6 +- commune/key/{ => types}/sol.py | 6 +- commune/key/utils.py | 2 +- 10 files changed, 146 insertions(+), 30 deletions(-) rename commune/key/{types/index.py => constants.py} (100%) create mode 100644 commune/key/test_key.py create mode 100644 commune/key/types/__init__.py rename commune/key/{dot_ed25519.py => types/dot/ed25519.py} (99%) rename commune/key/{dot_sr25519.py => types/dot/sr25519.py} (99%) rename commune/key/{ => types}/eth.py (99%) rename commune/key/{ => types}/sol.py (99%) diff --git a/commune/key/__init__.py b/commune/key/__init__.py index e69de29bb..2a2207852 100644 --- a/commune/key/__init__.py +++ b/commune/key/__init__.py @@ -0,0 +1 @@ +from .key import Key \ No newline at end of file diff --git a/commune/key/types/index.py b/commune/key/constants.py similarity index 100% rename from commune/key/types/index.py rename to commune/key/constants.py diff --git a/commune/key/key.py b/commune/key/key.py index dd953bdb9..d4e438eec 100644 --- a/commune/key/key.py +++ b/commune/key/key.py @@ -19,7 +19,7 @@ from scalecodec.utils.ss58 import is_valid_ss58_address -from .types.index import * +from commune.key.constants import * from .utils import * class KeyType: @@ -42,16 +42,16 @@ def __new__( ): crypto_type = cls.resolve_crypto_type(crypto_type) if crypto_type == KeyType.SR25519: - from .dot_sr25519 import DotSR25519 + from .types.dot.sr25519 import DotSR25519 return super().__new__(DotSR25519) elif crypto_type == KeyType.ED25519: - from .dot_ed25519 import DotED25519 + from commune.key.types.dot.ed25519 import DotED25519 return super().__new__(DotED25519) elif crypto_type == KeyType.ECDSA: - from .eth import ECDSA + from .types.eth import ECDSA return super().__new__(ECDSA) elif crypto_type == KeyType.SOLANA: - from .sol import Solana + from .types.sol import Solana return super().__new__(Solana) else: raise NotImplementedError(f"unsupported crypto_type {crypto_type}") @@ -507,16 +507,16 @@ def create_from_mnemonic( """ crypto_type = cls.resolve_crypto_type(crypto_type) if crypto_type == KeyType.SR25519: - from commune.key.dot_sr25519 import DotSR25519 + from commune.key.types.dot.sr25519 import DotSR25519 return DotSR25519.create_from_mnemonic(mnemonic=mnemonic, ss58_format=ss58_format, language_code=language_code, crypto_type=crypto_type) elif crypto_type == KeyType.ED25519: - from commune.key.dot_ed25519 import DotED25519 + from commune.key.types.dot.ed25519 import DotED25519 return DotED25519.create_from_mnemonic(mnemonic=mnemonic, ss58_format=ss58_format, language_code=language_code, crypto_type=crypto_type) elif crypto_type == KeyType.ECDSA: - from commune.key.eth import ECDSA + from commune.key.types.eth import ECDSA return ECDSA.create_from_mnemonic(mnemonic=mnemonic, ss58_format=ss58_format, language_code=language_code, crypto_type=crypto_type) elif crypto_type == KeyType.SOLANA: - from commune.key.sol import Solana + from commune.key.types.sol import Solana return Solana.create_from_mnemonic(mnemonic=mnemonic, ss58_format=ss58_format, language_code=language_code, crypto_type=crypto_type) else: raise NotImplementedError("create_from_mnemonic not implemented") @@ -545,16 +545,16 @@ def create_from_seed( """ crypto_type = cls.resolve_crypto_type(crypto_type) if crypto_type == KeyType.SR25519: - from commune.key.dot_sr25519 import DotSR25519 + from commune.key.types.dot.sr25519 import DotSR25519 return DotSR25519.create_from_seed(seed_hex=seed_hex, ss58_format=ss58_format, crypto_type=crypto_type) elif crypto_type == KeyType.ED25519: - from commune.key.dot_ed25519 import DotED25519 + from commune.key.types.dot.ed25519 import DotED25519 return DotED25519.create_from_seed(seed_hex=seed_hex, ss58_format=ss58_format, crypto_type=crypto_type) elif crypto_type == KeyType.ECDSA: - from commune.key.eth import ECDSA + from commune.key.types.eth import ECDSA return ECDSA.create_from_seed(seed_hex=seed_hex, ss58_format=ss58_format, crypto_type=crypto_type) elif crypto_type == KeyType.SOLANA: - from commune.key.sol import Solana + from commune.key.types.sol import Solana return Solana.create_from_seed(seed_hex=seed_hex, ss58_format=ss58_format, crypto_type=crypto_type) else: raise NotImplementedError("create_from_seed not implemented") @@ -588,16 +588,16 @@ def create_from_uri( Key """ if crypto_type == KeyType.SR25519: - from commune.key.dot_sr25519 import DotSR25519 + from commune.key.types.dot.sr25519 import DotSR25519 return DotSR25519.create_from_uri(suri=suri, ss58_format=ss58_format, language_code=language_code) elif crypto_type == KeyType.ED25519: - from commune.key.dot_ed25519 import DotED25519 + from commune.key.types.dot.ed25519 import DotED25519 return DotED25519.create_from_uri(suri=suri, ss58_format=ss58_format, language_code=language_code) elif crypto_type == KeyType.ECDSA: - from commune.key.eth import ECDSA + from commune.key.types.eth import ECDSA return ECDSA.create_from_uri(suri=suri, ss58_format=ss58_format, language_code=language_code) elif crypto_type == KeyType.SOLANA: - from commune.key.sol import Solana + from commune.key.types.sol import Solana return Solana.create_from_uri(suri=suri, ss58_format=ss58_format, language_code=language_code) else: raise NotImplementedError("create_from_uri not implemented") diff --git a/commune/key/test_key.py b/commune/key/test_key.py new file mode 100644 index 000000000..ebe2aed98 --- /dev/null +++ b/commune/key/test_key.py @@ -0,0 +1,109 @@ +import pytest +import commune as c + +crypto_types = ['ecdsa', 'solana', 'sr25519', 'ed25519'] + +@pytest.mark.parametrize('crypto_type', crypto_types) +def test_encryption(crypto_type, values = [10, 'fam', 'hello world']): + cls = c.module('key') + for value in values: + value = str(value) + key = cls.new_key(crypto_type=crypto_type) + enc = key.encrypt(value) + dec = key.decrypt(enc) + assert dec == value, f'encryption failed, {dec} != {value}' + return {'encrypted':enc, 'decrypted': dec} + +@pytest.mark.parametrize('crypto_type', crypto_types) +def test_encryption_with_password(crypto_type, value = 10, password = 'fam'): + cls = c.module('key') + value = str(value) + key = cls.new_key(crypto_type=crypto_type) + enc = key.encrypt(value, password=password) + dec = key.decrypt(enc, password=password) + assert dec == value, f'encryption failed, {dec} != {value}' + return {'encrypted':enc, 'decrypted': dec} + +@pytest.mark.parametrize('crypto_type', crypto_types) +def test_key_encryption(crypto_type, test_key='test.key'): + self = c.module('key') + key = self.add_key(test_key, refresh=True, crypto_type=crypto_type) + og_key = self.get_key(test_key, crypto_type=crypto_type) + r = self.encrypt_key(test_key) + self.decrypt_key(test_key, password=r['password']) + key = self.get_key(test_key, crypto_type=crypto_type) + assert key.ss58_address == og_key.ss58_address, f'key encryption failed, {key.ss58_address} != {self.ss58_address}' + return {'success': True, 'msg': 'test_key_encryption passed'} + +@pytest.mark.parametrize('crypto_type', crypto_types) +def test_key_management(crypto_type, key1='test.key' , key2='test2.key'): + key = c.module('key') + if key.key_exists(key1): + key.rm_key(key1) + if key.key_exists(key2): + key.rm_key(key2) + key.add_key(key1, crypto_type=crypto_type) + k1 = key.get_key(key1, crypto_type=crypto_type) + assert key.key_exists(key1), f'Key management failed, key still exists' + key.mv_key(key1, key2, crypto_type=crypto_type) + k2 = key.get_key(key2, crypto_type=crypto_type) + assert k1.ss58_address == k2.ss58_address, f'Key management failed, {k1.ss58_address} != {k2.ss58_address}' + assert key.key_exists(key2), f'Key management failed, key does not exist' + assert not key.key_exists(key1), f'Key management failed, key still exists' + key.mv_key(key2, key1, crypto_type=crypto_type) + assert key.key_exists(key1), f'Key management failed, key does not exist' + assert not key.key_exists(key2), f'Key management failed, key still exists' + key.rm_key(key1) + # key.rm_key(key2) + assert not key.key_exists(key1), f'Key management failed, key still exists' + assert not key.key_exists(key2), f'Key management failed, key still exists' + return {'success': True, 'msg': 'test_key_management passed'} + +@pytest.mark.parametrize('crypto_type', crypto_types) +def test_signing(crypto_type,): + self = c.module('key')(crypto_type=crypto_type) + sig = self.sign('test') + assert self.verify('test',sig, self.public_key) + return {'success':True} + +@pytest.mark.parametrize('crypto_type', crypto_types) +def test_key_encryption(crypto_type, password='1234'): + cls = c.module('key') + path = 'test.enc' + cls.add_key('test.enc', refresh=True, crypto_type=crypto_type) + assert cls.is_key_encrypted(path) == False, f'file {path} is encrypted' + cls.encrypt_key(path, password=password) + assert cls.is_key_encrypted(path) == True, f'file {path} is not encrypted' + cls.decrypt_key(path, password=password) + assert cls.is_key_encrypted(path) == False, f'file {path} is encrypted' + cls.rm(path) + print('file deleted', path, c.exists, 'fam') + assert not c.exists(path), f'file {path} not deleted' + return {'success': True, 'msg': 'test_key_encryption passed'} + +@pytest.mark.parametrize('crypto_type', crypto_types) +def test_move_key(crypto_type): + self = c.module('key')(crypto_type=crypto_type) + self.add_key('testfrom', crypto_type=crypto_type) + assert self.key_exists('testfrom') + og_key = self.get_key('testfrom', crypto_type=crypto_type) + self.mv_key('testfrom', 'testto', crypto_type=crypto_type) + assert self.key_exists('testto', crypto_type=crypto_type) + assert not self.key_exists('testfrom') + new_key = self.get_key('testto', crypto_type=crypto_type) + assert og_key.ss58_address == new_key.ss58_address + self.rm_key('testto') + assert not self.key_exists('testto') + return {'success':True, 'msg':'test_move_key passed', 'key':new_key.ss58_address} + + +def test_ss58_encoding(): + self = c.module('key') + keypair = self.create_from_uri('//Alice') + ss58_address = keypair.ss58_address + public_key = keypair.public_key + assert keypair.ss58_address == self.ss58_encode(public_key, ss58_format=42) + assert keypair.ss58_address == self.ss58_encode(public_key, ss58_format=42) + assert keypair.public_key.hex() == self.ss58_decode(ss58_address) + assert keypair.public_key.hex() == self.ss58_decode(ss58_address) + return {'success':True} diff --git a/commune/key/types/__init__.py b/commune/key/types/__init__.py new file mode 100644 index 000000000..fda498480 --- /dev/null +++ b/commune/key/types/__init__.py @@ -0,0 +1,6 @@ +from .dot.ed25519 import DotED25519 +from .dot.sr25519 import DotSR25519 +from .eth import ECDSA +from .sol import Solana + +__all__ = ["DotED25519", "DotSR25519", "ECDSA", "Solana"] \ No newline at end of file diff --git a/commune/key/dot_ed25519.py b/commune/key/types/dot/ed25519.py similarity index 99% rename from commune/key/dot_ed25519.py rename to commune/key/types/dot/ed25519.py index d0a9e72b2..0a2e8c04a 100644 --- a/commune/key/dot_ed25519.py +++ b/commune/key/types/dot/ed25519.py @@ -2,8 +2,8 @@ import sr25519 import commune as c -from .types.index import * -from .utils import * +from commune.key.constants import * +from commune.key.utils import * import binascii import re @@ -15,7 +15,7 @@ ss58_encode ) from typing import Union, Optional -from .key import Key, KeyType +from commune.key.key import Key, KeyType class DotED25519(Key): def __init__( self, diff --git a/commune/key/dot_sr25519.py b/commune/key/types/dot/sr25519.py similarity index 99% rename from commune/key/dot_sr25519.py rename to commune/key/types/dot/sr25519.py index 4911f5450..eb77cd8cc 100644 --- a/commune/key/dot_sr25519.py +++ b/commune/key/types/dot/sr25519.py @@ -1,8 +1,8 @@ import sr25519 import commune as c -from .types.index import * -from .utils import * +from commune.key.constants import * +from commune.key.utils import * import json @@ -27,7 +27,7 @@ get_ss58_format, ) from typing import Union, Optional -from .key import Key, KeyType +from ...key import Key, KeyType class DotSR25519(Key): def __init__( self, diff --git a/commune/key/eth.py b/commune/key/types/eth.py similarity index 99% rename from commune/key/eth.py rename to commune/key/types/eth.py index 9bff9d4a0..28b9bbd60 100644 --- a/commune/key/eth.py +++ b/commune/key/types/eth.py @@ -1,8 +1,8 @@ import sr25519 import commune as c -from .types.index import * -from .utils import * +from commune.key.constants import * +from commune.key.utils import * import json @@ -27,7 +27,7 @@ get_ss58_format, ) from typing import Union, Optional -from .key import Key, KeyType +from ..key import Key, KeyType class ECDSA(Key): def __init__( self, diff --git a/commune/key/sol.py b/commune/key/types/sol.py similarity index 99% rename from commune/key/sol.py rename to commune/key/types/sol.py index 6c3a949bd..e59b2d2f3 100644 --- a/commune/key/sol.py +++ b/commune/key/types/sol.py @@ -1,8 +1,8 @@ import sr25519 import commune as c -from .types.index import * -from .utils import * +from commune.key.constants import * +from commune.key.utils import * import re import nacl.public @@ -12,7 +12,7 @@ ) from base58 import b58encode from typing import Union, Optional -from .key import Key, KeyType +from ..key import Key, KeyType class Solana(Key): def __init__( self, diff --git a/commune/key/utils.py b/commune/key/utils.py index 6fa8c07c6..fd2657c9e 100644 --- a/commune/key/utils.py +++ b/commune/key/utils.py @@ -19,7 +19,7 @@ from solders.pubkey import Pubkey as SolanaPubkey from solders.signature import Signature as SolanaSignature -from .types.index import * +from .constants import * class PublicKey: def __init__(self, private_key): From 52e4c88d5752f2aeba66cee59540345536342187 Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Wed, 15 Jan 2025 18:56:31 -0600 Subject: [PATCH 21/27] ref: reconstruct network module --- commune/network/network.py | 101 +++++++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 commune/network/network.py diff --git a/commune/network/network.py b/commune/network/network.py new file mode 100644 index 000000000..da3e9c674 --- /dev/null +++ b/commune/network/network.py @@ -0,0 +1,101 @@ +from typing import * +import os +import commune as c +class Network(c.Module): + min_stake = 0 + blocktime = block_time = 8 + n = 100 + tempo = 60 + blocks_per_day = 24*60*60/block_time + # the default + endpoints = ['namespace'] + def __init__(self, network:str='local', tempo=tempo, path=None, **kwargs): + self.set_network(network=network, tempo=tempo, path=path) + + def set_network(self, network:str, tempo:int=60, path=None, **kwargs): + self.network = network + self.tempo = tempo + self.modules_path = self.resolve_path(path or f'{self.network}/modules') + return {'network': self.network, 'tempo': self.tempo, 'modules_path': self.modules_path} + + def params(self,*args, **kwargs): + return { 'network': self.network, + 'tempo' : self.tempo, + 'n': self.n} + + def net(self): + return c.network() + + def modules(self, + search=None, + max_age=tempo, + update=False, + features=['name', 'address', 'key'], + timeout=8, + **kwargs): + modules = c.get(self.modules_path, max_age=max_age, update=update) + if modules == None: + modules = [] + addresses = ['0.0.0.0'+':'+str(p) for p in c.used_ports()] + futures = [c.submit(c.call, [s + '/info'], timeout=timeout) for s in addresses] + try: + for f in c.as_completed(futures, timeout=timeout): + data = f.result() + if all([k in data for k in features]): + modules.append({k: data[k] for k in features}) + except Exception as e: + c.print('Error getting modules', e) + modules = [] + c.put(self.modules_path, modules) + if search != None: + modules = [m for m in modules if search in m['name']] + return modules + + def namespace(self, search=None, max_age:int = tempo, update:bool = False, **kwargs) -> dict: + return {m['name']: '0.0.0.0' + ':' + m['address'].split(':')[-1] for m in self.modules(search=search, max_age=max_age, update=update)} + + def add_server(self, name:str, address:str, key:str) -> None: + data = {'name': name, 'address': address, 'key': key} + modules = self.modules() + modules.append(data) + c.put(self.modules_path, modules) + return {'success': True, 'msg': f'Block {name}.'} + + def register_from_signature(self, signature=None): + import json + assert c.verify(signature), 'Signature is not valid.' + data = json.loads(signature['data']) + return self.add_server(data['name'], data['address']) + + def remove_server(self, name:str, features=['name', 'key', 'address']) -> Dict: + modules = self.modules() + modules = [m for m in modules if not any([m[f] == name for f in features])] + c.put(self.modules_path, modules) + + def resolve_network(self, network:str) -> str: + return network or self.network + + def names(self, *args, **kwargs) -> List[str]: + return list(self.namespace(*args, **kwargs).keys()) + + def addresses(self,*args, **kwargs) -> List[str]: + return list(self.namespace(*args, **kwargs).values()) + + def servers(self, search=None, **kwargs) -> List[str]: + namespace = self.namespace(search=search,**kwargs) + return list(namespace.keys()) + + def server_exists(self, name:str, **kwargs) -> bool: + servers = self.servers(**kwargs) + return bool(name in servers) + + def networks(self) -> List[str]: + return ['local', 'subspace', 'subtensor'] + + def infos(self, *args, **kwargs) -> Dict: + return [c.call(address+'/info') for name, address in self.namespace(*args, **kwargs).items()] + +if __name__ == "__main__": + Network.run() + + From 91b7711e575ec5a304fd69afc3d2c4029ca1c040 Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Wed, 15 Jan 2025 18:57:06 -0600 Subject: [PATCH 22/27] ref: ref: reconstruct server module --- commune/server/server.py | 664 ++++++++++++++++++++++++++++++++++ commune/server/test_server.py | 69 ++++ 2 files changed, 733 insertions(+) create mode 100644 commune/server/server.py create mode 100644 commune/server/test_server.py diff --git a/commune/server/server.py b/commune/server/server.py new file mode 100644 index 000000000..eec9dae89 --- /dev/null +++ b/commune/server/server.py @@ -0,0 +1,664 @@ +import commune as c +from typing import * +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware +from sse_starlette.sse import EventSourceResponse +import uvicorn +import os +import json +import asyncio + +class Middleware(BaseHTTPMiddleware): + def __init__(self, app, max_bytes: int): + super().__init__(app) + self.max_bytes = max_bytes + async def dispatch(self, request: Request, call_next): + content_length = request.headers.get('content-length') + if content_length: + if int(content_length) > self.max_bytes: + return JSONResponse(status_code=413, content={"error": "Request too large"}) + body = await request.body() + if len(body) > self.max_bytes: + return JSONResponse(status_code=413, content={"error": "Request too large"}) + response = await call_next(request) + return response + +class Server: + tag_seperator:str='::' + user_data_lifetime = 3600 + pm2_dir = os.path.expanduser('~/.pm2') + period : int = 3600 # the period for + max_request_staleness : int = 4 # (in seconds) the time it takes for the request to be too old + max_network_staleness: int = 60 # (in seconds) the time it takes for. the network to refresh + multipliers : Dict[str, float] = {'stake': 1, 'stake_to': 1,'stake_from': 1} + rates : Dict[str, int]= {'max': 10, 'local': 10000, 'stake': 1000, 'owner': 10000, 'admin': 10000} # the maximum rate ): + helper_functions = ['info', 'metadata', 'schema', 'free', 'name', 'functions','key_address', 'crypto_type','fns', 'forward', 'rate_limit'] # the helper functions + functions_attributes =['helper_functions', 'whitelist', "whitelist_functions", 'endpoints', 'functions', 'fns', "exposed_functions",'server_functions', 'public_functions'] # the attributes for the functions + def __init__( + self, + module: Union[c.Module, object] = None, + key:str = None, # key for the server (str) + name: str = None, # the name of the server + functions:Optional[List[Union[str, callable]]] = None, # list of endpoints + port: Optional[int] = None, # the port the server is running on + network:str = 'subspace', # the network used for incentives + fn2cost : Dict[str, float] = None, # the cost of the function + free : bool = False, # if the server is free (checks signature) + kwargs : dict = None, # the kwargs for the module + crypto_type = 'sr25519', # the crypto type of the key + users_path: Optional[str] = None, # the path to the user data + serializer: str = 'serializer', # the serializer used for the data + ) -> 'Server': + module = module or 'module' + kwargs = kwargs or {} + if self.tag_seperator in str(name): + # module::fam -> module=module, name=module::fam key=module::fam (default) + module, tag = name.split(self.tag_seperator) + module = c.module(module)(**kwargs) + if isinstance(module, str): + name = name or module + module = c.module(module)(**kwargs) + print(f'Launching', module, name, functions) + # NOTE: ONLY ENABLE FREEMODE IF YOU ARE ON A CLOSED NETWORK, + self.free = free + self.module = module + self.module.name = name + self.set_key(key=key, crypto_type=crypto_type) + self.set_port(port) + self.set_network(network) + self.set_functions(functions=functions, fn2cost=fn2cost) + self.set_user_path(users_path) + self.serializer = c.module(serializer)() + self.start_server() + + def set_functions(self, functions:Optional[List[str]] , fn2cost=None): + if self.free: + c.print('THE FOUNDING FATHERS WOULD BE PROUD OF YOU SON OF A BITCH', color='red') + else: + if hasattr(self.module, 'free'): + self.free = self.module.free + self.module.free = self.free + functions = functions or [] + for i, fn in enumerate(functions): + if callable(fn): + print('Adding function', f) + setattr(self, fn.__name__, fn) + functions[i] = fn.__name__ + functions = sorted(list(set(functions + self.helper_functions))) + module = self.module + for k in self.functions_attributes: + if hasattr(module, k) and isinstance(getattr(module, k), list): + print('Found ', k) + functions = getattr(module, k) + break + # get function decorators form c.endpoint() + for f in dir(module): + try: + if hasattr(getattr(module, f), '__metadata__'): + functions.append(f) + except Exception as e: + c.print(f'Error in get_endpoints: {e} for {f}') + module.functions = sorted(list(set(functions))) + ## get the schema for the functions + schema = {} + for fn in functions : + if hasattr(module, fn): + schema[fn] = c.schema(getattr(module, fn )) + else: + print(f'Function {fn} not found in {module.name}') + module.schema = dict(sorted(schema.items())) + module.fn2cost = module.fn2cost if hasattr(module, 'fn2cost') else (fn2cost or {}) + assert isinstance(module.fn2cost, dict), f'fn2cost must be a dict, not {type(module.fn2cost)}' + + ### get the info for the module + module.info = { + "functions": functions, + "schema": schema, + "name": module.name, + "address": module.address, + "key": module.key.ss58_address, + "crypto_type": module.key.crypto_type, + "fn2cost": module.fn2cost, + "free": module.free, + "time": c.time() + } + + def set_key(self, key, crypto_type): + module = self.module + module.key = c.get_key(key or module.name, create_if_not_exists=True, crypto_type=crypto_type) + module.key_address = module.key.key_address + module.crypto_type = module.key.crypto_type + return {'success':True, 'message':f'Set key to {module.key.ss58_address}'} + + def start_server(self, + max_bytes = 10 * 1024 * 1024 , # max bytes within the request (bytes) + allow_origins = ["*"], # allowed origins + allow_credentials =True, # allow credentials + allow_methods = ["*"], # allowed methods + allow_headers = ["*"] , # allowed headers + ): + module = self.module + c.thread(self.sync_loop) + self.loop = asyncio.get_event_loop() + app = FastAPI() + app.add_middleware(Middleware, max_bytes=max_bytes) + app.add_middleware(CORSMiddleware, + allow_origins=allow_origins, + allow_credentials=allow_credentials, + allow_methods=allow_methods, + allow_headers=allow_headers) + def api_forward(fn:str, request: Request): + return self.forward(fn, request) + app.post("/{fn}")(api_forward) + c.print(f'Served(name={module.name}, address={module.address}, key={module.key.key_address})', color='purple') + c.print(c.add_server(name=module.name, address=module.address, key=module.key.ss58_address)) + self.module = module + uvicorn.run(app, host='0.0.0.0', port=module.port, loop='asyncio') + + def set_port(self, port:Optional[int]=None, port_attributes = ['port', 'server_port'], ip = None): + module = self.module + name = module.name + for k in port_attributes: + if hasattr(module, k): + port = getattr(module, k) + break + if port in [None, 'None']: + namespace = c.namespace() + if name in namespace: + c.kill(name) + try: + port = int(namespace.get(module.name).split(':')[-1]) + except: + port = c.free_port() + else: + port = c.free_port() + + while c.port_used(port): + c.kill_port(port) + c.sleep(1) + print(f'Waiting for port {port} to be free') + + module.port = port + ip = ip or '0.0.0.0' + module.address = ip + ':' + str(module.port) + self.module = module + return {'success':True, 'message':f'Set port to {port}'} + + def is_admin(self, address): + return c.is_admin(address) + + def gate(self, fn:str, data:dict, headers:dict ) -> bool: + if self.free: + assert fn in self.module.functions , f"Function {fn} not in endpoints={self.module.functions}" + return True + auth = {'data': data, 'time': str(headers['time'])} + signature = headers['signature'] + + assert c.verify(auth=auth,signature=signature, address=headers['key']), 'Invalid signature' + request_staleness = c.time() - float(headers['time']) + assert request_staleness < self.max_request_staleness, f"Request is too old ({request_staleness}s > {self.max_request_staleness}s (MAX)" + auth={'data': data, 'time': str(headers['time'])} + module = self.module + address = headers['key'] + if c.is_admin(address): + rate_limit = self.rates['admin'] + elif address == module.key.ss58_address: + rate_limit = self.rates['owner'] + elif address in self.address2key: + rate_limit = self.rates['local'] + else: + stake_score = self.state['stake'].get(address, 0) + self.multipliers['stake'] + stake_to_score = (sum(self.state['stake_to'].get(address, {}).values())) * self.multipliers['stake_to'] + stake_from_score = self.state['stake_from'].get(module.key.ss58_address, {}).get(address, 0) * self.multipliers['stake_from'] + stake = stake_score + stake_to_score + stake_from_score + self.rates['stake'] = self.rates['stake'] * module.fn2cost.get(fn, 1) + rate_limit = min((stake / self.rates['stake']), self.rates['max']) + count = self.user_call_count(headers['key']) + assert count <= rate_limit, f'rate limit exceeded {count} > {rate_limit}' + return True + + def get_data(self, request: Request): + data = self.loop.run_until_complete(request.json()) + # data = self.serializer.deserialize(data) + if isinstance(data, str): + data = json.loads(data) + if 'kwargs' in data or 'params' in data: + kwargs = dict(data.get('kwargs', data.get('params', {}))) + else: + kwargs = data + if 'args' in data: + args = list(data.get('args', [])) + else: + args = [] + data = {'args': args, 'kwargs': kwargs} + return data + + def get_headers(self, request: Request): + headers = dict(request.headers) + headers['time'] = float(headers.get('time', c.time())) + headers['key'] = headers.get('key', headers.get('address', None)) + return headers + + def forward(self, fn:str, request: Request, catch_exception:bool=True) -> dict: + if catch_exception: + try: + return self.forward(fn, request, catch_exception=False) + except Exception as e: + result = c.detailed_error(e) + c.print(result, color='red') + return result + module = self.module + data = self.get_data(request) + headers = self.get_headers(request) + self.gate(fn=fn, data=data, headers=headers) + is_admin = bool(c.is_admin(headers['key'])) + is_owner = bool(headers['key'] == module.key.ss58_address) + if hasattr(module, fn): + fn_obj = getattr(module, fn) + elif (is_admin or is_owner) and hasattr(self, fn): + fn_obj = getattr(self, fn) + else: + raise Exception(f"{fn} not found in {module.name}") + result = fn_obj(*data['args'], **data['kwargs']) if callable(fn_obj) else fn_obj + latency = c.time() - float(headers['time']) + if c.is_generator(result): + output = '' + def generator_wrapper(generator): + for item in generator: + output += str(item) + yield item + result = EventSourceResponse(generator_wrapper(result)) + else: + output = self.serializer.serialize(result) + if not self.free: + user_data = { + 'fn': fn, + 'data': data, # the data of the request + 'output': output, # the response + 'time': headers["time"], # the time of the request + 'latency': latency, # the latency of the request + 'key': headers['key'], # the key of the user + 'cost': module.fn2cost.get(fn, 1), # the cost of the function + } + user_path = self.user_path(f'{user_data["key"]}/{user_data["fn"]}/{c.time()}.json') + c.put(user_path, user_data) + return result + + def sync_loop(self, sync_loop_initial_sleep=10): + c.sleep(sync_loop_initial_sleep) + while True: + try: + r = self.sync() + except Exception as e: + r = c.detailed_error(e) + c.print('Error in sync_loop -->', r, color='red') + c.sleep(self.max_network_staleness) + + def set_network(self, network): + self.network = network + self.network_path = self.resolve_path(f'networks/{self.network}/state.json') + self.address2key = c.address2key() + c.thread(self.sync_loop) + return {'success':True, 'message':f'Set network to {network}', 'network':network, 'network_path':self.network_path} + + def sync(self, update=True , state_keys = ['stake_from', 'stake_to']): + self.network_path = self.resolve_path(f'networks/{self.network}/state.json') + print(f'Sync({self.network_path})') + if hasattr(self, 'state'): + latency = c.time() - self.state.get('time', 0) + if latency < self.max_network_staleness: + return {'msg': 'state is fresh'} + max_age = self.max_network_staleness + network_path = self.network_path + state = c.get(network_path, {}, max_age=max_age, updpate=update) + state = {} + state['stake'] = {} + state['stake_to'] = {} + state['stake_from'] = {} + if update: + try : + c.namespace(max_age=max_age) + self.subspace = c.module('subspace')(network=self.network) + state['stake_from'] = self.subspace.stake_from(fmt='j', update=update, max_age=max_age) + state['stake_to'] = self.subspace.stake_to(fmt='j', update=update, max_age=max_age) + state['stake'] = {k: sum(v.values()) for k,v in state['stake_from'].items()} + except Exception as e: + c.print(f'Error {e} while syncing network') + is_valid_state = lambda x: all([k in x for k in state_keys]) + assert is_valid_state(state), f'Format for network state is {[k for k in state_keys if k not in state]}' + c.put(network_path, state) + self.state = state + return {'msg': 'state synced successfully'} + + @classmethod + def wait_for_server(cls, + name: str , + network: str = 'local', + timeout:int = 600, + max_age = 1, + sleep_interval: int = 1) -> bool : + + time_waiting = 0 + # rotating status thing + c.print(f'waiting for {name} to start...', color='cyan') + + while time_waiting < timeout: + namespace = c.namespace(network=network, max_age=max_age) + if name in namespace: + try: + result = c.call(namespace[name]+'/info') + if 'key' in result: + c.print(f'{name} is running', color='green') + return result + except Exception as e: + c.print(f'Error getting info for {name} --> {e}', color='red') + c.sleep(sleep_interval) + time_waiting += sleep_interval + raise TimeoutError(f'Waited for {timeout} seconds for {name} to start') + + @classmethod + def endpoint(cls, + cost = 1, + user2rate : dict = None, + rate_limit : int = 100, # calls per minute + timestale : int = 60, + public:bool = False, + **kwargs): + def decorator_fn(fn): + metadata = { + 'schema':c.schema(fn), + 'cost': cost, + 'rate_limit': rate_limit, + 'user2rate': user2rate, + 'timestale': timestale, + 'public': public, + } + fn.__dict__['__metadata__'] = metadata + return fn + return decorator_fn + + serverfn = endpoint + + @classmethod + def kill(cls, name:str, verbose:bool = True, **kwargs): + try: + if name == 'all': + return cls.kill_all(verbose=verbose) + c.cmd(f"pm2 delete {name}", verbose=False) + cls.rm_logs(name) + result = {'message':f'Killed {name}', 'success':True} + except Exception as e: + result = {'message':f'Error killing {name}', 'success':False, 'error':e} + + c.remove_server(name) + return result + + @classmethod + def kill_all_processes(cls, verbose:bool = True, timeout=20): + servers = cls.processes() + futures = [c.submit(cls.kill, kwargs={'name':s, 'update': False}, return_future=True) for s in servers] + results = c.wait(futures, timeout=timeout) + + return results + + @classmethod + def kill_all_servers(cls, network='local', timeout=20, verbose=True): + servers = c.servers(network=network) + futures = [c.submit(cls.kill, kwargs={'module':s, 'update': False}, return_future=True) for s in servers] + return c.wait(futures, timeout=timeout) + + @classmethod + def kill_all(cls, mode='process', verbose:bool = True, timeout=20): + if mode == 'process': + results = cls.kill_all_processes(verbose=verbose, timeout=timeout) + elif mode == 'server': + results = cls.kill_all_servers(verbose=verbose, timeout=timeout) + else: + raise NotImplementedError(f'mode {mode} not implemented') + c.namespace(update=True) + return results + + @classmethod + def killall(cls, **kwargs): + return cls.kill_all(**kwargs) + + @classmethod + def logs_path_map(cls, name=None): + logs_path_map = {} + for l in c.ls(f'{cls.pm2_dir}/logs/'): + key = '-'.join(l.split('/')[-1].split('-')[:-1]).replace('-',':') + logs_path_map[key] = logs_path_map.get(key, []) + [l] + for k in logs_path_map.keys(): + logs_path_map[k] = {l.split('-')[-1].split('.')[0]: l for l in list(logs_path_map[k])} + if name != None: + return logs_path_map.get(name, {}) + return logs_path_map + + @classmethod + def rm_logs( cls, name): + logs_map = cls.logs_path_map(name) + for k in logs_map.keys(): + c.rm(logs_map[k]) + + @classmethod + def logs(cls, module:str, tail: int =100, mode: str ='cmd', **kwargs): + + if mode == 'local': + text = '' + for m in ['out','error']: + # I know, this is fucked + path = f'{cls.pm2_dir}/logs/{module.replace("/", "-")}-{m}.log'.replace(':', '-').replace('_', '-') + try: + text += c.get_text(path, tail=tail) + except Exception as e: + c.print('ERROR GETTING LOGS -->' , e) + continue + return text + elif mode == 'cmd': + return c.cmd(f"pm2 logs {module}", verbose=True) + else: + raise NotImplementedError(f'mode {mode} not implemented') + + def get_logs(self, tail=100, mode='local'): + return self.logs(self.module.name, tail=tail, mode=mode) + + @classmethod + def kill_many(cls, search=None, verbose:bool = True, timeout=10): + futures = [] + for name in c.servers(search=search): + f = c.submit(c.kill, dict(name=name, verbose=verbose), return_future=True, timeout=timeout) + futures.append(f) + return c.wait(futures) + + @classmethod + def start_process(cls, + fn: str = 'serve', + module:str = None, + name:Optional[str]=None, + args : list = None, + kwargs: dict = None, + interpreter:str='python3', + autorestart: bool = True, + verbose: bool = False , + force:bool = True, + run_fn: str = 'run_fn', + cwd : str = None, + env : Dict[str, str] = None, + refresh:bool=True , + **extra_kwargs): + env = env or {} + if '/' in fn: + module, fn = fn.split('/') + module = module or cls.module_name() + name = name or module + if refresh: + cls.kill(name) + cmd = f"pm2 start {c.filepath()} --name {name} --interpreter {interpreter}" + cmd = cmd if autorestart else ' --no-autorestart' + cmd = cmd + ' -f ' if force else cmd + kwargs = {'module': module , 'fn': fn, 'args': args or [], 'kwargs': kwargs or {} } + kwargs_str = json.dumps(kwargs).replace('"', "'") + cmd = cmd + f' -- --fn {run_fn} --kwargs "{kwargs_str}"' + stdout = c.cmd(cmd, env=env, verbose=verbose, cwd=cwd) + return {'success':True, 'msg':f'Launched {module}', 'cmd': cmd, 'stdout':stdout} + remote_fn = launch = start_process + + @classmethod + def restart(cls, name:str): + assert name in cls.processes() + c.print(f'Restarting {name}', color='cyan') + c.cmd(f"pm2 restart {name}", verbose=False) + cls.rm_logs(name) + return {'success':True, 'message':f'Restarted {name}'} + + @classmethod + def processes(cls, search=None, **kwargs) -> List[str]: + output_string = c.cmd('pm2 status', verbose=False) + module_list = [] + for line in output_string.split('\n')[3:]: + if line.count('│') > 2: + name = line.split('│')[2].strip() + module_list += [name] + if search != None: + module_list = [m for m in module_list if search in m] + module_list = sorted(list(set(module_list))) + return module_list + + @classmethod + def procs(cls, **kwargs): + return cls.processes(**kwargs) + + + @classmethod + def process_exists(cls, name:str, **kwargs) -> bool: + return name in cls.processes(**kwargs) + + @classmethod + def serve(cls, + module: Any = None, + kwargs:Optional[dict] = None, # kwargs for the module + port :Optional[int] = None, # name of the server if None, it will be the module name + name = None, # name of the server if None, it will be the module name + remote:bool = True, # runs the server remotely (pm2, ray) + functions = None, # list of functions to serve, if none, it will be the endpoints of the module + key = None, # the key for the server + free = False, + cwd = None, + **extra_kwargs + ): + module = module or 'module' + name = name or module + kwargs = {**(kwargs or {}), **extra_kwargs} + c.print(f'Serving(module={module} params={kwargs} name={name} function={functions})') + if not isinstance(module, str): + remote = False + if remote: + rkwargs = {k : v for k, v in c.locals2kwargs(locals()).items() if k not in ['extra_kwargs', 'response', 'namespace']} + rkwargs['remote'] = False + cls.start_process(fn='serve', name=name, kwargs=rkwargs, cwd=cwd) + return cls.wait_for_server(name) + return Server(module=module, name=name, functions = functions, kwargs=kwargs, port=port, key=key, free = free) + + def extract_time(self, x): + try: + x = float(x.split('/')[-1].split('.')[0]) + except Exception as e: + x = 0 + return x + + @classmethod + def fleet(cls, module, n=10, timeout=10): + futures = [ c.submit(c.serve, {'module':module + '::' + str(i)}, timeout=timeout) for i in range(n)] + progress = c.progress(futures) + results = [] + for f in c.as_completed(futures, timeout=timeout): + r = f.result() + results.append(r) + progress.update() + return results + def remove_user_data(self, address): + return c.rm(self.user_path(address)) + + def users(self): + return os.listdir(self.users_path) + + def user2count(self): + user2count = {} + for user in self.users(): + user2count[user] = self.user_call_count(user) + return user2count + + def history(self, user): + return self.user_data(user) + + def user2fn2count(self): + user2fn2count = {} + for user in self.users(): + user2fn2count[user] = {} + for user_data in self.user_data(user): + fn = user_data['fn'] + user2fn2count[user][fn] = user2fn2count[user].get(fn, 0) + 1 + return user2fn2count + + def user_call_paths(self, address ): + user_paths = c.glob(self.user_path(address)) + return sorted(user_paths, key=self.extract_time) + + def user_data(self, address, stream=False): + user_paths = self.user_call_paths(address) + if stream: + def stream_fn(): + for user_path in user_paths: + yield c.get(user_path) + return stream_fn() + + else: + return [c.get(user_path) for user_path in user_paths] + + def user_path(self, key_address): + return self.users_path + '/' + key_address + + def user_call_count(self, user): + self.check_user_data(user) + return len(self.user_call_paths(user)) + + def users(self): + return os.listdir(self.users_path) + + def user_path2time(self, address): + user_paths = self.user_call_paths(address) + user_path2time = {user_path: self.extract_time(user_path) for user_path in user_paths} + return user_path2time + + def user_call_path2latency(self, address): + user_paths = self.user_call_paths(address) + t0 = c.time() + user_path2time = {user_path: t0 - self.extract_time(user_path) for user_path in user_paths} + return user_path2time + + def check_user_data(self, address): + path2latency = self.user_call_path2latency(address) + for path, latency in path2latency.items(): + if latency > self.user_data_lifetime: + c.print(f'Removing stale path {path} ({latency}/{self.period})') + if os.path.exists(path): + os.remove(path) + + def resolve_path(self, path): + return c.resolve_path(path, storage_dir=self.storage_dir()) + + def set_user_path(self, users_path): + self.users_path = users_path or self.resolve_path(f'users/{self.module.name}') + + + def add_endpoint(self, name, fn): + setattr(self, name, fn) + self.endpoints.append(name) + assert hasattr(self, name), f'{name} not added to {self.__class__.__name__}' + return {'success':True, 'message':f'Added {fn} to {self.__class__.__name__}'} + +if __name__ == '__main__': + Server.run() + diff --git a/commune/server/test_server.py b/commune/server/test_server.py new file mode 100644 index 000000000..3719fc1a6 --- /dev/null +++ b/commune/server/test_server.py @@ -0,0 +1,69 @@ + +import commune as c + + +def test(): + self = c.module('serializer')() + import torch, time + data_list = [ + torch.ones(1000), + torch.zeros(1000), + torch.rand(1000), + [1,2,3,4,5], + {'a':1, 'b':2, 'c':3}, + 'hello world', + c.df([{'name': 'joe', 'fam': 1}]), + 1, + 1.0, + True, + False, + None + + ] + for data in data_list: + t1 = time.time() + ser_data = self.serialize(data) + des_data = self.deserialize(ser_data) + des_ser_data = self.serialize(des_data) + t2 = time.time() + + latency = t2 - t1 + emoji = '✅' if str(des_ser_data) == str(ser_data) else '❌' + print(type(data),emoji) + return {'msg': 'PASSED test_serialize_deserialize'} + + +def test_basics() -> dict: + servers = c.servers() + c.print(servers) + name = f'module::test' + c.serve(name) + c.kill(name) + assert name not in c.servers() + return {'success': True, 'msg': 'server test passed'} + +def test_serving(name = 'module::test'): + module = c.serve(name) + module = c.connect(name) + r = module.info() + assert 'name' in r, f"get failed {r}" + c.kill(name) + assert name not in c.servers(update=1) + return {'success': True, 'msg': 'server test passed'} + +def test_serving_with_different_key(module = 'module', timeout=10): + tag = 'test_serving_with_different_key' + key_name = module + '::'+ tag + module_name = module + '::'+ tag + '_b' + if not c.key_exists(key_name): + key = c.add_key(key_name) + c.print(c.serve(module_name, key=key_name)) + key = c.get_key(key_name) + c.sleep(2) + info = c.call(f'{module_name}/info', timeout=2) + assert info.get('key', None) == key.ss58_address , f" {info}" + c.kill(module_name) + c.rm_key(key_name) + assert not c.key_exists(key_name) + assert not c.server_exists(module_name) + return {'success': True, 'msg': 'server test passed'} \ No newline at end of file From ddbeb77c180db89ca579c1be99faaeea0575bf7c Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Wed, 15 Jan 2025 18:57:36 -0600 Subject: [PATCH 23/27] ref: reonstruct subspace module --- commune/subspace/test_subspace.py | 40 +++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 commune/subspace/test_subspace.py diff --git a/commune/subspace/test_subspace.py b/commune/subspace/test_subspace.py new file mode 100644 index 000000000..a7365cf82 --- /dev/null +++ b/commune/subspace/test_subspace.py @@ -0,0 +1,40 @@ +import commune as c + +def test_global_params(): + self = c.module('subspace')() + global_params = self.global_params() + assert isinstance(global_params, dict) + return {'msg': 'global_params test passed', 'success': True} + +def test_subnet_params(subnet=0): + self = c.module('subspace')() + subnet_params = self.subnet_params(subnet=subnet) + assert isinstance(subnet_params, dict), f'{subnet_params} is not a dict' + return {'msg': 'subnet_params test passed', 'success': True} + + +def test_module_params(keys=['dividends', 'incentive'], subnet=0): + self = c.module('subspace')() + key = self.keys(subnet)[0] + module_info = self.get_module(key, subnet=subnet) + assert isinstance(module_info, dict) + for k in keys: + assert k in module_info, f'{k} not in {module_info}' + + return {'msg': 'module_params test passed', 'success': True, 'module_info': module_info} + + +def test_substrate(): + self = c.module('subspace')() + for i in range(3): + t1 = c.time() + c.print(self.substrate) + t2 = c.time() + c.print(f'{t2-t1:.2f} seconds') + return {'msg': 'substrate test passed', 'success': True} + + + + + + From 165229abce1d5e3b0c27a0a2bbfcfd2f1fbe1682 Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Wed, 15 Jan 2025 18:58:05 -0600 Subject: [PATCH 24/27] ref: reconstruct validator module --- commune/vali/test_validator.py | 5 + commune/vali/vali.py | 264 +++++++++++++++++++++++++++++++++ 2 files changed, 269 insertions(+) create mode 100644 commune/vali/test_validator.py create mode 100644 commune/vali/vali.py diff --git a/commune/vali/test_validator.py b/commune/vali/test_validator.py new file mode 100644 index 000000000..c3a8112be --- /dev/null +++ b/commune/vali/test_validator.py @@ -0,0 +1,5 @@ + +import commune as c +import pandas as pd + +test_net = c.module('vali').test \ No newline at end of file diff --git a/commune/vali/vali.py b/commune/vali/vali.py new file mode 100644 index 000000000..06435509b --- /dev/null +++ b/commune/vali/vali.py @@ -0,0 +1,264 @@ + +import commune as c +import os +import pandas as pd +from typing import * + +class Vali(c.Module): + endpoints = ['score', 'scoreboard'] + voting_networks = ['bittensor', 'subspace'] + networks = ['local'] + voting_networks + epoch_time = 0 + vote_time = 0 + vote_staleness = 0 # the time since the last vote + epochs = 0 # the number of epochs + futures = [] # the futures for the parallel tasks + results = [] # the results of the parallel tasks + _clients = {} # the clients for the parallel tasks + + def __init__(self, + network= 'local', # for local subspace:test or test # for testnet subspace:main or main # for mainnet + subnet : Optional[Union[str, int]] = None, # (OPTIONAL) the name of the subnetwork + search : Optional[str] = None, # (OPTIONAL) the search string for the network + batch_size : int = 128, # the batch size of the most parallel tasks + max_workers : Optional[int]= None , # the number of parallel workers in the executor + score : Union['callable', int]= None, # score function + key : str = None, + path : str= None, # the storage path for the module eval, if not null then the module eval is stored in this directory + tempo : int = None , + timeout : int = 3, # timeout per evaluation of the module + update : bool =False, # update during the first epoch + run_loop : bool = True, # This is the key that we need to change to false + **kwargs): + + self.timeout = timeout or 3 + self.max_workers = max_workers or c.cpu_count() * 5 + self.batch_size = batch_size or 128 + self.executor = c.module('executor')(max_workers=self.max_workers, maxsize=self.batch_size) + self.set_key(key) + self.set_network(network=network, subnet=subnet, tempo=tempo, search=search, path=path, score=score, update=update) + if run_loop: + c.thread(self.run_loop) + init_vali = __init__ + + + def set_key(self, key): + self.key = c.get_key(key or self.module_name()) + return {'success': True, 'msg': 'Key set', 'key': self.key} + + def set_network(self, network:str, + subnet:str=None, + tempo:int=60, + search:str=None, + path:str=None, + score = None, + update=False): + + + if not network in self.networks and '/' not in network: + network = f'subspace/{network}' + [network, subnet] = network.split('/') if '/' in network else [network, subnet] + self.subnet = subnet + self.network = network + self.network_module = c.module(self.network)() + self.tempo = tempo + self.search = search + self.path = os.path.abspath(path or self.resolve_path(f'{network}/{subnet}' if subnet else network)) + self.is_voting_network = any([v in self.network for v in self.voting_networks]) + + self.set_score(score) + self.sync(update=update) + + + def score(self, module): + return int('name' in module.info()) + + def set_score(self, score): + if callable(score): + setattr(self, 'score', score ) + assert callable(self.score), f'SCORE NOT SET {self.score}' + return {'success': True, 'msg': 'Score function set'} + + def run_loop(self): + while True: + try: + self.epoch() + except Exception as e: + c.print('XXXXXXXXXX EPOCH ERROR ----> XXXXXXXXXX ',c.detailed_error(e), color='red') + @property + def time_until_next_epoch(self): + return int(self.epoch_time + self.tempo - c.time()) + + + def get_client(self, module:dict): + if module['key'] in self._clients: + client = self._clients[module['key']] + else: + client = c.connect(module['address'], key=self.key) + self._clients[module['key']] = client + return client + + def score_module(self, module:dict, **kwargs): + """ + module: dict + name: str + address: str + key: str + time: int + """ + + module['time'] = c.time() # the timestamp + client = self.get_client(module) + module['score'] = self.score(client, **kwargs) + module['latency'] = c.time() - module['time'] + module['path'] = self.path +'/'+ module['key'] + return module + + def score_modules(self, modules: List[dict]): + module_results = [] + futures = [self.executor.submit(self.score_module, [m], timeout=self.timeout) for m in modules] + try: + for f in c.as_completed(futures, timeout=self.timeout): + m = f.result() + if m.get('score', 0) > 0: + c.put_json(m['path'], m) + module_results.append(m) + except Exception as e: + c.print(f'ERROR({c.detailed_error(e)})', color='red', verbose=1) + + return module_results + + def epoch(self): + next_epoch = self.time_until_next_epoch + progress = c.tqdm(total=next_epoch, desc='Next Epoch') + for _ in range(next_epoch): + progress.update(1) + c.sleep(1) + self.sync() + c.print(f'Epoch(network={self.network} epoch={self.epochs} n={self.n} )', color='yellow') + batches = [self.modules[i:i+self.batch_size] for i in range(0, self.n, self.batch_size)] + progress = c.tqdm(total=len(batches), desc='Evaluating Modules') + results = [] + for i, module_batch in enumerate(batches): + print(f'Batch(i={i}/{len(batches)})') + results += self.score_modules(module_batch) + progress.update(1) + self.epochs += 1 + self.epoch_time = c.time() + print(self.scoreboard()) + self.vote(results) + return results + + def sync(self, update = False): + max_age = 0 if update else (self.tempo or 60) + self.modules = self.network_module.modules(subnet=self.subnet, max_age=max_age) + self.params = self.network_module.params(subnet=self.subnet, max_age=max_age) + self.tempo = self.tempo or (self.params['tempo'] * self.network_module.block_time)//2 + print(self.tempo) + if self.search != None: + self.modules = [m for m in self.modules if self.search in m['name']] + self.n = len(self.modules) + self.network_info = {'n': self.n, 'network': self.network , 'subnet': self.subnet, 'params': self.params} + c.print(f' 0 : + df += [{k: r.get(k, None) for k in keys}] + else : + self.rm(path) + df = c.df(df) + if len(df) > 0: + if isinstance(by, str): + by = [by] + df = df.sort_values(by=by, ascending=ascending) + # if to_dict is true, we return the dataframe as a list of dictionaries + if to_dict: + return df.to_dict(orient='records') + if len(df) > page_size: + pages = len(df)//page_size + page = page or 0 + df = df[page*page_size:(page+1)*page_size] + + return df + + def module_paths(self): + paths = self.ls(self.path) + return paths + + @classmethod + def run_epoch(cls, network='local', run_loop=False, update=False, **kwargs): + return cls(network=network, run_loop=run_loop, update=update, **kwargs).epoch() + + @staticmethod + def test( + n=2, + tag = 'vali_test_net', + miner='module', + trials = 5, + tempo = 4, + update=True, + path = '/tmp/commune/vali_test', + network='local' + ): + test_miners = [f'{miner}::{tag}{i}' for i in range(n)] + modules = test_miners + search = tag + assert len(modules) == n, f'Number of miners not equal to n {len(modules)} != {n}' + for m in modules: + c.serve(m) + namespace = c.namespace() + for m in modules: + assert m in namespace, f'Miner not in namespace {m}' + vali = Vali(network=network, search=search, path=path, update=update, tempo=tempo, run_loop=False) + print(vali.modules) + scoreboard = [] + while len(scoreboard) < n: + c.sleep(1) + scoreboard = vali.epoch() + trials -= 1 + assert trials > 0, f'Trials exhausted {trials}' + for miner in modules: + c.print(c.kill(miner)) + return {'success': True, 'msg': 'subnet test passed'} + + def refresh_scoreboard(self): + path = self.path + c.rm(path) + return {'success': True, 'msg': 'Leaderboard removed', 'path': path} \ No newline at end of file From f76452788c359674c2e33361f71c5b2b708ec159 Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Wed, 15 Jan 2025 18:58:59 -0600 Subject: [PATCH 25/27] chore: use modular approach with folder --- commune/__init__.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/commune/__init__.py b/commune/__init__.py index bd85c776f..850bd9e45 100755 --- a/commune/__init__.py +++ b/commune/__init__.py @@ -1,13 +1,14 @@ +from .module import Module # the module module +M = c = Block = Agent = Module # alias c.Module as c.Block, c.Lego, c.M +from .vali.vali import Vali # the vali module +from .server.server import Server # the server module +from .client import Client # the client module +from .key.key import Key # the key module - -from .module import Module # the module module -M = c = Block = Agent = Module # alias c.Module as c.Block, c.Lego, c.M -from .vali import Vali # the vali module -from .server import Server # the server module -from .client import Client # the client module -from .key.key import Key # the key module # set the module functions as globalsw c.add_to_globals(globals()) -key = c.get_key # override key function with file key in commune/key.py TODO: remove this line with a better solution +key = ( + c.get_key +) # override key function with file key in commune/key.py TODO: remove this line with a better solution network = c.network From 2f4fdf832342f09e5dd27b0d063b1600b80d675c Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Wed, 15 Jan 2025 19:00:19 -0600 Subject: [PATCH 26/27] test: use decentralized tests --- commune/network.py | 101 ------ commune/server.py | 664 ---------------------------------------- commune/vali.py | 264 ---------------- tests/test_key.py | 103 ------- tests/test_server.py | 69 ----- tests/test_subspace.py | 40 --- tests/test_validator.py | 5 - 7 files changed, 1246 deletions(-) delete mode 100644 commune/network.py delete mode 100644 commune/server.py delete mode 100644 commune/vali.py delete mode 100644 tests/test_key.py delete mode 100644 tests/test_server.py delete mode 100644 tests/test_subspace.py delete mode 100644 tests/test_validator.py diff --git a/commune/network.py b/commune/network.py deleted file mode 100644 index da3e9c674..000000000 --- a/commune/network.py +++ /dev/null @@ -1,101 +0,0 @@ -from typing import * -import os -import commune as c -class Network(c.Module): - min_stake = 0 - blocktime = block_time = 8 - n = 100 - tempo = 60 - blocks_per_day = 24*60*60/block_time - # the default - endpoints = ['namespace'] - def __init__(self, network:str='local', tempo=tempo, path=None, **kwargs): - self.set_network(network=network, tempo=tempo, path=path) - - def set_network(self, network:str, tempo:int=60, path=None, **kwargs): - self.network = network - self.tempo = tempo - self.modules_path = self.resolve_path(path or f'{self.network}/modules') - return {'network': self.network, 'tempo': self.tempo, 'modules_path': self.modules_path} - - def params(self,*args, **kwargs): - return { 'network': self.network, - 'tempo' : self.tempo, - 'n': self.n} - - def net(self): - return c.network() - - def modules(self, - search=None, - max_age=tempo, - update=False, - features=['name', 'address', 'key'], - timeout=8, - **kwargs): - modules = c.get(self.modules_path, max_age=max_age, update=update) - if modules == None: - modules = [] - addresses = ['0.0.0.0'+':'+str(p) for p in c.used_ports()] - futures = [c.submit(c.call, [s + '/info'], timeout=timeout) for s in addresses] - try: - for f in c.as_completed(futures, timeout=timeout): - data = f.result() - if all([k in data for k in features]): - modules.append({k: data[k] for k in features}) - except Exception as e: - c.print('Error getting modules', e) - modules = [] - c.put(self.modules_path, modules) - if search != None: - modules = [m for m in modules if search in m['name']] - return modules - - def namespace(self, search=None, max_age:int = tempo, update:bool = False, **kwargs) -> dict: - return {m['name']: '0.0.0.0' + ':' + m['address'].split(':')[-1] for m in self.modules(search=search, max_age=max_age, update=update)} - - def add_server(self, name:str, address:str, key:str) -> None: - data = {'name': name, 'address': address, 'key': key} - modules = self.modules() - modules.append(data) - c.put(self.modules_path, modules) - return {'success': True, 'msg': f'Block {name}.'} - - def register_from_signature(self, signature=None): - import json - assert c.verify(signature), 'Signature is not valid.' - data = json.loads(signature['data']) - return self.add_server(data['name'], data['address']) - - def remove_server(self, name:str, features=['name', 'key', 'address']) -> Dict: - modules = self.modules() - modules = [m for m in modules if not any([m[f] == name for f in features])] - c.put(self.modules_path, modules) - - def resolve_network(self, network:str) -> str: - return network or self.network - - def names(self, *args, **kwargs) -> List[str]: - return list(self.namespace(*args, **kwargs).keys()) - - def addresses(self,*args, **kwargs) -> List[str]: - return list(self.namespace(*args, **kwargs).values()) - - def servers(self, search=None, **kwargs) -> List[str]: - namespace = self.namespace(search=search,**kwargs) - return list(namespace.keys()) - - def server_exists(self, name:str, **kwargs) -> bool: - servers = self.servers(**kwargs) - return bool(name in servers) - - def networks(self) -> List[str]: - return ['local', 'subspace', 'subtensor'] - - def infos(self, *args, **kwargs) -> Dict: - return [c.call(address+'/info') for name, address in self.namespace(*args, **kwargs).items()] - -if __name__ == "__main__": - Network.run() - - diff --git a/commune/server.py b/commune/server.py deleted file mode 100644 index eec9dae89..000000000 --- a/commune/server.py +++ /dev/null @@ -1,664 +0,0 @@ -import commune as c -from typing import * -from fastapi import FastAPI, Request -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse -from starlette.middleware.base import BaseHTTPMiddleware -from sse_starlette.sse import EventSourceResponse -import uvicorn -import os -import json -import asyncio - -class Middleware(BaseHTTPMiddleware): - def __init__(self, app, max_bytes: int): - super().__init__(app) - self.max_bytes = max_bytes - async def dispatch(self, request: Request, call_next): - content_length = request.headers.get('content-length') - if content_length: - if int(content_length) > self.max_bytes: - return JSONResponse(status_code=413, content={"error": "Request too large"}) - body = await request.body() - if len(body) > self.max_bytes: - return JSONResponse(status_code=413, content={"error": "Request too large"}) - response = await call_next(request) - return response - -class Server: - tag_seperator:str='::' - user_data_lifetime = 3600 - pm2_dir = os.path.expanduser('~/.pm2') - period : int = 3600 # the period for - max_request_staleness : int = 4 # (in seconds) the time it takes for the request to be too old - max_network_staleness: int = 60 # (in seconds) the time it takes for. the network to refresh - multipliers : Dict[str, float] = {'stake': 1, 'stake_to': 1,'stake_from': 1} - rates : Dict[str, int]= {'max': 10, 'local': 10000, 'stake': 1000, 'owner': 10000, 'admin': 10000} # the maximum rate ): - helper_functions = ['info', 'metadata', 'schema', 'free', 'name', 'functions','key_address', 'crypto_type','fns', 'forward', 'rate_limit'] # the helper functions - functions_attributes =['helper_functions', 'whitelist', "whitelist_functions", 'endpoints', 'functions', 'fns', "exposed_functions",'server_functions', 'public_functions'] # the attributes for the functions - def __init__( - self, - module: Union[c.Module, object] = None, - key:str = None, # key for the server (str) - name: str = None, # the name of the server - functions:Optional[List[Union[str, callable]]] = None, # list of endpoints - port: Optional[int] = None, # the port the server is running on - network:str = 'subspace', # the network used for incentives - fn2cost : Dict[str, float] = None, # the cost of the function - free : bool = False, # if the server is free (checks signature) - kwargs : dict = None, # the kwargs for the module - crypto_type = 'sr25519', # the crypto type of the key - users_path: Optional[str] = None, # the path to the user data - serializer: str = 'serializer', # the serializer used for the data - ) -> 'Server': - module = module or 'module' - kwargs = kwargs or {} - if self.tag_seperator in str(name): - # module::fam -> module=module, name=module::fam key=module::fam (default) - module, tag = name.split(self.tag_seperator) - module = c.module(module)(**kwargs) - if isinstance(module, str): - name = name or module - module = c.module(module)(**kwargs) - print(f'Launching', module, name, functions) - # NOTE: ONLY ENABLE FREEMODE IF YOU ARE ON A CLOSED NETWORK, - self.free = free - self.module = module - self.module.name = name - self.set_key(key=key, crypto_type=crypto_type) - self.set_port(port) - self.set_network(network) - self.set_functions(functions=functions, fn2cost=fn2cost) - self.set_user_path(users_path) - self.serializer = c.module(serializer)() - self.start_server() - - def set_functions(self, functions:Optional[List[str]] , fn2cost=None): - if self.free: - c.print('THE FOUNDING FATHERS WOULD BE PROUD OF YOU SON OF A BITCH', color='red') - else: - if hasattr(self.module, 'free'): - self.free = self.module.free - self.module.free = self.free - functions = functions or [] - for i, fn in enumerate(functions): - if callable(fn): - print('Adding function', f) - setattr(self, fn.__name__, fn) - functions[i] = fn.__name__ - functions = sorted(list(set(functions + self.helper_functions))) - module = self.module - for k in self.functions_attributes: - if hasattr(module, k) and isinstance(getattr(module, k), list): - print('Found ', k) - functions = getattr(module, k) - break - # get function decorators form c.endpoint() - for f in dir(module): - try: - if hasattr(getattr(module, f), '__metadata__'): - functions.append(f) - except Exception as e: - c.print(f'Error in get_endpoints: {e} for {f}') - module.functions = sorted(list(set(functions))) - ## get the schema for the functions - schema = {} - for fn in functions : - if hasattr(module, fn): - schema[fn] = c.schema(getattr(module, fn )) - else: - print(f'Function {fn} not found in {module.name}') - module.schema = dict(sorted(schema.items())) - module.fn2cost = module.fn2cost if hasattr(module, 'fn2cost') else (fn2cost or {}) - assert isinstance(module.fn2cost, dict), f'fn2cost must be a dict, not {type(module.fn2cost)}' - - ### get the info for the module - module.info = { - "functions": functions, - "schema": schema, - "name": module.name, - "address": module.address, - "key": module.key.ss58_address, - "crypto_type": module.key.crypto_type, - "fn2cost": module.fn2cost, - "free": module.free, - "time": c.time() - } - - def set_key(self, key, crypto_type): - module = self.module - module.key = c.get_key(key or module.name, create_if_not_exists=True, crypto_type=crypto_type) - module.key_address = module.key.key_address - module.crypto_type = module.key.crypto_type - return {'success':True, 'message':f'Set key to {module.key.ss58_address}'} - - def start_server(self, - max_bytes = 10 * 1024 * 1024 , # max bytes within the request (bytes) - allow_origins = ["*"], # allowed origins - allow_credentials =True, # allow credentials - allow_methods = ["*"], # allowed methods - allow_headers = ["*"] , # allowed headers - ): - module = self.module - c.thread(self.sync_loop) - self.loop = asyncio.get_event_loop() - app = FastAPI() - app.add_middleware(Middleware, max_bytes=max_bytes) - app.add_middleware(CORSMiddleware, - allow_origins=allow_origins, - allow_credentials=allow_credentials, - allow_methods=allow_methods, - allow_headers=allow_headers) - def api_forward(fn:str, request: Request): - return self.forward(fn, request) - app.post("/{fn}")(api_forward) - c.print(f'Served(name={module.name}, address={module.address}, key={module.key.key_address})', color='purple') - c.print(c.add_server(name=module.name, address=module.address, key=module.key.ss58_address)) - self.module = module - uvicorn.run(app, host='0.0.0.0', port=module.port, loop='asyncio') - - def set_port(self, port:Optional[int]=None, port_attributes = ['port', 'server_port'], ip = None): - module = self.module - name = module.name - for k in port_attributes: - if hasattr(module, k): - port = getattr(module, k) - break - if port in [None, 'None']: - namespace = c.namespace() - if name in namespace: - c.kill(name) - try: - port = int(namespace.get(module.name).split(':')[-1]) - except: - port = c.free_port() - else: - port = c.free_port() - - while c.port_used(port): - c.kill_port(port) - c.sleep(1) - print(f'Waiting for port {port} to be free') - - module.port = port - ip = ip or '0.0.0.0' - module.address = ip + ':' + str(module.port) - self.module = module - return {'success':True, 'message':f'Set port to {port}'} - - def is_admin(self, address): - return c.is_admin(address) - - def gate(self, fn:str, data:dict, headers:dict ) -> bool: - if self.free: - assert fn in self.module.functions , f"Function {fn} not in endpoints={self.module.functions}" - return True - auth = {'data': data, 'time': str(headers['time'])} - signature = headers['signature'] - - assert c.verify(auth=auth,signature=signature, address=headers['key']), 'Invalid signature' - request_staleness = c.time() - float(headers['time']) - assert request_staleness < self.max_request_staleness, f"Request is too old ({request_staleness}s > {self.max_request_staleness}s (MAX)" - auth={'data': data, 'time': str(headers['time'])} - module = self.module - address = headers['key'] - if c.is_admin(address): - rate_limit = self.rates['admin'] - elif address == module.key.ss58_address: - rate_limit = self.rates['owner'] - elif address in self.address2key: - rate_limit = self.rates['local'] - else: - stake_score = self.state['stake'].get(address, 0) + self.multipliers['stake'] - stake_to_score = (sum(self.state['stake_to'].get(address, {}).values())) * self.multipliers['stake_to'] - stake_from_score = self.state['stake_from'].get(module.key.ss58_address, {}).get(address, 0) * self.multipliers['stake_from'] - stake = stake_score + stake_to_score + stake_from_score - self.rates['stake'] = self.rates['stake'] * module.fn2cost.get(fn, 1) - rate_limit = min((stake / self.rates['stake']), self.rates['max']) - count = self.user_call_count(headers['key']) - assert count <= rate_limit, f'rate limit exceeded {count} > {rate_limit}' - return True - - def get_data(self, request: Request): - data = self.loop.run_until_complete(request.json()) - # data = self.serializer.deserialize(data) - if isinstance(data, str): - data = json.loads(data) - if 'kwargs' in data or 'params' in data: - kwargs = dict(data.get('kwargs', data.get('params', {}))) - else: - kwargs = data - if 'args' in data: - args = list(data.get('args', [])) - else: - args = [] - data = {'args': args, 'kwargs': kwargs} - return data - - def get_headers(self, request: Request): - headers = dict(request.headers) - headers['time'] = float(headers.get('time', c.time())) - headers['key'] = headers.get('key', headers.get('address', None)) - return headers - - def forward(self, fn:str, request: Request, catch_exception:bool=True) -> dict: - if catch_exception: - try: - return self.forward(fn, request, catch_exception=False) - except Exception as e: - result = c.detailed_error(e) - c.print(result, color='red') - return result - module = self.module - data = self.get_data(request) - headers = self.get_headers(request) - self.gate(fn=fn, data=data, headers=headers) - is_admin = bool(c.is_admin(headers['key'])) - is_owner = bool(headers['key'] == module.key.ss58_address) - if hasattr(module, fn): - fn_obj = getattr(module, fn) - elif (is_admin or is_owner) and hasattr(self, fn): - fn_obj = getattr(self, fn) - else: - raise Exception(f"{fn} not found in {module.name}") - result = fn_obj(*data['args'], **data['kwargs']) if callable(fn_obj) else fn_obj - latency = c.time() - float(headers['time']) - if c.is_generator(result): - output = '' - def generator_wrapper(generator): - for item in generator: - output += str(item) - yield item - result = EventSourceResponse(generator_wrapper(result)) - else: - output = self.serializer.serialize(result) - if not self.free: - user_data = { - 'fn': fn, - 'data': data, # the data of the request - 'output': output, # the response - 'time': headers["time"], # the time of the request - 'latency': latency, # the latency of the request - 'key': headers['key'], # the key of the user - 'cost': module.fn2cost.get(fn, 1), # the cost of the function - } - user_path = self.user_path(f'{user_data["key"]}/{user_data["fn"]}/{c.time()}.json') - c.put(user_path, user_data) - return result - - def sync_loop(self, sync_loop_initial_sleep=10): - c.sleep(sync_loop_initial_sleep) - while True: - try: - r = self.sync() - except Exception as e: - r = c.detailed_error(e) - c.print('Error in sync_loop -->', r, color='red') - c.sleep(self.max_network_staleness) - - def set_network(self, network): - self.network = network - self.network_path = self.resolve_path(f'networks/{self.network}/state.json') - self.address2key = c.address2key() - c.thread(self.sync_loop) - return {'success':True, 'message':f'Set network to {network}', 'network':network, 'network_path':self.network_path} - - def sync(self, update=True , state_keys = ['stake_from', 'stake_to']): - self.network_path = self.resolve_path(f'networks/{self.network}/state.json') - print(f'Sync({self.network_path})') - if hasattr(self, 'state'): - latency = c.time() - self.state.get('time', 0) - if latency < self.max_network_staleness: - return {'msg': 'state is fresh'} - max_age = self.max_network_staleness - network_path = self.network_path - state = c.get(network_path, {}, max_age=max_age, updpate=update) - state = {} - state['stake'] = {} - state['stake_to'] = {} - state['stake_from'] = {} - if update: - try : - c.namespace(max_age=max_age) - self.subspace = c.module('subspace')(network=self.network) - state['stake_from'] = self.subspace.stake_from(fmt='j', update=update, max_age=max_age) - state['stake_to'] = self.subspace.stake_to(fmt='j', update=update, max_age=max_age) - state['stake'] = {k: sum(v.values()) for k,v in state['stake_from'].items()} - except Exception as e: - c.print(f'Error {e} while syncing network') - is_valid_state = lambda x: all([k in x for k in state_keys]) - assert is_valid_state(state), f'Format for network state is {[k for k in state_keys if k not in state]}' - c.put(network_path, state) - self.state = state - return {'msg': 'state synced successfully'} - - @classmethod - def wait_for_server(cls, - name: str , - network: str = 'local', - timeout:int = 600, - max_age = 1, - sleep_interval: int = 1) -> bool : - - time_waiting = 0 - # rotating status thing - c.print(f'waiting for {name} to start...', color='cyan') - - while time_waiting < timeout: - namespace = c.namespace(network=network, max_age=max_age) - if name in namespace: - try: - result = c.call(namespace[name]+'/info') - if 'key' in result: - c.print(f'{name} is running', color='green') - return result - except Exception as e: - c.print(f'Error getting info for {name} --> {e}', color='red') - c.sleep(sleep_interval) - time_waiting += sleep_interval - raise TimeoutError(f'Waited for {timeout} seconds for {name} to start') - - @classmethod - def endpoint(cls, - cost = 1, - user2rate : dict = None, - rate_limit : int = 100, # calls per minute - timestale : int = 60, - public:bool = False, - **kwargs): - def decorator_fn(fn): - metadata = { - 'schema':c.schema(fn), - 'cost': cost, - 'rate_limit': rate_limit, - 'user2rate': user2rate, - 'timestale': timestale, - 'public': public, - } - fn.__dict__['__metadata__'] = metadata - return fn - return decorator_fn - - serverfn = endpoint - - @classmethod - def kill(cls, name:str, verbose:bool = True, **kwargs): - try: - if name == 'all': - return cls.kill_all(verbose=verbose) - c.cmd(f"pm2 delete {name}", verbose=False) - cls.rm_logs(name) - result = {'message':f'Killed {name}', 'success':True} - except Exception as e: - result = {'message':f'Error killing {name}', 'success':False, 'error':e} - - c.remove_server(name) - return result - - @classmethod - def kill_all_processes(cls, verbose:bool = True, timeout=20): - servers = cls.processes() - futures = [c.submit(cls.kill, kwargs={'name':s, 'update': False}, return_future=True) for s in servers] - results = c.wait(futures, timeout=timeout) - - return results - - @classmethod - def kill_all_servers(cls, network='local', timeout=20, verbose=True): - servers = c.servers(network=network) - futures = [c.submit(cls.kill, kwargs={'module':s, 'update': False}, return_future=True) for s in servers] - return c.wait(futures, timeout=timeout) - - @classmethod - def kill_all(cls, mode='process', verbose:bool = True, timeout=20): - if mode == 'process': - results = cls.kill_all_processes(verbose=verbose, timeout=timeout) - elif mode == 'server': - results = cls.kill_all_servers(verbose=verbose, timeout=timeout) - else: - raise NotImplementedError(f'mode {mode} not implemented') - c.namespace(update=True) - return results - - @classmethod - def killall(cls, **kwargs): - return cls.kill_all(**kwargs) - - @classmethod - def logs_path_map(cls, name=None): - logs_path_map = {} - for l in c.ls(f'{cls.pm2_dir}/logs/'): - key = '-'.join(l.split('/')[-1].split('-')[:-1]).replace('-',':') - logs_path_map[key] = logs_path_map.get(key, []) + [l] - for k in logs_path_map.keys(): - logs_path_map[k] = {l.split('-')[-1].split('.')[0]: l for l in list(logs_path_map[k])} - if name != None: - return logs_path_map.get(name, {}) - return logs_path_map - - @classmethod - def rm_logs( cls, name): - logs_map = cls.logs_path_map(name) - for k in logs_map.keys(): - c.rm(logs_map[k]) - - @classmethod - def logs(cls, module:str, tail: int =100, mode: str ='cmd', **kwargs): - - if mode == 'local': - text = '' - for m in ['out','error']: - # I know, this is fucked - path = f'{cls.pm2_dir}/logs/{module.replace("/", "-")}-{m}.log'.replace(':', '-').replace('_', '-') - try: - text += c.get_text(path, tail=tail) - except Exception as e: - c.print('ERROR GETTING LOGS -->' , e) - continue - return text - elif mode == 'cmd': - return c.cmd(f"pm2 logs {module}", verbose=True) - else: - raise NotImplementedError(f'mode {mode} not implemented') - - def get_logs(self, tail=100, mode='local'): - return self.logs(self.module.name, tail=tail, mode=mode) - - @classmethod - def kill_many(cls, search=None, verbose:bool = True, timeout=10): - futures = [] - for name in c.servers(search=search): - f = c.submit(c.kill, dict(name=name, verbose=verbose), return_future=True, timeout=timeout) - futures.append(f) - return c.wait(futures) - - @classmethod - def start_process(cls, - fn: str = 'serve', - module:str = None, - name:Optional[str]=None, - args : list = None, - kwargs: dict = None, - interpreter:str='python3', - autorestart: bool = True, - verbose: bool = False , - force:bool = True, - run_fn: str = 'run_fn', - cwd : str = None, - env : Dict[str, str] = None, - refresh:bool=True , - **extra_kwargs): - env = env or {} - if '/' in fn: - module, fn = fn.split('/') - module = module or cls.module_name() - name = name or module - if refresh: - cls.kill(name) - cmd = f"pm2 start {c.filepath()} --name {name} --interpreter {interpreter}" - cmd = cmd if autorestart else ' --no-autorestart' - cmd = cmd + ' -f ' if force else cmd - kwargs = {'module': module , 'fn': fn, 'args': args or [], 'kwargs': kwargs or {} } - kwargs_str = json.dumps(kwargs).replace('"', "'") - cmd = cmd + f' -- --fn {run_fn} --kwargs "{kwargs_str}"' - stdout = c.cmd(cmd, env=env, verbose=verbose, cwd=cwd) - return {'success':True, 'msg':f'Launched {module}', 'cmd': cmd, 'stdout':stdout} - remote_fn = launch = start_process - - @classmethod - def restart(cls, name:str): - assert name in cls.processes() - c.print(f'Restarting {name}', color='cyan') - c.cmd(f"pm2 restart {name}", verbose=False) - cls.rm_logs(name) - return {'success':True, 'message':f'Restarted {name}'} - - @classmethod - def processes(cls, search=None, **kwargs) -> List[str]: - output_string = c.cmd('pm2 status', verbose=False) - module_list = [] - for line in output_string.split('\n')[3:]: - if line.count('│') > 2: - name = line.split('│')[2].strip() - module_list += [name] - if search != None: - module_list = [m for m in module_list if search in m] - module_list = sorted(list(set(module_list))) - return module_list - - @classmethod - def procs(cls, **kwargs): - return cls.processes(**kwargs) - - - @classmethod - def process_exists(cls, name:str, **kwargs) -> bool: - return name in cls.processes(**kwargs) - - @classmethod - def serve(cls, - module: Any = None, - kwargs:Optional[dict] = None, # kwargs for the module - port :Optional[int] = None, # name of the server if None, it will be the module name - name = None, # name of the server if None, it will be the module name - remote:bool = True, # runs the server remotely (pm2, ray) - functions = None, # list of functions to serve, if none, it will be the endpoints of the module - key = None, # the key for the server - free = False, - cwd = None, - **extra_kwargs - ): - module = module or 'module' - name = name or module - kwargs = {**(kwargs or {}), **extra_kwargs} - c.print(f'Serving(module={module} params={kwargs} name={name} function={functions})') - if not isinstance(module, str): - remote = False - if remote: - rkwargs = {k : v for k, v in c.locals2kwargs(locals()).items() if k not in ['extra_kwargs', 'response', 'namespace']} - rkwargs['remote'] = False - cls.start_process(fn='serve', name=name, kwargs=rkwargs, cwd=cwd) - return cls.wait_for_server(name) - return Server(module=module, name=name, functions = functions, kwargs=kwargs, port=port, key=key, free = free) - - def extract_time(self, x): - try: - x = float(x.split('/')[-1].split('.')[0]) - except Exception as e: - x = 0 - return x - - @classmethod - def fleet(cls, module, n=10, timeout=10): - futures = [ c.submit(c.serve, {'module':module + '::' + str(i)}, timeout=timeout) for i in range(n)] - progress = c.progress(futures) - results = [] - for f in c.as_completed(futures, timeout=timeout): - r = f.result() - results.append(r) - progress.update() - return results - def remove_user_data(self, address): - return c.rm(self.user_path(address)) - - def users(self): - return os.listdir(self.users_path) - - def user2count(self): - user2count = {} - for user in self.users(): - user2count[user] = self.user_call_count(user) - return user2count - - def history(self, user): - return self.user_data(user) - - def user2fn2count(self): - user2fn2count = {} - for user in self.users(): - user2fn2count[user] = {} - for user_data in self.user_data(user): - fn = user_data['fn'] - user2fn2count[user][fn] = user2fn2count[user].get(fn, 0) + 1 - return user2fn2count - - def user_call_paths(self, address ): - user_paths = c.glob(self.user_path(address)) - return sorted(user_paths, key=self.extract_time) - - def user_data(self, address, stream=False): - user_paths = self.user_call_paths(address) - if stream: - def stream_fn(): - for user_path in user_paths: - yield c.get(user_path) - return stream_fn() - - else: - return [c.get(user_path) for user_path in user_paths] - - def user_path(self, key_address): - return self.users_path + '/' + key_address - - def user_call_count(self, user): - self.check_user_data(user) - return len(self.user_call_paths(user)) - - def users(self): - return os.listdir(self.users_path) - - def user_path2time(self, address): - user_paths = self.user_call_paths(address) - user_path2time = {user_path: self.extract_time(user_path) for user_path in user_paths} - return user_path2time - - def user_call_path2latency(self, address): - user_paths = self.user_call_paths(address) - t0 = c.time() - user_path2time = {user_path: t0 - self.extract_time(user_path) for user_path in user_paths} - return user_path2time - - def check_user_data(self, address): - path2latency = self.user_call_path2latency(address) - for path, latency in path2latency.items(): - if latency > self.user_data_lifetime: - c.print(f'Removing stale path {path} ({latency}/{self.period})') - if os.path.exists(path): - os.remove(path) - - def resolve_path(self, path): - return c.resolve_path(path, storage_dir=self.storage_dir()) - - def set_user_path(self, users_path): - self.users_path = users_path or self.resolve_path(f'users/{self.module.name}') - - - def add_endpoint(self, name, fn): - setattr(self, name, fn) - self.endpoints.append(name) - assert hasattr(self, name), f'{name} not added to {self.__class__.__name__}' - return {'success':True, 'message':f'Added {fn} to {self.__class__.__name__}'} - -if __name__ == '__main__': - Server.run() - diff --git a/commune/vali.py b/commune/vali.py deleted file mode 100644 index 06435509b..000000000 --- a/commune/vali.py +++ /dev/null @@ -1,264 +0,0 @@ - -import commune as c -import os -import pandas as pd -from typing import * - -class Vali(c.Module): - endpoints = ['score', 'scoreboard'] - voting_networks = ['bittensor', 'subspace'] - networks = ['local'] + voting_networks - epoch_time = 0 - vote_time = 0 - vote_staleness = 0 # the time since the last vote - epochs = 0 # the number of epochs - futures = [] # the futures for the parallel tasks - results = [] # the results of the parallel tasks - _clients = {} # the clients for the parallel tasks - - def __init__(self, - network= 'local', # for local subspace:test or test # for testnet subspace:main or main # for mainnet - subnet : Optional[Union[str, int]] = None, # (OPTIONAL) the name of the subnetwork - search : Optional[str] = None, # (OPTIONAL) the search string for the network - batch_size : int = 128, # the batch size of the most parallel tasks - max_workers : Optional[int]= None , # the number of parallel workers in the executor - score : Union['callable', int]= None, # score function - key : str = None, - path : str= None, # the storage path for the module eval, if not null then the module eval is stored in this directory - tempo : int = None , - timeout : int = 3, # timeout per evaluation of the module - update : bool =False, # update during the first epoch - run_loop : bool = True, # This is the key that we need to change to false - **kwargs): - - self.timeout = timeout or 3 - self.max_workers = max_workers or c.cpu_count() * 5 - self.batch_size = batch_size or 128 - self.executor = c.module('executor')(max_workers=self.max_workers, maxsize=self.batch_size) - self.set_key(key) - self.set_network(network=network, subnet=subnet, tempo=tempo, search=search, path=path, score=score, update=update) - if run_loop: - c.thread(self.run_loop) - init_vali = __init__ - - - def set_key(self, key): - self.key = c.get_key(key or self.module_name()) - return {'success': True, 'msg': 'Key set', 'key': self.key} - - def set_network(self, network:str, - subnet:str=None, - tempo:int=60, - search:str=None, - path:str=None, - score = None, - update=False): - - - if not network in self.networks and '/' not in network: - network = f'subspace/{network}' - [network, subnet] = network.split('/') if '/' in network else [network, subnet] - self.subnet = subnet - self.network = network - self.network_module = c.module(self.network)() - self.tempo = tempo - self.search = search - self.path = os.path.abspath(path or self.resolve_path(f'{network}/{subnet}' if subnet else network)) - self.is_voting_network = any([v in self.network for v in self.voting_networks]) - - self.set_score(score) - self.sync(update=update) - - - def score(self, module): - return int('name' in module.info()) - - def set_score(self, score): - if callable(score): - setattr(self, 'score', score ) - assert callable(self.score), f'SCORE NOT SET {self.score}' - return {'success': True, 'msg': 'Score function set'} - - def run_loop(self): - while True: - try: - self.epoch() - except Exception as e: - c.print('XXXXXXXXXX EPOCH ERROR ----> XXXXXXXXXX ',c.detailed_error(e), color='red') - @property - def time_until_next_epoch(self): - return int(self.epoch_time + self.tempo - c.time()) - - - def get_client(self, module:dict): - if module['key'] in self._clients: - client = self._clients[module['key']] - else: - client = c.connect(module['address'], key=self.key) - self._clients[module['key']] = client - return client - - def score_module(self, module:dict, **kwargs): - """ - module: dict - name: str - address: str - key: str - time: int - """ - - module['time'] = c.time() # the timestamp - client = self.get_client(module) - module['score'] = self.score(client, **kwargs) - module['latency'] = c.time() - module['time'] - module['path'] = self.path +'/'+ module['key'] - return module - - def score_modules(self, modules: List[dict]): - module_results = [] - futures = [self.executor.submit(self.score_module, [m], timeout=self.timeout) for m in modules] - try: - for f in c.as_completed(futures, timeout=self.timeout): - m = f.result() - if m.get('score', 0) > 0: - c.put_json(m['path'], m) - module_results.append(m) - except Exception as e: - c.print(f'ERROR({c.detailed_error(e)})', color='red', verbose=1) - - return module_results - - def epoch(self): - next_epoch = self.time_until_next_epoch - progress = c.tqdm(total=next_epoch, desc='Next Epoch') - for _ in range(next_epoch): - progress.update(1) - c.sleep(1) - self.sync() - c.print(f'Epoch(network={self.network} epoch={self.epochs} n={self.n} )', color='yellow') - batches = [self.modules[i:i+self.batch_size] for i in range(0, self.n, self.batch_size)] - progress = c.tqdm(total=len(batches), desc='Evaluating Modules') - results = [] - for i, module_batch in enumerate(batches): - print(f'Batch(i={i}/{len(batches)})') - results += self.score_modules(module_batch) - progress.update(1) - self.epochs += 1 - self.epoch_time = c.time() - print(self.scoreboard()) - self.vote(results) - return results - - def sync(self, update = False): - max_age = 0 if update else (self.tempo or 60) - self.modules = self.network_module.modules(subnet=self.subnet, max_age=max_age) - self.params = self.network_module.params(subnet=self.subnet, max_age=max_age) - self.tempo = self.tempo or (self.params['tempo'] * self.network_module.block_time)//2 - print(self.tempo) - if self.search != None: - self.modules = [m for m in self.modules if self.search in m['name']] - self.n = len(self.modules) - self.network_info = {'n': self.n, 'network': self.network , 'subnet': self.subnet, 'params': self.params} - c.print(f' 0 : - df += [{k: r.get(k, None) for k in keys}] - else : - self.rm(path) - df = c.df(df) - if len(df) > 0: - if isinstance(by, str): - by = [by] - df = df.sort_values(by=by, ascending=ascending) - # if to_dict is true, we return the dataframe as a list of dictionaries - if to_dict: - return df.to_dict(orient='records') - if len(df) > page_size: - pages = len(df)//page_size - page = page or 0 - df = df[page*page_size:(page+1)*page_size] - - return df - - def module_paths(self): - paths = self.ls(self.path) - return paths - - @classmethod - def run_epoch(cls, network='local', run_loop=False, update=False, **kwargs): - return cls(network=network, run_loop=run_loop, update=update, **kwargs).epoch() - - @staticmethod - def test( - n=2, - tag = 'vali_test_net', - miner='module', - trials = 5, - tempo = 4, - update=True, - path = '/tmp/commune/vali_test', - network='local' - ): - test_miners = [f'{miner}::{tag}{i}' for i in range(n)] - modules = test_miners - search = tag - assert len(modules) == n, f'Number of miners not equal to n {len(modules)} != {n}' - for m in modules: - c.serve(m) - namespace = c.namespace() - for m in modules: - assert m in namespace, f'Miner not in namespace {m}' - vali = Vali(network=network, search=search, path=path, update=update, tempo=tempo, run_loop=False) - print(vali.modules) - scoreboard = [] - while len(scoreboard) < n: - c.sleep(1) - scoreboard = vali.epoch() - trials -= 1 - assert trials > 0, f'Trials exhausted {trials}' - for miner in modules: - c.print(c.kill(miner)) - return {'success': True, 'msg': 'subnet test passed'} - - def refresh_scoreboard(self): - path = self.path - c.rm(path) - return {'success': True, 'msg': 'Leaderboard removed', 'path': path} \ No newline at end of file diff --git a/tests/test_key.py b/tests/test_key.py deleted file mode 100644 index badef093e..000000000 --- a/tests/test_key.py +++ /dev/null @@ -1,103 +0,0 @@ - -import commune as c -crypto_type='solana' -def test_encryption(values = [10, 'fam', 'hello world']): - cls = c.module('key') - for value in values: - value = str(value) - key = cls.new_key(crypto_type=crypto_type) - enc = key.encrypt(value) - dec = key.decrypt(enc) - assert dec == value, f'encryption failed, {dec} != {value}' - return {'encrypted':enc, 'decrypted': dec} - -def test_encryption_with_password(value = 10, password = 'fam'): - cls = c.module('key') - value = str(value) - key = cls.new_key(crypto_type=crypto_type) - enc = key.encrypt(value, password=password) - dec = key.decrypt(enc, password=password) - assert dec == value, f'encryption failed, {dec} != {value}' - return {'encrypted':enc, 'decrypted': dec} - -def test_key_encryption(test_key='test.key'): - self = c.module('key') - key = self.add_key(test_key, refresh=True, crypto_type=crypto_type) - og_key = self.get_key(test_key, crypto_type=crypto_type) - r = self.encrypt_key(test_key) - self.decrypt_key(test_key, password=r['password']) - key = self.get_key(test_key, crypto_type=crypto_type) - assert key.ss58_address == og_key.ss58_address, f'key encryption failed, {key.ss58_address} != {self.ss58_address}' - return {'success': True, 'msg': 'test_key_encryption passed'} - -def test_key_management(key1='test.key' , key2='test2.key'): - self = c.module('key') - if self.key_exists(key1): - self.rm_key(key1) - if self.key_exists(key2): - self.rm_key(key2) - self.add_key(key1, crypto_type=crypto_type) - k1 = self.get_key(key1, crypto_type=crypto_type) - assert self.key_exists(key1), f'Key management failed, key still exists' - self.mv_key(key1, key2, crypto_type=crypto_type) - k2 = self.get_key(key2, crypto_type=crypto_type) - assert k1.ss58_address == k2.ss58_address, f'Key management failed, {k1.ss58_address} != {k2.ss58_address}' - assert self.key_exists(key2), f'Key management failed, key does not exist' - assert not self.key_exists(key1), f'Key management failed, key still exists' - self.mv_key(key2, key1, crypto_type=crypto_type) - assert self.key_exists(key1), f'Key management failed, key does not exist' - assert not self.key_exists(key2), f'Key management failed, key still exists' - self.rm_key(key1) - # self.rm_key(key2) - assert not self.key_exists(key1), f'Key management failed, key still exists' - assert not self.key_exists(key2), f'Key management failed, key still exists' - return {'success': True, 'msg': 'test_key_management passed'} - - -def test_signing(): - self = c.module('key')(crypto_type=crypto_type) - sig = self.sign('test') - assert self.verify('test',sig, self.public_key) - return {'success':True} - -def test_key_encryption(password='1234'): - cls = c.module('key') - path = 'test.enc' - cls.add_key('test.enc', refresh=True, crypto_type=crypto_type) - assert cls.is_key_encrypted(path) == False, f'file {path} is encrypted' - cls.encrypt_key(path, password=password) - assert cls.is_key_encrypted(path) == True, f'file {path} is not encrypted' - cls.decrypt_key(path, password=password) - assert cls.is_key_encrypted(path) == False, f'file {path} is encrypted' - cls.rm(path) - print('file deleted', path, c.exists, 'fam') - assert not c.exists(path), f'file {path} not deleted' - return {'success': True, 'msg': 'test_key_encryption passed'} - -def test_move_key(): - self = c.module('key')(crypto_type=crypto_type) - self.add_key('testfrom', crypto_type=crypto_type) - assert self.key_exists('testfrom') - og_key = self.get_key('testfrom', crypto_type=crypto_type) - self.mv_key('testfrom', 'testto', crypto_type=crypto_type) - assert self.key_exists('testto', crypto_type=crypto_type) - assert not self.key_exists('testfrom') - new_key = self.get_key('testto', crypto_type=crypto_type) - assert og_key.ss58_address == new_key.ss58_address - self.rm_key('testto') - assert not self.key_exists('testto') - return {'success':True, 'msg':'test_move_key passed', 'key':new_key.ss58_address} - - -def test_ss58_encoding(): - self = c.module('key') - keypair = self.create_from_uri('//Alice') - ss58_address = keypair.ss58_address - public_key = keypair.public_key - assert keypair.ss58_address == self.ss58_encode(public_key, ss58_format=42) - assert keypair.ss58_address == self.ss58_encode(public_key, ss58_format=42) - assert keypair.public_key.hex() == self.ss58_decode(ss58_address) - assert keypair.public_key.hex() == self.ss58_decode(ss58_address) - return {'success':True} - - diff --git a/tests/test_server.py b/tests/test_server.py deleted file mode 100644 index a8e3fa1f2..000000000 --- a/tests/test_server.py +++ /dev/null @@ -1,69 +0,0 @@ - - -import commune as c - -def test(): - self = c.module('serializer')() - import torch, time - data_list = [ - torch.ones(1000), - torch.zeros(1000), - torch.rand(1000), - [1,2,3,4,5], - {'a':1, 'b':2, 'c':3}, - 'hello world', - c.df([{'name': 'joe', 'fam': 1}]), - 1, - 1.0, - True, - False, - None - - ] - for data in data_list: - t1 = time.time() - ser_data = self.serialize(data) - des_data = self.deserialize(ser_data) - des_ser_data = self.serialize(des_data) - t2 = time.time() - - latency = t2 - t1 - emoji = '✅' if str(des_ser_data) == str(ser_data) else '❌' - print(type(data),emoji) - return {'msg': 'PASSED test_serialize_deserialize'} - - -def test_basics() -> dict: - servers = c.servers() - c.print(servers) - name = f'module::test' - c.serve(name) - c.kill(name) - assert name not in c.servers() - return {'success': True, 'msg': 'server test passed'} - -def test_serving(name = 'module::test'): - module = c.serve(name) - module = c.connect(name) - r = module.info() - assert 'name' in r, f"get failed {r}" - c.kill(name) - assert name not in c.servers(update=1) - return {'success': True, 'msg': 'server test passed'} - -def test_serving_with_different_key(module = 'module', timeout=10): - tag = 'test_serving_with_different_key' - key_name = module + '::'+ tag - module_name = module + '::'+ tag + '_b' - if not c.key_exists(key_name): - key = c.add_key(key_name) - c.print(c.serve(module_name, key=key_name)) - key = c.get_key(key_name) - c.sleep(2) - info = c.call(f'{module_name}/info', timeout=2) - assert info.get('key', None) == key.ss58_address , f" {info}" - c.kill(module_name) - c.rm_key(key_name) - assert not c.key_exists(key_name) - assert not c.server_exists(module_name) - return {'success': True, 'msg': 'server test passed'} \ No newline at end of file diff --git a/tests/test_subspace.py b/tests/test_subspace.py deleted file mode 100644 index a7365cf82..000000000 --- a/tests/test_subspace.py +++ /dev/null @@ -1,40 +0,0 @@ -import commune as c - -def test_global_params(): - self = c.module('subspace')() - global_params = self.global_params() - assert isinstance(global_params, dict) - return {'msg': 'global_params test passed', 'success': True} - -def test_subnet_params(subnet=0): - self = c.module('subspace')() - subnet_params = self.subnet_params(subnet=subnet) - assert isinstance(subnet_params, dict), f'{subnet_params} is not a dict' - return {'msg': 'subnet_params test passed', 'success': True} - - -def test_module_params(keys=['dividends', 'incentive'], subnet=0): - self = c.module('subspace')() - key = self.keys(subnet)[0] - module_info = self.get_module(key, subnet=subnet) - assert isinstance(module_info, dict) - for k in keys: - assert k in module_info, f'{k} not in {module_info}' - - return {'msg': 'module_params test passed', 'success': True, 'module_info': module_info} - - -def test_substrate(): - self = c.module('subspace')() - for i in range(3): - t1 = c.time() - c.print(self.substrate) - t2 = c.time() - c.print(f'{t2-t1:.2f} seconds') - return {'msg': 'substrate test passed', 'success': True} - - - - - - diff --git a/tests/test_validator.py b/tests/test_validator.py deleted file mode 100644 index c3a8112be..000000000 --- a/tests/test_validator.py +++ /dev/null @@ -1,5 +0,0 @@ - -import commune as c -import pandas as pd - -test_net = c.module('vali').test \ No newline at end of file From 590a0c9cb00d16dc2cdb6cdedc657a04ed405c14 Mon Sep 17 00:00:00 2001 From: drunest <6984754+drunest@users.noreply.github.com> Date: Fri, 17 Jan 2025 07:45:13 -0600 Subject: [PATCH 27/27] fix: some errors occured by conflict --- commune/__init__.py | 2 +- commune/key/key.py | 6 +++--- commune/server/test_server.py | 2 +- tests/test_vali.py | 5 ----- 4 files changed, 5 insertions(+), 10 deletions(-) delete mode 100644 tests/test_vali.py diff --git a/commune/__init__.py b/commune/__init__.py index b9b2ae53d..4b8928925 100755 --- a/commune/__init__.py +++ b/commune/__init__.py @@ -3,7 +3,7 @@ M = c = Block = Agent = Module # alias c.Module as c.Block, c.Lego, c.M from .vali.vali import Vali # the vali module from .server.server import Server # the server module -from .client import Client # the client module +from .server.client import Client # the client module from .key.key import Key # the key module # set the module functions as globalsw diff --git a/commune/key/key.py b/commune/key/key.py index 9dbe5ffe9..55e48f6a2 100644 --- a/commune/key/key.py +++ b/commune/key/key.py @@ -84,8 +84,8 @@ def set_crypto_type(self, crypto_type): def set_private_key(self, private_key: Union[bytes, str] = None, - ss58_format: int = ss58_format, - crypto_type: int = crypto_type, + ss58_format: int = SS58_FORMAT, + crypto_type: int = KeyType.SR25519, derive_path: str = None, path:str = None, **kwargs @@ -588,7 +588,7 @@ def create_from_password(cls, password: str, crypto_type: Union[str, int] = KeyT str2key = pwd2key = password2key = from_password = create_from_password @classmethod - def create_from_uri + def create_from_uri( cls, suri: str, ss58_format: int = SS58_FORMAT, diff --git a/commune/server/test_server.py b/commune/server/test_server.py index 3719fc1a6..140f40ed7 100644 --- a/commune/server/test_server.py +++ b/commune/server/test_server.py @@ -2,7 +2,7 @@ import commune as c -def test(): +def test_serializer(): self = c.module('serializer')() import torch, time data_list = [ diff --git a/tests/test_vali.py b/tests/test_vali.py deleted file mode 100644 index c3a8112be..000000000 --- a/tests/test_vali.py +++ /dev/null @@ -1,5 +0,0 @@ - -import commune as c -import pandas as pd - -test_net = c.module('vali').test \ No newline at end of file