diff --git a/cashu/core/htlc.py b/cashu/core/htlc.py deleted file mode 100644 index bb4968bae..000000000 --- a/cashu/core/htlc.py +++ /dev/null @@ -1,12 +0,0 @@ -from .p2pk import P2PKSecret -from .secret import Secret, SecretKind - - -# HTLCSecret inherits properties from P2PKSecret -class HTLCSecret(P2PKSecret, Secret): - @classmethod - def from_secret(cls, secret: Secret): - assert SecretKind(secret.kind) == SecretKind.HTLC, "Secret is not a HTLC secret" - # NOTE: exclude tags in .dict() because it doesn't deserialize it properly - # need to add it back in manually with tags=secret.tags - return cls(**secret.dict(exclude={"tags"}), tags=secret.tags) diff --git a/cashu/core/nuts/nut11.py b/cashu/core/nuts/nut11.py new file mode 100644 index 000000000..2726aa064 --- /dev/null +++ b/cashu/core/nuts/nut11.py @@ -0,0 +1,115 @@ +import hashlib +from datetime import datetime, timedelta +from enum import Enum +from typing import List, Optional, Union + +from loguru import logger + +from ..base import BlindedMessage, Proof +from ..crypto.secp import PrivateKey, PublicKey +from ..secret import Secret, SecretKind, Tags + + +class SigFlags(Enum): + # require signatures only on the inputs (default signature flag) + SIG_INPUTS = "SIG_INPUTS" + # require signatures on inputs and outputs + SIG_ALL = "SIG_ALL" + + +class P2PKSecret(Secret): + @classmethod + def from_secret(cls, secret: Secret): + assert SecretKind(secret.kind) == SecretKind.P2PK, "Secret is not a P2PK secret" + # NOTE: exclude tags in .dict() because it doesn't deserialize it properly + # need to add it back in manually with tags=secret.tags + return cls(**secret.dict(exclude={"tags"}), tags=secret.tags) + + @property + def locktime(self) -> Union[None, int]: + locktime = self.tags.get_tag("locktime") + return int(locktime) if locktime else None + + @property + def sigflag(self) -> SigFlags: + sigflag = self.tags.get_tag("sigflag") + return SigFlags(sigflag) if sigflag else SigFlags.SIG_INPUTS + + @property + def n_sigs(self) -> int: + n_sigs = self.tags.get_tag_int("n_sigs") + return int(n_sigs) if n_sigs else 1 + + @property + def n_sigs_refund(self) -> Union[None, int]: + n_sigs_refund = self.tags.get_tag_int("n_sigs_refund") + return n_sigs_refund + + +def schnorr_sign(message: bytes, private_key: PrivateKey) -> bytes: + signature = private_key.schnorr_sign( + hashlib.sha256(message).digest(), None, raw=True + ) + return signature + + +def verify_schnorr_signature( + message: bytes, pubkey: PublicKey, signature: bytes +) -> bool: + return pubkey.schnorr_verify( + hashlib.sha256(message).digest(), signature, None, raw=True + ) + + +def create_p2pk_lock( + self, + data: str, + locktime_seconds: Optional[int] = None, + tags: Optional[Tags] = None, + sig_all: bool = False, + n_sigs: int = 1, +) -> P2PKSecret: + """Generate a P2PK secret with the given pubkeys, locktime, tags, and signature flag. + + Args: + data (str): Public key to lock to. + locktime_seconds (Optional[int], optional): Locktime in seconds. Defaults to None. + tags (Optional[Tags], optional): Tags to add to the secret. Defaults to None. + sig_all (bool, optional): Whether to use SIG_ALL spending condition. Defaults to False. + n_sigs (int, optional): Number of signatures required. Defaults to 1. + + Returns: + P2PKSecret: P2PK secret with the given pubkeys, locktime, tags, and signature flag. + """ + logger.debug(f"Provided tags: {tags}") + if not tags: + tags = Tags() + logger.debug(f"Before tags: {tags}") + if locktime_seconds: + tags["locktime"] = str( + int((datetime.now() + timedelta(seconds=locktime_seconds)).timestamp()) + ) + tags["sigflag"] = SigFlags.SIG_ALL.value if sig_all else SigFlags.SIG_INPUTS.value + if n_sigs > 1: + tags["n_sigs"] = str(n_sigs) + logger.debug(f"After tags: {tags}") + return P2PKSecret( + kind=SecretKind.P2PK.value, + data=data, + tags=tags, + ) + + +def sigall_message_to_sign(proofs: List[Proof], outputs: List[BlindedMessage]) -> str: + """ + Creates the message to sign for sigall spending conditions. + The message is a concatenation of all proof secrets and signatures + all output attributes (amount, id, B_). + """ + + # Concatenate all proof secrets + message = "".join([p.secret + p.C for p in proofs]) + + # Concatenate all output attributes + message += "".join([str(o.amount) + o.id + o.B_ for o in outputs]) + + return message diff --git a/cashu/core/nuts/nut14.py b/cashu/core/nuts/nut14.py index 5b7ce9898..579d22f8a 100644 --- a/cashu/core/nuts/nut14.py +++ b/cashu/core/nuts/nut14.py @@ -3,8 +3,17 @@ from ..base import Proof from ..errors import TransactionError -from ..htlc import HTLCSecret from ..secret import Secret, SecretKind +from .nut11 import P2PKSecret + + +class HTLCSecret(P2PKSecret, Secret): + @classmethod + def from_secret(cls, secret: Secret): + assert SecretKind(secret.kind) == SecretKind.HTLC, "Secret is not a HTLC secret" + # NOTE: exclude tags in .dict() because it doesn't deserialize it properly + # need to add it back in manually with tags=secret.tags + return cls(**secret.dict(exclude={"tags"}), tags=secret.tags) def verify_htlc_spending_conditions( @@ -27,11 +36,10 @@ def verify_htlc_spending_conditions( try: if len(proof.htlcpreimage) != 64: raise TransactionError("HTLC preimage must be 64 characters hex.") - if not sha256( - bytes.fromhex(proof.htlcpreimage) - ).digest() == bytes.fromhex(htlc_secret.data): + if not sha256(bytes.fromhex(proof.htlcpreimage)).digest() == bytes.fromhex( + htlc_secret.data + ): raise TransactionError("HTLC preimage does not match.") except ValueError: raise TransactionError("invalid preimage for HTLC: not a hex string.") return True - diff --git a/cashu/core/p2pk.py b/cashu/core/p2pk.py deleted file mode 100644 index 79afab108..000000000 --- a/cashu/core/p2pk.py +++ /dev/null @@ -1,57 +0,0 @@ -import hashlib -from enum import Enum -from typing import Union - -from .crypto.secp import PrivateKey, PublicKey -from .secret import Secret, SecretKind - - -class SigFlags(Enum): - # require signatures only on the inputs (default signature flag) - SIG_INPUTS = "SIG_INPUTS" - # require signatures on inputs and outputs - SIG_ALL = "SIG_ALL" - - -class P2PKSecret(Secret): - @classmethod - def from_secret(cls, secret: Secret): - assert SecretKind(secret.kind) == SecretKind.P2PK, "Secret is not a P2PK secret" - # NOTE: exclude tags in .dict() because it doesn't deserialize it properly - # need to add it back in manually with tags=secret.tags - return cls(**secret.dict(exclude={"tags"}), tags=secret.tags) - - @property - def locktime(self) -> Union[None, int]: - locktime = self.tags.get_tag("locktime") - return int(locktime) if locktime else None - - @property - def sigflag(self) -> SigFlags: - sigflag = self.tags.get_tag("sigflag") - return SigFlags(sigflag) if sigflag else SigFlags.SIG_INPUTS - - @property - def n_sigs(self) -> int: - n_sigs = self.tags.get_tag_int("n_sigs") - return int(n_sigs) if n_sigs else 1 - - @property - def n_sigs_refund(self) -> Union[None, int]: - n_sigs_refund = self.tags.get_tag_int("n_sigs_refund") - return n_sigs_refund - - -def schnorr_sign(message: bytes, private_key: PrivateKey) -> bytes: - signature = private_key.schnorr_sign( - hashlib.sha256(message).digest(), None, raw=True - ) - return signature - - -def verify_schnorr_signature( - message: bytes, pubkey: PublicKey, signature: bytes -) -> bool: - return pubkey.schnorr_verify( - hashlib.sha256(message).digest(), signature, None, raw=True - ) diff --git a/cashu/mint/conditions.py b/cashu/mint/conditions.py index cbb4e84c3..4b807548d 100644 --- a/cashu/mint/conditions.py +++ b/cashu/mint/conditions.py @@ -8,13 +8,13 @@ from ..core.errors import ( TransactionError, ) -from ..core.htlc import HTLCSecret -from ..core.nuts.nut14 import verify_htlc_spending_conditions -from ..core.p2pk import ( +from ..core.nuts import nut11, nut14 +from ..core.nuts.nut11 import ( P2PKSecret, SigFlags, verify_schnorr_signature, ) +from ..core.nuts.nut14 import HTLCSecret from ..core.secret import Secret, SecretKind @@ -163,7 +163,7 @@ def _verify_input_spending_conditions(self, proof: Proof) -> bool: # HTLC if SecretKind(secret.kind) == SecretKind.HTLC: htlc_secret = HTLCSecret.from_secret(secret) - verify_htlc_spending_conditions(proof) + nut14.verify_htlc_spending_conditions(proof) return self._verify_p2pk_sig_inputs(proof, htlc_secret) # no spending condition present @@ -285,8 +285,8 @@ def _verify_sigall_spending_conditions( if not pubkeys: return True - message_to_sign = message_to_sign or "".join( - [p.secret for p in proofs] + [o.B_ for o in outputs] + message_to_sign = message_to_sign or nut11.sigall_message_to_sign( + proofs, outputs ) # validation diff --git a/cashu/mint/ledger.py b/cashu/mint/ledger.py index c13e52946..a4be15305 100644 --- a/cashu/mint/ledger.py +++ b/cashu/mint/ledger.py @@ -45,6 +45,7 @@ PostMeltQuoteResponse, PostMintQuoteRequest, ) +from ..core.nuts import nut11 from ..core.settings import settings from ..core.split import amount_split from ..lightning.base import ( @@ -886,9 +887,7 @@ async def melt( ) # verify SIG_ALL signatures - message_to_sign = ( - "".join([p.secret for p in proofs] + [o.B_ for o in outputs or []]) + quote - ) + message_to_sign = nut11.sigall_message_to_sign(proofs, outputs or []) + quote self._verify_sigall_spending_conditions(proofs, outputs or [], message_to_sign) # verify that the amount of the input proofs is equal to the amount of the quote diff --git a/cashu/wallet/htlc.py b/cashu/wallet/htlc.py index 1c003235f..a22372471 100644 --- a/cashu/wallet/htlc.py +++ b/cashu/wallet/htlc.py @@ -4,7 +4,7 @@ from ..core.base import HTLCWitness, Proof from ..core.db import Database -from ..core.htlc import ( +from ..core.nuts.nut14 import ( HTLCSecret, ) from ..core.secret import SecretKind, Tags diff --git a/cashu/wallet/p2pk.py b/cashu/wallet/p2pk.py index 89686aea8..d291ee5b4 100644 --- a/cashu/wallet/p2pk.py +++ b/cashu/wallet/p2pk.py @@ -3,8 +3,6 @@ from loguru import logger -from cashu.core.htlc import HTLCSecret - from ..core.base import ( BlindedMessage, HTLCWitness, @@ -13,11 +11,13 @@ ) from ..core.crypto.secp import PrivateKey from ..core.db import Database -from ..core.p2pk import ( +from ..core.nuts import nut11 +from ..core.nuts.nut11 import ( P2PKSecret, SigFlags, schnorr_sign, ) +from ..core.nuts.nut14 import HTLCSecret from ..core.secret import Secret, SecretKind, Tags from .protocols import SupportsDb, SupportsPrivateKey @@ -157,8 +157,8 @@ def add_witness_swap_sig_all( secrets = set([Secret.deserialize(p.secret) for p in proofs]) if not len(secrets) == 1: raise Exception("Secrets not identical") - message_to_sign = message_to_sign or "".join( - [p.secret for p in proofs] + [o.B_ for o in outputs] + message_to_sign = message_to_sign or nut11.sigall_message_to_sign( + proofs, outputs ) signature = self.schnorr_sign_message(message_to_sign) # add witness to only the first proof @@ -195,9 +195,7 @@ def sign_proofs_inplace_melt( ) -> List[Proof]: # sign proofs if they are P2PK SIG_INPUTS proofs = self.add_witnesses_sig_inputs(proofs) - message_to_sign = ( - "".join([p.secret for p in proofs] + [o.B_ for o in outputs]) + quote_id - ) + message_to_sign = nut11.sigall_message_to_sign(proofs, outputs) + quote_id # sign first proof if swap is SIG_ALL return self.add_witness_swap_sig_all(proofs, outputs, message_to_sign) diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index 6ffeb97b5..e7e8cda57 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -39,7 +39,7 @@ PostMeltQuoteResponse, ) from ..core.nuts import nut20 -from ..core.p2pk import Secret +from ..core.secret import Secret from ..core.settings import settings from . import migrations from .compat import WalletCompat diff --git a/tests/mint/test_mint_p2pk.py b/tests/mint/test_mint_p2pk.py index 1b28180de..d862dc66e 100644 --- a/tests/mint/test_mint_p2pk.py +++ b/tests/mint/test_mint_p2pk.py @@ -2,6 +2,7 @@ import pytest_asyncio from cashu.core.base import P2PKWitness +from cashu.core.nuts import nut11 from cashu.mint.ledger import Ledger from cashu.wallet.wallet import Wallet as Wallet1 from tests.conftest import SERVER_ENDPOINT @@ -192,7 +193,7 @@ async def test_ledger_verify_sigall_validation(wallet1: Wallet1, ledger: Ledger) outputs, rs = wallet1._construct_outputs(output_amounts, secrets, rs) # Create the message to sign (all inputs + all outputs) - message_to_sign = "".join([p.secret for p in send_proofs] + [o.B_ for o in outputs]) + message_to_sign = nut11.sigall_message_to_sign(send_proofs, outputs) # Sign the message with the wallet's private key signature = wallet1.schnorr_sign_message(message_to_sign) diff --git a/tests/mint/test_mint_p2pk_comprehensive.py b/tests/mint/test_mint_p2pk_comprehensive.py index fed5312b8..d4d52cb5c 100644 --- a/tests/mint/test_mint_p2pk_comprehensive.py +++ b/tests/mint/test_mint_p2pk_comprehensive.py @@ -7,7 +7,8 @@ from cashu.core.base import BlindedMessage, P2PKWitness from cashu.core.migrations import migrate_databases -from cashu.core.p2pk import P2PKSecret, SigFlags +from cashu.core.nuts import nut11 +from cashu.core.nuts.nut11 import P2PKSecret, SigFlags from cashu.core.secret import Secret, SecretKind, Tags from cashu.mint.ledger import Ledger from cashu.wallet import migrations @@ -108,6 +109,39 @@ async def test_p2pk_sig_inputs_basic(wallet1: Wallet, wallet2: Wallet, ledger: L assert len(promises) == len(outputs) +@pytest.mark.asyncio +async def test_p2pk_sig_all_message_aggregation( + wallet1: Wallet, wallet2: Wallet, ledger: Ledger +): + # Mint tokens to wallet1 + mint_quote = await wallet1.request_mint(64) + await pay_if_regtest(mint_quote.request) + await wallet1.mint(64, quote_id=mint_quote.quote) + + # Create locked tokens with SIG_ALL + pubkey_wallet2 = await wallet2.create_p2pk_pubkey() + secret_lock = await wallet1.create_p2pk_lock(pubkey_wallet2, sig_all=True) + _, send_proofs = await wallet1.swap_to_send( + wallet1.proofs, 16, secret_lock=secret_lock + ) + + # Verify that sent tokens have P2PK secrets with SIG_ALL flag + for proof in send_proofs: + p2pk_secret = Secret.deserialize(proof.secret) + assert p2pk_secret.kind == SecretKind.P2PK.value + assert P2PKSecret.from_secret(p2pk_secret).sigflag == SigFlags.SIG_ALL + + # Create outputs for redemption + outputs = await create_test_outputs(wallet2, 16) + + message_to_sign_expected = "".join( + [p.secret + p.C for p in send_proofs] + + [str(o.amount) + o.id + o.B_ for o in outputs] + ) + message_to_sign_actual = nut11.sigall_message_to_sign(send_proofs, outputs) + assert message_to_sign_actual == message_to_sign_expected + + @pytest.mark.asyncio async def test_p2pk_sig_all_valid(wallet1: Wallet, wallet2: Wallet, ledger: Ledger): """Test P2PK with SIG_ALL where the signature covers both inputs and outputs.""" @@ -133,7 +167,7 @@ async def test_p2pk_sig_all_valid(wallet1: Wallet, wallet2: Wallet, ledger: Ledg outputs = await create_test_outputs(wallet2, 16) # Create a message from concatenated inputs and outputs - message_to_sign = "".join([p.secret for p in send_proofs] + [o.B_ for o in outputs]) + message_to_sign = nut11.sigall_message_to_sign(send_proofs, outputs) # Sign with wallet2's private key signature = wallet2.schnorr_sign_message(message_to_sign) @@ -611,7 +645,7 @@ async def test_p2pk_sig_all_with_multiple_pubkeys( outputs = await create_test_outputs(wallet1, 16) # Create message to sign (all inputs + all outputs) - message_to_sign = "".join([p.secret for p in send_proofs] + [o.B_ for o in outputs]) + message_to_sign = nut11.sigall_message_to_sign(send_proofs, outputs) # Sign with wallet1's key signature1 = wallet1.schnorr_sign_message(message_to_sign) diff --git a/tests/wallet/test_wallet_htlc.py b/tests/wallet/test_wallet_htlc.py index b8fb6b25f..60fd2232b 100644 --- a/tests/wallet/test_wallet_htlc.py +++ b/tests/wallet/test_wallet_htlc.py @@ -9,9 +9,9 @@ from cashu.core.base import HTLCWitness, Proof from cashu.core.crypto.secp import PrivateKey -from cashu.core.htlc import HTLCSecret from cashu.core.migrations import migrate_databases -from cashu.core.p2pk import SigFlags +from cashu.core.nuts.nut11 import SigFlags +from cashu.core.nuts.nut14 import HTLCSecret from cashu.core.secret import SecretKind from cashu.wallet import migrations from cashu.wallet.wallet import Wallet diff --git a/tests/wallet/test_wallet_p2pk.py b/tests/wallet/test_wallet_p2pk.py index 794ffb8f1..e03d72601 100644 --- a/tests/wallet/test_wallet_p2pk.py +++ b/tests/wallet/test_wallet_p2pk.py @@ -12,7 +12,7 @@ from cashu.core.base import P2PKWitness, Proof from cashu.core.crypto.secp import PrivateKey, PublicKey from cashu.core.migrations import migrate_databases -from cashu.core.p2pk import P2PKSecret, SigFlags +from cashu.core.nuts.nut11 import P2PKSecret, SigFlags from cashu.core.secret import Secret, SecretKind, Tags from cashu.wallet import migrations from cashu.wallet.wallet import Wallet diff --git a/tests/wallet/test_wallet_p2pk_methods.py b/tests/wallet/test_wallet_p2pk_methods.py index a6c8f9ae5..92325ecc1 100644 --- a/tests/wallet/test_wallet_p2pk_methods.py +++ b/tests/wallet/test_wallet_p2pk_methods.py @@ -8,7 +8,8 @@ from cashu.core.base import P2PKWitness from cashu.core.crypto.secp import PrivateKey from cashu.core.migrations import migrate_databases -from cashu.core.p2pk import P2PKSecret, SigFlags +from cashu.core.nuts import nut11 +from cashu.core.nuts.nut11 import P2PKSecret, SigFlags from cashu.core.secret import SecretKind, Tags from cashu.wallet import migrations from cashu.wallet.wallet import Wallet @@ -199,7 +200,7 @@ async def test_add_witness_swap_sig_all(wallet1: Wallet): assert len(witness.signatures) == 1 # Verify the signature includes both inputs and outputs - message_to_sign = "".join([p.secret for p in proofs] + [o.B_ for o in outputs]) + message_to_sign = nut11.sigall_message_to_sign(proofs, outputs) signature = wallet1.schnorr_sign_message(message_to_sign) assert witness.signatures[0] == signature