Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
- NonAssetOp: 其余(login/claim/mint 等)
"""

from typing import Optional
from .types import TxInput, TxCategory
from .constants import (
ASSET_OP_SELECTORS,
Expand Down
2 changes: 1 addition & 1 deletion src/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from .config import GateConfig
from .gate import ExecutionGate
from .evaluation import Evaluator, TxSample, load_dataset
from .evaluation import Evaluator, load_dataset

# 配置日志
logging.basicConfig(
Expand Down
2 changes: 1 addition & 1 deletion src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from dataclasses import dataclass, field
from typing import Optional, Dict, Any
from typing import Dict, Any
import os
import json

Expand Down
32 changes: 15 additions & 17 deletions src/delta_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ def extract_from_logs(

for log in logs:
event_sig, indexed = _parse_log_topics(log)
contract_addr = log.get("address", "").lower()
data = log.get("data", "0x")

# ERC20/ERC721 Transfer 事件
if event_sig == TOPIC_ERC20_TRANSFER.lower():
Expand All @@ -113,17 +111,17 @@ def extract_from_logs(

# ERC20 Approval 事件
elif event_sig == TOPIC_ERC20_APPROVAL.lower():
change = self._parse_approval_event(log, user_addr)
if change:
permission_changes.append(change)
affected_tokens.add(change.token_address)
permission_change = self._parse_approval_event(log, user_addr)
if permission_change:
permission_changes.append(permission_change)
affected_tokens.add(permission_change.token_address)

# ERC721/ERC1155 ApprovalForAll 事件
elif event_sig == TOPIC_APPROVAL_FOR_ALL.lower():
change = self._parse_approval_for_all_event(log, user_addr)
if change:
permission_changes.append(change)
affected_tokens.add(change.token_address)
permission_change = self._parse_approval_for_all_event(log, user_addr)
if permission_change:
permission_changes.append(permission_change)
affected_tokens.add(permission_change.token_address)

# ERC1155 TransferSingle 事件
elif event_sig == TOPIC_ERC1155_TRANSFER_SINGLE.lower():
Expand Down Expand Up @@ -176,7 +174,7 @@ def _parse_transfer_event(
user_addr: str
) -> Optional[AssetChange]:
"""解析 Transfer 事件"""
event_sig, indexed = _parse_log_topics(log)
_, indexed = _parse_log_topics(log)
contract_addr = log.get("address", "").lower()
data = log.get("data", "0x")

Expand Down Expand Up @@ -231,7 +229,7 @@ def _parse_approval_event(
user_addr: str
) -> Optional[PermissionChange]:
"""解析 Approval 事件(ERC20)"""
event_sig, indexed = _parse_log_topics(log)
_, indexed = _parse_log_topics(log)
contract_addr = log.get("address", "").lower()
data = log.get("data", "0x")

Expand Down Expand Up @@ -263,7 +261,7 @@ def _parse_approval_for_all_event(
user_addr: str
) -> Optional[PermissionChange]:
"""解析 ApprovalForAll 事件"""
event_sig, indexed = _parse_log_topics(log)
_, indexed = _parse_log_topics(log)
contract_addr = log.get("address", "").lower()
data = log.get("data", "0x")

Expand Down Expand Up @@ -298,7 +296,7 @@ def _parse_erc1155_transfer_single(
user_addr: str
) -> Optional[AssetChange]:
"""解析 ERC1155 TransferSingle 事件"""
event_sig, indexed = _parse_log_topics(log)
_, indexed = _parse_log_topics(log)
contract_addr = log.get("address", "").lower()
data = log.get("data", "0x")

Expand Down Expand Up @@ -343,11 +341,11 @@ def _parse_erc1155_transfer_batch(
user_addr: str
) -> List[AssetChange]:
"""解析 ERC1155 TransferBatch 事件"""
event_sig, indexed = _parse_log_topics(log)
_, indexed = _parse_log_topics(log)
contract_addr = log.get("address", "").lower()
data = log.get("data", "0x")

changes = []
changes: List[AssetChange] = []

if len(indexed) < 3:
return changes
Expand Down Expand Up @@ -487,7 +485,7 @@ async def extract_delta_for_tx(
Returns:
(TxInput, SimMeta, DeltaS)
"""
from .simulator import Simulator, SimulatorConfig
from .simulator import Simulator

config = SimulatorConfig(rpc_url=rpc_url, enable_trace=enable_trace)
simulator = Simulator(config)
Expand Down
5 changes: 2 additions & 3 deletions src/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
import time
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from typing import List, Dict, Any, Optional
from collections import defaultdict

from .types import Decision, InvariantId, GateDecision
from .types import Decision, GateDecision
from .config import GateConfig
from .gate import ExecutionGate

Expand Down
4 changes: 2 additions & 2 deletions src/gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

import asyncio
import logging
from typing import Optional, Tuple
from typing import Optional

from .types import TxInput, SimMeta, DeltaS, GateDecision, Decision
from .types import TxInput, GateDecision
from .config import GateConfig, DEFAULT_CONFIG
from .simulator import Simulator
from .delta_extractor import DeltaExtractor
Expand Down
1 change: 0 additions & 1 deletion src/invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
InvariantViolation,
RiskLabel,
FailOpenReason,
is_unlimited_allowance,
)
from .config import InvariantConfig, DEFAULT_CONFIG
from .classifier import is_likely_swap
Expand Down
24 changes: 19 additions & 5 deletions src/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
import asyncio
import logging
from typing import Optional, Dict, Any, List, Tuple
from dataclasses import dataclass

from .types import TxInput, SimMeta, FailOpenReason, TxCategory
from .types import TxInput, SimMeta, FailOpenReason
from .config import SimulatorConfig, DEFAULT_CONFIG
from .classifier import classify_transaction

Expand Down Expand Up @@ -114,7 +113,12 @@ async def get_transaction(self, tx_hash: str) -> Optional[Dict[str, Any]]:
"""
try:
tx_data = await self.client.call("eth_getTransactionByHash", [tx_hash])
return tx_data
if tx_data is None:
return None
if isinstance(tx_data, dict):
return tx_data
logger.warning(f"Unexpected transaction payload type for {tx_hash}: {type(tx_data)}")
return None
except Exception as e:
logger.warning(f"Failed to get transaction {tx_hash}: {e}")
return None
Expand All @@ -123,7 +127,12 @@ async def get_transaction_receipt(self, tx_hash: str) -> Optional[Dict[str, Any]
"""获取交易收据(包含日志)"""
try:
receipt = await self.client.call("eth_getTransactionReceipt", [tx_hash])
return receipt
if receipt is None:
return None
if isinstance(receipt, dict):
return receipt
logger.warning(f"Unexpected receipt payload type for {tx_hash}: {type(receipt)}")
return None
except Exception as e:
logger.warning(f"Failed to get receipt {tx_hash}: {e}")
return None
Expand Down Expand Up @@ -261,7 +270,12 @@ async def _get_trace(self, tx_hash: str) -> Optional[Dict[str, Any]]:
"debug_traceTransaction",
[tx_hash, {"tracer": "callTracer"}]
)
return trace
if trace is None:
return None
if isinstance(trace, dict):
return trace
logger.debug(f"Unexpected trace payload type for {tx_hash}: {type(trace)}")
return None
except Exception as e:
logger.debug(f"Failed to get trace for {tx_hash}: {e}")
return None
Expand Down
1 change: 0 additions & 1 deletion src/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional, List, Dict, Any
from decimal import Decimal


# ============================================================
Expand Down
1 change: 0 additions & 1 deletion tests/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
测试交易分类器
"""

import pytest
from src.types import TxInput, TxCategory
from src.classifier import classify_transaction, is_likely_swap
from src.constants import SELECTOR_TRANSFER, SELECTOR_APPROVE, SELECTOR_SWAP_EXACT_TOKENS
Expand Down
4 changes: 1 addition & 3 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
"""

import httpx
import pytest
from pytest_httpx import HTTPXMock

from src.gate import ExecutionGate
Expand All @@ -45,8 +44,7 @@
USER, ATTACKER, SPENDER, NFT_CONTRACT, DEX_ROUTER,
TOKEN_A, TOKEN_B, TOKEN_C,
ONE_ETHER,
eth_call_ok, receipt_ok, receipt_empty,
transfer_log, approval_log, approval_for_all_log,
eth_call_ok, receipt_ok, transfer_log, approval_log, approval_for_all_log,
)

UNLIMITED = 2**256 - 1
Expand Down
1 change: 0 additions & 1 deletion tests/test_invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@ def test_increase_allowance_classified_as_permission_op(self):
to_address="0xtoken",
data=SELECTOR_INCREASE_ALLOWANCE + "0" * 128,
)
from src.classifier import classify_transaction
assert classify_transaction(tx) == TxCategory.PERMISSION_OP

def test_1inch_swap_is_likely_swap(self):
Expand Down
3 changes: 0 additions & 3 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,16 @@
测试类型定义
"""

import pytest
from src.types import (
TxInput,
TxCategory,
AssetChange,
PermissionChange,
DeltaS,
DeltaScope,
GateDecision,
Decision,
InvariantViolation,
InvariantId,
ETH_ADDRESS,
is_unlimited_allowance,
UNLIMITED_ALLOWANCE,
)
Expand Down
Loading