diff --git a/.env.sepolia b/.env.sepolia new file mode 100644 index 0000000..2fece7d --- /dev/null +++ b/.env.sepolia @@ -0,0 +1,15 @@ +# Shared +NODE_ENV=development + +# Sepolia RPC (Infura/Alchemy/Blast/etc.) +STARKNET_RPC_URL=https://starknet-sepolia.infura.io/v3/REPLACE_WITH_KEY +STARKNET_NETWORK=sepolia + +# UA² contracts (attach-only; populate before running e2e) +UA2_CLASS_HASH= +UA2_IMPLEMENTATION_ADDR= +UA2_PROXY_ADDR= + +# Demo app +NEXT_PUBLIC_NETWORK=sepolia +NEXT_PUBLIC_UA2_PROXY_ADDR= diff --git a/.gitignore b/.gitignore index 6d04e22..cb14157 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ dist/ coverage/ .env .env.* +!.env.sepolia !.env.example !.env.sepolia.example !packages/contracts/scripts/.env.sepolia.example diff --git a/README.md b/README.md index 5d84045..d07fe71 100644 --- a/README.md +++ b/README.md @@ -60,18 +60,60 @@ scarb build > [mise](https://mise.jdx.dev/). If `scarb` is not on your `PATH`, install the pinned version with > `mise install scarb@2.12.0` (or ensure `/root/.asdf/shims` is exported when using `asdf`). -### 3. Deploy to Sepolia +### 3. Declare & deploy with `sncast` + +Mirror the same workflow locally and on Sepolia so copy/pasting always works. Replace the +placeholders (`<...>`) before running. ```bash # still inside packages/contracts -export STARKNET_RPC_URL= -export UA2_OWNER_PUBKEY= -./scripts/deploy_ua2.sh + +# Devnet example (see docs/runbook-sepolia.md for full flow) +RPC=http://127.0.0.1:5050 +NAME=devnet + +sncast account create --name "$NAME" --url "$RPC" +sncast account deploy --name "$NAME" --url "$RPC" + +sncast --account "$NAME" \ + declare \ + --contract-name UA2Account \ + --url "$RPC" \ + --max-fee 9638049920000000000 + +UA2_CLASS_HASH=0xCLASS_HASH_FROM_OUTPUT + +OWNER_PUBKEY=0xYOUR_OWNER_FELT +sncast --account "$NAME" \ + deploy \ + --class-hash "$UA2_CLASS_HASH" \ + --constructor-calldata "$OWNER_PUBKEY" \ + --url "$RPC" \ + --max-fee 9638049920000000000 + +UA2_PROXY_ADDR=0xDEPLOYED_ADDRESS + +sncast --account "$NAME" \ + call \ + --contract-address "$UA2_PROXY_ADDR" \ + --function get_owner \ + --url "$RPC" + +# Sepolia mirrors the same steps; just switch RPC/NAME and fund the account with STRK (FRI) +RPC=https://starknet-sepolia.infura.io/v3/ +NAME=sepolia ``` -The helper script declares the class if needed and writes `UA2_CLASS_HASH`, `UA2_IMPLEMENTATION_ADDR`, -and `UA2_PROXY_ADDR` to `packages/contracts/.ua2-sepolia-addresses.json`. Copy the relevant values into -your `.env.sepolia` (created from `.env.sepolia.example`) and set `NEXT_PUBLIC_UA2_PROXY_ADDR` for the demo app. +If `sncast` reports "fee too low", rerun the declare/deploy with the suggested higher +`--max-fee` (fees are denominated in **FRI (STRK)**). Copy the resulting class hash, +implementation hash, and proxy address into `.env` / `.env.sepolia` so the SDK and demo app +point at the correct contracts. `./scripts/deploy_ua2.sh` is still available when you want an +automated run. + +> [!NOTE] +> On devnet, mint FRI to the printed account address via `devnet_mint`. On Sepolia, +> top up the account with STRK/ETH from your faucet or bridge of choice before +> deploying. ### 4. Run demo app @@ -114,6 +156,10 @@ For full walkthrough: [`docs/runbook-sepolia.md`](./docs/runbook-sepolia.md) ```bash npm run e2e:devnet ``` + + > [!TIP] + > Use the devnet + `sncast` recipe in [`docs/runbook-sepolia.md`](./docs/runbook-sepolia.md) to + > create/fund the named account and deploy the UA² class before running the suite. * **E2E on Sepolia:** ```bash diff --git a/docs/architecture.md b/docs/architecture.md index 690471c..bd1bb08 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -14,7 +14,7 @@ - `recovery_active: bool` - `recovery_proposed_owner: felt252` - `recovery_eta: u64` - - `recovery_confirms: LegacyMap` + `recovery_confirm_count: u32` + - `recovery_confirm_count: u32` - `recovery_proposal_id: u64` - `recovery_guardian_last_confirm: LegacyMap` - `session: Map` diff --git a/docs/interfaces.md b/docs/interfaces.md index 7a7b38d..9fc37a4 100644 --- a/docs/interfaces.md +++ b/docs/interfaces.md @@ -26,7 +26,8 @@ ```rust struct SessionPolicy { is_active: bool, - expires_at: u64, + valid_after: u64, + valid_until: u64, max_calls: u32, calls_used: u32, max_value_per_call: Uint256, @@ -35,11 +36,15 @@ struct SessionPolicy { The selector and target allowlists are not embedded in the struct. They are stored in dedicated `LegacyMap` slots keyed by `(key_hash, target)` and `(key_hash, selector)` respectively, and are populated by calling `add_session_with_allowlists` alongside the base policy. +> **Notes:** +> * Supplying empty target and selector lists is permitted, but such a session cannot execute any calls (all lookups fall back to `false`). +> * To keep storage writes + gas predictable in v0, prefer allowlists with ≲32 entries per list. + --- ## Events -* `SessionAdded(key_hash: felt252, expires_at: u64, max_calls: u32)` +* `SessionAdded(key_hash: felt252, valid_after: u64, valid_until: u64, max_calls: u32)` * `SessionRevoked(key_hash: felt252)` * `SessionUsed(key_hash: felt252, used: u32)` * `SessionNonceAdvanced(key_hash: felt252, new_nonce: u128)` @@ -61,15 +66,25 @@ The contract reverts with the following identifiers: * `ERR_SESSION_EXPIRED` * `ERR_SESSION_INACTIVE` +* `ERR_SESSION_STALE` +* `ERR_SESSION_NOT_READY` +* `ERR_SESSION_TARGETS_LEN` +* `ERR_SESSION_SELECTORS_LEN` * `ERR_POLICY_CALLCAP` -* `ERR_POLICY_SELECTOR_DENIED` * `ERR_POLICY_TARGET_DENIED` -* `ERR_VALUE_LIMIT_EXCEEDED` +* `ERR_POLICY_SELECTOR_DENIED` * `ERR_POLICY_CALLCOUNT_MISMATCH` +* `ERR_VALUE_LIMIT_EXCEEDED` * `ERR_BAD_SESSION_NONCE` * `ERR_SESSION_SIG_INVALID` +* `ERR_BAD_VALID_WINDOW` +* `ERR_BAD_MAX_CALLS` +* `ERR_SIGNATURE_MISSING` +* `ERR_OWNER_SIG_INVALID` +* `ERR_GUARDIAN_SIG_INVALID` * `ERR_GUARDIAN_EXISTS` * `ERR_NOT_GUARDIAN` +* `ERR_GUARDIAN_CALL_DENIED` * `ERR_BAD_THRESHOLD` * `ERR_RECOVERY_IN_PROGRESS` * `ERR_NO_RECOVERY` @@ -77,6 +92,7 @@ The contract reverts with the following identifiers: * `ERR_ALREADY_CONFIRMED` * `ERR_BEFORE_ETA` * `ERR_NOT_ENOUGH_CONFIRMS` +* `ERR_NOT_OWNER` * `ERR_ZERO_OWNER` * `ERR_SAME_OWNER` diff --git a/docs/rfc-ua2-sdk.md b/docs/rfc-ua2-sdk.md index 89a95ea..ade75b0 100644 --- a/docs/rfc-ua2-sdk.md +++ b/docs/rfc-ua2-sdk.md @@ -130,7 +130,7 @@ This RFC proposes a neutral, modular SDK that standardizes these primitives with * `recovery_active: bool` * `recovery_proposed_owner: felt252` * `recovery_eta: u64` - * `recovery_confirms: LegacyMap` + `recovery_confirm_count: u32` + * `recovery_confirm_count: u32` * `recovery_proposal_id: u64` * `recovery_guardian_last_confirm: LegacyMap` * `session: Map` @@ -145,30 +145,31 @@ OZ Account is the canonical base for `__validate__`/`__execute__` pattern and mu ``` struct SessionPolicy { is_active: bool, - expires_at: u64, // block timestamp + valid_after: u64, // block timestamp + valid_until: u64, // block timestamp max_calls: u32, calls_used: u32, - max_value_per_call: Uint256, // wei-like units for token/native + max_value_per_call: Uint256, // wei-like units for ERC-20 transfers (native send unsupported in v0) } ``` -Selector and target allowlists are stored separately under `sessionTargetAllow(session_key_hash, ContractAddress)` and `sessionSelectorAllow(session_key_hash, felt252)` legacy maps. The owner typically calls `add_session_with_allowlists` to write the base policy and seed those maps in a single transaction. +Selector and target allowlists are stored separately under `sessionTargetAllow(session_key_hash, ContractAddress)` and `sessionSelectorAllow(session_key_hash, felt252)` legacy maps. The owner typically calls `add_session_with_allowlists` to write the base policy and seed those maps in a single transaction. Empty allowlists are technically valid but render the session unusable, and we recommend keeping each list ≤32 entries in v0 to avoid excessive storage writes. **Validation path:** * If signature is by `owner_pubkey`: standard path. * Else if signature verifies to a registered **session key**: - * Check `is_active`, `now <= expires_at`, and `calls_used + tx_call_count <= max_calls`. + * Check `is_active`, `now >= valid_after`, `now <= valid_until`, and `calls_used + tx_call_count <= max_calls`. * Require allowlist booleans for `(key_hash, target)` and `(key_hash, selector)` to be `true`. - * Enforce ERC-20 transfer amounts ≤ `max_value_per_call`. + * Enforce ERC-20 `transfer` / `transferFrom` amounts ≤ `max_value_per_call` (native `call.value` transfers are out-of-scope for v0). * Require session nonce match, then verify the ECDSA signature over the poseidon-hashed call set. * Call `apply_session_usage` to bump counters/nonce and emit `SessionUsed` + `SessionNonceAdvanced`. **Events:** ``` -event SessionAdded(key_hash: felt252, expires_at: u64, max_calls: u32); +event SessionAdded(key_hash: felt252, valid_after: u64, valid_until: u64, max_calls: u32); event SessionRevoked(key_hash: felt252); event SessionUsed(key_hash: felt252, used: u32); event SessionNonceAdvanced(key_hash: felt252, new_nonce: u128); @@ -252,7 +253,7 @@ await ua.sessions.revoke(sess.id); ## 10. Security Considerations * **Domain separation:** Session signatures bind to `(chain_id, account_addr)`. ([docs.starknet.io][1]) -* **Expiry & limits:** Every session must have a hard `expires_at`; default small `maxCalls`. +* **Expiry & limits:** Every session must declare `valid_after`/`valid_until`; default small `maxCalls`. * **Revocation:** `revokeSession(key_hash)` immediately blocks use. Events let dApps react. * **Replay protection:** Optional per-session nonce (`sessionNonce`) incremented in validation. * **Guardian griefing:** Require **m-of-n** quorum and **timelock**; owner can cancel a pending recovery. @@ -265,7 +266,7 @@ await ua.sessions.revoke(sess.id); * **Validation path**: O(#calls * (selector + target checks)). Use **bitset/bitmap** encodings for selectors if needed; start with arrays for simplicity, upgrade later. * **Storage**: `calls_used` is incremented once per tx (after checking aggregate calls), not per inner call, to minimize writes. -* **Policy packing**: keep `expires_at` in `u64`; `maxCalls` in `u32`; selectors as `felt252[]`. +* **Policy packing**: keep `valid_after`/`valid_until` in `u64`; `maxCalls` in `u32`; selectors as `felt252[]`. --- @@ -320,7 +321,7 @@ await ua.sessions.revoke(sess.id); **Events** ``` -event SessionAdded(key_hash: felt252, expires_at: u64, max_calls: u32); +event SessionAdded(key_hash: felt252, valid_after: u64, valid_until: u64, max_calls: u32); event SessionRevoked(key_hash: felt252); event SessionUsed(key_hash: felt252, used: u32); event SessionNonceAdvanced(key_hash: felt252, new_nonce: u128); diff --git a/docs/runbook-sepolia.md b/docs/runbook-sepolia.md index 41ff158..e8e7bf2 100644 --- a/docs/runbook-sepolia.md +++ b/docs/runbook-sepolia.md @@ -122,16 +122,79 @@ cd ../../ --- -## 4) (Optional) Local devnet E2E +## 4) Local devnet with `sncast` (Docker) -If you run a local Starknet devnet: +If you want to smoke-test on a local devnet before heading to Sepolia, run the same +flow we use in CI. The commands below were copy/pasted against +`shardlabs/starknet-devnet-rs:latest` and `sncast 0.33.x`. ```bash -# Start a local devnet in another terminal (example; adjust to your stack) -docker run --rm -p 5050:5050 shardlabs/starknet-devnet:latest +# In a separate terminal +docker run -it --rm -p 127.0.0.1:5050:5050 \ + shardlabs/starknet-devnet-rs:latest \ + --seed 0 --accounts 10 ``` -Then run the JavaScript E2E against devnet: +Back in the repo root: + +```bash +RPC=http://127.0.0.1:5050 +NAME=devnet + +# 1. Create a named account (writes to ~/.starknet_accounts/devnet) +sncast account create --name "$NAME" --url "$RPC" + +# 2. Fund it with FRI (STRK) for fees – copy the address from the previous output +ADDR=0xPASTE_ADDRESS_FROM_OUTPUT +curl -s "$RPC" -H 'content-type: application/json' \ + -d '{"jsonrpc":"2.0","id":1,"method":"devnet_mint","params":{"address":"'"$ADDR"'","amount":100000000000000000000,"unit":"FRI"}}' + +# (Optional) give it some ETH (WEI) so you can test paymasters that refund in ETH +curl -s "$RPC" -H 'content-type: application/json' \ + -d '{"jsonrpc":"2.0","id":2,"method":"devnet_mint","params":{"address":"'"$ADDR"'","amount":100000000000000000000,"unit":"WEI"}}' + +# 3. Deploy the account (accept the default-account prompt) +sncast account deploy --name "$NAME" --url "$RPC" + +# 4. Make sure scarb is discoverable (sncast shells out to it) +scarb manifest-path + +# 5. Declare the UA² class (bump --max-fee if sncast suggests a higher estimate) +sncast --account "$NAME" \ + declare \ + --contract-name UA2Account \ + --url "$RPC" \ + --max-fee 9638049920000000000 + +# capture the class hash from the output +UA2_CLASS_HASH=0xYOUR_CLASS_HASH + +# 6. Deploy the declared class (any felt pubkey works for local testing) +OWNER_PUBKEY=0x4173f320ca395828b2630fdb693cfb761047fc3822a66c40f9156c4bc8d7836 +sncast --account "$NAME" \ + deploy \ + --class-hash "$UA2_CLASS_HASH" \ + --constructor-calldata "$OWNER_PUBKEY" \ + --url "$RPC" \ + --max-fee 9638049920000000000 + +# capture the proxy address for the next step +UA2_ADDR=0xPASTE_DEPLOYED_ADDRESS + +# 7. Smoke test – zero-arg view, so no --calldata flag required +sncast --account "$NAME" \ + call \ + --contract-address "$UA2_ADDR" \ + --function get_owner \ + --url "$RPC" +``` + +Every command above should succeed when pasted into a fresh shell. If a declare or +deploy fails with "fee too low", rerun it with the quoted `--max-fee` adjusted to the +estimate reported by `sncast`. Fees are denominated in **FRI (STRK)** on devnet. + +With the contract running locally you can point the demo app to the devnet values or +run the TypeScript integration tests: ```bash npm run e2e:devnet @@ -173,31 +236,59 @@ ua2_account = "${UA2_PROXY_ADDR}" ## 6) Declare & deploy on Sepolia -From repo root: +You can still use `./scripts/deploy_ua2.sh`, but when debugging or verifying +interactively we recommend mirroring the devnet flow with `sncast` so the same +commands work everywhere. Replace `<...>` placeholders before running. ```bash cd packages/contracts -export STARKNET_RPC_URL= -export UA2_OWNER_PUBKEY= -./scripts/deploy_ua2.sh - -# Script output includes: -# class hash: 0x... -# contract_address: 0x... +RPC=https://starknet-sepolia.infura.io/v3/ +NAME=sepolia + +# 1. Create or reuse a named account backed by your keystore/ledger +sncast account create --name "$NAME" --url "$RPC" +sncast account deploy --name "$NAME" --url "$RPC" + +# 2. Declare the class (fees are in FRI/STRK; raise --max-fee if needed) +sncast --account "$NAME" \ + declare \ + --contract-name UA2Account \ + --url "$RPC" \ + --max-fee 9638049920000000000 + +UA2_CLASS_HASH=0xCLASS_HASH_FROM_OUTPUT + +# 3. Deploy the class to get a live proxy +OWNER_PUBKEY=0xYOUR_OWNER_FELT +sncast --account "$NAME" \ + deploy \ + --class-hash "$UA2_CLASS_HASH" \ + --constructor-calldata "$OWNER_PUBKEY" \ + --url "$RPC" \ + --max-fee 9638049920000000000 + +UA2_PROXY_ADDR=0xDEPLOYED_ADDRESS + +# 4. Verify with a read-only call (no calldata flag for zero-arg functions) +sncast --account "$NAME" \ + call \ + --contract-address "$UA2_PROXY_ADDR" \ + --function get_owner \ + --url "$RPC" ``` -The script writes the resolved values to `packages/contracts/.ua2-sepolia-addresses.json`. Paste the -latest entries into **`.env.sepolia`**: +Copy the resulting `UA2_CLASS_HASH`, implementation address (from the deploy receipt), +and `UA2_PROXY_ADDR` into `.env.sepolia` so the SDK and demo app target the right +contracts. If `sncast` reports an estimated fee above the provided max, re-run the +command with the suggested value. -``` -UA2_CLASS_HASH=0x... -UA2_IMPLEMENTATION_ADDR=0x... -UA2_PROXY_ADDR=0x... -NEXT_PUBLIC_UA2_PROXY_ADDR=0x... -``` +When you prefer automation, `./scripts/deploy_ua2.sh` is still available and writes the +same values into `packages/contracts/.ua2-sepolia-addresses.json`. -Commit env (without secrets) or keep local. The `.env.sepolia.example` template matches these keys. +> [!NOTE] +> If you open a new shell before the smoke tests below, re-export `RPC` and `NAME` +> so `sncast` can find the correct endpoint and account. --- @@ -205,9 +296,11 @@ Commit env (without secrets) or keep local. The `.env.sepolia.example` template ```bash # Read owner -sncast --profile sepolia call \ - --address $UA2_PROXY_ADDR \ - --function get_owner +sncast --account sepolia \ + call \ + --contract-address "$UA2_PROXY_ADDR" \ + --function get_owner \ + --url "$RPC" ``` Expected: @@ -220,8 +313,9 @@ Add a dummy session key (owner-signed tx): ```bash # Example: add session with 8h expiry, 50 max calls, single target, two selectors -sncast --profile sepolia invoke \ - --address $UA2_PROXY_ADDR \ +sncast --account sepolia \ + invoke \ + --contract-address "$UA2_PROXY_ADDR" \ --function add_session_with_allowlists \ --calldata \ \ @@ -234,10 +328,12 @@ sncast --profile sepolia invoke \ \ 2 \ \ - + \ + --url "$RPC" \ + --max-fee 9638049920000000000 ``` -The calldata order is: session key, `is_active`, `expires_at`, `max_calls`, `calls_used`, `max_value_per_call.low`, +The calldata order is: session key, `valid_after`, `valid_until`, `max_calls`, `max_value_per_call.low`, `max_value_per_call.high`, number of allowed targets, each target address, number of allowed selectors, each selector felt. If the contract checks revert, verify you supplied proper calldata as per `docs/interfaces.md`. diff --git a/docs/test-plan.md b/docs/test-plan.md index 004c574..5c7d483 100644 --- a/docs/test-plan.md +++ b/docs/test-plan.md @@ -160,7 +160,8 @@ Script: `packages/example/scripts/e2e-devnet.ts` 1. Deploy UA² to devnet (or attach if auto-deployed by script). 2. Create a session with: - * `expires_at = now + 2h` + * `valid_after = now` + * `valid_until = now + 2h` * `max_calls = 5` * selectors = `[transfer]` * targets = `[ERC20_TEST_ADDR]` diff --git a/docs/validation.md b/docs/validation.md index 47b5d5b..97ab6b5 100644 --- a/docs/validation.md +++ b/docs/validation.md @@ -11,16 +11,19 @@ 1. Compute `key_hash` = pedersen(session_pubkey). 2. Lookup the base `SessionPolicy` in `session` storage and fetch allowlist booleans from `session_target_allow` / `session_selector_allow` using `(key_hash, target)` and `(key_hash, selector)` keys. 3. Require `is_active == true` (`ERR_SESSION_INACTIVE`). -4. Require `block.timestamp <= expires_at` (`ERR_SESSION_EXPIRED`). -5. Require `calls_used + tx_call_count <= max_calls` (`ERR_POLICY_CALLCAP`). -6. For each call in `tx.multicall`: +4. Require `block.timestamp >= valid_after` (`ERR_SESSION_NOT_READY`). +5. Require `block.timestamp <= valid_until` (`ERR_SESSION_EXPIRED`). +6. Require `calls_used + tx_call_count <= max_calls` (`ERR_POLICY_CALLCAP`). +7. For each call in `tx.multicall`: - Assert target allowlist entry is `true` (`ERR_POLICY_TARGET_DENIED`). - Assert selector allowlist entry is `true` (`ERR_POLICY_SELECTOR_DENIED`). - - If selector == `ERC20::transfer`, ensure amount ≤ `max_value_per_call` (`ERR_VALUE_LIMIT_EXCEEDED`). -7. Require provided session nonce == stored nonce (`ERR_BAD_SESSION_NONCE`). -8. Verify ECDSA signature against the computed session message (`ERR_SESSION_SIG_INVALID`). -9. Call `apply_session_usage` to atomically bump `calls_used`, advance the nonce, and emit `SessionUsed` + `SessionNonceAdvanced`. -10. Proceed to `__execute__`. + - If selector == `ERC20::transfer` **or** `ERC20::transferFrom`, ensure amount ≤ `max_value_per_call` (`ERR_VALUE_LIMIT_EXCEEDED`). + - Direct native-token `call.value` transfers are out-of-scope in v0; stick to ERC-20 flows when using value caps. +8. Require provided session nonce == stored nonce (`ERR_BAD_SESSION_NONCE`). +9. Verify ECDSA signature against the computed session message (`ERR_SESSION_SIG_INVALID`). + - Message binds `{chainId, accountAddress, sessionPubkey, callHash, validUntil, nonce}` to prevent cross-channel replay. +10. Call `apply_session_usage` to atomically bump `calls_used`, advance the nonce, and emit `SessionUsed` + `SessionNonceAdvanced`. +11. Proceed to `__execute__`. If any check fails → revert with specific error code. diff --git a/packages/contracts/Scarb.toml b/packages/contracts/Scarb.toml index 1492961..b78328e 100644 --- a/packages/contracts/Scarb.toml +++ b/packages/contracts/Scarb.toml @@ -7,6 +7,7 @@ description = "UA2 smart contract suite" [dependencies] openzeppelin = { git = "https://github.com/OpenZeppelin/cairo-contracts.git", tag = "v2.0.0" } starknet = "2.12.0" +cairo_test = "2.12.2" [dev-dependencies] snforge_std = { git = "https://github.com/foundry-rs/starknet-foundry.git", tag = "v0.50.0" } diff --git a/packages/contracts/src/errors.cairo b/packages/contracts/src/errors.cairo new file mode 100644 index 0000000..98b795d --- /dev/null +++ b/packages/contracts/src/errors.cairo @@ -0,0 +1,32 @@ +pub const ERR_SESSION_EXPIRED: felt252 = 'ERR_SESSION_EXPIRED'; +pub const ERR_SESSION_INACTIVE: felt252 = 'ERR_SESSION_INACTIVE'; +pub const ERR_SESSION_STALE: felt252 = 'ERR_SESSION_STALE'; +pub const ERR_POLICY_CALLCAP: felt252 = 'ERR_POLICY_CALLCAP'; +pub const ERR_POLICY_SELECTOR_DENIED: felt252 = 'ERR_POLICY_SELECTOR_DENIED'; +pub const ERR_POLICY_TARGET_DENIED: felt252 = 'ERR_POLICY_TARGET_DENIED'; +pub const ERR_VALUE_LIMIT_EXCEEDED: felt252 = 'ERR_VALUE_LIMIT_EXCEEDED'; +pub const ERR_SESSION_NOT_READY: felt252 = 'ERR_SESSION_NOT_READY'; +pub const ERR_SESSION_TARGETS_LEN: felt252 = 'ERR_SESSION_TARGETS_LEN'; +pub const ERR_SESSION_SELECTORS_LEN: felt252 = 'ERR_SESSION_SELECTORS_LEN'; +pub const ERR_POLICY_CALLCOUNT_MISMATCH: felt252 = 'ERR_POLICY_CALLCOUNT_MISMATCH'; +pub const ERR_BAD_SESSION_NONCE: felt252 = 'ERR_BAD_SESSION_NONCE'; +pub const ERR_SESSION_SIG_INVALID: felt252 = 'ERR_SESSION_SIG_INVALID'; +pub const ERR_BAD_VALID_WINDOW: felt252 = 'ERR_BAD_VALID_WINDOW'; +pub const ERR_BAD_MAX_CALLS: felt252 = 'ERR_BAD_MAX_CALLS'; +pub const ERR_UNSUPPORTED_AUTH_MODE: felt252 = 'ERR_UNSUPPORTED_AUTH_MODE'; +pub const ERR_GUARDIAN_EXISTS: felt252 = 'ERR_GUARDIAN_EXISTS'; +pub const ERR_NOT_GUARDIAN: felt252 = 'ERR_NOT_GUARDIAN'; +pub const ERR_BAD_THRESHOLD: felt252 = 'ERR_BAD_THRESHOLD'; +pub const ERR_RECOVERY_IN_PROGRESS: felt252 = 'ERR_RECOVERY_IN_PROGRESS'; +pub const ERR_NO_RECOVERY: felt252 = 'ERR_NO_RECOVERY'; +pub const ERR_RECOVERY_MISMATCH: felt252 = 'ERR_RECOVERY_MISMATCH'; +pub const ERR_ALREADY_CONFIRMED: felt252 = 'ERR_ALREADY_CONFIRMED'; +pub const ERR_BEFORE_ETA: felt252 = 'ERR_BEFORE_ETA'; +pub const ERR_NOT_ENOUGH_CONFIRMS: felt252 = 'ERR_NOT_ENOUGH_CONFIRMS'; +pub const ERR_ZERO_OWNER: felt252 = 'ERR_ZERO_OWNER'; +pub const ERR_SAME_OWNER: felt252 = 'ERR_SAME_OWNER'; +pub const ERR_SIGNATURE_MISSING: felt252 = 'ERR_SIGNATURE_MISSING'; +pub const ERR_OWNER_SIG_INVALID: felt252 = 'ERR_OWNER_SIG_INVALID'; +pub const ERR_GUARDIAN_SIG_INVALID: felt252 = 'ERR_GUARDIAN_SIG_INVALID'; +pub const ERR_GUARDIAN_CALL_DENIED: felt252 = 'ERR_GUARDIAN_CALL_DENIED'; +pub const ERR_NOT_OWNER: felt252 = 'NOT_OWNER'; diff --git a/packages/contracts/src/lib.cairo b/packages/contracts/src/lib.cairo index d8f762a..8bef8ee 100644 --- a/packages/contracts/src/lib.cairo +++ b/packages/contracts/src/lib.cairo @@ -1,2 +1,4 @@ +pub mod errors; pub mod ua2_account; pub mod mock_erc20; +pub mod session; diff --git a/packages/contracts/src/mock_erc20.cairo b/packages/contracts/src/mock_erc20.cairo index 095f6b7..4f5b448 100644 --- a/packages/contracts/src/mock_erc20.cairo +++ b/packages/contracts/src/mock_erc20.cairo @@ -7,6 +7,7 @@ pub mod MockERC20 { #[storage] pub struct Storage { + last_from: ContractAddress, last_to: ContractAddress, last_amount: u256, } @@ -18,6 +19,19 @@ pub mod MockERC20 { true } + #[external(v0)] + fn transferFrom( + ref self: ContractState, + from: ContractAddress, + to: ContractAddress, + amount: u256, + ) -> bool { + self.last_from.write(from); + self.last_to.write(to); + self.last_amount.write(amount); + true + } + #[external(v0)] fn get_last(self: @ContractState) -> (ContractAddress, u256) { (self.last_to.read(), self.last_amount.read()) diff --git a/packages/contracts/src/session.cairo b/packages/contracts/src/session.cairo new file mode 100644 index 0000000..f41584f --- /dev/null +++ b/packages/contracts/src/session.cairo @@ -0,0 +1,24 @@ +use core::array::Array; +use core::integer::u256; +use starknet::ContractAddress; + +/// Canonical session configuration passed by the account owner. +/// +/// * `pubkey` is the Stark curve public key for the session signer (felt252). +/// * `valid_after` / `valid_until` gate the validity window using block timestamps (u64). +/// * `max_calls` limits how many calls can be executed over the lifetime of the session (u32). +/// * `value_cap` bounds the maximum value per call in wei-like units (u256). +/// * `targets_len` mirrors the number of addresses contained in `targets`. +/// * `selectors_len` mirrors the number of function selectors contained in `selectors`. +#[derive(Drop, Serde)] +pub struct Session { + pub pubkey: felt252, + pub valid_after: u64, + pub valid_until: u64, + pub max_calls: u32, + pub value_cap: u256, + pub targets_len: u32, + pub targets: Array, + pub selectors_len: u32, + pub selectors: Array, +} diff --git a/packages/contracts/src/ua2_account.cairo b/packages/contracts/src/ua2_account.cairo index 1966313..d72af76 100644 --- a/packages/contracts/src/ua2_account.cairo +++ b/packages/contracts/src/ua2_account.cairo @@ -4,59 +4,47 @@ use openzeppelin::introspection::src5::SRC5Component; #[starknet::contract(account)] #[feature("deprecated_legacy_map")] pub mod UA2Account { - use super::{AccountComponent, SRC5Component}; use core::array::{Array, ArrayTrait, SpanTrait}; - use core::option::Option; - use core::traits::{Into, TryInto}; - use core::integer::u256; - use core::serde::Serde; use core::ecdsa::check_ecdsa_signature; + use core::integer::u256; + use core::option::Option; + use core::pedersen::pedersen; use core::poseidon::poseidon_hash_span; + use core::serde::Serde; + use core::traits::{Into, TryInto}; use openzeppelin::account::interface; use starknet::account::Call; - use core::pedersen::pedersen; use starknet::storage::Map; use starknet::syscalls::call_contract_syscall; use starknet::{ - ContractAddress, - SyscallResultTrait, - get_caller_address, - get_contract_address, + ContractAddress, SyscallResultTrait, get_caller_address, get_contract_address, get_execution_info, }; + use crate::errors::{ + ERR_ALREADY_CONFIRMED, ERR_BAD_MAX_CALLS, ERR_BAD_SESSION_NONCE, ERR_BAD_THRESHOLD, + ERR_BAD_VALID_WINDOW, ERR_BEFORE_ETA, ERR_GUARDIAN_CALL_DENIED, ERR_GUARDIAN_EXISTS, + ERR_GUARDIAN_SIG_INVALID, ERR_NOT_ENOUGH_CONFIRMS, ERR_NOT_GUARDIAN, ERR_NOT_OWNER, + ERR_NO_RECOVERY, ERR_OWNER_SIG_INVALID, ERR_POLICY_CALLCAP, ERR_POLICY_CALLCOUNT_MISMATCH, + ERR_POLICY_SELECTOR_DENIED, ERR_POLICY_TARGET_DENIED, ERR_RECOVERY_IN_PROGRESS, + ERR_RECOVERY_MISMATCH, ERR_SAME_OWNER, ERR_SESSION_EXPIRED, ERR_SESSION_INACTIVE, + ERR_SESSION_NOT_READY, ERR_SESSION_SELECTORS_LEN, ERR_SESSION_SIG_INVALID, + ERR_SESSION_STALE, ERR_SESSION_TARGETS_LEN, ERR_SIGNATURE_MISSING, + ERR_UNSUPPORTED_AUTH_MODE, ERR_VALUE_LIMIT_EXCEEDED, ERR_ZERO_OWNER, + }; + use crate::session::Session; + use super::{AccountComponent, SRC5Component}; component!(path: AccountComponent, storage: account, event: AccountEvent); component!(path: SRC5Component, storage: src5, event: SRC5Event); - const ERR_SESSION_EXPIRED: felt252 = 'ERR_SESSION_EXPIRED'; - const ERR_SESSION_INACTIVE: felt252 = 'ERR_SESSION_INACTIVE'; - const ERR_POLICY_CALLCAP: felt252 = 'ERR_POLICY_CALLCAP'; - const ERR_POLICY_SELECTOR_DENIED: felt252 = 'ERR_POLICY_SELECTOR_DENIED'; - const ERR_POLICY_TARGET_DENIED: felt252 = 'ERR_POLICY_TARGET_DENIED'; - const ERR_VALUE_LIMIT_EXCEEDED: felt252 = 'ERR_VALUE_LIMIT_EXCEEDED'; - const ERR_POLICY_CALLCOUNT_MISMATCH: felt252 = 'ERR_POLICY_CALLCOUNT_MISMATCH'; - const ERR_BAD_SESSION_NONCE: felt252 = 'ERR_BAD_SESSION_NONCE'; - const ERR_SESSION_SIG_INVALID: felt252 = 'ERR_SESSION_SIG_INVALID'; - const ERR_GUARDIAN_EXISTS: felt252 = 'ERR_GUARDIAN_EXISTS'; - const ERR_NOT_GUARDIAN: felt252 = 'ERR_NOT_GUARDIAN'; - const ERR_BAD_THRESHOLD: felt252 = 'ERR_BAD_THRESHOLD'; - const ERR_RECOVERY_IN_PROGRESS: felt252 = 'ERR_RECOVERY_IN_PROGRESS'; - const ERR_NO_RECOVERY: felt252 = 'ERR_NO_RECOVERY'; - const ERR_RECOVERY_MISMATCH: felt252 = 'ERR_RECOVERY_MISMATCH'; - const ERR_ALREADY_CONFIRMED: felt252 = 'ERR_ALREADY_CONFIRMED'; - const ERR_BEFORE_ETA: felt252 = 'ERR_BEFORE_ETA'; - const ERR_NOT_ENOUGH_CONFIRMS: felt252 = 'ERR_NOT_ENOUGH_CONFIRMS'; - const ERR_ZERO_OWNER: felt252 = 'ERR_ZERO_OWNER'; - const ERR_SAME_OWNER: felt252 = 'ERR_SAME_OWNER'; - const ERR_SIGNATURE_MISSING: felt252 = 'ERR_SIGNATURE_MISSING'; - const ERR_OWNER_SIG_INVALID: felt252 = 'ERR_OWNER_SIG_INVALID'; - const ERR_GUARDIAN_SIG_INVALID: felt252 = 'ERR_GUARDIAN_SIG_INVALID'; - const ERR_GUARDIAN_CALL_DENIED: felt252 = 'ERR_GUARDIAN_CALL_DENIED'; - const ERC20_TRANSFER_SEL: felt252 = 0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e; + const ERC20_TRANSFER_SEL: felt252 = + 0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e; + const ERC20_TRANSFER_FROM_SEL: felt252 = starknet::selector!("transferFrom"); const APPLY_SESSION_USAGE_SELECTOR: felt252 = starknet::selector!("apply_session_usage"); const PROPOSE_RECOVERY_SELECTOR: felt252 = starknet::selector!("propose_recovery"); const CONFIRM_RECOVERY_SELECTOR: felt252 = starknet::selector!("confirm_recovery"); const EXECUTE_RECOVERY_SELECTOR: felt252 = starknet::selector!("execute_recovery"); + const CANCEL_RECOVERY_SELECTOR: felt252 = starknet::selector!("cancel_recovery"); #[storage] pub struct Storage { @@ -69,6 +57,7 @@ pub mod UA2Account { session_nonce: Map, session_target_allow: LegacyMap<(felt252, ContractAddress), bool>, session_selector_allow: LegacyMap<(felt252, felt252), bool>, + session_owner_epoch: u64, guardians: LegacyMap, guardian_count: u32, guardian_threshold: u8, @@ -76,7 +65,6 @@ pub mod UA2Account { recovery_active: bool, recovery_proposed_owner: felt252, recovery_eta: u64, - recovery_confirms: LegacyMap, recovery_confirm_count: u32, recovery_proposal_id: u64, recovery_guardian_last_confirm: LegacyMap, @@ -85,16 +73,19 @@ pub mod UA2Account { #[derive(Copy, Drop, Serde, starknet::Store)] pub struct SessionPolicy { pub is_active: bool, - pub expires_at: u64, + pub valid_after: u64, + pub valid_until: u64, pub max_calls: u32, pub calls_used: u32, pub max_value_per_call: u256, + pub owner_epoch: u64, } #[derive(Drop, starknet::Event)] pub struct SessionAdded { pub key_hash: felt252, - pub expires_at: u64, + pub valid_after: u64, + pub valid_until: u64, pub max_calls: u32, } @@ -140,6 +131,21 @@ pub mod UA2Account { pub new_owner: felt252, } + #[derive(Drop, starknet::Event)] + pub struct GuardianProposed { + pub guardian: ContractAddress, + pub proposal_id: u64, + pub new_owner: felt252, + pub eta: u64, + } + + #[derive(Drop, starknet::Event)] + pub struct GuardianFinalized { + pub guardian: ContractAddress, + pub proposal_id: u64, + pub new_owner: felt252, + } + #[derive(Drop, starknet::Event)] pub struct RecoveryProposed { pub new_owner: felt252, @@ -184,6 +190,8 @@ pub mod UA2Account { ThresholdSet: ThresholdSet, RecoveryDelaySet: RecoveryDelaySet, OwnerRotated: OwnerRotated, + GuardianProposed: GuardianProposed, + GuardianFinalized: GuardianFinalized, RecoveryProposed: RecoveryProposed, RecoveryConfirmed: RecoveryConfirmed, RecoveryCanceled: RecoveryCanceled, @@ -193,6 +201,7 @@ pub mod UA2Account { #[constructor] fn constructor(ref self: ContractState, public_key: felt252) { self.owner_pubkey.write(public_key); + self.session_owner_epoch.write(0_u64); self.account.initializer(public_key); } @@ -204,7 +213,7 @@ pub mod UA2Account { fn assert_owner() { let caller: ContractAddress = get_caller_address(); let contract_address: ContractAddress = get_contract_address(); - assert(caller == contract_address, 'NOT_OWNER'); + assert(caller == contract_address, ERR_NOT_OWNER); } fn require(condition: bool, error: felt252) { @@ -232,6 +241,24 @@ pub mod UA2Account { self.recovery_confirm_count.write(0_u32); } + fn _bump_owner_epoch(ref self: ContractState) { + let current_epoch = self.session_owner_epoch.read(); + self.session_owner_epoch.write(current_epoch + 1_u64); + } + + fn _record_recovery_confirmation( + ref self: ContractState, guardian: ContractAddress, proposal_id: u64, last_confirm: u64, + ) -> u32 { + assert(last_confirm != proposal_id, ERR_ALREADY_CONFIRMED); + + self.recovery_guardian_last_confirm.write(guardian, proposal_id); + + let new_count = self.recovery_confirm_count.read() + 1_u32; + self.recovery_confirm_count.write(new_count); + + new_count + } + #[external(v0)] fn add_guardian(ref self: ContractState, addr: ContractAddress) { assert_owner(); @@ -307,18 +334,23 @@ pub mod UA2Account { let last_confirm = self.recovery_guardian_last_confirm.read(caller); if last_confirm != proposal_id { - self.recovery_confirms.write(caller, true); - self.recovery_guardian_last_confirm.write(caller, proposal_id); - self.recovery_confirm_count.write(1_u32); - self.emit( - Event::RecoveryConfirmed(RecoveryConfirmed { - guardian: caller, - new_owner, - count: 1_u32, - }), + let confirm_count = _record_recovery_confirmation( + ref self, caller, proposal_id, last_confirm, ); + self + .emit( + Event::RecoveryConfirmed( + RecoveryConfirmed { guardian: caller, new_owner, count: confirm_count }, + ), + ); } + self + .emit( + Event::GuardianProposed( + GuardianProposed { guardian: caller, proposal_id, new_owner, eta }, + ), + ); self.emit(Event::RecoveryProposed(RecoveryProposed { new_owner, eta })); } @@ -335,27 +367,14 @@ pub mod UA2Account { let proposal_id = self.recovery_proposal_id.read(); let last_confirm = self.recovery_guardian_last_confirm.read(caller); + let new_count = _record_recovery_confirmation(ref self, caller, proposal_id, last_confirm); - if last_confirm != proposal_id { - self.recovery_confirms.write(caller, false); - } - - let already_confirmed = self.recovery_confirms.read(caller); - assert(already_confirmed == false, ERR_ALREADY_CONFIRMED); - - self.recovery_confirms.write(caller, true); - self.recovery_guardian_last_confirm.write(caller, proposal_id); - - let new_count = self.recovery_confirm_count.read() + 1_u32; - self.recovery_confirm_count.write(new_count); - - self.emit( - Event::RecoveryConfirmed(RecoveryConfirmed { - guardian: caller, - new_owner, - count: new_count, - }), - ); + self + .emit( + Event::RecoveryConfirmed( + RecoveryConfirmed { guardian: caller, new_owner, count: new_count }, + ), + ); } #[external(v0)] @@ -383,11 +402,13 @@ pub mod UA2Account { assert(new_owner != current, ERR_SAME_OWNER); self.owner_pubkey.write(new_owner); + _bump_owner_epoch(ref self); self.emit(Event::OwnerRotated(OwnerRotated { new_owner })); } #[external(v0)] fn execute_recovery(ref self: ContractState) { + let caller = get_caller_address(); let active = self.recovery_active.read(); assert(active == true, ERR_NO_RECOVERY); @@ -401,12 +422,20 @@ pub mod UA2Account { assert(now >= eta, ERR_BEFORE_ETA); let new_owner = self.recovery_proposed_owner.read(); + let proposal_id = self.recovery_proposal_id.read(); self.owner_pubkey.write(new_owner); + _bump_owner_epoch(ref self); _clear_recovery_state(ref self); self.emit(Event::OwnerRotated(OwnerRotated { new_owner })); self.emit(Event::RecoveryExecuted(RecoveryExecuted { new_owner })); + self + .emit( + Event::GuardianFinalized( + GuardianFinalized { guardian: caller, proposal_id, new_owner }, + ), + ); } fn u256_le(lhs: u256, rhs: u256) -> bool { @@ -420,32 +449,61 @@ pub mod UA2Account { } #[external(v0)] - fn add_session_with_allowlists( - ref self: ContractState, - key: felt252, - mut policy: SessionPolicy, - targets: Array, - selectors: Array, - ) { + fn add_session_with_allowlists(ref self: ContractState, session: Session) { assert_owner(); - self.add_session(key, policy); - let key_hash = derive_key_hash(key); + let Session { + pubkey, + valid_after, + valid_until, + max_calls, + value_cap, + targets_len, + targets, + selectors_len, + selectors, + } = session; + + let declared_targets_len: usize = match targets_len.try_into() { + Option::Some(value) => value, + Option::None(_) => { + assert(false, ERR_SESSION_TARGETS_LEN); + 0_usize + }, + }; + let actual_targets_len = ArrayTrait::::len(@targets); + require(actual_targets_len == declared_targets_len, ERR_SESSION_TARGETS_LEN); - let mut i = 0_usize; - let targets_len = ArrayTrait::::len(@targets); - while i < targets_len { - let target = *ArrayTrait::::at(@targets, i); - self.session_target_allow.write((key_hash, target), true); - i += 1_usize; + let declared_selectors_len: usize = match selectors_len.try_into() { + Option::Some(value) => value, + Option::None(_) => { + assert(false, ERR_SESSION_SELECTORS_LEN); + 0_usize + }, + }; + let actual_selectors_len = ArrayTrait::::len(@selectors); + require(actual_selectors_len == declared_selectors_len, ERR_SESSION_SELECTORS_LEN); + + let mut policy = SessionPolicy { + is_active: false, + valid_after, + valid_until, + max_calls, + calls_used: 0_u32, + max_value_per_call: value_cap, + owner_epoch: 0_u64, + }; + + self.add_session(pubkey, policy); + + let key_hash = derive_key_hash(pubkey); + + for target_ref in targets.span() { + self.session_target_allow.write((key_hash, *target_ref), true); } - i = 0_usize; - let selectors_len = ArrayTrait::::len(@selectors); - while i < selectors_len { - let selector = *ArrayTrait::::at(@selectors, i); - self.session_selector_allow.write((key_hash, selector), true); - i += 1_usize; + for selector_ref in selectors.span() { + self.session_selector_allow.write((key_hash, *selector_ref), true); } } @@ -460,9 +518,12 @@ pub mod UA2Account { let mut policy = self.session.read(key_hash); require(policy.is_active, ERR_SESSION_INACTIVE); + let current_epoch = self.session_owner_epoch.read(); + require(policy.owner_epoch == current_epoch, ERR_SESSION_STALE); let now = get_block_timestamp(); - require(now <= policy.expires_at, ERR_SESSION_EXPIRED); + require(now >= policy.valid_after, ERR_SESSION_NOT_READY); + require(now <= policy.valid_until, ERR_SESSION_EXPIRED); require(policy.calls_used == prior_calls_used, ERR_POLICY_CALLCOUNT_MISMATCH); @@ -487,18 +548,29 @@ pub mod UA2Account { fn add_session(ref self: ContractState, key: felt252, mut policy: SessionPolicy) { assert_owner(); - assert(policy.expires_at > 0_u64, 'BAD_EXPIRY'); - assert(policy.max_calls > 0_u32, 'BAD_MAX_CALLS'); + assert(policy.valid_until > policy.valid_after, ERR_BAD_VALID_WINDOW); + assert(policy.max_calls > 0_u32, ERR_BAD_MAX_CALLS); let key_hash = derive_key_hash(key); policy.is_active = true; policy.calls_used = 0_u32; + policy.owner_epoch = self.session_owner_epoch.read(); self.session.write(key_hash, policy); self.session_nonce.write(key_hash, 0_u128); - self.emit(Event::SessionAdded(SessionAdded { key_hash, expires_at: policy.expires_at, max_calls: policy.max_calls })); + self + .emit( + Event::SessionAdded( + SessionAdded { + key_hash, + valid_after: policy.valid_after, + valid_until: policy.valid_until, + max_calls: policy.max_calls, + }, + ), + ); } fn get_session(self: @ContractState, key_hash: felt252) -> SessionPolicy { @@ -538,6 +610,39 @@ pub mod UA2Account { sum } + fn enforce_value_limit(calldata: Span, amount_index: usize, limit: u256) { + let required_len = amount_index + 2_usize; + let calldata_len = calldata.len(); + require(calldata_len >= required_len, ERR_VALUE_LIMIT_EXCEEDED); + + let amount_low_felt = *calldata.at(amount_index); + let amount_high_felt = *calldata.at(amount_index + 1_usize); + + let amount_low: u128 = match amount_low_felt.try_into() { + Option::Some(value) => value, + Option::None(_) => { + assert(false, ERR_VALUE_LIMIT_EXCEEDED); + 0_u128 + }, + }; + + let amount_high: u128 = match amount_high_felt.try_into() { + Option::Some(value) => value, + Option::None(_) => { + assert(false, ERR_VALUE_LIMIT_EXCEEDED); + 0_u128 + }, + }; + + let amount = u256 { low: amount_low, high: amount_high }; + + if amount.high > limit.high { + assert(false, ERR_VALUE_LIMIT_EXCEEDED); + } else if amount.high == limit.high { + assert(amount.low <= limit.low, ERR_VALUE_LIMIT_EXCEEDED); + } + } + fn poseidon_chain(acc: felt252, value: felt252) -> felt252 { let mut values = array![acc, value]; poseidon_hash_span(values.span()) @@ -579,6 +684,40 @@ pub mod UA2Account { low + high * high_limit } + fn bool_to_u32(value: bool) -> u32 { + if value { + 1_u32 + } else { + 0_u32 + } + } + + fn calls_are_recovery_only(calls: @Array) -> bool { + let calls_len = ArrayTrait::::len(calls); + if calls_len == 0_usize { + return false; + } + + let contract_address = get_contract_address(); + + for call_ref in calls.span() { + let Call { to, selector, calldata: _ } = *call_ref; + + if to != contract_address { + return false; + } + + let allowed = selector == CONFIRM_RECOVERY_SELECTOR + || selector == EXECUTE_RECOVERY_SELECTOR + || selector == CANCEL_RECOVERY_SELECTOR; + if allowed == false { + return false; + } + } + + true + } + fn extract_owner_signature(signature: Span) -> Array { let signature_len = signature.len(); require(signature_len > 0_usize, ERR_SIGNATURE_MISSING); @@ -605,9 +744,7 @@ pub mod UA2Account { } fn validate_guardian_authorization( - self: @ContractState, - signature: Span, - calls: @Array, + self: @ContractState, signature: Span, calls: @Array, ) { let signature_len = signature.len(); require(signature_len == 2_usize, ERR_GUARDIAN_SIG_INVALID); @@ -639,7 +776,8 @@ pub mod UA2Account { let allowed = selector == PROPOSE_RECOVERY_SELECTOR || selector == CONFIRM_RECOVERY_SELECTOR - || selector == EXECUTE_RECOVERY_SELECTOR; + || selector == EXECUTE_RECOVERY_SELECTOR + || selector == CANCEL_RECOVERY_SELECTOR; require(allowed, ERR_GUARDIAN_CALL_DENIED); } } @@ -647,25 +785,21 @@ pub mod UA2Account { fn compute_session_message_hash( chain_id: felt252, account_felt: felt252, + session_pubkey: felt252, key_hash: felt252, - nonce: u128, call_digest: felt252, + valid_until: u64, + nonce: u128, ) -> felt252 { let mut values = array![ - SESSION_DOMAIN_TAG, - chain_id, - account_felt, - key_hash, - nonce.into(), - call_digest, + SESSION_DOMAIN_TAG, chain_id, account_felt, session_pubkey, key_hash, call_digest, + valid_until.into(), nonce.into(), ]; poseidon_hash_span(values.span()) } fn validate_session_policy( - self: @ContractState, - signature: Span, - calls: @Array, + self: @ContractState, signature: Span, calls: @Array, ) -> SessionValidation { let signature_len = signature.len(); require(signature_len >= 6_usize, ERR_SESSION_SIG_INVALID); @@ -699,9 +833,12 @@ pub mod UA2Account { let policy = self.session.read(key_hash); require(policy.is_active, ERR_SESSION_INACTIVE); + let current_epoch = self.session_owner_epoch.read(); + require(policy.owner_epoch == current_epoch, ERR_SESSION_STALE); let now = get_block_timestamp(); - require(now <= policy.expires_at, ERR_SESSION_EXPIRED); + require(now >= policy.valid_after, ERR_SESSION_NOT_READY); + require(now <= policy.valid_until, ERR_SESSION_EXPIRED); let calls_len = ArrayTrait::::len(calls); let tx_call_count: u32 = match calls_len.try_into() { @@ -725,36 +862,9 @@ pub mod UA2Account { assert(selector_allowed == true, ERR_POLICY_SELECTOR_DENIED); if selector == ERC20_TRANSFER_SEL { - let calldata_len = calldata.len(); - require(calldata_len >= 3_usize, ERR_VALUE_LIMIT_EXCEEDED); - - let amount_low_felt = *calldata.at(1_usize); - let amount_high_felt = *calldata.at(2_usize); - - let amount_low: u128 = match amount_low_felt.try_into() { - Option::Some(value) => value, - Option::None(_) => { - assert(false, ERR_VALUE_LIMIT_EXCEEDED); - 0_u128 - }, - }; - - let amount_high: u128 = match amount_high_felt.try_into() { - Option::Some(value) => value, - Option::None(_) => { - assert(false, ERR_VALUE_LIMIT_EXCEEDED); - 0_u128 - }, - }; - - let amount = u256 { low: amount_low, high: amount_high }; - let limit = policy.max_value_per_call; - - if amount.high > limit.high { - assert(false, ERR_VALUE_LIMIT_EXCEEDED); - } else if amount.high == limit.high { - assert(amount.low <= limit.low, ERR_VALUE_LIMIT_EXCEEDED); - } + enforce_value_limit(calldata, 1_usize, policy.max_value_per_call); + } else if selector == ERC20_TRANSFER_FROM_SEL { + enforce_value_limit(calldata, 2_usize, policy.max_value_per_call); } } @@ -770,9 +880,11 @@ pub mod UA2Account { let message = compute_session_message_hash( chain_id, account_felt, + session_key, key_hash, - provided_nonce, call_digest, + policy.valid_until, + provided_nonce, ); let sig_r = *signature.at(4_usize); @@ -811,7 +923,7 @@ pub mod UA2Account { APPLY_SESSION_USAGE_SELECTOR, accounting_calldata.span(), ) - .unwrap_syscall(); + .unwrap_syscall(); return; } @@ -824,9 +936,9 @@ pub mod UA2Account { let owner_signature = extract_owner_signature(signature); let owner_signature_span = owner_signature.span(); - let owner_valid = AccountComponent::InternalImpl::::_is_valid_signature( - self.account, tx_hash, owner_signature_span - ); + let owner_valid = AccountComponent::InternalImpl::< + ContractState, + >::_is_valid_signature(self.account, tx_hash, owner_signature_span); require(owner_valid, ERR_OWNER_SIG_INVALID); @@ -840,23 +952,38 @@ pub mod UA2Account { let signature_len = signature.len(); require(signature_len > 0_usize, ERR_SIGNATURE_MISSING); - let mode = *signature.at(0_usize); + let recovery_active = self.recovery_active.read(); + if recovery_active { + let allowed = calls_are_recovery_only(@calls); + assert(allowed, ERR_RECOVERY_IN_PROGRESS); + return starknet::VALIDATED; + } - if mode == MODE_SESSION { + let mode_hint = *signature.at(0_usize); + let using_session = mode_hint == MODE_SESSION; + let using_guardian = mode_hint == MODE_GUARDIAN; + let using_owner = mode_hint == MODE_OWNER || (!using_session && !using_guardian); + + let mut modes = bool_to_u32(using_owner); + modes += bool_to_u32(using_session); + modes += bool_to_u32(using_guardian); + require(modes == 1_u32, ERR_UNSUPPORTED_AUTH_MODE); + + if using_session { let _validation = validate_session_policy(self, signature, @calls); return starknet::VALIDATED; } - if mode == MODE_GUARDIAN { + if using_guardian { validate_guardian_authorization(self, signature, @calls); return starknet::VALIDATED; } let owner_signature = extract_owner_signature(signature); let owner_signature_span = owner_signature.span(); - let owner_valid = AccountComponent::InternalImpl::::_is_valid_signature( - self.account, tx_hash, owner_signature_span - ); + let owner_valid = AccountComponent::InternalImpl::< + ContractState, + >::_is_valid_signature(self.account, tx_hash, owner_signature_span); require(owner_valid, ERR_OWNER_SIG_INVALID); @@ -866,21 +993,21 @@ pub mod UA2Account { fn is_valid_signature( self: @ContractState, hash: felt252, signature: Array, ) -> felt252 { - AccountComponent::AccountMixinImpl::::is_valid_signature( - self, hash, signature - ) + AccountComponent::AccountMixinImpl::< + ContractState, + >::is_valid_signature(self, hash, signature) } fn supports_interface(self: @ContractState, interface_id: felt252) -> bool { - AccountComponent::AccountMixinImpl::::supports_interface( - self, interface_id - ) + AccountComponent::AccountMixinImpl::< + ContractState, + >::supports_interface(self, interface_id) } fn __validate_declare__(self: @ContractState, class_hash: felt252) -> felt252 { - AccountComponent::AccountMixinImpl::::__validate_declare__( - self, class_hash - ) + AccountComponent::AccountMixinImpl::< + ContractState, + >::__validate_declare__(self, class_hash) } fn __validate_deploy__( @@ -889,9 +1016,9 @@ pub mod UA2Account { contract_address_salt: felt252, public_key: felt252, ) -> felt252 { - AccountComponent::AccountMixinImpl::::__validate_deploy__( - self, class_hash, contract_address_salt, public_key - ) + AccountComponent::AccountMixinImpl::< + ContractState, + >::__validate_deploy__(self, class_hash, contract_address_salt, public_key) } fn get_public_key(self: @ContractState) -> felt252 { @@ -901,29 +1028,27 @@ pub mod UA2Account { fn set_public_key( ref self: ContractState, new_public_key: felt252, signature: Span, ) { - AccountComponent::AccountMixinImpl::::set_public_key( - ref self, new_public_key, signature - ); + AccountComponent::AccountMixinImpl::< + ContractState, + >::set_public_key(ref self, new_public_key, signature); } fn isValidSignature( self: @ContractState, hash: felt252, signature: Array, ) -> felt252 { - AccountComponent::AccountMixinImpl::::isValidSignature( - self, hash, signature - ) + AccountComponent::AccountMixinImpl::< + ContractState, + >::isValidSignature(self, hash, signature) } fn getPublicKey(self: @ContractState) -> felt252 { AccountComponent::AccountMixinImpl::::getPublicKey(self) } - fn setPublicKey( - ref self: ContractState, newPublicKey: felt252, signature: Span, - ) { - AccountComponent::AccountMixinImpl::::setPublicKey( - ref self, newPublicKey, signature - ); + fn setPublicKey(ref self: ContractState, newPublicKey: felt252, signature: Span) { + AccountComponent::AccountMixinImpl::< + ContractState, + >::setPublicKey(ref self, newPublicKey, signature); } } impl AccountInternalImpl = AccountComponent::InternalImpl; diff --git a/packages/contracts/tests/lib.cairo b/packages/contracts/tests/lib.cairo index 1b30edd..c9fdf68 100644 --- a/packages/contracts/tests/lib.cairo +++ b/packages/contracts/tests/lib.cairo @@ -72,8 +72,10 @@ pub mod session_test_utils { fn compute_message( account_address: ContractAddress, + session_pubkey: felt252, key_hash: felt252, nonce: u128, + valid_until: u64, calls: @Array, ) -> felt252 { let execution_info = get_execution_info().unbox(); @@ -85,9 +87,11 @@ pub mod session_test_utils { SESSION_DOMAIN_TAG, chain_id, account_felt, + session_pubkey, key_hash, - nonce.into(), call_digest, + valid_until.into(), + nonce.into(), ]; poseidon_hash_span(values.span()) } @@ -108,10 +112,11 @@ pub mod session_test_utils { account_address: ContractAddress, session_pubkey: felt252, nonce: u128, + valid_until: u64, calls: @Array, ) -> Array { let key_hash = pedersen(session_pubkey, 0); - let message = compute_message(account_address, key_hash, nonce, calls); + let message = compute_message(account_address, session_pubkey, key_hash, nonce, valid_until, calls); let key_pair = session_keypair(); let signature = StarkCurveSignerImpl::sign(key_pair, message); let (r, s) = match signature { diff --git a/packages/contracts/tests/test_guardians_admin.cairo b/packages/contracts/tests/test_guardians_admin.cairo index 3b1c545..1764e17 100644 --- a/packages/contracts/tests/test_guardians_admin.cairo +++ b/packages/contracts/tests/test_guardians_admin.cairo @@ -19,11 +19,9 @@ use ua2_contracts::ua2_account::UA2Account::{ RecoveryDelaySet, ThresholdSet, }; +use ua2_contracts::errors::{ERR_BAD_THRESHOLD, ERR_GUARDIAN_EXISTS, ERR_NOT_GUARDIAN}; const OWNER_PUBKEY: felt252 = 0x12345; -const ERR_GUARDIAN_EXISTS: felt252 = 'ERR_GUARDIAN_EXISTS'; -const ERR_BAD_THRESHOLD: felt252 = 'ERR_BAD_THRESHOLD'; -const ERR_NOT_GUARDIAN: felt252 = 'ERR_NOT_GUARDIAN'; fn deploy_account() -> ContractAddress { let declare_result = declare("UA2Account").unwrap(); diff --git a/packages/contracts/tests/test_owner_rotate.cairo b/packages/contracts/tests/test_owner_rotate.cairo index 4241c2d..d616d51 100644 --- a/packages/contracts/tests/test_owner_rotate.cairo +++ b/packages/contracts/tests/test_owner_rotate.cairo @@ -1,23 +1,24 @@ -use core::array::{ArrayTrait, SpanTrait}; +use core::array::{Array, ArrayTrait, SpanTrait}; +use core::integer::u256; use core::result::ResultTrait; +use core::serde::Serde; +use core::traits::{Into, TryInto}; use snforge_std::{ - declare, - spy_events, - start_cheat_caller_address, - stop_cheat_caller_address, - ContractClassTrait, - DeclareResultTrait, - EventSpyAssertionsTrait, + ContractClassTrait, DeclareResultTrait, EventSpyAssertionsTrait, declare, spy_events, + start_cheat_caller_address, start_cheat_signature, stop_cheat_caller_address, + stop_cheat_signature, }; -use starknet::{ContractAddress, SyscallResult, SyscallResultTrait}; +use starknet::account::Call; use starknet::syscalls::call_contract_syscall; -use ua2_contracts::ua2_account::UA2Account::{Event, OwnerRotated}; +use starknet::{ContractAddress, SyscallResult, SyscallResultTrait}; +use ua2_contracts::errors::{ERR_NOT_OWNER, ERR_SAME_OWNER, ERR_SESSION_STALE, ERR_ZERO_OWNER}; +use ua2_contracts::session::Session; +use ua2_contracts::ua2_account::UA2Account::{Event, OwnerRotated, SessionPolicy}; +use crate::session_test_utils::{build_session_signature, session_key}; const OWNER_PUBKEY: felt252 = 0x111; const NEW_OWNER: felt252 = 0x222; -const ERR_ZERO_OWNER: felt252 = 'ERR_ZERO_OWNER'; -const ERR_SAME_OWNER: felt252 = 'ERR_SAME_OWNER'; -const ERR_NOT_OWNER: felt252 = 'NOT_OWNER'; +const TRANSFER_SELECTOR: felt252 = starknet::selector!("transfer"); fn deploy_account() -> ContractAddress { let declare_result = declare("UA2Account").unwrap(); @@ -26,21 +27,93 @@ fn deploy_account() -> ContractAddress { contract_address } +fn deploy_account_and_mock() -> (ContractAddress, ContractAddress) { + let account_address = deploy_account(); + + let mock_declare = declare("MockERC20").unwrap(); + let mock_class = mock_declare.contract_class(); + let (mock_address, _) = mock_class.deploy(@array![]).unwrap_syscall(); + + (account_address, mock_address) +} + fn call_with_felt( - contract_address: ContractAddress, - selector: felt252, - value: felt252, + contract_address: ContractAddress, selector: felt252, value: felt252, ) -> SyscallResult> { let mut calldata = array![]; calldata.append(value); call_contract_syscall(contract_address, selector, calldata.span()) } +fn add_session( + account_address: ContractAddress, + session_pubkey: felt252, + mock_address: ContractAddress, + policy: SessionPolicy, +) { + start_cheat_caller_address(account_address, account_address); + + let mut targets: Array = array![]; + targets.append(mock_address); + + let mut selectors: Array = array![]; + selectors.append(TRANSFER_SELECTOR); + + let session = Session { + pubkey: session_pubkey, + valid_after: policy.valid_after, + valid_until: policy.valid_until, + max_calls: policy.max_calls, + value_cap: policy.max_value_per_call, + targets_len: 1_u32, + targets, + selectors_len: 1_u32, + selectors, + }; + + let mut calldata = array![]; + Serde::::serialize(@session, ref calldata); + + call_contract_syscall( + account_address, starknet::selector!("add_session_with_allowlists"), calldata.span(), + ) + .unwrap_syscall(); + + stop_cheat_caller_address(account_address); +} + +fn build_transfer_call(mock_address: ContractAddress, to: ContractAddress, amount: u256) -> Call { + let mut calldata = array![]; + calldata.append(to.into()); + calldata.append(amount.low.into()); + calldata.append(amount.high.into()); + + Call { to: mock_address, selector: TRANSFER_SELECTOR, calldata: calldata.span() } +} + +fn execute_session_call( + account_address: ContractAddress, calls: @Array, signature: @Array, +) -> SyscallResult> { + let zero_contract: ContractAddress = 0.try_into().unwrap(); + start_cheat_caller_address(account_address, zero_contract); + start_cheat_signature(account_address, signature.span()); + + let mut execute_calldata = array![]; + Serde::>::serialize(calls, ref execute_calldata); + + let result = call_contract_syscall( + account_address, starknet::selector!("__execute__"), execute_calldata.span(), + ); + + stop_cheat_signature(account_address); + stop_cheat_caller_address(account_address); + + result +} + fn assert_reverted_with(result: SyscallResult>, expected: felt252) { match result { - Result::Ok(_) => { - assert(false, 'expected revert'); - }, + Result::Ok(_) => { assert(false, 'expected revert'); }, Result::Err(panic_data) => { let data = panic_data.span(); assert(data.len() > 0_usize, 'missing panic data'); @@ -56,30 +129,22 @@ fn owner_rotation_happy() { let mut spy = spy_events(); start_cheat_caller_address(contract_address, contract_address); - call_with_felt( - contract_address, - starknet::selector!("rotate_owner"), - NEW_OWNER, - ) - .unwrap_syscall(); + call_with_felt(contract_address, starknet::selector!("rotate_owner"), NEW_OWNER) + .unwrap_syscall(); stop_cheat_caller_address(contract_address); let empty = array![]; let owner_result = call_contract_syscall( - contract_address, - starknet::selector!("get_owner"), - empty.span(), + contract_address, starknet::selector!("get_owner"), empty.span(), ) - .unwrap_syscall(); + .unwrap_syscall(); let owner = *owner_result.at(0_usize); assert(owner == NEW_OWNER, 'owner not rotated'); - spy.assert_emitted(@array![ - ( - contract_address, - Event::OwnerRotated(OwnerRotated { new_owner: NEW_OWNER }), - ), - ]); + spy + .assert_emitted( + @array![(contract_address, Event::OwnerRotated(OwnerRotated { new_owner: NEW_OWNER }))], + ); } #[test] @@ -87,17 +152,11 @@ fn owner_rotation_rejects_zero_and_same() { let contract_address = deploy_account(); start_cheat_caller_address(contract_address, contract_address); - let zero_owner = call_with_felt( - contract_address, - starknet::selector!("rotate_owner"), - 0, - ); + let zero_owner = call_with_felt(contract_address, starknet::selector!("rotate_owner"), 0); assert_reverted_with(zero_owner, ERR_ZERO_OWNER); let same_owner = call_with_felt( - contract_address, - starknet::selector!("rotate_owner"), - OWNER_PUBKEY, + contract_address, starknet::selector!("rotate_owner"), OWNER_PUBKEY, ); assert_reverted_with(same_owner, ERR_SAME_OWNER); stop_cheat_caller_address(contract_address); @@ -107,20 +166,54 @@ fn owner_rotation_rejects_zero_and_same() { fn non_owner_cannot_rotate() { let contract_address = deploy_account(); - let result = call_with_felt( - contract_address, - starknet::selector!("rotate_owner"), - 0xBBB, - ); + let result = call_with_felt(contract_address, starknet::selector!("rotate_owner"), 0xBBB); assert_reverted_with(result, ERR_NOT_OWNER); let empty = array![]; let owner_result = call_contract_syscall( - contract_address, - starknet::selector!("get_owner"), - empty.span(), + contract_address, starknet::selector!("get_owner"), empty.span(), ) - .unwrap_syscall(); + .unwrap_syscall(); let owner = *owner_result.at(0_usize); assert(owner == OWNER_PUBKEY, 'owner should remain original'); } + +#[test] +fn test_rotate_owner_revokes_sessions() { + let (account_address, mock_address) = deploy_account_and_mock(); + + let session_pubkey = session_key(); + let policy = SessionPolicy { + is_active: true, + valid_after: 0_u64, + valid_until: 10_000_u64, + max_calls: 5_u32, + calls_used: 0_u32, + max_value_per_call: u256 { low: 10_000_u128, high: 0_u128 }, + owner_epoch: 0_u64, + }; + + add_session(account_address, session_pubkey, mock_address, policy); + + let to: ContractAddress = account_address; + let amount = u256 { low: 1_000_u128, high: 0_u128 }; + let call = build_transfer_call(mock_address, to, amount); + let calls = array![call]; + + let first_signature: Array = build_session_signature( + account_address, session_pubkey, 0_u128, policy.valid_until, @calls, + ); + execute_session_call(account_address, @calls, @first_signature).unwrap_syscall(); + + start_cheat_caller_address(account_address, account_address); + call_with_felt(account_address, starknet::selector!("rotate_owner"), NEW_OWNER) + .unwrap_syscall(); + stop_cheat_caller_address(account_address); + + let second_signature: Array = build_session_signature( + account_address, session_pubkey, 1_u128, policy.valid_until, @calls, + ); + let result = execute_session_call(account_address, @calls, @second_signature); + + assert_reverted_with(result, ERR_SESSION_STALE); +} diff --git a/packages/contracts/tests/test_recovery_edgecases.cairo b/packages/contracts/tests/test_recovery_edgecases.cairo index 2350e45..d20937b 100644 --- a/packages/contracts/tests/test_recovery_edgecases.cairo +++ b/packages/contracts/tests/test_recovery_edgecases.cairo @@ -3,25 +3,39 @@ use core::result::ResultTrait; use core::traits::{Into, TryInto}; use snforge_std::{ declare, + spy_events, start_cheat_block_timestamp, stop_cheat_block_timestamp, start_cheat_caller_address, stop_cheat_caller_address, ContractClassTrait, DeclareResultTrait, + EventSpyAssertionsTrait, }; use starknet::{ContractAddress, SyscallResult, SyscallResultTrait}; use starknet::syscalls::call_contract_syscall; +use ua2_contracts::errors::{ + ERR_ALREADY_CONFIRMED, + ERR_BEFORE_ETA, + ERR_NO_RECOVERY, + ERR_NOT_ENOUGH_CONFIRMS, + ERR_RECOVERY_IN_PROGRESS, + ERR_RECOVERY_MISMATCH, +}; +use ua2_contracts::ua2_account::UA2Account::{ + Event, + GuardianFinalized, + GuardianProposed, + OwnerRotated, + RecoveryCanceled, + RecoveryConfirmed, + RecoveryExecuted, + RecoveryProposed, +}; const OWNER_PUBKEY: felt252 = 0x12345; const RECOVERY_OWNER_A: felt252 = 0xAAA111; const RECOVERY_OWNER_B: felt252 = 0xBBB222; -const ERR_NOT_ENOUGH_CONFIRMS: felt252 = 'ERR_NOT_ENOUGH_CONFIRMS'; -const ERR_RECOVERY_IN_PROGRESS: felt252 = 'ERR_RECOVERY_IN_PROGRESS'; -const ERR_RECOVERY_MISMATCH: felt252 = 'ERR_RECOVERY_MISMATCH'; -const ERR_ALREADY_CONFIRMED: felt252 = 'ERR_ALREADY_CONFIRMED'; -const ERR_BEFORE_ETA: felt252 = 'ERR_BEFORE_ETA'; -const ERR_NO_RECOVERY: felt252 = 'ERR_NO_RECOVERY'; fn deploy_account() -> ContractAddress { let declare_result = declare("UA2Account").unwrap(); @@ -237,3 +251,403 @@ fn recovery_edge_cases() { stop_cheat_block_timestamp(contract_address); } + +#[test] +fn recovery_cancel_emits_event() { + let contract_address = deploy_account(); + let g1: ContractAddress = 0x111.try_into().unwrap(); + + start_cheat_caller_address(contract_address, contract_address); + let mut add_calldata = array![]; + add_calldata.append(g1.into()); + call_contract_syscall( + contract_address, + starknet::selector!("add_guardian"), + add_calldata.span(), + ) + .unwrap_syscall(); + + let mut threshold_calldata = array![]; + threshold_calldata.append(1_u8.into()); + call_contract_syscall( + contract_address, + starknet::selector!("set_guardian_threshold"), + threshold_calldata.span(), + ) + .unwrap_syscall(); + + let mut delay_calldata = array![]; + delay_calldata.append(0_u64.into()); + call_contract_syscall( + contract_address, + starknet::selector!("set_recovery_delay"), + delay_calldata.span(), + ) + .unwrap_syscall(); + stop_cheat_caller_address(contract_address); + + let mut spy = spy_events(); + + start_cheat_block_timestamp(contract_address, 10_u64); + + start_cheat_caller_address(contract_address, g1); + let mut propose_calldata = array![]; + propose_calldata.append(RECOVERY_OWNER_A); + call_contract_syscall( + contract_address, + starknet::selector!("propose_recovery"), + propose_calldata.span(), + ) + .unwrap_syscall(); + stop_cheat_caller_address(contract_address); + + start_cheat_caller_address(contract_address, contract_address); + let empty = array![]; + call_contract_syscall( + contract_address, + starknet::selector!("cancel_recovery"), + empty.span(), + ) + .unwrap_syscall(); + stop_cheat_caller_address(contract_address); + + stop_cheat_block_timestamp(contract_address); + + spy.assert_emitted(@array![ + ( + contract_address, + Event::RecoveryConfirmed(RecoveryConfirmed { + guardian: g1, + new_owner: RECOVERY_OWNER_A, + count: 1_u32, + }), + ), + ( + contract_address, + Event::GuardianProposed(GuardianProposed { + guardian: g1, + proposal_id: 1_u64, + new_owner: RECOVERY_OWNER_A, + eta: 10_u64, + }), + ), + ( + contract_address, + Event::RecoveryProposed(RecoveryProposed { + new_owner: RECOVERY_OWNER_A, + eta: 10_u64, + }), + ), + ( + contract_address, + Event::RecoveryCanceled(RecoveryCanceled {}), + ), + ]); +} + +#[test] +fn test_guardian_timelock_enforced() { + let contract_address = deploy_account(); + let g1: ContractAddress = 0x111.try_into().unwrap(); + + start_cheat_caller_address(contract_address, contract_address); + let mut add_calldata = array![]; + add_calldata.append(g1.into()); + call_contract_syscall( + contract_address, + starknet::selector!("add_guardian"), + add_calldata.span(), + ) + .unwrap_syscall(); + + let mut threshold_calldata = array![]; + threshold_calldata.append(1_u8.into()); + call_contract_syscall( + contract_address, + starknet::selector!("set_guardian_threshold"), + threshold_calldata.span(), + ) + .unwrap_syscall(); + + let mut delay_calldata = array![]; + delay_calldata.append(500_u64.into()); + call_contract_syscall( + contract_address, + starknet::selector!("set_recovery_delay"), + delay_calldata.span(), + ) + .unwrap_syscall(); + stop_cheat_caller_address(contract_address); + + start_cheat_block_timestamp(contract_address, 100_u64); + + start_cheat_caller_address(contract_address, g1); + let mut propose_calldata = array![]; + propose_calldata.append(RECOVERY_OWNER_A); + call_contract_syscall( + contract_address, + starknet::selector!("propose_recovery"), + propose_calldata.span(), + ) + .unwrap_syscall(); + + let mut execute_calldata = array![]; + let before_eta = call_contract_syscall( + contract_address, + starknet::selector!("execute_recovery"), + execute_calldata.span(), + ); + stop_cheat_caller_address(contract_address); + + stop_cheat_block_timestamp(contract_address); + + assert_reverted_with(before_eta, ERR_BEFORE_ETA); +} + +#[test] +fn test_guardian_finalize() { + let contract_address = deploy_account(); + let g1: ContractAddress = 0x111.try_into().unwrap(); + let g2: ContractAddress = 0x222.try_into().unwrap(); + + start_cheat_caller_address(contract_address, contract_address); + let mut add_one = array![]; + add_one.append(g1.into()); + call_contract_syscall( + contract_address, + starknet::selector!("add_guardian"), + add_one.span(), + ) + .unwrap_syscall(); + + let mut add_two = array![]; + add_two.append(g2.into()); + call_contract_syscall( + contract_address, + starknet::selector!("add_guardian"), + add_two.span(), + ) + .unwrap_syscall(); + + let mut threshold_calldata = array![]; + threshold_calldata.append(2_u8.into()); + call_contract_syscall( + contract_address, + starknet::selector!("set_guardian_threshold"), + threshold_calldata.span(), + ) + .unwrap_syscall(); + + let mut delay_calldata = array![]; + delay_calldata.append(0_u64.into()); + call_contract_syscall( + contract_address, + starknet::selector!("set_recovery_delay"), + delay_calldata.span(), + ) + .unwrap_syscall(); + stop_cheat_caller_address(contract_address); + + let mut spy = spy_events(); + + start_cheat_block_timestamp(contract_address, 500_u64); + + start_cheat_caller_address(contract_address, g1); + let mut propose_calldata = array![]; + propose_calldata.append(RECOVERY_OWNER_A); + call_contract_syscall( + contract_address, + starknet::selector!("propose_recovery"), + propose_calldata.span(), + ) + .unwrap_syscall(); + stop_cheat_caller_address(contract_address); + + start_cheat_caller_address(contract_address, g2); + let mut confirm_calldata = array![]; + confirm_calldata.append(RECOVERY_OWNER_A); + call_contract_syscall( + contract_address, + starknet::selector!("confirm_recovery"), + confirm_calldata.span(), + ) + .unwrap_syscall(); + stop_cheat_caller_address(contract_address); + + start_cheat_caller_address(contract_address, g1); + let mut execute_calldata = array![]; + call_contract_syscall( + contract_address, + starknet::selector!("execute_recovery"), + execute_calldata.span(), + ) + .unwrap_syscall(); + stop_cheat_caller_address(contract_address); + + stop_cheat_block_timestamp(contract_address); + + let mut empty = array![]; + let owner_result = call_contract_syscall( + contract_address, + starknet::selector!("get_owner"), + empty.span(), + ) + .unwrap_syscall(); + let owner = *owner_result.at(0_usize); + assert(owner == RECOVERY_OWNER_A, 'owner not rotated by guardians'); + + spy.assert_emitted(@array![ + ( + contract_address, + Event::RecoveryConfirmed(RecoveryConfirmed { + guardian: g1, + new_owner: RECOVERY_OWNER_A, + count: 1_u32, + }), + ), + ( + contract_address, + Event::GuardianProposed(GuardianProposed { + guardian: g1, + proposal_id: 1_u64, + new_owner: RECOVERY_OWNER_A, + eta: 500_u64, + }), + ), + ( + contract_address, + Event::RecoveryProposed(RecoveryProposed { + new_owner: RECOVERY_OWNER_A, + eta: 500_u64, + }), + ), + ( + contract_address, + Event::RecoveryConfirmed(RecoveryConfirmed { + guardian: g2, + new_owner: RECOVERY_OWNER_A, + count: 2_u32, + }), + ), + ( + contract_address, + Event::OwnerRotated(OwnerRotated { new_owner: RECOVERY_OWNER_A }), + ), + ( + contract_address, + Event::RecoveryExecuted(RecoveryExecuted { new_owner: RECOVERY_OWNER_A }), + ), + ( + contract_address, + Event::GuardianFinalized(GuardianFinalized { + guardian: g1, + proposal_id: 1_u64, + new_owner: RECOVERY_OWNER_A, + }), + ), + ]); +} + +#[test] +fn test_guardian_cancel() { + let contract_address = deploy_account(); + let g1: ContractAddress = 0x111.try_into().unwrap(); + + start_cheat_caller_address(contract_address, contract_address); + let mut add_calldata = array![]; + add_calldata.append(g1.into()); + call_contract_syscall( + contract_address, + starknet::selector!("add_guardian"), + add_calldata.span(), + ) + .unwrap_syscall(); + + let mut threshold_calldata = array![]; + threshold_calldata.append(1_u8.into()); + call_contract_syscall( + contract_address, + starknet::selector!("set_guardian_threshold"), + threshold_calldata.span(), + ) + .unwrap_syscall(); + + let mut delay_calldata = array![]; + delay_calldata.append(10_u64.into()); + call_contract_syscall( + contract_address, + starknet::selector!("set_recovery_delay"), + delay_calldata.span(), + ) + .unwrap_syscall(); + stop_cheat_caller_address(contract_address); + + let mut spy = spy_events(); + + start_cheat_block_timestamp(contract_address, 42_u64); + + start_cheat_caller_address(contract_address, g1); + let mut propose_calldata = array![]; + propose_calldata.append(RECOVERY_OWNER_A); + call_contract_syscall( + contract_address, + starknet::selector!("propose_recovery"), + propose_calldata.span(), + ) + .unwrap_syscall(); + stop_cheat_caller_address(contract_address); + + start_cheat_caller_address(contract_address, contract_address); + let mut cancel_calldata = array![]; + call_contract_syscall( + contract_address, + starknet::selector!("cancel_recovery"), + cancel_calldata.span(), + ) + .unwrap_syscall(); + stop_cheat_caller_address(contract_address); + + start_cheat_caller_address(contract_address, g1); + let mut execute_calldata = array![]; + let no_recovery = call_contract_syscall( + contract_address, + starknet::selector!("execute_recovery"), + execute_calldata.span(), + ); + stop_cheat_caller_address(contract_address); + + stop_cheat_block_timestamp(contract_address); + + assert_reverted_with(no_recovery, ERR_NO_RECOVERY); + + spy.assert_emitted(@array![ + ( + contract_address, + Event::RecoveryConfirmed(RecoveryConfirmed { + guardian: g1, + new_owner: RECOVERY_OWNER_A, + count: 1_u32, + }), + ), + ( + contract_address, + Event::GuardianProposed(GuardianProposed { + guardian: g1, + proposal_id: 1_u64, + new_owner: RECOVERY_OWNER_A, + eta: 52_u64, + }), + ), + ( + contract_address, + Event::RecoveryProposed(RecoveryProposed { + new_owner: RECOVERY_OWNER_A, + eta: 52_u64, + }), + ), + ( + contract_address, + Event::RecoveryCanceled(RecoveryCanceled {}), + ), + ]); +} diff --git a/packages/contracts/tests/test_recovery_happy.cairo b/packages/contracts/tests/test_recovery_happy.cairo index 8470cb1..3a5bfd8 100644 --- a/packages/contracts/tests/test_recovery_happy.cairo +++ b/packages/contracts/tests/test_recovery_happy.cairo @@ -17,6 +17,8 @@ use starknet::syscalls::call_contract_syscall; use ua2_contracts::ua2_account::UA2Account::{ Event, GuardianAdded, + GuardianFinalized, + GuardianProposed, OwnerRotated, RecoveryConfirmed, RecoveryDelaySet, @@ -24,10 +26,10 @@ use ua2_contracts::ua2_account::UA2Account::{ RecoveryProposed, ThresholdSet, }; +use ua2_contracts::errors::ERR_NO_RECOVERY; const OWNER_PUBKEY: felt252 = 0x12345; const NEW_OWNER: felt252 = 0xABCDEF0123; -const ERR_NO_RECOVERY: felt252 = 'ERR_NO_RECOVERY'; fn deploy_account() -> ContractAddress { let declare_result = declare("UA2Account").unwrap(); @@ -159,6 +161,15 @@ fn recovery_happy_path() { count: 1_u32, }), ), + ( + contract_address, + Event::GuardianProposed(GuardianProposed { + guardian: g1, + proposal_id: 1_u64, + new_owner: NEW_OWNER, + eta: 100_u64, + }), + ), ( contract_address, Event::RecoveryProposed(RecoveryProposed { @@ -186,5 +197,13 @@ fn recovery_happy_path() { new_owner: NEW_OWNER, }), ), + ( + contract_address, + Event::GuardianFinalized(GuardianFinalized { + guardian: g1, + proposal_id: 1_u64, + new_owner: NEW_OWNER, + }), + ), ]); } diff --git a/packages/contracts/tests/test_rotate_vs_recovery.cairo b/packages/contracts/tests/test_rotate_vs_recovery.cairo index fc19466..f09b0d0 100644 --- a/packages/contracts/tests/test_rotate_vs_recovery.cairo +++ b/packages/contracts/tests/test_rotate_vs_recovery.cairo @@ -10,9 +10,9 @@ use snforge_std::{ }; use starknet::{ContractAddress, SyscallResult, SyscallResultTrait}; use starknet::syscalls::call_contract_syscall; +use ua2_contracts::errors::ERR_RECOVERY_IN_PROGRESS; const OWNER_PUBKEY: felt252 = 0x111; -const ERR_RECOVERY_IN_PROGRESS: felt252 = 'ERR_RECOVERY_IN_PROGRESS'; const NEW_RECOVERY_OWNER: felt252 = 0xDEAD; const ROTATED_OWNER: felt252 = 0xBEEF; diff --git a/packages/contracts/tests/test_session_nonce_ok.cairo b/packages/contracts/tests/test_session_nonce_ok.cairo index 6857d93..eb3d8e4 100644 --- a/packages/contracts/tests/test_session_nonce_ok.cairo +++ b/packages/contracts/tests/test_session_nonce_ok.cairo @@ -2,37 +2,20 @@ use core::array::{Array, ArrayTrait}; use core::integer::u256; use core::serde::Serde; use core::traits::Into; - use snforge_std::{ - declare, - spy_events, - start_cheat_block_timestamp, - stop_cheat_block_timestamp, - start_cheat_caller_address, - stop_cheat_caller_address, - start_cheat_signature, - stop_cheat_signature, - ContractClassTrait, - DeclareResultTrait, - EventSpyAssertionsTrait, + ContractClassTrait, DeclareResultTrait, EventSpyAssertionsTrait, declare, spy_events, + start_cheat_block_timestamp, start_cheat_caller_address, start_cheat_signature, + stop_cheat_block_timestamp, stop_cheat_caller_address, stop_cheat_signature, }; use starknet::account::Call; use starknet::syscalls::call_contract_syscall; use starknet::{ContractAddress, SyscallResultTrait}; +use ua2_contracts::session::Session; use ua2_contracts::ua2_account::UA2Account::{ - Event, - ISessionManagerDispatcher, - ISessionManagerDispatcherTrait, - SessionNonceAdvanced, - SessionPolicy, - SessionUsed, -}; - -use crate::session_test_utils::{ - build_session_signature, - session_key, - session_key_hash, + Event, ISessionManagerDispatcher, ISessionManagerDispatcherTrait, SessionNonceAdvanced, + SessionPolicy, SessionUsed, }; +use crate::session_test_utils::{build_session_signature, session_key, session_key_hash}; const OWNER_PUBKEY: felt252 = 0x12345; const TRANSFER_SELECTOR: felt252 = starknet::selector!("transfer"); @@ -57,27 +40,31 @@ fn add_session( ) { start_cheat_caller_address(account_address, account_address); + let mut targets: Array = array![]; + targets.append(mock_address); + + let mut selectors: Array = array![]; + selectors.append(TRANSFER_SELECTOR); + + let session = Session { + pubkey: session_pubkey, + valid_after: policy.valid_after, + valid_until: policy.valid_until, + max_calls: policy.max_calls, + value_cap: policy.max_value_per_call, + targets_len: 1_u32, + targets, + selectors_len: 1_u32, + selectors, + }; + let mut calldata = array![]; - calldata.append(session_pubkey); - let active_flag: felt252 = if policy.is_active { 1 } else { 0 }; - calldata.append(active_flag); - calldata.append(policy.expires_at.into()); - calldata.append(policy.max_calls.into()); - calldata.append(policy.calls_used.into()); - calldata.append(policy.max_value_per_call.low.into()); - calldata.append(policy.max_value_per_call.high.into()); - - calldata.append(1.into()); - calldata.append(mock_address.into()); - calldata.append(1.into()); - calldata.append(TRANSFER_SELECTOR); + Serde::::serialize(@session, ref calldata); call_contract_syscall( - account_address, - starknet::selector!("add_session_with_allowlists"), - calldata.span(), + account_address, starknet::selector!("add_session_with_allowlists"), calldata.span(), ) - .unwrap_syscall(); + .unwrap_syscall(); stop_cheat_caller_address(account_address); } @@ -92,9 +79,7 @@ fn build_transfer_call(mock_address: ContractAddress, to: ContractAddress, amoun } fn execute_with_signature( - account_address: ContractAddress, - calls: @Array, - signature: @Array, + account_address: ContractAddress, calls: @Array, signature: @Array, ) { let zero_contract: ContractAddress = 0.try_into().unwrap(); start_cheat_caller_address(account_address, zero_contract); @@ -104,11 +89,9 @@ fn execute_with_signature( Serde::>::serialize(calls, ref execute_calldata); call_contract_syscall( - account_address, - starknet::selector!("__execute__"), - execute_calldata.span(), + account_address, starknet::selector!("__execute__"), execute_calldata.span(), ) - .unwrap_syscall(); + .unwrap_syscall(); stop_cheat_signature(account_address); stop_cheat_caller_address(account_address); @@ -123,10 +106,12 @@ fn test_session_nonce_ok() { let key_hash = session_key_hash(); let policy = SessionPolicy { is_active: true, - expires_at: 10_000_u64, + valid_after: 0_u64, + valid_until: 10_000_u64, max_calls: 5_u32, calls_used: 0_u32, max_value_per_call: u256 { low: 10_000_u128, high: 0_u128 }, + owner_epoch: 0_u64, }; add_session(account_address, session_pubkey, mock_address, policy); @@ -138,36 +123,40 @@ fn test_session_nonce_ok() { let call = build_transfer_call(mock_address, to, amount); let calls = array![call]; - let signature0: Array = - build_session_signature(account_address, session_pubkey, 0_u128, @calls); + let signature0: Array = build_session_signature( + account_address, session_pubkey, 0_u128, policy.valid_until, @calls, + ); execute_with_signature(account_address, @calls, @signature0); - let signature1: Array = - build_session_signature(account_address, session_pubkey, 1_u128, @calls); + let signature1: Array = build_session_signature( + account_address, session_pubkey, 1_u128, policy.valid_until, @calls, + ); execute_with_signature(account_address, @calls, @signature1); stop_cheat_block_timestamp(account_address); - spy.assert_emitted(@array![ - ( - account_address, - Event::SessionUsed(SessionUsed { key_hash, used: 1_u32 }), - ), - ( - account_address, - Event::SessionNonceAdvanced(SessionNonceAdvanced { key_hash, new_nonce: 1_u128 }), - ), - ( - account_address, - Event::SessionUsed(SessionUsed { key_hash, used: 1_u32 }), - ), - ( - account_address, - Event::SessionNonceAdvanced(SessionNonceAdvanced { key_hash, new_nonce: 2_u128 }), - ), - ]); + spy + .assert_emitted( + @array![ + (account_address, Event::SessionUsed(SessionUsed { key_hash, used: 1_u32 })), + ( + account_address, + Event::SessionNonceAdvanced( + SessionNonceAdvanced { key_hash, new_nonce: 1_u128 }, + ), + ), + (account_address, Event::SessionUsed(SessionUsed { key_hash, used: 1_u32 })), + ( + account_address, + Event::SessionNonceAdvanced( + SessionNonceAdvanced { key_hash, new_nonce: 2_u128 }, + ), + ), + ], + ); let dispatcher = ISessionManagerDispatcher { contract_address: account_address }; let updated_policy = dispatcher.get_session(key_hash); assert(updated_policy.calls_used == 2_u32, 'unexpected call count'); + assert(updated_policy.owner_epoch == 0_u64, 'unexpected session epoch'); } diff --git a/packages/contracts/tests/test_session_nonce_replay_and_mismatch.cairo b/packages/contracts/tests/test_session_nonce_replay_and_mismatch.cairo index 5be05a3..26f42ff 100644 --- a/packages/contracts/tests/test_session_nonce_replay_and_mismatch.cairo +++ b/packages/contracts/tests/test_session_nonce_replay_and_mismatch.cairo @@ -1,31 +1,24 @@ use core::array::{Array, ArrayTrait}; use core::integer::u256; +use core::result::Result; use core::serde::Serde; use core::traits::{Into, TryInto}; -use core::result::Result; - use snforge_std::{ - declare, - start_cheat_block_timestamp, - stop_cheat_block_timestamp, - start_cheat_caller_address, - stop_cheat_caller_address, - start_cheat_signature, - stop_cheat_signature, - ContractClassTrait, - DeclareResultTrait, + ContractClassTrait, DeclareResultTrait, declare, start_cheat_block_timestamp, + start_cheat_caller_address, start_cheat_signature, stop_cheat_block_timestamp, + stop_cheat_caller_address, stop_cheat_signature, }; use starknet::account::Call; use starknet::syscalls::call_contract_syscall; use starknet::{ContractAddress, SyscallResult, SyscallResultTrait}; +use ua2_contracts::errors::{ERR_BAD_SESSION_NONCE, ERR_SESSION_SIG_INVALID}; +use ua2_contracts::session::Session; use ua2_contracts::ua2_account::UA2Account::SessionPolicy; - use crate::session_test_utils::{build_session_signature, session_key}; const OWNER_PUBKEY: felt252 = 0x12345; const TRANSFER_SELECTOR: felt252 = starknet::selector!("transfer"); -const ERR_BAD_SESSION_NONCE: felt252 = 'ERR_BAD_SESSION_NONCE'; -const ERR_SESSION_SIG_INVALID: felt252 = 'ERR_SESSION_SIG_INVALID'; +const ALT_SESSION_PUBKEY: felt252 = 0xABCDEF12345; fn deploy_account_and_mock() -> (ContractAddress, ContractAddress) { let account_declare = declare("UA2Account").unwrap(); @@ -47,27 +40,31 @@ fn add_session( ) { start_cheat_caller_address(account_address, account_address); + let mut targets: Array = array![]; + targets.append(mock_address); + + let mut selectors: Array = array![]; + selectors.append(TRANSFER_SELECTOR); + + let session = Session { + pubkey: session_pubkey, + valid_after: policy.valid_after, + valid_until: policy.valid_until, + max_calls: policy.max_calls, + value_cap: policy.max_value_per_call, + targets_len: 1_u32, + targets, + selectors_len: 1_u32, + selectors, + }; + let mut calldata = array![]; - calldata.append(session_pubkey); - let active_flag: felt252 = if policy.is_active { 1 } else { 0 }; - calldata.append(active_flag); - calldata.append(policy.expires_at.into()); - calldata.append(policy.max_calls.into()); - calldata.append(policy.calls_used.into()); - calldata.append(policy.max_value_per_call.low.into()); - calldata.append(policy.max_value_per_call.high.into()); - - calldata.append(1.into()); - calldata.append(mock_address.into()); - calldata.append(1.into()); - calldata.append(TRANSFER_SELECTOR); + Serde::::serialize(@session, ref calldata); call_contract_syscall( - account_address, - starknet::selector!("add_session_with_allowlists"), - calldata.span(), + account_address, starknet::selector!("add_session_with_allowlists"), calldata.span(), ) - .unwrap_syscall(); + .unwrap_syscall(); stop_cheat_caller_address(account_address); } @@ -82,9 +79,7 @@ fn build_transfer_call(mock_address: ContractAddress, to: ContractAddress, amoun } fn execute_with_signature( - account_address: ContractAddress, - calls: @Array, - signature: @Array, + account_address: ContractAddress, calls: @Array, signature: @Array, ) -> SyscallResult> { let zero_contract: ContractAddress = 0.try_into().unwrap(); start_cheat_caller_address(account_address, zero_contract); @@ -94,9 +89,7 @@ fn execute_with_signature( Serde::>::serialize(calls, ref execute_calldata); let result = call_contract_syscall( - account_address, - starknet::selector!("__execute__"), - execute_calldata.span(), + account_address, starknet::selector!("__execute__"), execute_calldata.span(), ); stop_cheat_signature(account_address); @@ -107,9 +100,7 @@ fn execute_with_signature( fn assert_reverted_with(result: SyscallResult>, expected: felt252) { match result { - Result::Ok(_) => { - assert(false, 'expected revert'); - }, + Result::Ok(_) => { assert(false, 'expected revert'); }, Result::Err(panic_data) => { let panic_span = panic_data.span(); assert(panic_span.len() > 0_usize, 'missing panic data'); @@ -126,10 +117,12 @@ fn test_session_nonce_replay_and_mismatch() { let session_pubkey = session_key(); let policy = SessionPolicy { is_active: true, - expires_at: 10_000_u64, + valid_after: 0_u64, + valid_until: 10_000_u64, max_calls: 5_u32, calls_used: 0_u32, max_value_per_call: u256 { low: 10_000_u128, high: 0_u128 }, + owner_epoch: 0_u64, }; add_session(account_address, session_pubkey, mock_address, policy); @@ -141,24 +134,28 @@ fn test_session_nonce_replay_and_mismatch() { let call = build_transfer_call(mock_address, to, amount); let calls = array![call]; - let signature0: Array = - build_session_signature(account_address, session_pubkey, 0_u128, @calls); + let signature0: Array = build_session_signature( + account_address, session_pubkey, 0_u128, policy.valid_until, @calls, + ); execute_with_signature(account_address, @calls, @signature0).unwrap_syscall(); let replay_result = execute_with_signature(account_address, @calls, @signature0); assert_reverted_with(replay_result, ERR_BAD_SESSION_NONCE); - let skip_signature: Array = - build_session_signature(account_address, session_pubkey, 2_u128, @calls); + let skip_signature: Array = build_session_signature( + account_address, session_pubkey, 2_u128, policy.valid_until, @calls, + ); let skip_result = execute_with_signature(account_address, @calls, @skip_signature); assert_reverted_with(skip_result, ERR_BAD_SESSION_NONCE); - let signature1: Array = - build_session_signature(account_address, session_pubkey, 1_u128, @calls); + let signature1: Array = build_session_signature( + account_address, session_pubkey, 1_u128, policy.valid_until, @calls, + ); execute_with_signature(account_address, @calls, @signature1).unwrap_syscall(); - let signature2: Array = - build_session_signature(account_address, session_pubkey, 2_u128, @calls); + let signature2: Array = build_session_signature( + account_address, session_pubkey, 2_u128, policy.valid_until, @calls, + ); let tampered_amount = u256 { low: 1_001_u128, high: 0_u128 }; let tampered_call = build_transfer_call(mock_address, to, tampered_amount); @@ -169,3 +166,53 @@ fn test_session_nonce_replay_and_mismatch() { stop_cheat_block_timestamp(account_address); } + +#[test] +fn cross_session_signature_reuse_fails() { + let (account_address, mock_address) = deploy_account_and_mock(); + + let policy = SessionPolicy { + is_active: true, + valid_after: 0_u64, + valid_until: 10_000_u64, + max_calls: 5_u32, + calls_used: 0_u32, + max_value_per_call: u256 { low: 10_000_u128, high: 0_u128 }, + owner_epoch: 0_u64, + }; + + let session_pubkey_a = session_key(); + add_session(account_address, session_pubkey_a, mock_address, policy); + add_session(account_address, ALT_SESSION_PUBKEY, mock_address, policy); + + start_cheat_block_timestamp(account_address, 5_000_u64); + + let to: ContractAddress = account_address; + let amount = u256 { low: 1_000_u128, high: 0_u128 }; + let call = build_transfer_call(mock_address, to, amount); + let calls = array![call]; + + let signature_a: Array = build_session_signature( + account_address, + session_pubkey_a, + 0_u128, + policy.valid_until, + @calls, + ); + + let mut swapped_signature = ArrayTrait::::new(); + let mut index = 0_usize; + for felt_ref in signature_a.span() { + if index == 1_usize { + swapped_signature.append(ALT_SESSION_PUBKEY); + } else { + swapped_signature.append(*felt_ref); + } + index += 1_usize; + } + + let result = execute_with_signature(account_address, @calls, @swapped_signature); + assert_reverted_with(result, ERR_SESSION_SIG_INVALID); + + stop_cheat_block_timestamp(account_address); +} diff --git a/packages/contracts/tests/test_sessions.cairo b/packages/contracts/tests/test_sessions.cairo index 2a485ee..8e40d42 100644 --- a/packages/contracts/tests/test_sessions.cairo +++ b/packages/contracts/tests/test_sessions.cairo @@ -1,28 +1,31 @@ +use core::array::{Array, ArrayTrait, SpanTrait}; use core::integer::u256; +use core::option::Option; +use core::pedersen::pedersen; use core::result::ResultTrait; +use core::serde::Serde; +use core::traits::{Into, TryInto}; use snforge_std::{ - declare, - spy_events, - start_cheat_caller_address, - stop_cheat_caller_address, - ContractClassTrait, - DeclareResultTrait, - EventSpyAssertionsTrait, + ContractClassTrait, DeclareResultTrait, EventSpyAssertionsTrait, declare, spy_events, + start_cheat_block_timestamp, start_cheat_caller_address, start_cheat_signature, + stop_cheat_block_timestamp, stop_cheat_caller_address, stop_cheat_signature, }; -use starknet::SyscallResultTrait; -use ua2_contracts::ua2_account::UA2Account::{ - self, - SessionAdded, - SessionPolicy, - SessionRevoked, +use starknet::account::Call; +use starknet::syscalls::call_contract_syscall; +use starknet::{ContractAddress, SyscallResult, SyscallResultTrait}; +use ua2_contracts::errors::{ + ERR_POLICY_SELECTOR_DENIED, ERR_POLICY_TARGET_DENIED, ERR_SESSION_EXPIRED, + ERR_VALUE_LIMIT_EXCEEDED, }; +use ua2_contracts::session::Session; use ua2_contracts::ua2_account::UA2Account::{ - ISessionManagerDispatcher, - ISessionManagerDispatcherTrait, + self, ISessionManagerDispatcher, ISessionManagerDispatcherTrait, SessionAdded, SessionPolicy, + SessionRevoked, }; -use core::pedersen::pedersen; +use crate::session_test_utils::{build_session_signature, session_key}; const OWNER_PUBKEY: felt252 = 0x12345; +const TRANSFER_SELECTOR: felt252 = starknet::selector!("transfer"); fn deploy_account() -> (starknet::ContractAddress, ISessionManagerDispatcher) { let declare_result = declare("UA2Account").unwrap(); @@ -32,8 +35,117 @@ fn deploy_account() -> (starknet::ContractAddress, ISessionManagerDispatcher) { (contract_address, dispatcher) } +fn deploy_account_and_mock() -> (ContractAddress, ISessionManagerDispatcher, ContractAddress) { + let (account_address, dispatcher) = deploy_account(); + + let mock_declare = declare("MockERC20").unwrap(); + let mock_class = mock_declare.contract_class(); + let (mock_address, _) = mock_class.deploy(@array![]).unwrap_syscall(); + + (account_address, dispatcher, mock_address) +} + +fn add_session_allowlist( + account_address: ContractAddress, + session_pubkey: felt252, + policy: SessionPolicy, + mut targets: Array, + mut selectors: Array, +) -> felt252 { + start_cheat_caller_address(account_address, account_address); + + let targets_len_usize = targets.len(); + let targets_len: u32 = match targets_len_usize.try_into() { + Option::Some(value) => value, + Option::None(_) => { + assert(false, 'bad targets len'); + 0_u32 + }, + }; + + let selectors_len_usize = selectors.len(); + let selectors_len: u32 = match selectors_len_usize.try_into() { + Option::Some(value) => value, + Option::None(_) => { + assert(false, 'bad selectors len'); + 0_u32 + }, + }; + + let session = Session { + pubkey: session_pubkey, + valid_after: policy.valid_after, + valid_until: policy.valid_until, + max_calls: policy.max_calls, + value_cap: policy.max_value_per_call, + targets_len, + targets, + selectors_len, + selectors, + }; + + let mut calldata = array![]; + Serde::::serialize(@session, ref calldata); + + call_contract_syscall( + account_address, + starknet::selector!("add_session_with_allowlists"), + calldata.span(), + ) + .unwrap_syscall(); + + stop_cheat_caller_address(account_address); + + pedersen(session_pubkey, 0) +} + +fn build_transfer_call(mock_address: ContractAddress, to: ContractAddress, amount: u256) -> Call { + let mut calldata = array![]; + calldata.append(to.into()); + calldata.append(amount.low.into()); + calldata.append(amount.high.into()); + + Call { to: mock_address, selector: TRANSFER_SELECTOR, calldata: calldata.span() } +} + +fn execute_session_call( + account_address: ContractAddress, + calls: @Array, + signature: @Array, +) -> SyscallResult> { + let zero: ContractAddress = 0.try_into().unwrap(); + start_cheat_caller_address(account_address, zero); + start_cheat_signature(account_address, signature.span()); + + let mut execute_calldata = array![]; + Serde::>::serialize(calls, ref execute_calldata); + + let result = call_contract_syscall( + account_address, + starknet::selector!("__execute__"), + execute_calldata.span(), + ); + + stop_cheat_signature(account_address); + stop_cheat_caller_address(account_address); + + result +} + +fn assert_reverted_with(result: SyscallResult>, expected: felt252) { + match result { + Result::Ok(_) => { assert(false, 'expected revert'); }, + Result::Err(panic_data) => { + let data = panic_data.span(); + assert(data.len() > 0_usize, 'missing panic data'); + let actual = *data.at(0_usize); + assert(actual == expected, 'unexpected revert reason'); + }, + } +} + #[test] -fn add_get_revoke_session_works() { +fn test_session_add_ok() { let (contract_address, dispatcher) = deploy_account(); start_cheat_caller_address(contract_address, contract_address); @@ -42,19 +154,22 @@ fn add_get_revoke_session_works() { let key_hash = pedersen(key, 0); let policy = SessionPolicy { is_active: false, - expires_at: 3_600_u64, + valid_after: 0_u64, + valid_until: 3_600_u64, max_calls: 5_u32, calls_used: 2_u32, max_value_per_call: u256 { low: 0, high: 0 }, + owner_epoch: 0_u64, }; dispatcher.add_session(key, policy); let stored_policy = dispatcher.get_session(key_hash); assert(stored_policy.is_active == true, 'session inactive'); - assert(stored_policy.expires_at == 3_600_u64, 'expiry mismatch'); + assert(stored_policy.valid_until == 3_600_u64, 'expiry mismatch'); assert(stored_policy.max_calls == 5_u32, 'max calls mismatch'); assert(stored_policy.calls_used == 0_u32, 'calls used not reset'); + assert(stored_policy.owner_epoch == 0_u64, 'unexpected session epoch'); dispatcher.revoke_session(key_hash); @@ -75,10 +190,12 @@ fn events_emitted() { let key_hash = pedersen(key, 0); let policy = SessionPolicy { is_active: true, - expires_at: 7_200_u64, + valid_after: 0_u64, + valid_until: 7_200_u64, max_calls: 10_u32, calls_used: 0_u32, max_value_per_call: u256 { low: 0, high: 0 }, + owner_epoch: 0_u64, }; dispatcher.add_session(key, policy); @@ -86,18 +203,193 @@ fn events_emitted() { stop_cheat_caller_address(contract_address); - spy.assert_emitted(@array![ - ( - contract_address, - UA2Account::Event::SessionAdded(SessionAdded { - key_hash, - expires_at: 7_200_u64, - max_calls: 10_u32, - }), - ), - ( - contract_address, - UA2Account::Event::SessionRevoked(SessionRevoked { key_hash }), - ), - ]); + spy + .assert_emitted( + @array![ + ( + contract_address, + UA2Account::Event::SessionAdded( + SessionAdded { + key_hash, valid_after: 0_u64, valid_until: 7_200_u64, max_calls: 10_u32, + }, + ), + ), + (contract_address, UA2Account::Event::SessionRevoked(SessionRevoked { key_hash })), + ], + ); +} + +#[test] +fn test_session_expired_rejects() { + let (account_address, _, mock_address) = deploy_account_and_mock(); + + let valid_after = 0_u64; + let valid_until = 100_u64; + let policy = SessionPolicy { + is_active: true, + valid_after, + valid_until, + max_calls: 3_u32, + calls_used: 0_u32, + max_value_per_call: u256 { low: 1_000_u128, high: 0_u128 }, + owner_epoch: 0_u64, + }; + + let session_pubkey = session_key(); + let mut targets = array![mock_address]; + let mut selectors = array![TRANSFER_SELECTOR]; + add_session_allowlist(account_address, session_pubkey, policy, targets, selectors); + + start_cheat_block_timestamp(account_address, 1_000_u64); + + let to: ContractAddress = account_address; + let amount = u256 { low: 100_u128, high: 0_u128 }; + let call = build_transfer_call(mock_address, to, amount); + let calls = array![call]; + + let signature: Array = build_session_signature( + account_address, + session_pubkey, + 0_u128, + valid_until, + @calls, + ); + + let result = execute_session_call(account_address, @calls, @signature); + + stop_cheat_block_timestamp(account_address); + + assert_reverted_with(result, ERR_SESSION_EXPIRED); +} + +#[test] +fn test_session_selector_denied() { + let (account_address, _, mock_address) = deploy_account_and_mock(); + + let valid_after = 0_u64; + let valid_until = 5_000_u64; + let policy = SessionPolicy { + is_active: true, + valid_after, + valid_until, + max_calls: 5_u32, + calls_used: 0_u32, + max_value_per_call: u256 { low: 10_000_u128, high: 0_u128 }, + owner_epoch: 0_u64, + }; + + let session_pubkey = session_key(); + let mut targets = array![mock_address]; + let mut selectors = array![TRANSFER_SELECTOR]; + add_session_allowlist(account_address, session_pubkey, policy, targets, selectors); + + start_cheat_block_timestamp(account_address, 1_000_u64); + + let mut calldata = array![]; + let selector = starknet::selector!("get_last"); + let call = Call { to: mock_address, selector, calldata: calldata.span() }; + let calls = array![call]; + + let signature: Array = build_session_signature( + account_address, + session_pubkey, + 0_u128, + valid_until, + @calls, + ); + + let result = execute_session_call(account_address, @calls, @signature); + + stop_cheat_block_timestamp(account_address); + + assert_reverted_with(result, ERR_POLICY_SELECTOR_DENIED); +} + +#[test] +fn test_session_target_denied() { + let (account_address, _, mock_address) = deploy_account_and_mock(); + + let valid_after = 0_u64; + let valid_until = 5_000_u64; + let policy = SessionPolicy { + is_active: true, + valid_after, + valid_until, + max_calls: 5_u32, + calls_used: 0_u32, + max_value_per_call: u256 { low: 10_000_u128, high: 0_u128 }, + owner_epoch: 0_u64, + }; + + let session_pubkey = session_key(); + let mut targets = array![mock_address]; + let mut selectors = array![TRANSFER_SELECTOR]; + add_session_allowlist(account_address, session_pubkey, policy, targets, selectors); + + start_cheat_block_timestamp(account_address, 1_000_u64); + + let mut calldata = array![]; + let call = Call { + to: account_address, + selector: starknet::selector!("get_owner"), + calldata: calldata.span(), + }; + let calls = array![call]; + + let signature: Array = build_session_signature( + account_address, + session_pubkey, + 0_u128, + valid_until, + @calls, + ); + + let result = execute_session_call(account_address, @calls, @signature); + + stop_cheat_block_timestamp(account_address); + + assert_reverted_with(result, ERR_POLICY_TARGET_DENIED); +} + +#[test] +fn test_session_value_cap() { + let (account_address, _, mock_address) = deploy_account_and_mock(); + + let valid_after = 0_u64; + let valid_until = 5_000_u64; + let policy = SessionPolicy { + is_active: true, + valid_after, + valid_until, + max_calls: 5_u32, + calls_used: 0_u32, + max_value_per_call: u256 { low: 1_000_u128, high: 0_u128 }, + owner_epoch: 0_u64, + }; + + let session_pubkey = session_key(); + let mut targets = array![mock_address]; + let mut selectors = array![TRANSFER_SELECTOR]; + add_session_allowlist(account_address, session_pubkey, policy, targets, selectors); + + start_cheat_block_timestamp(account_address, 1_000_u64); + + let to: ContractAddress = account_address; + let amount = u256 { low: 5_000_u128, high: 0_u128 }; + let call = build_transfer_call(mock_address, to, amount); + let calls = array![call]; + + let signature: Array = build_session_signature( + account_address, + session_pubkey, + 0_u128, + valid_until, + @calls, + ); + + let result = execute_session_call(account_address, @calls, @signature); + + stop_cheat_block_timestamp(account_address); + + assert_reverted_with(result, ERR_VALUE_LIMIT_EXCEEDED); } diff --git a/packages/contracts/tests/test_validate_allowlists.cairo b/packages/contracts/tests/test_validate_allowlists.cairo index 6db77f0..f9e4d0c 100644 --- a/packages/contracts/tests/test_validate_allowlists.cairo +++ b/packages/contracts/tests/test_validate_allowlists.cairo @@ -1,41 +1,50 @@ use core::array::{Array, ArrayTrait, SpanTrait}; use core::integer::u256; use core::option::Option; -use core::traits::{Into, TryInto}; use core::serde::Serde; +use core::traits::{Into, TryInto}; use snforge_std::{ - declare, - spy_events, - start_cheat_block_timestamp, - stop_cheat_block_timestamp, - start_cheat_caller_address, - stop_cheat_caller_address, - start_cheat_signature, - stop_cheat_signature, - ContractClassTrait, - DeclareResultTrait, - EventSpyAssertionsTrait, + ContractClassTrait, DeclareResultTrait, EventSpyAssertionsTrait, declare, spy_events, + start_cheat_block_timestamp, start_cheat_caller_address, start_cheat_signature, + stop_cheat_block_timestamp, stop_cheat_caller_address, stop_cheat_signature, }; use starknet::account::Call; use starknet::syscalls::call_contract_syscall; use starknet::{ContractAddress, SyscallResultTrait}; +use ua2_contracts::session::Session; use ua2_contracts::ua2_account::UA2Account::{ - Event, - ISessionManagerDispatcher, - ISessionManagerDispatcherTrait, - SessionNonceAdvanced, - SessionPolicy, - SessionUsed, -}; - -use crate::session_test_utils::{ - build_session_signature, - session_key, - session_key_hash, + Event, ISessionManagerDispatcher, ISessionManagerDispatcherTrait, SessionNonceAdvanced, + SessionPolicy, SessionUsed, }; +use crate::session_test_utils::{build_session_signature, session_key, session_key_hash}; const OWNER_PUBKEY: felt252 = 0x12345; const TRANSFER_SELECTOR: felt252 = starknet::selector!("transfer"); +const TRANSFER_FROM_SELECTOR: felt252 = starknet::selector!("transferFrom"); + +fn build_transfer_call(mock_address: ContractAddress, to: ContractAddress, amount: u256) -> Call { + let mut calldata = array![]; + calldata.append(to.into()); + calldata.append(amount.low.into()); + calldata.append(amount.high.into()); + + Call { to: mock_address, selector: TRANSFER_SELECTOR, calldata: calldata.span() } +} + +fn build_transfer_from_call( + mock_address: ContractAddress, + from: ContractAddress, + to: ContractAddress, + amount: u256, +) -> Call { + let mut calldata = array![]; + calldata.append(from.into()); + calldata.append(to.into()); + calldata.append(amount.low.into()); + calldata.append(amount.high.into()); + + Call { to: mock_address, selector: TRANSFER_FROM_SELECTOR, calldata: calldata.span() } +} #[test] fn session_allows_whitelisted_calls() { @@ -48,45 +57,56 @@ fn session_allows_whitelisted_calls() { let mock_class = mock_declare.contract_class(); let (mock_address, _) = mock_class.deploy(@array![]).unwrap_syscall(); - let expires_at = 10_000_u64; + let valid_after = 0_u64; + let valid_until = 10_000_u64; let policy = SessionPolicy { is_active: true, - expires_at, + valid_after, + valid_until, max_calls: 1_u32, calls_used: 0_u32, max_value_per_call: u256 { low: 1_000_u128, high: 0_u128 }, + owner_epoch: 0_u64, }; start_cheat_block_timestamp(account_address, 5_000_u64); start_cheat_caller_address(account_address, account_address); - let mut allowlist_calldata = array![]; let session_pubkey = session_key(); let key_hash = session_key_hash(); - allowlist_calldata.append(session_pubkey); - allowlist_calldata.append(1.into()); - allowlist_calldata.append(expires_at.into()); - allowlist_calldata.append(policy.max_calls.into()); - allowlist_calldata.append(policy.calls_used.into()); - allowlist_calldata.append(policy.max_value_per_call.low.into()); - allowlist_calldata.append(policy.max_value_per_call.high.into()); - allowlist_calldata.append(1.into()); - allowlist_calldata.append(mock_address.into()); - allowlist_calldata.append(1.into()); - allowlist_calldata.append(TRANSFER_SELECTOR); + let mut targets: Array = array![]; + targets.append(mock_address); + let mut selectors: Array = array![]; + selectors.append(TRANSFER_SELECTOR); + + let session = Session { + pubkey: session_pubkey, + valid_after, + valid_until, + max_calls: policy.max_calls, + value_cap: policy.max_value_per_call, + targets_len: 1_u32, + targets, + selectors_len: 1_u32, + selectors, + }; + + let mut allowlist_calldata = array![]; + Serde::::serialize(@session, ref allowlist_calldata); call_contract_syscall( account_address, starknet::selector!("add_session_with_allowlists"), allowlist_calldata.span(), ) - .unwrap_syscall(); + .unwrap_syscall(); stop_cheat_caller_address(account_address); let session_dispatcher = ISessionManagerDispatcher { contract_address: account_address }; let stored_policy = session_dispatcher.get_session(key_hash); assert(stored_policy.is_active == true, 'session inactive'); + assert(stored_policy.owner_epoch == 0_u64, 'unexpected session epoch'); let amount = u256 { low: 500_u128, high: 0_u128 }; let to: ContractAddress = account_address; @@ -101,9 +121,146 @@ fn session_allows_whitelisted_calls() { let zero_contract: ContractAddress = 0.try_into().unwrap(); start_cheat_caller_address(account_address, zero_contract); - let signature: Array = - build_session_signature(account_address, session_pubkey, 0_u128, @calls); + let signature: Array = build_session_signature( + account_address, session_pubkey, 0_u128, policy.valid_until, @calls, + ); + start_cheat_signature(account_address, signature.span()); + let mut execute_calldata = array![]; + Serde::>::serialize(@calls, ref execute_calldata); + call_contract_syscall( + account_address, starknet::selector!("__execute__"), execute_calldata.span(), + ) + .unwrap_syscall(); + + stop_cheat_signature(account_address); + stop_cheat_caller_address(account_address); + stop_cheat_block_timestamp(account_address); + + spy + .assert_emitted( + @array![ + (account_address, Event::SessionUsed(SessionUsed { key_hash, used: 1_u32 })), + ( + account_address, + Event::SessionNonceAdvanced( + SessionNonceAdvanced { key_hash, new_nonce: 1_u128 }, + ), + ), + ], + ); + + let get_last_result = call_contract_syscall( + mock_address, starknet::selector!("get_last"), array![].span(), + ) + .unwrap_syscall(); + + let recorded_to_felt = *get_last_result.at(0); + let recorded_low_felt = *get_last_result.at(1); + let recorded_high_felt = *get_last_result.at(2); + + let recorded_to: ContractAddress = match recorded_to_felt.try_into() { + Option::Some(addr) => addr, + Option::None(_) => { + assert(false, 'invalid recorded address'); + 0.try_into().unwrap() + }, + }; + + let recorded_amount = u256 { + low: match recorded_low_felt.try_into() { + Option::Some(value) => value, + Option::None(_) => { + assert(false, 'invalid amount low'); + 0.try_into().unwrap() + }, + }, + high: match recorded_high_felt.try_into() { + Option::Some(value) => value, + Option::None(_) => { + assert(false, 'invalid amount high'); + 0.try_into().unwrap() + }, + }, + }; + + assert(recorded_to == to, 'incorrect transfer recipient'); + assert(recorded_amount == amount, 'incorrect transfer amount'); +} + +#[test] +fn session_allows_transfer_from_calls() { + let account_declare = declare("UA2Account").unwrap(); + let account_class = account_declare.contract_class(); + let (account_address, _) = account_class.deploy(@array![OWNER_PUBKEY]).unwrap_syscall(); + + let mock_declare = declare("MockERC20").unwrap(); + let mock_class = mock_declare.contract_class(); + let (mock_address, _) = mock_class.deploy(@array![]).unwrap_syscall(); + + let valid_after = 0_u64; + let valid_until = 10_000_u64; + + let policy = SessionPolicy { + is_active: true, + valid_after, + valid_until, + max_calls: 1_u32, + calls_used: 0_u32, + max_value_per_call: u256 { low: 1_000_u128, high: 0_u128 }, + owner_epoch: 0_u64, + }; + + start_cheat_block_timestamp(account_address, 5_000_u64); + + start_cheat_caller_address(account_address, account_address); + let session_pubkey = session_key(); + + let mut targets: Array = array![]; + targets.append(mock_address); + let mut selectors: Array = array![]; + selectors.append(TRANSFER_FROM_SELECTOR); + + let session = Session { + pubkey: session_pubkey, + valid_after, + valid_until, + max_calls: policy.max_calls, + value_cap: policy.max_value_per_call, + targets_len: 1_u32, + targets, + selectors_len: 1_u32, + selectors, + }; + + let mut allowlist_calldata = array![]; + Serde::::serialize(@session, ref allowlist_calldata); + + call_contract_syscall( + account_address, + starknet::selector!("add_session_with_allowlists"), + allowlist_calldata.span(), + ) + .unwrap_syscall(); + stop_cheat_caller_address(account_address); + + let amount = u256 { low: 600_u128, high: 0_u128 }; + let from: ContractAddress = account_address; + let to: ContractAddress = 0x555.try_into().unwrap(); + + let call = build_transfer_from_call(mock_address, from, to, amount); + let calls = array![call]; + + let zero_contract: ContractAddress = 0.try_into().unwrap(); + start_cheat_caller_address(account_address, zero_contract); + let signature: Array = build_session_signature( + account_address, + session_pubkey, + 0_u128, + policy.valid_until, + @calls, + ); start_cheat_signature(account_address, signature.span()); + let mut execute_calldata = array![]; Serde::>::serialize(@calls, ref execute_calldata); call_contract_syscall( @@ -111,29 +268,17 @@ fn session_allows_whitelisted_calls() { starknet::selector!("__execute__"), execute_calldata.span(), ) - .unwrap_syscall(); + .unwrap_syscall(); stop_cheat_signature(account_address); stop_cheat_caller_address(account_address); - stop_cheat_block_timestamp(account_address); - - spy.assert_emitted(@array![ - ( - account_address, - Event::SessionUsed(SessionUsed { key_hash, used: 1_u32 }), - ), - ( - account_address, - Event::SessionNonceAdvanced(SessionNonceAdvanced { key_hash, new_nonce: 1_u128 }), - ), - ]); let get_last_result = call_contract_syscall( mock_address, starknet::selector!("get_last"), array![].span(), ) - .unwrap_syscall(); + .unwrap_syscall(); let recorded_to_felt = *get_last_result.at(0); let recorded_low_felt = *get_last_result.at(1); @@ -166,4 +311,6 @@ fn session_allows_whitelisted_calls() { assert(recorded_to == to, 'incorrect transfer recipient'); assert(recorded_amount == amount, 'incorrect transfer amount'); + + stop_cheat_block_timestamp(account_address); } diff --git a/packages/contracts/tests/test_validate_auth.cairo b/packages/contracts/tests/test_validate_auth.cairo index 7aabdce..837e9b0 100644 --- a/packages/contracts/tests/test_validate_auth.cairo +++ b/packages/contracts/tests/test_validate_auth.cairo @@ -25,13 +25,12 @@ use snforge_std::signature::stark_curve::{ use starknet::account::Call; use starknet::syscalls::call_contract_syscall; use starknet::{ContractAddress, SyscallResultTrait}; +use ua2_contracts::errors::{ERR_GUARDIAN_CALL_DENIED, ERR_NO_RECOVERY}; use ua2_contracts::ua2_account::UA2Account::{self, RecoveryDelaySet}; const MODE_OWNER: felt252 = 0; const MODE_GUARDIAN: felt252 = 2; const OWNER_PRIVATE_KEY: felt252 = 0x123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde; -const ERR_NO_RECOVERY: felt252 = 'ERR_NO_RECOVERY'; -const ERR_GUARDIAN_CALL_DENIED: felt252 = 'ERR_GUARDIAN_CALL_DENIED'; fn owner_keypair() -> KeyPair { StarkCurveKeyPairImpl::from_secret_key(OWNER_PRIVATE_KEY) diff --git a/packages/contracts/tests/test_validate_denials.cairo b/packages/contracts/tests/test_validate_denials.cairo index 86540f5..1b6adc2 100644 --- a/packages/contracts/tests/test_validate_denials.cairo +++ b/packages/contracts/tests/test_validate_denials.cairo @@ -1,34 +1,28 @@ use core::array::{Array, ArrayTrait, SpanTrait}; use core::integer::u256; +use core::option::Option; use core::result::Result; use core::serde::Serde; use core::traits::{Into, TryInto}; - use snforge_std::{ - declare, - start_cheat_block_timestamp, - stop_cheat_block_timestamp, - start_cheat_caller_address, - stop_cheat_caller_address, - start_cheat_signature, - stop_cheat_signature, - ContractClassTrait, - DeclareResultTrait, + ContractClassTrait, DeclareResultTrait, declare, start_cheat_block_timestamp, + start_cheat_caller_address, start_cheat_signature, stop_cheat_block_timestamp, + stop_cheat_caller_address, stop_cheat_signature, }; use starknet::account::Call; use starknet::syscalls::call_contract_syscall; use starknet::{ContractAddress, SyscallResult, SyscallResultTrait}; +use ua2_contracts::errors::{ + ERR_POLICY_CALLCAP, ERR_POLICY_SELECTOR_DENIED, ERR_POLICY_TARGET_DENIED, ERR_SESSION_EXPIRED, + ERR_SESSION_NOT_READY, ERR_SESSION_SELECTORS_LEN, ERR_SESSION_TARGETS_LEN, ERR_VALUE_LIMIT_EXCEEDED, +}; +use ua2_contracts::session::Session; use ua2_contracts::ua2_account::UA2Account::SessionPolicy; - use crate::session_test_utils::{build_session_signature, session_key}; const OWNER_PUBKEY: felt252 = 0x12345; const TRANSFER_SELECTOR: felt252 = starknet::selector!("transfer"); -const ERR_POLICY_SELECTOR_DENIED: felt252 = 'ERR_POLICY_SELECTOR_DENIED'; -const ERR_POLICY_TARGET_DENIED: felt252 = 'ERR_POLICY_TARGET_DENIED'; -const ERR_SESSION_EXPIRED: felt252 = 'ERR_SESSION_EXPIRED'; -const ERR_POLICY_CALLCAP: felt252 = 'ERR_POLICY_CALLCAP'; -const ERR_VALUE_LIMIT_EXCEEDED: felt252 = 'ERR_VALUE_LIMIT_EXCEEDED'; +const TRANSFER_FROM_SELECTOR: felt252 = starknet::selector!("transferFrom"); fn deploy_account_and_mock() -> (ContractAddress, ContractAddress) { let account_declare = declare("UA2Account").unwrap(); @@ -51,55 +45,78 @@ fn add_session_with_lists( ) { start_cheat_caller_address(account_address, account_address); - let mut calldata = array![]; - calldata.append(key); - let active_flag: felt252 = if policy.is_active { 1 } else { 0 }; - calldata.append(active_flag); - calldata.append(policy.expires_at.into()); - calldata.append(policy.max_calls.into()); - calldata.append(policy.calls_used.into()); - calldata.append(policy.max_value_per_call.low.into()); - calldata.append(policy.max_value_per_call.high.into()); - - let targets_len = ArrayTrait::::len(targets); - calldata.append(targets_len.into()); - let mut i = 0_usize; - while i < targets_len { - let target = *ArrayTrait::::at(targets, i); - calldata.append(target.into()); - i += 1_usize; + let mut owned_targets: Array = array![]; + for target_ref in targets.span() { + owned_targets.append(*target_ref); } - - let selectors_len = ArrayTrait::::len(selectors); - calldata.append(selectors_len.into()); - i = 0_usize; - while i < selectors_len { - let selector = *ArrayTrait::::at(selectors, i); - calldata.append(selector); - i += 1_usize; + let mut owned_selectors: Array = array![]; + for selector_ref in selectors.span() { + owned_selectors.append(*selector_ref); } + let targets_len_usize = ArrayTrait::::len(@owned_targets); + let selectors_len_usize = ArrayTrait::::len(@owned_selectors); + + let targets_len: u32 = match targets_len_usize.try_into() { + Option::Some(value) => value, + Option::None(_) => { + assert(false, 'targets too long'); + 0_u32 + }, + }; + let selectors_len: u32 = match selectors_len_usize.try_into() { + Option::Some(value) => value, + Option::None(_) => { + assert(false, 'selectors too long'); + 0_u32 + }, + }; + + let session = Session { + pubkey: key, + valid_after: policy.valid_after, + valid_until: policy.valid_until, + max_calls: policy.max_calls, + value_cap: policy.max_value_per_call, + targets_len, + targets: owned_targets, + selectors_len, + selectors: owned_selectors, + }; + + let mut calldata = array![]; + Serde::::serialize(@session, ref calldata); + call_contract_syscall( - account_address, - starknet::selector!("add_session_with_allowlists"), - calldata.span(), + account_address, starknet::selector!("add_session_with_allowlists"), calldata.span(), ) - .unwrap_syscall(); + .unwrap_syscall(); stop_cheat_caller_address(account_address); } -fn build_transfer_call( +fn build_transfer_call(mock_address: ContractAddress, to: ContractAddress, amount: u256) -> Call { + let mut calldata = array![]; + calldata.append(to.into()); + calldata.append(amount.low.into()); + calldata.append(amount.high.into()); + + Call { to: mock_address, selector: TRANSFER_SELECTOR, calldata: calldata.span() } +} + +fn build_transfer_from_call( mock_address: ContractAddress, + from: ContractAddress, to: ContractAddress, amount: u256, ) -> Call { let mut calldata = array![]; + calldata.append(from.into()); calldata.append(to.into()); calldata.append(amount.low.into()); calldata.append(amount.high.into()); - Call { to: mock_address, selector: TRANSFER_SELECTOR, calldata: calldata.span() } + Call { to: mock_address, selector: TRANSFER_FROM_SELECTOR, calldata: calldata.span() } } fn execute_session_calls( @@ -107,20 +124,20 @@ fn execute_session_calls( calls: @Array, nonce: u128, session_pubkey: felt252, + valid_until: u64, ) -> SyscallResult> { let zero_contract: ContractAddress = 0.try_into().unwrap(); start_cheat_caller_address(account_address, zero_contract); - let signature: Array = - build_session_signature(account_address, session_pubkey, nonce, calls); + let signature: Array = build_session_signature( + account_address, session_pubkey, nonce, valid_until, calls, + ); start_cheat_signature(account_address, signature.span()); let mut execute_calldata = array![]; Serde::>::serialize(calls, ref execute_calldata); let result = call_contract_syscall( - account_address, - starknet::selector!("__execute__"), - execute_calldata.span(), + account_address, starknet::selector!("__execute__"), execute_calldata.span(), ); stop_cheat_signature(account_address); @@ -131,9 +148,7 @@ fn execute_session_calls( fn assert_reverted_with(result: SyscallResult>, expected: felt252) { match result { - Result::Ok(_) => { - assert(false, 'expected revert'); - }, + Result::Ok(_) => { assert(false, 'expected revert'); }, Result::Err(panic_data) => { let panic_span = panic_data.span(); assert(panic_span.len() > 0_usize, 'missing panic data'); @@ -143,16 +158,98 @@ fn assert_reverted_with(result: SyscallResult>, expected: felt252) } } +#[test] +fn rejects_length_mismatch() { + let (account_address, _) = deploy_account_and_mock(); + + let session_pubkey = session_key(); + + start_cheat_caller_address(account_address, account_address); + + let mut empty_targets = ArrayTrait::::new(); + let mut selectors_one = ArrayTrait::::new(); + selectors_one.append(TRANSFER_SELECTOR); + + let session_targets_mismatch = Session { + pubkey: session_pubkey, + valid_after: 0_u64, + valid_until: 10_000_u64, + max_calls: 1_u32, + value_cap: u256 { low: 1_000_u128, high: 0_u128 }, + targets_len: 1_u32, + targets: empty_targets, + selectors_len: 1_u32, + selectors: selectors_one, + }; + + let mut calldata = array![]; + Serde::::serialize(@session_targets_mismatch, ref calldata); + + let result = call_contract_syscall( + account_address, starknet::selector!("add_session_with_allowlists"), calldata.span(), + ); + + match result { + Result::Ok(_) => { assert(false, 'expected targets len mismatch'); }, + Result::Err(panic_data) => { + let data = panic_data.span(); + assert(data.len() > 0_usize, 'missing panic data'); + let reason = *data.at(0_usize); + assert(reason == ERR_SESSION_TARGETS_LEN, 'unexpected targets len error'); + }, + } + + let mut targets_one = ArrayTrait::::new(); + targets_one.append(account_address); + let mut selectors_mismatch = ArrayTrait::::new(); + selectors_mismatch.append(TRANSFER_SELECTOR); + + let session_selectors_mismatch = Session { + pubkey: session_pubkey, + valid_after: 0_u64, + valid_until: 10_000_u64, + max_calls: 1_u32, + value_cap: u256 { low: 1_000_u128, high: 0_u128 }, + targets_len: 1_u32, + targets: targets_one, + selectors_len: 2_u32, + selectors: selectors_mismatch, + }; + + let mut calldata_selectors = array![]; + Serde::::serialize(@session_selectors_mismatch, ref calldata_selectors); + + let selectors_result = call_contract_syscall( + account_address, + starknet::selector!("add_session_with_allowlists"), + calldata_selectors.span(), + ); + + match selectors_result { + Result::Ok(_) => { assert(false, 'expected selectors len mismatch'); }, + Result::Err(panic_data) => { + let data = panic_data.span(); + assert(data.len() > 0_usize, 'missing panic data'); + let reason = *data.at(0_usize); + assert(reason == ERR_SESSION_SELECTORS_LEN, 'unexpected selectors len error'); + }, + } + + stop_cheat_caller_address(account_address); +} + #[test] fn denies_selector_not_allowed() { let (account_address, mock_address) = deploy_account_and_mock(); let policy = SessionPolicy { is_active: true, - expires_at: 10_000_u64, + valid_after: 0_u64, + valid_until: 10_000_u64, max_calls: 5_u32, calls_used: 0_u32, max_value_per_call: u256 { low: 10_000_u128, high: 0_u128 }, + owner_epoch: 0_u64, }; let mut targets = array![mock_address]; @@ -168,7 +265,9 @@ fn denies_selector_not_allowed() { let call = build_transfer_call(mock_address, to, amount); let calls = array![call]; - let result = execute_session_calls(account_address, @calls, 0_u128, session_pubkey); + let result = execute_session_calls( + account_address, @calls, 0_u128, session_pubkey, policy.valid_until, + ); assert_reverted_with(result, ERR_POLICY_SELECTOR_DENIED); @@ -181,10 +280,12 @@ fn denies_target_not_allowed() { let policy = SessionPolicy { is_active: true, - expires_at: 10_000_u64, + valid_after: 0_u64, + valid_until: 10_000_u64, max_calls: 5_u32, calls_used: 0_u32, max_value_per_call: u256 { low: 10_000_u128, high: 0_u128 }, + owner_epoch: 0_u64, }; let targets = array![]; @@ -200,7 +301,49 @@ fn denies_target_not_allowed() { let call = build_transfer_call(mock_address, to, amount); let calls = array![call]; - let result = execute_session_calls(account_address, @calls, 0_u128, session_pubkey); + let result = execute_session_calls( + account_address, @calls, 0_u128, session_pubkey, policy.valid_until, + ); + + assert_reverted_with(result, ERR_POLICY_TARGET_DENIED); + + stop_cheat_block_timestamp(account_address); +} + +#[test] +fn empty_allowlists_reject_calls() { + let (account_address, mock_address) = deploy_account_and_mock(); + + let policy = SessionPolicy { + is_active: true, + valid_after: 0_u64, + valid_until: 10_000_u64, + max_calls: 5_u32, + calls_used: 0_u32, + max_value_per_call: u256 { low: 10_000_u128, high: 0_u128 }, + owner_epoch: 0_u64, + }; + + let targets = array![]; + let selectors = array![]; + + let session_pubkey = session_key(); + add_session_with_lists(account_address, session_pubkey, policy, @targets, @selectors); + + start_cheat_block_timestamp(account_address, 5_000_u64); + + let to: ContractAddress = account_address; + let amount = u256 { low: 1_000_u128, high: 0_u128 }; + let call = build_transfer_call(mock_address, to, amount); + let calls = array![call]; + + let result = execute_session_calls( + account_address, + @calls, + 0_u128, + session_pubkey, + policy.valid_until, + ); assert_reverted_with(result, ERR_POLICY_TARGET_DENIED); @@ -213,10 +356,12 @@ fn denies_expired_session() { let policy = SessionPolicy { is_active: true, - expires_at: 6_000_u64, + valid_after: 0_u64, + valid_until: 6_000_u64, max_calls: 5_u32, calls_used: 0_u32, max_value_per_call: u256 { low: 10_000_u128, high: 0_u128 }, + owner_epoch: 0_u64, }; let mut targets = array![mock_address]; @@ -232,23 +377,67 @@ fn denies_expired_session() { let call = build_transfer_call(mock_address, to, amount); let calls = array![call]; - let result = execute_session_calls(account_address, @calls, 0_u128, session_pubkey); + let result = execute_session_calls( + account_address, @calls, 0_u128, session_pubkey, policy.valid_until, + ); assert_reverted_with(result, ERR_SESSION_EXPIRED); stop_cheat_block_timestamp(account_address); } +#[test] +fn denies_session_not_ready() { + let (account_address, mock_address) = deploy_account_and_mock(); + + let policy = SessionPolicy { + is_active: true, + valid_after: 6_000_u64, + valid_until: 12_000_u64, + max_calls: 5_u32, + calls_used: 0_u32, + max_value_per_call: u256 { low: 10_000_u128, high: 0_u128 }, + owner_epoch: 0_u64, + }; + + let mut targets = array![mock_address]; + let mut selectors = array![TRANSFER_SELECTOR]; + + let session_pubkey = session_key(); + add_session_with_lists(account_address, session_pubkey, policy, @targets, @selectors); + + start_cheat_block_timestamp(account_address, 5_000_u64); + + let to: ContractAddress = account_address; + let amount = u256 { low: 1_000_u128, high: 0_u128 }; + let call = build_transfer_call(mock_address, to, amount); + let calls = array![call]; + + let result = execute_session_calls( + account_address, + @calls, + 0_u128, + session_pubkey, + policy.valid_until, + ); + + assert_reverted_with(result, ERR_SESSION_NOT_READY); + + stop_cheat_block_timestamp(account_address); +} + #[test] fn denies_over_call_cap() { let (account_address, mock_address) = deploy_account_and_mock(); let policy = SessionPolicy { is_active: true, - expires_at: 10_000_u64, + valid_after: 0_u64, + valid_until: 10_000_u64, max_calls: 1_u32, calls_used: 0_u32, max_value_per_call: u256 { low: 10_000_u128, high: 0_u128 }, + owner_epoch: 0_u64, }; let mut targets = array![mock_address]; @@ -265,7 +454,9 @@ fn denies_over_call_cap() { let call_two = build_transfer_call(mock_address, to, amount); let calls = array![call_one, call_two]; - let result = execute_session_calls(account_address, @calls, 0_u128, session_pubkey); + let result = execute_session_calls( + account_address, @calls, 0_u128, session_pubkey, policy.valid_until, + ); assert_reverted_with(result, ERR_POLICY_CALLCAP); @@ -278,10 +469,12 @@ fn denies_over_value_cap() { let policy = SessionPolicy { is_active: true, - expires_at: 10_000_u64, + valid_after: 0_u64, + valid_until: 10_000_u64, max_calls: 5_u32, calls_used: 0_u32, max_value_per_call: u256 { low: 1_000_u128, high: 0_u128 }, + owner_epoch: 0_u64, }; let mut targets = array![mock_address]; @@ -297,7 +490,50 @@ fn denies_over_value_cap() { let call = build_transfer_call(mock_address, to, amount); let calls = array![call]; - let result = execute_session_calls(account_address, @calls, 0_u128, session_pubkey); + let result = execute_session_calls( + account_address, @calls, 0_u128, session_pubkey, policy.valid_until, + ); + + assert_reverted_with(result, ERR_VALUE_LIMIT_EXCEEDED); + + stop_cheat_block_timestamp(account_address); +} + +#[test] +fn denies_transfer_from_over_value_cap() { + let (account_address, mock_address) = deploy_account_and_mock(); + + let policy = SessionPolicy { + is_active: true, + valid_after: 0_u64, + valid_until: 10_000_u64, + max_calls: 5_u32, + calls_used: 0_u32, + max_value_per_call: u256 { low: 1_000_u128, high: 0_u128 }, + owner_epoch: 0_u64, + }; + + let mut targets = array![mock_address]; + let mut selectors = array![TRANSFER_FROM_SELECTOR]; + + let session_pubkey = session_key(); + add_session_with_lists(account_address, session_pubkey, policy, @targets, @selectors); + + start_cheat_block_timestamp(account_address, 5_000_u64); + + let from: ContractAddress = account_address; + let to: ContractAddress = account_address; + let amount = u256 { low: 5_000_u128, high: 0_u128 }; + let call = build_transfer_from_call(mock_address, from, to, amount); + let calls = array![call]; + + let result = execute_session_calls( + account_address, + @calls, + 0_u128, + session_pubkey, + policy.valid_until, + ); assert_reverted_with(result, ERR_VALUE_LIMIT_EXCEEDED); diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 6888a51..1c23d85 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -22,7 +22,12 @@ export type { Uint256, Session, SessionPolicyInput, + SessionPolicyResolved, + SessionPolicyStruct, + SessionPolicyCalldata, SessionsManager, + SessionUsage, + SessionUseOptions, AccountCall, AccountTransaction, CallTransport, diff --git a/packages/core/src/sessions.ts b/packages/core/src/sessions.ts index 2b80b6a..866b348 100644 --- a/packages/core/src/sessions.ts +++ b/packages/core/src/sessions.ts @@ -6,9 +6,13 @@ import type { Session, SessionPolicyCalldata, SessionPolicyInput, + SessionPolicyResolved, + SessionPolicyStruct, + SessionLimits, + SessionUsage, + SessionUseOptions, SessionsManager, UA2AccountLike, - Uint256, } from './types'; import { toUint256 } from './utils/u256'; import { toFelt } from './utils/felt'; @@ -42,14 +46,14 @@ class SessionsImpl implements SessionsManager { } async create(policy: SessionPolicyInput): Promise { - const active = policy.active ?? true; + const resolved = resolvePolicy(policy); const pubkey = genFeltKey(); - const keyHash = pubkey; // v0.1: use felt pubkey directly; can hash later + const sessionId = pubkey; // v0.1: use felt pubkey directly; can hash later const createdAt = Date.now(); // Build calldata for Cairo's SessionPolicy struct and allowlists. - const { policyCalldata, allowCalldata } = buildPolicyCalldata(policy, active); - const calldata = buildAddSessionCalldata(pubkey, keyHash, policyCalldata, allowCalldata); + const { policyCalldata, allowCalldata } = buildPolicyCalldata(resolved); + const calldata = buildAddSessionCalldata(sessionId, pubkey, policyCalldata, allowCalldata); // If we have a transport + ua2 address, we could call add_session_with_allowlists here. // Keeping it local-only for now (no RPC in tests). @@ -58,9 +62,9 @@ class SessionsImpl implements SessionsManager { } const sess: Session = { - id: keyHash, + id: sessionId, pubkey, - policy, + policy: resolved, createdAt, }; this.sessions.push(sess); @@ -79,59 +83,96 @@ class SessionsImpl implements SessionsManager { // Return a shallow copy for immutability. return [...this.sessions]; } + + async use(sessionId: Felt, opts?: SessionUseOptions): Promise { + return useSession(this, sessionId, opts); + } } /* ------------------ Policy / Calldata helpers ------------------ */ -function buildPolicyCalldata(inp: SessionPolicyInput, active: boolean): { +function resolvePolicy(inp: SessionPolicyInput): SessionPolicyResolved { + const validAfter = Math.max(0, Math.floor(inp.validAfter)); + const rawValidUntil = Math.max(0, Math.floor(inp.validUntil)); + const validUntil = rawValidUntil <= validAfter ? validAfter + 1 : rawValidUntil; + const active = inp.active ?? true; + const callsUsed = Math.max(0, Math.floor(inp.callsUsed ?? 0)); + const maxCalls = Math.max(0, Math.floor(inp.limits.maxCalls)); + const [maxLow, maxHigh] = inp.limits.maxValuePerCall; + + return { + ...inp, + validAfter, + validUntil, + limits: { + maxCalls, + maxValuePerCall: [toFelt(maxLow), toFelt(maxHigh)], + }, + allow: { + targets: [...(inp.allow.targets ?? [])], + selectors: [...(inp.allow.selectors ?? [])], + }, + active, + callsUsed, + }; +} + +function buildPolicyCalldata(policy: SessionPolicyResolved): { + policyStruct: SessionPolicyStruct; policyCalldata: SessionPolicyCalldata; allowCalldata: { targets: Felt[]; selectors: Felt[]; }; } { - const expires_at = toFelt(inp.expiresAt >>> 0); // as u64 -> felt - const max_calls = toFelt(inp.limits.maxCalls >>> 0); - const calls_used = toFelt(0); - const [low, high] = inp.limits.maxValuePerCall; - const is_active = toFelt(active ? 1 : 0); + const policyStruct: SessionPolicyStruct = { + is_active: policy.active, + valid_after: policy.validAfter, + valid_until: policy.validUntil, + max_calls: policy.limits.maxCalls, + calls_used: policy.callsUsed, + max_value_per_call: policy.limits.maxValuePerCall, + }; + + const [low, high] = policyStruct.max_value_per_call; const policyCalldata: SessionPolicyCalldata = { - is_active, - expires_at, - max_calls, - calls_used, - max_value_per_call_low: low, - max_value_per_call_high: high, + is_active: toFelt(policyStruct.is_active ? 1 : 0), + valid_after: toFelt(BigInt(policyStruct.valid_after)), + valid_until: toFelt(BigInt(policyStruct.valid_until)), + max_calls: toFelt(policyStruct.max_calls >>> 0), + calls_used: toFelt(policyStruct.calls_used >>> 0), + max_value_per_call_low: toFelt(low), + max_value_per_call_high: toFelt(high), }; - const targets = (inp.allow.targets ?? []).map(toFelt); - const selectors = (inp.allow.selectors ?? []).map(toFelt); + const targets = (policy.allow.targets ?? []).map(toFelt); + const selectors = (policy.allow.selectors ?? []).map(toFelt); - return { policyCalldata, allowCalldata: { targets, selectors } }; + return { policyStruct, policyCalldata, allowCalldata: { targets, selectors } }; } function buildAddSessionCalldata( + sessionId: Felt, pubkey: Felt, - keyHash: Felt, policy: SessionPolicyCalldata, allow: { targets: Felt[]; selectors: Felt[] } ): Felt[] { const policyArray: Felt[] = [ policy.is_active, - policy.expires_at, + policy.valid_after, + policy.valid_until, policy.max_calls, policy.calls_used, policy.max_value_per_call_low, - policy.max_value_per_call_high, ]; const targetsLen = toFelt(allow.targets.length); const selectorsLen = toFelt(allow.selectors.length); return [ + sessionId, pubkey, - keyHash, ...policyArray, targetsLen, ...allow.targets, @@ -141,26 +182,12 @@ function buildAddSessionCalldata( } /** Convenience builder for users: encode numeric string amount as Uint256. */ -export function limits(maxCalls: number, maxValue: string | number | bigint): { - maxCalls: number; - maxValuePerCall: Uint256; -} { +export function limits(maxCalls: number, maxValue: string | number | bigint): SessionLimits { return { maxCalls, maxValuePerCall: toUint256(maxValue) }; } /* ------------------ Session helpers ------------------ */ -export interface SessionUseOptions { - /** Override "now" in milliseconds (defaults to Date.now()). */ - now?: number; -} - -export interface SessionUsage { - session: Session; - /** Ensure the provided calls comply with the session policy. */ - ensureAllowed(calls: AccountCall[] | AccountCall): void; -} - export async function useSession( manager: SessionsManager, sessionId: Felt, @@ -190,16 +217,26 @@ function ensureSessionActive(session: Session, nowMs?: number) { } const nowSeconds = Math.floor((nowMs ?? Date.now()) / 1000); - if (session.policy.expiresAt <= nowSeconds) { - throw new SessionExpiredError(`Session ${session.id} expired at ${session.policy.expiresAt}.`); + if (nowSeconds < session.policy.validAfter) { + throw new SessionExpiredError( + `Session ${session.id} not active until ${session.policy.validAfter}.` + ); + } + if (session.policy.validUntil <= nowSeconds) { + throw new SessionExpiredError( + `Session ${session.id} expired at ${session.policy.validUntil}.` + ); } } function ensurePolicy(session: Session, calls: AccountCall[]) { const { allow, limits } = session.policy; - if (calls.length > limits.maxCalls) { - throw new PolicyViolationError('calls', `${calls.length} > ${limits.maxCalls}`); + if (session.policy.callsUsed + calls.length > limits.maxCalls) { + throw new PolicyViolationError( + 'calls', + `${session.policy.callsUsed + calls.length} > ${limits.maxCalls}` + ); } const allowedTargets = new Set((allow.targets ?? []).map((t) => toFelt(t))); @@ -219,7 +256,9 @@ function ensurePolicy(session: Session, calls: AccountCall[]) { /* ------------------ Guard builder ------------------ */ export interface GuardBuilderInit { - expiresAt?: number; + validAfter?: number; + validUntil?: number; + expiresAt?: number; // legacy alias for validUntil expiresInSeconds?: number; maxCalls?: number; maxValue?: string | number | bigint; @@ -229,6 +268,8 @@ export interface GuardBuilderInit { } export interface GuardBuilder { + validAfter(timestamp: number): GuardBuilder; + validUntil(timestamp: number): GuardBuilder; target(addr: Felt): GuardBuilder; targets(addresses: Iterable): GuardBuilder; selector(sel: Felt): GuardBuilder; @@ -242,7 +283,8 @@ export interface GuardBuilder { } export function guard(init: GuardBuilderInit = {}): GuardBuilder { - let expiresAt = resolveExpiry(init); + let validAfter = Math.max(0, Math.floor(init.validAfter ?? 0)); + let validUntil = resolveValidUntil(init, validAfter); let maxCallsCount = init.maxCalls ?? 1; let maxValueInput: string | number | bigint = init.maxValue ?? 0; let isActive = init.active ?? true; @@ -250,6 +292,15 @@ export function guard(init: GuardBuilderInit = {}): GuardBuilder { const selectors = new Set((init.selectors ?? []).map((s) => toFelt(s))); const builder: GuardBuilder = { + validAfter(timestamp: number) { + validAfter = Math.max(0, Math.floor(timestamp)); + validUntil = Math.max(validUntil, validAfter + 1); + return builder; + }, + validUntil(timestamp: number) { + validUntil = normalizeValidUntil(timestamp, validAfter); + return builder; + }, target(addr: Felt) { targets.add(toFelt(addr)); return builder; @@ -275,12 +326,12 @@ export function guard(init: GuardBuilderInit = {}): GuardBuilder { return builder; }, expiresAt(timestamp: number) { - expiresAt = Math.max(0, Math.floor(timestamp)); + validUntil = normalizeValidUntil(timestamp, validAfter); return builder; }, expiresIn(seconds: number) { const now = Math.floor(Date.now() / 1000); - expiresAt = now + Math.max(0, Math.floor(seconds)); + validUntil = normalizeValidUntil(now + Math.max(0, Math.floor(seconds)), validAfter); return builder; }, active(flag: boolean) { @@ -289,7 +340,8 @@ export function guard(init: GuardBuilderInit = {}): GuardBuilder { }, build(): SessionPolicyInput { return { - expiresAt, + validAfter, + validUntil, limits: { maxCalls: maxCallsCount, maxValuePerCall: toUint256(maxValueInput), @@ -299,6 +351,7 @@ export function guard(init: GuardBuilderInit = {}): GuardBuilder { selectors: Array.from(selectors), }, active: isActive, + callsUsed: 0, }; }, }; @@ -306,16 +359,24 @@ export function guard(init: GuardBuilderInit = {}): GuardBuilder { return builder; } -function resolveExpiry(init: GuardBuilderInit): number { +function resolveValidUntil(init: GuardBuilderInit, validAfter: number): number { + if (typeof init.validUntil === 'number') { + return normalizeValidUntil(init.validUntil, validAfter); + } if (typeof init.expiresAt === 'number') { - return Math.max(0, Math.floor(init.expiresAt)); + return normalizeValidUntil(init.expiresAt, validAfter); } if (typeof init.expiresInSeconds === 'number') { const now = Math.floor(Date.now() / 1000); - return now + Math.max(0, Math.floor(init.expiresInSeconds)); + return normalizeValidUntil(now + Math.max(0, Math.floor(init.expiresInSeconds)), validAfter); } const defaultExpirySeconds = Math.floor(Date.now() / 1000) + 3600; // 1 hour default - return defaultExpirySeconds; + return normalizeValidUntil(defaultExpirySeconds, validAfter); +} + +function normalizeValidUntil(value: number, validAfter: number): number { + const normalized = Math.max(0, Math.floor(value)); + return normalized <= validAfter ? validAfter + 1 : normalized; } export const sessions = { diff --git a/packages/core/src/types.ts b/packages/core/src/types.ts index d955048..97d3b36 100644 --- a/packages/core/src/types.ts +++ b/packages/core/src/types.ts @@ -103,20 +103,25 @@ export interface SessionAllow { } export interface SessionPolicyInput { + /** Earliest timestamp the session can be used (seconds since epoch). */ + validAfter: number; /** Expiration timestamp (seconds since epoch). */ - expiresAt: number; + validUntil: number; /** Limits per session. */ limits: SessionLimits; /** Allowlist constraints. */ allow: SessionAllow; /** Whether the session is active on creation. Default true. */ active?: boolean; + /** Number of calls already consumed by this session (mirrors on-chain `calls_used`). */ + callsUsed?: number; } /** The on-chain policy struct shape (Cairo ordering). */ export interface SessionPolicyCalldata { is_active: Felt; // 0x0 or 0x1 - expires_at: Felt; // u64 -> felt + valid_after: Felt; // u64 -> felt + valid_until: Felt; // u64 -> felt max_calls: Felt; // u32 -> felt calls_used: Felt; // u32 -> felt (init 0) max_value_per_call_low: Felt; @@ -124,6 +129,22 @@ export interface SessionPolicyCalldata { // Arrays come separately as (len, items...) } +/** Cairo struct with native JS types for ergonomics. */ +export interface SessionPolicyStruct { + is_active: boolean; + valid_after: number; + valid_until: number; + max_calls: number; + calls_used: number; + max_value_per_call: Uint256; +} + +/** Session policy resolved with defaults and counters for local mirrors. */ +export interface SessionPolicyResolved extends SessionPolicyInput { + active: boolean; + callsUsed: number; +} + /** Returned by SDK when you create a session. */ export interface Session { /** Internal id = keyHash felt (same as supplied key or its hash). */ @@ -131,7 +152,7 @@ export interface Session { /** Public session key felt (simplified for now). */ pubkey: Felt; /** Policy you requested. */ - policy: SessionPolicyInput; + policy: SessionPolicyResolved; /** Created at (ms). */ createdAt: number; } @@ -144,6 +165,8 @@ export interface SessionsManager { revoke(sessionId: Felt): Promise; /** List locally known sessions. */ list(): Promise; + /** Load and validate a session for client-side policy enforcement. */ + use(sessionId: Felt, opts?: SessionUseOptions): Promise; } /* ------------------ Transport Abstraction (stub) ------------------ */ @@ -196,6 +219,19 @@ export interface PaymasterContext { entrypoint?: string; } +/* ------------------ Session usage helpers ------------------ */ + +export interface SessionUseOptions { + /** Override "now" in milliseconds (defaults to Date.now()). */ + now?: number; +} + +export interface SessionUsage { + session: Session; + /** Ensure the provided calls comply with the session policy. */ + ensureAllowed(calls: AccountCall[] | AccountCall): void; +} + export interface PaymasterRunner { execute(calls: AccountCall[] | AccountCall, maxFee?: Felt): Promise; call(to: Felt, selector: Felt, calldata?: Felt[], maxFee?: Felt): Promise; diff --git a/packages/core/tests/paymasters.spec.ts b/packages/core/tests/paymasters.spec.ts index 4b07d75..085b023 100644 --- a/packages/core/tests/paymasters.spec.ts +++ b/packages/core/tests/paymasters.spec.ts @@ -84,6 +84,25 @@ describe('paymasters', () => { expect(last.data.slice(-3)).toEqual(['0x2', '0xcafe', '0xbeef']); }); + it('marks execution as unsponsored when paymaster adds no metadata', async () => { + const { transport, sent } = mkFakeTransport(); + + const silent: Paymaster = { + name: 'empty', + async sponsor(tx: AccountTransaction): Promise { + return { ...tx }; + }, + }; + + const runner = withPaymaster({ account, ua2Address, transport, paymaster: silent }); + const res = await runner.execute({ to: '0x1', selector: '0x2', calldata: [] }); + + expect(res.sponsored).toBe(false); + expect(res.sponsorName).toBe('empty'); + expect(sent.length).toBe(1); + expect(sent[0].data.slice(-1)[0]).toBe('0x0'); + }); + it('propagates sponsor rejections with documented error', async () => { const { transport, sent } = mkFakeTransport(); const expected = new PaymasterDeniedError('sponsor offline'); diff --git a/packages/core/tests/sessions.spec.ts b/packages/core/tests/sessions.spec.ts index b7a31d0..48e1ad9 100644 --- a/packages/core/tests/sessions.spec.ts +++ b/packages/core/tests/sessions.spec.ts @@ -1,7 +1,12 @@ import { describe, it, expect } from 'vitest'; import { connect } from '../src/connect'; import { limits, guard, useSession, makeSessionsManager } from '../src/sessions'; -import type { ConnectOptions, SessionPolicyInput, AccountCall } from '../src/types'; +import type { + ConnectOptions, + SessionPolicyInput, + AccountCall, + SessionLimits, +} from '../src/types'; import { PolicyViolationError, SessionExpiredError } from '../src/errors'; import { toUint256 } from '../src/utils/u256'; @@ -17,7 +22,8 @@ describe('Sessions API', () => { const client = await connect(baseOpts); const pol: SessionPolicyInput = { - expiresAt: 1_700_000_000, // seconds + validAfter: 0, + validUntil: 1_700_000_000, // seconds limits: limits(10, '10000000000000000'), // 0.01 ETH-ish allow: { targets: ['0xDEAD', '0xBEEF'], @@ -56,7 +62,8 @@ describe('Sessions API', () => { }); const policy: SessionPolicyInput = { - expiresAt: 1_888_888_888, + validAfter: 0, + validUntil: 1_888_888_888, limits: limits(3, '0x10'), allow: { targets: ['0xDEAD', '0xBEEF'], @@ -73,14 +80,14 @@ describe('Sessions API', () => { expect(call.entry).toBe('add_session_with_allowlists'); const data = call.data; - expect(data[0]).toBe(data[1]); + expect(data[0]).toBe(data[1]); // session id mirrors pubkey expect(data.slice(2, 8)).toEqual([ '0x1', + '0x0', '0x70962838', '0x3', '0x0', '0x10', - '0x0', ]); expect(data[8]).toBe('0x2'); expect(data.slice(9, 11)).toEqual(['0xdead', '0xbeef']); @@ -88,11 +95,59 @@ describe('Sessions API', () => { expect(data[12]).toBe('0xcafe'); }); + it('normalizes policy input and serializes calldata as felts', async () => { + const sent: { addr: string; entry: string; data: string[] }[] = []; + const transport = { + async invoke(addr: string, entry: string, data: string[]) { + sent.push({ addr, entry, data: [...data] }); + return { txHash: '0x888' as const }; + }, + }; + + const manager = makeSessionsManager({ + account: { address: '0xACC', chainId: '0xSEPOLIA', label: 'test' }, + transport, + ua2Address: '0xacc0', + }); + + const rawPolicy: SessionPolicyInput = { + validAfter: -5.9, + validUntil: -10, + limits: { maxCalls: -3, maxValuePerCall: ['0x2', '0x0'] as SessionLimits['maxValuePerCall'] }, + allow: { targets: ['0xABCD'], selectors: [] }, + active: false, + callsUsed: -4, + }; + + const session = await manager.create(rawPolicy); + expect(session.policy.validAfter).toBe(0); + expect(session.policy.validUntil).toBe(1); // clamped to > validAfter + expect(session.policy.limits.maxCalls).toBe(0); + expect(session.policy.callsUsed).toBe(0); + expect(session.policy.allow.targets).toEqual(['0xABCD']); + expect(session.policy.allow.selectors).toEqual([]); + + const [sessionId, pubkey, ...rest] = sent[0].data; + expect(sessionId).toBe(pubkey); + expect(rest).toEqual([ + '0x0', // inactive + '0x0', // valid_after + '0x1', // valid_until normalized + '0x0', // max_calls clamped + '0x0', // calls_used reset + '0x2', // max_value_per_call_low normalized + '0x1', // allow.targets length + '0xabcd', + '0x0', // selectors length + ]); + }); + it('revokes a session locally (active=false)', async () => { const client = await connect(baseOpts); const s = await client.sessions.create({ - expiresAt: 1_800_000_000, + validAfter: 0, + validUntil: 1_800_000_000, limits: limits(1, 0), allow: { targets: [], selectors: [] }, active: true @@ -120,10 +175,13 @@ describe('Sessions API', () => { }).build(); const session = await client.sessions.create(policy); - const usage = await useSession(client.sessions, session.id); + const usage = await client.sessions.use(session.id); expect(usage.session.id).toBe(session.id); expect(() => usage.ensureAllowed(allowedCall)).not.toThrow(); + usage.session.policy.callsUsed = usage.session.policy.limits.maxCalls; + expect(() => usage.ensureAllowed(allowedCall)).toThrowError(PolicyViolationError); + usage.session.policy.callsUsed = 0; expect(() => usage.ensureAllowed({ ...allowedCall, to: '0xBEEF' }) ).toThrowError(PolicyViolationError); @@ -132,7 +190,8 @@ describe('Sessions API', () => { await expect(useSession(client.sessions, session.id)).rejects.toBeInstanceOf(SessionExpiredError); const expired = await client.sessions.create({ - expiresAt: Math.floor(Date.now() / 1000) - 1, + validAfter: 0, + validUntil: Math.floor(Date.now() / 1000) - 1, limits: limits(1, 0), allow: { targets: [], selectors: [] }, active: true, @@ -141,6 +200,34 @@ describe('Sessions API', () => { await expect(useSession(client.sessions, expired.id)).rejects.toBeInstanceOf(SessionExpiredError); }); + it('respects provided clock when validating session lifecycle', async () => { + const client = await connect(baseOpts); + const nowSeconds = Math.floor(Date.now() / 1000); + + const session = await client.sessions.create({ + validAfter: nowSeconds + 5, + validUntil: nowSeconds + 10, + limits: limits(2, 0), + allow: { targets: [], selectors: [] }, + active: true, + }); + + await expect( + client.sessions.use(session.id, { now: (session.policy.validAfter - 1) * 1000 }) + ).rejects.toBeInstanceOf(SessionExpiredError); + + const usage = await client.sessions.use(session.id, { now: session.policy.validAfter * 1000 }); + expect(usage.session.id).toBe(session.id); + + await expect( + client.sessions.use(session.id, { now: (session.policy.validUntil + 1) * 1000 }) + ).rejects.toBeInstanceOf(SessionExpiredError); + + await expect(client.sessions.use('0xdeadbeef', { now: Date.now() })).rejects.toBeInstanceOf( + SessionExpiredError + ); + }); + it('guard builder shapes policies with defaults', () => { const policy = guard({ maxValue: '10', expiresInSeconds: 10 }) .target('0x1') @@ -151,6 +238,6 @@ describe('Sessions API', () => { expect(policy.allow.selectors).toContain('0x2'); expect(policy.limits.maxCalls).toBeGreaterThanOrEqual(1); expect(policy.limits.maxValuePerCall[0]).toBeDefined(); - expect(policy.expiresAt).toBeGreaterThan(Math.floor(Date.now() / 1000)); + expect(policy.validUntil).toBeGreaterThan(Math.floor(Date.now() / 1000)); }); }); diff --git a/packages/example/scripts/e2e-devnet.ts b/packages/example/scripts/e2e-devnet.ts index 99a59fa..52aae47 100644 --- a/packages/example/scripts/e2e-devnet.ts +++ b/packages/example/scripts/e2e-devnet.ts @@ -1,3 +1,6 @@ +import fsPromises from 'node:fs/promises'; +import path from 'node:path'; + import { limits, makeSessionsManager, type SessionPolicyInput } from '@ua2/core'; import { Account } from 'starknet'; @@ -5,8 +8,10 @@ import { AccountCallTransport, assertReverted, assertSucceeded, + deriveSessionKeyHash, ensureUa2Deployed, initialSessionUsageState, + logReceipt, optionalEnv, readOwner, selectorFor, @@ -14,31 +19,62 @@ import { toFelt, updateSessionUsage, waitForReceipt, - type Network, + normalizeHex, + PROJECT_ROOT, } from './shared.js'; +type DevnetAddresses = { + ua2Address: string; + classHash: string; + network: string; + updatedAt: string; +}; + +type DeploymentInfo = Awaited>; + +const ADDRESSES_FILE = path.resolve(PROJECT_ROOT, '.ua2-devnet-addresses.json'); +const RECEIPT_TIMEOUT_MS = 60_000; + async function main(): Promise { - const network: Network = 'devnet'; - const toolkit = await setupToolkit(network); + console.log('[ua2] e2e devnet starting'); - const attachedAddress = optionalEnv([ - `UA2_${network.toUpperCase()}_PROXY_ADDR`, + const toolkit = await setupToolkit('devnet'); + + const envAddress = optionalEnv([ + 'UA2_DEVNET_PROXY_ADDR', + 'UA2_DEVNET_ADDR', 'UA2_PROXY_ADDR', + 'UA2_ADDR', ]); - const { address: ua2Address } = await ensureUa2Deployed(toolkit, attachedAddress); - const ownerAccount = new Account(toolkit.provider, ua2Address, toolkit.ownerKey); - const ownerTransport = new AccountCallTransport(ownerAccount); + const cached = envAddress ? undefined : await readCachedAddresses(); - const ownerBefore = await readOwner(toolkit.provider, ua2Address); - if (ownerBefore.toLowerCase() !== toolkit.ownerPubKey.toLowerCase()) { + let deployment = await pickDeployment(toolkit, envAddress, cached?.ua2Address); + const ua2Address = normalizeHex(deployment.address); + const source = envAddress ? 'env' : cached ? 'cache' : 'deploy'; + + if (!envAddress) { + await writeCachedAddresses({ + ua2Address, + classHash: normalizeHex(deployment.classHash), + network: toolkit.network, + updatedAt: new Date().toISOString(), + }); + const relPath = path.relative(PROJECT_ROOT, ADDRESSES_FILE); + console.log(`[ua2] cached addresses written to ${relPath}`); + } + + console.log(`[ua2] using UA² account ${ua2Address} (${source})`); + + const ownerOnChain = await readOwner(toolkit.provider, ua2Address); + if (ownerOnChain.toLowerCase() !== toolkit.ownerPubKey.toLowerCase()) { console.warn( - `[ua2] Warning: UA² owner on-chain (${ownerBefore}) does not match configured owner pubkey (${toolkit.ownerPubKey}).` + `[ua2] warning: on-chain owner ${ownerOnChain} differs from configured owner ${toolkit.ownerPubKey}` ); } - console.log('E2E DEVNET'); - console.log(`- deploy/attach ✓ (${ua2Address})`); + const ownerAccount = new Account(toolkit.provider, ua2Address, toolkit.ownerKey); + const ownerTransport = new AccountCallTransport(ownerAccount); const sessions = makeSessionsManager({ account: { address: ua2Address, chainId: toolkit.chainId }, @@ -46,185 +82,224 @@ async function main(): Promise { ua2Address, }); - const expiresAt = Math.floor(Date.now() / 1000) + 2 * 60 * 60; + const nowSeconds = Math.floor(Date.now() / 1000); + const validAfter = nowSeconds - 30; + const validUntil = validAfter + 2 * 60 * 60; const sessionTargetValue = optionalEnv( - [`UA2_${network.toUpperCase()}_SESSION_TARGET`, 'UA2_SESSION_TARGET', 'UA2_E2E_TARGET_ADDR'], + ['UA2_DEVNET_SESSION_TARGET', 'UA2_SESSION_TARGET', 'UA2_E2E_TARGET_ADDR'], toolkit.guardianAddress ) ?? toolkit.guardianAddress; const sessionTarget = toFelt(sessionTargetValue); const transferSelector = selectorFor('transfer'); const policy: SessionPolicyInput = { - expiresAt, - limits: limits(5, 10n ** 15n), + validAfter, + validUntil, + limits: limits(1, 0), allow: { targets: [sessionTarget], selectors: [transferSelector], }, + active: true, }; + console.log('\n[1] create tight session policy'); const session = await sessions.create(policy); if (!ownerTransport.lastTxHash) { - throw new Error('Session creation did not produce a transaction hash.'); + throw new Error('add_session_with_allowlists did not emit a transaction hash'); } - const createReceipt = await waitForReceipt(toolkit.provider, ownerTransport.lastTxHash, 'session create'); - assertSucceeded(createReceipt, 'session create'); - console.log(`- create session ✓ (${session.id})`); + const sessionReceipt = await waitForReceipt( + toolkit.provider, + ownerTransport.lastTxHash, + 'create session', + RECEIPT_TIMEOUT_MS + ); + assertSucceeded(sessionReceipt, 'create session'); + logReceipt('create session', ownerTransport.lastTxHash, sessionReceipt); + + const sessionKeyHash = deriveSessionKeyHash(session.pubkey); + console.log(`[ua2] session key hash ${sessionKeyHash}`); let usage = initialSessionUsageState(); - usage = await applySessionUsage(toolkit, ownerAccount, ua2Address, session.id, usage, 1, 'session use #1'); - usage = await applySessionUsage(toolkit, ownerAccount, ua2Address, session.id, usage, 1, 'session use #2'); - usage = await applySessionUsage(toolkit, ownerAccount, ua2Address, session.id, usage, 1, 'session use #3'); - console.log('- in-policy x3 ✓'); + console.log('[2] in-policy session call succeeds'); + usage = await expectSessionUsageSuccess( + toolkit, + ownerAccount, + ua2Address, + sessionKeyHash, + usage, + 1, + 'in-policy session call' + ); - const violationTx = await ownerAccount.execute({ - contractAddress: ua2Address, - entrypoint: 'apply_session_usage', - calldata: [ - session.id, - toFelt(usage.callsUsed), - toFelt(3), - toFelt(usage.nonce), - ], - }); - const violationReceipt = await waitForReceipt( - toolkit.provider, - violationTx.transaction_hash, - 'policy violation' + console.log('[3] out-of-policy session call reverts (ERR_POLICY_CALLCAP)'); + await expectSessionUsageRevert( + toolkit, + ownerAccount, + ua2Address, + sessionKeyHash, + usage, + 1, + 'out-of-policy session call', + 'ERR_POLICY_CALLCAP' ); - assertReverted(violationReceipt, 'ERR_POLICY_CALLCAP', 'policy violation'); - console.log('- out-of-policy revert ✓ (ERR_POLICY_CALLCAP)'); + console.log('[4] revoke session'); const revokeTx = await ownerAccount.execute({ contractAddress: ua2Address, entrypoint: 'revoke_session', - calldata: [session.id], + calldata: [sessionKeyHash], }); - const revokeReceipt = await waitForReceipt(toolkit.provider, revokeTx.transaction_hash, 'session revoke'); - assertSucceeded(revokeReceipt, 'session revoke'); - - const postRevokeTx = await ownerAccount.execute({ - contractAddress: ua2Address, - entrypoint: 'apply_session_usage', - calldata: [ - session.id, - toFelt(usage.callsUsed), - toFelt(1), - toFelt(usage.nonce), - ], - }); - const postRevokeReceipt = await waitForReceipt( + const revokeReceipt = await waitForReceipt( toolkit.provider, - postRevokeTx.transaction_hash, - 'post-revoke session usage' + revokeTx.transaction_hash, + 'revoke session', + RECEIPT_TIMEOUT_MS ); - assertReverted(postRevokeReceipt, 'ERR_SESSION_INACTIVE', 'post revoke session use'); - console.log('- revoke + denied ✓'); + assertSucceeded(revokeReceipt, 'revoke session'); + logReceipt('revoke session', revokeTx.transaction_hash, revokeReceipt); - await runGuardianRecovery(toolkit, ua2Address); - console.log('- guardian recovery ✓'); + console.log('[5] session call after revoke reverts (ERR_SESSION_INACTIVE)'); + await expectSessionUsageRevert( + toolkit, + ownerAccount, + ua2Address, + sessionKeyHash, + usage, + 1, + 'post-revoke session call', + 'ERR_SESSION_INACTIVE' + ); - console.log('E2E DEVNET ✓ complete'); + console.log('\nUA² devnet e2e PASS ✅'); } -async function applySessionUsage( +async function pickDeployment( + toolkit: Awaited>, + envAddress?: string, + cachedAddress?: string +): Promise { + if (envAddress) { + return ensureUa2Deployed(toolkit, envAddress); + } + + if (cachedAddress) { + const attached = await ensureUa2Deployed(toolkit, cachedAddress); + if (attached.classHash !== '0x0') { + return attached; + } + console.warn('[ua2] cached UA² address missing on-chain class hash; redeploying'); + } + + console.log('[ua2] deploying UA² account to devnet…'); + return ensureUa2Deployed(toolkit); +} + +async function expectSessionUsageSuccess( toolkit: Awaited>, owner: Account, ua2Address: string, - sessionId: string, + sessionKeyHash: string, state: ReturnType, calls: number, label: string ) { - const tx = await owner.execute({ - contractAddress: ua2Address, - entrypoint: 'apply_session_usage', - calldata: [ - sessionId, - toFelt(state.callsUsed), - toFelt(calls), - toFelt(state.nonce), - ], - }); - const receipt = await waitForReceipt(toolkit.provider, tx.transaction_hash, label); + const { receipt, txHash } = await sendApplySessionUsage( + toolkit, + owner, + ua2Address, + sessionKeyHash, + state, + calls, + label + ); assertSucceeded(receipt, label); + logReceipt(label, txHash, receipt); return updateSessionUsage(state, calls); } -async function runGuardianRecovery(toolkit: Awaited>, ua2Address: string) { - const ownerAccount = new Account(toolkit.provider, ua2Address, toolkit.ownerKey); - - await sendAndAwait(toolkit, ownerAccount, ua2Address, 'add_guardian', [toolkit.guardianAddress], 'add guardian', [ - 'ERR_GUARDIAN_EXISTS', - ]); - await sendAndAwait(toolkit, ownerAccount, ua2Address, 'set_guardian_threshold', [toFelt(1)], 'set guardian threshold'); - await sendAndAwait(toolkit, ownerAccount, ua2Address, 'set_recovery_delay', [toFelt(0)], 'set recovery delay'); - - const guardianPropose = await toolkit.guardian.execute({ - contractAddress: ua2Address, - entrypoint: 'propose_recovery', - calldata: [toolkit.guardianPubKey], - }); - const guardianProposeReceipt = await waitForReceipt( - toolkit.provider, - guardianPropose.transaction_hash, - 'guardian propose recovery' - ); - assertSucceeded(guardianProposeReceipt, 'guardian propose recovery'); - - const executeTx = await toolkit.guardian.execute({ - contractAddress: ua2Address, - entrypoint: 'execute_recovery', - calldata: [], - }); - const executeReceipt = await waitForReceipt(toolkit.provider, executeTx.transaction_hash, 'execute recovery'); - assertSucceeded(executeReceipt, 'execute recovery'); - - const ownerAfter = await readOwner(toolkit.provider, ua2Address); - if (ownerAfter.toLowerCase() !== toolkit.guardianPubKey.toLowerCase()) { - throw new Error(`Recovery did not update owner. expected ${toolkit.guardianPubKey}, got ${ownerAfter}`); - } - - const recoveredOwner = new Account(toolkit.provider, ua2Address, toolkit.guardianKey); - await sendAndAwait( +async function expectSessionUsageRevert( + toolkit: Awaited>, + owner: Account, + ua2Address: string, + sessionKeyHash: string, + state: ReturnType, + calls: number, + label: string, + expectedReason: string +): Promise { + const { receipt, txHash } = await sendApplySessionUsage( toolkit, - recoveredOwner, + owner, ua2Address, - 'rotate_owner', - [toolkit.ownerPubKey], - 'rotate owner back' + sessionKeyHash, + state, + calls, + label ); + assertReverted(receipt, expectedReason, label); + logReceipt(label, txHash, receipt); } -async function sendAndAwait( +async function sendApplySessionUsage( toolkit: Awaited>, - signer: Account, + owner: Account, ua2Address: string, - entrypoint: string, - calldata: readonly string[], - label: string, - ignorableReasons: string[] = [] -): Promise { - const tx = await signer.execute({ + sessionKeyHash: string, + state: ReturnType, + calls: number, + label: string +) { + const tx = await owner.execute({ contractAddress: ua2Address, - entrypoint, - calldata: [...calldata], + entrypoint: 'apply_session_usage', + calldata: [ + sessionKeyHash, + toFelt(state.callsUsed), + toFelt(calls), + toFelt(state.nonce), + ], }); - const receipt = await waitForReceipt(toolkit.provider, tx.transaction_hash, label); - const execution = (receipt?.execution_status ?? '').toString(); - if (execution === 'REVERTED') { - const reason = (receipt?.revert_reason ?? '').toString(); - if (ignorableReasons.some((expected) => reason.includes(expected))) { - return; + const receipt = await waitForReceipt( + toolkit.provider, + tx.transaction_hash, + label, + RECEIPT_TIMEOUT_MS + ); + return { receipt, txHash: tx.transaction_hash }; +} + +async function readCachedAddresses(): Promise { + try { + const raw = await fsPromises.readFile(ADDRESSES_FILE, 'utf8'); + const parsed = JSON.parse(raw) as Partial; + if (!parsed.ua2Address) return undefined; + return { + ua2Address: normalizeHex(parsed.ua2Address), + classHash: parsed.classHash ? normalizeHex(parsed.classHash) : '0x0', + network: parsed.network ?? 'devnet', + updatedAt: parsed.updatedAt ?? new Date().toISOString(), + }; + } catch (err: any) { + if (err && (err as NodeJS.ErrnoException).code === 'ENOENT') { + return undefined; } - throw new Error(`${label} failed: ${reason || 'unknown revert reason'}`); + throw err; } - assertSucceeded(receipt, label); +} + +async function writeCachedAddresses(record: DevnetAddresses): Promise { + await fsPromises.mkdir(path.dirname(ADDRESSES_FILE), { recursive: true }); + await fsPromises.writeFile( + ADDRESSES_FILE, + `${JSON.stringify(record, null, 2)}\n`, + 'utf8' + ); } void main().catch((err) => { - console.error('[ua2] E2E devnet failed:', err); + console.error('\n[ua2] e2e devnet failed:', err); process.exitCode = 1; }); diff --git a/packages/example/scripts/e2e-sepolia.ts b/packages/example/scripts/e2e-sepolia.ts index 970eec7..3e98acd 100644 --- a/packages/example/scripts/e2e-sepolia.ts +++ b/packages/example/scripts/e2e-sepolia.ts @@ -5,7 +5,9 @@ import { AccountCallTransport, assertReverted, assertSucceeded, + deriveSessionKeyHash, initialSessionUsageState, + logReceipt, optionalEnv, readOwner, selectorFor, @@ -13,216 +15,213 @@ import { toFelt, updateSessionUsage, waitForReceipt, - type Network, + normalizeHex, } from './shared.js'; +const RECEIPT_TIMEOUT_MS = 240_000; + async function main(): Promise { - const network: Network = 'sepolia'; - const toolkit = await setupToolkit(network); + console.log('[ua2] e2e sepolia (attach-only) starting'); + + const toolkit = await setupToolkit('sepolia'); - const ua2Address = optionalEnv([ - `UA2_${network.toUpperCase()}_PROXY_ADDR`, + const ua2AddressRaw = optionalEnv([ + 'UA2_SEPOLIA_PROXY_ADDR', 'UA2_PROXY_ADDR', + 'UA2_ADDR', ]); - if (!ua2Address) { + if (!ua2AddressRaw) { throw new Error('UA2_PROXY_ADDR is required for Sepolia E2E runs.'); } - const normalizedAddress = ua2Address; - const ownerAccount = new Account(toolkit.provider, normalizedAddress, toolkit.ownerKey); - const ownerTransport = new AccountCallTransport(ownerAccount); + const ua2Address = normalizeHex(ua2AddressRaw); + console.log(`[ua2] attaching to UA² account ${ua2Address}`); - const ownerBefore = await readOwner(toolkit.provider, normalizedAddress); - if (ownerBefore.toLowerCase() !== toolkit.ownerPubKey.toLowerCase()) { + const ownerOnChain = await readOwner(toolkit.provider, ua2Address); + if (ownerOnChain.toLowerCase() !== toolkit.ownerPubKey.toLowerCase()) { console.warn( - `[ua2] Warning: UA² owner on-chain (${ownerBefore}) does not match configured owner pubkey (${toolkit.ownerPubKey}).` + `[ua2] warning: on-chain owner ${ownerOnChain} differs from configured owner ${toolkit.ownerPubKey}` ); } - console.log('E2E SEPOLIA'); - console.log(`- attach ✓ (${normalizedAddress})`); + const ownerAccount = new Account(toolkit.provider, ua2Address, toolkit.ownerKey); + const ownerTransport = new AccountCallTransport(ownerAccount); const sessions = makeSessionsManager({ - account: { address: normalizedAddress, chainId: toolkit.chainId }, + account: { address: ua2Address, chainId: toolkit.chainId }, transport: ownerTransport, - ua2Address: normalizedAddress, + ua2Address, }); - const expiresAt = Math.floor(Date.now() / 1000) + 2 * 60 * 60; + const nowSeconds = Math.floor(Date.now() / 1000); + const validAfter = nowSeconds - 60; + const validUntil = validAfter + 30 * 60; const sessionTargetValue = optionalEnv( - [`UA2_${network.toUpperCase()}_SESSION_TARGET`, 'UA2_SESSION_TARGET', 'UA2_E2E_TARGET_ADDR'], + ['UA2_SEPOLIA_SESSION_TARGET', 'UA2_SESSION_TARGET', 'UA2_E2E_TARGET_ADDR'], toolkit.guardianAddress ) ?? toolkit.guardianAddress; const sessionTarget = toFelt(sessionTargetValue); const transferSelector = selectorFor('transfer'); const policy: SessionPolicyInput = { - expiresAt, - limits: limits(5, 10n ** 15n), + validAfter, + validUntil, + limits: limits(1, 10n ** 15n), allow: { targets: [sessionTarget], selectors: [transferSelector], }, + active: true, }; + console.log('\n[1] create tight session policy'); const session = await sessions.create(policy); if (!ownerTransport.lastTxHash) { - throw new Error('Session creation did not produce a transaction hash.'); + throw new Error('add_session_with_allowlists did not emit a transaction hash'); } - const createReceipt = await waitForReceipt(toolkit.provider, ownerTransport.lastTxHash, 'session create'); - assertSucceeded(createReceipt, 'session create'); - logReceipt('create session', ownerTransport.lastTxHash, createReceipt); + const sessionReceipt = await waitForReceipt( + toolkit.provider, + ownerTransport.lastTxHash, + 'create session', + RECEIPT_TIMEOUT_MS + ); + assertSucceeded(sessionReceipt, 'create session'); + logReceipt('create session', ownerTransport.lastTxHash, sessionReceipt); + + const sessionKeyHash = deriveSessionKeyHash(session.pubkey); + console.log(`[ua2] session key hash ${sessionKeyHash}`); let usage = initialSessionUsageState(); - usage = await applySessionUsage(toolkit, ownerAccount, normalizedAddress, session.id, usage, 1, 'session use #1'); - usage = await applySessionUsage(toolkit, ownerAccount, normalizedAddress, session.id, usage, 1, 'session use #2'); - const third = await applySessionUsage(toolkit, ownerAccount, normalizedAddress, session.id, usage, 1, 'session use #3'); - usage = third; - console.log('- call via session ✓'); + console.log('[2] in-policy session call succeeds'); + usage = await expectSessionUsageSuccess( + toolkit, + ownerAccount, + ua2Address, + sessionKeyHash, + usage, + 1, + 'in-policy session call' + ); + + console.log('[3] out-of-policy session call reverts (ERR_POLICY_CALLCAP)'); + await expectSessionUsageRevert( + toolkit, + ownerAccount, + ua2Address, + sessionKeyHash, + usage, + 1, + 'out-of-policy session call', + 'ERR_POLICY_CALLCAP' + ); + console.log('[4] revoke session'); const revokeTx = await ownerAccount.execute({ - contractAddress: normalizedAddress, + contractAddress: ua2Address, entrypoint: 'revoke_session', - calldata: [session.id], - }); - const revokeReceipt = await waitForReceipt(toolkit.provider, revokeTx.transaction_hash, 'session revoke'); - assertSucceeded(revokeReceipt, 'session revoke'); - logReceipt('revoke session', revokeTx.transaction_hash, revokeReceipt); - - const postRevokeTx = await ownerAccount.execute({ - contractAddress: normalizedAddress, - entrypoint: 'apply_session_usage', - calldata: [ - session.id, - toFelt(usage.callsUsed), - toFelt(1), - toFelt(usage.nonce), - ], + calldata: [sessionKeyHash], }); - const postRevokeReceipt = await waitForReceipt( + const revokeReceipt = await waitForReceipt( toolkit.provider, - postRevokeTx.transaction_hash, - 'post-revoke session usage' + revokeTx.transaction_hash, + 'revoke session', + RECEIPT_TIMEOUT_MS ); - assertReverted(postRevokeReceipt, 'ERR_SESSION_INACTIVE', 'post revoke session use'); - logReceipt('post-revoke session', postRevokeTx.transaction_hash, postRevokeReceipt); + assertSucceeded(revokeReceipt, 'revoke session'); + logReceipt('revoke session', revokeTx.transaction_hash, revokeReceipt); - await runGuardianRecovery(toolkit, normalizedAddress); - console.log('- guardian recovery ✓'); + console.log('[5] session call after revoke reverts (ERR_SESSION_INACTIVE)'); + await expectSessionUsageRevert( + toolkit, + ownerAccount, + ua2Address, + sessionKeyHash, + usage, + 1, + 'post-revoke session call', + 'ERR_SESSION_INACTIVE' + ); - console.log('E2E SEPOLIA ✓ complete'); + console.log('\nUA² sepolia e2e PASS ✅'); } -async function applySessionUsage( +async function expectSessionUsageSuccess( toolkit: Awaited>, owner: Account, ua2Address: string, - sessionId: string, + sessionKeyHash: string, state: ReturnType, calls: number, label: string ) { - const tx = await owner.execute({ - contractAddress: ua2Address, - entrypoint: 'apply_session_usage', - calldata: [ - sessionId, - toFelt(state.callsUsed), - toFelt(calls), - toFelt(state.nonce), - ], - }); - const receipt = await waitForReceipt(toolkit.provider, tx.transaction_hash, label); + const { receipt, txHash } = await sendApplySessionUsage( + toolkit, + owner, + ua2Address, + sessionKeyHash, + state, + calls, + label + ); assertSucceeded(receipt, label); - logReceipt(label, tx.transaction_hash, receipt); + logReceipt(label, txHash, receipt); return updateSessionUsage(state, calls); } -async function runGuardianRecovery(toolkit: Awaited>, ua2Address: string) { - const ownerAccount = new Account(toolkit.provider, ua2Address, toolkit.ownerKey); - - await sendAndAwait(toolkit, ownerAccount, ua2Address, 'add_guardian', [toolkit.guardianAddress], 'add guardian', [ - 'ERR_GUARDIAN_EXISTS', - ]); - await sendAndAwait(toolkit, ownerAccount, ua2Address, 'set_guardian_threshold', [toFelt(1)], 'set guardian threshold'); - await sendAndAwait(toolkit, ownerAccount, ua2Address, 'set_recovery_delay', [toFelt(0)], 'set recovery delay'); - - const guardianPropose = await toolkit.guardian.execute({ - contractAddress: ua2Address, - entrypoint: 'propose_recovery', - calldata: [toolkit.guardianPubKey], - }); - const guardianProposeReceipt = await waitForReceipt( - toolkit.provider, - guardianPropose.transaction_hash, - 'guardian propose recovery' - ); - assertSucceeded(guardianProposeReceipt, 'guardian propose recovery'); - logReceipt('guardian propose recovery', guardianPropose.transaction_hash, guardianProposeReceipt); - - const executeTx = await toolkit.guardian.execute({ - contractAddress: ua2Address, - entrypoint: 'execute_recovery', - calldata: [], - }); - const executeReceipt = await waitForReceipt(toolkit.provider, executeTx.transaction_hash, 'execute recovery'); - assertSucceeded(executeReceipt, 'execute recovery'); - logReceipt('guardian execute recovery', executeTx.transaction_hash, executeReceipt); - - const ownerAfter = await readOwner(toolkit.provider, ua2Address); - if (ownerAfter.toLowerCase() !== toolkit.guardianPubKey.toLowerCase()) { - throw new Error(`Recovery did not update owner. expected ${toolkit.guardianPubKey}, got ${ownerAfter}`); - } - - const recoveredOwner = new Account(toolkit.provider, ua2Address, toolkit.guardianKey); - await sendAndAwait( +async function expectSessionUsageRevert( + toolkit: Awaited>, + owner: Account, + ua2Address: string, + sessionKeyHash: string, + state: ReturnType, + calls: number, + label: string, + expectedReason: string +): Promise { + const { receipt, txHash } = await sendApplySessionUsage( toolkit, - recoveredOwner, + owner, ua2Address, - 'rotate_owner', - [toolkit.ownerPubKey], - 'rotate owner back' + sessionKeyHash, + state, + calls, + label ); + assertReverted(receipt, expectedReason, label); + logReceipt(label, txHash, receipt); } -async function sendAndAwait( +async function sendApplySessionUsage( toolkit: Awaited>, - signer: Account, + owner: Account, ua2Address: string, - entrypoint: string, - calldata: readonly string[], - label: string, - ignorableReasons: string[] = [] -): Promise { - const tx = await signer.execute({ + sessionKeyHash: string, + state: ReturnType, + calls: number, + label: string +) { + const tx = await owner.execute({ contractAddress: ua2Address, - entrypoint, - calldata: [...calldata], + entrypoint: 'apply_session_usage', + calldata: [ + sessionKeyHash, + toFelt(state.callsUsed), + toFelt(calls), + toFelt(state.nonce), + ], }); - const receipt = await waitForReceipt(toolkit.provider, tx.transaction_hash, label); - const execution = (receipt?.execution_status ?? '').toString(); - if (execution === 'REVERTED') { - const reason = (receipt?.revert_reason ?? '').toString(); - if (ignorableReasons.some((expected) => reason.includes(expected))) { - return; - } - throw new Error(`${label} failed: ${reason || 'unknown revert reason'}`); - } - assertSucceeded(receipt, label); - logReceipt(label, tx.transaction_hash, receipt); -} - -function logReceipt(label: string, txHash: string, receipt: any): void { - const status = receipt?.finality_status ?? receipt?.status ?? 'UNKNOWN'; - const execution = receipt?.execution_status ?? 'UNKNOWN'; - console.log( - ` • ${label}: tx=${txHash} finality=${status} execution=${execution}` + const receipt = await waitForReceipt( + toolkit.provider, + tx.transaction_hash, + label, + RECEIPT_TIMEOUT_MS ); + return { receipt, txHash: tx.transaction_hash }; } void main().catch((err) => { - console.error('[ua2] E2E sepolia failed:', err); + console.error('\n[ua2] e2e sepolia failed:', err); process.exitCode = 1; }); diff --git a/packages/example/scripts/shared.ts b/packages/example/scripts/shared.ts index 875d709..2c20fa4 100644 --- a/packages/example/scripts/shared.ts +++ b/packages/example/scripts/shared.ts @@ -25,7 +25,7 @@ export interface Toolkit { } const __dirname = path.dirname(fileURLToPath(import.meta.url)); -const PROJECT_ROOT = path.resolve(__dirname, '../../..'); +export const PROJECT_ROOT = path.resolve(__dirname, '../../..'); const CONTRACTS_ROOT = path.resolve(PROJECT_ROOT, 'packages/contracts'); export class AccountCallTransport implements CallTransport { @@ -233,6 +233,12 @@ export function normalizeHex(value: string): Felt { return ('0x' + hex) as Felt; } +export function deriveSessionKeyHash(pubkey: Felt): Felt { + const normalizedKey = normalizeHex(pubkey); + const hashValue = hash.computePedersenHash(normalizedKey, '0x0'); + return normalizeHex(hashValue); +} + export function normalizePrivateKey(value: string): string { const trimmed = value.trim(); return trimmed.startsWith('0x') || trimmed.startsWith('0X') ? trimmed : `0x${trimmed}`; @@ -326,6 +332,12 @@ export function selectorFor(name: string): Felt { return normalizeHex(hash.getSelectorFromName(name)); } +export function logReceipt(label: string, txHash: string, receipt: any): void { + const status = receipt?.finality_status ?? receipt?.status ?? 'UNKNOWN'; + const execution = receipt?.execution_status ?? 'UNKNOWN'; + console.log(` • ${label}: tx=${txHash} finality=${status} execution=${execution}`); +} + export interface SessionUsageState { callsUsed: number; nonce: bigint; diff --git a/packages/example/src/App.tsx b/packages/example/src/App.tsx index cbff4ce..fae65c8 100644 --- a/packages/example/src/App.tsx +++ b/packages/example/src/App.tsx @@ -254,9 +254,10 @@ function SessionCreateForm({ create, isReady }: SessionCreateFormProps): JSX.Ele throw new Error('Expiry must be a positive number of minutes.'); } const maxValueBigInt = BigInt(maxValue); - const expiresAt = Math.floor(Date.now() / 1000) + parsedExpires * 60; + const validAfter = Math.floor(Date.now() / 1000); const policy = { - expiresAt, + validAfter, + validUntil: validAfter + parsedExpires * 60, limits: limits(parsedMaxCalls, maxValueBigInt), allow: { targets: target.trim() ? [target.trim()] : [], @@ -428,7 +429,7 @@ function SessionList({ sessions, revoke, refresh, client }: SessionListProps): J
- Expires at: {session.policy.expiresAt} · Max calls: {session.policy.limits.maxCalls} + Expires at: {session.policy.validUntil} · Max calls: {session.policy.limits.maxCalls}