diff --git a/src/allocators/OnChainAllocator.sol b/src/allocators/OnChainAllocator.sol index d27f49a..edd1ec6 100644 --- a/src/allocators/OnChainAllocator.sol +++ b/src/allocators/OnChainAllocator.sol @@ -128,7 +128,11 @@ contract OnChainAllocator is IOnChainAllocator { uint32 expires, bytes32 typehash, bytes32 witness - ) public returns (bytes32 claimHash, uint256[] memory registeredAmounts, uint256 nonce) { + ) public payable returns (bytes32 claimHash, uint256[] memory registeredAmounts, uint256 nonce) { + // Check for empty commitments + if (commitments.length == 0) { + revert InvalidCommitments(); + } if (expires <= block.timestamp) { revert InvalidExpiration(expires, block.timestamp); } @@ -136,37 +140,76 @@ contract OnChainAllocator is IOnChainAllocator { recipient = AL.getRecipient(recipient); nonce = _getAndUpdateNonce(msg.sender, recipient); + // Transformed locks will be stored in idsAndAmounts uint256[2][] memory idsAndAmounts = new uint256[2][](commitments.length); + // Init minResetPeriod to the max value uint256 minResetPeriod = type(uint256).max; - for (uint256 i = 0; i < commitments.length; i++) { + + // Process native token (zero address) first + uint256 i; + if (commitments[i].token == address(0)) { + // The Compact will revert if invalid value is provided for native token + // Possible cases: + // 1. The callvalue is zero but the first token is native + // 2. the callvalue is nonzero but the first token is non-native + // 3. the first token is native and the callvalue doesn't equal the first amount + + // Handle first and third points + if (commitments[i].amount == 0 || commitments[i].amount != msg.value) { + revert InvalidAmount(commitments[i].amount); + } + minResetPeriod = _checkInput(commitments[i], recipient, expires, minResetPeriod); + idsAndAmounts[i][0] = AL.toId(commitments[i].lockTag, commitments[i].token); + idsAndAmounts[i][1] = msg.value; + + unchecked { + ++i; + } + } else { + // Handle second point + if (msg.value != 0) { + revert InvalidAmount(msg.value); + } + } + + // Process the rest of the commitments + for (; i < commitments.length; i++) { + minResetPeriod = _checkInput(commitments[i], recipient, expires, minResetPeriod); + + address token = commitments[i].token; + // Safe to cast - _checkInput validated that the value fits the uint224 uint224 amount = uint224(commitments[i].amount); - // If the amount is 0, we use the balance of the contract to deposit. + // If the amount is 0, we use the balance of the contract to deposit if (amount == 0) { - uint256 balance = IERC20(commitments[i].token).balanceOf(address(this)); + uint256 balance = IERC20(token).balanceOf(address(this)); // Check the amount fits in the supported range if (balance > type(uint224).max) { revert InvalidAmount(balance); } amount = uint224(balance); } + + // Store the lock in idsAndAmounts + idsAndAmounts[i][0] = AL.toId(commitments[i].lockTag, token); idsAndAmounts[i][1] = amount; // Approve the compact contract to spend the tokens. - if (IERC20(commitments[i].token).allowance(address(this), COMPACT_CONTRACT) < amount) { - SafeTransferLib.safeApproveWithRetry(commitments[i].token, COMPACT_CONTRACT, type(uint256).max); + if (IERC20(token).allowance(address(this), COMPACT_CONTRACT) < amount) { + SafeTransferLib.safeApproveWithRetry(token, COMPACT_CONTRACT, type(uint256).max); } } - // Ensure expiration is not bigger then the smallest reset period + + // Ensure expiration is less then the smallest reset period if (expires >= block.timestamp + minResetPeriod) { revert InvalidExpiration(expires, block.timestamp + minResetPeriod); } // Deposit the tokens and register the claim in the compact - (claimHash, registeredAmounts) = ITheCompact(COMPACT_CONTRACT).batchDepositAndRegisterFor( + (claimHash, registeredAmounts) = ITheCompact(COMPACT_CONTRACT).batchDepositAndRegisterFor{value: msg.value}( recipient, idsAndAmounts, arbiter, nonce, expires, typehash, witness ); diff --git a/src/interfaces/IOnChainAllocator.sol b/src/interfaces/IOnChainAllocator.sol index 540543a..b79ff16 100644 --- a/src/interfaces/IOnChainAllocator.sol +++ b/src/interfaces/IOnChainAllocator.sol @@ -94,5 +94,5 @@ interface IOnChainAllocator is IOnChainAllocation { uint32 expires, bytes32 typehash, bytes32 witness - ) external returns (bytes32 claimHash, uint256[] memory registeredAmounts, uint256 nonce); + ) external payable returns (bytes32 claimHash, uint256[] memory registeredAmounts, uint256 nonce); } diff --git a/test/OnChainAllocator.t.sol b/test/OnChainAllocator.t.sol index da45976..90523d3 100644 --- a/test/OnChainAllocator.t.sol +++ b/test/OnChainAllocator.t.sol @@ -1246,6 +1246,46 @@ contract OnChainAllocatorTest is Test, TestHelper { ); } + function test_allocateAndRegister_revert_InvalidCommitments() public { + Lock[] memory commitments = new Lock[](0); + + vm.expectRevert(abi.encodeWithSelector(IOnChainAllocator.InvalidCommitments.selector)); + allocator.allocateAndRegister( + recipient, commitments, arbiter, defaultExpiration, BATCH_COMPACT_TYPEHASH, bytes32(0) + ); + } + + function test_allocateAndRegister_revert_InvalidAmount_native() public { + Lock[] memory commitments = new Lock[](1); + commitments[0] = _makeLock(address(0), 1 ether); + + vm.expectRevert(abi.encodeWithSelector(IOnChainAllocator.InvalidAmount.selector, commitments[0].amount)); + allocator.allocateAndRegister{value: commitments[0].amount + 1}( + recipient, commitments, arbiter, defaultExpiration, BATCH_COMPACT_TYPEHASH, bytes32(0) + ); + } + + function test_allocateAndRegister_revert_InvalidAmount_native_with_zero_deposit() public { + Lock[] memory commitments = new Lock[](1); + commitments[0] = _makeLock(address(0), 0 ether); + + vm.expectRevert(abi.encodeWithSelector(IOnChainAllocator.InvalidAmount.selector, 0)); + allocator.allocateAndRegister{value: 0}( + recipient, commitments, arbiter, defaultExpiration, BATCH_COMPACT_TYPEHASH, bytes32(0) + ); + } + + function test_allocateAndRegister_revert_InvalidAmount_non_native_with_non_zero_native_call() public { + Lock[] memory commitments = new Lock[](1); + commitments[0] = _makeLock(address(usdc), 1); + + uint256 amount = 1; + vm.expectRevert(abi.encodeWithSelector(IOnChainAllocator.InvalidAmount.selector, amount)); + allocator.allocateAndRegister{value: amount}( + recipient, commitments, arbiter, defaultExpiration, BATCH_COMPACT_TYPEHASH, bytes32(0) + ); + } + function test_allocateAndRegister_revert_invalidExpiration() public { Lock[] memory commitments = new Lock[](1); commitments[0] = _makeLock(address(usdc), defaultAmount); @@ -1422,6 +1462,43 @@ contract OnChainAllocatorTest is Test, TestHelper { ); } + function test_allocateAndRegister_success_multiple() public { + uint256 amount1 = 1 ether; + uint256 amount2 = 2 ether; + + usdc.mint(address(allocator), amount2); + + Lock[] memory commitments = new Lock[](2); + commitments[0] = _makeLock(address(0), amount1); + commitments[1] = _makeLock(address(usdc), amount2); + + vm.deal(caller, amount1); + vm.prank(caller); + (bytes32 claimHash, uint256[] memory registeredAmounts, uint256 nonce) = allocator.allocateAndRegister{ + value: amount1 + }(recipient, commitments, arbiter, defaultExpiration, BATCH_COMPACT_TYPEHASH, bytes32(0)); + + uint256 id1 = _toId(Scope.Multichain, ResetPeriod.TenMinutes, address(allocator), address(0)); + uint256 id2 = _toId(Scope.Multichain, ResetPeriod.TenMinutes, address(allocator), address(usdc)); + + assertEq(registeredAmounts.length, 2); + assertEq(registeredAmounts[0], amount1); + assertEq(registeredAmounts[1], amount2); + + assertEq(ERC6909(address(compact)).balanceOf(recipient, id1), amount1); + assertEq(ERC6909(address(compact)).balanceOf(recipient, id2), amount2); + + uint256[2][] memory idsAndAmounts = new uint256[2][](2); + idsAndAmounts[0][0] = id1; + idsAndAmounts[0][1] = amount1; + idsAndAmounts[1][0] = id2; + idsAndAmounts[1][1] = amount2; + + assertTrue( + allocator.isClaimAuthorized(claimHash, arbiter, recipient, nonce, defaultExpiration, idsAndAmounts, '') + ); + } + function test_constructor_allowsPreRegisteredAllocator_create2() public { OnChainAllocatorFactory factory = new OnChainAllocatorFactory();